nms.cpp 726 B

123456789101112131415161718192021222324252627
  1. #include "nms.h"
  2. #include <ATen/core/dispatch/Dispatcher.h>
  3. #include <torch/library.h>
  4. #include <torch/types.h>
  5. namespace vision {
  6. namespace ops {
  7. at::Tensor nms(
  8. const at::Tensor& dets,
  9. const at::Tensor& scores,
  10. double iou_threshold) {
  11. C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.nms.nms");
  12. static auto op = c10::Dispatcher::singleton()
  13. .findSchemaOrThrow("torchvision::nms", "")
  14. .typed<decltype(nms)>();
  15. return op.call(dets, scores, iou_threshold);
  16. }
  17. TORCH_LIBRARY_FRAGMENT(torchvision, m) {
  18. m.def(TORCH_SELECTIVE_SCHEMA(
  19. "torchvision::nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor"));
  20. }
  21. } // namespace ops
  22. } // namespace vision