123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161 |
- #pragma once
- #include <bitset>
- #include <utility>
- #include <ATen/ArrayRef.h>
- #include <ATen/SmallVector.h>
- #include <ATen/Tensor.h>
- namespace at {
- // We assume this in a few other places in the codebase,
- // but there isn't a centralized definition.
- constexpr int64_t kVmapMaxTensorDims = 64;
- // The valid vmap levels range from [0, 64). This effectively means that we
- // support a maximum of 64 nested vmaps.
- constexpr int64_t kVmapNumLevels = 64;
- // Store this number of elements of BatchDims on the stack. Most people will
- // probably use <= 5 nested vmaps, but adjust this number as necessary.
- constexpr int64_t kBatchDimsStackSize = 5;
- // a BatchDim represents a "private" dimension on a Tensor created inside of
- // vmap. It is a (level, dim) tuple, with the `dim` indicating which dimension
- // is being vmap'ed over and the `level` being an identifier for which vmap
- // said dimension was created inside. The `dim` corresponds to a "physical
- // dim" - it is a dimension index on the underlying physical tensor that is
- // being vmapped over.
- struct BatchDim {
- BatchDim(int64_t level, int64_t dim) : dim_(dim), level_(level) {}
- int64_t dim() const {
- return dim_;
- }
- int64_t level() const {
- return level_;
- }
- private:
- int64_t dim_;
- int64_t level_;
- };
- using BatchDims = SmallVector<BatchDim, kBatchDimsStackSize>;
- using BatchDimsRef = ArrayRef<BatchDim>;
- // A BatchedTensorImpl holds an underlying Tensor and a list of BatchDim
- // NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
- // BatchedTensorImpl.
- //
- // The batch dimensions are treated as being "private"; they are not
- // user-visible. For example, in the following Tensor,
- // bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2, dim=1)])
- // dimensions 0 and 1 are batch dimensions.
- //
- // bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public)
- // dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7)
- // tensor.
- struct TORCH_API BatchedTensorImpl : public c10::TensorImpl {
- explicit BatchedTensorImpl(Tensor value, BatchDims bdims);
- // Returns a reference to BatchDims that represent which dimensions of this
- // tensor are private.
- BatchDimsRef bdims() const {
- return bdims_;
- }
- // BatchedTensorImpl wraps a Tensor
- const Tensor& value() const {
- return value_;
- };
- // Given a public dimension index, return the dimension index in the
- // underlying value() tensor. For example, if we have
- // bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2,
- // dim=2)])
- // bt.actualDim(0) -> 1
- // bt.actualDim(1) -> 3
- // bt.actualDim(2) -> Error
- int64_t actualDim(int64_t dim, bool wrap_dim = true) const;
- // We have to override this because we opted into CustomStrides
- IntArrayRef strides_custom() const override;
- // Override a bunch of methods inherited from TensorImpl to return error
- // messages.
- bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
- void set_size(int64_t dim, int64_t new_size) override;
- void set_stride(int64_t dim, int64_t new_stride) override;
- void set_storage_offset(int64_t storage_offset) override;
- #ifdef DEBUG
- bool has_storage() const override;
- #endif
- private:
- // see NOTE: [BatchedTensorImpl levels invariant]
- void checkInvariants() const;
- const char* tensorimpl_type_name() const override;
- Tensor value_;
- // Note: [BatchedTensorImpl levels invariant]
- // There is an invariant that the BatchDims must be stored in increasing
- // `level` order. That is, for i < j, bdims_[i].level must be less than
- // bdims_[j].level.
- BatchDims bdims_;
- };
- // NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
- // BatchedTensorImpl.
- inline bool isBatchedTensor(const Tensor& tensor) {
- return tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::Batched);
- }
- // It is unsafe to call this on a Tensor that is not backed by a
- // BatchedTensorImpl. Please use `maybeGetBatchedImpl` whenever possible.
- inline BatchedTensorImpl* unsafeGetBatchedImpl(Tensor tensor) {
- return static_cast<BatchedTensorImpl*>(tensor.unsafeGetTensorImpl());
- }
- inline BatchedTensorImpl* maybeGetBatchedImpl(Tensor tensor) {
- if (!isBatchedTensor(tensor)) {
- return nullptr;
- }
- return unsafeGetBatchedImpl(std::move(tensor));
- }
- // Returns a bitset. If bit i is set, then that means dim i is a batchdim.
- inline std::bitset<kVmapMaxTensorDims> createBatchDimBitset(
- BatchDimsRef bdims) {
- std::bitset<kVmapMaxTensorDims> is_bdim;
- for (const auto& bdim : bdims) {
- is_bdim.set(bdim.dim());
- }
- return is_bdim;
- }
- // Creates a bitset for all of the levels present in `bdims`
- inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(BatchDimsRef bdims) {
- std::bitset<kVmapNumLevels> result;
- for (const auto& bdim : bdims) {
- result.set(bdim.level());
- }
- return result;
- }
- inline std::ostream& operator<<(std::ostream& out, const BatchDim& bdim) {
- out << "(lvl=" << bdim.level() << ", dim=" << bdim.dim() << ")";
- return out;
- }
- // Use this to construct a BatchedTensor from a regular Tensor
- TORCH_API Tensor makeBatched(const Tensor& tensor, BatchDims bdims);
- // Adds a batch dim to `tensor`, returning a BatchedTensor
- TORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t level, int64_t dim);
- // Checks if an inplace operation on self and other is "vmap compatible".
- // See NOTE: [vmap-incompatible in-place operations] for the definition of this.
- TORCH_API bool inplaceIsVmapCompatible(const Tensor& self, const Tensor& other);
- } // namespace at
|