TensorCompare.h 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. #pragma once
  2. #include <ATen/native/DispatchStub.h>
  3. namespace c10 {
  4. class Scalar;
  5. }
  6. namespace at {
  7. class Tensor;
  8. struct TensorIterator;
  9. struct TensorIteratorBase;
  10. }
  11. namespace at { namespace native {
  12. using reduce_minmax_fn =
  13. void (*)(Tensor&, Tensor&, const Tensor&, int64_t, bool);
  14. using structured_reduce_minmax_fn =
  15. void (*)(const Tensor&, const Tensor&, const Tensor&, int64_t, bool);
  16. DECLARE_DISPATCH(structured_reduce_minmax_fn, max_stub);
  17. DECLARE_DISPATCH(structured_reduce_minmax_fn, min_stub);
  18. using where_fn = void (*)(TensorIterator &);
  19. DECLARE_DISPATCH(where_fn, where_kernel);
  20. using is_infinity_op_fn = void (*)(TensorIteratorBase &);
  21. DECLARE_DISPATCH(is_infinity_op_fn, isposinf_stub);
  22. DECLARE_DISPATCH(is_infinity_op_fn, isneginf_stub);
  23. using mode_fn = void (*)(Tensor&, Tensor&, const Tensor&, int64_t, bool);
  24. DECLARE_DISPATCH(mode_fn, mode_stub);
  25. using clamp_tensor_fn = void (*)(TensorIteratorBase &);
  26. DECLARE_DISPATCH(clamp_tensor_fn, clamp_stub);
  27. namespace detail {
  28. enum class ClampLimits {Min, Max, MinMax};
  29. }
  30. DECLARE_DISPATCH(void (*)(TensorIteratorBase &, const c10::Scalar&, const c10::Scalar&), clamp_scalar_stub);
  31. DECLARE_DISPATCH(void (*)(TensorIteratorBase &, c10::Scalar), clamp_min_scalar_stub);
  32. DECLARE_DISPATCH(void (*)(TensorIteratorBase &, c10::Scalar), clamp_max_scalar_stub);
  33. using isin_default_fn = void (*)(const Tensor&, const Tensor&, bool, const Tensor&);
  34. DECLARE_DISPATCH(isin_default_fn, isin_default_stub);
  35. }} // namespace at::native