roi_pool_kernel.cpp 961 B

12345678910111213141516171819202122232425262728293031323334353637383940
  1. #include "../roi_pool.h"
  2. #include <ATen/autocast_mode.h>
  3. #include <torch/library.h>
  4. #include <torch/types.h>
  5. namespace vision {
  6. namespace ops {
  7. namespace {
  8. std::tuple<at::Tensor, at::Tensor> roi_pool_autocast(
  9. const at::Tensor& input,
  10. const at::Tensor& rois,
  11. double spatial_scale,
  12. int64_t pooled_height,
  13. int64_t pooled_width) {
  14. c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
  15. auto result = roi_pool(
  16. at::autocast::cached_cast(at::kFloat, input),
  17. at::autocast::cached_cast(at::kFloat, rois),
  18. spatial_scale,
  19. pooled_height,
  20. pooled_width);
  21. return std::make_tuple(
  22. std::get<0>(result).to(input.scalar_type()),
  23. std::get<1>(result).to(input.scalar_type()));
  24. }
  25. } // namespace
  26. TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
  27. m.impl(
  28. TORCH_SELECTIVE_NAME("torchvision::roi_pool"),
  29. TORCH_FN(roi_pool_autocast));
  30. }
  31. } // namespace ops
  32. } // namespace vision