LegacyBatchedTensorImpl.h 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. #pragma once
  2. #include <bitset>
  3. #include <utility>
  4. #include <ATen/ArrayRef.h>
  5. #include <ATen/SmallVector.h>
  6. #include <ATen/Tensor.h>
  7. namespace at {
  8. // We assume this in a few other places in the codebase,
  9. // but there isn't a centralized definition.
  10. constexpr int64_t kVmapMaxTensorDims = 64;
  11. // The valid vmap levels range from [0, 64). This effectively means that we
  12. // support a maximum of 64 nested vmaps.
  13. constexpr int64_t kVmapNumLevels = 64;
  14. // Store this number of elements of BatchDims on the stack. Most people will
  15. // probably use <= 5 nested vmaps, but adjust this number as necessary.
  16. constexpr int64_t kBatchDimsStackSize = 5;
  17. // a BatchDim represents a "private" dimension on a Tensor created inside of
  18. // vmap. It is a (level, dim) tuple, with the `dim` indicating which dimension
  19. // is being vmap'ed over and the `level` being an identifier for which vmap
  20. // said dimension was created inside. The `dim` corresponds to a "physical
  21. // dim" - it is a dimension index on the underlying physical tensor that is
  22. // being vmapped over.
  23. struct BatchDim {
  24. BatchDim(int64_t level, int64_t dim) : dim_(dim), level_(level) {}
  25. int64_t dim() const {
  26. return dim_;
  27. }
  28. int64_t level() const {
  29. return level_;
  30. }
  31. private:
  32. int64_t dim_;
  33. int64_t level_;
  34. };
  35. using BatchDims = SmallVector<BatchDim, kBatchDimsStackSize>;
  36. using BatchDimsRef = ArrayRef<BatchDim>;
  37. // A BatchedTensorImpl holds an underlying Tensor and a list of BatchDim
  38. // NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
  39. // BatchedTensorImpl.
  40. //
  41. // The batch dimensions are treated as being "private"; they are not
  42. // user-visible. For example, in the following Tensor,
  43. // bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2, dim=1)])
  44. // dimensions 0 and 1 are batch dimensions.
  45. //
  46. // bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public)
  47. // dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7)
  48. // tensor.
  49. struct TORCH_API BatchedTensorImpl : public c10::TensorImpl {
  50. explicit BatchedTensorImpl(Tensor value, BatchDims bdims);
  51. // Returns a reference to BatchDims that represent which dimensions of this
  52. // tensor are private.
  53. BatchDimsRef bdims() const {
  54. return bdims_;
  55. }
  56. // BatchedTensorImpl wraps a Tensor
  57. const Tensor& value() const {
  58. return value_;
  59. };
  60. // Given a public dimension index, return the dimension index in the
  61. // underlying value() tensor. For example, if we have
  62. // bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2,
  63. // dim=2)])
  64. // bt.actualDim(0) -> 1
  65. // bt.actualDim(1) -> 3
  66. // bt.actualDim(2) -> Error
  67. int64_t actualDim(int64_t dim, bool wrap_dim = true) const;
  68. // We have to override this because we opted into CustomStrides
  69. IntArrayRef strides_custom() const override;
  70. // Override a bunch of methods inherited from TensorImpl to return error
  71. // messages.
  72. bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
  73. void set_size(int64_t dim, int64_t new_size) override;
  74. void set_stride(int64_t dim, int64_t new_stride) override;
  75. void set_storage_offset(int64_t storage_offset) override;
  76. #ifdef DEBUG
  77. bool has_storage() const override;
  78. #endif
  79. private:
  80. // see NOTE: [BatchedTensorImpl levels invariant]
  81. void checkInvariants() const;
  82. const char* tensorimpl_type_name() const override;
  83. Tensor value_;
  84. // Note: [BatchedTensorImpl levels invariant]
  85. // There is an invariant that the BatchDims must be stored in increasing
  86. // `level` order. That is, for i < j, bdims_[i].level must be less than
  87. // bdims_[j].level.
  88. BatchDims bdims_;
  89. };
  90. // NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
  91. // BatchedTensorImpl.
  92. inline bool isBatchedTensor(const Tensor& tensor) {
  93. return tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::Batched);
  94. }
  95. // It is unsafe to call this on a Tensor that is not backed by a
  96. // BatchedTensorImpl. Please use `maybeGetBatchedImpl` whenever possible.
  97. inline BatchedTensorImpl* unsafeGetBatchedImpl(Tensor tensor) {
  98. return static_cast<BatchedTensorImpl*>(tensor.unsafeGetTensorImpl());
  99. }
  100. inline BatchedTensorImpl* maybeGetBatchedImpl(Tensor tensor) {
  101. if (!isBatchedTensor(tensor)) {
  102. return nullptr;
  103. }
  104. return unsafeGetBatchedImpl(std::move(tensor));
  105. }
  106. // Returns a bitset. If bit i is set, then that means dim i is a batchdim.
  107. inline std::bitset<kVmapMaxTensorDims> createBatchDimBitset(
  108. BatchDimsRef bdims) {
  109. std::bitset<kVmapMaxTensorDims> is_bdim;
  110. for (const auto& bdim : bdims) {
  111. is_bdim.set(bdim.dim());
  112. }
  113. return is_bdim;
  114. }
  115. // Creates a bitset for all of the levels present in `bdims`
  116. inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(BatchDimsRef bdims) {
  117. std::bitset<kVmapNumLevels> result;
  118. for (const auto& bdim : bdims) {
  119. result.set(bdim.level());
  120. }
  121. return result;
  122. }
  123. inline std::ostream& operator<<(std::ostream& out, const BatchDim& bdim) {
  124. out << "(lvl=" << bdim.level() << ", dim=" << bdim.dim() << ")";
  125. return out;
  126. }
  127. // Use this to construct a BatchedTensor from a regular Tensor
  128. TORCH_API Tensor makeBatched(const Tensor& tensor, BatchDims bdims);
  129. // Adds a batch dim to `tensor`, returning a BatchedTensor
  130. TORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t level, int64_t dim);
  131. // Checks if an inplace operation on self and other is "vmap compatible".
  132. // See NOTE: [vmap-incompatible in-place operations] for the definition of this.
  133. TORCH_API bool inplaceIsVmapCompatible(const Tensor& self, const Tensor& other);
  134. } // namespace at