ExpandUtils.h 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523
  1. #pragma once
  2. #ifndef AT_PER_OPERATOR_HEADERS
  3. #include <ATen/Functions.h>
  4. #else
  5. #include <ATen/ops/view.h>
  6. #include <ATen/ops/view_copy.h>
  7. #endif
  8. #include <ATen/Tensor.h>
  9. #include <ATen/core/DimVector.h>
  10. #include <c10/util/Exception.h>
  11. #include <c10/util/MaybeOwned.h>
  12. #include <c10/util/irange.h>
  13. #include <functional>
  14. #include <sstream>
  15. #include <tuple>
  16. #include <utility>
  17. namespace at {
  18. TORCH_API std::vector<int64_t> infer_size(IntArrayRef a, IntArrayRef b);
  19. TORCH_API DimVector infer_size_dimvector(IntArrayRef a, IntArrayRef b);
  20. TORCH_API SymDimVector
  21. infer_size_symdimvector(SymIntArrayRef a, SymIntArrayRef b);
  22. // Named type instead of a pair/tuple so that we can be sure to
  23. // construct the vectors in place and get NRVO.
  24. template <typename Container>
  25. struct InferExpandGeometryResult {
  26. Container sizes;
  27. Container strides;
  28. explicit InferExpandGeometryResult(size_t ndim)
  29. : sizes(ndim), strides(ndim) {}
  30. explicit InferExpandGeometryResult(IntArrayRef sizes_, size_t ndim)
  31. : sizes(sizes_.begin(), sizes_.end()), strides(ndim) {}
  32. };
  33. TORCH_API std::tuple<std::vector<int64_t>, std::vector<int64_t>>
  34. inferExpandGeometry(
  35. IntArrayRef tensor_sizes,
  36. IntArrayRef tensor_strides,
  37. IntArrayRef sizes);
  38. TORCH_API InferExpandGeometryResult<DimVector> inferExpandGeometry_dimvector(
  39. IntArrayRef tensor_sizes,
  40. IntArrayRef tensor_strides,
  41. IntArrayRef sizes);
  42. TORCH_API std::vector<int64_t> infer_dense_strides(
  43. IntArrayRef tensor_sizes,
  44. IntArrayRef tensor_strides);
  45. // True if input shapes are expandable
  46. // NOTE: infer_size did a similar check, please keep them sync if change is
  47. // needed
  48. inline bool are_expandable(IntArrayRef shape1, IntArrayRef shape2) {
  49. size_t ndim1 = shape1.size();
  50. size_t ndim2 = shape2.size();
  51. size_t ndim = ndim1 < ndim2 ? ndim1 : ndim2;
  52. for (int64_t i = ndim - 1; i >= 0; --i) {
  53. if (shape1[--ndim1] == shape2[--ndim2] || shape1[ndim1] == 1 ||
  54. shape2[ndim2] == 1) {
  55. continue;
  56. }
  57. return false;
  58. }
  59. return true;
  60. }
  61. // avoid copy-construction of Tensor by using a reference_wrapper.
  62. inline void check_defined(
  63. std::initializer_list<std::reference_wrapper<const Tensor>> tensors,
  64. const char* api_name) {
  65. for (auto& t : tensors) {
  66. if (!t.get().defined()) {
  67. AT_ERROR(api_name, "(...) called with an undefined Tensor");
  68. }
  69. }
  70. }
  71. // NOTE [ ExpandUtils Borrowing ]
  72. //
  73. // Functions in ExpandUtils return `c10::MaybeOwned<Tensor>` because
  74. // expansion may not actually be needed, in which case we can improve
  75. // efficiency by returning
  76. // `c10::MaybeOwned<Tensor>::borrowed(to_expand)`. However, this means
  77. // that you need to be careful: the returned `c10::MaybeOwned<Tensor>`
  78. // must not outlive the original `Tensor` object that `to_expand`
  79. // referred to! The deleted rvalue reference overloads of these
  80. // functions help with this by preventing trivial use of a temporary
  81. // resulting from a function call, but it is still possible to make a
  82. // mistake.
  83. inline c10::MaybeOwned<Tensor> expand_inplace(
  84. const Tensor& tensor,
  85. const Tensor& to_expand) {
  86. if (tensor.sym_sizes().equals(to_expand.sym_sizes())) {
  87. return c10::MaybeOwned<Tensor>::borrowed(to_expand);
  88. }
  89. return c10::MaybeOwned<Tensor>::owned(
  90. to_expand.expand_symint(tensor.sym_sizes()));
  91. }
  92. inline c10::MaybeOwned<Tensor> expand_inplace(
  93. const Tensor& tensor,
  94. Tensor&& to_expand) = delete;
  95. inline c10::MaybeOwned<Tensor> expand_inplace(
  96. const Tensor& tensor,
  97. const Tensor& to_expand,
  98. const char* api_name) {
  99. check_defined({tensor, to_expand}, api_name);
  100. return expand_inplace(tensor, to_expand);
  101. }
  102. inline c10::MaybeOwned<Tensor> expand_inplace(
  103. const Tensor& tensor,
  104. Tensor&& to_expand,
  105. const char* api_name) = delete;
  106. inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
  107. expand_inplace(
  108. const Tensor& tensor,
  109. const Tensor& to_expand1,
  110. const Tensor& to_expand2) {
  111. if (tensor.sizes().equals(to_expand1.sizes()) &&
  112. tensor.sizes().equals((to_expand2.sizes()))) {
  113. return std::make_tuple(
  114. c10::MaybeOwned<Tensor>::borrowed(to_expand1),
  115. c10::MaybeOwned<Tensor>::borrowed(to_expand2));
  116. }
  117. return std::make_tuple(
  118. c10::MaybeOwned<Tensor>::owned(to_expand1.expand(tensor.sizes())),
  119. c10::MaybeOwned<Tensor>::owned(to_expand2.expand(tensor.sizes())));
  120. }
  121. inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
  122. expand_inplace(
  123. const Tensor& tensor,
  124. Tensor&& to_expand1,
  125. const Tensor& to_expand2) = delete;
  126. inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
  127. expand_inplace(
  128. const Tensor& tensor,
  129. const Tensor& to_expand1,
  130. Tensor&& to_expand2) = delete;
  131. inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
  132. expand_inplace(const Tensor& tensor, Tensor&& to_expand1, Tensor&& to_expand2) =
  133. delete;
  134. inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
  135. expand_inplace(
  136. const Tensor& tensor,
  137. const Tensor& to_expand1,
  138. const Tensor& to_expand2,
  139. const char* api_name) {
  140. check_defined({tensor, to_expand1, to_expand2}, api_name);
  141. return expand_inplace(tensor, to_expand1, to_expand2);
  142. }
  143. inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
  144. expand_inplace(
  145. const Tensor& tensor,
  146. Tensor&& to_expand1,
  147. const Tensor& to_expand2,
  148. const char* api_name) = delete;
  149. inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
  150. expand_inplace(
  151. const Tensor& tensor,
  152. const Tensor& to_expand1,
  153. Tensor&& to_expand2,
  154. const char* api_name) = delete;
  155. inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
  156. expand_inplace(
  157. const Tensor& tensor,
  158. Tensor&& to_expand1,
  159. Tensor&& to_expand2,
  160. const char* api_name) = delete;
  161. // See NOTE [ ExpandUtils Borrowing ] above for `MaybeOwned` explanation.
  162. inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
  163. expand_outplace(const Tensor& to_expand1, const Tensor& to_expand2) {
  164. if (to_expand1.sizes().equals(to_expand2.sizes())) {
  165. return std::make_tuple(
  166. c10::MaybeOwned<Tensor>::borrowed(to_expand1),
  167. c10::MaybeOwned<Tensor>::borrowed(to_expand2));
  168. }
  169. auto expanded_size =
  170. infer_size_dimvector(to_expand1.sizes(), to_expand2.sizes());
  171. return std::make_tuple(
  172. c10::MaybeOwned<Tensor>::owned(to_expand1.expand(expanded_size)),
  173. c10::MaybeOwned<Tensor>::owned(to_expand2.expand(expanded_size)));
  174. }
  175. inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
  176. expand_outplace(Tensor&& to_expand1, const Tensor& to_expand2) = delete;
  177. inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
  178. expand_outplace(const Tensor& to_expand1, Tensor&& to_expand2) = delete;
  179. inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
  180. expand_outplace(Tensor&& to_expand1, Tensor&& to_expand2) = delete;
  181. inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
  182. expand_outplace(
  183. const Tensor& to_expand1,
  184. const Tensor& to_expand2,
  185. const char* api_name) {
  186. check_defined({to_expand1, to_expand2}, api_name);
  187. return expand_outplace(to_expand1, to_expand2);
  188. }
  189. inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
  190. expand_outplace(
  191. Tensor&& to_expand1,
  192. const Tensor& to_expand2,
  193. const char* api_name) = delete;
  194. inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
  195. expand_outplace(
  196. const Tensor& to_expand1,
  197. Tensor&& to_expand2,
  198. const char* api_name) = delete;
  199. inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
  200. expand_outplace(
  201. Tensor&& to_expand1,
  202. Tensor&& to_expand2,
  203. const char* api_name) = delete;
  204. inline std::tuple<
  205. c10::MaybeOwned<Tensor>,
  206. c10::MaybeOwned<Tensor>,
  207. c10::MaybeOwned<Tensor>>
  208. expand_outplace(
  209. const Tensor& to_expand1,
  210. const Tensor& to_expand2,
  211. const Tensor& to_expand3) {
  212. if (to_expand1.sizes().equals(to_expand2.sizes()) &&
  213. to_expand1.sizes().equals(to_expand3.sizes())) {
  214. return std::make_tuple(
  215. c10::MaybeOwned<Tensor>::borrowed(to_expand1),
  216. c10::MaybeOwned<Tensor>::borrowed(to_expand2),
  217. c10::MaybeOwned<Tensor>::borrowed(to_expand3));
  218. }
  219. auto expanded_size12 =
  220. infer_size_dimvector(to_expand1.sizes(), to_expand2.sizes());
  221. auto expanded_size =
  222. infer_size_dimvector(expanded_size12, to_expand3.sizes());
  223. return std::make_tuple(
  224. c10::MaybeOwned<Tensor>::owned(to_expand1.expand(expanded_size)),
  225. c10::MaybeOwned<Tensor>::owned(to_expand2.expand(expanded_size)),
  226. c10::MaybeOwned<Tensor>::owned(to_expand3.expand(expanded_size)));
  227. }
  228. inline std::tuple<
  229. c10::MaybeOwned<Tensor>,
  230. c10::MaybeOwned<Tensor>,
  231. c10::MaybeOwned<Tensor>>
  232. expand_outplace(
  233. Tensor&& to_expand1,
  234. const Tensor& to_expand2,
  235. const Tensor& to_expand3) = delete;
  236. inline std::tuple<
  237. c10::MaybeOwned<Tensor>,
  238. c10::MaybeOwned<Tensor>,
  239. c10::MaybeOwned<Tensor>>
  240. expand_outplace(
  241. const Tensor& to_expand1,
  242. Tensor&& to_expand2,
  243. const Tensor& to_expand3) = delete;
  244. inline std::tuple<
  245. c10::MaybeOwned<Tensor>,
  246. c10::MaybeOwned<Tensor>,
  247. c10::MaybeOwned<Tensor>>
  248. expand_outplace(
  249. Tensor&& to_expand1,
  250. Tensor&& to_expand2,
  251. const Tensor& to_expand3) = delete;
  252. inline std::tuple<
  253. c10::MaybeOwned<Tensor>,
  254. c10::MaybeOwned<Tensor>,
  255. c10::MaybeOwned<Tensor>>
  256. expand_outplace(
  257. const Tensor& to_expand1,
  258. const Tensor& to_expand2,
  259. Tensor&& to_expand3) = delete;
  260. inline std::tuple<
  261. c10::MaybeOwned<Tensor>,
  262. c10::MaybeOwned<Tensor>,
  263. c10::MaybeOwned<Tensor>>
  264. expand_outplace(
  265. Tensor&& to_expand1,
  266. const Tensor& to_expand2,
  267. Tensor&& to_expand3) = delete;
  268. inline std::tuple<
  269. c10::MaybeOwned<Tensor>,
  270. c10::MaybeOwned<Tensor>,
  271. c10::MaybeOwned<Tensor>>
  272. expand_outplace(
  273. const Tensor& to_expand1,
  274. Tensor&& to_expand2,
  275. Tensor&& to_expand3) = delete;
  276. inline std::tuple<
  277. c10::MaybeOwned<Tensor>,
  278. c10::MaybeOwned<Tensor>,
  279. c10::MaybeOwned<Tensor>>
  280. expand_outplace(Tensor&& to_expand1, Tensor&& to_expand2, Tensor&& to_expand3) =
  281. delete;
  282. inline std::tuple<
  283. c10::MaybeOwned<Tensor>,
  284. c10::MaybeOwned<Tensor>,
  285. c10::MaybeOwned<Tensor>>
  286. expand_outplace(
  287. const Tensor& to_expand1,
  288. const Tensor& to_expand2,
  289. const Tensor& to_expand3,
  290. const char* api_name) {
  291. check_defined({to_expand1, to_expand2, to_expand3}, api_name);
  292. return expand_outplace(to_expand1, to_expand2, to_expand3);
  293. }
  294. inline std::tuple<
  295. c10::MaybeOwned<Tensor>,
  296. c10::MaybeOwned<Tensor>,
  297. c10::MaybeOwned<Tensor>>
  298. expand_outplace(
  299. Tensor&& to_expand1,
  300. const Tensor& to_expand2,
  301. const Tensor& to_expand3,
  302. const char* api_name) = delete;
  303. inline std::tuple<
  304. c10::MaybeOwned<Tensor>,
  305. c10::MaybeOwned<Tensor>,
  306. c10::MaybeOwned<Tensor>>
  307. expand_outplace(
  308. const Tensor& to_expand1,
  309. Tensor&& to_expand2,
  310. const Tensor& to_expand3,
  311. const char* api_name) = delete;
  312. inline std::tuple<
  313. c10::MaybeOwned<Tensor>,
  314. c10::MaybeOwned<Tensor>,
  315. c10::MaybeOwned<Tensor>>
  316. expand_outplace(
  317. Tensor&& to_expand1,
  318. Tensor&& to_expand2,
  319. const Tensor& to_expand3,
  320. const char* api_name) = delete;
  321. inline std::tuple<
  322. c10::MaybeOwned<Tensor>,
  323. c10::MaybeOwned<Tensor>,
  324. c10::MaybeOwned<Tensor>>
  325. expand_outplace(
  326. const Tensor& to_expand1,
  327. const Tensor& to_expand2,
  328. Tensor&& to_expand3,
  329. const char* api_name) = delete;
  330. inline std::tuple<
  331. c10::MaybeOwned<Tensor>,
  332. c10::MaybeOwned<Tensor>,
  333. c10::MaybeOwned<Tensor>>
  334. expand_outplace(
  335. Tensor&& to_expand1,
  336. const Tensor& to_expand2,
  337. Tensor&& to_expand3,
  338. const char* api_name) = delete;
  339. inline std::tuple<
  340. c10::MaybeOwned<Tensor>,
  341. c10::MaybeOwned<Tensor>,
  342. c10::MaybeOwned<Tensor>>
  343. expand_outplace(
  344. const Tensor& to_expand1,
  345. Tensor&& to_expand2,
  346. Tensor&& to_expand3,
  347. const char* api_name) = delete;
  348. inline std::tuple<
  349. c10::MaybeOwned<Tensor>,
  350. c10::MaybeOwned<Tensor>,
  351. c10::MaybeOwned<Tensor>>
  352. expand_outplace(
  353. Tensor&& to_expand1,
  354. Tensor&& to_expand2,
  355. Tensor&& to_expand3,
  356. const char* api_name) = delete;
  357. inline c10::MaybeOwned<Tensor> expand_size(
  358. const Tensor& to_expand,
  359. IntArrayRef sizes) {
  360. if (to_expand.sizes().equals(sizes)) {
  361. return c10::MaybeOwned<Tensor>::borrowed(to_expand);
  362. }
  363. return c10::MaybeOwned<Tensor>::owned(to_expand.expand(sizes));
  364. }
  365. inline c10::MaybeOwned<Tensor> expand_size(
  366. Tensor&& to_expand,
  367. IntArrayRef sizes) = delete;
  368. inline c10::MaybeOwned<Tensor> expand_size(
  369. const Tensor& to_expand,
  370. IntArrayRef sizes,
  371. const char* api_name) {
  372. check_defined({to_expand}, api_name);
  373. return expand_size(to_expand, sizes);
  374. }
  375. inline c10::MaybeOwned<Tensor> expand_size(
  376. Tensor&& to_expand,
  377. IntArrayRef sizes,
  378. const char* api_name) = delete;
  379. inline std::vector<Tensor> expand_outplace(TensorList to_expand) {
  380. // expands a list of Tensors; ignores undefined (null) tensors
  381. bool first = true;
  382. DimVector sizes;
  383. for (const auto i : c10::irange(to_expand.size())) {
  384. if (!to_expand[i].defined()) {
  385. continue;
  386. } else if (first) {
  387. sizes = to_expand[i].sizes();
  388. first = false;
  389. } else {
  390. sizes = infer_size_dimvector(sizes, to_expand[i].sizes());
  391. }
  392. }
  393. std::vector<Tensor> result(to_expand.size());
  394. for (const auto i : c10::irange(to_expand.size())) {
  395. if (!to_expand[i].defined()) {
  396. continue;
  397. } else if (to_expand[i].sizes().equals(sizes)) {
  398. result[i] = to_expand[i];
  399. } else {
  400. result[i] = to_expand[i].expand(sizes);
  401. }
  402. }
  403. return result;
  404. }
  405. template <typename T>
  406. inline Tensor _sum_to(
  407. Tensor tensor,
  408. const c10::ArrayRef<T> shape,
  409. bool always_return_non_view = false) {
  410. if (shape.size() == 0) {
  411. return tensor.sum();
  412. }
  413. auto sizes = at::symint::sizes<T>(tensor);
  414. c10::SmallVector<int64_t, 8> reduce_dims;
  415. const int64_t leading_dims = sizes.size() - shape.size();
  416. for (const auto i : c10::irange(leading_dims)) {
  417. reduce_dims.push_back(i);
  418. }
  419. for (int64_t i = leading_dims; i < static_cast<int64_t>(sizes.size()); ++i) {
  420. if (shape[i - leading_dims] == 1 && sizes[i] != 1) {
  421. reduce_dims.push_back(i);
  422. }
  423. }
  424. if (!reduce_dims.empty()) {
  425. tensor = tensor.sum(reduce_dims, /*keepdim=*/true);
  426. }
  427. if (always_return_non_view) {
  428. // This is only actually used by the functionalization pass.
  429. // We want to be able to guarantee that this function doesn't return a view
  430. // of the input.
  431. return leading_dims > 0 ? at::symint::view_copy<T>(tensor, shape)
  432. : tensor.clone();
  433. } else {
  434. return leading_dims > 0 ? at::symint::view<T>(tensor, shape) : tensor;
  435. }
  436. }
  437. inline Tensor sum_to(
  438. Tensor tensor,
  439. const c10::SymIntArrayRef shape,
  440. bool always_return_non_view = false) {
  441. return _sum_to(std::move(tensor), shape, always_return_non_view);
  442. }
  443. // Sums `tensor` repeatedly to produce a tensor of shape `shape`.
  444. // Precondition: is_expandable_to(shape, tensor.sizes()) must be true
  445. inline Tensor sum_to(
  446. Tensor tensor,
  447. const IntArrayRef shape,
  448. bool always_return_non_view = false) {
  449. return _sum_to(std::move(tensor), shape, always_return_non_view);
  450. }
  451. static inline bool is_expandable_to(
  452. SymIntArrayRef shape,
  453. c10::SymIntArrayRef desired) {
  454. size_t ndim = shape.size();
  455. size_t target_dim = desired.size();
  456. if (ndim > target_dim) {
  457. return false;
  458. }
  459. for (const auto i : c10::irange(ndim)) {
  460. const auto& size = shape[ndim - i - 1];
  461. const auto& target = desired[target_dim - i - 1];
  462. if (size != target && size != 1) {
  463. return false;
  464. }
  465. }
  466. return true;
  467. }
  468. static inline bool is_expandable_to(IntArrayRef shape, IntArrayRef desired) {
  469. auto sym_shape = c10::SymIntArrayRef(
  470. reinterpret_cast<const c10::SymInt*>(shape.data()), shape.size());
  471. auto sym_desired = c10::SymIntArrayRef(
  472. reinterpret_cast<const c10::SymInt*>(desired.data()), desired.size());
  473. return is_expandable_to(sym_shape, sym_desired);
  474. }
  475. } // namespace at