TensorIterator.h 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971
  1. #pragma once
  2. #include <ATen/TensorMeta.h>
  3. #include <ATen/core/Dimname.h>
  4. #include <ATen/core/Range.h>
  5. #include <ATen/core/TensorBase.h>
  6. #include <c10/core/DynamicCast.h>
  7. #include <c10/util/FunctionRef.h>
  8. #include <c10/util/MaybeOwned.h>
  9. #include <c10/util/SmallVector.h>
  10. #include <c10/util/TypeCast.h>
  11. #include <c10/util/irange.h>
  12. #include <array>
  13. #include <bitset>
  14. C10_CLANG_DIAGNOSTIC_PUSH()
  15. #if C10_CLANG_HAS_WARNING("-Wshorten-64-to-32")
  16. C10_CLANG_DIAGNOSTIC_IGNORE("-Wshorten-64-to-32")
  17. #endif
  18. #if C10_CLANG_HAS_WARNING("-Wdeprecated-copy-dtor")
  19. C10_CLANG_DIAGNOSTIC_IGNORE("-Wdeprecated-copy-dtor")
  20. #endif
  21. namespace at {
  22. class Tensor;
  23. class OptionalTensorRef;
  24. using NameVector = SmallVector<Dimname, kDimVectorStaticSize>;
  25. } // namespace at
  26. // TensorIterator is a helper class for element-wise operations, such as
  27. // arithmetic, comparisons, and trigonometric functions. It handles
  28. // broadcasting and type conversions of operands.
  29. //
  30. // This is inspired by NumPy's Array Iterator API (NpyIter).
  31. //
  32. // The files Loops.h and Loops.cuh provide functions to build kernels that
  33. // use TensorIterator.
  34. //
  35. // Example:
  36. //
  37. // auto iter = TensorIteratorConfig()
  38. // .add_output(output)
  39. // .add_input(input)
  40. // .build()
  41. //
  42. // [MyKernel.cpp / MyKernel.cu]
  43. // cpu_kernel(iter, [](float a, float b) {
  44. // return a + b;
  45. // });
  46. //
  47. // gpu_kernel(iter, []GPU_LAMBDA(float a, float b) -> float {
  48. // return a + b;
  49. // });
  50. //
  51. // Note [Order of Construction]
  52. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  53. // When setting up the tensor iterator configuration, the output Tensors
  54. // have to be added first via
  55. // TensorIteratorConfig::add_owned_output(at::Tensor). After adding all outputs,
  56. // the inputs can be added via
  57. // TensorIteratorConfig::add_owned_input(at::Tensor).
  58. // Adding another output after inputs have been added will rise an exception.
  59. //
  60. // Note [Common Dtype Computation]
  61. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  62. // Some operations have a natural notion of a "common dtype" or
  63. // "computation dtype" where all inputs are cast to one dtype, the
  64. // operation is performed, and then the results are cast to all outputs.
  65. //
  66. // TensorIterator infers a common dtype if all inputs have the same dtype,
  67. // and it computes one using type promotion rules on its inputs if
  68. // promote_inputs_to_common_dtype_ is true. Attempting to query
  69. // a common dtype otherwise will throw an exception.
  70. //
  71. // Note that the outputs are not considered when computing a common dtype.
  72. namespace at {
  73. namespace internal {
  74. // This parameter is heuristically chosen to determine the minimum number of
  75. // work that warrants parallelism. For example, when summing an array, it is
  76. // deemed inefficient to parallelise over arrays shorter than 32768. Further,
  77. // no parallel algorithm (such as parallel_reduce) should split work into
  78. // smaller than GRAIN_SIZE chunks.
  79. constexpr int64_t GRAIN_SIZE = 32768;
  80. // Storage for a non-owning Tensor, without needing to include Tensor.h
  81. class TORCH_API OpaqueOptionalTensorRef {
  82. alignas(alignof(TensorBase)) std::array<char, sizeof(TensorBase)> data_;
  83. public:
  84. OpaqueOptionalTensorRef();
  85. ~OpaqueOptionalTensorRef();
  86. OptionalTensorRef* get() {
  87. return reinterpret_cast<OptionalTensorRef*>(data_.data());
  88. }
  89. const OptionalTensorRef* get() const {
  90. return reinterpret_cast<const OptionalTensorRef*>(data_.data());
  91. }
  92. OptionalTensorRef& operator*() {
  93. return *get();
  94. }
  95. const OptionalTensorRef& operator*() const {
  96. return *get();
  97. }
  98. OptionalTensorRef* operator->() {
  99. return get();
  100. }
  101. const OptionalTensorRef* operator->() const {
  102. return get();
  103. }
  104. const Tensor& getTensor() const;
  105. };
  106. } // namespace internal
  107. struct TORCH_API OperandInfo {
  108. using StrideVector = SmallVector<int64_t, 6>;
  109. OperandInfo() = default;
  110. C10_ALWAYS_INLINE explicit OperandInfo(c10::MaybeOwned<TensorBase>&& t) {
  111. if (t->defined()) {
  112. device = t->device();
  113. target_dtype = t->scalar_type();
  114. current_dtype = target_dtype;
  115. }
  116. tensor(std::move(t));
  117. validate();
  118. }
  119. C10_ALWAYS_INLINE ~OperandInfo() = default;
  120. /// Stride after broadcasting. The stride is in bytes, not number of elements.
  121. StrideVector stride_bytes;
  122. /// The desired device and type for the operand. For inputs, this specifies
  123. /// that the input should be converted to this type if necessary. For outputs,
  124. /// this specifies which type to allocate. target_dtype and device are
  125. /// initialized with the dtype and device of the tensor but during type
  126. /// promotion target_dtype value can become different from tensor's dtype
  127. /// also, during type promotion target_dtype and device can be set for an
  128. /// undefined tensor so that tensor can be properly constructed later.
  129. c10::optional<Device> device = c10::nullopt;
  130. ScalarType target_dtype = ScalarType::Undefined;
  131. // Caches dtype of the tensor, because scalar_type is an expensive operation
  132. // If dtype of the tensor is changed (e.g. as a result of type promotion or in
  133. // allocate_outputs), this
  134. // value should be changed too.
  135. ScalarType current_dtype = ScalarType::Undefined;
  136. bool is_device_defined() const {
  137. return device.has_value();
  138. }
  139. bool is_type_defined() const {
  140. return target_dtype != ScalarType::Undefined;
  141. }
  142. TensorOptions options() const {
  143. return TensorOptions(target_dtype).device(device);
  144. }
  145. /// The data pointer. This may be different from tensor->data_ptr() if the
  146. /// iterator is split.
  147. void* data = nullptr;
  148. bool is_output = false;
  149. bool will_resize = false;
  150. bool is_read_write = false;
  151. void validate() {
  152. TORCH_CHECK(
  153. !tensor_base_->defined() || tensor_base_->layout() == kStrided,
  154. "unsupported tensor layout: ",
  155. tensor_base_->layout());
  156. }
  157. /// The tensor operand. Note that the strides, data pointer, and
  158. /// other attributes may differ due to dimension reordering and
  159. /// coalescing.
  160. const Tensor& tensor() const {
  161. return tensor_storage_.getTensor();
  162. }
  163. const TensorBase& tensor_base() const {
  164. return *tensor_base_;
  165. }
  166. void tensor(c10::MaybeOwned<TensorBase>&& tensor);
  167. // Save the original tensor operand in cases when an output is modified
  168. // (e.g. if dtype is changed)
  169. const Tensor& original_tensor() const {
  170. return original_tensor_storage_.getTensor();
  171. }
  172. const TensorBase& original_tensor_base() const {
  173. return *original_tensor_base_;
  174. }
  175. // Set tensor to a new value, and store the old tensor value in
  176. // original_tensor Should only ever be called once for the lifetime of an
  177. // operand
  178. void exchange_tensor(c10::MaybeOwned<TensorBase>&& new_tensor);
  179. // Move original_tensor back into tensor, exchange_tensor must have been
  180. // called before
  181. void restore_original_tensor();
  182. private:
  183. c10::MaybeOwned<TensorBase> tensor_base_;
  184. c10::MaybeOwned<TensorBase> original_tensor_base_ =
  185. c10::MaybeOwned<TensorBase>::owned(c10::in_place);
  186. // We store TensorBase visibly in the header to allow inline access.
  187. // However, we sometimes need a genuine `const Tensor &` for the
  188. // TensorIterator API. So, we also store a non-owning `Tensor`
  189. // object in these `_storage_` variables.
  190. internal::OpaqueOptionalTensorRef tensor_storage_;
  191. internal::OpaqueOptionalTensorRef original_tensor_storage_;
  192. };
  193. struct SplitUntil32Bit;
  194. enum class FastSetupType : uint8_t {
  195. NONE,
  196. CONTIGUOUS,
  197. CHANNELS_LAST,
  198. NON_OVERLAPPING_DENSE
  199. };
  200. class TensorIteratorConfig;
  201. struct TensorIterator;
  202. struct TORCH_API TensorIteratorBase : public impl::MetaBase {
  203. using DimMask = std::bitset<64>;
  204. using PtrVector = SmallVector<char*, 4>;
  205. using StrideVector = SmallVector<int64_t, 6>;
  206. TensorIteratorBase();
  207. void build(TensorIteratorConfig&);
  208. // The inner-loop function operates on the fastest moving dimension. It
  209. // implements element-wise operations in terms of 1-d strided tensors.
  210. //
  211. // Arguments:
  212. // data: data pointers for each operand (length `ntensors`)
  213. // strides: stride for each operand (length `ntensors`)
  214. // size: size of inner loop
  215. //
  216. // The `size` often matches shape[0], but may be smaller due to
  217. // parallelization of the inner loop.
  218. using loop2d_t = c10::function_ref<
  219. void(char** data, const int64_t* strides, int64_t size0, int64_t size1)>;
  220. using loop_subiter_t = c10::function_ref<void(TensorIteratorBase& subiter)>;
  221. void foreach_reduced_elt(loop_subiter_t loop, bool parallelize = true);
  222. int ndim() const {
  223. return shape_.size();
  224. }
  225. IntArrayRef shape() const {
  226. return shape_;
  227. }
  228. int64_t numel() const;
  229. int ntensors() const {
  230. return operands_.size();
  231. }
  232. int noutputs() const {
  233. return num_outputs_;
  234. }
  235. int ninputs() const {
  236. return ntensors() - noutputs();
  237. }
  238. IntArrayRef view_offsets() const {
  239. return view_offsets_;
  240. }
  241. /// number of elements in the output operand. this is the same as numel() for
  242. /// operations that are not reductions.
  243. int64_t num_output_elements() const;
  244. /// number of reduced dimensions in a reduction operation
  245. int num_reduce_dims() const;
  246. /// 1-dimensional iteration and no buffering or type conversion
  247. bool is_trivial_1d() const;
  248. /// Reducible to 1-dimensional and all operands are contiguous
  249. bool is_contiguous() const;
  250. bool is_dim_reduced(int dim) const;
  251. /// Accessors for each operand
  252. IntArrayRef strides(int arg) const {
  253. return operands_[arg].stride_bytes;
  254. }
  255. void* data_ptr(int arg) const;
  256. ScalarType dtype(int arg = 0) const {
  257. return operands_[arg].current_dtype;
  258. }
  259. ScalarType common_dtype() const {
  260. TORCH_INTERNAL_ASSERT(
  261. common_dtype_ != ScalarType::Undefined,
  262. "Queried for invalid common dtype!");
  263. return common_dtype_;
  264. }
  265. ScalarType input_dtype(int arg = 0) const {
  266. return operands_[num_outputs_ + arg].current_dtype;
  267. }
  268. Device device(int arg = 0) const {
  269. return operands_[arg].device.value();
  270. }
  271. DeviceType device_type(int arg = 0) const {
  272. return device(arg).type();
  273. }
  274. int64_t element_size(int arg) const {
  275. return elementSize(dtype(arg));
  276. }
  277. bool is_scalar(int arg) const;
  278. bool is_cpu_scalar(int arg) const;
  279. const TensorBase& tensor_base(int arg) const {
  280. return operands_[arg].tensor_base();
  281. }
  282. const Tensor& tensor(int arg) const {
  283. return operands_[arg].tensor();
  284. }
  285. const TensorBase& output_base(int arg = 0) const {
  286. AT_ASSERT(arg < num_outputs_);
  287. return tensor_base(arg);
  288. }
  289. const Tensor& output(int arg = 0) const {
  290. AT_ASSERT(arg < num_outputs_);
  291. return tensor(arg);
  292. }
  293. const TensorBase& input_base(int arg = 0) const {
  294. AT_ASSERT(arg >= 0 && arg < ntensors() - num_outputs_);
  295. return tensor_base(num_outputs_ + arg);
  296. }
  297. const Tensor& input(int arg = 0) const {
  298. AT_ASSERT(arg >= 0 && arg < ntensors() - num_outputs_);
  299. return tensor(num_outputs_ + arg);
  300. }
  301. // Copies from temporary outputs back to the original outputs
  302. // NOTE: only used on CPU
  303. void cast_outputs();
  304. /// Removes an operand from this iterator
  305. void remove_operand(int arg);
  306. /// Shrinks an iterated dimension
  307. void narrow(int dim, int64_t start, int64_t size);
  308. /// Narrows every dim after and including `start_dim` to size one.
  309. void select_all_keeping_dim(int start_dim, IntArrayRef starts);
  310. /// Replaces the data pointer for the operand at index `arg`.
  311. /// The new pointer should have the same sizes, strides and dtype as the
  312. /// original
  313. void unsafe_replace_operand(int arg, void* data);
  314. /// Splits this TensorIterator into two iterators. Together they iterate over
  315. /// the entire operation. Used by `with_32bit_indexing()`.
  316. std::unique_ptr<TensorIterator> split(int dim);
  317. /// Returns the dimension with the largest extent: (size[dim]-1) * stride[dim]
  318. int get_dim_to_split() const;
  319. template <typename T>
  320. T scalar_value(int arg) {
  321. auto& op = operands_[arg];
  322. return c10::fetch_and_cast<T>(op.tensor_base().scalar_type(), op.data);
  323. }
  324. private:
  325. template <typename loop1d_t>
  326. auto loop_2d_from_1d(const loop1d_t& loop) {
  327. return
  328. [loop, ntensor = ntensors()](
  329. char** base, const int64_t* strides, int64_t size0, int64_t size1) {
  330. PtrVector data(base, base + ntensor);
  331. const int64_t* outer_strides = &strides[ntensor];
  332. for (const auto i : c10::irange(size1)) {
  333. if (i > 0) {
  334. for (const auto arg : c10::irange(ntensor)) {
  335. data[arg] += outer_strides[arg];
  336. }
  337. }
  338. loop(data.data(), strides, size0);
  339. }
  340. };
  341. }
  342. public:
  343. template <
  344. typename loop1d_t,
  345. std::enable_if_t<
  346. std::is_convertible<
  347. loop1d_t,
  348. c10::function_ref<
  349. void(char**, const int64_t* strides, int64_t size)>>::value,
  350. int> = 0>
  351. void for_each(loop1d_t loop, int64_t grain_size = at::internal::GRAIN_SIZE) {
  352. for_each(loop_2d_from_1d(loop), grain_size);
  353. }
  354. void for_each(loop2d_t loop, int64_t grain_size = at::internal::GRAIN_SIZE);
  355. void parallel_reduce(loop2d_t loop);
  356. template <
  357. typename loop1d_t,
  358. std::enable_if_t<
  359. std::is_convertible<
  360. loop1d_t,
  361. c10::function_ref<
  362. void(char**, const int64_t* strides, int64_t size)>>::value,
  363. int> = 0>
  364. void serial_for_each(loop1d_t loop, Range range) {
  365. serial_for_each(loop_2d_from_1d(loop), range);
  366. }
  367. void serial_for_each(loop2d_t loop, Range range) const;
  368. /// Create a strides array for a Tensor with shape of this iterator. The
  369. /// parameter `element_size` specifies the size of Tensor's data type in
  370. /// bytes (e.g. `4` for `float`)
  371. StrideVector compatible_stride(int element_size) const;
  372. /// Inverts the re-ordering done by reorder_dimensions. This can only be
  373. /// called *before* coalesce_dimensions() is called.
  374. DimVector invert_perm(IntArrayRef input) const;
  375. /// Reapply same re-ordering as it is done by reorder_dimensions. This can
  376. /// only be called *before* coalesce_dimensions() is called.
  377. DimVector apply_perm_and_mul(IntArrayRef input, int mul) const;
  378. /// Helper functions for CPU iteration
  379. StrideVector get_dim_strides(int dim) const;
  380. StrideVector get_strides() const;
  381. StrideVector get_inner_strides() const {
  382. return get_dim_strides(0);
  383. }
  384. PtrVector get_base_ptrs() const;
  385. // Helper functions for advanced stride manipulations (e.g. torch.flip)
  386. void _unsafe_set_arg_strides(const int arg, IntArrayRef strides) {
  387. operands_[arg].stride_bytes = std::move(strides);
  388. }
  389. void _unsafe_set_arg_data(const int arg, void* data) {
  390. operands_[arg].data = data;
  391. }
  392. /// true if the stride computation can use 32-bit arithmetic. Used by GPU
  393. /// kernels
  394. bool can_use_32bit_indexing() const;
  395. /// An "iteratable" object that recursively splits this iterator into
  396. /// sub-iterators that can use 32-bit indexing.
  397. SplitUntil32Bit with_32bit_indexing() const;
  398. /// If the kernel should accumulate into the output. Only relevant for CUDA
  399. /// reductions.
  400. bool should_accumulate() const {
  401. return accumulate_;
  402. }
  403. /// Whether this iterator produces the actual output,
  404. /// as opposed to something that will be accumulated further. Only relevant
  405. /// for CUDA reductions.
  406. bool is_final_output() const {
  407. return final_output_;
  408. }
  409. bool has_contiguous_first_dim() const {
  410. if (ndim() == 0) {
  411. return true;
  412. }
  413. int num_tensors = ntensors();
  414. for (const auto i : c10::irange(num_tensors)) {
  415. if (strides(i)[0] != element_size(i)) {
  416. return false;
  417. }
  418. }
  419. return true;
  420. }
  421. void set_output_raw_strided(
  422. int64_t output_idx,
  423. IntArrayRef sizes,
  424. IntArrayRef strides,
  425. TensorOptions options,
  426. DimnameList names) override;
  427. #define TORCH_DISALLOW_TEMPORARIES_IMPL(methodname, maybestatic) \
  428. maybestatic void methodname( \
  429. TensorBase&& out, const TensorBase& a, const TensorBase& b) = delete; \
  430. maybestatic void methodname( \
  431. const TensorBase& out, TensorBase&& a, const TensorBase& b) = delete; \
  432. maybestatic void methodname( \
  433. const TensorBase& out, const TensorBase& a, TensorBase&& b) = delete; \
  434. maybestatic void methodname( \
  435. TensorBase&& out, TensorBase&& a, const TensorBase& b) = delete; \
  436. maybestatic void methodname( \
  437. TensorBase&& out, const TensorBase& a, TensorBase&& b) = delete; \
  438. maybestatic void methodname( \
  439. const TensorBase& out, TensorBase&& a, TensorBase&& b) = delete; \
  440. maybestatic void methodname( \
  441. TensorBase&& out, TensorBase&& a, TensorBase&& b) = delete;
  442. #define TORCH_DISALLOW_TEMPORARIES(methodname) \
  443. TORCH_DISALLOW_TEMPORARIES_IMPL(methodname, )
  444. void build_binary_float_op(
  445. const TensorBase& out,
  446. const TensorBase& a,
  447. const TensorBase& b);
  448. void build_borrowing_binary_float_op(
  449. const TensorBase& out,
  450. const TensorBase& a,
  451. const TensorBase& b);
  452. TORCH_DISALLOW_TEMPORARIES(build_borrowing_binary_float_op)
  453. void build_binary_op(
  454. const TensorBase& out,
  455. const TensorBase& a,
  456. const TensorBase& b);
  457. void build_borrowing_binary_op(
  458. const TensorBase& out,
  459. const TensorBase& a,
  460. const TensorBase& b);
  461. TORCH_DISALLOW_TEMPORARIES(build_borrowing_binary_op)
  462. void build_unary_float_op(const TensorBase& out, const TensorBase& a);
  463. void build_borrowing_unary_float_op(
  464. const TensorBase& out,
  465. const TensorBase& a);
  466. TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_float_op)
  467. void build_unary_op(const TensorBase& out, const TensorBase& a);
  468. // Odd special case needed for pow. Has to borrow the output because
  469. // it's a structured kernel, but the argument is potentially a copy.
  470. void build_output_borrowing_argument_owning_unary_op(
  471. const TensorBase& out,
  472. const TensorBase& a);
  473. void build_borrowing_unary_op(const TensorBase& out, const TensorBase& a);
  474. TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_op)
  475. void build_borrowing_unary_force_boolean_op(
  476. const TensorBase& out,
  477. const TensorBase& a);
  478. TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_force_boolean_op)
  479. void build_comparison_op(
  480. const TensorBase& out,
  481. const TensorBase& a,
  482. const TensorBase& b);
  483. void build_borrowing_comparison_op(
  484. const TensorBase& out,
  485. const TensorBase& a,
  486. const TensorBase& b);
  487. TORCH_DISALLOW_TEMPORARIES(build_borrowing_comparison_op)
  488. // Another special case: we need to own the second argument for comparison
  489. // ops.
  490. void build_borrowing_except_last_argument_comparison_op(
  491. const TensorBase& out,
  492. const TensorBase& a,
  493. const TensorBase& b);
  494. void build_ternary_op(
  495. const TensorBase& out,
  496. const TensorBase& a,
  497. const TensorBase& b,
  498. const TensorBase& c);
  499. #undef TORCH_DISALLOW_TEMPORARIES
  500. protected:
  501. // Mutable reference as it moves tensors out of TensorIteratorConfig
  502. void populate_operands(TensorIteratorConfig&);
  503. void mark_outputs();
  504. void mark_resize_outputs(const TensorIteratorConfig&);
  505. void compute_mem_overlaps(const TensorIteratorConfig&);
  506. void compute_shape(const TensorIteratorConfig&);
  507. void compute_strides(const TensorIteratorConfig&);
  508. void reorder_dimensions();
  509. void permute_dimensions(IntArrayRef perm);
  510. void compute_types(const TensorIteratorConfig&);
  511. ScalarType compute_common_dtype();
  512. void allocate_or_resize_outputs();
  513. bool fast_set_up(const TensorIteratorConfig&);
  514. FastSetupType compute_fast_setup_type(const TensorIteratorConfig&);
  515. void compute_names(const TensorIteratorConfig&);
  516. void propagate_names_to_outputs();
  517. void coalesce_dimensions();
  518. protected:
  519. /// Records the "computation" shape of the output tensor. The computation
  520. /// shape is different from the regular shape in a few ways:
  521. ///
  522. /// - The shape may be permuted (via permute_dimensions) so that we
  523. /// process the dimensions in the most computationally efficient order
  524. /// (rather than the logical order given to us by the users.)
  525. /// - The shape may have adjacent dimensions collapsed (via
  526. /// coalesce_dimensions) so that we minimize the number of
  527. /// dimensions we have to explicitly iterate over. For example,
  528. /// a pointwise operation on a contiguous tensor "computationally"
  529. /// consists of only a single dimension.
  530. ///
  531. /// In other words, the computation shape is the output shape as it
  532. /// actually matters for implementing the kernel, but not necessarily the
  533. /// output shape that the user will see in the end.
  534. ///
  535. /// The lifecycle of mutations to shape_ in TensorIterator:
  536. /// - declare_static_shape() sets an initial shape explicitly
  537. /// provided by user, otherwise
  538. /// - compute_shape() computes the true (non-computational) shape
  539. /// specified by the user.
  540. /// - reorder_dimensions() reorders dimensions to improve coalescing.
  541. /// - coalesce_dimensions() then coalesces adjacent dimensions when
  542. /// possible.
  543. ///
  544. /// The shape may also be further modified if we create sub-TensorIterators,
  545. /// e.g., via narrow or select_all_keeping_dim.
  546. DimVector shape_;
  547. /// Temporarily records the permutation computed by reorder_dimensions.
  548. /// This permutation maps the computation output dimension (dim) to
  549. /// the original true output dimension (perm_[dim]). It is used by
  550. /// invert_perm to undo the permutation. After coalesce_dimensions is
  551. /// called, the permutation is no longer valid (as, in general, there
  552. /// is no permutation that will make computation dimensions to
  553. /// output dimensions); methods that manipulate perm_ are obligated
  554. /// to test that !has_coalesced_dimensions
  555. DimVector perm_;
  556. /// Has coalesce_dimensions() (or any moral equivalent, e.g., fast_build())
  557. /// been called? This is SOLELY used to check validity of perm_.
  558. bool has_coalesced_dimensions_ = false;
  559. /// Whether iteration must be fixed. This disables dimension permuting and
  560. /// also changes how for_each divides work among threads.
  561. bool enforce_linear_iteration_ = false;
  562. /// The index offsets into the original tensors for each dimension.
  563. /// This is only non-zero when you narrow() a TensorIterator (e.g.,
  564. /// when you make sub-TensorIterators).
  565. DimVector view_offsets_;
  566. /// The computed names of the output tensor. Computed by compute_names()
  567. NameVector names_;
  568. /// The operands of the TensorIterator: both the inputs and outputs. The
  569. /// outputs MUST come first in the operands_ list. There is always an
  570. /// operand for each output of the TensorIterator, even if TensorIterator
  571. /// will ultimately be responsible for allocating the output; in those
  572. /// cases, tensor is simply undefined (and will be populated later
  573. /// during build()).
  574. ///
  575. /// This list is initially populated prior to build(), but build() mutates
  576. /// OperandInfo to populate more information.
  577. SmallVector<OperandInfo, 4> operands_;
  578. /// Number of outputs in operands_ (the length of the outputs prefix
  579. /// in operands_).
  580. int num_outputs_ = 0;
  581. /// Whether or not all operands have the same shape and are 1d+. Having all
  582. /// the same shape affects whether or not the iterator is eligible for fast
  583. /// setup.
  584. bool all_ops_same_shape_ = false;
  585. /// Whether or not all operands are 0d, this affects type promotion
  586. bool all_ops_are_scalars_ = false;
  587. /// The "computation" dtype of TensorIterator, specifying what the dtype
  588. /// we will do the internal computation in TensorIterator. Typically,
  589. /// this matches the dtype of the output tensors, but not always!
  590. ScalarType common_dtype_ = ScalarType::Undefined;
  591. /// This is currently defined as kCPU, or the device of the first non-CPU
  592. /// tensor argument. See TensorIteratorBase::compute_types for details.
  593. Device common_device_ = kCPU;
  594. /// Set by split(), see should_accumulate() and is_final_output()
  595. bool accumulate_ = false;
  596. bool final_output_ = true;
  597. // From TensorIteratorConfig
  598. bool is_reduction_ = false;
  599. /// Set by populate_operands(), says if we're handling meta tensors
  600. bool is_meta_ = false;
  601. };
  602. struct TORCH_API TensorIterator final : public TensorIteratorBase {
  603. TensorIterator() : TensorIteratorBase() {}
  604. // Slicing is OK, TensorIterator guaranteed NOT to have any fields
  605. TensorIterator(const TensorIteratorBase& iter) : TensorIteratorBase(iter) {}
  606. #define TORCH_DISALLOW_TEMPORARIES(methodname) \
  607. TORCH_DISALLOW_TEMPORARIES_IMPL(methodname, static)
  608. static TensorIterator binary_float_op(
  609. TensorBase& out,
  610. const TensorBase& a,
  611. const TensorBase& b);
  612. static TensorIterator binary_op(
  613. TensorBase& out,
  614. const TensorBase& a,
  615. const TensorBase& b);
  616. static TensorIterator borrowing_binary_op(
  617. const TensorBase& out,
  618. const TensorBase& a,
  619. const TensorBase& b);
  620. TORCH_DISALLOW_TEMPORARIES(borrowing_binary_op)
  621. static TensorIterator comparison_op(
  622. TensorBase& out,
  623. const TensorBase& a,
  624. const TensorBase& b);
  625. static TensorIterator unary_op(TensorBase& out, const TensorBase& a);
  626. static TensorIterator unary_float_op(TensorBase& out, const TensorBase& a);
  627. static TensorIterator nullary_op(TensorBase& out);
  628. static TensorIterator borrowing_nullary_op(const TensorBase& out);
  629. static TensorIterator borrowing_nullary_op(TensorBase&& out) = delete;
  630. static TensorIterator reduce_op(TensorBase& out, const TensorBase& a);
  631. static TensorIterator reduce_op(
  632. TensorBase& out1,
  633. TensorBase& out2,
  634. const TensorBase& a);
  635. #undef TORCH_DISALLOW_TEMPORARIES
  636. #undef TORCH_DISALLOW_TEMPORARIES_IMPL
  637. const Tensor& maybe_get_output(int64_t output_idx) override;
  638. void set_output_raw_strided(
  639. int64_t output_idx,
  640. IntArrayRef sizes,
  641. IntArrayRef strides,
  642. TensorOptions options,
  643. DimnameList names) override;
  644. };
  645. class TORCH_API TensorIteratorConfig final {
  646. public:
  647. friend struct TensorIteratorBase;
  648. friend struct TensorIterator;
  649. TensorIteratorConfig() = default;
  650. C10_DISABLE_COPY_AND_ASSIGN(TensorIteratorConfig);
  651. /// Construction
  652. // Stores input/output Tensors without incrementing the reference count.
  653. // Important: the outputs have to be added before the inputs.
  654. TensorIteratorConfig& add_output(const TensorBase& output) {
  655. return add_borrowed_output(output);
  656. }
  657. TensorIteratorConfig& add_input(const TensorBase& input) {
  658. return add_borrowed_input(input);
  659. }
  660. // Borrowing from temporaries is unlikely to go well.
  661. TensorIteratorConfig& add_output(TensorBase&& output) = delete;
  662. TensorIteratorConfig& add_input(TensorBase&& input) = delete;
  663. // Stores input/output Tensors while incrementing the reference count.
  664. // Note that add_{in,out}put are nearly always what you
  665. // want, and the exception (adding an unnamed temporary) won't
  666. // compile.
  667. TensorIteratorConfig& add_owned_output(const TensorBase& output);
  668. TensorIteratorConfig& add_owned_input(const TensorBase& input);
  669. // Advanced API: stores input/output Tensors without incrementing
  670. // the reference count. The caller must ensure that these Tensors
  671. // live at least as long as this TensorIteratorConfig and any
  672. // TensorIteratorBase built from this TensorIteratorConfig.
  673. // Important: the outputs have to be added before the inputs.
  674. TensorIteratorConfig& add_borrowed_output(const TensorBase& output);
  675. TensorIteratorConfig& add_borrowed_input(const TensorBase& input);
  676. // Borrowing from temporaries is unlikely to go well.
  677. TensorIteratorConfig& add_borrowed_output(TensorBase&& output) = delete;
  678. TensorIteratorConfig& add_borrowed_input(TensorBase&& input) = delete;
  679. // Sets the check_mem_overlap_ flag, which is true by default.
  680. // If true, inputs are checked for partial overlap with the outputs and
  681. // outputs are checked for internal overlap (e.g. broadcasted views). An error
  682. // is raised if unacceptable overlap is detected.
  683. // If you're migrating an existing operator to using TensorIterator, please
  684. // consider if the previous implementation checked memory overlap. If it did
  685. // not, and if the operator is idempotent (for example, Tensor.fill_(0)), then
  686. // checking memory overlap is BC-breaking. Please don't check memory overlap
  687. // in that case.
  688. TensorIteratorConfig& set_check_mem_overlap(bool check_mem_overlap) {
  689. check_mem_overlap_ = check_mem_overlap;
  690. return *this;
  691. }
  692. // Sets the check_all_same_dtype_ flag, which is true by default
  693. // If true, checks that all inputs and defined outputs have the same dtype
  694. // Setting either of promote_inputs_to_common_dtype_
  695. // or cast_common_dtype_to_outputs_ to true will set
  696. // check_all_same_dtype_ to false.
  697. TensorIteratorConfig& check_all_same_dtype(const bool _check_all_same_dtype) {
  698. check_all_same_dtype_ = _check_all_same_dtype;
  699. return *this;
  700. }
  701. // Sets the check_all_same_device_ flag, which is true by default
  702. // If true, all operands must be on the same device, with the possible
  703. // exception of CPU scalars, which can be passed to some CUDA kernels
  704. // as kernel arguments.
  705. TensorIteratorConfig& check_all_same_device(
  706. const bool _check_all_same_device) {
  707. check_all_same_device_ = _check_all_same_device;
  708. return *this;
  709. }
  710. // Sets the enforce_safe_casting_to_output_ flag, which is false by default
  711. // If true, the iterator's "common dtype" must be computable
  712. // (see the [Common Dtype Computation] note) and
  713. // canCast(common dtype, output dtype) must be true for all outputs.
  714. TensorIteratorConfig& enforce_safe_casting_to_output(
  715. const bool _enforce_safe_casting_to_output) {
  716. enforce_safe_casting_to_output_ = _enforce_safe_casting_to_output;
  717. return *this;
  718. }
  719. // Sets the enforce_linear_iteration_ flag, which is false by default.
  720. // If true, iteration goes in the same order as a C-contiguous tensor
  721. // is layed out in memory. i.e. last dimension iterates fastest.
  722. //
  723. // This iteration order can be less efficient and may even prevent
  724. // vectorization. So only use if the correctness of your kernel depends on it.
  725. TensorIteratorConfig& enforce_linear_iteration(
  726. const bool _enforce_linear_iteration = true) {
  727. enforce_linear_iteration_ = _enforce_linear_iteration;
  728. return *this;
  729. }
  730. // Sets the promote_inputs_to_common_dtype_ flag, which is false by default
  731. // If true, the iterator's "common dtype" is always computed (see the
  732. // [Common Dtype Computation] note) and, on the CPU, temporary copies of
  733. // the inputs in the common dtype are passed as the actual inputs to
  734. // the operation.
  735. // Setting this flag to true sets check_all_same_dtype_ to false.
  736. TensorIteratorConfig& promote_inputs_to_common_dtype(
  737. const bool _promote_inputs_to_common_dtype) {
  738. promote_inputs_to_common_dtype_ = _promote_inputs_to_common_dtype;
  739. if (_promote_inputs_to_common_dtype) {
  740. check_all_same_dtype_ = false;
  741. }
  742. return *this;
  743. }
  744. // Sets the promote_integer_inputs_to_float_ flag, which is false by default
  745. // NOTE: If set to true, the promote_inputs_to_common_dtype_ must also be
  746. // true. If true, if the iterator's "common dtype" is an integral type
  747. // (including bool)
  748. // then it is changed to the default float scalar type.
  749. TensorIteratorConfig& promote_integer_inputs_to_float(
  750. const bool _promote_integer_inputs_to_float) {
  751. promote_integer_inputs_to_float_ = _promote_integer_inputs_to_float;
  752. TORCH_INTERNAL_ASSERT(
  753. !promote_integer_inputs_to_float_ || promote_inputs_to_common_dtype_);
  754. return *this;
  755. }
  756. TensorIteratorConfig& is_reduction(const bool _is_reduction) {
  757. is_reduction_ = _is_reduction;
  758. return *this;
  759. }
  760. TensorIteratorConfig& allow_cpu_scalars(const bool _allow_cpu_scalars) {
  761. allow_cpu_scalars_ = _allow_cpu_scalars;
  762. return *this;
  763. }
  764. // Sets the cast_common_dtype_to_outputs_ flag, which is false by default
  765. // If true, the iterator's "common dtype" must be computatable
  766. // (see the [Common Dtype Computation] note) and, on the CPU, temporary
  767. // copies of the outputs are passed as the actual output to the operation.
  768. // These temporaries are then copied to the original outputs after
  769. // the operation is performed (see cast_outputs()).
  770. // Setting this flag to true sets check_all_same_dtype_ to false.
  771. TensorIteratorConfig& cast_common_dtype_to_outputs(
  772. const bool _cast_common_dtype_to_outputs) {
  773. cast_common_dtype_to_outputs_ = _cast_common_dtype_to_outputs;
  774. if (_cast_common_dtype_to_outputs) {
  775. check_all_same_dtype_ = false;
  776. }
  777. return *this;
  778. }
  779. TensorIteratorConfig& resize_outputs(bool resize_outputs) {
  780. resize_outputs_ = resize_outputs;
  781. return *this;
  782. }
  783. // Bypass output dtype/device computation and fix the dtype/device as
  784. // specified here.
  785. TensorIteratorConfig& declare_static_dtype_and_device(
  786. ScalarType dtype,
  787. Device device);
  788. TensorIteratorConfig& declare_static_dtype(ScalarType dtype);
  789. TensorIteratorConfig& declare_static_device(Device device);
  790. TensorIteratorConfig& declare_static_shape(IntArrayRef shape);
  791. TensorIteratorConfig& declare_static_shape(
  792. IntArrayRef shape,
  793. IntArrayRef squash_dims);
  794. // It would be better if this was && qualified, but this would be at the cost
  795. // of a lot of boilerplate above
  796. TensorIterator build() {
  797. TensorIterator iter;
  798. iter.build(*this);
  799. return iter;
  800. }
  801. private:
  802. SmallVector<c10::MaybeOwned<TensorBase>, 4> tensors_;
  803. int num_outputs_ = 0;
  804. int num_inputs_ = 0;
  805. c10::optional<DimVector> static_shape_ = c10::nullopt;
  806. c10::optional<ScalarType> static_dtype_ = c10::nullopt;
  807. c10::optional<Device> static_device_ = c10::nullopt;
  808. bool check_mem_overlap_ = true;
  809. bool allow_cpu_scalars_ = false;
  810. bool is_reduction_ = false;
  811. bool resize_outputs_ = true;
  812. bool check_all_same_dtype_ = true;
  813. bool check_all_same_device_ = true;
  814. bool enforce_safe_casting_to_output_ = false;
  815. bool enforce_linear_iteration_ = false;
  816. bool promote_inputs_to_common_dtype_ = false;
  817. bool promote_integer_inputs_to_float_ = false;
  818. bool cast_common_dtype_to_outputs_ = false;
  819. };
  820. /// A container-like struct that acts as if it contains splits of a
  821. /// TensorIterator that can use 32-bit indexing. Taken together the splits cover
  822. /// the original TensorIterator.
  823. struct TORCH_API SplitUntil32Bit {
  824. struct TORCH_API iterator {
  825. iterator() = default;
  826. iterator(const TensorIteratorBase& iter);
  827. iterator(iterator&&) = default;
  828. // Guaranteed to be a TensorIterator proper!
  829. TensorIterator& operator*() const;
  830. iterator& operator++();
  831. bool operator==(const iterator& other) const {
  832. // two iterators are equal if they are the same object or they're both
  833. // empty
  834. return this == &other || (vec.empty() && other.vec.empty());
  835. }
  836. // needed for C++11 range-based for loop
  837. bool operator!=(const iterator& other) const {
  838. return !(*this == other);
  839. }
  840. /// stack of TensorIterators to be split
  841. std::vector<std::unique_ptr<TensorIterator>> vec;
  842. };
  843. SplitUntil32Bit(const TensorIteratorBase& iter) : iter(iter) {}
  844. iterator begin() const;
  845. iterator end() const;
  846. private:
  847. const TensorIteratorBase& iter;
  848. };
  849. } // namespace at
  850. C10_CLANG_DIAGNOSTIC_POP()