TensorIndexing.h 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725
  1. #pragma once
  2. #include <ATen/ExpandUtils.h>
  3. #include <ATen/ScalarOps.h>
  4. #include <ATen/core/Tensor.h>
  5. #include <ATen/core/TensorBody.h>
  6. #include <c10/core/SymInt.h>
  7. #include <c10/util/Optional.h>
  8. #include <c10/util/irange.h>
  9. #ifndef AT_PER_OPERATOR_HEADERS
  10. #include <ATen/Functions.h>
  11. #include <ATen/NativeFunctions.h>
  12. #else
  13. #include <ATen/ops/alias.h>
  14. #include <ATen/ops/empty.h>
  15. #include <ATen/ops/scalar_tensor.h>
  16. #include <ATen/ops/zeros.h>
  17. #endif
  18. #include <ATen/core/List.h>
  19. #include <utility>
  20. namespace at {
  21. namespace indexing {
  22. const int64_t INDEX_MIN = c10::SymInt::min_representable_int();
  23. const int64_t INDEX_MAX = -(INDEX_MIN + 1);
  24. enum class TensorIndexType { None, Ellipsis, Integer, Boolean, Slice, Tensor };
  25. constexpr c10::nullopt_t None = c10::nullopt;
  26. struct TORCH_API EllipsisIndexType final {
  27. EllipsisIndexType() = default;
  28. };
  29. TORCH_API extern const EllipsisIndexType Ellipsis;
  30. struct TORCH_API Slice final {
  31. public:
  32. Slice(
  33. c10::optional<c10::SymInt> start_index = c10::nullopt,
  34. c10::optional<c10::SymInt> stop_index = c10::nullopt,
  35. c10::optional<c10::SymInt> step_index = c10::nullopt) {
  36. if (!step_index.has_value()) {
  37. step_ = c10::SymInt(1);
  38. } else {
  39. step_ = std::move(step_index).value();
  40. }
  41. TORCH_CHECK_VALUE(step_ != 0, "slice step cannot be zero");
  42. if (!start_index.has_value()) {
  43. start_ = c10::SymInt(step_ < 0 ? INDEX_MAX : 0);
  44. } else {
  45. start_ = std::move(start_index).value();
  46. }
  47. if (!stop_index.has_value()) {
  48. stop_ = c10::SymInt(step_ < 0 ? INDEX_MIN : INDEX_MAX);
  49. } else {
  50. stop_ = std::move(stop_index).value();
  51. }
  52. }
  53. inline c10::SymInt start() const {
  54. return start_;
  55. }
  56. inline c10::SymInt stop() const {
  57. return stop_;
  58. }
  59. inline c10::SymInt step() const {
  60. return step_;
  61. }
  62. private:
  63. c10::SymInt start_;
  64. c10::SymInt stop_;
  65. c10::SymInt step_;
  66. };
  67. TORCH_API std::ostream& operator<<(std::ostream& stream, const Slice& slice);
  68. // `at::indexing::TensorIndex` is used for converting C++ tensor indices such as
  69. // `{None, "...", Ellipsis, 0, true, Slice(1, None, 2), torch::tensor({1, 2})}`
  70. // into its equivalent `std::vector<TensorIndex>`, so that further tensor
  71. // indexing operations can be performed using the supplied indices.
  72. //
  73. // There is one-to-one correspondence between Python and C++ tensor index types:
  74. // Python | C++
  75. // -----------------------------------------------------
  76. // `None` | `at::indexing::None`
  77. // `Ellipsis` | `at::indexing::Ellipsis`
  78. // `...` | `"..."`
  79. // `123` | `123`
  80. // `True` / `False` | `true` / `false`
  81. // `:` | `Slice()` / `Slice(None, None)`
  82. // `::` | `Slice()` / `Slice(None, None, None)`
  83. // `1:` | `Slice(1, None)`
  84. // `1::` | `Slice(1, None, None)`
  85. // `:3` | `Slice(None, 3)`
  86. // `:3:` | `Slice(None, 3, None)`
  87. // `::2` | `Slice(None, None, 2)`
  88. // `1:3` | `Slice(1, 3)`
  89. // `1::2` | `Slice(1, None, 2)`
  90. // `:3:2` | `Slice(None, 3, 2)`
  91. // `1:3:2` | `Slice(1, 3, 2)`
  92. // `torch.tensor([1, 2])`) | `torch::tensor({1, 2})`
  93. struct TORCH_API TensorIndex final {
  94. // Case 1: `at::indexing::None`
  95. TensorIndex(c10::nullopt_t) : type_(TensorIndexType::None) {}
  96. // Case 2: "..." / `at::indexing::Ellipsis`
  97. TensorIndex(at::indexing::EllipsisIndexType)
  98. : type_(TensorIndexType::Ellipsis) {}
  99. TensorIndex(const char* str) : TensorIndex(at::indexing::Ellipsis) {
  100. TORCH_CHECK_VALUE(
  101. strcmp(str, "...") == 0,
  102. "Expected \"...\" to represent an ellipsis index, but got \"",
  103. str,
  104. "\"");
  105. }
  106. // Case 3: Integer value
  107. TensorIndex(int64_t integer)
  108. : integer_(integer), type_(TensorIndexType::Integer) {}
  109. TensorIndex(int integer) : TensorIndex((int64_t)integer) {}
  110. // Case 4: Boolean value
  111. template <
  112. class T,
  113. class = typename std::enable_if<std::is_same<bool, T>::value>::type>
  114. TensorIndex(T boolean) : boolean_(boolean), type_(TensorIndexType::Boolean) {}
  115. // Case 5: Slice represented in `at::indexing::Slice` form
  116. TensorIndex(Slice slice)
  117. : slice_(std::move(slice)), type_(TensorIndexType::Slice) {}
  118. // Case 6: Tensor value
  119. TensorIndex(Tensor tensor)
  120. : tensor_(std::move(tensor)), type_(TensorIndexType::Tensor) {}
  121. inline bool is_none() const {
  122. return type_ == TensorIndexType::None;
  123. }
  124. inline bool is_ellipsis() const {
  125. return type_ == TensorIndexType::Ellipsis;
  126. }
  127. inline bool is_integer() const {
  128. return type_ == TensorIndexType::Integer;
  129. }
  130. inline int64_t integer() const {
  131. return integer_;
  132. }
  133. inline bool is_boolean() const {
  134. return type_ == TensorIndexType::Boolean;
  135. }
  136. inline bool boolean() const {
  137. return boolean_;
  138. }
  139. inline bool is_slice() const {
  140. return type_ == TensorIndexType::Slice;
  141. }
  142. inline const Slice& slice() const {
  143. return slice_;
  144. }
  145. inline bool is_tensor() const {
  146. return type_ == TensorIndexType::Tensor;
  147. }
  148. inline const Tensor& tensor() const {
  149. return tensor_;
  150. }
  151. private:
  152. int64_t integer_ = 0;
  153. bool boolean_ = false;
  154. Slice slice_;
  155. Tensor tensor_;
  156. TensorIndexType type_;
  157. };
  158. TORCH_API std::ostream& operator<<(
  159. std::ostream& stream,
  160. const TensorIndex& tensor_index);
  161. TORCH_API std::ostream& operator<<(
  162. std::ostream& stream,
  163. const std::vector<TensorIndex>& tensor_indices);
  164. namespace impl {
  165. static inline Tensor applySlice(
  166. const Tensor& self,
  167. int64_t dim,
  168. c10::SymInt start,
  169. c10::SymInt stop,
  170. c10::SymInt step,
  171. bool disable_slice_optimization,
  172. const at::Device& self_device,
  173. const c10::optional<SymIntArrayRef>& self_sizes) {
  174. // TODO: implement negative step
  175. TORCH_CHECK_VALUE(step > 0, "step must be greater than zero");
  176. // See NOTE [nested tensor size for indexing]
  177. if (self_sizes.has_value()) {
  178. // Skip this optimization if we are tracing, as the trace may be polymorphic
  179. // over the shape of the `self` tensor, and we still want to record
  180. // the slice.
  181. SymInt length = (self_device == at::kCPU || self_device == at::kCUDA)
  182. ? (*self_sizes)[dim]
  183. : self.sym_size(dim);
  184. if (!disable_slice_optimization && start == 0 && length == stop &&
  185. step == 1) {
  186. return self;
  187. }
  188. }
  189. return self.slice_symint(dim, start, stop, std::move(step));
  190. }
  191. static inline Tensor applySelect(
  192. const Tensor& self,
  193. int64_t dim,
  194. int64_t index,
  195. int64_t real_dim,
  196. const at::Device& /*self_device*/,
  197. const c10::optional<SymIntArrayRef>& self_sizes) {
  198. // See NOTE [nested tensor size for indexing]
  199. if (self_sizes.has_value()) {
  200. TORCH_CHECK_INDEX(
  201. !(index == 0 && dim == 0 && self_sizes->empty()),
  202. "invalid index of a 0-dim tensor. ",
  203. "Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to convert a 0-dim tensor to a number");
  204. auto size = (*self_sizes)[dim];
  205. TORCH_CHECK_INDEX(
  206. size >= -index && size > index,
  207. "index ",
  208. index,
  209. " is out of bounds for dimension ",
  210. real_dim,
  211. " with size ",
  212. size);
  213. }
  214. // if the index is negative, do not normalize it because that would fix the
  215. // index on the current tensor size in the tracer. aten::select also works on
  216. // negative indices
  217. return self.select(dim, index);
  218. }
  219. static inline Tensor boolToIndexingTensorCPUOrCUDA(
  220. const Tensor& self,
  221. bool value) {
  222. // booleans add a dimension of size 1. true indexes this dimension as if 0:,
  223. // false as empty.
  224. if (value) {
  225. return at::empty({1}, {}, self.options().dtype(kLong)).fill_(0.);
  226. } else {
  227. return at::empty({0}, {}, self.options().dtype(kLong));
  228. }
  229. }
  230. static inline Tensor boolToIndexingTensorNonNativeDeviceType(
  231. const Tensor& self,
  232. bool value) {
  233. // booleans add a dimension of size 1. true indexes this dimension as if 0:,
  234. // false as empty.
  235. if (value) {
  236. return at::zeros({1}, {}, self.options().dtype(kLong));
  237. } else {
  238. return at::empty({0}, {}, self.options().dtype(kLong));
  239. }
  240. }
  241. static inline Tensor boolToIndexingTensor(
  242. const Tensor& self,
  243. bool value,
  244. const at::Device& self_device) {
  245. if (self_device == at::kCPU || self_device == at::kCUDA) {
  246. return boolToIndexingTensorCPUOrCUDA(self, value);
  247. } else {
  248. return boolToIndexingTensorNonNativeDeviceType(self, value);
  249. }
  250. }
  251. static inline Tensor scalarToTensorNonNativeDeviceType(
  252. const Scalar& v,
  253. const TensorOptions& options) {
  254. return at::scalar_tensor(v, options);
  255. }
  256. static inline void recordTensorIndex(
  257. const Tensor& tensor,
  258. std::vector<Tensor>& outIndices,
  259. int64_t* dim_ptr) {
  260. // TODO: check scalarType
  261. outIndices.resize(*dim_ptr + 1);
  262. outIndices[*dim_ptr] = tensor;
  263. (*dim_ptr)++;
  264. };
  265. static inline c10::List<c10::optional<Tensor>> typeConvertIndices(
  266. const Tensor& /*self*/,
  267. std::vector<Tensor>&& indices) {
  268. c10::List<c10::optional<Tensor>> converted_inds;
  269. converted_inds.reserve(indices.size());
  270. for (const auto& i : indices) {
  271. converted_inds.push_back(std::move(i));
  272. }
  273. return converted_inds;
  274. }
  275. // NOTE: Why do we mirror instead of replace the `count_specified_dimensions`
  276. // function in torch/csrc/autograd/python_variable_indexing.cpp? It's because
  277. // `count_specified_dimensions` is on the hot path of Python tensor multi-dim
  278. // indexing (i.e. it's called by `applySlicing` which is called by
  279. // `THPVariable_getitem` / `THPVariable_setitem` when handling indexing of more
  280. // than one dimension). If we were to merge the Python/C++
  281. // `count_specified_dimensions` function, on the Python side we would have to
  282. // construct a `std::vector` container to be consumed by the C++
  283. // `count_specified_dimensions` function, which adds 100s of nanoseconds
  284. // overhead and is undesirable.
  285. static inline int64_t count_specified_dimensions(
  286. const ArrayRef<TensorIndex>& indices) {
  287. // Count the number of indexed dimensions (everything but ellipsis and None)
  288. int64_t count = 0;
  289. for (auto& obj : indices) {
  290. if (obj.is_tensor()) {
  291. auto& tensor = obj.tensor();
  292. if (tensor.scalar_type() == kByte || tensor.scalar_type() == kBool) {
  293. count += tensor.dim();
  294. } else {
  295. count++;
  296. }
  297. } else if (!obj.is_none() && !obj.is_ellipsis() && !obj.is_boolean()) {
  298. count++;
  299. }
  300. }
  301. return count;
  302. }
  303. } // namespace impl
  304. // NOTE: Many functions below are only for consumption from Python indexing
  305. // implementation, they include:
  306. //
  307. // - `Tensor scalarToTensor(...)`
  308. // - `IntArrayRef slicePrefix1sSize(...)`
  309. // - `void copy_to(...)`
  310. // - `Tensor handleDimInMultiDimIndexing(...)`
  311. // - `Tensor dispatch_index(...)`
  312. // - `Tensor dispatch_index_put_(...)`
  313. // - `Tensor get_item(...)`
  314. // - `void set_item(...)`
  315. //
  316. // The rest of the functions are in `at::indexing::impl` namespace, signifying
  317. // that they shouldn't be used from Python indexing implementation.
  318. static inline Tensor scalarToTensor(
  319. const Scalar& v,
  320. const TensorOptions& options,
  321. const at::Device& self_device) {
  322. if (self_device == at::kCPU) {
  323. return at::detail::scalar_tensor_static(
  324. v, options.dtype_opt()->toScalarType(), self_device);
  325. } else {
  326. return impl::scalarToTensorNonNativeDeviceType(v, options);
  327. }
  328. }
  329. // To match numpy semantics:
  330. // As a special case for backwards compatibility,
  331. // strip away unit dimensions from the left of 'src'
  332. static inline SymIntArrayRef slicePrefix1sSize(const SymIntArrayRef& sizes) {
  333. size_t first_non1_src = sizes.size();
  334. for (const auto i : c10::irange(sizes.size())) {
  335. if (sizes[i] != 1) {
  336. first_non1_src = i;
  337. break;
  338. }
  339. }
  340. return sizes.slice(first_non1_src);
  341. }
  342. static inline void copy_to(const Tensor& dst, const Tensor& src) {
  343. if (dst.sym_sizes().equals(src.sym_sizes())) {
  344. // A shortcut to avoid generating hard-coded constant sizes during tracing.
  345. // This is not a perfect solution: when src & dst have different shapes,
  346. // constants will still appear. Users can workaround that case by
  347. // dst[index..] = src.reshape(..)
  348. dst.copy_(src);
  349. return;
  350. } else if (src.dim() == 0 && src.device().type() == at::kCPU) {
  351. dst.fill_(src);
  352. return;
  353. }
  354. auto src_view = src.view_symint(slicePrefix1sSize(src.sym_sizes()));
  355. c10::MaybeOwned<Tensor> b_src = expand_inplace(dst, src_view, "setitem");
  356. dst.copy_(*b_src);
  357. }
  358. // See NOTE [ Setting `disable_slice_optimization` when calling C++ tensor
  359. // indexing functions from Python ]
  360. static inline Tensor handleDimInMultiDimIndexing(
  361. const Tensor& prev_dim_result,
  362. const Tensor& original_tensor,
  363. const TensorIndex& index,
  364. int64_t* dim_ptr,
  365. int64_t* specified_dims_ptr,
  366. int64_t real_dim,
  367. std::vector<Tensor>& outIndices,
  368. bool disable_slice_optimization,
  369. const at::Device& original_tensor_device,
  370. const c10::optional<SymIntArrayRef>& prev_dim_result_sizes) {
  371. if (index.is_integer()) {
  372. return impl::applySelect(
  373. prev_dim_result,
  374. *dim_ptr,
  375. index.integer(),
  376. real_dim,
  377. original_tensor_device,
  378. prev_dim_result_sizes);
  379. } else if (index.is_slice()) {
  380. Tensor result = impl::applySlice(
  381. prev_dim_result,
  382. *dim_ptr,
  383. index.slice().start(),
  384. index.slice().stop(),
  385. index.slice().step(),
  386. /*disable_slice_optimization=*/disable_slice_optimization,
  387. original_tensor_device,
  388. prev_dim_result_sizes);
  389. (*dim_ptr)++;
  390. return result;
  391. } else if (index.is_ellipsis()) {
  392. (*dim_ptr) += original_tensor.dim() - (*specified_dims_ptr);
  393. return prev_dim_result;
  394. } else if (index.is_none()) {
  395. Tensor result = prev_dim_result.unsqueeze(*dim_ptr);
  396. (*dim_ptr)++;
  397. return result;
  398. } else if (index.is_boolean()) {
  399. Tensor result = prev_dim_result.unsqueeze(*dim_ptr);
  400. impl::recordTensorIndex(
  401. impl::boolToIndexingTensor(
  402. result, index.boolean(), original_tensor_device),
  403. outIndices,
  404. dim_ptr);
  405. return result;
  406. } else if (index.is_tensor()) {
  407. Tensor result = prev_dim_result;
  408. const Tensor& tensor = index.tensor();
  409. auto scalar_type = tensor.scalar_type();
  410. if (tensor.dim() == 0 &&
  411. at::isIntegralType(scalar_type, /*includeBool=*/true)) {
  412. if (scalar_type != at::kByte && scalar_type != at::kBool) {
  413. result = impl::applySelect(
  414. result,
  415. *dim_ptr,
  416. tensor.item<int64_t>(),
  417. real_dim,
  418. original_tensor_device,
  419. prev_dim_result_sizes);
  420. } else {
  421. result = result.unsqueeze(*dim_ptr);
  422. if (scalar_type == at::kBool) {
  423. impl::recordTensorIndex(
  424. impl::boolToIndexingTensor(
  425. result, tensor.item<bool>() != 0, original_tensor_device),
  426. outIndices,
  427. dim_ptr);
  428. } else {
  429. impl::recordTensorIndex(
  430. impl::boolToIndexingTensor(
  431. result, tensor.item<uint8_t>() != 0, original_tensor_device),
  432. outIndices,
  433. dim_ptr);
  434. }
  435. }
  436. } else {
  437. impl::recordTensorIndex(tensor, outIndices, dim_ptr);
  438. }
  439. return result;
  440. } else {
  441. TORCH_INTERNAL_ASSERT(false, "Invalid TensorIndex type");
  442. }
  443. }
  444. namespace impl {
  445. // This mirrors `applySlicing` in
  446. // torch/csrc/autograd/python_variable_indexing.cpp
  447. static inline Tensor applySlicing(
  448. const Tensor& self,
  449. const ArrayRef<TensorIndex>& indices,
  450. std::vector<Tensor>& outIndices,
  451. bool disable_slice_optimization,
  452. const at::Device& self_device,
  453. const c10::optional<SymIntArrayRef>& self_sizes) {
  454. int64_t dim = 0;
  455. int64_t specified_dims = impl::count_specified_dimensions(indices);
  456. // See NOTE [nested tensor size for indexing]
  457. if (self_sizes.has_value()) {
  458. TORCH_CHECK_INDEX(
  459. specified_dims <= (int64_t)self_sizes->size(),
  460. "too many indices for tensor of dimension ",
  461. (int)self_sizes->size());
  462. }
  463. Tensor result = self;
  464. for (const auto i : c10::irange(indices.size())) {
  465. auto& obj = indices[i];
  466. // See NOTE [nested tensor size for indexing]
  467. c10::optional<SymIntArrayRef> result_sizes = result.is_nested()
  468. ? c10::optional<SymIntArrayRef>(c10::nullopt)
  469. : c10::optional<SymIntArrayRef>(result.sym_sizes());
  470. result = handleDimInMultiDimIndexing(
  471. /*prev_dim_result=*/result,
  472. /*original_tensor=*/self,
  473. /*index=*/obj,
  474. /*dim=*/&dim,
  475. /*specified_dims=*/&specified_dims,
  476. /*real_dim=*/i,
  477. /*outIndices=*/outIndices,
  478. /*disable_slice_optimization=*/disable_slice_optimization,
  479. /*original_tensor_device=*/self_device,
  480. /*prev_dim_result_sizes=*/result_sizes);
  481. }
  482. return result;
  483. }
  484. } // namespace impl
  485. static inline Tensor dispatch_index(
  486. const Tensor& self,
  487. std::vector<Tensor>&& indices) {
  488. return self.index(impl::typeConvertIndices(self, std::move(indices)));
  489. }
  490. static inline Tensor dispatch_index_put_(
  491. Tensor& self,
  492. std::vector<Tensor>&& indices,
  493. const Tensor& value) {
  494. return self.index_put_(
  495. impl::typeConvertIndices(self, std::move(indices)), value);
  496. }
  497. // NOTE [ Setting `disable_slice_optimization` when calling C++ tensor indexing
  498. // functions from Python ]
  499. //
  500. // Question: When should we set `disable_slice_optimization` to `true` when
  501. // calling C++ tensor indexing functions from Python indexing code?
  502. //
  503. // Answer: What "slice optimization" means: when we have a slicing expression
  504. // like `x[0:5, 0]`, where the sliced tensor was of size 5 in dimension 0, we
  505. // would skip dispatching the actual slice call as an optimization. However,
  506. // here are the cases where we DON'T want this optimization:
  507. //
  508. // 1. When we are doing 1-D slicing (e.g. `tensor[:]`).
  509. // Reason: we always return a shallow copy for expressions such as
  510. // `tensor[:]` / `tensor[...]` / `tensor[:, :]`. (Note that for `tensor[:,
  511. // :]`, we return an alias of `tensor` by doing the following:
  512. // ```
  513. // Tensor sliced = impl::applySlicing(self, indices, tensorIndices,
  514. // disable_slice_optimization, self_device, self_sizes); if
  515. // (tensorIndices.empty()) {
  516. // if (sliced.is_same(self)) {
  517. // // ensure we return a shallow copy for things like x[...]
  518. // sliced = at::alias(sliced);
  519. // }
  520. // return sliced;
  521. // }
  522. // ```)
  523. // 2. When we are doing JIT tracing.
  524. // Reason: JIT tracing needs the `self.slice(...)` call to properly trace the
  525. // slice operation.
  526. // This mirrors `THPVariable_getitem` in
  527. // torch/csrc/autograd/python_variable_indexing.cpp See NOTE [ Setting
  528. // `disable_slice_optimization` when calling C++ tensor indexing functions from
  529. // Python ]
  530. static inline Tensor get_item(
  531. const Tensor& self,
  532. const ArrayRef<TensorIndex>& indices,
  533. bool disable_slice_optimization = false) {
  534. at::Device self_device = self.device();
  535. // NOTE [nested tensor size for indexing]
  536. // nested tensor does not have a size (yet) so for now we represent its size
  537. // as null may need to be changed after we reach a better solution for nested
  538. // tensor size
  539. c10::optional<SymIntArrayRef> self_sizes = self.is_nested()
  540. ? c10::optional<SymIntArrayRef>(c10::nullopt)
  541. : c10::optional<SymIntArrayRef>(self.sym_sizes());
  542. // handle simple types: integers, slices, none, ellipsis, bool
  543. if (indices.size() == 1) {
  544. const TensorIndex& index = indices[0];
  545. if (index.is_integer()) {
  546. return impl::applySelect(
  547. self, 0, index.integer(), 0, self_device, self_sizes);
  548. } else if (index.is_slice()) {
  549. return impl::applySlice(
  550. self,
  551. 0,
  552. index.slice().start(),
  553. index.slice().stop(),
  554. index.slice().step(),
  555. /*disable_slice_optimization=*/true,
  556. self_device,
  557. self_sizes);
  558. } else if (index.is_none()) {
  559. return self.unsqueeze(0);
  560. } else if (index.is_ellipsis()) {
  561. return at::alias(self);
  562. } else if (index.is_boolean()) {
  563. Tensor result = self.unsqueeze(0);
  564. return dispatch_index(
  565. result,
  566. std::vector<Tensor>{impl::boolToIndexingTensor(
  567. result, index.boolean(), self_device)});
  568. }
  569. }
  570. std::vector<Tensor> tensorIndices;
  571. Tensor sliced = impl::applySlicing(
  572. self,
  573. indices,
  574. tensorIndices,
  575. disable_slice_optimization,
  576. self_device,
  577. self_sizes);
  578. if (tensorIndices.empty()) {
  579. if (sliced.is_same(self)) {
  580. // ensure we return a shallow copy for things like x[...]
  581. sliced = at::alias(sliced);
  582. }
  583. return sliced;
  584. }
  585. // indexing by tensors ("advanced" indexing)
  586. return dispatch_index(sliced, std::move(tensorIndices));
  587. }
  588. // This mirrors `THPVariable_setitem` in
  589. // torch/csrc/autograd/python_variable_indexing.cpp for "the assigned value is a
  590. // Tensor" case See NOTE [ Setting `disable_slice_optimization` when calling C++
  591. // tensor indexing functions from Python ]
  592. static inline void set_item(
  593. const Tensor& self,
  594. const ArrayRef<TensorIndex>& indices,
  595. const Tensor& value,
  596. bool disable_slice_optimization = false) {
  597. at::Device self_device = self.device();
  598. SymIntArrayRef self_sizes = self.sym_sizes();
  599. // handle simple types: integers, slices, ellipsis, bool
  600. if (indices.size() == 1) {
  601. const TensorIndex& index = indices[0];
  602. if (index.is_boolean() && !index.boolean()) {
  603. // do nothing for false (technically we should check the size, but we
  604. // don't have real 0-sized shapes.
  605. return;
  606. } else if (index.is_ellipsis()) {
  607. copy_to(self, value);
  608. return;
  609. } else if (index.is_none() || (index.is_boolean() && index.boolean())) {
  610. copy_to(self.unsqueeze(0), value);
  611. return;
  612. } else if (index.is_integer()) {
  613. copy_to(
  614. impl::applySelect(
  615. self, 0, index.integer(), 0, self_device, self_sizes),
  616. value);
  617. return;
  618. } else if (index.is_slice()) {
  619. copy_to(
  620. impl::applySlice(
  621. self,
  622. 0,
  623. index.slice().start(),
  624. index.slice().stop(),
  625. index.slice().step(),
  626. /*disable_slice_optimization=*/disable_slice_optimization,
  627. self_device,
  628. self_sizes),
  629. value);
  630. return;
  631. }
  632. }
  633. std::vector<Tensor> tensorIndices;
  634. Tensor sliced = impl::applySlicing(
  635. self,
  636. indices,
  637. tensorIndices,
  638. disable_slice_optimization,
  639. self_device,
  640. self_sizes);
  641. if (tensorIndices.empty()) {
  642. copy_to(sliced, value);
  643. return;
  644. }
  645. SymIntArrayRef valueSizes = value.sym_sizes();
  646. SymIntArrayRef slicedValueSizes = slicePrefix1sSize(valueSizes);
  647. Tensor valuesSliced;
  648. if (!valueSizes.equals(slicedValueSizes)) {
  649. valuesSliced = value.view_symint(slicedValueSizes);
  650. } else {
  651. valuesSliced = value;
  652. }
  653. dispatch_index_put_(sliced, std::move(tensorIndices), valuesSliced);
  654. return;
  655. }
  656. } // namespace indexing
  657. } // namespace at