CPUApplyUtils.h 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. #pragma once
  2. #include <ATen/CollapseDims.h>
  3. #include <ATen/Parallel.h>
  4. #include <ATen/TensorUtils.h>
  5. #include <c10/util/irange.h>
  6. #include <cstring>
  7. #include <limits>
  8. #include <utility>
  9. namespace at {
  10. /*
  11. * The basic strategy for apply is as follows:
  12. *
  13. * 1. Starting with the outermost index, loop until we reach a dimension where
  14. * the data is no longer contiguous, i.e. the stride at that dimension is not
  15. * equal to the size of the tensor defined by the outer dimensions. Let's call
  16. * this outer (contiguous) tensor A. Note that if the Tensor is contiguous, then
  17. * A is equal to the entire Tensor. Let's call the inner tensor B.
  18. *
  19. * 2. We loop through the indices in B, starting at its outermost dimension. For
  20. * example, if B is a 2x2 matrix, then we do:
  21. *
  22. * B[0][0]
  23. * B[0][1]
  24. * B[1][0]
  25. * B[1][1]
  26. *
  27. * We set the offset into the underlying storage as (storageOffset + stride_B *
  28. * index_B), i.e. basically we compute the offset into the storage as we would
  29. * normally for a Tensor. But because we are guaranteed the subsequent data is
  30. * contiguous in memory, we can simply loop for sizeof(A) iterations and perform
  31. * the operation, without having to follow the order described by the strides of
  32. * A.
  33. *
  34. * 3. As an optimization, we merge dimensions of A that are contiguous in
  35. * memory. For example, if A is a 3x3x3x3 tensor narrowed from a 3x3x4x3 tensor,
  36. * then the first two dimensions can be merged for the purposes of APPLY,
  37. * reducing the number of nested loops.
  38. */
  39. inline Tensor sort_strides(Tensor& tensor_) {
  40. IntArrayRef strides = tensor_.strides();
  41. std::vector<int64_t> indices;
  42. indices.reserve(tensor_.ndimension());
  43. for (const auto i : c10::irange(tensor_.ndimension())) {
  44. indices.push_back(i);
  45. }
  46. std::sort(indices.begin(), indices.end(), [&strides](int64_t i1, int64_t i2) {
  47. return strides[i1] > strides[i2];
  48. });
  49. Tensor tensor = tensor_.permute(indices);
  50. return tensor;
  51. }
  52. template <typename T, int N>
  53. struct strided_tensor_iter_fixed {
  54. public:
  55. T* data_ = NULL;
  56. int64_t dim_ = 0;
  57. int64_t counter_[N] = {0};
  58. int64_t sizes_[N] = {0};
  59. int64_t strides_[N] = {0};
  60. strided_tensor_iter_fixed(strided_tensor_iter_fixed const&) = delete;
  61. void operator=(strided_tensor_iter_fixed const& x) = delete;
  62. strided_tensor_iter_fixed(strided_tensor_iter_fixed&&) = default;
  63. strided_tensor_iter_fixed(Tensor& tensor, bool sort_strides = false)
  64. : data_(tensor.data_ptr<T>()) {
  65. (void)sort_strides; // Suppress unused variable warning
  66. std::memset(counter_, 0, sizeof(int64_t) * N);
  67. if (tensor.dim() > 0) {
  68. std::memcpy(
  69. sizes_, tensor.sizes().data(), tensor.dim() * sizeof(int64_t));
  70. std::memcpy(
  71. strides_, tensor.strides().data(), tensor.dim() * sizeof(int64_t));
  72. }
  73. dim_ = std::get<1>(collapse_dims(sizes_, strides_, tensor.ndimension()));
  74. }
  75. };
  76. template <typename T>
  77. struct strided_tensor_iter {
  78. private:
  79. public:
  80. T* data_ = NULL;
  81. int64_t dim_;
  82. std::vector<int64_t> counter_;
  83. std::vector<int64_t> sizes_;
  84. std::vector<int64_t> strides_;
  85. strided_tensor_iter(strided_tensor_iter const&) = delete;
  86. void operator=(strided_tensor_iter const& x) = delete;
  87. strided_tensor_iter(strided_tensor_iter&&) = default;
  88. strided_tensor_iter(Tensor& tensor)
  89. : data_(tensor.data_ptr<T>()),
  90. dim_(tensor.ndimension()),
  91. counter_(dim_, 0),
  92. sizes_(tensor.sizes().vec()),
  93. strides_(tensor.strides().vec()) {
  94. dim_ = std::get<1>(collapse_dims(sizes_.data(), strides_.data(), dim_));
  95. }
  96. };
  97. inline bool _all_equal_numel(at::ArrayRef<Tensor> tensors) {
  98. if (tensors.empty())
  99. return true;
  100. int64_t all_numel = tensors[0].numel();
  101. for (const auto i : c10::irange(1, tensors.size())) {
  102. if (tensors[i].numel() != all_numel)
  103. return false;
  104. }
  105. return true;
  106. }
  107. inline std::string _all_equal_numel_error(at::ArrayRef<Tensor> tensors) {
  108. std::ostringstream oss;
  109. oss << "inconsistent tensor size, expected ";
  110. for (size_t i = 0; i < tensors.size() - 1; i++) {
  111. oss << tensors[i].sizes() << ", ";
  112. }
  113. oss << "and " << tensors[tensors.size() - 1].sizes()
  114. << " to have the same number of elements, but got ";
  115. for (size_t i = 0; i < tensors.size() - 1; i++) {
  116. oss << tensors[i].numel() << ", ";
  117. }
  118. oss << "and " << tensors[tensors.size() - 1].numel()
  119. << " elements respectively";
  120. return oss.str();
  121. }
  122. inline bool _apply_preamble(ArrayRef<Tensor> tensors) {
  123. checkDeviceType("CPU_tensor_apply", tensors, kCPU);
  124. checkLayout("CPU_tensor_apply", tensors, kStrided);
  125. if (!_all_equal_numel(tensors))
  126. AT_ERROR(_all_equal_numel_error(tensors));
  127. // An empty tensor has no elements
  128. for (auto& t : tensors)
  129. if (t.numel() == 0)
  130. return false;
  131. return true;
  132. }
  133. inline int64_t _max_dim_tensors(ArrayRef<Tensor> tensors) {
  134. int64_t dim = 0;
  135. for (auto& t : tensors)
  136. dim = std::max(dim, t.ndimension());
  137. return dim;
  138. }
  139. inline void iterate(int64_t /*size*/){};
  140. template <typename Arg, typename... Args>
  141. inline void iterate(int64_t size, Arg& iter, Args&... iter_tail) {
  142. iter.counter_[iter.dim_ - 1] += size;
  143. iter.data_ = iter.data_ + size * iter.strides_[iter.dim_ - 1];
  144. iterate(size, iter_tail...);
  145. }
  146. inline bool iterate_continue() {
  147. return true;
  148. };
  149. template <typename Arg, typename... Args>
  150. inline bool iterate_continue(Arg& iter, Args&... iter_tail) {
  151. return iter.counter_[iter.dim_ - 1] < iter.sizes_[iter.dim_ - 1] &&
  152. iterate_continue(iter_tail...);
  153. }
  154. inline int64_t max_iterate_size() {
  155. return std::numeric_limits<int64_t>::max();
  156. };
  157. template <typename Arg, typename... Args>
  158. inline int64_t max_iterate_size(Arg& iter, Args&... iter_tail) {
  159. return std::min(
  160. (iter.sizes_[iter.dim_ - 1] - iter.counter_[iter.dim_ - 1]),
  161. max_iterate_size(iter_tail...));
  162. }
  163. inline void iterate_overflow(){};
  164. template <typename Arg, typename... Args>
  165. inline void iterate_overflow(Arg& iter, Args&... iter_tail) {
  166. if (iter.counter_[iter.dim_ - 1] == iter.sizes_[iter.dim_ - 1]) {
  167. for (int64_t i = iter.dim_ - 1; i > 0; i--) {
  168. if (iter.counter_[i] == iter.sizes_[i]) {
  169. iter.counter_[i] = 0;
  170. iter.counter_[i - 1]++;
  171. iter.data_ = iter.data_ - (iter.sizes_[i] * iter.strides_[i]) +
  172. iter.strides_[i - 1];
  173. }
  174. }
  175. }
  176. iterate_overflow(iter_tail...);
  177. }
  178. inline void forward(int64_t /*offset*/){};
  179. template <typename Arg, typename... Args>
  180. inline void forward(int64_t offset, Arg& iter, Args&... iter_tail) {
  181. int64_t multi = offset;
  182. for (int64_t i = iter.dim_ - 1; i >= 0; i--) {
  183. int64_t inc = multi % iter.sizes_[i];
  184. multi = multi / iter.sizes_[i];
  185. iter.data_ = iter.data_ + inc * iter.strides_[i];
  186. iter.counter_[i] += inc;
  187. }
  188. forward(offset, iter_tail...);
  189. }
  190. inline int64_t max_dim() {
  191. return 0;
  192. }
  193. template <typename Arg, typename... Args>
  194. inline int64_t max_dim(Arg& iter, Args&... iter_tail) {
  195. return std::max(iter.dim_, max_dim(iter_tail...));
  196. }
  197. inline void apply_op(){};
  198. template <typename Op, typename... Args>
  199. inline void apply_op(
  200. int64_t numel,
  201. int64_t offset,
  202. const Op& op,
  203. Args... iters) {
  204. // For 0-dim tensors
  205. if (numel == 1 && max_dim(iters...) == 0) {
  206. op(*iters.data_...);
  207. return;
  208. }
  209. if (offset > 0)
  210. forward(offset, iters...);
  211. // Splitting this into chunks helps the compiler create faster assembly
  212. for (int64_t i = 0; i < numel;) {
  213. for (; iterate_continue(iters...) && i < numel;) {
  214. op(*iters.data_...);
  215. iterate(1, iters...);
  216. i++;
  217. }
  218. iterate_overflow(iters...);
  219. }
  220. }
  221. /*
  222. Apply a pointwise operator to sequence of tensors
  223. The calling convention for op is a function/functor that takes the same
  224. number of pointers of type scalar as the number of given tensors. For example,
  225. to compute a = b * c, op would be of the form:
  226. [](scalar* a_val, const scalar* b_val, const scalar* c_val) { a_val[0] =
  227. b_val[0] * c_val[0]; };
  228. */
  229. template <typename scalar1, typename scalar2, typename Op>
  230. inline void CPU_tensor_apply2(Tensor tensor1, Tensor tensor2, const Op op) {
  231. if (!_apply_preamble({tensor1, tensor2}))
  232. return;
  233. if (_max_dim_tensors({tensor1, tensor2}) <= 8) {
  234. apply_op(
  235. tensor1.numel(),
  236. 0,
  237. op,
  238. strided_tensor_iter_fixed<scalar1, 8>(tensor1),
  239. strided_tensor_iter_fixed<scalar2, 8>(tensor2));
  240. } else {
  241. apply_op(
  242. tensor1.numel(),
  243. 0,
  244. op,
  245. strided_tensor_iter<scalar1>(tensor1),
  246. strided_tensor_iter<scalar2>(tensor2));
  247. }
  248. }
  249. template <typename scalar1, typename scalar2, typename scalar3, typename Op>
  250. inline void CPU_tensor_apply3(
  251. Tensor tensor1,
  252. Tensor tensor2,
  253. Tensor tensor3,
  254. const Op op) {
  255. if (!_apply_preamble({tensor1, tensor2, tensor3}))
  256. return;
  257. if (_max_dim_tensors({tensor1, tensor2, tensor3}) <= 8) {
  258. apply_op(
  259. tensor1.numel(),
  260. 0,
  261. op,
  262. strided_tensor_iter_fixed<scalar1, 8>(tensor1),
  263. strided_tensor_iter_fixed<scalar2, 8>(tensor2),
  264. strided_tensor_iter_fixed<scalar3, 8>(tensor3));
  265. } else {
  266. apply_op(
  267. tensor1.numel(),
  268. 0,
  269. op,
  270. strided_tensor_iter<scalar1>(tensor1),
  271. strided_tensor_iter<scalar2>(tensor2),
  272. strided_tensor_iter<scalar3>(tensor3));
  273. }
  274. }
  275. template <
  276. typename scalar1,
  277. typename scalar2,
  278. typename scalar3,
  279. typename scalar4,
  280. typename Op>
  281. inline void CPU_tensor_apply4(
  282. Tensor tensor1,
  283. Tensor tensor2,
  284. Tensor tensor3,
  285. Tensor tensor4,
  286. const Op op) {
  287. if (!_apply_preamble({tensor1, tensor2, tensor3, tensor4}))
  288. return;
  289. if (_max_dim_tensors({tensor1, tensor2, tensor3, tensor4}) <= 8) {
  290. apply_op(
  291. tensor1.numel(),
  292. 0,
  293. op,
  294. strided_tensor_iter_fixed<scalar1, 8>(tensor1),
  295. strided_tensor_iter_fixed<scalar2, 8>(tensor2),
  296. strided_tensor_iter_fixed<scalar3, 8>(tensor3),
  297. strided_tensor_iter_fixed<scalar4, 8>(tensor4));
  298. } else {
  299. apply_op(
  300. tensor1.numel(),
  301. 0,
  302. op,
  303. strided_tensor_iter<scalar1>(tensor1),
  304. strided_tensor_iter<scalar2>(tensor2),
  305. strided_tensor_iter<scalar3>(tensor3),
  306. strided_tensor_iter<scalar4>(tensor4));
  307. }
  308. }
  309. } // namespace at