#pragma once // Indexing tensors by tensors #include #include #include #include #include namespace at { struct TensorIterator; } namespace at { namespace native { using index_put_with_sort_fn = void(*)(Tensor &, const c10::List> &, const Tensor &, bool accumulate, bool unsafe); using index_put_with_sort_quantized_fn = void(*)(Tensor& self, const c10::List>& indices, const Tensor& value, double scale, int zero_point, bool unsafe); using gather_fn = void (*)(const Tensor & result, const Tensor & self, int64_t dim, const Tensor & index); using scatter_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src); using scatter_fill_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Scalar& src); using scatter_add_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src); using scatter_reduce_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index, const Tensor& src, const ReductionType& reduce); using scatter_scalar_reduce_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index, const Scalar& value, const ReductionType& reduce); using scatter_reduce_two_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index, const Tensor& src, const ReductionType& reduce); DECLARE_DISPATCH(index_put_with_sort_fn, index_put_with_sort_stub); DECLARE_DISPATCH(index_put_with_sort_quantized_fn, index_put_with_sort_quantized_stub); DECLARE_DISPATCH(gather_fn, gather_stub); DECLARE_DISPATCH(scatter_fn, scatter_stub); DECLARE_DISPATCH(scatter_fill_fn, scatter_fill_stub); DECLARE_DISPATCH(scatter_add_fn, scatter_add_stub); DECLARE_DISPATCH(scatter_reduce_fn, scatter_reduce_stub); DECLARE_DISPATCH(scatter_scalar_reduce_fn, scatter_scalar_reduce_stub); DECLARE_DISPATCH(scatter_reduce_two_fn, scatter_reduce_two_stub); TORCH_API Tensor& index_out(Tensor& result, const Tensor & self, const c10::List>& indices); // fast paths for GNN usage static inline bool can_use_expanded_index_path( const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src, bool is_scatter_like) { if (!self.device().is_cpu()) { return false; } const auto st = self.scalar_type(); if (!(c10::isFloatingType(st)) || st == ScalarType::Half) { return false; } if (!is_radix_sort_available()) { return false; } // skip when having empty tensor if (self.numel() == 0 || index.numel() == 0 || src.numel() == 0) { return false; } // skip when having scalar tensor if (self.ndimension() == 0 || index.ndimension() == 0 || src.ndimension() == 0) { return false; } if (is_scatter_like) { // using `spmm` for scatter would require sorting on index, // this is only perf beneficial when the inner dimension, aka, `channels` // is big enough. constexpr int64_t threshold = 16; if (index.numel() / index.size(0) < threshold) { return false; } } // usually the expanded index has stride on the first dimension to be 1, // and strides on other dims to be 0 or 1, e.g. // shape [108365, 16]; strides [1, 0] // shape [13264, 1, 7]; strides [1, 1, 0] auto index_strides = index.strides().vec(); bool is_index_expanded = index_strides[0] == 1; for (const auto dim : c10::irange(1, index_strides.size())) { if (index_strides[dim] > 1) { is_index_expanded = false; } } // index is expanded return dim == 0 && is_index_expanded && src.is_contiguous() && self.is_contiguous(); } using scatter_add_expanded_index_fn = void(*)(const Tensor&, const Tensor&, const Tensor&); using scatter_reduce_expanded_index_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const ReductionType& reduce, bool); using gather_expanded_index_fn = void (*)(const Tensor&, const Tensor&, const Tensor&); DECLARE_DISPATCH(scatter_add_expanded_index_fn, scatter_add_expanded_index_stub); DECLARE_DISPATCH(scatter_reduce_expanded_index_fn, scatter_reduce_expanded_index_stub); DECLARE_DISPATCH(gather_expanded_index_fn, gather_expanded_index_stub); }} // namespace at::native