tensor.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. from functools import reduce
  2. import torch
  3. import torch._utils
  4. from ..function import Function
  5. class Type(Function):
  6. @staticmethod
  7. def forward(ctx, i, dest_type):
  8. ctx.input_type = type(i)
  9. ctx.input_device = -1 if not i.is_cuda else i.get_device()
  10. return i.type(dest_type)
  11. @staticmethod
  12. def backward(ctx, grad_output):
  13. if ctx.input_device == -1:
  14. return grad_output.type(ctx.input_type), None
  15. else:
  16. with torch.cuda.device(ctx.input_device):
  17. return grad_output.type(ctx.input_type), None
  18. # TODO: deprecate this
  19. class Resize(Function):
  20. @staticmethod
  21. def forward(ctx, tensor, sizes):
  22. ctx.sizes = sizes
  23. ctx.numel = reduce(lambda x, y: x * y, sizes, 1)
  24. if tensor.numel() != ctx.numel:
  25. raise RuntimeError(("requested resize to {} ({} elements in total), "
  26. "but the given tensor has a size of {} ({} elements). "
  27. "autograd's resize can only change the shape of a given "
  28. "tensor, while preserving the number of elements. ").format(
  29. 'x'.join(map(str, sizes)), ctx.numel,
  30. 'x'.join(map(str, tensor.size())), tensor.numel()))
  31. ctx.input_sizes = tensor.size()
  32. if tensor.is_quantized:
  33. tensor.copy_(tensor)
  34. return tensor.contiguous().view(*sizes)
  35. if tensor.is_contiguous():
  36. result = tensor.new(tensor).contiguous().view(*sizes)
  37. return result
  38. else:
  39. return tensor.contiguous().view(*sizes)
  40. @staticmethod
  41. def backward(ctx, grad_output):
  42. assert grad_output.numel() == ctx.numel
  43. return grad_output.contiguous().view(ctx.input_sizes), None