#pragma once #include #include #include #include namespace at { namespace native { namespace { // checks whether index.dtype == int64 // and self.dtype == src.dtype if src is a Tensor static void scatter_gather_dtype_check( const std::string& method_name, const Tensor& self, const Tensor& index, const c10::optional& src_opt = c10::nullopt ) { if (index.numel() != 0) { TORCH_CHECK( index.scalar_type() == at::ScalarType::Long, method_name, "(): Expected dtype int64 for index" ); } if (src_opt.has_value()) { const auto& src = src_opt.value(); TORCH_CHECK( self.scalar_type() == src.scalar_type(), method_name, "(): Expected self.dtype to be equal to src.dtype" ); } } // Used for `gather`-like methods // Note: self means the input tensor here // Test: // 1. index.size(d) <= self.size(d) for all d != dim // 2. index.dim() == self.dim() static C10_UNUSED void gather_shape_check(const Tensor& self, int64_t dim, const Tensor& index ) { auto self_dims = ensure_nonempty_dim(self.dim()); TORCH_CHECK(self_dims == ensure_nonempty_dim(index.dim()), "Index tensor must have the same number of dimensions as input tensor" ); for (const auto i : c10::irange(self_dims)) { if (i != dim) { TORCH_CHECK( ensure_nonempty_size(index, i) <= ensure_nonempty_size(self, i), "Size does not match at dimension ", i, " expected index ", index.sizes(), " to be smaller than self ", self.sizes(), " apart from dimension ", dim ); } } } // Used for `scatter` and `scatter_add` // Tests: // 1. index.size(d) <= self.size(d) for all d != dim // 2. index.size(d) <= src.size(d) for all d if src is a Tensor // 3. index.dim() == self.dim() == src.dim() static C10_UNUSED void scatter_shape_check( const Tensor& self, int64_t dim, const Tensor& index, const c10::optional& src_opt = c10::nullopt ) { if (index.numel() == 0) return; TORCH_CHECK( ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()), "Index tensor must have the same number of dimensions as self tensor" ); bool is_wrong_shape = false; int64_t self_dims = ensure_nonempty_dim(self.dim()); // Check: index.size(d) <= self.size(d) for all d != dim for (const auto d : c10::irange(self_dims)) { int64_t index_d_size = ensure_nonempty_size(index, d); if (d == dim) continue; if (index_d_size > ensure_nonempty_size(self, d)) { is_wrong_shape = true; break; } } // Check: index.size(d) <= src.size(d) for all d if src is Tensor if (!is_wrong_shape && src_opt.has_value()) { const auto& src = src_opt.value(); for (const auto d : c10::irange(self_dims)) { int64_t index_d_size = ensure_nonempty_size(index, d); if (index_d_size > ensure_nonempty_size(src, d)) { is_wrong_shape = true; break; } } } if (src_opt.has_value()) { const auto& src = src_opt.value(); TORCH_CHECK( ensure_nonempty_dim(src.dim()) == ensure_nonempty_dim(index.dim()), "Index tensor must have the same number of dimensions as src tensor" ); TORCH_CHECK(!is_wrong_shape, "Expected index ", index.sizes(), " to be smaller than self ", self.sizes(), " apart from dimension ", dim, " and to be smaller size than src ", src.sizes() ); } else { TORCH_CHECK(!is_wrong_shape, "Expected index ", index.sizes(), " to be smaller than self ", self.sizes(), " apart from dimension ", dim ); } } } // anonymous namespace }} // namespace at::native