FunctionalTensorWrapper.h 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. #pragma once
  2. #include <ATen/ArrayRef.h>
  3. #include <ATen/FunctionalStorageImpl.h>
  4. #include <ATen/core/IListRef.h>
  5. #include <ATen/core/List.h>
  6. #include <ATen/core/boxing/BoxedKernel.h>
  7. #include <ATen/core/boxing/impl/boxing.h>
  8. #include <ATen/core/dispatch/Dispatcher.h>
  9. #include <c10/core/DispatchKey.h>
  10. namespace at {
  11. // Note [Functionalization Pass In Core]
  12. // The Functionalization pass is used to remove aliasing from a pytorch program.
  13. //
  14. // This is useful for backends that don't support aliasing, like XLA and Vulkan.
  15. // It's also necessary in order to remove mutation from a program, which is
  16. // needed in Functorch.
  17. //
  18. // Consider this program:
  19. // a = torch.ones(...)
  20. // b = a.view(...)
  21. // b.add_(1)
  22. //
  23. // In this program, b is meant to alias with a due to the use of view(). At the
  24. // end of the program, both a and b are full of 2's. However, backends that
  25. // don't support aliasing aren't able to correctly implement the view()
  26. // operator. Instead, they can opt into the Functionalization pass, which will
  27. // sit between the user and the backend, and provide the necessary aliasing
  28. // logic.
  29. //
  30. // The functionalization pass will turn the above program into a slightly
  31. // different program that has the same semantics, transparently to the user,
  32. // that backends like XLA/Vulkan are able to implement a = torch.ones(...) b =
  33. // a.view_copy(...) # view() replaced with view_copy(). Backends like
  34. // XLA/Vulkan can implement this! b.add_(1) a.add_(1) # Our functionalization
  35. // pass machinery knows that a and b are aliased - it applies b's mutation to a
  36. // too.
  37. //
  38. // So, how does the functionalization pass keep track of which tensors are
  39. // aliased? The pass works by wrapping EVERY tensor in the program inside of a
  40. // FunctionalTensorWrapper, which knows about its alias'd tensors.
  41. //
  42. // See Note [Functionalization: Alias Removal] for details on the aliasing
  43. // machinery. See Note [Functionalization: Mutation Removal] for details on
  44. // mutation removal.
  45. struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
  46. explicit FunctionalTensorWrapper(const Tensor& value);
  47. // Additional constructor to create a FunctionalTensorWrapper directly from an
  48. // underlying tensor that was created from a view. For example, the code b =
  49. // a.view1() will generate a constructor call to FunctionalTensorWrapper(b, a,
  50. // view1_meta)
  51. explicit FunctionalTensorWrapper(
  52. const Tensor& view_value,
  53. const FunctionalTensorWrapper* base,
  54. functionalization::ViewMeta meta);
  55. // Get the underlying, actual tensor, that doesn't know anything about
  56. // functionalization.
  57. const Tensor& value() const {
  58. return value_;
  59. };
  60. // The concept of "level" is only ever important to functorch; it's exposed
  61. // here as more of a hook for functorch to use.
  62. int64_t level() const {
  63. return level_;
  64. };
  65. void set_level(int64_t level) {
  66. level_ = level;
  67. }
  68. // Sync's the underlying tensor with its alias, if it's out of date. This
  69. // involves two steps: 1) Apply any pending updates/mutations to the alias 2)
  70. // Replay the views (if any) to regenerate the current tensor off of the
  71. // updated alias.
  72. void sync_();
  73. // Performs step (1) of the sync. This is its own public API because it's
  74. // needed by view_inplace ops like transpose_. See Note [Functionalization
  75. // Pass - Inplace View Ops]
  76. void regenerate_from_base();
  77. // Performs step (2) of the sync. This is its own public API because it's
  78. // needed by functorch. functorch wants to make sure that all input tensors to
  79. // a functionalized program have been properly synced so it can properly
  80. // propagate mutations to inputs. It can't just call sync_(), because the
  81. // FunctionalTensorWrapper will look like it has no aliases and sync_ will be
  82. // a noop. We use the reference count on storage_ to determine if the wrapper
  83. // is aliased, and by the time functorch is ready to propagate updates to
  84. // inputs, any intermediate views of the input created by the program will
  85. // have been deallocated. This function also returns whether or not the base
  86. // actually had any updates to apply.
  87. bool apply_updates();
  88. // Takes the current state of value_ and snapshots it, sending it as a pending
  89. // update to the alias.
  90. void commit_update();
  91. // When any tensor is mutated, the tensor increments its alias's "generation".
  92. // Separately, each tensor maintains its own "generation" counter, which is
  93. // used to determine if it's up-to-date with its alias. The act of syncing a
  94. // tensor will set a tensor's generation equal to its alias's generation.
  95. bool is_up_to_date() const;
  96. // Freezes the storage of this tensor, preventing subsequent mutations
  97. void freeze_storage() const;
  98. // Every FunctionalTensorWrapper contains a vector<ViewMeta> objects
  99. // describing the series of view ops that ran to generate the current tensor
  100. // from the base tensor. This method is used by inplace-view ops like
  101. // transpose_. It appends a ViewMeta to the existing stack, and refreshes the
  102. // tensor by replaying the views off of the alias.
  103. void mutate_view_meta(at::functionalization::ViewMeta meta);
  104. // The functionalization pass can be used to remove mutations.
  105. // It does so by replacing any mutation op with it's corresponding
  106. // out-of-place op, followed by a call to replace_(). e.g:
  107. //
  108. // a.add_(1)
  109. //
  110. // will turn into:
  111. //
  112. // tmp = a.add(1)
  113. // a.replace_(tmp)
  114. //
  115. // replace_() swaps out the wrapped tensor, value_, with tmp.
  116. void replace_(const Tensor& other);
  117. // See Note[resize_() in functionalization pass]
  118. void maybe_replace_storage(const Tensor& other);
  119. c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
  120. const c10::VariableVersion& version_counter,
  121. bool allow_tensor_metadata_change) const override;
  122. c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
  123. c10::VariableVersion&& version_counter,
  124. bool allow_tensor_metadata_change) const override;
  125. ~FunctionalTensorWrapper() override = default;
  126. // FunctionalTensorWrapper overrides all custom size/stride function,
  127. // so that if the inner tensor has a custom implementation
  128. // we make sure to call that implementation.
  129. at::IntArrayRef sizes_custom() const override;
  130. at::IntArrayRef strides_custom() const override;
  131. int64_t dim_custom() const override;
  132. int64_t numel_custom() const override;
  133. bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
  134. c10::SymIntArrayRef sym_sizes_custom() const override;
  135. c10::SymInt sym_size_custom(int64_t d) const override;
  136. c10::SymIntArrayRef sym_strides_custom() const override;
  137. c10::SymInt sym_storage_offset_custom() const override;
  138. c10::Device device_custom() const override;
  139. private:
  140. const char* tensorimpl_type_name() const override;
  141. void set_constructor_metadata();
  142. functionalization::FunctionalStorageImpl* functional_storage_impl() const;
  143. // This is used to re-implement shallow_copy_and_detach for
  144. // FunctionalTensorWrapper. The implementation is identical, but we just need
  145. // to return a subclass instead of a plain TensorImpl.
  146. // TODO: maybe it's possible to arrange for that to happen automatically
  147. // without an override here?
  148. template <typename VariableVersion>
  149. c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
  150. VariableVersion&& version_counter,
  151. bool allow_tensor_metadata_change) const;
  152. // Note that value is not taken by reference: internally, the wrapper will
  153. // change the value tensor that it points to over time.
  154. Tensor value_;
  155. int64_t level_;
  156. size_t generation_ = 0;
  157. std::vector<at::functionalization::ViewMeta> view_metas_;
  158. };
  159. // Utility functions for the functionalization pass.
  160. namespace functionalization {
  161. namespace impl {
  162. TORCH_API inline FunctionalTensorWrapper* unsafeGetFunctionalWrapper(
  163. const Tensor& tensor) {
  164. auto functional_impl =
  165. static_cast<FunctionalTensorWrapper*>(tensor.unsafeGetTensorImpl());
  166. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(functional_impl != nullptr);
  167. return functional_impl;
  168. }
  169. TORCH_API bool isFunctionalTensor(const at::Tensor& tensor);
  170. TORCH_API bool isFunctionalTensor(const c10::optional<Tensor>& t);
  171. TORCH_API bool isFunctionalTensor(
  172. const c10::List<c10::optional<Tensor>>& t_list);
  173. TORCH_API bool isFunctionalTensor(ITensorListRef list);
  174. TORCH_API Tensor to_functional_tensor(const Tensor& tensor);
  175. TORCH_API c10::optional<Tensor> to_functional_tensor(
  176. const c10::optional<Tensor>& tensor);
  177. TORCH_API c10::List<c10::optional<Tensor>> to_functional_tensor(
  178. const c10::List<c10::optional<Tensor>>& t_list);
  179. TORCH_API std::vector<Tensor> to_functional_tensor(ITensorListRef t_list);
  180. TORCH_API void freeze_functional_tensor(const Tensor& tensor);
  181. TORCH_API Tensor
  182. from_functional_tensor(const Tensor& tensor, bool assert_functional = true);
  183. TORCH_API c10::optional<Tensor> from_functional_tensor(
  184. const c10::optional<Tensor>& t,
  185. bool assert_functional = true);
  186. TORCH_API c10::List<c10::optional<Tensor>> from_functional_tensor(
  187. const c10::List<c10::optional<Tensor>>& t_list);
  188. TORCH_API std::vector<Tensor> from_functional_tensor(ITensorListRef t_list);
  189. TORCH_API void sync(const at::Tensor& t);
  190. TORCH_API void sync(const c10::optional<Tensor>& t);
  191. TORCH_API void sync(const c10::List<c10::optional<Tensor>> t_list);
  192. TORCH_API void sync(ITensorListRef t_list);
  193. TORCH_API void replace_(const Tensor& functional_tensor, const Tensor& other);
  194. TORCH_API void replace_(
  195. const ITensorListRef functional_tensor,
  196. ITensorListRef other);
  197. TORCH_API void commit_update(const Tensor& functional_tensor);
  198. TORCH_API void commit_update(ITensorListRef functional_tensor);
  199. Tensor create_functional_tensor_with_view_meta(
  200. const Tensor& view_to_wrap,
  201. const Tensor& base,
  202. functionalization::ViewMeta meta,
  203. int64_t out_idx = 0);
  204. std::vector<Tensor> create_functional_tensor_with_view_meta(
  205. ITensorListRef view_to_wrap,
  206. const Tensor& base,
  207. functionalization::ViewMeta meta);
  208. void mutate_view_meta(const Tensor& self, functionalization::ViewMeta meta);
  209. void set_sizes_strides_offset(const Tensor& out, const Tensor& meta_out);
  210. void set_sizes_strides_offset(
  211. const std::vector<Tensor>& outs,
  212. const std::vector<Tensor>& meta_outs);
  213. // ~~~~~ TLS used in functionalization ~~~~~
  214. TORCH_API bool getFunctionalizationReapplyViewsTLS();
  215. TORCH_API void setFunctionalizationReapplyViewsTLS(bool reapply_views);
  216. class TORCH_API FunctionalizationReapplyViewsGuard {
  217. public:
  218. FunctionalizationReapplyViewsGuard(bool reapply_views)
  219. : prev_(getFunctionalizationReapplyViewsTLS()) {
  220. setFunctionalizationReapplyViewsTLS(reapply_views);
  221. }
  222. ~FunctionalizationReapplyViewsGuard() {
  223. setFunctionalizationReapplyViewsTLS(prev_);
  224. }
  225. FunctionalizationReapplyViewsGuard(
  226. const FunctionalizationReapplyViewsGuard&) = delete;
  227. FunctionalizationReapplyViewsGuard operator=(
  228. const FunctionalizationReapplyViewsGuard&) = delete;
  229. FunctionalizationReapplyViewsGuard(FunctionalizationReapplyViewsGuard&&) =
  230. delete;
  231. FunctionalizationReapplyViewsGuard operator=(
  232. FunctionalizationReapplyViewsGuard&&) = delete;
  233. private:
  234. bool prev_;
  235. };
  236. } // namespace impl
  237. // Helper function to call an out-of-place composite aten kernel that may use
  238. // mutations / views internally, and functionalize them.
  239. TORCH_API void functionalize_op_helper(
  240. const c10::OperatorHandle& op,
  241. torch::jit::Stack* stack);
  242. template <class Op, bool symint, class ReturnType, class... ParameterTypes>
  243. struct _functionalize_aten_op final {};
  244. template <class Op, bool symint, class ReturnType, class... ParameterTypes>
  245. struct _functionalize_aten_op<Op, symint, ReturnType(ParameterTypes...)> final {
  246. static ReturnType call(
  247. typename c10::maybe_keep_symint<symint, ParameterTypes>::type... args) {
  248. using FuncType = ReturnType(
  249. typename c10::maybe_keep_symint<symint, ParameterTypes>::type...);
  250. auto op = c10::Dispatcher::singleton()
  251. .findSchemaOrThrow(
  252. (const char*)Op::name, (const char*)Op::overload_name)
  253. .typed<FuncType>();
  254. return c10::impl::BoxedKernelWrapper<FuncType>::call(
  255. c10::BoxedKernel::makeFromFunction<functionalize_op_helper>(),
  256. op,
  257. // BoxedKernelWrapper knows to ignore this keyset argument,
  258. // because functionalize_op_helper doesn't take in a DispatchKeySet
  259. c10::DispatchKeySet(),
  260. args...);
  261. }
  262. };
  263. template <class Op>
  264. using functionalize_aten_op =
  265. _functionalize_aten_op<Op, false, typename Op::schema>;
  266. template <class Op>
  267. using functionalize_aten_op_symint =
  268. _functionalize_aten_op<Op, true, typename Op::schema>;
  269. } // namespace functionalization
  270. } // namespace at