mps_kernels.h 47 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102
  1. #include <ATen/native/mps/OperationUtils.h>
  2. namespace vision {
  3. namespace ops {
  4. namespace mps {
  5. static const char* METAL_VISION = R"VISION_METAL(
  6. #include <metal_atomic>
  7. #include <metal_stdlib>
  8. using namespace metal;
  9. /*----------Macros----------*/
  10. #define MPS_1D_KERNEL_LOOP_T(i, n, n_tgs, index_t) \
  11. for (index_t i = (tgid.x * tptg.x) + tid2.x; i < (n); \
  12. i += (tptg.x * n_tgs))
  13. #define MPS_1D_KERNEL_LOOP(i, n, n_tgs) MPS_1D_KERNEL_LOOP_T(i, n, n_tgs, uint)
  14. /*----------Helpers--------*/
  15. template <typename T>
  16. inline T ceil_div(T n, T m) {
  17. return (n + m - 1) / m;
  18. }
  19. template <typename T>
  20. inline void atomic_add_float( device T* data_ptr, const T val)
  21. {
  22. #if __METAL_VERSION__ >= 300
  23. // atomic_float is supported in Metal 3 (macOS Ventura) onward.
  24. device atomic_fetch_add_explicit((device atomic_float*) data_ptr, val, memory_order_relaxed);
  25. #else
  26. // Custom atomic addition implementation
  27. // https://github.com/ShoYamanishi/AppleNumericalComputing/blob/053f06c1f5a831095c4bcc29aaf11366fce5231e/03_dot/metal/dot.metal#L447-L472
  28. // https://forums.developer.nvidia.com/t/atomicadd-float-float-atomicmul-float-float/14639
  29. // https://on-demand.gputechconf.com/gtc/2013/presentations/S3101-Atomic-Memory-Operations.pdf (See the last slide)
  30. // Create an atomic uint pointer for atomic transaction.
  31. device atomic_uint* atom_var = (device atomic_uint*)data_ptr;
  32. // Create necessary storage.
  33. uint fetched_uint, assigning_uint;
  34. T fetched_float, assigning_float;
  35. // Replace the value in atom_var with 0 and return the previous value in atom_var.
  36. fetched_uint = atomic_exchange_explicit( atom_var, 0 /*desired*/, memory_order_relaxed);
  37. // Read out the previous value as float.
  38. fetched_float = *( (thread T*) &fetched_uint );
  39. // Do addition and represent the addition result in uint for atomic transaction.
  40. assigning_float = fetched_float + val;
  41. assigning_uint = *((thread uint*) &assigning_float);
  42. // atom_var should be 0 now, try to assign the addition result back to the atom_var (data_ptr).
  43. while ((fetched_uint = atomic_exchange_explicit( atom_var, assigning_uint /*desired*/, memory_order_relaxed)) != 0) {
  44. // If atom_var was not 0, i.e. fetched_uint != 0, it means that the data has been modified by other threads.
  45. // Try to assign 0 and get the previously assigned addition result.
  46. uint fetched_uint_again = atomic_exchange_explicit(atom_var, 0 /*desired*/, memory_order_relaxed);
  47. T fetched_float_again = *( (thread T*) &fetched_uint_again );
  48. // Re-add again
  49. fetched_float = *((thread T*) &(fetched_uint));
  50. // Previously assigned addition result + addition result from other threads.
  51. assigning_float = fetched_float_again + fetched_float;
  52. assigning_uint = *( (thread uint*) &assigning_float);
  53. }
  54. #endif
  55. }
  56. template <typename T, typename integer_t>
  57. inline T bilinear_interpolate(
  58. constant T* input,
  59. integer_t height,
  60. integer_t width,
  61. T y,
  62. T x,
  63. uint index /* index for debug only*/) {
  64. // deal with cases that inverse elements are out of feature map boundary
  65. if (y < -1.0 || y > height || x < -1.0 || x > width) {
  66. // empty
  67. return 0;
  68. }
  69. if (y <= 0)
  70. y = 0;
  71. if (x <= 0)
  72. x = 0;
  73. integer_t y_low = (integer_t)y;
  74. integer_t x_low = (integer_t)x;
  75. integer_t y_high;
  76. integer_t x_high;
  77. if (y_low >= height - 1) {
  78. y_high = y_low = height - 1;
  79. y = (T)y_low;
  80. } else {
  81. y_high = y_low + 1;
  82. }
  83. if (x_low >= width - 1) {
  84. x_high = x_low = width - 1;
  85. x = (T)x_low;
  86. } else {
  87. x_high = x_low + 1;
  88. }
  89. T ly = y - y_low;
  90. T lx = x - x_low;
  91. T hy = 1. - ly, hx = 1. - lx;
  92. // do bilinear interpolation
  93. T v1 = input[y_low * width + x_low];
  94. T v2 = input[y_low * width + x_high];
  95. T v3 = input[y_high * width + x_low];
  96. T v4 = input[y_high * width + x_high];
  97. T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
  98. T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
  99. return val;
  100. }
  101. template <typename T, typename integer_t>
  102. inline void bilinear_interpolate_gradient(
  103. integer_t height,
  104. integer_t width,
  105. T y,
  106. T x,
  107. thread T& w1,
  108. thread T& w2,
  109. thread T& w3,
  110. thread T& w4,
  111. thread integer_t& x_low,
  112. thread integer_t& x_high,
  113. thread integer_t& y_low,
  114. thread integer_t& y_high,
  115. uint index /* index for debug only*/) {
  116. // deal with cases that inverse elements are out of feature map boundary
  117. if (y < -1.0 || y > height || x < -1.0 || x > width) {
  118. // empty
  119. w1 = w2 = w3 = w4 = 0.;
  120. x_low = x_high = y_low = y_high = -1;
  121. return;
  122. }
  123. if (y <= 0)
  124. y = 0;
  125. if (x <= 0)
  126. x = 0;
  127. y_low = (integer_t)y;
  128. x_low = (integer_t)x;
  129. if (y_low >= height - 1) {
  130. y_high = y_low = height - 1;
  131. y = (T)y_low;
  132. } else {
  133. y_high = y_low + 1;
  134. }
  135. if (x_low >= width - 1) {
  136. x_high = x_low = width - 1;
  137. x = (T)x_low;
  138. } else {
  139. x_high = x_low + 1;
  140. }
  141. T ly = y - y_low;
  142. T lx = x - x_low;
  143. T hy = 1. - ly, hx = 1. - lx;
  144. // reference in forward
  145. // T v1 = input[y_low * width + x_low];
  146. // T v2 = input[y_low * width + x_high];
  147. // T v3 = input[y_high * width + x_low];
  148. // T v4 = input[y_high * width + x_high];
  149. // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
  150. w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
  151. }
  152. template <typename T, typename scalar_t>
  153. inline bool IoU(
  154. constant T & a,
  155. threadgroup T & b,
  156. const float threshold) {
  157. auto xx1 = max(a.x, b.x);
  158. auto yy1 = max(a.y, b.y);
  159. auto xx2 = min(a.z, b.z);
  160. auto yy2 = min(a.w, b.w);
  161. auto w = max(static_cast<scalar_t>(0), xx2 - xx1);
  162. auto h = max(static_cast<scalar_t>(0), yy2 - yy1);
  163. // Upcast to float before multiplications to circumvent precision issues in half.
  164. auto inter = static_cast<float>(w) * static_cast<float>(h);
  165. auto area_b = static_cast<float>(b.z - b.x) * static_cast<float>(b.w - b.y);
  166. auto area_a = static_cast<float>(a.z - a.x) * static_cast<float>(a.w - a.y);
  167. return (inter / (area_a + area_b - inter)) > threshold;
  168. }
  169. /*----------Kernels----------*/
  170. // This should be in sync with the one in nms_kernel.mm.
  171. // Since metal does not support dynamic array,
  172. // we need to make it static instead of deriving it from [[threads_per_threadgroup]].
  173. constant int64_t nmsThreadsPerBlock = sizeof(uint64_t) * 8;
  174. template<typename T, typename scalar_t>
  175. kernel void nms(constant T * dev_boxes [[buffer(0)]],
  176. device uint64_t * mask [[buffer(1)]],
  177. constant int64_t & n_boxes [[buffer(2)]],
  178. constant float & iou_threshold [[buffer(3)]],
  179. uint2 tgid [[threadgroup_position_in_grid]],
  180. uint2 tid2 [[thread_position_in_threadgroup]]) {
  181. const uint row_start = tgid.y;
  182. const uint col_start = tgid.x;
  183. const uint tid = tid2.x;
  184. const uint row_size =
  185. min(n_boxes - row_start * nmsThreadsPerBlock, nmsThreadsPerBlock);
  186. const uint col_size =
  187. min(n_boxes - col_start * nmsThreadsPerBlock, nmsThreadsPerBlock);
  188. threadgroup T block_boxes[nmsThreadsPerBlock];
  189. block_boxes[tid] = dev_boxes[nmsThreadsPerBlock * col_start + tid];
  190. threadgroup_barrier(mem_flags::mem_threadgroup);
  191. if (tid < row_size) {
  192. const uint cur_box_idx = nmsThreadsPerBlock * row_start + tid;
  193. uint64_t t = 0;
  194. uint start = 0;
  195. if (row_start == col_start) {
  196. start = tid + 1;
  197. }
  198. for (uint i = start; i < col_size; i++){
  199. if (IoU<T, scalar_t>(dev_boxes[cur_box_idx], block_boxes[i], iou_threshold)){
  200. t |= static_cast<uint64_t>(1) << i; // discard 1 keep 0
  201. }
  202. }
  203. const uint col_blocks = ceil_div(n_boxes, nmsThreadsPerBlock);
  204. mask[cur_box_idx * col_blocks + col_start] = t;
  205. }
  206. }
  207. #define REGISTER_NMS_OP(DTYPE) \
  208. template \
  209. [[host_name("nms_" #DTYPE)]] \
  210. kernel void nms<DTYPE ## 4, DTYPE>( \
  211. constant DTYPE ## 4 * dev_boxes [[buffer(0)]], \
  212. device uint64_t * mask [[buffer(1)]], \
  213. constant int64_t & n_boxes [[buffer(2)]], \
  214. constant float & iou_threshold [[buffer(3)]], \
  215. uint2 tgid [[threadgroup_position_in_grid]], \
  216. uint2 tid2 [[thread_position_in_threadgroup]]);
  217. template<typename T, typename integer_t>
  218. kernel void roi_align(
  219. constant T * input [[buffer(0)]],
  220. constant T * rois [[buffer(1)]],
  221. device T * output [[buffer(2)]],
  222. constant int64_t & output_size [[buffer(3)]],
  223. constant int64_t & channels [[buffer(4)]],
  224. constant int64_t & height [[buffer(5)]],
  225. constant int64_t & width [[buffer(6)]],
  226. constant int64_t & pooled_height [[buffer(7)]],
  227. constant int64_t & pooled_width [[buffer(8)]],
  228. constant int64_t & sampling_ratio [[buffer(9)]],
  229. constant bool & aligned [[buffer(10)]],
  230. constant float & spatial_scale [[buffer(11)]],
  231. uint2 tgid [[threadgroup_position_in_grid]],
  232. uint2 tptg [[threads_per_threadgroup]],
  233. uint2 tid2 [[thread_position_in_threadgroup]]){
  234. MPS_1D_KERNEL_LOOP(index, output_size, 1) {
  235. // (n, c, ph, pw) is an element in the pooled output
  236. integer_t pw = index % pooled_width;
  237. integer_t ph = (index / pooled_width) % pooled_height;
  238. integer_t c = (index / pooled_width / pooled_height) % channels;
  239. integer_t n = index / pooled_width / pooled_height / channels;
  240. constant T* offset_rois = rois + n * 5;
  241. integer_t roi_batch_ind = offset_rois[0];
  242. // Do not using rounding; this implementation detail is critical
  243. T offset = aligned ? (T)0.5 : (T)0.0;
  244. T roi_start_w = offset_rois[1] * spatial_scale - offset;
  245. T roi_start_h = offset_rois[2] * spatial_scale - offset;
  246. T roi_end_w = offset_rois[3] * spatial_scale - offset;
  247. T roi_end_h = offset_rois[4] * spatial_scale - offset;
  248. T roi_width = roi_end_w - roi_start_w;
  249. T roi_height = roi_end_h - roi_start_h;
  250. if (!aligned) {
  251. // Force malformed ROIs to be 1x1
  252. roi_width = max(roi_width, (T)1.);
  253. roi_height = max(roi_height, (T)1.);
  254. }
  255. T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
  256. T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
  257. constant T* offset_input =
  258. input + (roi_batch_ind * channels + c) * height * width;
  259. // We use roi_bin_grid to sample the grid and mimic integral
  260. integer_t roi_bin_grid_h = (sampling_ratio > 0)
  261. ? sampling_ratio
  262. : ceil(roi_height / pooled_height); // e.g., = 2
  263. integer_t roi_bin_grid_w =
  264. (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
  265. // We do average (integral) pooling inside a bin
  266. // When the grid is empty, output zeros.
  267. const T count = max(roi_bin_grid_h * roi_bin_grid_w, static_cast<integer_t>(1)); // e.g. = 4
  268. T output_val = 0.;
  269. for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
  270. {
  271. const T y = roi_start_h + ph * bin_size_h +
  272. static_cast<T>(iy + .5f) * bin_size_h /
  273. static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
  274. for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) {
  275. const T x = roi_start_w + pw * bin_size_w +
  276. static_cast<T>(ix + .5f) * bin_size_w /
  277. static_cast<T>(roi_bin_grid_w);
  278. T val = bilinear_interpolate(offset_input, height, width, y, x, index);
  279. output_val += val;
  280. }
  281. }
  282. output_val /= count;
  283. output[index] = output_val;
  284. }
  285. }
  286. #define REGISTER_ROI_ALIGN_OP(DTYPE, INT_DTYPE) \
  287. template \
  288. [[host_name("roi_align_" #DTYPE)]] \
  289. kernel void roi_align<DTYPE, INT_DTYPE>( \
  290. constant DTYPE * input [[buffer(0)]], \
  291. constant DTYPE * rois [[buffer(1)]], \
  292. device DTYPE * output [[buffer(2)]], \
  293. constant int64_t & output_size [[buffer(3)]], \
  294. constant int64_t & channels [[buffer(4)]], \
  295. constant int64_t & height [[buffer(5)]], \
  296. constant int64_t & width [[buffer(6)]], \
  297. constant int64_t & pooled_height [[buffer(7)]], \
  298. constant int64_t & pooled_width [[buffer(8)]], \
  299. constant int64_t & sampling_ratio [[buffer(9)]], \
  300. constant bool & aligned [[buffer(10)]], \
  301. constant float & spatial_scale [[buffer(11)]], \
  302. uint2 tgid [[threadgroup_position_in_grid]], \
  303. uint2 tptg [[threads_per_threadgroup]], \
  304. uint2 tid2 [[thread_position_in_threadgroup]]);
  305. template<typename T, typename integer_t>
  306. kernel void roi_align_backward(
  307. constant T * grad_output [[buffer(0)]],
  308. constant T * rois [[buffer(1)]],
  309. device T * grad_input [[buffer(2)]],
  310. constant int64_t & output_size [[buffer(3)]],
  311. constant int64_t & channels [[buffer(4)]],
  312. constant int64_t & height [[buffer(5)]],
  313. constant int64_t & width [[buffer(6)]],
  314. constant int64_t & pooled_height [[buffer(7)]],
  315. constant int64_t & pooled_width [[buffer(8)]],
  316. constant int64_t & sampling_ratio [[buffer(9)]],
  317. constant bool & aligned [[buffer(10)]],
  318. constant float & spatial_scale [[buffer(11)]],
  319. constant int64_t & n_stride [[buffer(12)]],
  320. constant int64_t & c_stride [[buffer(13)]],
  321. constant int64_t & h_stride [[buffer(14)]],
  322. constant int64_t & w_stride [[buffer(15)]],
  323. uint2 tgid [[threadgroup_position_in_grid]],
  324. uint2 tptg [[threads_per_threadgroup]],
  325. uint2 tid2 [[thread_position_in_threadgroup]]){
  326. MPS_1D_KERNEL_LOOP(index, output_size, 1) {
  327. // (n, c, ph, pw) is an element in the pooled output
  328. integer_t pw = index % pooled_width;
  329. integer_t ph = (index / pooled_width) % pooled_height;
  330. integer_t c = (index / pooled_width / pooled_height) % channels;
  331. integer_t n = index / pooled_width / pooled_height / channels;
  332. constant T* offset_rois = rois + n * 5;
  333. integer_t roi_batch_ind = offset_rois[0];
  334. // Do not using rounding; this implementation detail is critical
  335. T offset = aligned ? (T)0.5 : (T)0.0;
  336. T roi_start_w = offset_rois[1] * spatial_scale - offset;
  337. T roi_start_h = offset_rois[2] * spatial_scale - offset;
  338. T roi_end_w = offset_rois[3] * spatial_scale - offset;
  339. T roi_end_h = offset_rois[4] * spatial_scale - offset;
  340. T roi_width = roi_end_w - roi_start_w;
  341. T roi_height = roi_end_h - roi_start_h;
  342. if (!aligned) {
  343. // Force malformed ROIs to be 1x1
  344. roi_width = max(roi_width, (T)1.);
  345. roi_height = max(roi_height, (T)1.);
  346. }
  347. T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
  348. T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
  349. // We need to index the gradient using the tensor strides to access the
  350. // correct values.
  351. const integer_t output_offset = n * n_stride + c * c_stride;
  352. constant T* offset_grad_output = grad_output + output_offset;
  353. const T grad_output_this_bin =
  354. offset_grad_output[ph * h_stride + pw * w_stride];
  355. // We use roi_bin_grid to sample the grid and mimic integral
  356. integer_t roi_bin_grid_h = (sampling_ratio > 0)
  357. ? sampling_ratio
  358. : ceil(roi_height / pooled_height); // e.g., = 2
  359. integer_t roi_bin_grid_w =
  360. (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
  361. // We do average (integral) pooling inside a bin
  362. const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
  363. const integer_t input_offset = (roi_batch_ind * channels + c) * height * width;
  364. for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
  365. {
  366. const T y = roi_start_h + ph * bin_size_h +
  367. static_cast<T>(iy + .5f) * bin_size_h /
  368. static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
  369. for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) {
  370. const T x = roi_start_w + pw * bin_size_w +
  371. static_cast<T>(ix + .5f) * bin_size_w /
  372. static_cast<T>(roi_bin_grid_w);
  373. T w1, w2, w3, w4;
  374. integer_t x_low, x_high, y_low, y_high;
  375. bilinear_interpolate_gradient(
  376. height,
  377. width,
  378. y,
  379. x,
  380. w1,
  381. w2,
  382. w3,
  383. w4,
  384. x_low,
  385. x_high,
  386. y_low,
  387. y_high,
  388. index);
  389. T g1 = grad_output_this_bin * w1 / count;
  390. T g2 = grad_output_this_bin * w2 / count;
  391. T g3 = grad_output_this_bin * w3 / count;
  392. T g4 = grad_output_this_bin * w4 / count;
  393. if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
  394. atomic_add_float(grad_input + input_offset + y_low * width + x_low, static_cast<T>(g1));
  395. atomic_add_float(grad_input + input_offset + y_low * width + x_high, static_cast<T>(g2));
  396. atomic_add_float(grad_input + input_offset + y_high * width + x_low, static_cast<T>(g3));
  397. atomic_add_float(grad_input + input_offset + y_high * width + x_high, static_cast<T>(g4));
  398. } // if
  399. } // ix
  400. } // iy
  401. } // MPS_1D_KERNEL_LOOP
  402. }
  403. #define REGISTER_ROI_ALIGN_BACKWARD_OP(DTYPE, INT_DTYPE) \
  404. template \
  405. [[host_name("roi_align_backward_" #DTYPE)]] \
  406. kernel void roi_align_backward<DTYPE, INT_DTYPE>( \
  407. constant DTYPE * grad_output [[buffer(0)]], \
  408. constant DTYPE * rois [[buffer(1)]], \
  409. device DTYPE * grad_input [[buffer(2)]], \
  410. constant int64_t & output_size [[buffer(3)]], \
  411. constant int64_t & channels [[buffer(4)]], \
  412. constant int64_t & height [[buffer(5)]], \
  413. constant int64_t & width [[buffer(6)]], \
  414. constant int64_t & pooled_height [[buffer(7)]], \
  415. constant int64_t & pooled_width [[buffer(8)]], \
  416. constant int64_t & sampling_ratio [[buffer(9)]], \
  417. constant bool & aligned [[buffer(10)]], \
  418. constant float & spatial_scale [[buffer(11)]], \
  419. constant int64_t & n_stride [[buffer(12)]], \
  420. constant int64_t & c_stride [[buffer(13)]], \
  421. constant int64_t & h_stride [[buffer(14)]], \
  422. constant int64_t & w_stride [[buffer(15)]], \
  423. uint2 tgid [[threadgroup_position_in_grid]], \
  424. uint2 tptg [[threads_per_threadgroup]], \
  425. uint2 tid2 [[thread_position_in_threadgroup]]);
  426. template<typename T, typename integer_t>
  427. kernel void roi_pool(
  428. constant T * input [[buffer(0)]],
  429. constant T * rois [[buffer(1)]],
  430. device T * output [[buffer(2)]],
  431. device int64_t * argmax [[buffer(3)]],
  432. constant int64_t & output_size [[buffer(4)]],
  433. constant int64_t & channels [[buffer(5)]],
  434. constant int64_t & height [[buffer(6)]],
  435. constant int64_t & width [[buffer(7)]],
  436. constant int64_t & pooled_height [[buffer(8)]],
  437. constant int64_t & pooled_width [[buffer(9)]],
  438. constant float & spatial_scale [[buffer(10)]],
  439. uint2 tgid [[threadgroup_position_in_grid]],
  440. uint2 tptg [[threads_per_threadgroup]],
  441. uint2 tid2 [[thread_position_in_threadgroup]]){
  442. MPS_1D_KERNEL_LOOP(index, output_size, 1) {
  443. // (n, c, ph, pw) is an element in the pooled output
  444. integer_t pw = index % pooled_width;
  445. integer_t ph = (index / pooled_width) % pooled_height;
  446. integer_t c = (index / pooled_width / pooled_height) % channels;
  447. integer_t n = index / pooled_width / pooled_height / channels;
  448. constant T* offset_rois = rois + n * 5;
  449. integer_t roi_batch_ind = offset_rois[0];
  450. integer_t roi_start_w = round(offset_rois[1] * spatial_scale);
  451. integer_t roi_start_h = round(offset_rois[2] * spatial_scale);
  452. integer_t roi_end_w = round(offset_rois[3] * spatial_scale);
  453. integer_t roi_end_h = round(offset_rois[4] * spatial_scale);
  454. // Force malformed ROIs to be 1x1
  455. integer_t roi_width = max(roi_end_w - roi_start_w + 1, static_cast<integer_t>(1));
  456. integer_t roi_height = max(roi_end_h - roi_start_h + 1, static_cast<integer_t>(1));
  457. T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
  458. T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
  459. integer_t hstart = static_cast<integer_t>(floor(static_cast<T>(ph) * bin_size_h));
  460. integer_t wstart = static_cast<integer_t>(floor(static_cast<T>(pw) * bin_size_w));
  461. integer_t hend = static_cast<integer_t>(ceil(static_cast<T>(ph + 1) * bin_size_h));
  462. integer_t wend = static_cast<integer_t>(ceil(static_cast<T>(pw + 1) * bin_size_w));
  463. // Add roi offsets and clip to input boundaries
  464. hstart = min(max(hstart + roi_start_h, static_cast<integer_t>(0)), static_cast<integer_t>(height));
  465. hend = min(max(hend + roi_start_h, static_cast<integer_t>(0)), static_cast<integer_t>(height));
  466. wstart = min(max(wstart + roi_start_w, static_cast<integer_t>(0)), static_cast<integer_t>(width));
  467. wend = min(max(wend + roi_start_w, static_cast<integer_t>(0)), static_cast<integer_t>(width));
  468. bool is_empty = (hend <= hstart) || (wend <= wstart);
  469. // Define an empty pooling region to be zero
  470. T maxval = is_empty ? 0 : -FLT_MAX;
  471. // If nothing is pooled, argmax = -1 causes nothing to be backprop'd
  472. integer_t maxidx = -1;
  473. constant T* offset_input =
  474. input + (roi_batch_ind * channels + c) * height * width;
  475. for (integer_t h = hstart; h < hend; ++h) {
  476. for (integer_t w = wstart; w < wend; ++w) {
  477. integer_t input_index = h * width + w;
  478. if (offset_input[input_index] > maxval) {
  479. maxval = offset_input[input_index];
  480. maxidx = input_index;
  481. }
  482. }
  483. }
  484. output[index] = maxval;
  485. argmax[index] = maxidx;
  486. }
  487. }
  488. #define REGISTER_ROI_POOL_OP(DTYPE, INT_DTYPE) \
  489. template \
  490. [[host_name("roi_pool_" #DTYPE)]] \
  491. kernel void roi_pool<DTYPE, INT_DTYPE>( \
  492. constant DTYPE * input [[buffer(0)]], \
  493. constant DTYPE * rois [[buffer(1)]], \
  494. device DTYPE * output [[buffer(2)]], \
  495. device int64_t * argmax_data [[buffer(3)]], \
  496. constant int64_t & output_size [[buffer(4)]], \
  497. constant int64_t & channels [[buffer(5)]], \
  498. constant int64_t & height [[buffer(6)]], \
  499. constant int64_t & width [[buffer(7)]], \
  500. constant int64_t & pooled_height [[buffer(8)]], \
  501. constant int64_t & pooled_width [[buffer(9)]], \
  502. constant float & spatial_scale [[buffer(10)]], \
  503. uint2 tgid [[threadgroup_position_in_grid]], \
  504. uint2 tptg [[threads_per_threadgroup]], \
  505. uint2 tid2 [[thread_position_in_threadgroup]]);
  506. template<typename T, typename integer_t>
  507. kernel void roi_pool_backward(
  508. constant T * grad_output [[buffer(0)]],
  509. constant T * rois [[buffer(1)]],
  510. constant int64_t * argmax_data [[buffer(2)]],
  511. device T * grad_input [[buffer(3)]],
  512. constant int64_t & output_size [[buffer(4)]],
  513. constant int64_t & channels [[buffer(5)]],
  514. constant int64_t & height [[buffer(6)]],
  515. constant int64_t & width [[buffer(7)]],
  516. constant int64_t & pooled_height [[buffer(8)]],
  517. constant int64_t & pooled_width [[buffer(9)]],
  518. constant float & spatial_scale [[buffer(10)]],
  519. constant int64_t & n_stride [[buffer(11)]],
  520. constant int64_t & c_stride [[buffer(12)]],
  521. constant int64_t & h_stride [[buffer(13)]],
  522. constant int64_t & w_stride [[buffer(14)]],
  523. uint2 tgid [[threadgroup_position_in_grid]],
  524. uint2 tptg [[threads_per_threadgroup]],
  525. uint2 tid2 [[thread_position_in_threadgroup]]){
  526. MPS_1D_KERNEL_LOOP(index, output_size, 1) {
  527. // (n, c, ph, pw) is an element in the pooled output
  528. integer_t pw = index % pooled_width;
  529. integer_t ph = (index / pooled_width) % pooled_height;
  530. integer_t c = (index / pooled_width / pooled_height) % channels;
  531. integer_t n = index / pooled_width / pooled_height / channels;
  532. constant T* offset_rois = rois + n * 5;
  533. integer_t roi_batch_ind = offset_rois[0];
  534. const integer_t output_offset = n * n_stride + c * c_stride;
  535. constant integer_t * argmax_data_offset =
  536. argmax_data + (n * channels + c) * pooled_height * pooled_width;
  537. const integer_t argmax = argmax_data_offset[ph * pooled_width + pw];
  538. const integer_t offset = (roi_batch_ind * channels + c) * height * width;
  539. if (argmax != -1) {
  540. atomic_add_float(grad_input + offset + argmax, static_cast<T>(grad_output[output_offset + ph * h_stride + pw * w_stride]));
  541. }
  542. } // MPS_1D_KERNEL_LOOP
  543. }
  544. #define REGISTER_ROI_POOL_BACKWARD_OP(DTYPE, INT_DTYPE) \
  545. template \
  546. [[host_name("roi_pool_backward_" #DTYPE)]] \
  547. kernel void roi_pool_backward<DTYPE, INT_DTYPE>( \
  548. constant DTYPE * grad_output [[buffer(0)]], \
  549. constant DTYPE * rois [[buffer(1)]], \
  550. constant int64_t * argmax_data [[buffer(2)]], \
  551. device DTYPE * grad_input [[buffer(3)]], \
  552. constant int64_t & output_size [[buffer(4)]], \
  553. constant int64_t & channels [[buffer(5)]], \
  554. constant int64_t & height [[buffer(6)]], \
  555. constant int64_t & width [[buffer(7)]], \
  556. constant int64_t & pooled_height [[buffer(8)]], \
  557. constant int64_t & pooled_width [[buffer(9)]], \
  558. constant float & spatial_scale [[buffer(10)]], \
  559. constant int64_t & n_stride [[buffer(11)]], \
  560. constant int64_t & c_stride [[buffer(12)]], \
  561. constant int64_t & h_stride [[buffer(13)]], \
  562. constant int64_t & w_stride [[buffer(14)]], \
  563. uint2 tgid [[threadgroup_position_in_grid]], \
  564. uint2 tptg [[threads_per_threadgroup]], \
  565. uint2 tid2 [[thread_position_in_threadgroup]]);
  566. template<typename T, typename integer_t>
  567. kernel void ps_roi_align(
  568. constant T * input [[buffer(0)]],
  569. constant T * rois [[buffer(1)]],
  570. device T * output [[buffer(2)]],
  571. device int64_t * channel_mapping [[buffer(3)]],
  572. constant int64_t & output_size [[buffer(4)]],
  573. constant int64_t & channels [[buffer(5)]],
  574. constant int64_t & height [[buffer(6)]],
  575. constant int64_t & width [[buffer(7)]],
  576. constant int64_t & pooled_height [[buffer(8)]],
  577. constant int64_t & pooled_width [[buffer(9)]],
  578. constant int64_t & sampling_ratio [[buffer(10)]],
  579. constant int64_t & channels_out [[buffer(11)]],
  580. constant float & spatial_scale [[buffer(12)]],
  581. uint2 tgid [[threadgroup_position_in_grid]],
  582. uint2 tptg [[threads_per_threadgroup]],
  583. uint2 tid2 [[thread_position_in_threadgroup]]){
  584. MPS_1D_KERNEL_LOOP(index, output_size, 1) {
  585. // (n, c_out, ph, pw) is an element in the pooled output
  586. integer_t pw = index % pooled_width;
  587. integer_t ph = (index / pooled_width) % pooled_height;
  588. integer_t c_out = (index / pooled_width / pooled_height) % channels_out;
  589. integer_t n = index / pooled_width / pooled_height / channels_out;
  590. // (n, c_in, ph, pw) is the associated element in the input
  591. integer_t c_in = (c_out * pooled_height + ph) * pooled_width + pw;
  592. // [start, end) interval for spatial sampling
  593. constant T* offset_rois = rois + n * 5;
  594. integer_t roi_batch_ind = offset_rois[0];
  595. // Do not using rounding; this implementation detail is critical
  596. T roi_start_w = offset_rois[1] * spatial_scale - static_cast<T>(0.5);
  597. T roi_start_h = offset_rois[2] * spatial_scale - static_cast<T>(0.5);
  598. T roi_end_w = offset_rois[3] * spatial_scale - static_cast<T>(0.5);
  599. T roi_end_h = offset_rois[4] * spatial_scale - static_cast<T>(0.5);
  600. T roi_width = roi_end_w - roi_start_w;
  601. T roi_height = roi_end_h - roi_start_h;
  602. T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
  603. T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
  604. // Do not using floor/ceil; this implementation detail is critical
  605. T hstart = static_cast<T>(ph) * bin_size_h + roi_start_h;
  606. T wstart = static_cast<T>(pw) * bin_size_w + roi_start_w;
  607. // We use roi_bin_grid to sample the grid and mimic integral
  608. integer_t roi_bin_grid_h = (sampling_ratio > 0)
  609. ? sampling_ratio
  610. : ceil(roi_height / pooled_height);
  611. integer_t roi_bin_grid_w =
  612. (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
  613. const T count = roi_bin_grid_h * roi_bin_grid_w;
  614. constant T* offset_input =
  615. input + (roi_batch_ind * channels + c_in) * height * width;
  616. T out_sum = 0;
  617. for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) {
  618. const T y = hstart +
  619. static_cast<T>(iy + .5f) * bin_size_h /
  620. static_cast<T>(roi_bin_grid_h);
  621. for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) {
  622. const T x = wstart +
  623. static_cast<T>(ix + .5f) * bin_size_w /
  624. static_cast<T>(roi_bin_grid_w);
  625. T val = bilinear_interpolate(offset_input, height, width, y, x, index);
  626. out_sum += val;
  627. }
  628. }
  629. out_sum /= count;
  630. output[index] = out_sum;
  631. channel_mapping[index] = c_in;
  632. }
  633. }
  634. #define REGISTER_PS_ROI_ALIGN_OP(DTYPE, INT_DTYPE) \
  635. template \
  636. [[host_name("ps_roi_align_" #DTYPE)]] \
  637. kernel void ps_roi_align<DTYPE, INT_DTYPE>( \
  638. constant DTYPE * input [[buffer(0)]], \
  639. constant DTYPE * rois [[buffer(1)]], \
  640. device DTYPE * output [[buffer(2)]], \
  641. device int64_t * channel_mapping [[buffer(3)]], \
  642. constant int64_t & output_size [[buffer(4)]], \
  643. constant int64_t & channels [[buffer(5)]], \
  644. constant int64_t & height [[buffer(6)]], \
  645. constant int64_t & width [[buffer(7)]], \
  646. constant int64_t & pooled_height [[buffer(8)]], \
  647. constant int64_t & pooled_width [[buffer(9)]], \
  648. constant int64_t & sampling_ratio [[buffer(10)]], \
  649. constant int64_t & channels_out [[buffer(11)]], \
  650. constant float & spatial_scale [[buffer(12)]], \
  651. uint2 tgid [[threadgroup_position_in_grid]], \
  652. uint2 tptg [[threads_per_threadgroup]], \
  653. uint2 tid2 [[thread_position_in_threadgroup]]);
  654. template<typename T, typename integer_t>
  655. kernel void ps_roi_align_backward(
  656. constant T * grad_output [[buffer(0)]],
  657. constant T * rois [[buffer(1)]],
  658. constant int64_t * channel_mapping [[buffer(2)]],
  659. device T * grad_input [[buffer(3)]],
  660. constant int64_t & output_size [[buffer(4)]],
  661. constant int64_t & channels [[buffer(5)]],
  662. constant int64_t & height [[buffer(6)]],
  663. constant int64_t & width [[buffer(7)]],
  664. constant int64_t & pooled_height [[buffer(8)]],
  665. constant int64_t & pooled_width [[buffer(9)]],
  666. constant int64_t & sampling_ratio [[buffer(10)]],
  667. constant int64_t & channels_out [[buffer(11)]],
  668. constant float & spatial_scale [[buffer(12)]],
  669. uint2 tgid [[threadgroup_position_in_grid]],
  670. uint2 tptg [[threads_per_threadgroup]],
  671. uint2 tid2 [[thread_position_in_threadgroup]]){
  672. MPS_1D_KERNEL_LOOP(index, output_size, 1) {
  673. // (n, *, ph, pw) is an element in the pooled output
  674. integer_t pw = index % pooled_width;
  675. integer_t ph = (index / pooled_width) % pooled_height;
  676. integer_t n = index / pooled_width / pooled_height / channels_out;
  677. constant T* offset_rois = rois + n * 5;
  678. integer_t roi_batch_ind = offset_rois[0];
  679. // Do not using rounding; this implementation detail is critical
  680. T roi_start_w = offset_rois[1] * spatial_scale - static_cast<T>(0.5);
  681. T roi_start_h = offset_rois[2] * spatial_scale - static_cast<T>(0.5);
  682. T roi_end_w = offset_rois[3] * spatial_scale - static_cast<T>(0.5);
  683. T roi_end_h = offset_rois[4] * spatial_scale - static_cast<T>(0.5);
  684. // Force too small ROIs to be 1x1
  685. T roi_width = roi_end_w - roi_start_w;
  686. T roi_height = roi_end_h - roi_start_h;
  687. T bin_size_h = roi_height / static_cast<T>(pooled_height);
  688. T bin_size_w = roi_width / static_cast<T>(pooled_width);
  689. integer_t c_in = channel_mapping[index];
  690. // Do not using floor/ceil; this implementation detail is critical
  691. T hstart = static_cast<T>(ph) * bin_size_h + roi_start_h;
  692. T wstart = static_cast<T>(pw) * bin_size_w + roi_start_w;
  693. const T grad_output_this_bin = grad_output[index];
  694. // We use roi_bin_grid to sample the grid and mimic integral
  695. integer_t roi_bin_grid_h = (sampling_ratio > 0)
  696. ? sampling_ratio
  697. : ceil(roi_height / pooled_height); // e.g., = 2
  698. integer_t roi_bin_grid_w =
  699. (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
  700. const T count = roi_bin_grid_h * roi_bin_grid_w;
  701. const integer_t offset = (roi_batch_ind * channels + c_in) * height * width;
  702. for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) {
  703. const T y = hstart +
  704. static_cast<T>(iy + .5f) * bin_size_h /
  705. static_cast<T>(roi_bin_grid_h);
  706. for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) {
  707. const T x = wstart +
  708. static_cast<T>(ix + .5f) * bin_size_w /
  709. static_cast<T>(roi_bin_grid_w);
  710. T w1, w2, w3, w4;
  711. integer_t x_low, x_high, y_low, y_high;
  712. bilinear_interpolate_gradient(
  713. height,
  714. width,
  715. y,
  716. x,
  717. w1,
  718. w2,
  719. w3,
  720. w4,
  721. x_low,
  722. x_high,
  723. y_low,
  724. y_high,
  725. index);
  726. T g1 = grad_output_this_bin * w1 / count;
  727. T g2 = grad_output_this_bin * w2 / count;
  728. T g3 = grad_output_this_bin * w3 / count;
  729. T g4 = grad_output_this_bin * w4 / count;
  730. if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
  731. atomic_add_float(grad_input + offset + y_low * width + x_low, static_cast<T>(g1));
  732. atomic_add_float(grad_input + offset + y_low * width + x_high, static_cast<T>(g2));
  733. atomic_add_float(grad_input + offset + y_high * width + x_low, static_cast<T>(g3));
  734. atomic_add_float(grad_input + offset + y_high * width + x_high, static_cast<T>(g4));
  735. } // if
  736. } // ix
  737. } // iy
  738. }
  739. }
  740. #define REGISTER_PS_ROI_ALIGN_BACKWARD_OP(DTYPE, INT_DTYPE) \
  741. template \
  742. [[host_name("ps_roi_align_backward_" #DTYPE)]] \
  743. kernel void ps_roi_align_backward<DTYPE, INT_DTYPE>( \
  744. constant DTYPE * grad_output [[buffer(0)]], \
  745. constant DTYPE * rois [[buffer(1)]], \
  746. constant int64_t * channel_mapping [[buffer(2)]], \
  747. device DTYPE * grad_input [[buffer(3)]], \
  748. constant int64_t & output_size [[buffer(4)]], \
  749. constant int64_t & channels [[buffer(5)]], \
  750. constant int64_t & height [[buffer(6)]], \
  751. constant int64_t & width [[buffer(7)]], \
  752. constant int64_t & pooled_height [[buffer(8)]], \
  753. constant int64_t & pooled_width [[buffer(9)]], \
  754. constant int64_t & sampling_ratio [[buffer(10)]], \
  755. constant int64_t & channels_out [[buffer(11)]], \
  756. constant float & spatial_scale [[buffer(12)]], \
  757. uint2 tgid [[threadgroup_position_in_grid]], \
  758. uint2 tptg [[threads_per_threadgroup]], \
  759. uint2 tid2 [[thread_position_in_threadgroup]]);
  760. template<typename T, typename integer_t>
  761. kernel void ps_roi_pool(
  762. constant T * input [[buffer(0)]],
  763. constant T * rois [[buffer(1)]],
  764. device T * output [[buffer(2)]],
  765. device int64_t * channel_mapping [[buffer(3)]],
  766. constant int64_t & output_size [[buffer(4)]],
  767. constant int64_t & channels [[buffer(5)]],
  768. constant int64_t & height [[buffer(6)]],
  769. constant int64_t & width [[buffer(7)]],
  770. constant int64_t & pooled_height [[buffer(8)]],
  771. constant int64_t & pooled_width [[buffer(9)]],
  772. constant int64_t & channels_out [[buffer(10)]],
  773. constant float & spatial_scale [[buffer(11)]],
  774. uint2 tgid [[threadgroup_position_in_grid]],
  775. uint2 tptg [[threads_per_threadgroup]],
  776. uint2 tid2 [[thread_position_in_threadgroup]]){
  777. MPS_1D_KERNEL_LOOP(index, output_size, 1) {
  778. // (n, c_out, ph, pw) is an element in the pooled output
  779. integer_t pw = index % pooled_width;
  780. integer_t ph = (index / pooled_width) % pooled_height;
  781. integer_t c_out = (index / (pooled_width * pooled_height)) % channels_out;
  782. integer_t n = index / pooled_width / pooled_height / channels_out;
  783. // (n, c_in, ph, pw) is the associated element in the input
  784. integer_t c_in = (c_out * pooled_height + ph) * pooled_width + pw;
  785. // [start, end) interval for spatial sampling
  786. constant T* offset_rois = rois + n * 5;
  787. integer_t roi_batch_ind = offset_rois[0];
  788. integer_t roi_start_w = round(offset_rois[1] * spatial_scale);
  789. integer_t roi_start_h = round(offset_rois[2] * spatial_scale);
  790. integer_t roi_end_w = round(offset_rois[3] * spatial_scale);
  791. integer_t roi_end_h = round(offset_rois[4] * spatial_scale);
  792. // Force too small ROIs to be 1x1
  793. integer_t roi_width = max(roi_end_w - roi_start_w, static_cast<integer_t>(1));
  794. integer_t roi_height = max(roi_end_h - roi_start_h, static_cast<integer_t>(1));
  795. T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
  796. T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
  797. integer_t hstart = static_cast<integer_t>(floor(static_cast<T>(ph) * bin_size_h));
  798. integer_t wstart = static_cast<integer_t>(floor(static_cast<T>(pw) * bin_size_w));
  799. integer_t hend = static_cast<integer_t>(ceil(static_cast<T>(ph + 1) * bin_size_h));
  800. integer_t wend = static_cast<integer_t>(ceil(static_cast<T>(pw + 1) * bin_size_w));
  801. // Add roi offsets and clip to input boundaries
  802. hstart = min(max(hstart + roi_start_h, static_cast<integer_t>(0)), static_cast<integer_t>(height - 1));
  803. hend = min(max(hend + roi_start_h, static_cast<integer_t>(0)), static_cast<integer_t>(height - 1));
  804. wstart = min(max(wstart + roi_start_w, static_cast<integer_t>(0)), static_cast<integer_t>(width - 1));
  805. wend = min(max(wend + roi_start_w, static_cast<integer_t>(0)), static_cast<integer_t>(width - 1));
  806. bool is_empty = (hend <= hstart) || (wend <= wstart);
  807. constant T* offset_input =
  808. input + (roi_batch_ind * channels + c_in) * height * width;
  809. T out_sum = 0;
  810. for (integer_t h = hstart; h < hend; ++h) {
  811. for (integer_t w = wstart; w < wend; ++w) {
  812. integer_t input_index = h * width + w;
  813. out_sum += offset_input[input_index];
  814. }
  815. }
  816. T bin_area = (hend - hstart) * (wend - wstart);
  817. output[index] = is_empty ? static_cast<T>(0) : out_sum / bin_area;
  818. channel_mapping[index] = c_in;
  819. }
  820. }
  821. #define REGISTER_PS_ROI_POOL_OP(DTYPE, INT_DTYPE) \
  822. template \
  823. [[host_name("ps_roi_pool_" #DTYPE)]] \
  824. kernel void ps_roi_pool<DTYPE, INT_DTYPE>( \
  825. constant DTYPE * input [[buffer(0)]], \
  826. constant DTYPE * rois [[buffer(1)]], \
  827. device DTYPE * output [[buffer(2)]], \
  828. device int64_t * channel_mapping [[buffer(3)]], \
  829. constant int64_t & output_size [[buffer(4)]], \
  830. constant int64_t & channels [[buffer(5)]], \
  831. constant int64_t & height [[buffer(6)]], \
  832. constant int64_t & width [[buffer(7)]], \
  833. constant int64_t & pooled_height [[buffer(8)]], \
  834. constant int64_t & pooled_width [[buffer(9)]], \
  835. constant int64_t & channels_out [[buffer(10)]], \
  836. constant float & spatial_scale [[buffer(11)]], \
  837. uint2 tgid [[threadgroup_position_in_grid]], \
  838. uint2 tptg [[threads_per_threadgroup]], \
  839. uint2 tid2 [[thread_position_in_threadgroup]]);
  840. template<typename T, typename integer_t>
  841. kernel void ps_roi_pool_backward(
  842. constant T * grad_output [[buffer(0)]],
  843. constant T * rois [[buffer(1)]],
  844. constant int64_t * channel_mapping [[buffer(2)]],
  845. device T * grad_input [[buffer(3)]],
  846. constant int64_t & output_size [[buffer(4)]],
  847. constant int64_t & channels [[buffer(5)]],
  848. constant int64_t & height [[buffer(6)]],
  849. constant int64_t & width [[buffer(7)]],
  850. constant int64_t & pooled_height [[buffer(8)]],
  851. constant int64_t & pooled_width [[buffer(9)]],
  852. constant int64_t & channels_out [[buffer(10)]],
  853. constant float & spatial_scale [[buffer(11)]],
  854. uint2 tgid [[threadgroup_position_in_grid]],
  855. uint2 tptg [[threads_per_threadgroup]],
  856. uint2 tid2 [[thread_position_in_threadgroup]]){
  857. MPS_1D_KERNEL_LOOP(index, output_size, 1) {
  858. // (n, *, ph, pw) is an element in the pooled output
  859. integer_t pw = index % pooled_width;
  860. integer_t ph = (index / pooled_width) % pooled_height;
  861. integer_t n = index / pooled_width / pooled_height / channels_out;
  862. constant T* offset_rois = rois + n * 5;
  863. integer_t roi_batch_ind = offset_rois[0];
  864. integer_t roi_start_w = round(offset_rois[1] * spatial_scale);
  865. integer_t roi_start_h = round(offset_rois[2] * spatial_scale);
  866. integer_t roi_end_w = round(offset_rois[3] * spatial_scale);
  867. integer_t roi_end_h = round(offset_rois[4] * spatial_scale);
  868. // Force too small ROIs to be 1x1
  869. integer_t roi_width = max(roi_end_w - roi_start_w, static_cast<integer_t>(1));
  870. integer_t roi_height = max(roi_end_h - roi_start_h, static_cast<integer_t>(1));
  871. T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
  872. T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
  873. integer_t hstart = static_cast<integer_t>(floor(static_cast<T>(ph) * bin_size_h));
  874. integer_t wstart = static_cast<integer_t>(floor(static_cast<T>(pw) * bin_size_w));
  875. integer_t hend = static_cast<integer_t>(ceil(static_cast<T>(ph + 1) * bin_size_h));
  876. integer_t wend = static_cast<integer_t>(ceil(static_cast<T>(pw + 1) * bin_size_w));
  877. // Add roi offsets and clip to input boundaries
  878. hstart = min(max(hstart + roi_start_h, static_cast<integer_t>(0)), static_cast<integer_t>(height));
  879. hend = min(max(hend + roi_start_h, static_cast<integer_t>(0)), static_cast<integer_t>(height));
  880. wstart = min(max(wstart + roi_start_w, static_cast<integer_t>(0)), static_cast<integer_t>(width));
  881. wend = min(max(wend + roi_start_w, static_cast<integer_t>(0)), static_cast<integer_t>(width));
  882. bool is_empty = (hend <= hstart) || (wend <= wstart);
  883. integer_t c_in = channel_mapping[index];
  884. T bin_area = (hend - hstart) * (wend - wstart);
  885. T diff_val = is_empty ? static_cast<T>(0) : grad_output[index] / bin_area;
  886. const integer_t offset = (roi_batch_ind * channels + c_in) * height * width;
  887. for (integer_t h = hstart; h < hend; ++h) {
  888. for (integer_t w = wstart; w < wend; ++w) {
  889. integer_t grad_input_index = h * width + w;
  890. atomic_add_float(grad_input + offset + grad_input_index, diff_val);
  891. }
  892. }
  893. } // MPS_1D_KERNEL_LOOP
  894. }
  895. #define REGISTER_PS_ROI_POOL_BACKWARD_OP(DTYPE, INT_DTYPE) \
  896. template \
  897. [[host_name("ps_roi_pool_backward_" #DTYPE)]] \
  898. kernel void ps_roi_pool_backward<DTYPE, INT_DTYPE>( \
  899. constant DTYPE * grad_output [[buffer(0)]], \
  900. constant DTYPE * rois [[buffer(1)]], \
  901. constant int64_t * channel_mapping [[buffer(2)]], \
  902. device DTYPE * grad_input [[buffer(3)]], \
  903. constant int64_t & output_size [[buffer(4)]], \
  904. constant int64_t & channels [[buffer(5)]], \
  905. constant int64_t & height [[buffer(6)]], \
  906. constant int64_t & width [[buffer(7)]], \
  907. constant int64_t & pooled_height [[buffer(8)]], \
  908. constant int64_t & pooled_width [[buffer(9)]], \
  909. constant int64_t & channels_out [[buffer(10)]], \
  910. constant float & spatial_scale [[buffer(11)]], \
  911. uint2 tgid [[threadgroup_position_in_grid]], \
  912. uint2 tptg [[threads_per_threadgroup]], \
  913. uint2 tid2 [[thread_position_in_threadgroup]]);
  914. REGISTER_NMS_OP(float);
  915. REGISTER_NMS_OP(half);
  916. REGISTER_ROI_ALIGN_OP(float, int64_t);
  917. REGISTER_ROI_ALIGN_OP(half, int64_t);
  918. REGISTER_ROI_ALIGN_BACKWARD_OP(float, int64_t);
  919. REGISTER_ROI_ALIGN_BACKWARD_OP(half, int64_t);
  920. REGISTER_ROI_POOL_OP(float, int64_t);
  921. REGISTER_ROI_POOL_OP(half, int64_t);
  922. REGISTER_ROI_POOL_BACKWARD_OP(float, int64_t);
  923. REGISTER_ROI_POOL_BACKWARD_OP(half, int64_t);
  924. REGISTER_PS_ROI_ALIGN_OP(float, int64_t);
  925. REGISTER_PS_ROI_ALIGN_OP(half, int64_t);
  926. REGISTER_PS_ROI_ALIGN_BACKWARD_OP(float, int64_t);
  927. REGISTER_PS_ROI_ALIGN_BACKWARD_OP(half, int64_t);
  928. REGISTER_PS_ROI_POOL_OP(float, int64_t);
  929. REGISTER_PS_ROI_POOL_OP(half, int64_t);
  930. REGISTER_PS_ROI_POOL_BACKWARD_OP(float, int64_t);
  931. REGISTER_PS_ROI_POOL_BACKWARD_OP(half, int64_t);
  932. )VISION_METAL";
  933. static id<MTLLibrary> compileVisionOpsLibrary(id<MTLDevice> device) {
  934. static id<MTLLibrary> visionLibrary = nil;
  935. if (visionLibrary) {
  936. return visionLibrary;
  937. }
  938. NSError* error = nil;
  939. MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
  940. [options setLanguageVersion:MTLLanguageVersion2_3];
  941. visionLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_VISION encoding:NSASCIIStringEncoding]
  942. options:options
  943. error:&error];
  944. TORCH_CHECK(visionLibrary, "Failed to create metal vision library, error: ", [[error description] UTF8String]);
  945. return visionLibrary;
  946. }
  947. static id<MTLComputePipelineState> visionPipelineState(id<MTLDevice> device, const std::string& kernel) {
  948. static std::unordered_map<std::string, id<MTLComputePipelineState>> psoCache;
  949. id<MTLComputePipelineState> pso = psoCache[kernel];
  950. if (pso) {
  951. return pso;
  952. }
  953. NSError* error = nil;
  954. id<MTLLibrary> visionLib = compileVisionOpsLibrary(device);
  955. id<MTLFunction> visionFunc = [visionLib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]];
  956. TORCH_CHECK(visionFunc, "Failed to create function state object for: ", kernel);
  957. pso = [device newComputePipelineStateWithFunction:visionFunc error:&error];
  958. TORCH_CHECK(pso, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
  959. psoCache[kernel] = pso;
  960. return pso;
  961. }
  962. } // namespace mps
  963. } // namespace ops
  964. } // namespace vision