UnfoldBackward.h 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. #pragma once
  2. #include <ATen/core/Tensor.h>
  3. #include <ATen/TensorIterator.h>
  4. #include <ATen/native/DispatchStub.h>
  5. #include <ATen/native/NonEmptyUtils.h>
  6. #ifndef AT_PER_OPERATOR_HEADERS
  7. #include <ATen/Functions.h>
  8. #else
  9. #include <ATen/ops/arange.h>
  10. #endif
  11. namespace at { namespace native {
  12. using unfold_backward_fn = void (*)(
  13. Tensor& grad_in,
  14. const Tensor& grad,
  15. int64_t dim,
  16. int64_t size,
  17. int64_t step
  18. );
  19. DECLARE_DISPATCH(unfold_backward_fn, unfold_backward_stub);
  20. namespace {
  21. // Note on naming: it is unconventional.
  22. // grad_in does not mean that it is a gradient wrt to input,
  23. // grad_in/grad_out is just an input/output of unfold_backward kernel.
  24. static C10_UNUSED TensorIterator _make_unfold_backward_iter_over_grad_out(
  25. Tensor& grad_out,
  26. const Tensor& grad_in,
  27. int64_t dim,
  28. int64_t size,
  29. int64_t step
  30. ) {
  31. dim = maybe_wrap_dim(dim, grad_out.dim());
  32. // last dim stores the folds
  33. auto grad_out_dim_size = ensure_nonempty_size(grad_out, dim);
  34. auto grad_in_dim_size = ensure_nonempty_size(grad_in, dim);
  35. // dictates the number of elements to iterate over
  36. // in dimension `dim`
  37. auto iter_dim_size = std::min(
  38. grad_out_dim_size,
  39. (grad_in_dim_size - 1) * step + size
  40. );
  41. /* prepare grad_out for TensorIterator { */
  42. auto grad_out_strides = ensure_nonempty_vec(grad_out.strides().vec());
  43. auto grad_out_sizes = ensure_nonempty_vec(grad_out.sizes().vec());
  44. grad_out_sizes[dim] = iter_dim_size;
  45. auto grad_out_restrided = grad_out.as_strided(
  46. grad_out_sizes, grad_out_strides
  47. );
  48. /* } */
  49. /* prepare grad_in for TensorIterator { */
  50. auto grad_in_strides = ensure_nonempty_vec(grad_in.strides().vec());
  51. auto grad_in_sizes = ensure_nonempty_vec(grad_in.sizes().vec());
  52. // set strides for dim to 0
  53. // and size to 1 because
  54. // this dimension is indexed inside the kernel
  55. grad_in_strides[dim] = 0;
  56. grad_in_sizes[dim] = 1;
  57. grad_in_strides.pop_back();
  58. grad_in_sizes.pop_back();
  59. auto grad_in_restrided = grad_in.squeeze(-1).as_strided(
  60. grad_in_sizes, grad_in_strides
  61. );
  62. /* } */
  63. // During the TensorIterator iteration we have to know
  64. // i_dim in grad_out[i_1,...,i_dim,...i_n],
  65. // idx_dim stores this information
  66. /* prepare idx_dim for TensorIterator { */
  67. auto idx_dim = at::arange(
  68. 0, iter_dim_size, grad_in.options().dtype(at::kLong)
  69. );
  70. auto grad_out_dim = ensure_nonempty_dim(grad_out.dim());
  71. auto idx_dim_strides = std::vector<int64_t>(grad_out_dim, 0);
  72. auto idx_dim_sizes = std::vector<int64_t>(grad_out_dim, 1);
  73. idx_dim_strides[dim] = 1;
  74. idx_dim_sizes[dim] = iter_dim_size;
  75. // idx_dim size will broadcast over determined by grad_out sizes in TensorIterator
  76. auto idx_dim_restrided = idx_dim.as_strided(idx_dim_sizes, idx_dim_strides);
  77. /* } */
  78. auto iter = TensorIteratorConfig()
  79. .set_check_mem_overlap(false)
  80. .check_all_same_dtype(false)
  81. .resize_outputs(false)
  82. .add_owned_output(grad_out_restrided)
  83. .add_owned_input(grad_in_restrided)
  84. .add_owned_input(idx_dim_restrided)
  85. .build();
  86. return iter;
  87. }
  88. }
  89. }} // namespace at::native