TensorAdvancedIndexing.h 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. #pragma once
  2. // Indexing tensors by tensors
  3. #include <ATen/core/List.h>
  4. #include <ATen/core/Tensor.h>
  5. #include <ATen/native/DispatchStub.h>
  6. #include <ATen/native/ReductionType.h>
  7. #include <ATen/native/cpu/radix_sort.h>
  8. namespace at {
  9. struct TensorIterator;
  10. }
  11. namespace at { namespace native {
  12. using index_put_with_sort_fn = void(*)(Tensor &, const c10::List<c10::optional<Tensor>> &, const Tensor &, bool accumulate, bool unsafe);
  13. using index_put_with_sort_quantized_fn = void(*)(Tensor& self, const c10::List<c10::optional<Tensor>>& indices, const Tensor& value, double scale, int zero_point, bool unsafe);
  14. using gather_fn = void (*)(const Tensor & result, const Tensor & self, int64_t dim, const Tensor & index);
  15. using scatter_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src);
  16. using scatter_fill_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Scalar& src);
  17. using scatter_add_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src);
  18. using scatter_reduce_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index,
  19. const Tensor& src, const ReductionType& reduce);
  20. using scatter_scalar_reduce_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index,
  21. const Scalar& value, const ReductionType& reduce);
  22. using scatter_reduce_two_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index,
  23. const Tensor& src, const ReductionType& reduce);
  24. DECLARE_DISPATCH(index_put_with_sort_fn, index_put_with_sort_stub);
  25. DECLARE_DISPATCH(index_put_with_sort_quantized_fn, index_put_with_sort_quantized_stub);
  26. DECLARE_DISPATCH(gather_fn, gather_stub);
  27. DECLARE_DISPATCH(scatter_fn, scatter_stub);
  28. DECLARE_DISPATCH(scatter_fill_fn, scatter_fill_stub);
  29. DECLARE_DISPATCH(scatter_add_fn, scatter_add_stub);
  30. DECLARE_DISPATCH(scatter_reduce_fn, scatter_reduce_stub);
  31. DECLARE_DISPATCH(scatter_scalar_reduce_fn, scatter_scalar_reduce_stub);
  32. DECLARE_DISPATCH(scatter_reduce_two_fn, scatter_reduce_two_stub);
  33. TORCH_API Tensor& index_out(Tensor& result, const Tensor & self, const c10::List<c10::optional<at::Tensor>>& indices);
  34. // fast paths for GNN usage
  35. static inline bool can_use_expanded_index_path(
  36. const Tensor& self,
  37. int64_t dim,
  38. const Tensor& index,
  39. const Tensor& src,
  40. bool is_scatter_like) {
  41. if (!self.device().is_cpu()) {
  42. return false;
  43. }
  44. const auto st = self.scalar_type();
  45. if (!(c10::isFloatingType(st)) || st == ScalarType::Half) {
  46. return false;
  47. }
  48. if (!is_radix_sort_available()) {
  49. return false;
  50. }
  51. // skip when having empty tensor
  52. if (self.numel() == 0 || index.numel() == 0 || src.numel() == 0) {
  53. return false;
  54. }
  55. // skip when having scalar tensor
  56. if (self.ndimension() == 0 || index.ndimension() == 0 || src.ndimension() == 0) {
  57. return false;
  58. }
  59. if (is_scatter_like) {
  60. // using `spmm` for scatter would require sorting on index,
  61. // this is only perf beneficial when the inner dimension, aka, `channels`
  62. // is big enough.
  63. constexpr int64_t threshold = 16;
  64. if (index.numel() / index.size(0) < threshold) {
  65. return false;
  66. }
  67. }
  68. // usually the expanded index has stride on the first dimension to be 1,
  69. // and strides on other dims to be 0 or 1, e.g.
  70. // shape [108365, 16]; strides [1, 0]
  71. // shape [13264, 1, 7]; strides [1, 1, 0]
  72. auto index_strides = index.strides().vec();
  73. bool is_index_expanded = index_strides[0] == 1;
  74. for (const auto dim : c10::irange(1, index_strides.size())) {
  75. if (index_strides[dim] > 1) { is_index_expanded = false; }
  76. }
  77. // index is expanded
  78. return dim == 0 && is_index_expanded && src.is_contiguous() && self.is_contiguous();
  79. }
  80. using scatter_add_expanded_index_fn = void(*)(const Tensor&, const Tensor&, const Tensor&);
  81. using scatter_reduce_expanded_index_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const ReductionType& reduce, bool);
  82. using gather_expanded_index_fn = void (*)(const Tensor&, const Tensor&, const Tensor&);
  83. DECLARE_DISPATCH(scatter_add_expanded_index_fn, scatter_add_expanded_index_stub);
  84. DECLARE_DISPATCH(scatter_reduce_expanded_index_fn, scatter_reduce_expanded_index_stub);
  85. DECLARE_DISPATCH(gather_expanded_index_fn, gather_expanded_index_stub);
  86. }} // namespace at::native