nms_kernel.cpp 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. #include <ATen/ATen.h>
  2. #include <torch/library.h>
  3. namespace vision {
  4. namespace ops {
  5. namespace {
  6. template <typename scalar_t>
  7. at::Tensor nms_kernel_impl(
  8. const at::Tensor& dets,
  9. const at::Tensor& scores,
  10. double iou_threshold) {
  11. TORCH_CHECK(dets.is_cpu(), "dets must be a CPU tensor");
  12. TORCH_CHECK(scores.is_cpu(), "scores must be a CPU tensor");
  13. TORCH_CHECK(
  14. dets.scalar_type() == scores.scalar_type(),
  15. "dets should have the same type as scores");
  16. if (dets.numel() == 0)
  17. return at::empty({0}, dets.options().dtype(at::kLong));
  18. auto x1_t = dets.select(1, 0).contiguous();
  19. auto y1_t = dets.select(1, 1).contiguous();
  20. auto x2_t = dets.select(1, 2).contiguous();
  21. auto y2_t = dets.select(1, 3).contiguous();
  22. at::Tensor areas_t = (x2_t - x1_t) * (y2_t - y1_t);
  23. auto order_t = std::get<1>(
  24. scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true));
  25. auto ndets = dets.size(0);
  26. at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte));
  27. at::Tensor keep_t = at::zeros({ndets}, dets.options().dtype(at::kLong));
  28. auto suppressed = suppressed_t.data_ptr<uint8_t>();
  29. auto keep = keep_t.data_ptr<int64_t>();
  30. auto order = order_t.data_ptr<int64_t>();
  31. auto x1 = x1_t.data_ptr<scalar_t>();
  32. auto y1 = y1_t.data_ptr<scalar_t>();
  33. auto x2 = x2_t.data_ptr<scalar_t>();
  34. auto y2 = y2_t.data_ptr<scalar_t>();
  35. auto areas = areas_t.data_ptr<scalar_t>();
  36. int64_t num_to_keep = 0;
  37. for (int64_t _i = 0; _i < ndets; _i++) {
  38. auto i = order[_i];
  39. if (suppressed[i] == 1)
  40. continue;
  41. keep[num_to_keep++] = i;
  42. auto ix1 = x1[i];
  43. auto iy1 = y1[i];
  44. auto ix2 = x2[i];
  45. auto iy2 = y2[i];
  46. auto iarea = areas[i];
  47. for (int64_t _j = _i + 1; _j < ndets; _j++) {
  48. auto j = order[_j];
  49. if (suppressed[j] == 1)
  50. continue;
  51. auto xx1 = std::max(ix1, x1[j]);
  52. auto yy1 = std::max(iy1, y1[j]);
  53. auto xx2 = std::min(ix2, x2[j]);
  54. auto yy2 = std::min(iy2, y2[j]);
  55. auto w = std::max(static_cast<scalar_t>(0), xx2 - xx1);
  56. auto h = std::max(static_cast<scalar_t>(0), yy2 - yy1);
  57. auto inter = w * h;
  58. auto ovr = inter / (iarea + areas[j] - inter);
  59. if (ovr > iou_threshold)
  60. suppressed[j] = 1;
  61. }
  62. }
  63. return keep_t.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep);
  64. }
  65. at::Tensor nms_kernel(
  66. const at::Tensor& dets,
  67. const at::Tensor& scores,
  68. double iou_threshold) {
  69. TORCH_CHECK(
  70. dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D");
  71. TORCH_CHECK(
  72. dets.size(1) == 4,
  73. "boxes should have 4 elements in dimension 1, got ",
  74. dets.size(1));
  75. TORCH_CHECK(
  76. scores.dim() == 1,
  77. "scores should be a 1d tensor, got ",
  78. scores.dim(),
  79. "D");
  80. TORCH_CHECK(
  81. dets.size(0) == scores.size(0),
  82. "boxes and scores should have same number of elements in ",
  83. "dimension 0, got ",
  84. dets.size(0),
  85. " and ",
  86. scores.size(0));
  87. auto result = at::empty({0}, dets.options());
  88. AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_kernel", [&] {
  89. result = nms_kernel_impl<scalar_t>(dets, scores, iou_threshold);
  90. });
  91. return result;
  92. }
  93. } // namespace
  94. TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
  95. m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel));
  96. }
  97. } // namespace ops
  98. } // namespace vision