123456789101112131415161718192021222324252627282930313233343536373839404142 |
- #pragma once
- #include <ATen/native/DispatchStub.h>
- #include <cstdint>
- namespace at {
- class Tensor;
- namespace native {
- using forward_fn = void (*)(
- const Tensor& /* X */,
- const Tensor& /* gamma */,
- const Tensor& /* beta */,
- int64_t /* N */,
- int64_t /* C */,
- int64_t /* HxW */,
- int64_t /* group */,
- double /* eps */,
- Tensor& /* Y */,
- Tensor& /* mean */,
- Tensor& /* rstd */);
- using backward_fn = void (*)(
- const Tensor& /* dY */,
- const Tensor& /* X */,
- const Tensor& /* mean */,
- const Tensor& /* rstd */,
- const Tensor& /* gamma */,
- int64_t /* N */,
- int64_t /* C */,
- int64_t /* HxW */,
- int64_t /* group */,
- Tensor& /* dX */,
- Tensor& /* dgamma */,
- Tensor& /* dbeta */);
- DECLARE_DISPATCH(forward_fn, GroupNormKernel);
- DECLARE_DISPATCH(backward_fn, GroupNormBackwardKernel);
- } // namespace native
- } // namespace at
|