ComplexHelper.h 3.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. #pragma once
  2. #include <ATen/core/Tensor.h>
  3. #include <c10/util/irange.h>
  4. #ifndef AT_PER_OPERATOR_HEADERS
  5. #include <ATen/NativeFunctions.h>
  6. #else
  7. #include <ATen/ops/view_as_real_native.h>
  8. #include <ATen/ops/view_as_complex_native.h>
  9. #include <utility>
  10. #endif
  11. // WARNING: this header contains non-inline functions and should be only
  12. // included from ONE cpp file
  13. namespace at { namespace native {
  14. // View tensor with new dtype, storage offset, sizes and strides
  15. inline Tensor view_tensor(
  16. const Tensor &tensor, ScalarType dtype,
  17. c10::SymInt offset, SymIntArrayRef sizes, SymIntArrayRef strides) {
  18. Storage storage = tensor.storage();
  19. auto key_set = tensor.key_set().remove(DispatchKey::Conjugate);
  20. auto new_tensor = detail::make_tensor<TensorImpl>(
  21. c10::TensorImpl::VIEW, std::move(storage), key_set, scalarTypeToTypeMeta(dtype));
  22. auto * impl = new_tensor.unsafeGetTensorImpl();
  23. impl->set_sizes_and_strides(sizes, strides, offset);
  24. return new_tensor;
  25. }
  26. inline SymDimVector computeStrideForViewAsReal(SymIntArrayRef oldstride) {
  27. SymDimVector res(oldstride.size() + 1);
  28. for (const auto i : c10::irange(oldstride.size())) {
  29. res[i] = oldstride[i] * 2;
  30. }
  31. res.back() = 1;
  32. return res;
  33. }
  34. Tensor _view_as_real_physical(const Tensor& self) {
  35. TORCH_CHECK(self.is_complex(), "view_as_real is only supported for complex tensors");
  36. auto old_sizes = self.sym_sizes();
  37. SymDimVector new_sizes(old_sizes.size() + 1);
  38. std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin());
  39. // last dimension will always have two elements containing the real and imag vals
  40. new_sizes.back() = 2;
  41. auto new_strides = computeStrideForViewAsReal(self.sym_strides());
  42. auto new_storage_offset = self.sym_storage_offset() * 2;
  43. const auto float_type = c10::toRealValueType(self.scalar_type());
  44. auto real_tensor = view_tensor(self, float_type, std::move(new_storage_offset), new_sizes, new_strides);
  45. return real_tensor;
  46. }
  47. // expects as input a complex tensor and returns back a tensor
  48. // with corresponding real dtype containing the complex values
  49. // in the last two dimensions
  50. Tensor view_as_real(const Tensor& self) {
  51. TORCH_CHECK(!self.is_conj(), "view_as_real doesn't work on unresolved conjugated tensors. To resolve the conjugate tensor so you can view it as real, use self.resolve_conj(); however, be warned that the resulting tensor will NOT alias the original.");
  52. return _view_as_real_physical(self);
  53. }
  54. inline SymDimVector computeStrideForViewAsComplex(SymIntArrayRef oldstride) {
  55. const int64_t dim = oldstride.size();
  56. TORCH_CHECK(oldstride[dim-1] == 1, "Tensor must have a last dimension with stride 1");
  57. SymDimVector res(dim - 1);
  58. for (const auto i : c10::irange(res.size())) {
  59. TORCH_CHECK(oldstride[i] % 2 == 0, "Tensor must have a stride divisible by 2 for all but last dimension");
  60. res[i] = oldstride[i] / 2;
  61. }
  62. return res;
  63. }
  64. // expects as input a float or double tensor with last dimension of size 2
  65. // and returns back a tensor with corresponding complex dtype
  66. Tensor view_as_complex(const Tensor& self) {
  67. TORCH_CHECK(
  68. self.scalar_type() == kFloat || self.scalar_type() == kDouble || self.scalar_type() == kHalf,
  69. "view_as_complex is only supported for half, float and double tensors, but got a tensor of scalar type: ", self.scalar_type());
  70. auto old_sizes = self.sym_sizes();
  71. TORCH_CHECK(!old_sizes.empty(), "Input tensor must have one or more dimensions");
  72. TORCH_CHECK(old_sizes[old_sizes.size()-1] == 2, "Tensor must have a last dimension of size 2");
  73. SymDimVector new_sizes(old_sizes.begin(), old_sizes.end() - 1);
  74. const auto new_strides = computeStrideForViewAsComplex(self.sym_strides());
  75. const auto complex_type = c10::toComplexType(self.scalar_type());
  76. TORCH_CHECK(self.sym_storage_offset() % 2 == 0, "Tensor must have a storage_offset divisible by 2");
  77. const auto new_storage_offset = self.sym_storage_offset() / 2;
  78. return view_tensor(self, complex_type, new_storage_offset, new_sizes, new_strides);
  79. }
  80. }} // namespace at::native