SoftmaxKernel.h 943 B

12345678910111213141516171819202122232425262728
  1. #pragma once
  2. #include <ATen/native/DispatchStub.h>
  3. #include <cstdint>
  4. namespace at {
  5. class Tensor;
  6. namespace native {
  7. using forward_fn = void (*)(const Tensor&, const Tensor&);
  8. using backward_fn = void(*)(const Tensor &, const Tensor &, const Tensor&);
  9. DECLARE_DISPATCH(forward_fn, softmax_lastdim_kernel);
  10. DECLARE_DISPATCH(forward_fn, log_softmax_lastdim_kernel);
  11. DECLARE_DISPATCH(backward_fn, softmax_backward_lastdim_kernel);
  12. DECLARE_DISPATCH(backward_fn, log_softmax_backward_lastdim_kernel);
  13. using forward_fn_with_dim = void(*)(const Tensor &, const Tensor &, const int64_t);
  14. using backward_fn_with_dim =
  15. void (*)(const Tensor&, const Tensor&, const Tensor&, const int64_t);
  16. DECLARE_DISPATCH(forward_fn_with_dim, softmax_kernel);
  17. DECLARE_DISPATCH(forward_fn_with_dim, log_softmax_kernel);
  18. DECLARE_DISPATCH(backward_fn_with_dim, softmax_backward_kernel);
  19. DECLARE_DISPATCH(backward_fn_with_dim, log_softmax_backward_kernel);
  20. }
  21. }