deform_conv2d.h 1.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. #pragma once
  2. #include <ATen/ATen.h>
  3. #include "../macros.h"
  4. namespace vision {
  5. namespace ops {
  6. VISION_API at::Tensor deform_conv2d(
  7. const at::Tensor& input,
  8. const at::Tensor& weight,
  9. const at::Tensor& offset,
  10. const at::Tensor& mask,
  11. const at::Tensor& bias,
  12. int64_t stride_h,
  13. int64_t stride_w,
  14. int64_t pad_h,
  15. int64_t pad_w,
  16. int64_t dilation_h,
  17. int64_t dilation_w,
  18. int64_t groups,
  19. int64_t offset_groups,
  20. bool use_mask);
  21. namespace detail {
  22. std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
  23. _deform_conv2d_backward(
  24. const at::Tensor& grad,
  25. const at::Tensor& input,
  26. const at::Tensor& weight,
  27. const at::Tensor& offset,
  28. const at::Tensor& mask,
  29. const at::Tensor& bias,
  30. int64_t stride_h,
  31. int64_t stride_w,
  32. int64_t pad_h,
  33. int64_t pad_w,
  34. int64_t dilation_h,
  35. int64_t dilation_w,
  36. int64_t groups,
  37. int64_t offset_groups,
  38. bool use_mask);
  39. } // namespace detail
  40. } // namespace ops
  41. } // namespace vision