roi_pool_kernel.cu 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. #include <ATen/ATen.h>
  2. #include <ATen/cuda/CUDAContext.h>
  3. #include <c10/cuda/CUDAGuard.h>
  4. #include <float.h>
  5. #include <torch/library.h>
  6. #include <ATen/native/cuda/KernelUtils.cuh>
  7. #include "cuda_helpers.h"
  8. namespace vision {
  9. namespace ops {
  10. namespace {
  11. template <typename T>
  12. __global__ void roi_pool_forward_kernel_impl(
  13. int nthreads,
  14. const T* input,
  15. const T spatial_scale,
  16. int channels,
  17. int height,
  18. int width,
  19. int pooled_height,
  20. int pooled_width,
  21. const T* rois,
  22. T* output,
  23. int* argmax_data) {
  24. CUDA_1D_KERNEL_LOOP(index, nthreads) {
  25. // (n, c, ph, pw) is an element in the pooled output
  26. int pw = index % pooled_width;
  27. int ph = (index / pooled_width) % pooled_height;
  28. int c = (index / pooled_width / pooled_height) % channels;
  29. int n = index / pooled_width / pooled_height / channels;
  30. const T* offset_rois = rois + n * 5;
  31. int roi_batch_ind = offset_rois[0];
  32. int roi_start_w = round(offset_rois[1] * spatial_scale);
  33. int roi_start_h = round(offset_rois[2] * spatial_scale);
  34. int roi_end_w = round(offset_rois[3] * spatial_scale);
  35. int roi_end_h = round(offset_rois[4] * spatial_scale);
  36. // Force malformed ROIs to be 1x1
  37. int roi_width = max(roi_end_w - roi_start_w + 1, 1);
  38. int roi_height = max(roi_end_h - roi_start_h + 1, 1);
  39. T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
  40. T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
  41. int hstart = static_cast<int>(floor(static_cast<T>(ph) * bin_size_h));
  42. int wstart = static_cast<int>(floor(static_cast<T>(pw) * bin_size_w));
  43. int hend = static_cast<int>(ceil(static_cast<T>(ph + 1) * bin_size_h));
  44. int wend = static_cast<int>(ceil(static_cast<T>(pw + 1) * bin_size_w));
  45. // Add roi offsets and clip to input boundaries
  46. hstart = min(max(hstart + roi_start_h, 0), height);
  47. hend = min(max(hend + roi_start_h, 0), height);
  48. wstart = min(max(wstart + roi_start_w, 0), width);
  49. wend = min(max(wend + roi_start_w, 0), width);
  50. bool is_empty = (hend <= hstart) || (wend <= wstart);
  51. // Define an empty pooling region to be zero
  52. T maxval = is_empty ? 0 : -FLT_MAX;
  53. // If nothing is pooled, argmax = -1 causes nothing to be backprop'd
  54. int maxidx = -1;
  55. const T* offset_input =
  56. input + (roi_batch_ind * channels + c) * height * width;
  57. for (int h = hstart; h < hend; ++h) {
  58. for (int w = wstart; w < wend; ++w) {
  59. int input_index = h * width + w;
  60. if (offset_input[input_index] > maxval) {
  61. maxval = offset_input[input_index];
  62. maxidx = input_index;
  63. }
  64. }
  65. }
  66. output[index] = maxval;
  67. argmax_data[index] = maxidx;
  68. }
  69. }
  70. template <typename T>
  71. __global__ void roi_pool_backward_kernel_impl(
  72. int nthreads,
  73. const T* grad_output,
  74. const int* argmax_data,
  75. int num_rois,
  76. const T spatial_scale,
  77. int channels,
  78. int height,
  79. int width,
  80. int pooled_height,
  81. int pooled_width,
  82. T* grad_input,
  83. const T* rois,
  84. int n_stride,
  85. int c_stride,
  86. int h_stride,
  87. int w_stride,
  88. const int memory_span) {
  89. CUDA_1D_KERNEL_LOOP(index, nthreads) {
  90. // (n, c, ph, pw) is an element in the pooled output
  91. int pw = index % pooled_width;
  92. int ph = (index / pooled_width) % pooled_height;
  93. int c = (index / pooled_width / pooled_height) % channels;
  94. int n = index / pooled_width / pooled_height / channels;
  95. const T* offset_rois = rois + n * 5;
  96. int roi_batch_ind = offset_rois[0];
  97. const int output_offset = n * n_stride + c * c_stride;
  98. const int* argmax_data_offset =
  99. argmax_data + (n * channels + c) * pooled_height * pooled_width;
  100. const int argmax = argmax_data_offset[ph * pooled_width + pw];
  101. const int offset = (roi_batch_ind * channels + c) * height * width;
  102. if (argmax != -1) {
  103. at::native::fastAtomicAdd(
  104. grad_input,
  105. offset + argmax,
  106. memory_span,
  107. static_cast<T>(
  108. grad_output[output_offset + ph * h_stride + pw * w_stride]),
  109. true);
  110. }
  111. }
  112. }
  113. std::tuple<at::Tensor, at::Tensor> roi_pool_forward_kernel(
  114. const at::Tensor& input,
  115. const at::Tensor& rois,
  116. double spatial_scale,
  117. int64_t pooled_height,
  118. int64_t pooled_width) {
  119. TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor");
  120. TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor");
  121. TORCH_CHECK(
  122. rois.size(1) == 5, "Tensor rois should have shape as Tensor[K, 5]");
  123. at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
  124. at::CheckedFrom c = "roi_pool_forward_kernel";
  125. at::checkAllSameGPU(c, {input_t, rois_t});
  126. at::checkAllSameType(c, {input_t, rois_t});
  127. at::cuda::CUDAGuard device_guard(input.device());
  128. auto num_rois = rois.size(0);
  129. auto channels = input.size(1);
  130. auto height = input.size(2);
  131. auto width = input.size(3);
  132. at::Tensor output = at::zeros(
  133. {num_rois, channels, pooled_height, pooled_width}, input.options());
  134. at::Tensor argmax = at::zeros(
  135. {num_rois, channels, pooled_height, pooled_width},
  136. input.options().dtype(at::kInt));
  137. auto output_size = num_rois * pooled_height * pooled_width * channels;
  138. cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  139. dim3 grid(std::min(
  140. ceil_div(static_cast<int64_t>(output_size), static_cast<int64_t>(512)),
  141. static_cast<int64_t>(4096)));
  142. dim3 block(512);
  143. if (output.numel() == 0) {
  144. AT_CUDA_CHECK(cudaGetLastError());
  145. return std::make_tuple(output, argmax);
  146. }
  147. auto input_ = input.contiguous(), rois_ = rois.contiguous();
  148. AT_DISPATCH_FLOATING_TYPES_AND_HALF(
  149. input.scalar_type(), "roi_pool_forward_kernel", [&] {
  150. roi_pool_forward_kernel_impl<scalar_t><<<grid, block, 0, stream>>>(
  151. output_size,
  152. input_.data_ptr<scalar_t>(),
  153. spatial_scale,
  154. channels,
  155. height,
  156. width,
  157. pooled_height,
  158. pooled_width,
  159. rois_.data_ptr<scalar_t>(),
  160. output.data_ptr<scalar_t>(),
  161. argmax.data_ptr<int>());
  162. });
  163. AT_CUDA_CHECK(cudaGetLastError());
  164. return std::make_tuple(output, argmax);
  165. }
  166. at::Tensor roi_pool_backward_kernel(
  167. const at::Tensor& grad,
  168. const at::Tensor& rois,
  169. const at::Tensor& argmax,
  170. double spatial_scale,
  171. int64_t pooled_height,
  172. int64_t pooled_width,
  173. int64_t batch_size,
  174. int64_t channels,
  175. int64_t height,
  176. int64_t width) {
  177. // Check if input tensors are CUDA tensors
  178. TORCH_CHECK(grad.is_cuda(), "grad must be a CUDA tensor");
  179. TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor");
  180. TORCH_CHECK(argmax.is_cuda(), "argmax must be a CUDA tensor");
  181. at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2},
  182. argmax_t{argmax, "argmax", 3};
  183. at::CheckedFrom c = "roi_pool_backward_kernel";
  184. at::checkAllSameGPU(c, {grad_t, rois_t, argmax_t});
  185. at::checkAllSameType(c, {grad_t, rois_t});
  186. at::cuda::CUDAGuard device_guard(grad.device());
  187. auto num_rois = rois.size(0);
  188. at::Tensor grad_input =
  189. at::zeros({batch_size, channels, height, width}, grad.options());
  190. cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  191. dim3 grid(std::min(
  192. ceil_div(static_cast<int64_t>(grad.numel()), static_cast<int64_t>(512)),
  193. static_cast<int64_t>(4096)));
  194. dim3 block(512);
  195. // handle possibly empty gradients
  196. if (grad.numel() == 0) {
  197. AT_CUDA_CHECK(cudaGetLastError());
  198. return grad_input;
  199. }
  200. int n_stride = grad.stride(0);
  201. int c_stride = grad.stride(1);
  202. int h_stride = grad.stride(2);
  203. int w_stride = grad.stride(3);
  204. at::globalContext().alertNotDeterministic("roi_pool_backward_kernel");
  205. auto argmax_ = argmax.contiguous(), rois_ = rois.contiguous();
  206. AT_DISPATCH_FLOATING_TYPES_AND_HALF(
  207. grad.scalar_type(), "roi_pool_backward_kernel", [&] {
  208. roi_pool_backward_kernel_impl<scalar_t><<<grid, block, 0, stream>>>(
  209. grad.numel(),
  210. grad.data_ptr<scalar_t>(),
  211. argmax_.data_ptr<int>(),
  212. num_rois,
  213. spatial_scale,
  214. channels,
  215. height,
  216. width,
  217. pooled_height,
  218. pooled_width,
  219. grad_input.data_ptr<scalar_t>(),
  220. rois_.data_ptr<scalar_t>(),
  221. n_stride,
  222. c_stride,
  223. h_stride,
  224. w_stride,
  225. grad_input.numel());
  226. });
  227. AT_CUDA_CHECK(cudaGetLastError());
  228. return grad_input;
  229. }
  230. } // namespace
  231. TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
  232. m.impl(
  233. TORCH_SELECTIVE_NAME("torchvision::roi_pool"),
  234. TORCH_FN(roi_pool_forward_kernel));
  235. m.impl(
  236. TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"),
  237. TORCH_FN(roi_pool_backward_kernel));
  238. }
  239. } // namespace ops
  240. } // namespace vision