roi_align_kernel.cu 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480
  1. #include <ATen/ATen.h>
  2. #include <ATen/cuda/CUDAContext.h>
  3. #include <c10/cuda/CUDAGuard.h>
  4. #include <torch/library.h>
  5. #include <ATen/native/cuda/KernelUtils.cuh>
  6. #include "cuda_helpers.h"
  7. namespace vision {
  8. namespace ops {
  9. namespace {
  10. template <typename T>
  11. __device__ T bilinear_interpolate(
  12. const T* input,
  13. int height,
  14. int width,
  15. T y,
  16. T x,
  17. int index /* index for debug only*/) {
  18. // deal with cases that inverse elements are out of feature map boundary
  19. if (y < -1.0 || y > height || x < -1.0 || x > width) {
  20. // empty
  21. return 0;
  22. }
  23. if (y <= 0)
  24. y = 0;
  25. if (x <= 0)
  26. x = 0;
  27. int y_low = (int)y;
  28. int x_low = (int)x;
  29. int y_high;
  30. int x_high;
  31. if (y_low >= height - 1) {
  32. y_high = y_low = height - 1;
  33. y = (T)y_low;
  34. } else {
  35. y_high = y_low + 1;
  36. }
  37. if (x_low >= width - 1) {
  38. x_high = x_low = width - 1;
  39. x = (T)x_low;
  40. } else {
  41. x_high = x_low + 1;
  42. }
  43. T ly = y - y_low;
  44. T lx = x - x_low;
  45. T hy = 1. - ly, hx = 1. - lx;
  46. // do bilinear interpolation
  47. T v1 = input[y_low * width + x_low];
  48. T v2 = input[y_low * width + x_high];
  49. T v3 = input[y_high * width + x_low];
  50. T v4 = input[y_high * width + x_high];
  51. T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
  52. T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
  53. return val;
  54. }
  55. template <typename T>
  56. __global__ void roi_align_forward_kernel_impl(
  57. int nthreads,
  58. const T* input,
  59. const T spatial_scale,
  60. int channels,
  61. int height,
  62. int width,
  63. int pooled_height,
  64. int pooled_width,
  65. int sampling_ratio,
  66. bool aligned,
  67. const T* rois,
  68. T* output) {
  69. CUDA_1D_KERNEL_LOOP(index, nthreads) {
  70. // (n, c, ph, pw) is an element in the pooled output
  71. int pw = index % pooled_width;
  72. int ph = (index / pooled_width) % pooled_height;
  73. int c = (index / pooled_width / pooled_height) % channels;
  74. int n = index / pooled_width / pooled_height / channels;
  75. const T* offset_rois = rois + n * 5;
  76. int roi_batch_ind = offset_rois[0];
  77. // Do not using rounding; this implementation detail is critical
  78. T offset = aligned ? (T)0.5 : (T)0.0;
  79. T roi_start_w = offset_rois[1] * spatial_scale - offset;
  80. T roi_start_h = offset_rois[2] * spatial_scale - offset;
  81. T roi_end_w = offset_rois[3] * spatial_scale - offset;
  82. T roi_end_h = offset_rois[4] * spatial_scale - offset;
  83. T roi_width = roi_end_w - roi_start_w;
  84. T roi_height = roi_end_h - roi_start_h;
  85. if (!aligned) {
  86. // Force malformed ROIs to be 1x1
  87. roi_width = max(roi_width, (T)1.);
  88. roi_height = max(roi_height, (T)1.);
  89. }
  90. T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
  91. T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
  92. const T* offset_input =
  93. input + (roi_batch_ind * channels + c) * height * width;
  94. // We use roi_bin_grid to sample the grid and mimic integral
  95. int roi_bin_grid_h = (sampling_ratio > 0)
  96. ? sampling_ratio
  97. : ceil(roi_height / pooled_height); // e.g., = 2
  98. int roi_bin_grid_w =
  99. (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
  100. // We do average (integral) pooling inside a bin
  101. // When the grid is empty, output zeros.
  102. const T count = max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4
  103. T output_val = 0.;
  104. for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
  105. {
  106. const T y = roi_start_h + ph * bin_size_h +
  107. static_cast<T>(iy + .5f) * bin_size_h /
  108. static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
  109. for (int ix = 0; ix < roi_bin_grid_w; ix++) {
  110. const T x = roi_start_w + pw * bin_size_w +
  111. static_cast<T>(ix + .5f) * bin_size_w /
  112. static_cast<T>(roi_bin_grid_w);
  113. T val = bilinear_interpolate(offset_input, height, width, y, x, index);
  114. output_val += val;
  115. }
  116. }
  117. output_val /= count;
  118. output[index] = output_val;
  119. }
  120. }
  121. template <typename T>
  122. __device__ void bilinear_interpolate_gradient(
  123. int height,
  124. int width,
  125. T y,
  126. T x,
  127. T& w1,
  128. T& w2,
  129. T& w3,
  130. T& w4,
  131. int& x_low,
  132. int& x_high,
  133. int& y_low,
  134. int& y_high,
  135. int index /* index for debug only*/) {
  136. // deal with cases that inverse elements are out of feature map boundary
  137. if (y < -1.0 || y > height || x < -1.0 || x > width) {
  138. // empty
  139. w1 = w2 = w3 = w4 = 0.;
  140. x_low = x_high = y_low = y_high = -1;
  141. return;
  142. }
  143. if (y <= 0)
  144. y = 0;
  145. if (x <= 0)
  146. x = 0;
  147. y_low = (int)y;
  148. x_low = (int)x;
  149. if (y_low >= height - 1) {
  150. y_high = y_low = height - 1;
  151. y = (T)y_low;
  152. } else {
  153. y_high = y_low + 1;
  154. }
  155. if (x_low >= width - 1) {
  156. x_high = x_low = width - 1;
  157. x = (T)x_low;
  158. } else {
  159. x_high = x_low + 1;
  160. }
  161. T ly = y - y_low;
  162. T lx = x - x_low;
  163. T hy = 1. - ly, hx = 1. - lx;
  164. // reference in forward
  165. // T v1 = input[y_low * width + x_low];
  166. // T v2 = input[y_low * width + x_high];
  167. // T v3 = input[y_high * width + x_low];
  168. // T v4 = input[y_high * width + x_high];
  169. // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
  170. w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
  171. }
  172. template <typename T>
  173. __global__ void roi_align_backward_kernel_impl(
  174. int nthreads,
  175. const T* grad_output,
  176. const T spatial_scale,
  177. int channels,
  178. int height,
  179. int width,
  180. int pooled_height,
  181. int pooled_width,
  182. int sampling_ratio,
  183. bool aligned,
  184. T* grad_input,
  185. const T* rois,
  186. int n_stride,
  187. int c_stride,
  188. int h_stride,
  189. int w_stride,
  190. const int memory_span) {
  191. CUDA_1D_KERNEL_LOOP(index, nthreads) {
  192. // (n, c, ph, pw) is an element in the pooled output
  193. int pw = index % pooled_width;
  194. int ph = (index / pooled_width) % pooled_height;
  195. int c = (index / pooled_width / pooled_height) % channels;
  196. int n = index / pooled_width / pooled_height / channels;
  197. const T* offset_rois = rois + n * 5;
  198. int roi_batch_ind = offset_rois[0];
  199. // Do not using rounding; this implementation detail is critical
  200. T offset = aligned ? (T)0.5 : (T)0.0;
  201. T roi_start_w = offset_rois[1] * spatial_scale - offset;
  202. T roi_start_h = offset_rois[2] * spatial_scale - offset;
  203. T roi_end_w = offset_rois[3] * spatial_scale - offset;
  204. T roi_end_h = offset_rois[4] * spatial_scale - offset;
  205. T roi_width = roi_end_w - roi_start_w;
  206. T roi_height = roi_end_h - roi_start_h;
  207. if (!aligned) {
  208. // Force malformed ROIs to be 1x1
  209. roi_width = max(roi_width, (T)1.);
  210. roi_height = max(roi_height, (T)1.);
  211. }
  212. T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
  213. T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
  214. // We need to index the gradient using the tensor strides to access the
  215. // correct values.
  216. const int output_offset = n * n_stride + c * c_stride;
  217. const T* offset_grad_output = grad_output + output_offset;
  218. const T grad_output_this_bin =
  219. offset_grad_output[ph * h_stride + pw * w_stride];
  220. // We use roi_bin_grid to sample the grid and mimic integral
  221. int roi_bin_grid_h = (sampling_ratio > 0)
  222. ? sampling_ratio
  223. : ceil(roi_height / pooled_height); // e.g., = 2
  224. int roi_bin_grid_w =
  225. (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
  226. // We do average (integral) pooling inside a bin
  227. const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
  228. const int input_offset = (roi_batch_ind * channels + c) * height * width;
  229. for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
  230. {
  231. const T y = roi_start_h + ph * bin_size_h +
  232. static_cast<T>(iy + .5f) * bin_size_h /
  233. static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
  234. for (int ix = 0; ix < roi_bin_grid_w; ix++) {
  235. const T x = roi_start_w + pw * bin_size_w +
  236. static_cast<T>(ix + .5f) * bin_size_w /
  237. static_cast<T>(roi_bin_grid_w);
  238. T w1, w2, w3, w4;
  239. int x_low, x_high, y_low, y_high;
  240. bilinear_interpolate_gradient(
  241. height,
  242. width,
  243. y,
  244. x,
  245. w1,
  246. w2,
  247. w3,
  248. w4,
  249. x_low,
  250. x_high,
  251. y_low,
  252. y_high,
  253. index);
  254. T g1 = grad_output_this_bin * w1 / count;
  255. T g2 = grad_output_this_bin * w2 / count;
  256. T g3 = grad_output_this_bin * w3 / count;
  257. T g4 = grad_output_this_bin * w4 / count;
  258. if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
  259. at::native::fastAtomicAdd(
  260. grad_input,
  261. input_offset + y_low * width + x_low,
  262. memory_span,
  263. static_cast<T>(g1),
  264. true);
  265. at::native::fastAtomicAdd(
  266. grad_input,
  267. input_offset + y_low * width + x_high,
  268. memory_span,
  269. static_cast<T>(g2),
  270. true);
  271. at::native::fastAtomicAdd(
  272. grad_input,
  273. input_offset + y_high * width + x_low,
  274. memory_span,
  275. static_cast<T>(g3),
  276. true);
  277. at::native::fastAtomicAdd(
  278. grad_input,
  279. input_offset + y_high * width + x_high,
  280. memory_span,
  281. static_cast<T>(g4),
  282. true);
  283. } // if
  284. } // ix
  285. } // iy
  286. } // CUDA_1D_KERNEL_LOOP
  287. }
  288. at::Tensor roi_align_forward_kernel(
  289. const at::Tensor& input,
  290. const at::Tensor& rois,
  291. double spatial_scale,
  292. int64_t pooled_height,
  293. int64_t pooled_width,
  294. int64_t sampling_ratio,
  295. bool aligned) {
  296. TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor");
  297. TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor");
  298. TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]");
  299. at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
  300. at::CheckedFrom c = "roi_align_forward_kernel";
  301. at::checkAllSameGPU(c, {input_t, rois_t});
  302. at::checkAllSameType(c, {input_t, rois_t});
  303. at::cuda::CUDAGuard device_guard(input.device());
  304. auto num_rois = rois.size(0);
  305. auto channels = input.size(1);
  306. auto height = input.size(2);
  307. auto width = input.size(3);
  308. at::Tensor output = at::zeros(
  309. {num_rois, channels, pooled_height, pooled_width}, input.options());
  310. auto output_size = num_rois * pooled_height * pooled_width * channels;
  311. cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  312. dim3 grid(std::min(
  313. ceil_div(static_cast<int64_t>(output_size), static_cast<int64_t>(512)),
  314. static_cast<int64_t>(4096)));
  315. dim3 block(512);
  316. if (output.numel() == 0) {
  317. AT_CUDA_CHECK(cudaGetLastError());
  318. return output;
  319. }
  320. auto input_ = input.contiguous(), rois_ = rois.contiguous();
  321. AT_DISPATCH_FLOATING_TYPES_AND_HALF(
  322. input.scalar_type(), "roi_align_forward_kernel", [&] {
  323. roi_align_forward_kernel_impl<scalar_t><<<grid, block, 0, stream>>>(
  324. output_size,
  325. input_.data_ptr<scalar_t>(),
  326. spatial_scale,
  327. channels,
  328. height,
  329. width,
  330. pooled_height,
  331. pooled_width,
  332. sampling_ratio,
  333. aligned,
  334. rois_.data_ptr<scalar_t>(),
  335. output.data_ptr<scalar_t>());
  336. });
  337. AT_CUDA_CHECK(cudaGetLastError());
  338. return output;
  339. }
  340. at::Tensor roi_align_backward_kernel(
  341. const at::Tensor& grad,
  342. const at::Tensor& rois,
  343. double spatial_scale,
  344. int64_t pooled_height,
  345. int64_t pooled_width,
  346. int64_t batch_size,
  347. int64_t channels,
  348. int64_t height,
  349. int64_t width,
  350. int64_t sampling_ratio,
  351. bool aligned) {
  352. TORCH_CHECK(grad.is_cuda(), "grad must be a CUDA tensor");
  353. TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor");
  354. at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2};
  355. at::CheckedFrom c = "roi_align_backward_kernel";
  356. at::checkAllSameGPU(c, {grad_t, rois_t});
  357. at::checkAllSameType(c, {grad_t, rois_t});
  358. at::cuda::CUDAGuard device_guard(grad.device());
  359. at::Tensor grad_input =
  360. at::zeros({batch_size, channels, height, width}, grad.options());
  361. cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  362. dim3 grid(std::min(
  363. ceil_div(static_cast<int64_t>(grad.numel()), static_cast<int64_t>(512)),
  364. static_cast<int64_t>(4096)));
  365. dim3 block(512);
  366. // handle possibly empty gradients
  367. if (grad.numel() == 0) {
  368. AT_CUDA_CHECK(cudaGetLastError());
  369. return grad_input;
  370. }
  371. int n_stride = grad.stride(0);
  372. int c_stride = grad.stride(1);
  373. int h_stride = grad.stride(2);
  374. int w_stride = grad.stride(3);
  375. at::globalContext().alertNotDeterministic("roi_align_backward_kernel");
  376. auto rois_ = rois.contiguous();
  377. AT_DISPATCH_FLOATING_TYPES_AND_HALF(
  378. grad.scalar_type(), "roi_align_backward_kernel", [&] {
  379. roi_align_backward_kernel_impl<scalar_t><<<grid, block, 0, stream>>>(
  380. grad.numel(),
  381. grad.data_ptr<scalar_t>(),
  382. spatial_scale,
  383. channels,
  384. height,
  385. width,
  386. pooled_height,
  387. pooled_width,
  388. sampling_ratio,
  389. aligned,
  390. grad_input.data_ptr<scalar_t>(),
  391. rois_.data_ptr<scalar_t>(),
  392. n_stride,
  393. c_stride,
  394. h_stride,
  395. w_stride,
  396. grad_input.numel());
  397. });
  398. AT_CUDA_CHECK(cudaGetLastError());
  399. return grad_input;
  400. }
  401. } // namespace
  402. TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
  403. m.impl(
  404. TORCH_SELECTIVE_NAME("torchvision::roi_align"),
  405. TORCH_FN(roi_align_forward_kernel));
  406. m.impl(
  407. TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"),
  408. TORCH_FN(roi_align_backward_kernel));
  409. }
  410. } // namespace ops
  411. } // namespace vision