SegmentReduce.h 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. #pragma once
  2. #include <ATen/native/DispatchStub.h>
  3. #include <ATen/native/ReductionType.h>
  4. #include <c10/core/Scalar.h>
  5. #include <c10/util/Optional.h>
  6. namespace at {
  7. class Tensor;
  8. namespace native {
  9. using segment_reduce_lengths_fn = Tensor (*)(
  10. ReductionType,
  11. const Tensor&,
  12. const Tensor&,
  13. int64_t,
  14. const c10::optional<Scalar>&);
  15. DECLARE_DISPATCH(segment_reduce_lengths_fn, _segment_reduce_lengths_stub);
  16. using segment_reduce_offsets_fn = Tensor (*)(
  17. ReductionType,
  18. const Tensor&,
  19. const Tensor&,
  20. int64_t,
  21. const c10::optional<Scalar>&);
  22. DECLARE_DISPATCH(segment_reduce_offsets_fn, _segment_reduce_offsets_stub);
  23. using segment_reduce_lengths_backward_fn = Tensor (*)(
  24. const Tensor&,
  25. const Tensor&,
  26. const Tensor&,
  27. ReductionType,
  28. const Tensor&,
  29. int64_t,
  30. const c10::optional<Scalar>&);
  31. DECLARE_DISPATCH(segment_reduce_lengths_backward_fn, _segment_reduce_lengths_backward_stub);
  32. using segment_reduce_offsets_backward_fn = Tensor (*)(
  33. const Tensor&,
  34. const Tensor&,
  35. const Tensor&,
  36. ReductionType,
  37. const Tensor&,
  38. int64_t,
  39. const c10::optional<Scalar>&);
  40. DECLARE_DISPATCH(segment_reduce_offsets_backward_fn, _segment_reduce_offsets_backward_stub);
  41. } // namespace native
  42. } // namespace at