nms_kernel.cpp 670 B

123456789101112131415161718192021222324252627282930
  1. #include "../nms.h"
  2. #include <ATen/autocast_mode.h>
  3. #include <torch/library.h>
  4. #include <torch/types.h>
  5. namespace vision {
  6. namespace ops {
  7. namespace {
  8. at::Tensor nms_autocast(
  9. const at::Tensor& dets,
  10. const at::Tensor& scores,
  11. double iou_threshold) {
  12. c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
  13. return nms(
  14. at::autocast::cached_cast(at::kFloat, dets),
  15. at::autocast::cached_cast(at::kFloat, scores),
  16. iou_threshold);
  17. }
  18. } // namespace
  19. TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
  20. m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_autocast));
  21. }
  22. } // namespace ops
  23. } // namespace vision