IndexKernel.h 567 B

1234567891011121314
  1. #pragma once
  2. #include <ATen/native/TensorIterator.h>
  3. namespace at {
  4. namespace native {
  5. using masked_fill_kernel_quantized_fn = void(*)(TensorIterator& iter, const Scalar& value, double scale, int zero_point);
  6. using index_put_kernel_quantized_fn = void(*)(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate, double scale, int zero_point);
  7. DECLARE_DISPATCH(masked_fill_kernel_quantized_fn, masked_fill_kernel_quantized_stub);
  8. DECLARE_DISPATCH(index_put_kernel_quantized_fn, index_put_kernel_quantized_stub);
  9. } // native
  10. } // at