12345678910111213141516171819202122232425262728 |
- #pragma once
- #include <ATen/native/DispatchStub.h>
- #include <cstdint>
- namespace at {
- class Tensor;
- namespace native {
- using forward_fn = void (*)(const Tensor&, const Tensor&);
- using backward_fn = void(*)(const Tensor &, const Tensor &, const Tensor&);
- DECLARE_DISPATCH(forward_fn, softmax_lastdim_kernel);
- DECLARE_DISPATCH(forward_fn, log_softmax_lastdim_kernel);
- DECLARE_DISPATCH(backward_fn, softmax_backward_lastdim_kernel);
- DECLARE_DISPATCH(backward_fn, log_softmax_backward_lastdim_kernel);
- using forward_fn_with_dim = void(*)(const Tensor &, const Tensor &, const int64_t);
- using backward_fn_with_dim =
- void (*)(const Tensor&, const Tensor&, const Tensor&, const int64_t);
- DECLARE_DISPATCH(forward_fn_with_dim, softmax_kernel);
- DECLARE_DISPATCH(forward_fn_with_dim, log_softmax_kernel);
- DECLARE_DISPATCH(backward_fn_with_dim, softmax_backward_kernel);
- DECLARE_DISPATCH(backward_fn_with_dim, log_softmax_backward_kernel);
- }
- }
|