roi_pool_kernel.mm 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. #include <ATen/mps/MPSProfiler.h>
  2. #include <ATen/native/mps/OperationUtils.h>
  3. #include "mps_helpers.h"
  4. #include "mps_kernels.h"
  5. namespace vision {
  6. namespace ops {
  7. namespace {
  8. std::tuple<at::Tensor, at::Tensor> roi_pool_forward_kernel(const at::Tensor& input,
  9. const at::Tensor& rois,
  10. double spatial_scale,
  11. int64_t pooled_height,
  12. int64_t pooled_width) {
  13. using namespace at::native::mps;
  14. TORCH_CHECK(input.is_mps(), "input must be a MPS tensor");
  15. TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor");
  16. TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]");
  17. at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
  18. at::CheckedFrom c = "roi_pool_forward_kernel";
  19. at::checkAllSameGPU(c, {input_t, rois_t});
  20. at::checkAllSameType(c, {input_t, rois_t});
  21. int64_t num_rois = rois.size(0);
  22. int64_t channels = input.size(1);
  23. int64_t height = input.size(2);
  24. int64_t width = input.size(3);
  25. float spatial_scale_f = static_cast<float>(spatial_scale);
  26. at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options());
  27. at::Tensor argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options().dtype(at::kLong));
  28. int64_t output_size = num_rois * pooled_height * pooled_width * channels;
  29. if (output.numel() == 0) {
  30. return std::make_tuple(output, argmax);
  31. }
  32. auto input_ = input.contiguous();
  33. auto rois_ = rois.contiguous();
  34. id<MTLBuffer> inputBuffer = getMTLBufferStorage(input_);
  35. id<MTLBuffer> roisBuffer = getMTLBufferStorage(rois_);
  36. id<MTLBuffer> outputBuffer = getMTLBufferStorage(output);
  37. id<MTLBuffer> argmaxBuffer = getMTLBufferStorage(argmax);
  38. id<MTLDevice> device = MPSDevice::getInstance()->device();
  39. MPSStream* mpsStream = getCurrentMPSStream();
  40. dispatch_sync(mpsStream->queue(), ^() {
  41. @autoreleasepool {
  42. id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
  43. MTLSize threadgroupsPerGrid = MTLSizeMake(
  44. std::min(ceil_div(static_cast<int64_t>(output_size), static_cast<int64_t>(512)), static_cast<int64_t>(4096)),
  45. 1,
  46. 1);
  47. const std::string kernel = "roi_pool_" + scalarToMetalTypeString(input.scalar_type());
  48. id<MTLComputePipelineState> visionPSO = mps::visionPipelineState(device, kernel);
  49. // this function call is a no-op if MPS Profiler is not enabled
  50. getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_});
  51. [computeEncoder setComputePipelineState:visionPSO];
  52. // [N, C, H, W]
  53. [computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0];
  54. [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1];
  55. [computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2];
  56. [computeEncoder setBuffer:argmaxBuffer offset:argmax.storage_offset() * argmax.element_size() atIndex:3];
  57. [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4];
  58. [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5];
  59. [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6];
  60. [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7];
  61. [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8];
  62. [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9];
  63. [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:10];
  64. // A threadGroup is equivalent to a cuda's block.
  65. NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup;
  66. if (tgSize > threadsPerBlock) {
  67. tgSize = threadsPerBlock;
  68. }
  69. MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
  70. [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
  71. getMPSProfiler().endProfileKernel(visionPSO);
  72. }
  73. });
  74. return std::make_tuple(output, argmax);
  75. }
  76. at::Tensor roi_pool_backward_kernel(const at::Tensor& grad,
  77. const at::Tensor& rois,
  78. const at::Tensor& argmax,
  79. double spatial_scale,
  80. int64_t pooled_height,
  81. int64_t pooled_width,
  82. int64_t batch_size,
  83. int64_t channels,
  84. int64_t height,
  85. int64_t width) {
  86. using namespace at::native::mps;
  87. TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor");
  88. TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor");
  89. TORCH_CHECK(grad.scalar_type() != at::kHalf, "MPS does not support roi_pool backward with float16 inputs.");
  90. TORCH_CHECK(argmax.is_mps(), "argmax must be a MPS tensor");
  91. at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}, argmax_t{argmax, "argmax", 3};
  92. at::CheckedFrom c = "roi_pool_backward_kernel";
  93. at::checkAllSameGPU(c, {grad_t, rois_t, argmax_t});
  94. at::checkAllSameType(c, {grad_t, rois_t});
  95. float spatial_scale_f = static_cast<float>(spatial_scale);
  96. at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.options());
  97. if (grad.numel() == 0) {
  98. return grad_input;
  99. }
  100. int64_t n_stride = grad.stride(0);
  101. int64_t c_stride = grad.stride(1);
  102. int64_t h_stride = grad.stride(2);
  103. int64_t w_stride = grad.stride(3);
  104. int64_t output_size = grad.numel();
  105. at::globalContext().alertNotDeterministic("roi_pool_backward_kernel");
  106. auto argmax_ = argmax.contiguous(), rois_ = rois.contiguous();
  107. id<MTLBuffer> inputBuffer = getMTLBufferStorage(grad);
  108. id<MTLBuffer> roisBuffer = getMTLBufferStorage(rois_);
  109. id<MTLBuffer> argmaxBuffer = getMTLBufferStorage(argmax_);
  110. id<MTLBuffer> outputBuffer = getMTLBufferStorage(grad_input);
  111. id<MTLDevice> device = MPSDevice::getInstance()->device();
  112. MPSStream* mpsStream = getCurrentMPSStream();
  113. dispatch_sync(mpsStream->queue(), ^() {
  114. @autoreleasepool {
  115. id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
  116. MTLSize threadgroupsPerGrid = MTLSizeMake(
  117. std::min(ceil_div(static_cast<int64_t>(grad.numel()), static_cast<int64_t>(512)), static_cast<int64_t>(4096)),
  118. 1,
  119. 1);
  120. const std::string kernel = "roi_pool_backward_" + scalarToMetalTypeString(grad.scalar_type());
  121. id<MTLComputePipelineState> visionPSO = mps::visionPipelineState(device, kernel);
  122. // this function call is a no-op if MPS Profiler is not enabled
  123. getMPSProfiler().beginProfileKernel(visionPSO, kernel, {grad, rois_, argmax_});
  124. [computeEncoder setComputePipelineState:visionPSO];
  125. // [N, C, H, W]
  126. [computeEncoder setBuffer:inputBuffer offset:grad.storage_offset() * grad.element_size() atIndex:0];
  127. [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1];
  128. [computeEncoder setBuffer:argmaxBuffer offset:argmax_.storage_offset() * argmax_.element_size() atIndex:2];
  129. [computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:3];
  130. [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4];
  131. [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5];
  132. [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6];
  133. [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7];
  134. [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8];
  135. [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9];
  136. [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:10];
  137. [computeEncoder setBytes:&n_stride length:sizeof(int64_t) atIndex:11];
  138. [computeEncoder setBytes:&c_stride length:sizeof(int64_t) atIndex:12];
  139. [computeEncoder setBytes:&h_stride length:sizeof(int64_t) atIndex:13];
  140. [computeEncoder setBytes:&w_stride length:sizeof(int64_t) atIndex:14];
  141. // A threadGroup is equivalent to a cuda's block.
  142. NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup;
  143. if (tgSize > threadsPerBlock) {
  144. tgSize = threadsPerBlock;
  145. }
  146. MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
  147. [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
  148. getMPSProfiler().endProfileKernel(visionPSO);
  149. }
  150. });
  151. return grad_input;
  152. }
  153. } // namespace
  154. TORCH_LIBRARY_IMPL(torchvision, MPS, m) {
  155. m.impl(TORCH_SELECTIVE_NAME("torchvision::roi_pool"), TORCH_FN(roi_pool_forward_kernel));
  156. m.impl(TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"), TORCH_FN(roi_pool_backward_kernel));
  157. }
  158. } // namespace ops
  159. } // namespace vision