im2col.h 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. #pragma once
  2. #include <ATen/core/Tensor.h>
  3. #include <ATen/TensorUtils.h>
  4. #include <ATen/Utils.h>
  5. #include <ATen/Parallel.h>
  6. #include <ATen/native/cpu/utils.h>
  7. #include <c10/util/irange.h>
  8. #include <algorithm>
  9. namespace at {
  10. namespace native {
  11. template <typename T>
  12. static void im2col(
  13. const T* data_im,
  14. const int64_t channels,
  15. const int64_t height,
  16. const int64_t width,
  17. const int64_t output_height,
  18. const int64_t output_width,
  19. const int64_t kernel_h,
  20. const int64_t kernel_w,
  21. const int64_t pad_h,
  22. const int64_t pad_w,
  23. const int64_t stride_h,
  24. const int64_t stride_w,
  25. const int64_t dilation_h,
  26. const int64_t dilation_w,
  27. T* data_col,
  28. bool is_channels_last = false) {
  29. const int64_t height_col = output_height;
  30. const int64_t width_col = output_width;
  31. const int64_t channels_col = channels * kernel_h * kernel_w;
  32. if (is_channels_last) {
  33. at::parallel_for(0, height_col * width_col, 0, [&](int64_t begin, int64_t end) {
  34. int64_t h_col{0}, w_col{0};
  35. data_index_init(begin, h_col, height_col, w_col, width_col);
  36. for (const auto i_col : c10::irange(begin, end)) {
  37. for (const auto h_offset : c10::irange(kernel_h)) {
  38. int64_t h_im = h_col * stride_h - pad_h + h_offset * dilation_h;
  39. for (const auto w_offset : c10::irange(kernel_w)) {
  40. int64_t w_im = w_col * stride_w - pad_w + w_offset * dilation_w;
  41. const T* slice_im = data_im + (h_im * width + w_im) * channels;
  42. T* slice_col = data_col + (i_col * kernel_h * kernel_w + h_offset * kernel_w + w_offset) * channels;
  43. if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
  44. std::copy_n(slice_im, channels, slice_col);
  45. } else {
  46. std::fill_n(slice_col, channels, T(0));
  47. }
  48. }
  49. }
  50. // move the the next index
  51. data_index_step(h_col, height_col, w_col, width_col);
  52. }
  53. });
  54. } else {
  55. at::parallel_for(0, channels_col, 0, [&](int64_t begin, int64_t end) {
  56. int64_t c_im{0}, h_offset{0}, w_offset{0};
  57. data_index_init(begin, c_im, channels, h_offset, kernel_h, w_offset, kernel_w);
  58. for (const auto c_col : c10::irange(begin, end)) {
  59. for (const auto h_col : c10::irange(height_col)) {
  60. int64_t h_im = h_col * stride_h - pad_h + h_offset * dilation_h;
  61. for (const auto w_col : c10::irange(width_col)) {
  62. int64_t w_im = w_col * stride_w - pad_w + w_offset * dilation_w;
  63. data_col[(c_col * height_col + h_col) * width_col + w_col] =
  64. (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width)
  65. ? data_im[(c_im * height + h_im) * width + w_im]
  66. : static_cast<T>(0);
  67. }
  68. }
  69. // move to the next index
  70. data_index_step(c_im, channels, h_offset, kernel_h, w_offset, kernel_w);
  71. }
  72. });
  73. }
  74. }
  75. template <typename T>
  76. static void col2im(
  77. const T* data_col,
  78. const int64_t channels,
  79. const int64_t height,
  80. const int64_t width,
  81. const int64_t output_height,
  82. const int64_t output_width,
  83. const int64_t kernel_h,
  84. const int64_t kernel_w,
  85. const int64_t pad_h,
  86. const int64_t pad_w,
  87. const int64_t stride_h,
  88. const int64_t stride_w,
  89. const int64_t dilation_h,
  90. const int64_t dilation_w,
  91. T* data_im,
  92. bool is_channels_last = false) {
  93. std::fill_n(data_im, height * width * channels, T(0));
  94. const int64_t height_col = output_height;
  95. const int64_t width_col = output_width;
  96. const int64_t channels_col = channels * kernel_h * kernel_w;
  97. if (is_channels_last) {
  98. for (const auto h_col : c10::irange(height_col)) {
  99. for (const auto w_col : c10::irange(width_col)) {
  100. for (const auto h_offset : c10::irange(kernel_h)) {
  101. int64_t h_im = h_col * stride_h - pad_h + h_offset * dilation_h;
  102. for (const auto w_offset : c10::irange(kernel_w)) {
  103. int64_t w_im = w_col * stride_w - pad_w + w_offset * dilation_w;
  104. T* slice_im = data_im + (h_im * width + w_im) * channels;
  105. const T* slice_col = data_col + ((h_col * width_col + w_col) * kernel_h * kernel_w
  106. + h_offset * kernel_w + w_offset) * channels;
  107. if (h_im >= 0 && h_im < height && w_im >= 0 && w_im < width) {
  108. std::transform(slice_col, slice_col + channels, slice_im, slice_im, std::plus<T>());
  109. }
  110. }
  111. }
  112. }
  113. }
  114. } else {
  115. for (const auto c_col : c10::irange(channels_col)) {
  116. int64_t w_offset = c_col % kernel_w;
  117. int64_t h_offset = (c_col / kernel_w) % kernel_h;
  118. int64_t c_im = c_col / kernel_h / kernel_w;
  119. for (const auto h_col : c10::irange(height_col)) {
  120. int64_t h_im = h_col * stride_h - pad_h + h_offset * dilation_h;
  121. for (const auto w_col : c10::irange(width_col)) {
  122. int64_t w_im = w_col * stride_w - pad_w + w_offset * dilation_w;
  123. if (h_im >= 0 && h_im < height && w_im >= 0 && w_im < width)
  124. data_im[(c_im * height + h_im) * width + w_im] +=
  125. data_col[(c_col * height_col + h_col) * width_col + w_col];
  126. }
  127. }
  128. }
  129. }
  130. }
  131. } // native
  132. } // at