123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725 |
- #pragma once
- #include <ATen/ExpandUtils.h>
- #include <ATen/ScalarOps.h>
- #include <ATen/core/Tensor.h>
- #include <ATen/core/TensorBody.h>
- #include <c10/core/SymInt.h>
- #include <c10/util/Optional.h>
- #include <c10/util/irange.h>
- #ifndef AT_PER_OPERATOR_HEADERS
- #include <ATen/Functions.h>
- #include <ATen/NativeFunctions.h>
- #else
- #include <ATen/ops/alias.h>
- #include <ATen/ops/empty.h>
- #include <ATen/ops/scalar_tensor.h>
- #include <ATen/ops/zeros.h>
- #endif
- #include <ATen/core/List.h>
- #include <utility>
- namespace at {
- namespace indexing {
- const int64_t INDEX_MIN = c10::SymInt::min_representable_int();
- const int64_t INDEX_MAX = -(INDEX_MIN + 1);
- enum class TensorIndexType { None, Ellipsis, Integer, Boolean, Slice, Tensor };
- constexpr c10::nullopt_t None = c10::nullopt;
- struct TORCH_API EllipsisIndexType final {
- EllipsisIndexType() = default;
- };
- TORCH_API extern const EllipsisIndexType Ellipsis;
- struct TORCH_API Slice final {
- public:
- Slice(
- c10::optional<c10::SymInt> start_index = c10::nullopt,
- c10::optional<c10::SymInt> stop_index = c10::nullopt,
- c10::optional<c10::SymInt> step_index = c10::nullopt) {
- if (!step_index.has_value()) {
- step_ = c10::SymInt(1);
- } else {
- step_ = std::move(step_index).value();
- }
- TORCH_CHECK_VALUE(step_ != 0, "slice step cannot be zero");
- if (!start_index.has_value()) {
- start_ = c10::SymInt(step_ < 0 ? INDEX_MAX : 0);
- } else {
- start_ = std::move(start_index).value();
- }
- if (!stop_index.has_value()) {
- stop_ = c10::SymInt(step_ < 0 ? INDEX_MIN : INDEX_MAX);
- } else {
- stop_ = std::move(stop_index).value();
- }
- }
- inline c10::SymInt start() const {
- return start_;
- }
- inline c10::SymInt stop() const {
- return stop_;
- }
- inline c10::SymInt step() const {
- return step_;
- }
- private:
- c10::SymInt start_;
- c10::SymInt stop_;
- c10::SymInt step_;
- };
- TORCH_API std::ostream& operator<<(std::ostream& stream, const Slice& slice);
- // `at::indexing::TensorIndex` is used for converting C++ tensor indices such as
- // `{None, "...", Ellipsis, 0, true, Slice(1, None, 2), torch::tensor({1, 2})}`
- // into its equivalent `std::vector<TensorIndex>`, so that further tensor
- // indexing operations can be performed using the supplied indices.
- //
- // There is one-to-one correspondence between Python and C++ tensor index types:
- // Python | C++
- // -----------------------------------------------------
- // `None` | `at::indexing::None`
- // `Ellipsis` | `at::indexing::Ellipsis`
- // `...` | `"..."`
- // `123` | `123`
- // `True` / `False` | `true` / `false`
- // `:` | `Slice()` / `Slice(None, None)`
- // `::` | `Slice()` / `Slice(None, None, None)`
- // `1:` | `Slice(1, None)`
- // `1::` | `Slice(1, None, None)`
- // `:3` | `Slice(None, 3)`
- // `:3:` | `Slice(None, 3, None)`
- // `::2` | `Slice(None, None, 2)`
- // `1:3` | `Slice(1, 3)`
- // `1::2` | `Slice(1, None, 2)`
- // `:3:2` | `Slice(None, 3, 2)`
- // `1:3:2` | `Slice(1, 3, 2)`
- // `torch.tensor([1, 2])`) | `torch::tensor({1, 2})`
- struct TORCH_API TensorIndex final {
- // Case 1: `at::indexing::None`
- TensorIndex(c10::nullopt_t) : type_(TensorIndexType::None) {}
- // Case 2: "..." / `at::indexing::Ellipsis`
- TensorIndex(at::indexing::EllipsisIndexType)
- : type_(TensorIndexType::Ellipsis) {}
- TensorIndex(const char* str) : TensorIndex(at::indexing::Ellipsis) {
- TORCH_CHECK_VALUE(
- strcmp(str, "...") == 0,
- "Expected \"...\" to represent an ellipsis index, but got \"",
- str,
- "\"");
- }
- // Case 3: Integer value
- TensorIndex(int64_t integer)
- : integer_(integer), type_(TensorIndexType::Integer) {}
- TensorIndex(int integer) : TensorIndex((int64_t)integer) {}
- // Case 4: Boolean value
- template <
- class T,
- class = typename std::enable_if<std::is_same<bool, T>::value>::type>
- TensorIndex(T boolean) : boolean_(boolean), type_(TensorIndexType::Boolean) {}
- // Case 5: Slice represented in `at::indexing::Slice` form
- TensorIndex(Slice slice)
- : slice_(std::move(slice)), type_(TensorIndexType::Slice) {}
- // Case 6: Tensor value
- TensorIndex(Tensor tensor)
- : tensor_(std::move(tensor)), type_(TensorIndexType::Tensor) {}
- inline bool is_none() const {
- return type_ == TensorIndexType::None;
- }
- inline bool is_ellipsis() const {
- return type_ == TensorIndexType::Ellipsis;
- }
- inline bool is_integer() const {
- return type_ == TensorIndexType::Integer;
- }
- inline int64_t integer() const {
- return integer_;
- }
- inline bool is_boolean() const {
- return type_ == TensorIndexType::Boolean;
- }
- inline bool boolean() const {
- return boolean_;
- }
- inline bool is_slice() const {
- return type_ == TensorIndexType::Slice;
- }
- inline const Slice& slice() const {
- return slice_;
- }
- inline bool is_tensor() const {
- return type_ == TensorIndexType::Tensor;
- }
- inline const Tensor& tensor() const {
- return tensor_;
- }
- private:
- int64_t integer_ = 0;
- bool boolean_ = false;
- Slice slice_;
- Tensor tensor_;
- TensorIndexType type_;
- };
- TORCH_API std::ostream& operator<<(
- std::ostream& stream,
- const TensorIndex& tensor_index);
- TORCH_API std::ostream& operator<<(
- std::ostream& stream,
- const std::vector<TensorIndex>& tensor_indices);
- namespace impl {
- static inline Tensor applySlice(
- const Tensor& self,
- int64_t dim,
- c10::SymInt start,
- c10::SymInt stop,
- c10::SymInt step,
- bool disable_slice_optimization,
- const at::Device& self_device,
- const c10::optional<SymIntArrayRef>& self_sizes) {
- // TODO: implement negative step
- TORCH_CHECK_VALUE(step > 0, "step must be greater than zero");
- // See NOTE [nested tensor size for indexing]
- if (self_sizes.has_value()) {
- // Skip this optimization if we are tracing, as the trace may be polymorphic
- // over the shape of the `self` tensor, and we still want to record
- // the slice.
- SymInt length = (self_device == at::kCPU || self_device == at::kCUDA)
- ? (*self_sizes)[dim]
- : self.sym_size(dim);
- if (!disable_slice_optimization && start == 0 && length == stop &&
- step == 1) {
- return self;
- }
- }
- return self.slice_symint(dim, start, stop, std::move(step));
- }
- static inline Tensor applySelect(
- const Tensor& self,
- int64_t dim,
- int64_t index,
- int64_t real_dim,
- const at::Device& /*self_device*/,
- const c10::optional<SymIntArrayRef>& self_sizes) {
- // See NOTE [nested tensor size for indexing]
- if (self_sizes.has_value()) {
- TORCH_CHECK_INDEX(
- !(index == 0 && dim == 0 && self_sizes->empty()),
- "invalid index of a 0-dim tensor. ",
- "Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to convert a 0-dim tensor to a number");
- auto size = (*self_sizes)[dim];
- TORCH_CHECK_INDEX(
- size >= -index && size > index,
- "index ",
- index,
- " is out of bounds for dimension ",
- real_dim,
- " with size ",
- size);
- }
- // if the index is negative, do not normalize it because that would fix the
- // index on the current tensor size in the tracer. aten::select also works on
- // negative indices
- return self.select(dim, index);
- }
- static inline Tensor boolToIndexingTensorCPUOrCUDA(
- const Tensor& self,
- bool value) {
- // booleans add a dimension of size 1. true indexes this dimension as if 0:,
- // false as empty.
- if (value) {
- return at::empty({1}, {}, self.options().dtype(kLong)).fill_(0.);
- } else {
- return at::empty({0}, {}, self.options().dtype(kLong));
- }
- }
- static inline Tensor boolToIndexingTensorNonNativeDeviceType(
- const Tensor& self,
- bool value) {
- // booleans add a dimension of size 1. true indexes this dimension as if 0:,
- // false as empty.
- if (value) {
- return at::zeros({1}, {}, self.options().dtype(kLong));
- } else {
- return at::empty({0}, {}, self.options().dtype(kLong));
- }
- }
- static inline Tensor boolToIndexingTensor(
- const Tensor& self,
- bool value,
- const at::Device& self_device) {
- if (self_device == at::kCPU || self_device == at::kCUDA) {
- return boolToIndexingTensorCPUOrCUDA(self, value);
- } else {
- return boolToIndexingTensorNonNativeDeviceType(self, value);
- }
- }
- static inline Tensor scalarToTensorNonNativeDeviceType(
- const Scalar& v,
- const TensorOptions& options) {
- return at::scalar_tensor(v, options);
- }
- static inline void recordTensorIndex(
- const Tensor& tensor,
- std::vector<Tensor>& outIndices,
- int64_t* dim_ptr) {
- // TODO: check scalarType
- outIndices.resize(*dim_ptr + 1);
- outIndices[*dim_ptr] = tensor;
- (*dim_ptr)++;
- };
- static inline c10::List<c10::optional<Tensor>> typeConvertIndices(
- const Tensor& /*self*/,
- std::vector<Tensor>&& indices) {
- c10::List<c10::optional<Tensor>> converted_inds;
- converted_inds.reserve(indices.size());
- for (const auto& i : indices) {
- converted_inds.push_back(std::move(i));
- }
- return converted_inds;
- }
- // NOTE: Why do we mirror instead of replace the `count_specified_dimensions`
- // function in torch/csrc/autograd/python_variable_indexing.cpp? It's because
- // `count_specified_dimensions` is on the hot path of Python tensor multi-dim
- // indexing (i.e. it's called by `applySlicing` which is called by
- // `THPVariable_getitem` / `THPVariable_setitem` when handling indexing of more
- // than one dimension). If we were to merge the Python/C++
- // `count_specified_dimensions` function, on the Python side we would have to
- // construct a `std::vector` container to be consumed by the C++
- // `count_specified_dimensions` function, which adds 100s of nanoseconds
- // overhead and is undesirable.
- static inline int64_t count_specified_dimensions(
- const ArrayRef<TensorIndex>& indices) {
- // Count the number of indexed dimensions (everything but ellipsis and None)
- int64_t count = 0;
- for (auto& obj : indices) {
- if (obj.is_tensor()) {
- auto& tensor = obj.tensor();
- if (tensor.scalar_type() == kByte || tensor.scalar_type() == kBool) {
- count += tensor.dim();
- } else {
- count++;
- }
- } else if (!obj.is_none() && !obj.is_ellipsis() && !obj.is_boolean()) {
- count++;
- }
- }
- return count;
- }
- } // namespace impl
- // NOTE: Many functions below are only for consumption from Python indexing
- // implementation, they include:
- //
- // - `Tensor scalarToTensor(...)`
- // - `IntArrayRef slicePrefix1sSize(...)`
- // - `void copy_to(...)`
- // - `Tensor handleDimInMultiDimIndexing(...)`
- // - `Tensor dispatch_index(...)`
- // - `Tensor dispatch_index_put_(...)`
- // - `Tensor get_item(...)`
- // - `void set_item(...)`
- //
- // The rest of the functions are in `at::indexing::impl` namespace, signifying
- // that they shouldn't be used from Python indexing implementation.
- static inline Tensor scalarToTensor(
- const Scalar& v,
- const TensorOptions& options,
- const at::Device& self_device) {
- if (self_device == at::kCPU) {
- return at::detail::scalar_tensor_static(
- v, options.dtype_opt()->toScalarType(), self_device);
- } else {
- return impl::scalarToTensorNonNativeDeviceType(v, options);
- }
- }
- // To match numpy semantics:
- // As a special case for backwards compatibility,
- // strip away unit dimensions from the left of 'src'
- static inline SymIntArrayRef slicePrefix1sSize(const SymIntArrayRef& sizes) {
- size_t first_non1_src = sizes.size();
- for (const auto i : c10::irange(sizes.size())) {
- if (sizes[i] != 1) {
- first_non1_src = i;
- break;
- }
- }
- return sizes.slice(first_non1_src);
- }
- static inline void copy_to(const Tensor& dst, const Tensor& src) {
- if (dst.sym_sizes().equals(src.sym_sizes())) {
- // A shortcut to avoid generating hard-coded constant sizes during tracing.
- // This is not a perfect solution: when src & dst have different shapes,
- // constants will still appear. Users can workaround that case by
- // dst[index..] = src.reshape(..)
- dst.copy_(src);
- return;
- } else if (src.dim() == 0 && src.device().type() == at::kCPU) {
- dst.fill_(src);
- return;
- }
- auto src_view = src.view_symint(slicePrefix1sSize(src.sym_sizes()));
- c10::MaybeOwned<Tensor> b_src = expand_inplace(dst, src_view, "setitem");
- dst.copy_(*b_src);
- }
- // See NOTE [ Setting `disable_slice_optimization` when calling C++ tensor
- // indexing functions from Python ]
- static inline Tensor handleDimInMultiDimIndexing(
- const Tensor& prev_dim_result,
- const Tensor& original_tensor,
- const TensorIndex& index,
- int64_t* dim_ptr,
- int64_t* specified_dims_ptr,
- int64_t real_dim,
- std::vector<Tensor>& outIndices,
- bool disable_slice_optimization,
- const at::Device& original_tensor_device,
- const c10::optional<SymIntArrayRef>& prev_dim_result_sizes) {
- if (index.is_integer()) {
- return impl::applySelect(
- prev_dim_result,
- *dim_ptr,
- index.integer(),
- real_dim,
- original_tensor_device,
- prev_dim_result_sizes);
- } else if (index.is_slice()) {
- Tensor result = impl::applySlice(
- prev_dim_result,
- *dim_ptr,
- index.slice().start(),
- index.slice().stop(),
- index.slice().step(),
- /*disable_slice_optimization=*/disable_slice_optimization,
- original_tensor_device,
- prev_dim_result_sizes);
- (*dim_ptr)++;
- return result;
- } else if (index.is_ellipsis()) {
- (*dim_ptr) += original_tensor.dim() - (*specified_dims_ptr);
- return prev_dim_result;
- } else if (index.is_none()) {
- Tensor result = prev_dim_result.unsqueeze(*dim_ptr);
- (*dim_ptr)++;
- return result;
- } else if (index.is_boolean()) {
- Tensor result = prev_dim_result.unsqueeze(*dim_ptr);
- impl::recordTensorIndex(
- impl::boolToIndexingTensor(
- result, index.boolean(), original_tensor_device),
- outIndices,
- dim_ptr);
- return result;
- } else if (index.is_tensor()) {
- Tensor result = prev_dim_result;
- const Tensor& tensor = index.tensor();
- auto scalar_type = tensor.scalar_type();
- if (tensor.dim() == 0 &&
- at::isIntegralType(scalar_type, /*includeBool=*/true)) {
- if (scalar_type != at::kByte && scalar_type != at::kBool) {
- result = impl::applySelect(
- result,
- *dim_ptr,
- tensor.item<int64_t>(),
- real_dim,
- original_tensor_device,
- prev_dim_result_sizes);
- } else {
- result = result.unsqueeze(*dim_ptr);
- if (scalar_type == at::kBool) {
- impl::recordTensorIndex(
- impl::boolToIndexingTensor(
- result, tensor.item<bool>() != 0, original_tensor_device),
- outIndices,
- dim_ptr);
- } else {
- impl::recordTensorIndex(
- impl::boolToIndexingTensor(
- result, tensor.item<uint8_t>() != 0, original_tensor_device),
- outIndices,
- dim_ptr);
- }
- }
- } else {
- impl::recordTensorIndex(tensor, outIndices, dim_ptr);
- }
- return result;
- } else {
- TORCH_INTERNAL_ASSERT(false, "Invalid TensorIndex type");
- }
- }
- namespace impl {
- // This mirrors `applySlicing` in
- // torch/csrc/autograd/python_variable_indexing.cpp
- static inline Tensor applySlicing(
- const Tensor& self,
- const ArrayRef<TensorIndex>& indices,
- std::vector<Tensor>& outIndices,
- bool disable_slice_optimization,
- const at::Device& self_device,
- const c10::optional<SymIntArrayRef>& self_sizes) {
- int64_t dim = 0;
- int64_t specified_dims = impl::count_specified_dimensions(indices);
- // See NOTE [nested tensor size for indexing]
- if (self_sizes.has_value()) {
- TORCH_CHECK_INDEX(
- specified_dims <= (int64_t)self_sizes->size(),
- "too many indices for tensor of dimension ",
- (int)self_sizes->size());
- }
- Tensor result = self;
- for (const auto i : c10::irange(indices.size())) {
- auto& obj = indices[i];
- // See NOTE [nested tensor size for indexing]
- c10::optional<SymIntArrayRef> result_sizes = result.is_nested()
- ? c10::optional<SymIntArrayRef>(c10::nullopt)
- : c10::optional<SymIntArrayRef>(result.sym_sizes());
- result = handleDimInMultiDimIndexing(
- /*prev_dim_result=*/result,
- /*original_tensor=*/self,
- /*index=*/obj,
- /*dim=*/&dim,
- /*specified_dims=*/&specified_dims,
- /*real_dim=*/i,
- /*outIndices=*/outIndices,
- /*disable_slice_optimization=*/disable_slice_optimization,
- /*original_tensor_device=*/self_device,
- /*prev_dim_result_sizes=*/result_sizes);
- }
- return result;
- }
- } // namespace impl
- static inline Tensor dispatch_index(
- const Tensor& self,
- std::vector<Tensor>&& indices) {
- return self.index(impl::typeConvertIndices(self, std::move(indices)));
- }
- static inline Tensor dispatch_index_put_(
- Tensor& self,
- std::vector<Tensor>&& indices,
- const Tensor& value) {
- return self.index_put_(
- impl::typeConvertIndices(self, std::move(indices)), value);
- }
- // NOTE [ Setting `disable_slice_optimization` when calling C++ tensor indexing
- // functions from Python ]
- //
- // Question: When should we set `disable_slice_optimization` to `true` when
- // calling C++ tensor indexing functions from Python indexing code?
- //
- // Answer: What "slice optimization" means: when we have a slicing expression
- // like `x[0:5, 0]`, where the sliced tensor was of size 5 in dimension 0, we
- // would skip dispatching the actual slice call as an optimization. However,
- // here are the cases where we DON'T want this optimization:
- //
- // 1. When we are doing 1-D slicing (e.g. `tensor[:]`).
- // Reason: we always return a shallow copy for expressions such as
- // `tensor[:]` / `tensor[...]` / `tensor[:, :]`. (Note that for `tensor[:,
- // :]`, we return an alias of `tensor` by doing the following:
- // ```
- // Tensor sliced = impl::applySlicing(self, indices, tensorIndices,
- // disable_slice_optimization, self_device, self_sizes); if
- // (tensorIndices.empty()) {
- // if (sliced.is_same(self)) {
- // // ensure we return a shallow copy for things like x[...]
- // sliced = at::alias(sliced);
- // }
- // return sliced;
- // }
- // ```)
- // 2. When we are doing JIT tracing.
- // Reason: JIT tracing needs the `self.slice(...)` call to properly trace the
- // slice operation.
- // This mirrors `THPVariable_getitem` in
- // torch/csrc/autograd/python_variable_indexing.cpp See NOTE [ Setting
- // `disable_slice_optimization` when calling C++ tensor indexing functions from
- // Python ]
- static inline Tensor get_item(
- const Tensor& self,
- const ArrayRef<TensorIndex>& indices,
- bool disable_slice_optimization = false) {
- at::Device self_device = self.device();
- // NOTE [nested tensor size for indexing]
- // nested tensor does not have a size (yet) so for now we represent its size
- // as null may need to be changed after we reach a better solution for nested
- // tensor size
- c10::optional<SymIntArrayRef> self_sizes = self.is_nested()
- ? c10::optional<SymIntArrayRef>(c10::nullopt)
- : c10::optional<SymIntArrayRef>(self.sym_sizes());
- // handle simple types: integers, slices, none, ellipsis, bool
- if (indices.size() == 1) {
- const TensorIndex& index = indices[0];
- if (index.is_integer()) {
- return impl::applySelect(
- self, 0, index.integer(), 0, self_device, self_sizes);
- } else if (index.is_slice()) {
- return impl::applySlice(
- self,
- 0,
- index.slice().start(),
- index.slice().stop(),
- index.slice().step(),
- /*disable_slice_optimization=*/true,
- self_device,
- self_sizes);
- } else if (index.is_none()) {
- return self.unsqueeze(0);
- } else if (index.is_ellipsis()) {
- return at::alias(self);
- } else if (index.is_boolean()) {
- Tensor result = self.unsqueeze(0);
- return dispatch_index(
- result,
- std::vector<Tensor>{impl::boolToIndexingTensor(
- result, index.boolean(), self_device)});
- }
- }
- std::vector<Tensor> tensorIndices;
- Tensor sliced = impl::applySlicing(
- self,
- indices,
- tensorIndices,
- disable_slice_optimization,
- self_device,
- self_sizes);
- if (tensorIndices.empty()) {
- if (sliced.is_same(self)) {
- // ensure we return a shallow copy for things like x[...]
- sliced = at::alias(sliced);
- }
- return sliced;
- }
- // indexing by tensors ("advanced" indexing)
- return dispatch_index(sliced, std::move(tensorIndices));
- }
- // This mirrors `THPVariable_setitem` in
- // torch/csrc/autograd/python_variable_indexing.cpp for "the assigned value is a
- // Tensor" case See NOTE [ Setting `disable_slice_optimization` when calling C++
- // tensor indexing functions from Python ]
- static inline void set_item(
- const Tensor& self,
- const ArrayRef<TensorIndex>& indices,
- const Tensor& value,
- bool disable_slice_optimization = false) {
- at::Device self_device = self.device();
- SymIntArrayRef self_sizes = self.sym_sizes();
- // handle simple types: integers, slices, ellipsis, bool
- if (indices.size() == 1) {
- const TensorIndex& index = indices[0];
- if (index.is_boolean() && !index.boolean()) {
- // do nothing for false (technically we should check the size, but we
- // don't have real 0-sized shapes.
- return;
- } else if (index.is_ellipsis()) {
- copy_to(self, value);
- return;
- } else if (index.is_none() || (index.is_boolean() && index.boolean())) {
- copy_to(self.unsqueeze(0), value);
- return;
- } else if (index.is_integer()) {
- copy_to(
- impl::applySelect(
- self, 0, index.integer(), 0, self_device, self_sizes),
- value);
- return;
- } else if (index.is_slice()) {
- copy_to(
- impl::applySlice(
- self,
- 0,
- index.slice().start(),
- index.slice().stop(),
- index.slice().step(),
- /*disable_slice_optimization=*/disable_slice_optimization,
- self_device,
- self_sizes),
- value);
- return;
- }
- }
- std::vector<Tensor> tensorIndices;
- Tensor sliced = impl::applySlicing(
- self,
- indices,
- tensorIndices,
- disable_slice_optimization,
- self_device,
- self_sizes);
- if (tensorIndices.empty()) {
- copy_to(sliced, value);
- return;
- }
- SymIntArrayRef valueSizes = value.sym_sizes();
- SymIntArrayRef slicedValueSizes = slicePrefix1sSize(valueSizes);
- Tensor valuesSliced;
- if (!valueSizes.equals(slicedValueSizes)) {
- valuesSliced = value.view_symint(slicedValueSizes);
- } else {
- valuesSliced = value;
- }
- dispatch_index_put_(sliced, std::move(tensorIndices), valuesSliced);
- return;
- }
- } // namespace indexing
- } // namespace at
|