12345678910111213141516171819202122232425262728293031323334353637383940 |
- #include "../roi_pool.h"
- #include <ATen/autocast_mode.h>
- #include <torch/library.h>
- #include <torch/types.h>
- namespace vision {
- namespace ops {
- namespace {
- std::tuple<at::Tensor, at::Tensor> roi_pool_autocast(
- const at::Tensor& input,
- const at::Tensor& rois,
- double spatial_scale,
- int64_t pooled_height,
- int64_t pooled_width) {
- c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
- auto result = roi_pool(
- at::autocast::cached_cast(at::kFloat, input),
- at::autocast::cached_cast(at::kFloat, rois),
- spatial_scale,
- pooled_height,
- pooled_width);
- return std::make_tuple(
- std::get<0>(result).to(input.scalar_type()),
- std::get<1>(result).to(input.scalar_type()));
- }
- } // namespace
- TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
- m.impl(
- TORCH_SELECTIVE_NAME("torchvision::roi_pool"),
- TORCH_FN(roi_pool_autocast));
- }
- } // namespace ops
- } // namespace vision
|