PersistentSoftmax.cuh 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. #pragma once
  2. #include <assert.h>
  3. #include <cfloat>
  4. #include <limits>
  5. #include <stdint.h>
  6. #include <cuda_fp16.h>
  7. #include <c10/macros/Macros.h>
  8. #include <ATen/cuda/DeviceUtils.cuh>
  9. namespace {
  10. int log2_ceil(int value) {
  11. int log2_value = 0;
  12. while ((1 << log2_value) < value) ++log2_value;
  13. return log2_value;
  14. }
  15. template<typename T>
  16. struct Add {
  17. __device__ __forceinline__ T operator()(T a, T b) const {
  18. return a + b;
  19. }
  20. };
  21. template<typename T>
  22. struct Max {
  23. __device__ __forceinline__ T operator()(T a, T b) const {
  24. return a < b ? b : a;
  25. }
  26. };
  27. template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
  28. __device__ __forceinline__ void warp_reduce(acc_t* sum) {
  29. ReduceOp<acc_t> r;
  30. #pragma unroll
  31. for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
  32. #pragma unroll
  33. for (int i = 0; i < WARP_BATCH; ++i) {
  34. acc_t b = WARP_SHFL_XOR(sum[i], offset, WARP_SIZE);
  35. sum[i] = r(sum[i], b);
  36. }
  37. }
  38. }
  39. // The softmax_warp_* methods perform softmax forward and backward propagation on samples spanning the fast dimension.
  40. // Each sample contains element_count scalar elements. element_count can be any integer value <= 1024.
  41. // The template arguments have the following meaning:
  42. // One "WARP" works on one "BATCH". One "BATCH" contains "WARP_BATCH" samples.
  43. // WARP_BATCH is equal to 1 when element_count is large, and > 1 when element_count is small.
  44. // A "WARP" contains "C10_WARPS_SIZE" threads, these treads are guaranteed to belong to the same warp.
  45. // This is important because it means only __shfl_ instructions are required for reductions.
  46. // Note that this means WARP_SIZE must be a power of two and <= architecture warp size.
  47. // CUDA warp size is 32 for all existing GPU architectures, but there is no guarantee this will not change for future arch.
  48. // ROCm warp size is 64 for all currently ROCm-supported GPU architectures, but this may change for future archs.
  49. // is_log_softmax is a flag indicating whether SoftMax or LogSoftMax should be computed.
  50. // is_masked is a flag indicating whether SoftMax or MaskedSoftMax should be computed.
  51. // The template can be instantiated with any floating point type for the type arguments input_t, output_t and acc_t.
  52. // This allows SoftMax to be fused with a cast immediately following the SoftMax.
  53. // The mask should have the same shape as input, with a boolean indicate if the value is masked.
  54. // The head_chunk_size is only used for transformer mask softmax, equals to H * D * D.
  55. // For instance:
  56. // input_t=half, acc_t=float, output_t=half => read half tensor, float accumulators, write half tensor.
  57. // input_t=half, acc_t=float, output_t=float => read half tensor, float accumulators, write float tensor.
  58. // input_t_float, acc_t=float, output_t=half => read float tensor, float accumulators, write half tensor.
  59. template <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_log_softmax, bool is_masked>
  60. __global__ void softmax_warp_forward(output_t *dst, const input_t *src, int batch_size, int stride, int element_count, const bool *mask = nullptr, const int head_chunk_size = -1, bool is_transformer_mask = false)
  61. {
  62. // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_forward_kernel.
  63. constexpr int next_power_of_two = 1 << log2_elements;
  64. constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
  65. constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
  66. constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
  67. int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
  68. // batch_size might not be a multiple of WARP_BATCH. Check how
  69. // many batches have to computed within this WARP.
  70. int local_batches = batch_size - first_batch;
  71. if (local_batches > WARP_BATCH)
  72. local_batches = WARP_BATCH;
  73. // there might be multiple batches per warp. compute the index within the batch
  74. int local_idx = threadIdx.x;
  75. int idx_offset = first_batch * stride + local_idx;
  76. src += idx_offset;
  77. dst += idx_offset;
  78. if (is_transformer_mask) {
  79. mask += ((first_batch * stride) / head_chunk_size) * stride + local_idx;
  80. } else {
  81. mask += idx_offset;
  82. }
  83. // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop,
  84. // but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep
  85. // the nested loops.
  86. // This should have no impact on performance because the loops are unrolled anyway.
  87. // load data from global memory
  88. acc_t elements[WARP_BATCH][WARP_ITERATIONS];
  89. for (int i = 0; i < WARP_BATCH; ++i) {
  90. int batch_element_count = (i >= local_batches) ? 0 : element_count;
  91. for (int it = 0; it < WARP_ITERATIONS; ++it) {
  92. int element_index = local_idx + it * WARP_SIZE;
  93. if (element_index < batch_element_count) {
  94. elements[i][it] = src[i*element_count+it*WARP_SIZE];
  95. } else {
  96. elements[i][it] = -std::numeric_limits<acc_t>::infinity();
  97. }
  98. }
  99. }
  100. // compute max_value
  101. acc_t max_value[WARP_BATCH];
  102. #pragma unroll
  103. for (int i = 0; i < WARP_BATCH; ++i) {
  104. int batch_element_count = (i >= local_batches) ? 0 : element_count;
  105. bool is_meaningful_max = false;
  106. max_value[i] = elements[i][0];
  107. #pragma unroll
  108. for (int it = 0; it < WARP_ITERATIONS; ++it) {
  109. if (is_masked) {
  110. int idx = it*WARP_SIZE;
  111. if ((idx + local_idx) < batch_element_count) {
  112. if (!is_transformer_mask) {
  113. idx += i*element_count;
  114. }
  115. if (!mask[idx]) {
  116. max_value[i] = (is_meaningful_max && max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
  117. is_meaningful_max = true;
  118. }
  119. }
  120. } else {
  121. max_value[i] = max_value[i] > elements[i][it] ? max_value[i] : elements[i][it];
  122. }
  123. }
  124. if (is_masked) {
  125. if (!is_meaningful_max) {
  126. max_value[i] = -std::numeric_limits<acc_t>::infinity();
  127. }
  128. }
  129. }
  130. warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
  131. acc_t sum[WARP_BATCH] { 0.0f };
  132. #pragma unroll
  133. for (int i = 0; i < WARP_BATCH; ++i) {
  134. int batch_element_count = (i >= local_batches) ? 0 : element_count;
  135. #pragma unroll
  136. for (int it = 0; it < WARP_ITERATIONS; ++it) {
  137. if (!is_masked) {
  138. if (is_log_softmax) {
  139. sum[i] += std::exp(elements[i][it] - max_value[i]);
  140. } else {
  141. elements[i][it] = std::exp(elements[i][it] - max_value[i]);
  142. sum[i] += elements[i][it];
  143. }
  144. } else {
  145. int idx = it*WARP_SIZE;
  146. bool valid = (idx + local_idx) < batch_element_count;
  147. if (!is_transformer_mask) {
  148. idx += i*element_count;
  149. }
  150. if (valid) {
  151. if (!mask[idx]) {
  152. if (is_log_softmax) {
  153. sum[i] += std::exp(elements[i][it] - max_value[i]);
  154. } else {
  155. elements[i][it] = std::exp(elements[i][it] - max_value[i]);
  156. sum[i] += elements[i][it];
  157. }
  158. } else {
  159. if (!is_log_softmax) {
  160. // Masked values are treated as -infinity, and std::exp(-infinity) is 0.
  161. elements[i][it] = 0;
  162. }
  163. }
  164. } else {
  165. if (!is_log_softmax) {
  166. elements[i][it] = 0.;
  167. }
  168. }
  169. }
  170. }
  171. }
  172. warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
  173. // store result
  174. #pragma unroll
  175. for (int i = 0; i < WARP_BATCH; ++i) {
  176. if (i >= local_batches)
  177. break;
  178. if (is_log_softmax) sum[i] = std::log(sum[i]);
  179. #pragma unroll
  180. for (int it = 0; it < WARP_ITERATIONS; ++it) {
  181. int element_index = local_idx + it * WARP_SIZE;
  182. if (element_index < element_count) {
  183. if (is_log_softmax) {
  184. dst[i*element_count+it*WARP_SIZE] = elements[i][it] - max_value[i] - sum[i];
  185. } else if (sum[i] == 0) {
  186. dst[i*element_count+it*WARP_SIZE] = std::numeric_limits<acc_t>::quiet_NaN();
  187. } else {
  188. dst[i*element_count+it*WARP_SIZE] = elements[i][it] / sum[i];
  189. }
  190. } else {
  191. break;
  192. }
  193. }
  194. }
  195. }
  196. template <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_log_softmax, bool is_masked>
  197. __global__ void softmax_warp_backward(output_t *gradInput, const input_t *grad, const input_t *output, int batch_size, int stride, int element_count, const bool *mask = nullptr)
  198. {
  199. // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_backward_kernel.
  200. constexpr int next_power_of_two = 1 << log2_elements;
  201. constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
  202. constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
  203. constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
  204. int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
  205. // batch_size might not be a multiple of WARP_BATCH. Check how
  206. // many batches have to computed within this WARP.
  207. int local_batches = batch_size - first_batch;
  208. if (local_batches > WARP_BATCH)
  209. local_batches = WARP_BATCH;
  210. // there might be multiple batches per warp. compute the index within the batch
  211. int local_idx = threadIdx.x % WARP_SIZE;
  212. // the first element to process by the current thread
  213. int thread_offset = first_batch * stride + local_idx;
  214. grad += thread_offset;
  215. output += thread_offset;
  216. gradInput += thread_offset;
  217. if (is_masked) {
  218. mask += thread_offset;
  219. }
  220. // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified to one loop,
  221. // but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep
  222. // the nested loops.
  223. // This should have no impact on performance because the loops are unrolled anyway.
  224. // load data from global memory
  225. acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS];
  226. acc_t output_reg[WARP_BATCH][WARP_ITERATIONS];
  227. for (int i = 0; i < WARP_BATCH; ++i) {
  228. int batch_element_count = (i >= local_batches) ? 0 : element_count;
  229. for (int it = 0; it < WARP_ITERATIONS; ++it) {
  230. int element_index = local_idx + it * WARP_SIZE;
  231. if (element_index < batch_element_count) {
  232. grad_reg[i][it] = grad[i*element_count+it*WARP_SIZE];
  233. output_reg[i][it] = output[i*element_count+it*WARP_SIZE];
  234. } else {
  235. grad_reg[i][it] = acc_t(0);
  236. output_reg[i][it] = acc_t(0);
  237. }
  238. }
  239. }
  240. acc_t sum[WARP_BATCH] { 0.0f };
  241. #pragma unroll
  242. for (int i = 0; i < WARP_BATCH; ++i) {
  243. #pragma unroll
  244. for (int it = 0; it < WARP_ITERATIONS; ++it) {
  245. if (!is_masked || !mask[i*element_count+it*WARP_SIZE]) {
  246. sum[i] += grad_reg[i][it];
  247. }
  248. }
  249. }
  250. warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
  251. // store result
  252. #pragma unroll
  253. for (int i = 0; i < WARP_BATCH; ++i) {
  254. if (i >= local_batches)
  255. break;
  256. #pragma unroll
  257. for (int it = 0; it < WARP_ITERATIONS; ++it) {
  258. int element_index = local_idx + it * WARP_SIZE;
  259. if (element_index < element_count) {
  260. if (is_masked && mask[i*element_count+it*WARP_SIZE]) {
  261. gradInput[i*element_count+it*WARP_SIZE] = 0;
  262. }
  263. // compute gradients
  264. else if (is_log_softmax) {
  265. gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]);
  266. } else {
  267. gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - output_reg[i][it] * sum[i]);
  268. }
  269. }
  270. }
  271. }
  272. }
  273. } // end of anonymous namespace
  274. template<typename input_t, typename output_t, typename acc_t, bool is_log_softmax, bool is_masked>
  275. void dispatch_softmax_forward(output_t *dst, const input_t *src, int softmax_elements, int softmax_elements_stride, int batch_count, const bool *mask = nullptr, int chunk_size = -1, bool is_transformer_mask = false)
  276. {
  277. TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 1024 );
  278. if (softmax_elements == 0) {
  279. return;
  280. } else {
  281. int log2_elements = log2_ceil(softmax_elements);
  282. const int next_power_of_two = 1 << log2_elements;
  283. // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
  284. int warp_size = at::cuda::warp_size();
  285. warp_size = (next_power_of_two < warp_size) ? next_power_of_two : warp_size;
  286. // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
  287. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
  288. // use 128 threads per block to maximimize gpu utilization
  289. constexpr int threads_per_block = 128;
  290. int warps_per_block = (threads_per_block / warp_size);
  291. int batches_per_block = warps_per_block * batches_per_warp;
  292. int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
  293. dim3 threads(warp_size, warps_per_block, 1);
  294. // Launch code would be more elegant if C++ supported FOR CONSTEXPR
  295. switch (log2_elements) {
  296. #define LAUNCH_SOFTMAX_WARP_FORWARD(L2E) case L2E: \
  297. softmax_warp_forward<input_t, output_t, acc_t, L2E, is_log_softmax, is_masked> \
  298. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, \
  299. src, batch_count, softmax_elements_stride, softmax_elements, mask, chunk_size, is_transformer_mask); \
  300. C10_CUDA_KERNEL_LAUNCH_CHECK(); \
  301. break;
  302. LAUNCH_SOFTMAX_WARP_FORWARD(0); // 1
  303. LAUNCH_SOFTMAX_WARP_FORWARD(1); // 2
  304. LAUNCH_SOFTMAX_WARP_FORWARD(2); // 4
  305. LAUNCH_SOFTMAX_WARP_FORWARD(3); // 8
  306. LAUNCH_SOFTMAX_WARP_FORWARD(4); // 16
  307. LAUNCH_SOFTMAX_WARP_FORWARD(5); // 32
  308. LAUNCH_SOFTMAX_WARP_FORWARD(6); // 64
  309. LAUNCH_SOFTMAX_WARP_FORWARD(7); // 128
  310. LAUNCH_SOFTMAX_WARP_FORWARD(8); // 256
  311. LAUNCH_SOFTMAX_WARP_FORWARD(9); // 512
  312. LAUNCH_SOFTMAX_WARP_FORWARD(10); ; // 1024
  313. default:
  314. break;
  315. }
  316. }
  317. }
  318. template<typename input_t, typename output_t, typename acc_t, bool is_log_softmax, bool is_masked>
  319. void dispatch_softmax_backward(output_t *grad_input, const input_t *grad, const input_t *output, int softmax_elements, int softmax_elements_stride, int batch_count, const bool *mask = nullptr)
  320. {
  321. TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 1024 );
  322. if (softmax_elements == 0) {
  323. return;
  324. } else {
  325. int log2_elements = log2_ceil(softmax_elements);
  326. const int next_power_of_two = 1 << log2_elements;
  327. // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
  328. int warp_size = at::cuda::warp_size();
  329. warp_size = (next_power_of_two < warp_size) ? next_power_of_two : warp_size;
  330. // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
  331. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
  332. // use 128 threads per block to maximimize gpu utilization
  333. constexpr int threads_per_block = 128;
  334. int warps_per_block = (threads_per_block / warp_size);
  335. int batches_per_block = warps_per_block * batches_per_warp;
  336. int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
  337. dim3 threads(warp_size, warps_per_block, 1);
  338. // Launch code would be more elegant if C++ supported FOR CONSTEXPR
  339. switch (log2_elements) {
  340. #define LAUNCH_SOFTMAX_WARP_BACKWARD(L2E) case L2E: \
  341. softmax_warp_backward<input_t, output_t, acc_t, L2E, is_log_softmax, is_masked> \
  342. <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> \
  343. (grad_input, grad, output, batch_count, softmax_elements_stride, \
  344. softmax_elements, mask); \
  345. C10_CUDA_KERNEL_LAUNCH_CHECK(); \
  346. break;
  347. LAUNCH_SOFTMAX_WARP_BACKWARD(0); // 1
  348. LAUNCH_SOFTMAX_WARP_BACKWARD(1); // 2
  349. LAUNCH_SOFTMAX_WARP_BACKWARD(2); // 4
  350. LAUNCH_SOFTMAX_WARP_BACKWARD(3); // 8
  351. LAUNCH_SOFTMAX_WARP_BACKWARD(4); // 16
  352. LAUNCH_SOFTMAX_WARP_BACKWARD(5); // 32
  353. LAUNCH_SOFTMAX_WARP_BACKWARD(6); // 64
  354. LAUNCH_SOFTMAX_WARP_BACKWARD(7); // 128
  355. LAUNCH_SOFTMAX_WARP_BACKWARD(8); // 256
  356. LAUNCH_SOFTMAX_WARP_BACKWARD(9); // 512
  357. LAUNCH_SOFTMAX_WARP_BACKWARD(10); // 1024
  358. default:
  359. break;
  360. }
  361. }
  362. }