nms_kernel.mm 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. #include <ATen/mps/MPSProfiler.h>
  2. #include <ATen/native/mps/OperationUtils.h>
  3. #include "mps_kernels.h"
  4. namespace vision {
  5. namespace ops {
  6. namespace {
  7. // This should be in sync with `nmsThreadsPerBlock` in the metal kernel.
  8. constexpr int64_t nmsThreadsPerBlock = sizeof(uint64_t) * 8;
  9. at::Tensor nms_kernel(const at::Tensor& dets, const at::Tensor& scores, double iou_threshold) {
  10. using namespace at::native::mps;
  11. TORCH_CHECK(dets.is_mps(), "dets must be a MPS tensor");
  12. TORCH_CHECK(scores.is_mps(), "scores must be a MPS tensor");
  13. TORCH_CHECK(dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D");
  14. TORCH_CHECK(dets.size(1) == 4, "boxes should have 4 elements in dimension 1, got ", dets.size(1));
  15. TORCH_CHECK(scores.dim() == 1, "scores should be a 1d tensor, got ", scores.dim(), "D");
  16. TORCH_CHECK(dets.size(0) == scores.size(0),
  17. "boxes and scores should have same number of elements in ",
  18. "dimension 0, got ",
  19. dets.size(0),
  20. " and ",
  21. scores.size(0))
  22. if (dets.numel() == 0) {
  23. return at::empty({0}, dets.options().dtype(at::kLong));
  24. }
  25. auto order_t = std::get<1>(scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true));
  26. auto dets_sorted = dets.index_select(0, order_t).contiguous();
  27. int64_t dets_num = dets.size(0);
  28. float iou_threshold_f = static_cast<float>(iou_threshold);
  29. const int col_blocks = (dets_num + nmsThreadsPerBlock - 1) / nmsThreadsPerBlock;
  30. at::Tensor mask = at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong));
  31. id<MTLBuffer> inputBuffer = getMTLBufferStorage(dets_sorted);
  32. id<MTLBuffer> outputBuffer = getMTLBufferStorage(mask);
  33. id<MTLDevice> device = MPSDevice::getInstance()->device();
  34. MPSStream* mpsStream = getCurrentMPSStream();
  35. dispatch_sync(mpsStream->queue(), ^() {
  36. @autoreleasepool {
  37. id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
  38. MTLSize threadgroupsPerGrid = MTLSizeMake(col_blocks, col_blocks, 1);
  39. const std::string kernel = "nms_" + scalarToMetalTypeString(dets_sorted.scalar_type());
  40. id<MTLComputePipelineState> visionPSO = mps::visionPipelineState(device, kernel);
  41. // this function call is a no-op if MPS Profiler is not enabled
  42. getMPSProfiler().beginProfileKernel(visionPSO, kernel, {dets, scores});
  43. [computeEncoder setComputePipelineState:visionPSO];
  44. [computeEncoder setBuffer:inputBuffer offset:dets_sorted.storage_offset() * dets_sorted.element_size() atIndex:0];
  45. [computeEncoder setBuffer:outputBuffer offset:mask.storage_offset() * mask.element_size() atIndex:1];
  46. [computeEncoder setBytes:&dets_num length:sizeof(int64_t) atIndex:2];
  47. [computeEncoder setBytes:&iou_threshold_f length:sizeof(float) atIndex:3];
  48. // A threadGroup is equivalent to a cuda's block.
  49. NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup;
  50. if (tgSize > nmsThreadsPerBlock) {
  51. tgSize = nmsThreadsPerBlock;
  52. }
  53. MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
  54. [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
  55. getMPSProfiler().endProfileKernel(visionPSO);
  56. }
  57. });
  58. int64_t num_to_keep = 0;
  59. at::Tensor mask_cpu = mask.to(at::kCPU);
  60. unsigned long long* mask_host = (unsigned long long*)mask_cpu.data_ptr<int64_t>();
  61. std::vector<unsigned long long> remv(col_blocks);
  62. memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
  63. at::Tensor keep = at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU));
  64. int64_t* keep_out = keep.data_ptr<int64_t>();
  65. for (int64_t i = 0; i < dets_num; i++) {
  66. int64_t nblock = i / nmsThreadsPerBlock;
  67. int64_t inblock = i % nmsThreadsPerBlock;
  68. if (!(remv[nblock] & (1ULL << inblock))) {
  69. keep_out[num_to_keep++] = i;
  70. unsigned long long* p = mask_host + i * col_blocks;
  71. for (int64_t j = nblock; j < col_blocks; j++) {
  72. remv[j] |= p[j];
  73. }
  74. }
  75. }
  76. return order_t.index(
  77. {keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to(order_t.device(), keep.scalar_type())});
  78. }
  79. } // namespace
  80. TORCH_LIBRARY_IMPL(torchvision, MPS, m) {
  81. m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel));
  82. }
  83. } // namespace ops
  84. } // namespace vision