IndexKernel.h 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. #pragma once
  2. #include <ATen/native/DispatchStub.h>
  3. #include <c10/util/ArrayRef.h>
  4. namespace at {
  5. class Tensor;
  6. class TensorBase;
  7. struct TensorIterator;
  8. struct TensorIteratorBase;
  9. }
  10. namespace c10 {
  11. class Scalar;
  12. }
  13. namespace at { namespace native {
  14. using index_fn = void(*)(TensorIteratorBase &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides);
  15. using index_fill_fn = void(*)(TensorIterator & iter, int64_t dim, int64_t self_dim_size, int64_t self_dim_stride, const Scalar& source);
  16. using index_copy_fn = void(*)(TensorIterator & iter, int64_t dim, int64_t self_dim_size, int64_t self_dim_stride);
  17. using index_put_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides, bool accumulate);
  18. using put_fn = void(*)(TensorIterator & iter, const TensorBase& self, const bool accumulate);
  19. using take_fn = void(*)(TensorIterator & iter, const TensorBase& input);
  20. using flip_fn = void(*)(TensorIterator &, const bool);
  21. using masked_fill_fn = void(*)(TensorIterator &, const Scalar& scalar);
  22. using masked_select_fn = void(*)(TensorIterator &, int64_t orig_stride);
  23. using masked_scatter_fn = void(*)(TensorIterator &, const TensorBase &);
  24. DECLARE_DISPATCH(index_fn, index_stub);
  25. DECLARE_DISPATCH(index_fill_fn, index_fill_stub);
  26. DECLARE_DISPATCH(index_copy_fn, index_copy_stub);
  27. DECLARE_DISPATCH(index_put_fn, index_put_stub);
  28. DECLARE_DISPATCH(put_fn, put_stub);
  29. DECLARE_DISPATCH(take_fn, take_stub);
  30. DECLARE_DISPATCH(flip_fn, flip_stub);
  31. DECLARE_DISPATCH(masked_fill_fn, masked_fill_stub);
  32. DECLARE_DISPATCH(masked_select_fn, masked_select_serial_stub);
  33. DECLARE_DISPATCH(masked_select_fn, masked_select_stub);
  34. DECLARE_DISPATCH(masked_scatter_fn, masked_scatter_stub);
  35. }} // namespace at::native