roi_align_kernel.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. #include <ATen/ATen.h>
  2. #include <torch/library.h>
  3. #include "./roi_align_common.h"
  4. namespace vision {
  5. namespace ops {
  6. namespace {
  7. template <typename T>
  8. void roi_align_forward_kernel_impl(
  9. int n_rois,
  10. const T* input,
  11. const T& spatial_scale,
  12. int channels,
  13. int height,
  14. int width,
  15. int pooled_height,
  16. int pooled_width,
  17. int sampling_ratio,
  18. bool aligned,
  19. const T* rois,
  20. T* output) {
  21. // (n, c, ph, pw) is an element in the pooled output
  22. // can be parallelized using omp
  23. // #pragma omp parallel for num_threads(32)
  24. for (int n = 0; n < n_rois; n++) {
  25. int index_n = n * channels * pooled_width * pooled_height;
  26. const T* offset_rois = rois + n * 5;
  27. int roi_batch_ind = offset_rois[0];
  28. // Do not using rounding; this implementation detail is critical
  29. T offset = aligned ? (T)0.5 : (T)0.0;
  30. T roi_start_w = offset_rois[1] * spatial_scale - offset;
  31. T roi_start_h = offset_rois[2] * spatial_scale - offset;
  32. T roi_end_w = offset_rois[3] * spatial_scale - offset;
  33. T roi_end_h = offset_rois[4] * spatial_scale - offset;
  34. T roi_width = roi_end_w - roi_start_w;
  35. T roi_height = roi_end_h - roi_start_h;
  36. if (!aligned) {
  37. // Force malformed ROIs to be 1x1
  38. roi_width = std::max(roi_width, (T)1.);
  39. roi_height = std::max(roi_height, (T)1.);
  40. }
  41. T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
  42. T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
  43. // We use roi_bin_grid to sample the grid and mimic integral
  44. int roi_bin_grid_h = (sampling_ratio > 0)
  45. ? sampling_ratio
  46. : ceil(roi_height / pooled_height); // e.g., = 2
  47. int roi_bin_grid_w =
  48. (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
  49. // We do average (integral) pooling inside a bin
  50. // When the grid is empty, output zeros.
  51. const T count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4
  52. // we want to precalculate indices and weights shared by all channels,
  53. // this is the key point of optimization
  54. std::vector<detail::PreCalc<T>> pre_calc(
  55. roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height);
  56. detail::pre_calc_for_bilinear_interpolate(
  57. height,
  58. width,
  59. pooled_height,
  60. pooled_width,
  61. roi_start_h,
  62. roi_start_w,
  63. bin_size_h,
  64. bin_size_w,
  65. roi_bin_grid_h,
  66. roi_bin_grid_w,
  67. pre_calc);
  68. for (int c = 0; c < channels; c++) {
  69. int index_n_c = index_n + c * pooled_width * pooled_height;
  70. const T* offset_input =
  71. input + (roi_batch_ind * channels + c) * height * width;
  72. int pre_calc_index = 0;
  73. for (int ph = 0; ph < pooled_height; ph++) {
  74. for (int pw = 0; pw < pooled_width; pw++) {
  75. int index = index_n_c + ph * pooled_width + pw;
  76. T output_val = 0.;
  77. for (int iy = 0; iy < roi_bin_grid_h; iy++) {
  78. for (int ix = 0; ix < roi_bin_grid_w; ix++) {
  79. detail::PreCalc<T> pc = pre_calc[pre_calc_index];
  80. output_val += pc.w1 * offset_input[pc.pos1] +
  81. pc.w2 * offset_input[pc.pos2] +
  82. pc.w3 * offset_input[pc.pos3] + pc.w4 * offset_input[pc.pos4];
  83. pre_calc_index += 1;
  84. }
  85. }
  86. output_val /= count; // Average pooling
  87. output[index] = output_val;
  88. } // for pw
  89. } // for ph
  90. } // for c
  91. } // for n
  92. }
  93. template <typename T>
  94. void bilinear_interpolate_gradient(
  95. int height,
  96. int width,
  97. T y,
  98. T x,
  99. T& w1,
  100. T& w2,
  101. T& w3,
  102. T& w4,
  103. int& x_low,
  104. int& x_high,
  105. int& y_low,
  106. int& y_high,
  107. int index /* index for debug only*/) {
  108. // deal with cases that inverse elements are out of feature map boundary
  109. if (y < -1.0 || y > height || x < -1.0 || x > width) {
  110. // empty
  111. w1 = w2 = w3 = w4 = 0.;
  112. x_low = x_high = y_low = y_high = -1;
  113. return;
  114. }
  115. if (y <= 0)
  116. y = 0;
  117. if (x <= 0)
  118. x = 0;
  119. y_low = (int)y;
  120. x_low = (int)x;
  121. if (y_low >= height - 1) {
  122. y_high = y_low = height - 1;
  123. y = (T)y_low;
  124. } else {
  125. y_high = y_low + 1;
  126. }
  127. if (x_low >= width - 1) {
  128. x_high = x_low = width - 1;
  129. x = (T)x_low;
  130. } else {
  131. x_high = x_low + 1;
  132. }
  133. T ly = y - y_low;
  134. T lx = x - x_low;
  135. T hy = 1. - ly, hx = 1. - lx;
  136. // reference in forward
  137. // T v1 = input[y_low * width + x_low];
  138. // T v2 = input[y_low * width + x_high];
  139. // T v3 = input[y_high * width + x_low];
  140. // T v4 = input[y_high * width + x_high];
  141. // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
  142. w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
  143. }
  144. template <class T>
  145. inline void add(T* address, const T& val) {
  146. *address += val;
  147. }
  148. template <typename T>
  149. void roi_align_backward_kernel_impl(
  150. int nthreads,
  151. const T* grad_output,
  152. const T& spatial_scale,
  153. int channels,
  154. int height,
  155. int width,
  156. int pooled_height,
  157. int pooled_width,
  158. int sampling_ratio,
  159. bool aligned,
  160. T* grad_input,
  161. const T* rois,
  162. int n_stride,
  163. int c_stride,
  164. int h_stride,
  165. int w_stride) {
  166. for (int index = 0; index < nthreads; index++) {
  167. // (n, c, ph, pw) is an element in the pooled output
  168. int pw = index % pooled_width;
  169. int ph = (index / pooled_width) % pooled_height;
  170. int c = (index / pooled_width / pooled_height) % channels;
  171. int n = index / pooled_width / pooled_height / channels;
  172. const T* offset_rois = rois + n * 5;
  173. int roi_batch_ind = offset_rois[0];
  174. // Do not using rounding; this implementation detail is critical
  175. T offset = aligned ? (T)0.5 : (T)0.0;
  176. T roi_start_w = offset_rois[1] * spatial_scale - offset;
  177. T roi_start_h = offset_rois[2] * spatial_scale - offset;
  178. T roi_end_w = offset_rois[3] * spatial_scale - offset;
  179. T roi_end_h = offset_rois[4] * spatial_scale - offset;
  180. T roi_width = roi_end_w - roi_start_w;
  181. T roi_height = roi_end_h - roi_start_h;
  182. if (!aligned) {
  183. // Force malformed ROIs to be 1x1
  184. roi_width = std::max(roi_width, (T)1.);
  185. roi_height = std::max(roi_height, (T)1.);
  186. }
  187. T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
  188. T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
  189. T* offset_grad_input =
  190. grad_input + ((roi_batch_ind * channels + c) * height * width);
  191. int output_offset = n * n_stride + c * c_stride;
  192. const T* offset_grad_output = grad_output + output_offset;
  193. const T grad_output_this_bin =
  194. offset_grad_output[ph * h_stride + pw * w_stride];
  195. // We use roi_bin_grid to sample the grid and mimic integral
  196. int roi_bin_grid_h = (sampling_ratio > 0)
  197. ? sampling_ratio
  198. : ceil(roi_height / pooled_height); // e.g., = 2
  199. int roi_bin_grid_w =
  200. (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
  201. // We do average (integral) pooling inside a bin
  202. const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
  203. for (int iy = 0; iy < roi_bin_grid_h; iy++) {
  204. const T y = roi_start_h + ph * bin_size_h +
  205. static_cast<T>(iy + .5f) * bin_size_h /
  206. static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
  207. for (int ix = 0; ix < roi_bin_grid_w; ix++) {
  208. const T x = roi_start_w + pw * bin_size_w +
  209. static_cast<T>(ix + .5f) * bin_size_w /
  210. static_cast<T>(roi_bin_grid_w);
  211. T w1, w2, w3, w4;
  212. int x_low, x_high, y_low, y_high;
  213. bilinear_interpolate_gradient(
  214. height,
  215. width,
  216. y,
  217. x,
  218. w1,
  219. w2,
  220. w3,
  221. w4,
  222. x_low,
  223. x_high,
  224. y_low,
  225. y_high,
  226. index);
  227. T g1 = grad_output_this_bin * w1 / count;
  228. T g2 = grad_output_this_bin * w2 / count;
  229. T g3 = grad_output_this_bin * w3 / count;
  230. T g4 = grad_output_this_bin * w4 / count;
  231. if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
  232. // atomic add is not needed for now since it is single threaded
  233. add(offset_grad_input + y_low * width + x_low, static_cast<T>(g1));
  234. add(offset_grad_input + y_low * width + x_high, static_cast<T>(g2));
  235. add(offset_grad_input + y_high * width + x_low, static_cast<T>(g3));
  236. add(offset_grad_input + y_high * width + x_high, static_cast<T>(g4));
  237. } // if
  238. } // ix
  239. } // iy
  240. } // for
  241. }
  242. at::Tensor roi_align_forward_kernel(
  243. const at::Tensor& input,
  244. const at::Tensor& rois,
  245. double spatial_scale,
  246. int64_t pooled_height,
  247. int64_t pooled_width,
  248. int64_t sampling_ratio,
  249. bool aligned) {
  250. TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor");
  251. TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor");
  252. TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]");
  253. at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
  254. at::CheckedFrom c = "roi_align_forward_kernel";
  255. at::checkAllSameType(c, {input_t, rois_t});
  256. auto num_rois = rois.size(0);
  257. auto channels = input.size(1);
  258. auto height = input.size(2);
  259. auto width = input.size(3);
  260. at::Tensor output = at::zeros(
  261. {num_rois, channels, pooled_height, pooled_width}, input.options());
  262. if (output.numel() == 0)
  263. return output;
  264. auto input_ = input.contiguous(), rois_ = rois.contiguous();
  265. AT_DISPATCH_FLOATING_TYPES_AND_HALF(
  266. input.scalar_type(), "roi_align_forward_kernel", [&] {
  267. roi_align_forward_kernel_impl<scalar_t>(
  268. num_rois,
  269. input_.data_ptr<scalar_t>(),
  270. spatial_scale,
  271. channels,
  272. height,
  273. width,
  274. pooled_height,
  275. pooled_width,
  276. sampling_ratio,
  277. aligned,
  278. rois_.data_ptr<scalar_t>(),
  279. output.data_ptr<scalar_t>());
  280. });
  281. return output;
  282. }
  283. at::Tensor roi_align_backward_kernel(
  284. const at::Tensor& grad,
  285. const at::Tensor& rois,
  286. double spatial_scale,
  287. int64_t pooled_height,
  288. int64_t pooled_width,
  289. int64_t batch_size,
  290. int64_t channels,
  291. int64_t height,
  292. int64_t width,
  293. int64_t sampling_ratio,
  294. bool aligned) {
  295. TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor");
  296. TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor");
  297. at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2};
  298. at::CheckedFrom c = "roi_align_backward_kernel";
  299. at::checkAllSameType(c, {grad_t, rois_t});
  300. at::Tensor grad_input =
  301. at::zeros({batch_size, channels, height, width}, grad.options());
  302. // handle possibly empty gradients
  303. if (grad.numel() == 0) {
  304. return grad_input;
  305. }
  306. // get stride values to ensure indexing into gradients is correct.
  307. int n_stride = grad.stride(0);
  308. int c_stride = grad.stride(1);
  309. int h_stride = grad.stride(2);
  310. int w_stride = grad.stride(3);
  311. auto rois_ = rois.contiguous();
  312. AT_DISPATCH_FLOATING_TYPES_AND_HALF(
  313. grad.scalar_type(), "roi_align_backward_kernel", [&] {
  314. roi_align_backward_kernel_impl<scalar_t>(
  315. grad.numel(),
  316. grad.data_ptr<scalar_t>(),
  317. spatial_scale,
  318. channels,
  319. height,
  320. width,
  321. pooled_height,
  322. pooled_width,
  323. sampling_ratio,
  324. aligned,
  325. grad_input.data_ptr<scalar_t>(),
  326. rois_.data_ptr<scalar_t>(),
  327. n_stride,
  328. c_stride,
  329. h_stride,
  330. w_stride);
  331. });
  332. return grad_input;
  333. }
  334. } // namespace
  335. TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
  336. m.impl(
  337. TORCH_SELECTIVE_NAME("torchvision::roi_align"),
  338. TORCH_FN(roi_align_forward_kernel));
  339. m.impl(
  340. TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"),
  341. TORCH_FN(roi_align_backward_kernel));
  342. }
  343. } // namespace ops
  344. } // namespace vision