roi_align_kernel.cpp 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. #include "../roi_align.h"
  2. #include <torch/autograd.h>
  3. #include <torch/types.h>
  4. namespace vision {
  5. namespace ops {
  6. namespace {
  7. class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
  8. public:
  9. static torch::autograd::variable_list forward(
  10. torch::autograd::AutogradContext* ctx,
  11. const torch::autograd::Variable& input,
  12. const torch::autograd::Variable& rois,
  13. double spatial_scale,
  14. c10::SymInt pooled_height,
  15. c10::SymInt pooled_width,
  16. int64_t sampling_ratio,
  17. bool aligned) {
  18. ctx->saved_data["spatial_scale"] = spatial_scale;
  19. ctx->saved_data["pooled_height"] = pooled_height;
  20. ctx->saved_data["pooled_width"] = pooled_width;
  21. ctx->saved_data["sampling_ratio"] = sampling_ratio;
  22. ctx->saved_data["aligned"] = aligned;
  23. ctx->saved_data["input_shape"] = input.sym_sizes();
  24. ctx->save_for_backward({rois});
  25. at::AutoDispatchBelowADInplaceOrView g;
  26. auto result = roi_align_symint(
  27. input,
  28. rois,
  29. spatial_scale,
  30. pooled_height,
  31. pooled_width,
  32. sampling_ratio,
  33. aligned);
  34. return {result};
  35. }
  36. static torch::autograd::variable_list backward(
  37. torch::autograd::AutogradContext* ctx,
  38. const torch::autograd::variable_list& grad_output) {
  39. // Use data saved in forward
  40. auto saved = ctx->get_saved_variables();
  41. auto rois = saved[0];
  42. auto input_shape = ctx->saved_data["input_shape"].toList();
  43. auto grad_in = detail::_roi_align_backward_symint(
  44. grad_output[0],
  45. rois,
  46. ctx->saved_data["spatial_scale"].toDouble(),
  47. ctx->saved_data["pooled_height"].toSymInt(),
  48. ctx->saved_data["pooled_width"].toSymInt(),
  49. input_shape[0].get().toSymInt(),
  50. input_shape[1].get().toSymInt(),
  51. input_shape[2].get().toSymInt(),
  52. input_shape[3].get().toSymInt(),
  53. ctx->saved_data["sampling_ratio"].toInt(),
  54. ctx->saved_data["aligned"].toBool());
  55. return {
  56. grad_in,
  57. torch::autograd::Variable(),
  58. torch::autograd::Variable(),
  59. torch::autograd::Variable(),
  60. torch::autograd::Variable(),
  61. torch::autograd::Variable(),
  62. torch::autograd::Variable()};
  63. }
  64. };
  65. // TODO: There should be an easier way to do this
  66. class ROIAlignBackwardFunction
  67. : public torch::autograd::Function<ROIAlignBackwardFunction> {
  68. public:
  69. static torch::autograd::variable_list forward(
  70. torch::autograd::AutogradContext* ctx,
  71. const torch::autograd::Variable& grad,
  72. const torch::autograd::Variable& rois,
  73. double spatial_scale,
  74. c10::SymInt pooled_height,
  75. c10::SymInt pooled_width,
  76. c10::SymInt batch_size,
  77. c10::SymInt channels,
  78. c10::SymInt height,
  79. c10::SymInt width,
  80. int64_t sampling_ratio,
  81. bool aligned) {
  82. at::AutoDispatchBelowADInplaceOrView g;
  83. auto result = detail::_roi_align_backward_symint(
  84. grad,
  85. rois,
  86. spatial_scale,
  87. pooled_height,
  88. pooled_width,
  89. batch_size,
  90. channels,
  91. height,
  92. width,
  93. sampling_ratio,
  94. aligned);
  95. return {result};
  96. }
  97. static torch::autograd::variable_list backward(
  98. torch::autograd::AutogradContext* ctx,
  99. const torch::autograd::variable_list& grad_output) {
  100. TORCH_CHECK(0, "double backwards on roi_align not supported");
  101. }
  102. };
  103. at::Tensor roi_align_autograd(
  104. const at::Tensor& input,
  105. const at::Tensor& rois,
  106. double spatial_scale,
  107. c10::SymInt pooled_height,
  108. c10::SymInt pooled_width,
  109. int64_t sampling_ratio,
  110. bool aligned) {
  111. return ROIAlignFunction::apply(
  112. input,
  113. rois,
  114. spatial_scale,
  115. pooled_height,
  116. pooled_width,
  117. sampling_ratio,
  118. aligned)[0];
  119. }
  120. at::Tensor roi_align_backward_autograd(
  121. const at::Tensor& grad,
  122. const at::Tensor& rois,
  123. double spatial_scale,
  124. c10::SymInt pooled_height,
  125. c10::SymInt pooled_width,
  126. c10::SymInt batch_size,
  127. c10::SymInt channels,
  128. c10::SymInt height,
  129. c10::SymInt width,
  130. int64_t sampling_ratio,
  131. bool aligned) {
  132. return ROIAlignBackwardFunction::apply(
  133. grad,
  134. rois,
  135. spatial_scale,
  136. pooled_height,
  137. pooled_width,
  138. batch_size,
  139. channels,
  140. height,
  141. width,
  142. sampling_ratio,
  143. aligned)[0];
  144. }
  145. } // namespace
  146. TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
  147. m.impl(
  148. TORCH_SELECTIVE_NAME("torchvision::roi_align"),
  149. TORCH_FN(roi_align_autograd));
  150. m.impl(
  151. TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"),
  152. TORCH_FN(roi_align_backward_autograd));
  153. }
  154. } // namespace ops
  155. } // namespace vision