WeightNormKernel.h 565 B

1234567891011121314151617181920
  1. #pragma once
  2. #include <ATen/native/DispatchStub.h>
  3. #include <cstdint>
  4. namespace at {
  5. class TensorBase;
  6. }
  7. namespace at { namespace native {
  8. using weight_norm_fn = void(*)(
  9. TensorBase&, TensorBase&, const TensorBase&, const TensorBase&, int64_t);
  10. using weight_norm_backward_fn = void(*)(
  11. TensorBase&, TensorBase&, const TensorBase&, const TensorBase&,
  12. const TensorBase&, const TensorBase&, int64_t);
  13. DECLARE_DISPATCH(weight_norm_fn, weight_norm_stub);
  14. DECLARE_DISPATCH(weight_norm_backward_fn, weight_norm_backward_stub);
  15. }} // namespace at::native