roi_pool_kernel.cpp 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. #include <float.h>
  2. #include <ATen/ATen.h>
  3. #include <torch/library.h>
  4. namespace vision {
  5. namespace ops {
  6. namespace {
  7. template <class T>
  8. inline void add(T* address, const T& val) {
  9. *address += val;
  10. }
  11. template <typename T>
  12. void roi_pool_forward_kernel_impl(
  13. const T* input,
  14. const T spatial_scale,
  15. int channels,
  16. int height,
  17. int width,
  18. int pooled_height,
  19. int pooled_width,
  20. const T* rois,
  21. int num_rois,
  22. T* output,
  23. int* argmax_data) {
  24. for (int n = 0; n < num_rois; ++n) {
  25. const T* offset_rois = rois + n * 5;
  26. int roi_batch_ind = offset_rois[0];
  27. int roi_start_w = round(offset_rois[1] * spatial_scale);
  28. int roi_start_h = round(offset_rois[2] * spatial_scale);
  29. int roi_end_w = round(offset_rois[3] * spatial_scale);
  30. int roi_end_h = round(offset_rois[4] * spatial_scale);
  31. // Force malformed ROIs to be 1x1
  32. int roi_width = std::max(roi_end_w - roi_start_w + 1, 1);
  33. int roi_height = std::max(roi_end_h - roi_start_h + 1, 1);
  34. T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
  35. T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
  36. for (int ph = 0; ph < pooled_height; ++ph) {
  37. for (int pw = 0; pw < pooled_width; ++pw) {
  38. int hstart = static_cast<int>(floor(static_cast<T>(ph) * bin_size_h));
  39. int wstart = static_cast<int>(floor(static_cast<T>(pw) * bin_size_w));
  40. int hend = static_cast<int>(ceil(static_cast<T>(ph + 1) * bin_size_h));
  41. int wend = static_cast<int>(ceil(static_cast<T>(pw + 1) * bin_size_w));
  42. // Add roi offsets and clip to input boundaries
  43. hstart = std::min(std::max(hstart + roi_start_h, 0), height);
  44. hend = std::min(std::max(hend + roi_start_h, 0), height);
  45. wstart = std::min(std::max(wstart + roi_start_w, 0), width);
  46. wend = std::min(std::max(wend + roi_start_w, 0), width);
  47. bool is_empty = (hend <= hstart) || (wend <= wstart);
  48. for (int c = 0; c < channels; ++c) {
  49. // Define an empty pooling region to be zero
  50. T maxval = is_empty ? 0 : -FLT_MAX;
  51. // If nothing is pooled, argmax = -1 causes nothing to be backprop'd
  52. int maxidx = -1;
  53. const T* input_offset =
  54. input + (roi_batch_ind * channels + c) * height * width;
  55. for (int h = hstart; h < hend; ++h) {
  56. for (int w = wstart; w < wend; ++w) {
  57. int input_index = h * width + w;
  58. if (input_offset[input_index] > maxval) {
  59. maxval = input_offset[input_index];
  60. maxidx = input_index;
  61. }
  62. }
  63. }
  64. int index =
  65. ((n * channels + c) * pooled_height + ph) * pooled_width + pw;
  66. output[index] = maxval;
  67. argmax_data[index] = maxidx;
  68. } // channels
  69. } // pooled_width
  70. } // pooled_height
  71. } // num_rois
  72. }
  73. template <typename T>
  74. void roi_pool_backward_kernel_impl(
  75. const T* grad_output,
  76. const int* argmax_data,
  77. int num_rois,
  78. int channels,
  79. int height,
  80. int width,
  81. int pooled_height,
  82. int pooled_width,
  83. T* grad_input,
  84. const T* rois,
  85. int n_stride,
  86. int c_stride,
  87. int h_stride,
  88. int w_stride) {
  89. for (int n = 0; n < num_rois; ++n) {
  90. const T* offset_rois = rois + n * 5;
  91. int roi_batch_ind = offset_rois[0];
  92. for (int c = 0; c < channels; ++c) {
  93. T* grad_input_offset =
  94. grad_input + ((roi_batch_ind * channels + c) * height * width);
  95. const int* argmax_data_offset =
  96. argmax_data + (n * channels + c) * pooled_height * pooled_width;
  97. for (int ph = 0; ph < pooled_height; ++ph) {
  98. for (int pw = 0; pw < pooled_width; ++pw) {
  99. int output_offset = n * n_stride + c * c_stride;
  100. int argmax = argmax_data_offset[ph * pooled_width + pw];
  101. if (argmax != -1) {
  102. add(grad_input_offset + argmax,
  103. static_cast<T>(
  104. grad_output
  105. [output_offset + ph * h_stride + pw * w_stride]));
  106. }
  107. } // pooled_width
  108. } // pooled_height
  109. } // channels
  110. } // num_rois
  111. }
  112. std::tuple<at::Tensor, at::Tensor> roi_pool_forward_kernel(
  113. const at::Tensor& input,
  114. const at::Tensor& rois,
  115. double spatial_scale,
  116. int64_t pooled_height,
  117. int64_t pooled_width) {
  118. TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor");
  119. TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor");
  120. at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
  121. at::CheckedFrom c = "roi_pool_forward_kernel";
  122. at::checkAllSameType(c, {input_t, rois_t});
  123. int num_rois = rois.size(0);
  124. int channels = input.size(1);
  125. int height = input.size(2);
  126. int width = input.size(3);
  127. at::Tensor output = at::zeros(
  128. {num_rois, channels, pooled_height, pooled_width}, input.options());
  129. at::Tensor argmax = at::zeros(
  130. {num_rois, channels, pooled_height, pooled_width},
  131. input.options().dtype(at::kInt));
  132. if (output.numel() == 0) {
  133. return std::make_tuple(output, argmax);
  134. }
  135. auto input_ = input.contiguous(), rois_ = rois.contiguous();
  136. AT_DISPATCH_FLOATING_TYPES_AND_HALF(
  137. input.scalar_type(), "roi_pool_forward_kernel", [&] {
  138. roi_pool_forward_kernel_impl<scalar_t>(
  139. input_.data_ptr<scalar_t>(),
  140. spatial_scale,
  141. channels,
  142. height,
  143. width,
  144. pooled_height,
  145. pooled_width,
  146. rois_.data_ptr<scalar_t>(),
  147. num_rois,
  148. output.data_ptr<scalar_t>(),
  149. argmax.data_ptr<int>());
  150. });
  151. return std::make_tuple(output, argmax);
  152. }
  153. at::Tensor roi_pool_backward_kernel(
  154. const at::Tensor& grad,
  155. const at::Tensor& rois,
  156. const at::Tensor& argmax,
  157. double spatial_scale,
  158. int64_t pooled_height,
  159. int64_t pooled_width,
  160. int64_t batch_size,
  161. int64_t channels,
  162. int64_t height,
  163. int64_t width) {
  164. // Check if input tensors are CPU tensors
  165. TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor");
  166. TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor");
  167. TORCH_CHECK(argmax.device().is_cpu(), "argmax must be a CPU tensor");
  168. TORCH_CHECK(
  169. rois.size(1) == 5, "Tensor rois should have shape as Tensor[K, 5]");
  170. at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2};
  171. at::CheckedFrom c = "roi_pool_backward_kernel";
  172. at::checkAllSameType(c, {grad_t, rois_t});
  173. auto num_rois = rois.size(0);
  174. at::Tensor grad_input =
  175. at::zeros({batch_size, channels, height, width}, grad.options());
  176. // handle possibly empty gradients
  177. if (grad.numel() == 0) {
  178. return grad_input;
  179. }
  180. // get stride values to ensure indexing into gradients is correct.
  181. int n_stride = grad.stride(0);
  182. int c_stride = grad.stride(1);
  183. int h_stride = grad.stride(2);
  184. int w_stride = grad.stride(3);
  185. auto rois_ = rois.contiguous();
  186. AT_DISPATCH_FLOATING_TYPES_AND_HALF(
  187. grad.scalar_type(), "roi_pool_backward_kernel", [&] {
  188. roi_pool_backward_kernel_impl<scalar_t>(
  189. grad.data_ptr<scalar_t>(),
  190. argmax.data_ptr<int>(),
  191. num_rois,
  192. channels,
  193. height,
  194. width,
  195. pooled_height,
  196. pooled_width,
  197. grad_input.data_ptr<scalar_t>(),
  198. rois_.data_ptr<scalar_t>(),
  199. n_stride,
  200. c_stride,
  201. h_stride,
  202. w_stride);
  203. });
  204. return grad_input;
  205. }
  206. } // namespace
  207. TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
  208. m.impl(
  209. TORCH_SELECTIVE_NAME("torchvision::roi_pool"),
  210. TORCH_FN(roi_pool_forward_kernel));
  211. m.impl(
  212. TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"),
  213. TORCH_FN(roi_pool_backward_kernel));
  214. }
  215. } // namespace ops
  216. } // namespace vision