ps_roi_align.h 731 B

1234567891011121314151617181920212223242526272829303132333435
  1. #pragma once
  2. #include <ATen/ATen.h>
  3. #include "../macros.h"
  4. namespace vision {
  5. namespace ops {
  6. VISION_API std::tuple<at::Tensor, at::Tensor> ps_roi_align(
  7. const at::Tensor& input,
  8. const at::Tensor& rois,
  9. double spatial_scale,
  10. int64_t pooled_height,
  11. int64_t pooled_width,
  12. int64_t sampling_ratio);
  13. namespace detail {
  14. at::Tensor _ps_roi_align_backward(
  15. const at::Tensor& grad,
  16. const at::Tensor& rois,
  17. const at::Tensor& channel_mapping,
  18. double spatial_scale,
  19. int64_t pooled_height,
  20. int64_t pooled_width,
  21. int64_t sampling_ratio,
  22. int64_t batch_size,
  23. int64_t channels,
  24. int64_t height,
  25. int64_t width);
  26. } // namespace detail
  27. } // namespace ops
  28. } // namespace vision