TensorUtils.h 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. #pragma once
  2. #include <ATen/DimVector.h>
  3. #include <ATen/EmptyTensor.h>
  4. #include <ATen/Tensor.h>
  5. #include <ATen/TensorGeometry.h>
  6. #include <ATen/Utils.h>
  7. #include <utility>
  8. // These functions are NOT in Utils.h, because this file has a dep on Tensor.h
  9. #define TORCH_CHECK_TENSOR_ALL(cond, ...) \
  10. TORCH_CHECK((cond)._is_all_true().item<bool>(), __VA_ARGS__);
  11. namespace at {
  12. // The following are utility functions for checking that arguments
  13. // make sense. These are particularly useful for native functions,
  14. // which do NO argument checking by default.
  15. struct TORCH_API TensorArg {
  16. const Tensor& tensor;
  17. const char* name;
  18. int pos; // 1-indexed
  19. TensorArg(const Tensor& tensor, const char* name, int pos)
  20. : tensor(tensor), name(name), pos(pos) {}
  21. // Try to mitigate any possibility of dangling reference to temporaries.
  22. TensorArg(Tensor&& tensor, const char* name, int pos) = delete;
  23. const Tensor* operator->() const {
  24. return &tensor;
  25. }
  26. const Tensor& operator*() const {
  27. return tensor;
  28. }
  29. };
  30. struct TORCH_API TensorGeometryArg {
  31. TensorGeometry tensor;
  32. const char* name;
  33. int pos; // 1-indexed
  34. /* implicit */ TensorGeometryArg(TensorArg arg)
  35. : tensor(TensorGeometry{arg.tensor}), name(arg.name), pos(arg.pos) {}
  36. TensorGeometryArg(TensorGeometry tensor, const char* name, int pos)
  37. : tensor(std::move(tensor)), name(name), pos(pos) {}
  38. const TensorGeometry* operator->() const {
  39. return &tensor;
  40. }
  41. const TensorGeometry& operator*() const {
  42. return tensor;
  43. }
  44. };
  45. // A string describing which function did checks on its input
  46. // arguments.
  47. // TODO: Consider generalizing this into a call stack.
  48. using CheckedFrom = const char*;
  49. // The undefined convention: singular operators assume their arguments
  50. // are defined, but functions which take multiple tensors will
  51. // implicitly filter out undefined tensors (to make it easier to perform
  52. // tests which should apply if the tensor is defined, and should not
  53. // otherwise.)
  54. //
  55. // NB: This means that the n-ary operators take lists of TensorArg,
  56. // not TensorGeometryArg, because the Tensor to TensorGeometry
  57. // conversion will blow up if you have undefined tensors.
  58. TORCH_API std::ostream& operator<<(std::ostream& out, TensorGeometryArg t);
  59. TORCH_API void checkDim(
  60. CheckedFrom c,
  61. const Tensor& tensor,
  62. const char* name,
  63. int pos, // 1-indexed
  64. int64_t dim);
  65. TORCH_API void checkDim(CheckedFrom c, const TensorGeometryArg& t, int64_t dim);
  66. // NB: this is an inclusive-exclusive range
  67. TORCH_API void checkDimRange(
  68. CheckedFrom c,
  69. const TensorGeometryArg& t,
  70. int64_t dim_start,
  71. int64_t dim_end);
  72. TORCH_API void checkSameDim(
  73. CheckedFrom c,
  74. const TensorGeometryArg& t1,
  75. const TensorGeometryArg& t2);
  76. TORCH_API void checkContiguous(CheckedFrom c, const TensorGeometryArg& t);
  77. TORCH_API void checkAllContiguous(CheckedFrom c, at::ArrayRef<TensorArg> ts);
  78. TORCH_API void checkSize(
  79. CheckedFrom c,
  80. const TensorGeometryArg& t,
  81. IntArrayRef sizes);
  82. TORCH_API void checkSize_symint(
  83. CheckedFrom c,
  84. const TensorGeometryArg& t,
  85. c10::SymIntArrayRef sizes);
  86. TORCH_API void checkSize(
  87. CheckedFrom c,
  88. const TensorGeometryArg& t,
  89. int64_t dim,
  90. int64_t size);
  91. TORCH_API void checkSize_symint(
  92. CheckedFrom c,
  93. const TensorGeometryArg& t,
  94. int64_t dim,
  95. c10::SymInt size);
  96. TORCH_API void checkNumel(
  97. CheckedFrom c,
  98. const TensorGeometryArg& t,
  99. int64_t numel);
  100. TORCH_API void checkSameNumel(
  101. CheckedFrom c,
  102. const TensorGeometryArg& t1,
  103. const TensorGeometryArg& t2);
  104. TORCH_API void checkAllSameNumel(CheckedFrom c, ArrayRef<TensorArg> tensors);
  105. TORCH_API void checkScalarType(CheckedFrom c, const TensorArg& t, ScalarType s);
  106. TORCH_API void checkScalarTypes(
  107. CheckedFrom c,
  108. const TensorArg& t,
  109. at::ArrayRef<ScalarType> l);
  110. TORCH_API void checkSameGPU(
  111. CheckedFrom c,
  112. const TensorArg& t1,
  113. const TensorArg& t2);
  114. TORCH_API void checkAllSameGPU(CheckedFrom c, ArrayRef<TensorArg> tensors);
  115. TORCH_API void checkSameType(
  116. CheckedFrom c,
  117. const TensorArg& t1,
  118. const TensorArg& t2);
  119. TORCH_API void checkAllSameType(CheckedFrom c, ArrayRef<TensorArg> tensors);
  120. TORCH_API void checkSameSize(
  121. CheckedFrom c,
  122. const TensorArg& t1,
  123. const TensorArg& t2);
  124. TORCH_API void checkDefined(CheckedFrom c, const TensorArg& t);
  125. TORCH_API void checkAllDefined(CheckedFrom c, at::ArrayRef<TensorArg> t);
  126. // FixMe: does TensorArg slow things down?
  127. TORCH_API void checkBackend(
  128. CheckedFrom c,
  129. at::ArrayRef<Tensor> t,
  130. at::Backend backend);
  131. TORCH_API void checkDeviceType(
  132. CheckedFrom c,
  133. at::ArrayRef<Tensor> tensors,
  134. at::DeviceType device_type);
  135. TORCH_API void checkLayout(CheckedFrom c, const Tensor& t, Layout layout);
  136. TORCH_API void checkLayout(
  137. CheckedFrom c,
  138. at::ArrayRef<Tensor> tensors,
  139. at::Layout layout);
  140. // Methods for getting data_ptr if tensor is defined
  141. TORCH_API void* maybe_data_ptr(const Tensor& tensor);
  142. TORCH_API void* maybe_data_ptr(const TensorArg& tensor);
  143. TORCH_API void check_dim_size(
  144. const Tensor& tensor,
  145. int64_t dim,
  146. int64_t dim_size,
  147. int64_t size);
  148. namespace detail {
  149. TORCH_API std::vector<int64_t> defaultStrides(IntArrayRef sizes);
  150. TORCH_API c10::optional<std::vector<int64_t>> computeStride(
  151. IntArrayRef oldshape,
  152. IntArrayRef oldstride,
  153. IntArrayRef newshape);
  154. TORCH_API c10::optional<SymDimVector> computeStride(
  155. c10::SymIntArrayRef oldshape,
  156. c10::SymIntArrayRef oldstride,
  157. c10::SymIntArrayRef newshape);
  158. TORCH_API c10::optional<DimVector> computeStride(
  159. IntArrayRef oldshape,
  160. IntArrayRef oldstride,
  161. const DimVector& newshape);
  162. } // namespace detail
  163. } // namespace at