GridSampler.h 1.1 KB

1234567891011121314151617181920212223242526272829303132
  1. #pragma once
  2. #include <array>
  3. #include <cstdint>
  4. namespace at {
  5. class TensorBase;
  6. }
  7. namespace at {
  8. namespace native {
  9. void launch_grid_sampler_2d_forward_kernel(
  10. const TensorBase &output, const TensorBase &input, const TensorBase &grid,
  11. int64_t interpolation_mode, int64_t padding_mode, bool align_corners);
  12. void launch_grid_sampler_3d_forward_kernel(
  13. const TensorBase &output, const TensorBase &input, const TensorBase &grid,
  14. int64_t interpolation_mode, int64_t padding_mode, bool align_corners);
  15. void launch_grid_sampler_2d_backward_kernel(
  16. const TensorBase &grad_input, const TensorBase &grad_grid,
  17. const TensorBase &grad_output, const TensorBase &input,
  18. const TensorBase &grid, int64_t interpolation_mode, int64_t padding_mode,
  19. bool align_corners, std::array<bool, 2> output_mask);
  20. void launch_grid_sampler_3d_backward_kernel(
  21. const TensorBase &grad_input, const TensorBase &grad_grid,
  22. const TensorBase &grad_output, const TensorBase &input,
  23. const TensorBase &grid, int64_t interpolation_mode, int64_t padding_mode,
  24. bool align_corners, std::array<bool, 2> output_mask);
  25. }} // namespace at::native