ReduceOpsUtils.h 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. #pragma once
  2. #include <limits>
  3. #include <ATen/core/Tensor.h>
  4. #include <ATen/native/Resize.h>
  5. #include <ATen/native/TensorIterator.h>
  6. #include <ATen/native/NonEmptyUtils.h>
  7. #include <ATen/WrapDimUtilsMulti.h>
  8. #include <c10/core/ScalarType.h>
  9. #include <c10/util/irange.h>
  10. #ifndef AT_PER_OPERATOR_HEADERS
  11. #include <ATen/Functions.h>
  12. #else
  13. #include <ATen/ops/empty.h>
  14. #include <ATen/ops/scalar_tensor.h>
  15. #endif
  16. namespace at { namespace native {
  17. // Maximum and minimum possible scalar values, including infinities
  18. template <typename scalar_t>
  19. constexpr scalar_t upper_bound() {
  20. using lim = std::numeric_limits<scalar_t>;
  21. return lim::has_infinity ? lim::infinity() : lim::max();
  22. }
  23. template <typename scalar_t>
  24. constexpr scalar_t lower_bound() {
  25. using lim = std::numeric_limits<scalar_t>;
  26. return lim::has_infinity ? -lim::infinity() : lim::lowest();
  27. }
  28. static inline Tensor restride_dim(
  29. const Tensor& src, int64_t dim,
  30. IntArrayRef replacement_shape
  31. ) {
  32. auto strides = ensure_nonempty_vec(src.strides().vec());
  33. strides[dim] = 0;
  34. return src.as_strided(replacement_shape, strides);
  35. }
  36. inline void _dimreduce_setup(const Tensor &result, const Tensor &self,
  37. int64_t dim) {
  38. IntArrayRef self_sizes = self.sizes();
  39. std::vector<int64_t> result_sizes;
  40. result_sizes.insert(result_sizes.end(), self_sizes.begin(), self_sizes.end());
  41. result_sizes[dim] = 1;
  42. result.resize_(result_sizes);
  43. }
  44. inline bool _dimreduce_return_trivial(const Tensor &result, const Tensor &self,
  45. const Scalar& ident, int64_t dim, bool keepdim) {
  46. if (self.numel() == 1 && self.ndimension() == 0) {
  47. result.resize_({});
  48. result.fill_(self);
  49. return true;
  50. }
  51. // Return identity
  52. if (self.numel() == 0) {
  53. _dimreduce_setup(result, self, dim);
  54. result.fill_(ident);
  55. if (!keepdim) result.squeeze_(dim);
  56. return true;
  57. }
  58. return false;
  59. }
  60. inline bool _dimreduce_return_trivial_no_ident(Tensor &result, const Tensor &self,
  61. int64_t /*dim*/, bool /*keepdim*/, const char* /*fn_name*/) {
  62. if (self.numel() == 1 && self.ndimension() == 0) {
  63. result.resize_({});
  64. result.fill_(self);
  65. return true;
  66. }
  67. return false;
  68. }
  69. inline c10::optional<Tensor> _allreduce_return_trivial(
  70. const Tensor& self,
  71. const Scalar& ident) {
  72. // Return identity
  73. if (self.numel() == 0) {
  74. return at::scalar_tensor(ident, self.options());
  75. }
  76. return c10::nullopt;
  77. }
  78. #define OPTION_TYPE_EQUALITY_CHECK(option, out, self) \
  79. { \
  80. TORCH_CHECK(\
  81. out.option() == self.option(),\
  82. "expected ", #option, " ",\
  83. self.option(),\
  84. " but found ", out.option())\
  85. }
  86. static inline void check_scalar_type_device_layout_equal(const Tensor& out, const Tensor& self) {
  87. OPTION_TYPE_EQUALITY_CHECK(scalar_type, out, self);
  88. OPTION_TYPE_EQUALITY_CHECK(device, out.options(), self.options());
  89. OPTION_TYPE_EQUALITY_CHECK(layout, out.options(), self.options());
  90. }
  91. static inline Tensor integer_upcast(const Tensor& self, c10::optional<ScalarType> dtype) {
  92. ScalarType scalarType = self.scalar_type();
  93. ScalarType upcast_scalarType = dtype.value_or(at::isIntegralType(scalarType, /*includeBool=*/true) ? ScalarType::Long : scalarType);
  94. return self.toType(upcast_scalarType);
  95. }
  96. using DimMask = TensorIterator::DimMask;
  97. static DimVector make_dim_vector(OptionalIntArrayRef opt_dims, int64_t ndim) {
  98. if (opt_dims.has_value()) {
  99. return DimVector(opt_dims.value());
  100. } else {
  101. std::vector<int64_t> all_dims(ndim);
  102. std::iota(all_dims.begin(), all_dims.end(), 0);
  103. return DimVector(all_dims);
  104. }
  105. }
  106. static DimMask make_dim_mask(OptionalIntArrayRef opt_dims, int64_t ndim) {
  107. DimMask mask;
  108. if (opt_dims.has_value()) {
  109. auto dims = opt_dims.value();
  110. if (dims.empty()) {
  111. mask = DimMask().flip();
  112. } else {
  113. mask = at::dim_list_to_bitset(dims, ndim);
  114. }
  115. } else {
  116. mask = DimMask().flip();
  117. }
  118. return mask;
  119. }
  120. inline DimVector shape_from_dim_mask(const Tensor& self, DimMask mask, bool keepdim) {
  121. auto shape = DimVector(self.sizes());
  122. for (int dim = shape.size() - 1; dim >= 0; dim--) {
  123. if (mask[dim]) {
  124. if (keepdim) {
  125. shape[dim] = 1;
  126. } else {
  127. shape.erase(shape.begin() + dim);
  128. }
  129. }
  130. }
  131. return shape;
  132. }
  133. static void resize_reduction_result(
  134. Tensor& result, const Tensor& self, DimMask mask, bool keepdim,
  135. ScalarType /*dtype*/)
  136. {
  137. auto shape = shape_from_dim_mask(self, mask, keepdim);
  138. TORCH_CHECK(result.defined(), "Cannot create a new tensor inside a reduction op. You likely tried to call an operator with an out argument but the out argument was an undefined tensor.");
  139. at::native::resize_output(result, shape);
  140. }
  141. inline Tensor create_reduction_result(
  142. const Tensor& self, at::OptionalIntArrayRef dim, bool keepdim, ScalarType dtype
  143. ) {
  144. DimMask mask = make_dim_mask(dim, self.dim());
  145. auto shape = shape_from_dim_mask(self, mask, keepdim);
  146. return at::empty(shape, self.options().dtype(dtype));
  147. }
  148. static Tensor review_reduce_result(const Tensor& result, int ndim, DimMask mask, bool keepdim) {
  149. if (keepdim) {
  150. return result;
  151. }
  152. auto shape = DimVector(result.sizes());
  153. auto stride = DimVector(result.strides());
  154. for (const auto dim : c10::irange(ndim)) {
  155. if (mask[dim]) {
  156. shape.insert(shape.begin() + dim, 1);
  157. stride.insert(stride.begin() + dim, 0);
  158. }
  159. }
  160. return result.as_strided(shape, stride);
  161. }
  162. static TensorIterator make_reduction(
  163. const char* name, Tensor& result, const Tensor& self,
  164. at::OptionalIntArrayRef dim_opt,
  165. bool keepdim, ScalarType in_dtype, ScalarType out_dtype) {
  166. // check that result type and dtype match if provided
  167. TORCH_CHECK(
  168. !result.defined() || result.scalar_type() == out_dtype,
  169. name, ": provided dtype must match dtype of result. Got ",
  170. toString(result.scalar_type()),
  171. " and ",
  172. toString(out_dtype),
  173. ".");
  174. // dim={} performs an all-reduce, same as dim=None
  175. IntArrayRef dim = dim_opt.value_or(IntArrayRef{});
  176. int64_t ndim = self.dim();
  177. auto mask = make_dim_mask(dim, ndim);
  178. resize_reduction_result(result, self, mask, keepdim, out_dtype);
  179. auto viewed_result = review_reduce_result(result, ndim, mask, keepdim);
  180. namedinference::propagate_names_for_reduction(result, self, dim, keepdim);
  181. if (self.scalar_type() == in_dtype) {
  182. return TensorIterator::reduce_op(viewed_result, self);
  183. }
  184. return TensorIterator::reduce_op(viewed_result, self.to(in_dtype));
  185. }
  186. static C10_UNUSED TensorIterator make_reduction(
  187. const char* name, Tensor& result, const Tensor& self,
  188. at::OptionalIntArrayRef dim, bool keepdim, ScalarType out_dtype) {
  189. // special case for type promotion in mixed precision, improves computational
  190. // efficiency.
  191. // not generalize this to common mismatched input/output types to avoid cross
  192. // product of templated kernel launches.
  193. const bool gpu_lowp_to_f32 = (
  194. self.is_cuda() && (self.scalar_type() == kHalf || self.scalar_type() == kBFloat16) && out_dtype == kFloat);
  195. auto in_dtype = gpu_lowp_to_f32 ? self.scalar_type()
  196. : self.is_complex() ? c10::toComplexType(out_dtype)
  197. : out_dtype;
  198. return make_reduction(name, result, self, dim, keepdim, in_dtype, out_dtype);
  199. }
  200. static TensorIterator make_reduction(
  201. const char* name, Tensor& result1, Tensor& result2, const Tensor& self,
  202. at::OptionalIntArrayRef dim_opt, bool keepdim, ScalarType dtype1,
  203. ScalarType dtype2) {
  204. // check that result type and dtype match if provided
  205. TORCH_CHECK(
  206. (!result1.defined() || result1.scalar_type() == dtype1) && (!result2.defined() || result2.scalar_type() == dtype2),
  207. name, ": provided dtype must match dtype of result. Got ",
  208. toString(result1.scalar_type()), toString(result2.scalar_type()),
  209. " and ",
  210. toString(dtype1), toString(dtype2),
  211. ".");
  212. // dim={} performs an all-reduce, same as dim=None
  213. auto dim = dim_opt.value_or(IntArrayRef{});
  214. int64_t ndim = self.dim();
  215. DimMask mask = make_dim_mask(dim, ndim);
  216. resize_reduction_result(result1, self, mask, keepdim, dtype1);
  217. auto viewed_result1 = review_reduce_result(result1, ndim, mask, keepdim);
  218. resize_reduction_result(result2, self, mask, keepdim, dtype2);
  219. auto viewed_result2 = review_reduce_result(result2, ndim, mask, keepdim);
  220. namedinference::propagate_names_for_reduction(result1, self, dim, keepdim);
  221. namedinference::propagate_names_for_reduction(result2, self, dim, keepdim);
  222. // special case for type promotion in mixed precision, improves computational
  223. // efficiency.
  224. // We don't generalize this to common mismatched input/output types to avoid cross
  225. // product of templated kernel launches.
  226. if (self.scalar_type() == dtype1 ||
  227. (self.is_cuda() && self.scalar_type() == kHalf && dtype1 == kFloat)) {
  228. return TensorIterator::reduce_op(viewed_result1, viewed_result2, self);
  229. }
  230. return TensorIterator::reduce_op(viewed_result1, viewed_result2, self.to(dtype1));
  231. }
  232. static C10_UNUSED TensorIterator make_reduction(
  233. const char* name, Tensor& result1, Tensor& result2, const Tensor& self,
  234. at::OptionalIntArrayRef dim, bool keepdim, ScalarType dtype) {
  235. return make_reduction(name, result1, result2, self, dim, keepdim, dtype, dtype);
  236. }
  237. static void zero_numel_check_dims(const Tensor& self, const int64_t dim, const char *fn_name) {
  238. if (self.ndimension() == 0) {
  239. TORCH_CHECK_INDEX(dim == 0 || dim == -1, fn_name,
  240. ": Expected reduction dim -1 or 0 for scalar but got ", dim);
  241. }
  242. else {
  243. TORCH_CHECK_INDEX(self.size(dim) != 0, fn_name,
  244. ": Expected reduction dim ", dim, " to have non-zero size.");
  245. }
  246. }
  247. static void zero_numel_check_dims(const Tensor& self, const IntArrayRef dim, const char *fn_name) {
  248. TORCH_CHECK(
  249. !dim.empty(),
  250. fn_name, ": Expected reduction dim to be specified for input.numel() == 0. ",
  251. "Specify the reduction dim with the 'dim' argument.");
  252. for (const int64_t d : dim) {
  253. zero_numel_check_dims(self, d, fn_name);
  254. }
  255. }
  256. static std::vector<int64_t> get_zero_numel_tensor_size(
  257. const Tensor& self,
  258. const int64_t dim,
  259. const bool keepdim,
  260. const char* fn_name) {
  261. TORCH_INTERNAL_ASSERT(self.numel() == 0, fn_name, ": Expected self.numel() == 0.");
  262. zero_numel_check_dims(self, dim, fn_name);
  263. std::vector<int64_t> sizes;
  264. if (keepdim) {
  265. sizes = self.sizes().vec();
  266. sizes[dim] = 1;
  267. }
  268. else {
  269. for (const auto d : c10::irange(self.dim())) {
  270. if (d != dim) {
  271. sizes.push_back(self.sizes()[d]);
  272. }
  273. }
  274. }
  275. return sizes;
  276. }
  277. // Resize the result tensor and indices when result.numel() == 0 depending on values of
  278. // dim and keepdim for returning tensors containing reduction results.
  279. // This function should be called when you are reducing a zero-numel tensor and want to
  280. // resize the output and return it. This function exists for resizing zero-numel
  281. // tensors when the size of the reduction dimension is non-zero.
  282. static C10_UNUSED void zero_numel_tensor_resize(Tensor& result, Tensor& result_indices,
  283. const Tensor& self, const int64_t dim,
  284. const bool keepdim, const char *fn_name) {
  285. auto sizes = get_zero_numel_tensor_size(self, dim, keepdim, fn_name);
  286. at::native::resize_output(result, sizes);
  287. at::native::resize_output(result_indices, sizes);
  288. }
  289. inline ScalarType get_dtype_from_self(
  290. const Tensor& self,
  291. const c10::optional<ScalarType>& dtype,
  292. bool promote_integers) {
  293. if (dtype.has_value()) {
  294. return dtype.value();
  295. }
  296. ScalarType src_type = self.scalar_type();
  297. if (promote_integers && at::isIntegralType(src_type, /*includeBool=*/true)) {
  298. return kLong;
  299. }
  300. return src_type;
  301. }
  302. inline ScalarType get_dtype_from_result(Tensor& result, c10::optional<ScalarType> dtype) {
  303. TORCH_CHECK(result.defined(), "Cannot create a new tensor inside a reduction op. You likely tried to call an operator with an out argument but the out argument was an undefined tensor.");
  304. if (dtype.has_value()) {
  305. return dtype.value();
  306. } else {
  307. return result.scalar_type();
  308. }
  309. }
  310. } // native
  311. namespace meta {
  312. static C10_UNUSED DimVector get_reduction_shape(
  313. const Tensor& self,
  314. IntArrayRef dims,
  315. bool keepdim) {
  316. auto mask = native::make_dim_mask(dims, self.dim());
  317. return native::shape_from_dim_mask(self, mask, keepdim);
  318. }
  319. static void resize_reduction(
  320. impl::MetaBase& meta,
  321. const Tensor& self,
  322. OptionalIntArrayRef opt_dims,
  323. bool keepdim,
  324. ScalarType out_dtype) {
  325. DimVector dims_ = at::native::make_dim_vector(opt_dims, self.dim());
  326. maybe_wrap_dims(dims_, self.dim());
  327. auto shape = get_reduction_shape(self, dims_, keepdim);
  328. meta.set_output_raw_strided(0, shape, {}, self.options().dtype(out_dtype));
  329. namedinference::propagate_names_for_reduction(
  330. meta.maybe_get_output(), self, dims_, keepdim);
  331. }
  332. static void resize_reduction_with_indices(
  333. impl::MetaBase& meta,
  334. const Tensor& self,
  335. IntArrayRef dims,
  336. bool keepdim,
  337. ScalarType out_dtype) {
  338. DimVector dims_(dims);
  339. maybe_wrap_dims(dims_, self.dim());
  340. auto shape = get_reduction_shape(self, dims_, keepdim);
  341. meta.set_output_raw_strided(0, shape, {}, self.options().dtype(out_dtype));
  342. meta.set_output_raw_strided(1, shape, {}, self.options().dtype(kLong));
  343. namedinference::propagate_names_for_reduction(
  344. meta.maybe_get_output(0), self, dims_, keepdim);
  345. namedinference::propagate_names_for_reduction(
  346. meta.maybe_get_output(1), self, dims_, keepdim);
  347. }
  348. static TensorIterator make_reduction(
  349. const Tensor& self,
  350. const Tensor& result,
  351. OptionalIntArrayRef opt_dims,
  352. bool keepdim,
  353. ScalarType in_dtype) {
  354. int64_t ndim = self.dim();
  355. auto mask = at::native::make_dim_mask(opt_dims, ndim);
  356. auto viewed_result =
  357. at::native::review_reduce_result(result, ndim, mask, keepdim);
  358. if (self.scalar_type() == in_dtype) {
  359. return TensorIterator::reduce_op(viewed_result, self);
  360. }
  361. return TensorIterator::reduce_op(viewed_result, self.to(in_dtype));
  362. }
  363. static TensorIterator make_reduction(
  364. const Tensor& self,
  365. const Tensor& result1,
  366. const Tensor& result2,
  367. IntArrayRef dims,
  368. bool keepdim,
  369. ScalarType dtype1,
  370. ScalarType /*dtype2*/) {
  371. int64_t ndim = self.dim();
  372. auto mask = at::native::make_dim_mask(dims, ndim);
  373. auto viewed_result1 = at::native::review_reduce_result(result1, ndim, mask, keepdim);
  374. auto viewed_result2 = at::native::review_reduce_result(result2, ndim, mask, keepdim);
  375. // special case for type promotion in mixed precision, improves computational efficiency.
  376. // We don't generalize this to common mismatched input/output types to avoid cross product
  377. // of templated kernel launches.
  378. if (self.scalar_type() == dtype1 ||
  379. (self.is_cuda() && self.scalar_type() == kHalf && dtype1 == kFloat)) {
  380. return TensorIterator::reduce_op(viewed_result1, viewed_result2, self);
  381. }
  382. return TensorIterator::reduce_op(viewed_result1, viewed_result2, self.to(dtype1));
  383. }
  384. static C10_UNUSED TensorIterator make_reduction_from_out_ty(
  385. const Tensor& self,
  386. const Tensor& result,
  387. OptionalIntArrayRef opt_dims,
  388. bool keepdim,
  389. ScalarType out_dtype) {
  390. // special case for type promotion in mixed precision, improves computational
  391. // efficiency.
  392. // not generalize this to common mismatched input/output types to avoid cross
  393. // product of templated kernel launches.
  394. const bool gpu_lowp_to_f32 =
  395. (self.is_cuda() &&
  396. (self.scalar_type() == kHalf || self.scalar_type() == kBFloat16) &&
  397. out_dtype == kFloat);
  398. auto in_dtype = gpu_lowp_to_f32 ? self.scalar_type() : out_dtype;
  399. return make_reduction(self, result, opt_dims, keepdim, in_dtype);
  400. }
  401. } // namespace meta
  402. } // namespace at