ps_roi_pool.h 673 B

123456789101112131415161718192021222324252627282930313233
  1. #pragma once
  2. #include <ATen/ATen.h>
  3. #include "../macros.h"
  4. namespace vision {
  5. namespace ops {
  6. VISION_API std::tuple<at::Tensor, at::Tensor> ps_roi_pool(
  7. const at::Tensor& input,
  8. const at::Tensor& rois,
  9. double spatial_scale,
  10. int64_t pooled_height,
  11. int64_t pooled_width);
  12. namespace detail {
  13. at::Tensor _ps_roi_pool_backward(
  14. const at::Tensor& grad,
  15. const at::Tensor& rois,
  16. const at::Tensor& channel_mapping,
  17. double spatial_scale,
  18. int64_t pooled_height,
  19. int64_t pooled_width,
  20. int64_t batch_size,
  21. int64_t channels,
  22. int64_t height,
  23. int64_t width);
  24. } // namespace detail
  25. } // namespace ops
  26. } // namespace vision