roi_align_kernel.cpp 968 B

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. #include "../roi_align.h"
  2. #include <ATen/autocast_mode.h>
  3. #include <torch/library.h>
  4. #include <torch/types.h>
  5. namespace vision {
  6. namespace ops {
  7. namespace {
  8. at::Tensor roi_align_autocast(
  9. const at::Tensor& input,
  10. const at::Tensor& rois,
  11. double spatial_scale,
  12. int64_t pooled_height,
  13. int64_t pooled_width,
  14. int64_t sampling_ratio,
  15. bool aligned) {
  16. c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
  17. return roi_align(
  18. at::autocast::cached_cast(at::kFloat, input),
  19. at::autocast::cached_cast(at::kFloat, rois),
  20. spatial_scale,
  21. pooled_height,
  22. pooled_width,
  23. sampling_ratio,
  24. aligned)
  25. .to(input.scalar_type());
  26. }
  27. } // namespace
  28. TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
  29. m.impl(
  30. TORCH_SELECTIVE_NAME("torchvision::roi_align"),
  31. TORCH_FN(roi_align_autocast));
  32. }
  33. } // namespace ops
  34. } // namespace vision