vol2col.cuh 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. #pragma once
  2. #include <ATen/cuda/CUDAContext.h>
  3. #include <ATen/cuda/detail/KernelUtils.h>
  4. #include <ATen/cuda/detail/IndexUtils.cuh>
  5. #include <ATen/cuda/detail/TensorInfo.cuh>
  6. #include <c10/macros/Macros.h>
  7. namespace at {
  8. namespace native {
  9. using namespace at::cuda::detail;
  10. // Kernel for fast unfold+copy on volumes
  11. template <typename T>
  12. __global__ void vol2col_kernel(
  13. const int64_t n,
  14. const T* data_vol,
  15. const int depth,
  16. const int height,
  17. const int width,
  18. const int ksize_t,
  19. const int ksize_h,
  20. const int ksize_w,
  21. const int pad_t,
  22. const int pad_h,
  23. const int pad_w,
  24. const int stride_t,
  25. const int stride_h,
  26. const int stride_w,
  27. const int dilation_t,
  28. const int dilation_h,
  29. const int dilation_w,
  30. const int depth_col,
  31. const int height_col,
  32. const int width_col,
  33. T* data_col) {
  34. CUDA_KERNEL_LOOP(index, n) {
  35. auto w_out = index % width_col;
  36. index /= width_col;
  37. auto h_out = index % height_col;
  38. index /= height_col;
  39. auto t_out = index % depth_col;
  40. auto channel_in = index / depth_col;
  41. auto channel_out = channel_in * ksize_t * ksize_h * ksize_w;
  42. auto t_in = t_out * stride_t - pad_t;
  43. auto h_in = h_out * stride_h - pad_h;
  44. auto w_in = w_out * stride_w - pad_w;
  45. data_col +=
  46. ((channel_out * depth_col + t_out) * height_col + h_out) * width_col +
  47. w_out;
  48. data_vol += ((channel_in * depth + t_in) * height + h_in) * width + w_in;
  49. for (int i = 0; i < ksize_t; ++i) {
  50. for (int j = 0; j < ksize_h; ++j) {
  51. for (int k = 0; k < ksize_w; ++k) {
  52. auto t = t_in + i * dilation_t;
  53. auto h = h_in + j * dilation_h;
  54. auto w = w_in + k * dilation_w;
  55. *data_col = (t >= 0 && h >= 0 && w >= 0 && t < depth && h < height &&
  56. w < width)
  57. ? data_vol
  58. [i * dilation_t * height * width + j * dilation_h * width +
  59. k * dilation_w]
  60. : static_cast<T>(0);
  61. data_col += depth_col * height_col * width_col;
  62. }
  63. }
  64. }
  65. }
  66. }
  67. template <typename T>
  68. void vol2col(
  69. cudaStream_t stream,
  70. const T* data_vol,
  71. const int channels,
  72. const int depth,
  73. const int height,
  74. const int width,
  75. const int depth_col,
  76. const int height_col,
  77. const int width_col,
  78. const int ksize_t,
  79. const int ksize_h,
  80. const int ksize_w,
  81. const int pad_t,
  82. const int pad_h,
  83. const int pad_w,
  84. const int stride_t,
  85. const int stride_h,
  86. const int stride_w,
  87. const int dilation_t,
  88. const int dilation_h,
  89. const int dilation_w,
  90. T* data_col) {
  91. // We are going to launch channels * depth_col * height_col * width_col
  92. // kernels, each kernel responsible for copying a single-channel grid.
  93. // We cast an operand to int64 so that the product will not overflow
  94. const auto num_kernels = static_cast<int64_t>(channels) * depth_col * height_col * width_col;
  95. // Launch
  96. vol2col_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, stream>>>(
  97. num_kernels,
  98. data_vol,
  99. depth,
  100. height,
  101. width,
  102. ksize_t,
  103. ksize_h,
  104. ksize_w,
  105. pad_t,
  106. pad_h,
  107. pad_w,
  108. stride_t,
  109. stride_h,
  110. stride_w,
  111. dilation_t,
  112. dilation_h,
  113. dilation_w,
  114. depth_col,
  115. height_col,
  116. width_col,
  117. data_col);
  118. C10_CUDA_KERNEL_LAUNCH_CHECK();
  119. }
  120. template <typename T, typename accT>
  121. __global__ void vol2im_kernel(
  122. const int64_t n,
  123. const T* data_col,
  124. const unsigned depth,
  125. const unsigned height,
  126. const unsigned width,
  127. const unsigned channels,
  128. const unsigned kernel_t,
  129. const unsigned kernel_h,
  130. const unsigned kernel_w,
  131. const unsigned pad_t,
  132. const unsigned pad_h,
  133. const unsigned pad_w,
  134. const unsigned stride_t,
  135. const unsigned stride_h,
  136. const unsigned stride_w,
  137. const unsigned dilation_t,
  138. const unsigned dilation_h,
  139. const unsigned dilation_w,
  140. const unsigned depth_col,
  141. const unsigned height_col,
  142. const unsigned width_col,
  143. T* data_vol) {
  144. CUDA_KERNEL_LOOP(index, n) {
  145. accT val = static_cast<accT>(0);
  146. const auto w_im = index % width + pad_w;
  147. const auto h_im = (index / width) % height + pad_h;
  148. const auto t_im = (index / width / height) % depth + pad_t;
  149. const auto c_im = index / (width * height * depth);
  150. auto kernel_extent_w = (kernel_w - 1) * dilation_w + 1;
  151. auto kernel_extent_h = (kernel_h - 1) * dilation_h + 1;
  152. auto kernel_extent_t = (kernel_t - 1) * dilation_t + 1;
  153. // compute the start and end of the output
  154. const auto w_col_start =
  155. (w_im < kernel_extent_w) ? 0 : (w_im - kernel_extent_w) / stride_w + 1;
  156. const auto w_col_end = std::min(w_im / stride_w + 1, width_col);
  157. const auto h_col_start =
  158. (h_im < kernel_extent_h) ? 0 : (h_im - kernel_extent_h) / stride_h + 1;
  159. const auto h_col_end = std::min(h_im / stride_h + 1, height_col);
  160. const auto t_col_start =
  161. (t_im < kernel_extent_t) ? 0 : (t_im - kernel_extent_t) / stride_t + 1;
  162. const auto t_col_end = std::min(t_im / stride_t + 1, depth_col);
  163. // TODO: use LCM of stride and dilation to avoid unnecessary loops
  164. for (unsigned t_col = t_col_start; t_col < t_col_end; t_col += 1) {
  165. for (unsigned h_col = h_col_start; h_col < h_col_end; h_col += 1) {
  166. for (unsigned w_col = w_col_start; w_col < w_col_end; w_col += 1) {
  167. uint64_t t_k = (t_im - t_col * stride_t);
  168. uint64_t h_k = (h_im - h_col * stride_h);
  169. uint64_t w_k = (w_im - w_col * stride_w);
  170. if (t_k % dilation_t == 0 && h_k % dilation_h == 0 &&
  171. w_k % dilation_w == 0) {
  172. t_k /= dilation_t;
  173. h_k /= dilation_h;
  174. w_k /= dilation_w;
  175. const int64_t idx_k =
  176. ((c_im * kernel_t + t_k) * kernel_h + h_k) * kernel_w + w_k;
  177. const int64_t data_col_index =
  178. ((idx_k * depth_col + t_col) *
  179. height_col + h_col) *
  180. width_col + w_col;
  181. val += data_col[data_col_index];
  182. }
  183. }
  184. }
  185. }
  186. data_vol[index] = static_cast<T>(val);
  187. }
  188. }
  189. template <typename T, typename accT>
  190. void col2vol(
  191. cudaStream_t stream,
  192. const T* data_col,
  193. const int64_t channels,
  194. const int64_t depth,
  195. const int64_t height,
  196. const int64_t width,
  197. const int64_t output_depth,
  198. const int64_t output_height,
  199. const int64_t output_width,
  200. const int64_t patch_t,
  201. const int64_t patch_h,
  202. const int64_t patch_w,
  203. const int64_t pad_t,
  204. const int64_t pad_h,
  205. const int64_t pad_w,
  206. const int64_t stride_t,
  207. const int64_t stride_h,
  208. const int64_t stride_w,
  209. const int64_t dilation_t,
  210. const int64_t dilation_h,
  211. const int64_t dilation_w,
  212. T* data_vol) {
  213. const auto num_kernels = channels * depth * height * width;
  214. auto check_fits_in_unsigned =
  215. [](int64_t val, const char * name) {
  216. constexpr auto umax = std::numeric_limits<unsigned>::max();
  217. TORCH_CHECK(val >= 0 && val <= umax,
  218. name, " must fit in a 32-bit unsigned value");
  219. };
  220. check_fits_in_unsigned(num_kernels, "input size");
  221. check_fits_in_unsigned(
  222. channels * patch_t * patch_h * patch_w, "channels x kernel size");
  223. // To avoid involving atomic operations, we will launch one kernel per
  224. // bottom dimension, and then in the kernel add up the top dimensions.
  225. vol2im_kernel<T, accT>
  226. <<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, stream>>>(
  227. num_kernels,
  228. data_col,
  229. depth,
  230. height,
  231. width,
  232. channels,
  233. patch_t,
  234. patch_h,
  235. patch_w,
  236. pad_t,
  237. pad_h,
  238. pad_w,
  239. stride_t,
  240. stride_h,
  241. stride_w,
  242. dilation_t,
  243. dilation_h,
  244. dilation_w,
  245. output_depth,
  246. output_height,
  247. output_width,
  248. data_vol);
  249. C10_CUDA_KERNEL_LAUNCH_CHECK();
  250. }
  251. } // namespace native
  252. } // namespace at