roi_align_kernel.mm 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. #include <ATen/mps/MPSProfiler.h>
  2. #include <ATen/native/mps/OperationUtils.h>
  3. #include "mps_helpers.h"
  4. #include "mps_kernels.h"
  5. namespace vision {
  6. namespace ops {
  7. namespace {
  8. at::Tensor roi_align_forward_kernel(const at::Tensor& input,
  9. const at::Tensor& rois,
  10. double spatial_scale,
  11. int64_t pooled_height,
  12. int64_t pooled_width,
  13. int64_t sampling_ratio,
  14. bool aligned) {
  15. using namespace at::native::mps;
  16. TORCH_CHECK(input.is_mps(), "input must be a MPS tensor");
  17. TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor");
  18. TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]");
  19. at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
  20. at::CheckedFrom c = "roi_align_forward_kernel";
  21. at::checkAllSameGPU(c, {input_t, rois_t});
  22. at::checkAllSameType(c, {input_t, rois_t});
  23. int64_t num_rois = rois.size(0);
  24. int64_t channels = input.size(1);
  25. int64_t height = input.size(2);
  26. int64_t width = input.size(3);
  27. float spatial_scale_f = static_cast<float>(spatial_scale);
  28. at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options());
  29. int64_t output_size = num_rois * pooled_height * pooled_width * channels;
  30. if (output.numel() == 0) {
  31. return output;
  32. }
  33. auto input_ = input.contiguous();
  34. auto rois_ = rois.contiguous();
  35. id<MTLBuffer> inputBuffer = getMTLBufferStorage(input_);
  36. id<MTLBuffer> roisBuffer = getMTLBufferStorage(rois_);
  37. id<MTLBuffer> outputBuffer = getMTLBufferStorage(output);
  38. id<MTLDevice> device = MPSDevice::getInstance()->device();
  39. MPSStream* mpsStream = getCurrentMPSStream();
  40. dispatch_sync(mpsStream->queue(), ^() {
  41. @autoreleasepool {
  42. id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
  43. MTLSize threadgroupsPerGrid = MTLSizeMake(
  44. std::min(ceil_div(static_cast<int64_t>(output_size), static_cast<int64_t>(512)), static_cast<int64_t>(4096)),
  45. 1,
  46. 1);
  47. const std::string kernel = "roi_align_" + scalarToMetalTypeString(input.scalar_type());
  48. id<MTLComputePipelineState> visionPSO = mps::visionPipelineState(device, kernel);
  49. // this function call is a no-op if MPS Profiler is not enabled
  50. getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_});
  51. [computeEncoder setComputePipelineState:visionPSO];
  52. // [N, C, H, W]
  53. [computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0];
  54. [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1];
  55. [computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2];
  56. [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:3];
  57. [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4];
  58. [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5];
  59. [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6];
  60. [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:7];
  61. [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:8];
  62. [computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:9];
  63. [computeEncoder setBytes:&aligned length:sizeof(bool) atIndex:10];
  64. [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11];
  65. // A threadGroup is equivalent to a cuda's block.
  66. NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup;
  67. if (tgSize > threadsPerBlock) {
  68. tgSize = threadsPerBlock;
  69. }
  70. MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
  71. [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
  72. getMPSProfiler().endProfileKernel(visionPSO);
  73. }
  74. });
  75. return output;
  76. }
  77. at::Tensor roi_align_backward_kernel(const at::Tensor& grad,
  78. const at::Tensor& rois,
  79. double spatial_scale,
  80. int64_t pooled_height,
  81. int64_t pooled_width,
  82. int64_t batch_size,
  83. int64_t channels,
  84. int64_t height,
  85. int64_t width,
  86. int64_t sampling_ratio,
  87. bool aligned) {
  88. using namespace at::native::mps;
  89. TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor");
  90. TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor");
  91. TORCH_CHECK(grad.scalar_type() != at::kHalf, "MPS does not support roi_align backward with float16 inputs.");
  92. at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2};
  93. at::CheckedFrom c = "roi_align_backward_kernel";
  94. at::checkAllSameGPU(c, {grad_t, rois_t});
  95. at::checkAllSameType(c, {grad_t, rois_t});
  96. float spatial_scale_f = static_cast<float>(spatial_scale);
  97. at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.options());
  98. if (grad.numel() == 0) {
  99. return grad_input;
  100. }
  101. int64_t n_stride = grad.stride(0);
  102. int64_t c_stride = grad.stride(1);
  103. int64_t h_stride = grad.stride(2);
  104. int64_t w_stride = grad.stride(3);
  105. int64_t output_size = grad.numel();
  106. at::globalContext().alertNotDeterministic("roi_align_backward_kernel");
  107. auto rois_ = rois.contiguous();
  108. id<MTLBuffer> inputBuffer = getMTLBufferStorage(grad);
  109. id<MTLBuffer> roisBuffer = getMTLBufferStorage(rois_);
  110. id<MTLBuffer> outputBuffer = getMTLBufferStorage(grad_input);
  111. id<MTLDevice> device = MPSDevice::getInstance()->device();
  112. MPSStream* mpsStream = getCurrentMPSStream();
  113. dispatch_sync(mpsStream->queue(), ^() {
  114. @autoreleasepool {
  115. id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
  116. MTLSize threadgroupsPerGrid = MTLSizeMake(
  117. std::min(ceil_div(static_cast<int64_t>(grad.numel()), static_cast<int64_t>(512)), static_cast<int64_t>(4096)),
  118. 1,
  119. 1);
  120. const std::string kernel = "roi_align_backward_" + scalarToMetalTypeString(grad.scalar_type());
  121. id<MTLComputePipelineState> visionPSO = mps::visionPipelineState(device, kernel);
  122. // this function call is a no-op if MPS Profiler is not enabled
  123. getMPSProfiler().beginProfileKernel(visionPSO, kernel, {grad, rois_});
  124. [computeEncoder setComputePipelineState:visionPSO];
  125. // [N, C, H, W]
  126. [computeEncoder setBuffer:inputBuffer offset:grad.storage_offset() * grad.element_size() atIndex:0];
  127. [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1];
  128. [computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:2];
  129. [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:3];
  130. [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4];
  131. [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5];
  132. [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6];
  133. [computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:7];
  134. [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:8];
  135. [computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:9];
  136. [computeEncoder setBytes:&aligned length:sizeof(bool) atIndex:10];
  137. [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11];
  138. [computeEncoder setBytes:&n_stride length:sizeof(int64_t) atIndex:12];
  139. [computeEncoder setBytes:&c_stride length:sizeof(int64_t) atIndex:13];
  140. [computeEncoder setBytes:&h_stride length:sizeof(int64_t) atIndex:14];
  141. [computeEncoder setBytes:&w_stride length:sizeof(int64_t) atIndex:15];
  142. // A threadGroup is equivalent to a cuda's block.
  143. NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup;
  144. if (tgSize > threadsPerBlock) {
  145. tgSize = threadsPerBlock;
  146. }
  147. MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
  148. [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
  149. getMPSProfiler().endProfileKernel(visionPSO);
  150. }
  151. });
  152. return grad_input;
  153. }
  154. } // namespace
  155. TORCH_LIBRARY_IMPL(torchvision, MPS, m) {
  156. m.impl(TORCH_SELECTIVE_NAME("torchvision::roi_align"), TORCH_FN(roi_align_forward_kernel));
  157. m.impl(TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"), TORCH_FN(roi_align_backward_kernel));
  158. }
  159. } // namespace ops
  160. } // namespace vision