qnms_kernel.cpp 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. #include <ATen/ATen.h>
  2. #include <torch/library.h>
  3. namespace vision {
  4. namespace ops {
  5. namespace {
  6. template <typename scalar_t>
  7. at::Tensor qnms_kernel_impl(
  8. const at::Tensor& dets,
  9. const at::Tensor& scores,
  10. double iou_threshold) {
  11. TORCH_CHECK(!dets.is_cuda(), "dets must be a CPU tensor");
  12. TORCH_CHECK(!scores.is_cuda(), "scores must be a CPU tensor");
  13. TORCH_CHECK(
  14. dets.scalar_type() == scores.scalar_type(),
  15. "dets should have the same type as scores");
  16. if (dets.numel() == 0)
  17. return at::empty({0}, dets.options().dtype(at::kLong));
  18. const auto ndets = dets.size(0);
  19. auto x1_t = dets.select(1, 0).contiguous();
  20. auto y1_t = dets.select(1, 1).contiguous();
  21. auto x2_t = dets.select(1, 2).contiguous();
  22. auto y2_t = dets.select(1, 3).contiguous();
  23. auto order_t = std::get<1>(
  24. scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true));
  25. at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte));
  26. at::Tensor keep_t = at::zeros({ndets}, dets.options().dtype(at::kLong));
  27. at::Tensor areas_t = at::zeros({ndets}, dets.options().dtype(at::kFloat));
  28. auto suppressed = suppressed_t.data_ptr<uint8_t>();
  29. auto keep = keep_t.data_ptr<int64_t>();
  30. auto order = order_t.data_ptr<int64_t>();
  31. auto x1 = x1_t.data_ptr<scalar_t>();
  32. auto y1 = y1_t.data_ptr<scalar_t>();
  33. auto x2 = x2_t.data_ptr<scalar_t>();
  34. auto y2 = y2_t.data_ptr<scalar_t>();
  35. auto areas = areas_t.data_ptr<float>();
  36. for (int64_t i = 0; i < ndets; i++) {
  37. // Note 1: To get the exact area we'd need to multiply by scale**2, but this
  38. // would get canceled out in the computation of ovr below. So we leave that
  39. // out.
  40. // Note 2: degenerate boxes (x2 < x1 or y2 < y1) may underflow, although
  41. // integral promotion rules will likely prevent it (see
  42. // https://stackoverflow.com/questions/32959564/subtraction-of-two-unsigned-gives-signed
  43. // for more details).
  44. areas[i] = (x2[i].val_ - x1[i].val_) * (y2[i].val_ - y1[i].val_);
  45. }
  46. int64_t num_to_keep = 0;
  47. for (int64_t _i = 0; _i < ndets; _i++) {
  48. auto i = order[_i];
  49. if (suppressed[i] == 1)
  50. continue;
  51. keep[num_to_keep++] = i;
  52. // We explicitly cast coordinates to float so that the code can be
  53. // vectorized.
  54. float ix1val = x1[i].val_;
  55. float iy1val = y1[i].val_;
  56. float ix2val = x2[i].val_;
  57. float iy2val = y2[i].val_;
  58. float iarea = areas[i];
  59. for (int64_t _j = _i + 1; _j < ndets; _j++) {
  60. auto j = order[_j];
  61. if (suppressed[j] == 1)
  62. continue;
  63. float xx1 = std::max(ix1val, (float)x1[j].val_);
  64. float yy1 = std::max(iy1val, (float)y1[j].val_);
  65. float xx2 = std::min(ix2val, (float)x2[j].val_);
  66. float yy2 = std::min(iy2val, (float)y2[j].val_);
  67. auto w = std::max(0.f, xx2 - xx1); // * scale (gets canceled below)
  68. auto h = std::max(0.f, yy2 - yy1); // * scale (gets canceled below)
  69. auto inter = w * h;
  70. auto ovr = inter / (iarea + areas[j] - inter);
  71. if (ovr > iou_threshold)
  72. suppressed[j] = 1;
  73. }
  74. }
  75. return keep_t.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep);
  76. }
  77. at::Tensor qnms_kernel(
  78. const at::Tensor& dets,
  79. const at::Tensor& scores,
  80. double iou_threshold) {
  81. TORCH_CHECK(
  82. dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D");
  83. TORCH_CHECK(
  84. dets.size(1) == 4,
  85. "boxes should have 4 elements in dimension 1, got ",
  86. dets.size(1));
  87. TORCH_CHECK(
  88. scores.dim() == 1,
  89. "scores should be a 1d tensor, got ",
  90. scores.dim(),
  91. "D");
  92. TORCH_CHECK(
  93. dets.size(0) == scores.size(0),
  94. "boxes and scores should have same number of elements in ",
  95. "dimension 0, got ",
  96. dets.size(0),
  97. " and ",
  98. scores.size(0));
  99. auto result = at::empty({0});
  100. AT_DISPATCH_QINT_TYPES(dets.scalar_type(), "qnms_kernel", [&] {
  101. result = qnms_kernel_impl<scalar_t>(dets, scores, iou_threshold);
  102. });
  103. return result;
  104. }
  105. } // namespace
  106. TORCH_LIBRARY_IMPL(torchvision, QuantizedCPU, m) {
  107. m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(qnms_kernel));
  108. }
  109. } // namespace ops
  110. } // namespace vision