PointwiseOps.h 786 B

12345678910111213141516171819202122232425262728
  1. // Ternary and higher-order pointwise operations
  2. #pragma once
  3. #include <ATen/native/DispatchStub.h>
  4. namespace c10 {
  5. class Scalar;
  6. }
  7. namespace at {
  8. struct TensorIterator;
  9. struct TensorIteratorBase;
  10. namespace native {
  11. using pointwise_fn = void (*)(TensorIterator&, const Scalar& scalar);
  12. using structured_pointwise_fn = void (*)(TensorIteratorBase&, const Scalar& scalar);
  13. using pointwise_fn_double = void (*)(TensorIterator&, const Scalar&, double);
  14. DECLARE_DISPATCH(structured_pointwise_fn, addcmul_stub);
  15. DECLARE_DISPATCH(structured_pointwise_fn, addcdiv_stub);
  16. DECLARE_DISPATCH(pointwise_fn_double, smooth_l1_backward_stub);
  17. DECLARE_DISPATCH(pointwise_fn_double, huber_backward_stub);
  18. DECLARE_DISPATCH(pointwise_fn, mse_backward_stub);
  19. } // namespace native
  20. } // namespace at