roi_align.cpp 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. #include "roi_align.h"
  2. #include <ATen/core/dispatch/Dispatcher.h>
  3. #include <torch/library.h>
  4. #include <torch/types.h>
  5. namespace vision {
  6. namespace ops {
  7. at::Tensor roi_align(
  8. const at::Tensor& input, // Input feature map.
  9. const at::Tensor& rois, // List of ROIs to pool over.
  10. double spatial_scale, // The scale of the image features. ROIs will be
  11. // scaled to this.
  12. int64_t pooled_height, // The height of the pooled feature map.
  13. int64_t pooled_width, // The width of the pooled feature
  14. int64_t sampling_ratio, // The number of points to sample in each bin
  15. bool aligned) // The flag for pixel shift
  16. // along each axis.
  17. {
  18. C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.roi_align.roi_align");
  19. static auto op = c10::Dispatcher::singleton()
  20. .findSchemaOrThrow("torchvision::roi_align", "")
  21. .typed<decltype(roi_align)>();
  22. return op.call(
  23. input,
  24. rois,
  25. spatial_scale,
  26. pooled_height,
  27. pooled_width,
  28. sampling_ratio,
  29. aligned);
  30. }
  31. at::Tensor roi_align_symint(
  32. const at::Tensor& input, // Input feature map.
  33. const at::Tensor& rois, // List of ROIs to pool over.
  34. double spatial_scale, // The scale of the image features. ROIs will be
  35. // scaled to this.
  36. c10::SymInt pooled_height, // The height of the pooled feature map.
  37. c10::SymInt pooled_width, // The width of the pooled feature
  38. int64_t sampling_ratio, // The number of points to sample in each bin
  39. bool aligned) // The flag for pixel shift
  40. // along each axis.
  41. {
  42. C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.roi_align.roi_align");
  43. static auto op = c10::Dispatcher::singleton()
  44. .findSchemaOrThrow("torchvision::roi_align", "")
  45. .typed<decltype(roi_align_symint)>();
  46. return op.call(
  47. input,
  48. rois,
  49. spatial_scale,
  50. pooled_height,
  51. pooled_width,
  52. sampling_ratio,
  53. aligned);
  54. }
  55. namespace detail {
  56. at::Tensor _roi_align_backward(
  57. const at::Tensor& grad,
  58. const at::Tensor& rois,
  59. double spatial_scale,
  60. int64_t pooled_height,
  61. int64_t pooled_width,
  62. int64_t batch_size,
  63. int64_t channels,
  64. int64_t height,
  65. int64_t width,
  66. int64_t sampling_ratio,
  67. bool aligned) {
  68. static auto op =
  69. c10::Dispatcher::singleton()
  70. .findSchemaOrThrow("torchvision::_roi_align_backward", "")
  71. .typed<decltype(_roi_align_backward)>();
  72. return op.call(
  73. grad,
  74. rois,
  75. spatial_scale,
  76. pooled_height,
  77. pooled_width,
  78. batch_size,
  79. channels,
  80. height,
  81. width,
  82. sampling_ratio,
  83. aligned);
  84. }
  85. at::Tensor _roi_align_backward_symint(
  86. const at::Tensor& grad,
  87. const at::Tensor& rois,
  88. double spatial_scale,
  89. c10::SymInt pooled_height,
  90. c10::SymInt pooled_width,
  91. c10::SymInt batch_size,
  92. c10::SymInt channels,
  93. c10::SymInt height,
  94. c10::SymInt width,
  95. int64_t sampling_ratio,
  96. bool aligned) {
  97. static auto op =
  98. c10::Dispatcher::singleton()
  99. .findSchemaOrThrow("torchvision::_roi_align_backward", "")
  100. .typed<decltype(_roi_align_backward_symint)>();
  101. return op.call(
  102. grad,
  103. rois,
  104. spatial_scale,
  105. pooled_height,
  106. pooled_width,
  107. batch_size,
  108. channels,
  109. height,
  110. width,
  111. sampling_ratio,
  112. aligned);
  113. }
  114. } // namespace detail
  115. TORCH_LIBRARY_FRAGMENT(torchvision, m) {
  116. m.def(TORCH_SELECTIVE_SCHEMA(
  117. "torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width, int sampling_ratio, bool aligned) -> Tensor"));
  118. m.def(TORCH_SELECTIVE_SCHEMA(
  119. "torchvision::_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width, SymInt batch_size, SymInt channels, SymInt height, SymInt width, int sampling_ratio, bool aligned) -> Tensor"));
  120. }
  121. } // namespace ops
  122. } // namespace vision