roi_align.h 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. #pragma once
  2. #include <ATen/ATen.h>
  3. #include "../macros.h"
  4. namespace vision {
  5. namespace ops {
  6. VISION_API at::Tensor roi_align(
  7. const at::Tensor& input,
  8. const at::Tensor& rois,
  9. double spatial_scale,
  10. int64_t pooled_height,
  11. int64_t pooled_width,
  12. int64_t sampling_ratio,
  13. bool aligned);
  14. VISION_API at::Tensor roi_align_symint(
  15. const at::Tensor& input,
  16. const at::Tensor& rois,
  17. double spatial_scale,
  18. c10::SymInt pooled_height,
  19. c10::SymInt pooled_width,
  20. int64_t sampling_ratio,
  21. bool aligned);
  22. namespace detail {
  23. at::Tensor _roi_align_backward(
  24. const at::Tensor& grad,
  25. const at::Tensor& rois,
  26. double spatial_scale,
  27. int64_t pooled_height,
  28. int64_t pooled_width,
  29. int64_t batch_size,
  30. int64_t channels,
  31. int64_t height,
  32. int64_t width,
  33. int64_t sampling_ratio,
  34. bool aligned);
  35. at::Tensor _roi_align_backward_symint(
  36. const at::Tensor& grad,
  37. const at::Tensor& rois,
  38. double spatial_scale,
  39. c10::SymInt pooled_height,
  40. c10::SymInt pooled_width,
  41. c10::SymInt batch_size,
  42. c10::SymInt channels,
  43. c10::SymInt height,
  44. c10::SymInt width,
  45. int64_t sampling_ratio,
  46. bool aligned);
  47. } // namespace detail
  48. } // namespace ops
  49. } // namespace vision