roi_pool_kernel.cpp 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. #include "../roi_pool.h"
  2. #include <torch/autograd.h>
  3. #include <torch/types.h>
  4. namespace vision {
  5. namespace ops {
  6. namespace {
  7. class ROIPoolFunction : public torch::autograd::Function<ROIPoolFunction> {
  8. public:
  9. static torch::autograd::variable_list forward(
  10. torch::autograd::AutogradContext* ctx,
  11. const torch::autograd::Variable& input,
  12. const torch::autograd::Variable& rois,
  13. double spatial_scale,
  14. int64_t pooled_height,
  15. int64_t pooled_width) {
  16. ctx->saved_data["spatial_scale"] = spatial_scale;
  17. ctx->saved_data["pooled_height"] = pooled_height;
  18. ctx->saved_data["pooled_width"] = pooled_width;
  19. ctx->saved_data["input_shape"] = input.sizes();
  20. at::AutoDispatchBelowADInplaceOrView g;
  21. auto result =
  22. roi_pool(input, rois, spatial_scale, pooled_height, pooled_width);
  23. auto output = std::get<0>(result);
  24. auto argmax = std::get<1>(result);
  25. ctx->save_for_backward({rois, argmax});
  26. ctx->mark_non_differentiable({argmax});
  27. return {output, argmax};
  28. }
  29. static torch::autograd::variable_list backward(
  30. torch::autograd::AutogradContext* ctx,
  31. const torch::autograd::variable_list& grad_output) {
  32. // Use data saved in forward
  33. auto saved = ctx->get_saved_variables();
  34. auto rois = saved[0];
  35. auto argmax = saved[1];
  36. auto input_shape = ctx->saved_data["input_shape"].toIntList();
  37. auto grad_in = detail::_roi_pool_backward(
  38. grad_output[0],
  39. rois,
  40. argmax,
  41. ctx->saved_data["spatial_scale"].toDouble(),
  42. ctx->saved_data["pooled_height"].toInt(),
  43. ctx->saved_data["pooled_width"].toInt(),
  44. input_shape[0],
  45. input_shape[1],
  46. input_shape[2],
  47. input_shape[3]);
  48. return {
  49. grad_in,
  50. torch::autograd::Variable(),
  51. torch::autograd::Variable(),
  52. torch::autograd::Variable(),
  53. torch::autograd::Variable()};
  54. }
  55. };
  56. // TODO: There should be an easier way to do this
  57. class ROIPoolBackwardFunction
  58. : public torch::autograd::Function<ROIPoolBackwardFunction> {
  59. public:
  60. static torch::autograd::variable_list forward(
  61. torch::autograd::AutogradContext* ctx,
  62. const torch::autograd::Variable& grad,
  63. const torch::autograd::Variable& rois,
  64. const torch::autograd::Variable& argmax,
  65. double spatial_scale,
  66. int64_t pooled_height,
  67. int64_t pooled_width,
  68. int64_t batch_size,
  69. int64_t channels,
  70. int64_t height,
  71. int64_t width) {
  72. at::AutoDispatchBelowADInplaceOrView g;
  73. auto grad_in = detail::_roi_pool_backward(
  74. grad,
  75. rois,
  76. argmax,
  77. spatial_scale,
  78. pooled_height,
  79. pooled_width,
  80. batch_size,
  81. channels,
  82. height,
  83. width);
  84. return {grad_in};
  85. }
  86. static torch::autograd::variable_list backward(
  87. torch::autograd::AutogradContext* ctx,
  88. const torch::autograd::variable_list& grad_output) {
  89. TORCH_CHECK(0, "double backwards on roi_pool not supported");
  90. }
  91. };
  92. std::tuple<at::Tensor, at::Tensor> roi_pool_autograd(
  93. const at::Tensor& input,
  94. const at::Tensor& rois,
  95. double spatial_scale,
  96. int64_t pooled_height,
  97. int64_t pooled_width) {
  98. auto result = ROIPoolFunction::apply(
  99. input, rois, spatial_scale, pooled_height, pooled_width);
  100. return std::make_tuple(result[0], result[1]);
  101. }
  102. at::Tensor roi_pool_backward_autograd(
  103. const at::Tensor& grad,
  104. const at::Tensor& rois,
  105. const at::Tensor& argmax,
  106. double spatial_scale,
  107. int64_t pooled_height,
  108. int64_t pooled_width,
  109. int64_t batch_size,
  110. int64_t channels,
  111. int64_t height,
  112. int64_t width) {
  113. return ROIPoolBackwardFunction::apply(
  114. grad,
  115. rois,
  116. argmax,
  117. spatial_scale,
  118. pooled_height,
  119. pooled_width,
  120. batch_size,
  121. channels,
  122. height,
  123. width)[0];
  124. }
  125. } // namespace
  126. TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
  127. m.impl(
  128. TORCH_SELECTIVE_NAME("torchvision::roi_pool"),
  129. TORCH_FN(roi_pool_autograd));
  130. m.impl(
  131. TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"),
  132. TORCH_FN(roi_pool_backward_autograd));
  133. }
  134. } // namespace ops
  135. } // namespace vision