im2col.cuh 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. #pragma once
  2. #include <ATen/AccumulateType.h>
  3. #include <ATen/cuda/CUDAContext.h>
  4. #include <ATen/cuda/detail/KernelUtils.h>
  5. #include <c10/macros/Macros.h>
  6. namespace at {
  7. namespace native {
  8. using namespace at::cuda::detail;
  9. // Kernel for fast unfold+copy
  10. // (borrowed from Caffe:
  11. // https://github.com/BVLC/caffe/blob/master/src/caffe/layers/conv_layer.cu)
  12. // CUDA_NUM_THREADS = 1024
  13. template <typename dt>
  14. C10_LAUNCH_BOUNDS_1(1024)
  15. __global__ void im2col_kernel(
  16. const int64_t n,
  17. const dt* data_im,
  18. const int64_t height,
  19. const int64_t width,
  20. const int64_t kernel_height,
  21. const int64_t kernel_width,
  22. const int64_t pad_height,
  23. const int64_t pad_width,
  24. const int64_t stride_height,
  25. const int64_t stride_width,
  26. const int64_t dilation_height,
  27. const int64_t dilation_width,
  28. const int64_t height_col,
  29. const int64_t width_col,
  30. dt* data_col) {
  31. CUDA_KERNEL_LOOP(index, n) {
  32. int64_t w_out = index % width_col;
  33. int64_t idx = index / width_col;
  34. int64_t h_out = idx % height_col;
  35. int64_t channel_in = idx / height_col;
  36. int64_t channel_out = channel_in * kernel_height * kernel_width;
  37. int64_t h_in = h_out * stride_height - pad_height;
  38. int64_t w_in = w_out * stride_width - pad_width;
  39. dt* col = data_col + (channel_out * height_col + h_out) * width_col + w_out;
  40. const dt* im = data_im + (channel_in * height + h_in) * width + w_in;
  41. for (int64_t i = 0; i < kernel_height; ++i) {
  42. for (int64_t j = 0; j < kernel_width; ++j) {
  43. int64_t h = h_in + i * dilation_height;
  44. int64_t w = w_in + j * dilation_width;
  45. *col = (h >= 0 && w >= 0 && h < height && w < width)
  46. ? im[i * dilation_height * width + j * dilation_width]
  47. : static_cast<dt>(0);
  48. col += height_col * width_col;
  49. }
  50. }
  51. }
  52. }
  53. template <typename dt>
  54. void im2col(
  55. cudaStream_t stream,
  56. const dt* data_im,
  57. const int64_t channels,
  58. const int64_t height,
  59. const int64_t width,
  60. const int64_t height_col,
  61. const int64_t width_col,
  62. const int64_t kernel_height,
  63. const int64_t kernel_width,
  64. const int64_t pad_height,
  65. const int64_t pad_width,
  66. const int64_t stride_height,
  67. const int64_t stride_width,
  68. const int64_t dilation_height,
  69. const int64_t dilation_width,
  70. dt* data_col) {
  71. // We are going to launch channels * height_col * width_col kernels, each
  72. // kernel responsible for copying a single-channel grid.
  73. int64_t num_kernels = channels * height_col * width_col;
  74. // Launch CUDA_NUM_THREADS = 1024
  75. im2col_kernel<<<GET_BLOCKS(num_kernels), 1024, 0, stream>>>(
  76. num_kernels,
  77. data_im,
  78. height,
  79. width,
  80. kernel_height,
  81. kernel_width,
  82. pad_height,
  83. pad_width,
  84. stride_height,
  85. stride_width,
  86. dilation_height,
  87. dilation_width,
  88. height_col,
  89. width_col,
  90. data_col);
  91. C10_CUDA_KERNEL_LAUNCH_CHECK();
  92. }
  93. template <typename accT, typename dt>
  94. __forceinline__ __device__ void col2im_device(
  95. const int64_t index,
  96. const dt* data_col,
  97. const int64_t height,
  98. const int64_t width,
  99. const int64_t channels,
  100. const int64_t kernel_h,
  101. const int64_t kernel_w,
  102. const int64_t pad_height,
  103. const int64_t pad_width,
  104. const int64_t stride_height,
  105. const int64_t stride_width,
  106. const int64_t dilation_height,
  107. const int64_t dilation_width,
  108. const int64_t height_col,
  109. const int64_t width_col,
  110. dt* data_im) {
  111. accT val = static_cast<accT>(0);
  112. const int64_t w_im = index % width + pad_width;
  113. const int64_t h_im = (index / width) % height + pad_height;
  114. const int64_t c_im = index / (width * height);
  115. int64_t kernel_extent_w = (kernel_w - 1) * dilation_width + 1;
  116. int64_t kernel_extent_h = (kernel_h - 1) * dilation_height + 1;
  117. // compute the start and end of the output
  118. const int64_t w_col_start = (w_im < kernel_extent_w)
  119. ? 0
  120. : (w_im - kernel_extent_w) / stride_width + 1;
  121. const int64_t w_col_end = ::min(w_im / stride_width + 1, width_col);
  122. const int64_t h_col_start = (h_im < kernel_extent_h)
  123. ? 0
  124. : (h_im - kernel_extent_h) / stride_height + 1;
  125. const int64_t h_col_end = ::min(h_im / stride_height + 1, height_col);
  126. // TODO: use LCM of stride and dilation to avoid unnecessary loops
  127. for (int64_t h_col = h_col_start; h_col < h_col_end; h_col += 1) {
  128. for (int64_t w_col = w_col_start; w_col < w_col_end; w_col += 1) {
  129. int64_t h_k = (h_im - h_col * stride_height);
  130. int64_t w_k = (w_im - w_col * stride_width);
  131. if (h_k % dilation_height == 0 && w_k % dilation_width == 0) {
  132. h_k /= dilation_height;
  133. w_k /= dilation_width;
  134. int64_t data_col_index =
  135. (((c_im * kernel_h + h_k) * kernel_w + w_k) * height_col +
  136. h_col) *
  137. width_col +
  138. w_col;
  139. val += data_col[data_col_index];
  140. }
  141. }
  142. }
  143. data_im[index] = static_cast<dt>(val);
  144. }
  145. template <typename dt, typename accT>
  146. C10_LAUNCH_BOUNDS_1(512)
  147. __global__ void col2im_kernel(
  148. const int64_t n,
  149. const dt* data_col,
  150. const int64_t height,
  151. const int64_t width,
  152. const int64_t channels,
  153. const int64_t kernel_h,
  154. const int64_t kernel_w,
  155. const int64_t pad_height,
  156. const int64_t pad_width,
  157. const int64_t stride_height,
  158. const int64_t stride_width,
  159. const int64_t dilation_height,
  160. const int64_t dilation_width,
  161. const int64_t height_col,
  162. const int64_t width_col,
  163. dt* data_im) {
  164. CUDA_KERNEL_LOOP(index, n) {
  165. col2im_device<accT>(
  166. index,
  167. data_col,
  168. height,
  169. width,
  170. channels,
  171. kernel_h,
  172. kernel_w,
  173. pad_height,
  174. pad_width,
  175. stride_height,
  176. stride_width,
  177. dilation_height,
  178. dilation_width,
  179. height_col,
  180. width_col,
  181. data_im);
  182. }
  183. }
  184. template <typename dt, typename accT>
  185. void col2im(
  186. cudaStream_t stream,
  187. const dt* data_col,
  188. const int64_t channels,
  189. const int64_t height,
  190. const int64_t width,
  191. const int64_t height_col,
  192. const int64_t width_col,
  193. const int64_t patch_height,
  194. const int64_t patch_width,
  195. const int64_t pad_height,
  196. const int64_t pad_width,
  197. const int64_t stride_height,
  198. const int64_t stride_width,
  199. const int64_t dilation_height,
  200. const int64_t dilation_width,
  201. dt* data_im) {
  202. int64_t num_kernels = channels * height * width;
  203. // To avoid involving atomic operations, we will launch one kernel per
  204. // bottom dimension, and then in the kernel add up the top dimensions.
  205. // CUDA_NUM_THREADS = 1024
  206. col2im_kernel<dt, accT>
  207. <<<GET_BLOCKS(num_kernels, 512), 512, 0, stream>>>(
  208. num_kernels,
  209. data_col,
  210. height,
  211. width,
  212. channels,
  213. patch_height,
  214. patch_width,
  215. pad_height,
  216. pad_width,
  217. stride_height,
  218. stride_width,
  219. dilation_height,
  220. dilation_width,
  221. height_col,
  222. width_col,
  223. data_im);
  224. C10_CUDA_KERNEL_LAUNCH_CHECK();
  225. }
  226. template <typename dt>
  227. C10_LAUNCH_BOUNDS_1(512)
  228. __global__ void col2im_batched_kernel(
  229. const int64_t n,
  230. const dt* data_col,
  231. const int64_t col_batch_stride,
  232. const int64_t nbatch,
  233. const int64_t height,
  234. const int64_t width,
  235. const int64_t channels,
  236. const int64_t kernel_h,
  237. const int64_t kernel_w,
  238. const int64_t pad_height,
  239. const int64_t pad_width,
  240. const int64_t stride_height,
  241. const int64_t stride_width,
  242. const int64_t dilation_height,
  243. const int64_t dilation_width,
  244. const int64_t height_col,
  245. const int64_t width_col,
  246. dt* data_im,
  247. const int64_t im_batch_stride) {
  248. using accT = at::acc_type<dt, /*is_cuda*/true>;
  249. const auto im_numel = n * nbatch;
  250. CUDA_KERNEL_LOOP_TYPE(index, im_numel, int64_t) {
  251. const auto ibatch = index / n;
  252. const auto slice_index = index % n;
  253. col2im_device<accT>(
  254. slice_index,
  255. data_col + ibatch * col_batch_stride,
  256. height,
  257. width,
  258. channels,
  259. kernel_h,
  260. kernel_w,
  261. pad_height,
  262. pad_width,
  263. stride_height,
  264. stride_width,
  265. dilation_height,
  266. dilation_width,
  267. height_col,
  268. width_col,
  269. data_im + ibatch * im_batch_stride);
  270. }
  271. }
  272. template <typename dt>
  273. void col2im_batched(
  274. cudaStream_t stream,
  275. const dt* data_col,
  276. const int64_t col_batch_stride,
  277. const int64_t nbatch,
  278. const int64_t channels,
  279. const int64_t height,
  280. const int64_t width,
  281. const int64_t height_col,
  282. const int64_t width_col,
  283. const int64_t patch_height,
  284. const int64_t patch_width,
  285. const int64_t pad_height,
  286. const int64_t pad_width,
  287. const int64_t stride_height,
  288. const int64_t stride_width,
  289. const int64_t dilation_height,
  290. const int64_t dilation_width,
  291. dt* data_im,
  292. const int64_t im_batch_stride) {
  293. const int64_t num_kernels = channels * height * width;
  294. const int64_t output_numel = nbatch * num_kernels;
  295. if (output_numel == 0) {
  296. return; // No work to do
  297. }
  298. // To avoid involving atomic operations, we will launch one kernel per
  299. // bottom dimension, and then in the kernel add up the top dimensions.
  300. // CUDA_NUM_THREADS = 1024
  301. col2im_batched_kernel<<<GET_BLOCKS(output_numel, 512), 512, 0, stream>>>(
  302. num_kernels,
  303. data_col,
  304. col_batch_stride,
  305. nbatch,
  306. height,
  307. width,
  308. channels,
  309. patch_height,
  310. patch_width,
  311. pad_height,
  312. pad_width,
  313. stride_height,
  314. stride_width,
  315. dilation_height,
  316. dilation_width,
  317. height_col,
  318. width_col,
  319. data_im,
  320. im_batch_stride);
  321. C10_CUDA_KERNEL_LAUNCH_CHECK();
  322. }
  323. } // namespace native
  324. } // namespace at