group_norm.h 907 B

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. #pragma once
  2. #include <ATen/native/DispatchStub.h>
  3. #include <cstdint>
  4. namespace at {
  5. class Tensor;
  6. namespace native {
  7. using forward_fn = void (*)(
  8. const Tensor& /* X */,
  9. const Tensor& /* gamma */,
  10. const Tensor& /* beta */,
  11. int64_t /* N */,
  12. int64_t /* C */,
  13. int64_t /* HxW */,
  14. int64_t /* group */,
  15. double /* eps */,
  16. Tensor& /* Y */,
  17. Tensor& /* mean */,
  18. Tensor& /* rstd */);
  19. using backward_fn = void (*)(
  20. const Tensor& /* dY */,
  21. const Tensor& /* X */,
  22. const Tensor& /* mean */,
  23. const Tensor& /* rstd */,
  24. const Tensor& /* gamma */,
  25. int64_t /* N */,
  26. int64_t /* C */,
  27. int64_t /* HxW */,
  28. int64_t /* group */,
  29. Tensor& /* dX */,
  30. Tensor& /* dgamma */,
  31. Tensor& /* dbeta */);
  32. DECLARE_DISPATCH(forward_fn, GroupNormKernel);
  33. DECLARE_DISPATCH(backward_fn, GroupNormBackwardKernel);
  34. } // namespace native
  35. } // namespace at