MultiTensorApply.cuh 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. #pragma once
  2. #include <ATen/core/Tensor.h>
  3. #include <ATen/cuda/CUDAContext.h>
  4. #include <c10/cuda/CUDAGuard.h>
  5. #include <ATen/native/cuda/Loops.cuh>
  6. #include <ATen/native/cuda/MemoryAccess.cuh>
  7. namespace at { namespace native {
  8. namespace {
  9. static constexpr int64_t kILP = 4;
  10. static constexpr int64_t kChunkSize = 65536;
  11. static constexpr int64_t kBlockSize = 512;
  12. template<typename T>
  13. __device__ __forceinline__ bool is_aligned(T* p){
  14. return ((uint64_t)p) % (kILP * sizeof(T)) == 0;
  15. }
  16. template<typename T>
  17. __device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
  18. using LT = at::native::memory::aligned_vector<T, kILP>;
  19. ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
  20. }
  21. // TODO(crcrpar): Add `n>5` for `low prec params & their higher prec copy`
  22. // TensorListMetadata has to be < 4KB - the limit for kernel launch argument
  23. static constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
  24. static constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
  25. static constexpr int depth_to_max_tensors_scalarlist[5] = {96, 64, 48, 36, 30};
  26. template<int n> struct TensorListMetadata
  27. {
  28. void* addresses[n][depth_to_max_tensors[n-1]];
  29. int numel_for_tensor[depth_to_max_tensors[n-1]];
  30. unsigned char block_to_tensor[depth_to_max_blocks[n-1]];
  31. int block_to_chunk[depth_to_max_blocks[n-1]];
  32. int start_tensor_this_launch;
  33. };
  34. // NOTE(crcrpar): This is a conservative resolution to handle `state_steps`
  35. // whose each element is `at::Tensor` of 1 element representing the number of `step`s called so far.
  36. template<int n> struct FusedOptimizerTensorListMetadata
  37. {
  38. void* addresses[n][depth_to_max_tensors[n-1]];
  39. int numel_for_tensor[depth_to_max_tensors[n-1]];
  40. void* state_steps_addresses[depth_to_max_tensors_scalarlist[n-1]];
  41. unsigned char block_to_tensor[depth_to_max_blocks[n-1]];
  42. int block_to_chunk[depth_to_max_blocks[n-1]];
  43. int start_tensor_this_launch;
  44. };
  45. template<typename scalar_vals_t, int n> struct TensorListScalarListMetadata
  46. {
  47. void* addresses[n][depth_to_max_tensors_scalarlist[n-1]];
  48. int numel_for_tensor[depth_to_max_tensors_scalarlist[n-1]];
  49. scalar_vals_t scalar_vals[depth_to_max_tensors_scalarlist[n-1]];
  50. unsigned char block_to_tensor[depth_to_max_blocks[n-1]];
  51. int block_to_chunk[depth_to_max_blocks[n-1]];
  52. };
  53. // note(mkozuki): `n` of 96 and `scalar_vals_t` of `c10::complex<double>`
  54. // violates the cuda kernel argument size limitation of 4kb.
  55. // 80 is a number that does not violate this limitation.
  56. template<> struct TensorListScalarListMetadata<c10::complex<double>, 1>
  57. {
  58. void* addresses[1][80];
  59. int numel_for_tensor[80];
  60. c10::complex<double> scalar_vals[80];
  61. unsigned char block_to_tensor[depth_to_max_blocks[1-1]];
  62. int block_to_chunk[depth_to_max_blocks[1-1]];
  63. };
  64. template<typename T, typename U, typename... ArgTypes>
  65. C10_LAUNCH_BOUNDS_1(kBlockSize)
  66. __global__ void
  67. multi_tensor_apply_kernel(
  68. T tensorListMeta,
  69. U callable,
  70. ArgTypes... args) {
  71. // Hand the chunk information to the user-supplied functor to process however it likes.
  72. callable(kChunkSize, tensorListMeta, args...);
  73. }
  74. template<int depth, typename scalar_T, typename T, typename... ArgTypes>
  75. void multi_tensor_apply(
  76. std::vector<std::vector<at::Tensor>>& tensor_lists,
  77. at::ArrayRef<Scalar> scalars,
  78. T callable,
  79. ArgTypes... args) {
  80. TORCH_CHECK(tensor_lists.size() == depth, "Number of tensor lists has to match the depth.");
  81. size_t n_tensors = tensor_lists[0].size();
  82. using scalar_vals_t = typename T::opmath_t;
  83. TensorListScalarListMetadata<scalar_vals_t, depth> tensorListMeta;
  84. int loc_block_info = 0;
  85. int loc_tensor_info = 0;
  86. for(size_t t = 0; t < n_tensors; t++) {
  87. tensorListMeta.scalar_vals[loc_tensor_info] = scalars[t].to<scalar_T>();
  88. tensorListMeta.numel_for_tensor[loc_tensor_info] = tensor_lists[0][t].numel();
  89. for (int d = 0; d < depth; d++) {
  90. tensorListMeta.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
  91. }
  92. loc_tensor_info++;
  93. int chunks = (tensor_lists[0][t].numel() + kChunkSize - 1)/kChunkSize;
  94. for (int chunk = 0; chunk < chunks; chunk++) {
  95. tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
  96. tensorListMeta.block_to_chunk[loc_block_info] = chunk;
  97. loc_block_info++;
  98. bool tensors_full = (loc_tensor_info == depth_to_max_tensors_scalarlist[depth-1] &&
  99. chunk == chunks - 1);
  100. bool blocks_full = (loc_block_info == depth_to_max_blocks[depth-1]);
  101. bool last_chunk = (t == n_tensors - 1 && chunk == chunks - 1);
  102. if (tensors_full || blocks_full || last_chunk) {
  103. multi_tensor_apply_kernel<<<loc_block_info, kBlockSize, 0, at::cuda::getCurrentCUDAStream()>>>(
  104. tensorListMeta,
  105. callable,
  106. args...);
  107. C10_CUDA_KERNEL_LAUNCH_CHECK();
  108. // Reset.
  109. loc_block_info = 0;
  110. if(chunk == chunks - 1) {
  111. loc_tensor_info = 0;
  112. }
  113. else {
  114. tensorListMeta.numel_for_tensor[0] = tensorListMeta.numel_for_tensor[loc_tensor_info-1];
  115. tensorListMeta.scalar_vals[0] = tensorListMeta.scalar_vals[loc_tensor_info-1];
  116. for(int d = 0; d < depth; d++) {
  117. tensorListMeta.addresses[d][0] = tensorListMeta.addresses[d][loc_tensor_info-1];
  118. }
  119. loc_tensor_info = 1;
  120. }
  121. }
  122. }
  123. }
  124. }
  125. template<int depth, typename T, typename... ArgTypes>
  126. void multi_tensor_apply(
  127. std::vector<std::vector<at::Tensor>>& tensor_lists,
  128. T callable,
  129. ArgTypes... args) {
  130. TORCH_CHECK(tensor_lists.size() == depth, "Number of tensor lists has to match the depth.");
  131. size_t n_tensors = tensor_lists[0].size();
  132. TensorListMetadata<depth> tensorListMeta;
  133. tensorListMeta.start_tensor_this_launch = 0;
  134. int loc_block_info = 0;
  135. int loc_tensor_info = 0;
  136. for(size_t t = 0; t < n_tensors; t++) {
  137. tensorListMeta.numel_for_tensor[loc_tensor_info] = tensor_lists[0][t].numel();
  138. for (int d = 0; d < depth; d++) {
  139. tensorListMeta.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
  140. }
  141. loc_tensor_info++;
  142. int chunks = (tensor_lists[0][t].numel() + kChunkSize - 1)/kChunkSize;
  143. for (int chunk = 0; chunk < chunks; chunk++) {
  144. tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
  145. tensorListMeta.block_to_chunk[loc_block_info] = chunk;
  146. loc_block_info++;
  147. bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth-1] &&
  148. chunk == chunks - 1);
  149. bool blocks_full = (loc_block_info == depth_to_max_blocks[depth-1]);
  150. bool last_chunk = (t == n_tensors - 1 && chunk == chunks - 1);
  151. if (tensors_full || blocks_full || last_chunk) {
  152. multi_tensor_apply_kernel<<<loc_block_info, kBlockSize, 0, at::cuda::getCurrentCUDAStream()>>>(
  153. tensorListMeta,
  154. callable,
  155. args...);
  156. C10_CUDA_KERNEL_LAUNCH_CHECK();
  157. // Reset.
  158. loc_block_info = 0;
  159. if(chunk == chunks - 1) {
  160. loc_tensor_info = 0;
  161. tensorListMeta.start_tensor_this_launch = t + 1;
  162. }
  163. else {
  164. tensorListMeta.numel_for_tensor[0] = tensorListMeta.numel_for_tensor[loc_tensor_info-1];
  165. for(int d = 0; d < depth; d++) {
  166. tensorListMeta.addresses[d][0] = tensorListMeta.addresses[d][loc_tensor_info-1];
  167. }
  168. loc_tensor_info = 1;
  169. tensorListMeta.start_tensor_this_launch = t;
  170. }
  171. }
  172. }
  173. }
  174. }
  175. template<int depth, typename T, typename... ArgTypes>
  176. void multi_tensor_apply_for_fused_optimizer(
  177. std::vector<std::vector<at::Tensor>>& tensor_lists,
  178. at::TensorList state_steps,
  179. T callable,
  180. ArgTypes... args) {
  181. TORCH_CHECK(tensor_lists.size() == depth, "Number of tensor lists has to match the depth");
  182. const auto num_tensors = tensor_lists[0].size();
  183. FusedOptimizerTensorListMetadata<depth> tensorListMeta;
  184. int loc_block_info = 0;
  185. int loc_tensor_info = 0;
  186. for (const auto & tensor_index : c10::irange(num_tensors)) {
  187. tensorListMeta.state_steps_addresses[loc_tensor_info] = state_steps[tensor_index].data_ptr();
  188. tensorListMeta.numel_for_tensor[loc_tensor_info] = tensor_lists[0][tensor_index].numel();
  189. for (const auto & d : c10::irange(depth)) {
  190. tensorListMeta.addresses[d][loc_tensor_info] = tensor_lists[d][tensor_index].data_ptr();
  191. }
  192. loc_tensor_info++;
  193. const auto chunks = (tensor_lists[0][tensor_index].numel() + kChunkSize - 1) / kChunkSize;
  194. for (const auto & chunk : c10::irange(chunks)) {
  195. tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
  196. tensorListMeta.block_to_chunk[loc_block_info] = chunk;
  197. loc_block_info++;
  198. const auto tensor_full = (loc_tensor_info == depth_to_max_tensors[depth - 1] && chunk == chunks - 1);
  199. const auto blocks_full = loc_block_info == depth_to_max_blocks[depth - 1];
  200. const auto last_chunk = (tensor_index == num_tensors - 1 && chunk == chunks - 1);
  201. if (tensor_full || blocks_full || last_chunk) {
  202. multi_tensor_apply_kernel<<<loc_block_info, kBlockSize, 0, at::cuda::getCurrentCUDAStream()>>>(
  203. tensorListMeta,
  204. callable,
  205. args...);
  206. C10_CUDA_KERNEL_LAUNCH_CHECK();
  207. // Reset.
  208. loc_block_info = 0;
  209. if (chunk == chunks - 1) {
  210. loc_tensor_info = 0;
  211. } else {
  212. tensorListMeta.numel_for_tensor[0] = tensorListMeta.numel_for_tensor[loc_tensor_info - 1];
  213. tensorListMeta.state_steps_addresses[0] = tensorListMeta.state_steps_addresses[loc_tensor_info - 1];
  214. for (const auto & d : c10::irange(depth)) {
  215. tensorListMeta.addresses[d][0] = tensorListMeta.addresses[d][loc_tensor_info - 1];
  216. }
  217. loc_tensor_info = 1;
  218. }
  219. }
  220. }
  221. }
  222. }
  223. } // namespace
  224. }} // at::native