TensorWrapper.h 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. // Copyright (c) Facebook, Inc. and its affiliates.
  2. // All rights reserved.
  3. //
  4. // This source code is licensed under the BSD-style license found in the
  5. // LICENSE file in the root directory of this source tree.
  6. #pragma once
  7. #include <ATen/functorch/Macros.h>
  8. #include <ATen/Tensor.h>
  9. #include <ATen/functorch/Interpreter.h>
  10. namespace at {
  11. namespace functorch {
  12. // NOTE: [functorch's TensorWrapper]
  13. //
  14. // Taking better suggestions for a name. TensorWrapper is the wrapper Tensor
  15. // Subclass for functorch's grad-based transforms (grad, vjp, jvp). It is
  16. // analogous to how vmap uses BatchedTensor as the wrapper Tensor subclass.
  17. //
  18. // If you're familiar with the Tensor-Variable merge, TensorWrapper is effectively
  19. // another Variable.
  20. //
  21. // Consider grad(grad(torch.sin))(x). This wraps `x` as TensorWrapper(TensorWrapper(x)).
  22. // The reason why is so that each TensorWrapper can hold its own AutogradMeta and
  23. // participate in a **separate** autograd graph.
  24. //
  25. // There are alternative designs we could have chosen (e.g. each grad transform
  26. // stores a weak map of Tensor -> AutogradMeta); the benefit of the TensorWrapper
  27. // design is that we can re-use existing VariableType kernels (i.e. Autograd kernels)
  28. // without much modification. Since a TensorWrapper looks like a regular Tensor,
  29. // the VariableType kernel can pull out the AutogradMeta struct from where it
  30. // expects and extend the autograd graph
  31. struct TORCH_API TensorWrapper : public c10::TensorImpl {
  32. explicit TensorWrapper(
  33. c10::DispatchKeySet key_set,
  34. Tensor value,
  35. int64_t level,
  36. std::shared_ptr<bool> is_alive,
  37. bool is_immutable = false, // if true, this came from an operation that aliases an immutable tensor
  38. bool use_value_sizes_strides = true);
  39. // Override a bunch of methods inherited from TensorImpl to return error messages
  40. void set_size(int64_t dim, int64_t new_size) override;
  41. void set_stride(int64_t dim, int64_t new_stride) override;
  42. void set_storage_offset(int64_t storage_offset) override;
  43. void refreshMetadata();
  44. const Tensor& value() const {
  45. return value_;
  46. }
  47. optional<int64_t> level() const {
  48. if (is_alive()) {
  49. return level_;
  50. }
  51. return {};
  52. }
  53. bool is_immutable() const {
  54. return is_immutable_;
  55. }
  56. bool is_alive() const;
  57. // Overrides necessary for autograd
  58. c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
  59. const c10::VariableVersion& version_counter,
  60. bool allow_tensor_metadata_change) const override;
  61. c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
  62. c10::VariableVersion&& version_counter,
  63. bool allow_tensor_metadata_change) const override;
  64. void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;
  65. private:
  66. const char* tensorimpl_type_name() const override;
  67. Tensor value_;
  68. int64_t level_;
  69. bool is_immutable_;
  70. // TensorWrapper receives a boolean flag on whether or not the Grad Interpreter
  71. // that created it is still alive or not.
  72. // If the Grad Interpreter is no longer alive then it attempts to behave like
  73. // a regular Tensor.
  74. //
  75. // When we exit the level, this wrapper may be marked as "not alive".
  76. // Wrappers that are not alive:
  77. // 1) May still have autograd metadata on them
  78. // 2) Forward dispatches to the underlying value()
  79. std::shared_ptr<bool> is_alive_;
  80. };
  81. // There are two variants of makeTensorWrapper: one that accepts a level
  82. // and one that accepts an Interpreter.
  83. //
  84. // The one that accepts a level tries to automatically get the life handle from the
  85. // interpreter on the DynamicLayerStack.
  86. // It needs to be used with caution: if the interpreter is not on the
  87. // DynamicLayerStack, then we won't be able to find the life handle.
  88. //
  89. // In practice this isn't a problem: when we're constructing TensorWrapper in
  90. // Python, the corresponding interpreter is on the stack.
  91. TORCH_API Tensor makeTensorWrapper(const Tensor& tensor, int64_t level, bool is_immutable=false);
  92. TORCH_API Tensor makeTensorWrapper(const Tensor& tensor, const Interpreter& interpreter, bool is_immutable=false);
  93. TORCH_API TensorWrapper* maybeGetTensorWrapper(const Tensor& tensor);
  94. TORCH_API void dumpTensor(std::ostream & ss, const Tensor& tensor);
  95. TORCH_API void dumpTensorCout(const Tensor& tensor);
  96. }
  97. } // namespace at