BatchedTensorImpl.h 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. // Copyright (c) Facebook, Inc. and its affiliates.
  2. // All rights reserved.
  3. //
  4. // This source code is licensed under the BSD-style license found in the
  5. // LICENSE file in the root directory of this source tree.
  6. #pragma once
  7. #include <bitset>
  8. #include <utility>
  9. #include <ATen/ArrayRef.h>
  10. #include <ATen/SmallVector.h>
  11. #include <ATen/Tensor.h>
  12. namespace at {
  13. namespace functorch {
  14. using Tensor = at::Tensor;
  15. // We assume this in a few other places in the codebase,
  16. // but there isn't a centralized definition.
  17. constexpr int64_t kVmapMaxTensorDims = 64;
  18. // The valid vmap levels range from [0, 64). This effectively means that we
  19. // support a maximum of 64 nested vmaps.
  20. constexpr int64_t kVmapNumLevels = 64;
  21. // Store this number of elements of BatchDims on the stack. Most people will
  22. // probably use <= 5 nested vmaps, but adjust this number as necessary.
  23. constexpr int64_t kBatchDimsStackSize = 5;
  24. // A BatchedTensorImpl holds an underlying Tensor and a single batch dim
  25. // NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
  26. // BatchedTensorImpl.
  27. //
  28. // The batch dimensions are treated as being "private"; they are not user-visible.
  29. // For example, in the following Tensor,
  30. // bt = BatchedTensorImpl(ones(2, 3, 5, 7), lvl=1, dim=0)
  31. // dimension 0 is batch dimension.
  32. //
  33. // bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public)
  34. // dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7) tensor.
  35. struct TORCH_API BatchedTensorImpl : public c10::TensorImpl {
  36. explicit BatchedTensorImpl(at::DispatchKeySet key_set, Tensor value, int64_t dim, int64_t level);
  37. // Returns batch dimension of this tensor
  38. int64_t bdim() const { return bdim_; }
  39. // Returns batch dimension of this tensor
  40. int64_t level() const { return level_; }
  41. // BatchedTensorImpl wraps a Tensor
  42. const Tensor& value() const { return value_; }
  43. // Given a public dimension index, return the dimension index in the underlying
  44. // value() tensor.
  45. // For example, if we have
  46. // bt = BatchedTensorImpl(ones(2, 3, 5, 7), lvl=1, dim=0)
  47. // bt.actualDim(0) -> 1
  48. // bt.actualDim(1) -> 2
  49. // bt.actualDim(2) -> 3
  50. // bt.actualDim(3) -> Error
  51. int64_t actualDim(int64_t dim, bool wrap_dim = true) const;
  52. // We have to override this because we opted into CustomStrides
  53. IntArrayRef strides_custom() const override;
  54. SymIntArrayRef sym_strides_custom() const override;
  55. // Override a bunch of methods inherited from TensorImpl to return error messages.
  56. bool is_contiguous_custom(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const override;
  57. void set_size(int64_t dim, int64_t new_size) override;
  58. void set_stride(int64_t dim, int64_t new_stride) override;
  59. void set_storage_offset(int64_t storage_offset) override;
  60. #ifdef DEBUG
  61. bool has_storage() const override;
  62. #endif
  63. void refreshTensorMetadata();
  64. // Used in torchdim. torchdim uses non-lexical BatchedTensor; the way it
  65. // accomplishes this is a hack where it is able to modify the levels of
  66. // BatchedTensor to match the level of the current vmap transform.
  67. void _unsafe_set_level(int64_t level) {
  68. level_ = level;
  69. }
  70. // Used in batching rule for in-place view operations that can change
  71. // the index of the bdim (think squeeze_, unsqueeze_)
  72. void unsafe_set_bdim(int64_t bdim) {
  73. // NB: you MUST call refreshTensorMetadata after doing this.
  74. bdim_ = bdim;
  75. }
  76. private:
  77. // see NOTE: [BatchedTensorImpl levels invariant]
  78. void checkInvariants() const;
  79. const char* tensorimpl_type_name() const override;
  80. Tensor value_;
  81. int64_t level_;
  82. int64_t bdim_;
  83. };
  84. // NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
  85. // BatchedTensorImpl.
  86. inline bool isBatchedTensor(const Tensor& tensor) {
  87. return tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::FuncTorchBatched);
  88. }
  89. // It is unsafe to call this on a Tensor that is not backed by a
  90. // BatchedTensorImpl. Please use `maybeGetBatchedImpl` whenever possible.
  91. inline BatchedTensorImpl* unsafeGetBatchedImpl(Tensor tensor) {
  92. return static_cast<BatchedTensorImpl*>(tensor.unsafeGetTensorImpl());
  93. }
  94. inline BatchedTensorImpl* maybeGetBatchedImpl(Tensor tensor) {
  95. if (!isBatchedTensor(tensor)) {
  96. return nullptr;
  97. }
  98. return unsafeGetBatchedImpl(std::move(tensor));
  99. }
  100. // Returns a bitset. If bit i is set, then that means dim i is a batchdim.
  101. inline std::bitset<kVmapMaxTensorDims> createBatchDimBitset(int64_t dim) {
  102. std::bitset<kVmapMaxTensorDims> is_bdim;
  103. is_bdim.set(dim);
  104. return is_bdim;
  105. }
  106. // Creates a bitset for the given level
  107. inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(int64_t level) {
  108. std::bitset<kVmapNumLevels> result;
  109. result.set(level);
  110. return result;
  111. }
  112. // Use this to construct a BatchedTensor from a regular Tensor
  113. TORCH_API Tensor makeBatched(const Tensor& tensor, int64_t dim, int64_t level);
  114. // Adds a batch dim to `tensor`, returning a BatchedTensor
  115. TORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t dim, int64_t level);
  116. // Certain dispatch keys must be propagated to the BatchedTensor (or, in general,
  117. // any wrapper Tensor subclasses). This is because there are methods on Tensor
  118. // that skip dispatch and check for the presence of a dispatch key (e.g. is_cpu()).
  119. // TODO: should probably contain more (or all?) backend keys
  120. constexpr DispatchKeySet kKeysToPropagateToWrapper({
  121. DispatchKey::Negative,
  122. DispatchKey::Conjugate,
  123. DispatchKey::XLA,
  124. DispatchKey::CUDA,
  125. DispatchKey::CPU,
  126. });
  127. inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) {
  128. auto key_set = tensor.unsafeGetTensorImpl()->key_set();
  129. return key_set & kKeysToPropagateToWrapper;
  130. }
  131. }
  132. }