DepthwiseConvKernel.h 471 B

123456789101112131415161718192021
  1. #pragma once
  2. #include <ATen/native/DispatchStub.h>
  3. #include <c10/util/ArrayRef.h>
  4. /*
  5. Depthwise 3x3 Winograd convolution operator
  6. */
  7. namespace at {
  8. class Tensor;
  9. namespace native {
  10. using convolution_depthwise3x3_winograd_fn =
  11. Tensor (*)(const Tensor &, const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, int64_t);
  12. DECLARE_DISPATCH(convolution_depthwise3x3_winograd_fn, convolution_depthwise3x3_winograd_stub);
  13. } // namespace native
  14. } // namespace at