ScatterGatherChecks.h 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. #pragma once
  2. #include <vector>
  3. #include <ATen/core/Tensor.h>
  4. #include <ATen/native/ReduceOpsUtils.h>
  5. #include <c10/util/irange.h>
  6. namespace at { namespace native {
  7. namespace {
  8. // checks whether index.dtype == int64
  9. // and self.dtype == src.dtype if src is a Tensor
  10. static void scatter_gather_dtype_check(
  11. const std::string& method_name,
  12. const Tensor& self,
  13. const Tensor& index,
  14. const c10::optional<Tensor>& src_opt = c10::nullopt
  15. ) {
  16. if (index.numel() != 0) {
  17. TORCH_CHECK(
  18. index.scalar_type() == at::ScalarType::Long,
  19. method_name, "(): Expected dtype int64 for index"
  20. );
  21. }
  22. if (src_opt.has_value()) {
  23. const auto& src = src_opt.value();
  24. TORCH_CHECK(
  25. self.scalar_type() == src.scalar_type(),
  26. method_name, "(): Expected self.dtype to be equal to src.dtype"
  27. );
  28. }
  29. }
  30. // Used for `gather`-like methods
  31. // Note: self means the input tensor here
  32. // Test:
  33. // 1. index.size(d) <= self.size(d) for all d != dim
  34. // 2. index.dim() == self.dim()
  35. static C10_UNUSED void gather_shape_check(const Tensor& self, int64_t dim,
  36. const Tensor& index
  37. ) {
  38. auto self_dims = ensure_nonempty_dim(self.dim());
  39. TORCH_CHECK(self_dims == ensure_nonempty_dim(index.dim()),
  40. "Index tensor must have the same number of dimensions as input tensor"
  41. );
  42. for (const auto i : c10::irange(self_dims)) {
  43. if (i != dim) {
  44. TORCH_CHECK(
  45. ensure_nonempty_size(index, i) <= ensure_nonempty_size(self, i),
  46. "Size does not match at dimension ", i,
  47. " expected index ", index.sizes(),
  48. " to be smaller than self ", self.sizes(),
  49. " apart from dimension ", dim
  50. );
  51. }
  52. }
  53. }
  54. // Used for `scatter` and `scatter_add`
  55. // Tests:
  56. // 1. index.size(d) <= self.size(d) for all d != dim
  57. // 2. index.size(d) <= src.size(d) for all d if src is a Tensor
  58. // 3. index.dim() == self.dim() == src.dim()
  59. static C10_UNUSED void scatter_shape_check(
  60. const Tensor& self, int64_t dim, const Tensor& index,
  61. const c10::optional<Tensor>& src_opt = c10::nullopt
  62. ) {
  63. if (index.numel() == 0) return;
  64. TORCH_CHECK(
  65. ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
  66. "Index tensor must have the same number of dimensions as self tensor"
  67. );
  68. bool is_wrong_shape = false;
  69. int64_t self_dims = ensure_nonempty_dim(self.dim());
  70. // Check: index.size(d) <= self.size(d) for all d != dim
  71. for (const auto d : c10::irange(self_dims)) {
  72. int64_t index_d_size = ensure_nonempty_size(index, d);
  73. if (d == dim) continue;
  74. if (index_d_size > ensure_nonempty_size(self, d)) {
  75. is_wrong_shape = true;
  76. break;
  77. }
  78. }
  79. // Check: index.size(d) <= src.size(d) for all d if src is Tensor
  80. if (!is_wrong_shape && src_opt.has_value()) {
  81. const auto& src = src_opt.value();
  82. for (const auto d : c10::irange(self_dims)) {
  83. int64_t index_d_size = ensure_nonempty_size(index, d);
  84. if (index_d_size > ensure_nonempty_size(src, d)) {
  85. is_wrong_shape = true;
  86. break;
  87. }
  88. }
  89. }
  90. if (src_opt.has_value()) {
  91. const auto& src = src_opt.value();
  92. TORCH_CHECK(
  93. ensure_nonempty_dim(src.dim()) == ensure_nonempty_dim(index.dim()),
  94. "Index tensor must have the same number of dimensions as src tensor"
  95. );
  96. TORCH_CHECK(!is_wrong_shape,
  97. "Expected index ", index.sizes(),
  98. " to be smaller than self ", self.sizes(),
  99. " apart from dimension ", dim,
  100. " and to be smaller size than src ", src.sizes()
  101. );
  102. }
  103. else {
  104. TORCH_CHECK(!is_wrong_shape,
  105. "Expected index ", index.sizes(),
  106. " to be smaller than self ", self.sizes(),
  107. " apart from dimension ", dim
  108. );
  109. }
  110. }
  111. } // anonymous namespace
  112. }} // namespace at::native