TensorAdvancedIndexingUtils.h 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. #pragma once
  2. #include <ATen/core/Tensor.h>
  3. #include <ATen/native/IndexingUtils.h>
  4. #include <ATen/native/TensorIterator.h>
  5. namespace at {
  6. namespace native {
  7. namespace {
  8. static std::string shapes_as_str(TensorList tensors) {
  9. std::ostringstream os;
  10. bool first = true;
  11. for (auto& tensor : tensors) {
  12. if (tensor.defined()) {
  13. if (!first) {
  14. os << ", ";
  15. }
  16. os << tensor.sizes();
  17. first = false;
  18. }
  19. }
  20. return os.str();
  21. }
  22. } // anonymous namespace
  23. static std::tuple<bool, Tensor> canDispatchToMaskedFill(const Tensor& self, const torch::List<c10::optional<at::Tensor>>& indices,
  24. const Tensor& value){
  25. if (!(value.numel() ==1 && value.device().is_cpu())){
  26. return std::make_tuple(false,Tensor());
  27. }
  28. int64_t num_ind = 0;
  29. Tensor mask;
  30. auto self_device = self.device();
  31. for (const c10::optional<Tensor> i: indices) {
  32. if (!i.has_value() || !(*i).defined()){
  33. num_ind++;
  34. } else {
  35. Tensor index = std::move(*i);
  36. if ((index.scalar_type() != kByte && index.scalar_type() != kBool) ||
  37. index.device() != self_device || mask.defined()){
  38. return std::make_tuple(false, Tensor());
  39. } else {
  40. mask = index;
  41. for (const auto j : c10::irange(index.dim())) {
  42. int64_t srcIdx = num_ind + j;
  43. TORCH_CHECK_INDEX(index.size(j) == self.size(srcIdx), "The shape of the mask ", index.sizes(), " at index ", j,
  44. " does not match the shape of the indexed tensor ", self.sizes(), " at index ", srcIdx);
  45. }
  46. num_ind += mask.ndimension();
  47. }
  48. }
  49. }
  50. for (const auto i : c10::irange(num_ind, self.ndimension())) {
  51. (void)i; //Suppress unused variable warning
  52. mask = mask.unsqueeze(-1);
  53. }
  54. return std::make_tuple(true, mask);
  55. }
  56. static AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) {
  57. checkIndexTensorTypes(orig, /*allow_int*/ true);
  58. // first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
  59. auto indices = expandTensors(self, orig);
  60. // next broadcast all index tensors together
  61. try {
  62. indices = expand_outplace(indices);
  63. } catch (std::exception& e) {
  64. TORCH_CHECK_INDEX(false, "shape mismatch: indexing tensors could not be broadcast together"
  65. " with shapes ", shapes_as_str(indices));
  66. }
  67. // add missing null Tensors so that it matches self.dim()
  68. while (indices.size() < (size_t)self.dim()) {
  69. indices.emplace_back();
  70. }
  71. // if the non-null indices are not all adjacent, transpose self and indices
  72. // together so that they're adjacent at the front
  73. if (!hasContiguousSubspace(indices)) {
  74. std::tie(self, indices) = transposeToFront(self, indices);
  75. }
  76. // Ensure indices are on the same device as self
  77. for (auto & indice : indices) {
  78. if (indice.defined() && indice.device() != self.device()) {
  79. indice = indice.to(self.device());
  80. }
  81. }
  82. for (auto & indice : indices) {
  83. if (indice.defined() && indice.dtype() == at::kInt) {
  84. indice = indice.to(at::kLong);
  85. }
  86. }
  87. return AdvancedIndex(self, indices);
  88. }
  89. } // at
  90. } // native