Normalization.cuh 72 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742
  1. #pragma once
  2. #include <ATen/core/Tensor.h>
  3. #include <ATen/Dispatch.h>
  4. #include <ATen/AccumulateType.h>
  5. #include <ATen/ceil_div.h>
  6. #include <ATen/cuda/CUDAContext.h>
  7. #include <ATen/cuda/DeviceUtils.cuh>
  8. #include <ATen/native/cuda/block_reduce.cuh>
  9. #include <ATen/native/cuda/DeviceSqrt.cuh>
  10. #include <ATen/native/cuda/LaunchUtils.h>
  11. #include <c10/macros/Macros.h>
  12. #ifndef AT_PER_OPERATOR_HEADERS
  13. #include <ATen/Functions.h>
  14. #else
  15. #include <ATen/ops/empty.h>
  16. #include <ATen/ops/empty_like.h>
  17. #include <ATen/ops/zeros.h>
  18. #endif
  19. namespace at { namespace native {
  20. // The maximum number of threads in a block
  21. #if defined(USE_ROCM)
  22. constexpr int MAX_BLOCK_SIZE = 256;
  23. #else
  24. constexpr int MAX_BLOCK_SIZE = 512;
  25. #endif
  26. constexpr unsigned MAX_GRID_SIZE = 65535u;
  27. // Number of threads in a block given an input size up to MAX_BLOCK_SIZE
  28. static int getNumThreads(int nElem) {
  29. #if defined(USE_ROCM)
  30. int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE };
  31. #else
  32. int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
  33. #endif
  34. for (int i = 0; i != 5; ++i) {
  35. if (nElem <= threadSizes[i]) {
  36. return threadSizes[i];
  37. }
  38. }
  39. return MAX_BLOCK_SIZE;
  40. }
  41. // Returns the index of the most significant 1 bit in `val`.
  42. __device__ __forceinline__ int getMSB(int val) {
  43. return 31 - __clz(val);
  44. }
  45. template <typename scalar_t, typename accscalar_t>
  46. struct Float2 {
  47. accscalar_t v1, v2;
  48. __device__ Float2() {}
  49. __device__ Float2(scalar_t v1, scalar_t v2) : v1(static_cast<accscalar_t>(v1)), v2(static_cast<accscalar_t>(v2)) {}
  50. __device__ Float2(int v) : v1(static_cast<accscalar_t>(v)), v2(static_cast<accscalar_t>(v)) {}
  51. __device__ Float2& operator+=(const Float2& a) {
  52. v1 += a.v1;
  53. v2 += a.v2;
  54. return *this;
  55. }
  56. __device__ friend Float2 operator+(Float2 a, const Float2& b) {
  57. a += b;
  58. return a;
  59. }
  60. };
  61. template <typename scalar_t, typename accscalar_t, typename PTA>
  62. struct GradOp {
  63. __device__ GradOp(accscalar_t m, const PTA& i, const PTA& g)
  64. : mean(m), input(i), grad_output(g) {}
  65. __device__ __forceinline__ Float2<scalar_t, accscalar_t> operator()(int batch, int plane, int n) {
  66. accscalar_t g = grad_output[batch][plane][n];
  67. accscalar_t c = static_cast<accscalar_t>(input[batch][plane][n]) - mean;
  68. return Float2<scalar_t, accscalar_t>(g, g * c);
  69. }
  70. const accscalar_t mean;
  71. const PTA& input;
  72. const PTA& grad_output;
  73. };
  74. template <typename acc_t>
  75. struct SumReduceOp {
  76. __device__ __forceinline__ acc_t combine(acc_t a, acc_t b) const { return a + b; }
  77. __device__ __forceinline__ acc_t warp_shfl_down(acc_t data, int offset) const {
  78. return WARP_SHFL_DOWN(data, offset);
  79. }
  80. };
  81. template <typename scalar_t, typename accscalar_t>
  82. struct SumReduceOp<Float2<scalar_t, accscalar_t>> {
  83. using acc_t = Float2<scalar_t, accscalar_t>;
  84. __device__ __forceinline__ acc_t combine(acc_t a, acc_t b) const { return a + b; }
  85. __device__ __forceinline__ acc_t warp_shfl_down(acc_t data, int offset) const {
  86. return {WARP_SHFL_DOWN(data.v1, offset), WARP_SHFL_DOWN(data.v2, offset)};
  87. }
  88. };
  89. // Sum across (batch, x/y/z) applying Op() pointwise
  90. // this works by first having each thread sum it's part
  91. // of the data. Then there is a double-shuffling reduction.
  92. // First each warp (of C10_WARP_SIZE threads) uses warpSum to reduce its
  93. // data to the "warp leader", who writes its value into shared memory.
  94. // Then a single warp reads the remaining (at most C10_WARP_SIZE) items
  95. // and reduces them using another warpSum.
  96. // The implicit assumption is that there are no more
  97. // than C10_WARP_SIZE**2 threads.
  98. template<typename scalar_t, typename Op, typename PTA>
  99. __device__ scalar_t reduce(Op op, PTA tensor, int plane) {
  100. // first the reductions each thread does separately
  101. scalar_t sum = static_cast<scalar_t>(0);
  102. for (int batch = threadIdx.y; batch < tensor.size(0); batch += blockDim.y) {
  103. for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x) {
  104. sum += op(batch, plane, x);
  105. }
  106. }
  107. __shared__ scalar_t shared[C10_WARP_SIZE];
  108. SumReduceOp<scalar_t> reduce_op;
  109. sum = cuda_utils::BlockReduce<scalar_t, SumReduceOp<scalar_t>, cuda_utils::Block2D>(sum, reduce_op, 0, shared);
  110. if (threadIdx.x == 0 && threadIdx.y == 0) {
  111. shared[0] = sum;
  112. }
  113. __syncthreads();
  114. // Everyone picks it up, should be broadcast into the whole grad_input
  115. return shared[0];
  116. }
  117. constexpr int ELEMENTS_PER_ITER = 4; // enables concurrency within each thread to hide latency
  118. constexpr int ELEMENTS_PER_THREAD = 16;
  119. constexpr int OPTIMAL_TILE_W = 32;
  120. constexpr int MAX_H_BLOCK = 128;
  121. __host__ void flexible_launch_configs(
  122. const int reduction,
  123. const int stride,
  124. dim3 &block,
  125. dim3 &grid,
  126. const bool coop_flag = false) {
  127. int block_x = std::min(lastPow2(stride), OPTIMAL_TILE_W);
  128. int block_y = std::min(lastPow2(at::ceil_div(reduction , ELEMENTS_PER_THREAD)),
  129. MAX_BLOCK_SIZE / block_x);
  130. if (block_x * block_y != MAX_BLOCK_SIZE) {
  131. block_x = std::min(lastPow2(stride), MAX_BLOCK_SIZE / block_y);
  132. }
  133. int grid_x = at::ceil_div(stride, block_x);
  134. int grid_y = std::min(at::ceil_div(reduction, block_y * ELEMENTS_PER_THREAD), MAX_H_BLOCK);
  135. if (coop_flag) {
  136. // it's not worth having a grid reduction if the reduction dimension is not big enough
  137. grid_y = grid_y < 8 ? 1 : grid_y;
  138. }
  139. block.x = block_x;
  140. block.y = block_y;
  141. block.z = 1;
  142. grid.x = grid_x;
  143. grid.y = grid_y;
  144. grid.z = 1;
  145. }
  146. template<typename T, typename C>
  147. __device__ __forceinline__ void welford_merge_element(C& count,
  148. T& mean,
  149. T& m2n,
  150. const C& count_new,
  151. const T& mean_new,
  152. const T& m2n_new) {
  153. T factor = T(1.0) / ::max(1, (count + count_new));
  154. T delta0 = mean - mean_new;
  155. mean = (mean_new * count_new + mean * count) * factor;
  156. m2n += m2n_new + delta0 * delta0 * count_new * count * factor;
  157. count += count_new;
  158. }
  159. // merge mean/m2n among threadIdx.y within block
  160. template<typename T, typename C>
  161. __device__ __forceinline__ void welford_merge_block_vertical(C& count,
  162. T& mean,
  163. T& m2n,
  164. C* shmem_count,
  165. T* shmem_mean,
  166. T* shmem_m2n) {
  167. // write to shared memory
  168. auto address_base = threadIdx.x + threadIdx.y * blockDim.x;
  169. #pragma unroll
  170. for (int offset = blockDim.y/2; offset > 0; offset >>= 1) {
  171. if (threadIdx.y < offset*2) {
  172. shmem_mean[address_base] = mean;
  173. shmem_m2n[address_base] = m2n;
  174. shmem_count[address_base] = count;
  175. }
  176. __syncthreads();
  177. if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
  178. auto address = address_base + offset * blockDim.x;
  179. // read shared memory back to register for reduction
  180. auto count_new = shmem_count[address];
  181. auto mean_new = shmem_mean[address];
  182. auto m2n_new = shmem_m2n[address];
  183. welford_merge_element(count, mean, m2n, count_new, mean_new, m2n_new);
  184. }
  185. }
  186. }
  187. template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, bool train, typename index_t>
  188. __global__ void batch_norm_transform_input_kernel(
  189. const GenericPackedTensorAccessor<input_scalar_t, 3, RestrictPtrTraits, index_t> input,
  190. GenericPackedTensorAccessor<input_scalar_t, 3, RestrictPtrTraits, index_t> output,
  191. const GenericPackedTensorAccessor<typename std::conditional<train, stat_accscalar_t, stat_scalar_t>::type, 1, RestrictPtrTraits, index_t> mean_,
  192. const GenericPackedTensorAccessor<typename std::conditional<train, stat_accscalar_t, stat_scalar_t>::type, 1, RestrictPtrTraits, index_t> var_or_invstd,
  193. const GenericPackedTensorAccessor<stat_scalar_t, 1, RestrictPtrTraits, index_t> weight,
  194. const GenericPackedTensorAccessor<stat_scalar_t, 1, RestrictPtrTraits, index_t> bias,
  195. stat_accscalar_t epsilon) {
  196. index_t plane = blockIdx.x;
  197. if (plane >= input.size(1)) {
  198. return;
  199. }
  200. stat_accscalar_t gamma = weight.size(0) > 0 ? static_cast<stat_accscalar_t>(weight[plane]) : static_cast<stat_accscalar_t>(1);
  201. stat_accscalar_t beta = bias.size(0) > 0 ? static_cast<stat_accscalar_t>(bias[plane]) : static_cast<stat_accscalar_t>(0);
  202. stat_accscalar_t mean = static_cast<stat_accscalar_t>(mean_[plane]);
  203. stat_accscalar_t invstd;
  204. if (train) {
  205. invstd = var_or_invstd[plane];
  206. } else {
  207. invstd = static_cast<stat_accscalar_t>(1) / device_sqrt(static_cast<stat_accscalar_t>(var_or_invstd[plane]) + epsilon);
  208. }
  209. index_t bs = input.size(0);
  210. index_t fs = input.size(2);
  211. index_t bstep = blockDim.y * gridDim.y;
  212. for (index_t batch = threadIdx.y + blockIdx.y * blockDim.y; batch < bs; batch += bstep) {
  213. auto o = output[batch][plane];
  214. auto i = input[batch][plane];
  215. for (index_t feature = threadIdx.x; feature < fs; feature += blockDim.x) {
  216. o[feature] = static_cast<input_scalar_t>(gamma * (i[feature] - mean) * invstd + beta);
  217. }
  218. }
  219. }
  220. struct InvStd {
  221. template <typename T>
  222. __device__ __forceinline__ T operator()(T var, double epsilon) const {
  223. T invstd = 0;
  224. if (var != static_cast<T>(0) || epsilon != static_cast<T>(0)) {
  225. invstd = static_cast<T>(1) / device_sqrt(var + epsilon);
  226. }
  227. return invstd;
  228. }
  229. };
  230. struct Var {
  231. template <typename T>
  232. __device__ __forceinline__ T operator()(T var, double epsilon) const {
  233. return var;
  234. }
  235. };
  236. template <typename VarTransform, typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
  237. __global__ void batch_norm_collect_statistics_kernel(
  238. const GenericPackedTensorAccessor<input_scalar_t, 3, RestrictPtrTraits, index_t> input,
  239. const stat_accscalar_t epsilon,
  240. const stat_accscalar_t momentum,
  241. GenericPackedTensorAccessor<stat_accscalar_t, 1, RestrictPtrTraits, index_t> save_mean,
  242. GenericPackedTensorAccessor<stat_accscalar_t, 1, RestrictPtrTraits, index_t> save_transformed_var) {
  243. __shared__ int shared_n[2 * 2 * C10_WARP_SIZE + C10_WARP_SIZE];
  244. int plane = blockIdx.x;
  245. int N = input.size(0) * input.size(2);
  246. int tid = threadIdx.x + threadIdx.y * blockDim.x;
  247. // Compute the mean and variance across (batch, x/y/z)
  248. // this uses the Welford (in the for loop)/parallel algorithm (to sum across the block)
  249. // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
  250. // and the parallel algorithm on the same page.
  251. // We use two shuffles to reduce across the entire block.
  252. // https://devblogs.nvidia.com/faster-parallel-reductions-kepler/ has a description.
  253. stat_accscalar_t* shared_avg_var = (stat_accscalar_t*) &shared_n[C10_WARP_SIZE];
  254. // first the reductions each thread does separately
  255. stat_accscalar_t avg = 0;
  256. stat_accscalar_t var_n = 0;
  257. int n = 0;
  258. for (int batch = threadIdx.y; batch < input.size(0); batch += blockDim.y) {
  259. for (int x = threadIdx.x; x < input.size(2); x += blockDim.x) {
  260. stat_accscalar_t v = input[batch][plane][x];
  261. stat_accscalar_t d1 = v - avg;
  262. n++;
  263. avg += d1 / n;
  264. var_n += d1 * (v - avg);
  265. }
  266. }
  267. // first warpSum to get one value per thread to
  268. // one value per warp
  269. for (int i = 0; i < getMSB(C10_WARP_SIZE); ++i) {
  270. stat_accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, C10_WARP_SIZE);
  271. int o_n = WARP_SHFL_XOR(n, 1 << i, C10_WARP_SIZE);
  272. stat_accscalar_t factor = 1.0 / fmaxf(1.0, n+o_n);
  273. var_n += WARP_SHFL_XOR(var_n, 1 << i, C10_WARP_SIZE) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor;
  274. avg = (n * avg + o_n * o_avg) * factor;
  275. n += o_n;
  276. }
  277. // this writes each warps item into shared memory
  278. // there are at most C10_WARP_SIZE items left because
  279. // there are at most C10_WARP_SIZE**2 threads at the beginning
  280. __syncthreads();
  281. if (tid % C10_WARP_SIZE == 0) {
  282. shared_n[tid / C10_WARP_SIZE] = n;
  283. shared_avg_var[tid / C10_WARP_SIZE * 2] = avg;
  284. shared_avg_var[tid / C10_WARP_SIZE * 2 + 1] = var_n;
  285. }
  286. __syncthreads();
  287. // now have a second warpSum to reduce the intermediate values
  288. // from shared memory to a single number. The very first
  289. // thread writes it to shared memory.
  290. if (tid < C10_WARP_SIZE) {
  291. n = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_n[tid] : 0);
  292. avg = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_avg_var[2 * tid] : stat_accscalar_t(0));
  293. var_n = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_avg_var[2 * tid + 1] : stat_accscalar_t(0));
  294. }
  295. for (int i = 0; i < getMSB(C10_WARP_SIZE); ++i) {
  296. stat_accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, C10_WARP_SIZE);
  297. int o_n = WARP_SHFL_XOR(n, 1 << i, C10_WARP_SIZE);
  298. stat_accscalar_t factor = 1.0 / fmaxf(1.0, n+o_n);
  299. var_n += WARP_SHFL_XOR(var_n, 1 << i, C10_WARP_SIZE) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor;
  300. avg = (n * avg + o_n * o_avg) * factor;
  301. n += o_n;
  302. }
  303. // Save the mean, variance, and moving averages
  304. if (tid == 0) {
  305. if (save_mean.data() != NULL) {
  306. save_mean[plane] = avg;
  307. }
  308. if (save_transformed_var.data() != NULL) {
  309. save_transformed_var[plane] = VarTransform{}(var_n / N, epsilon);
  310. }
  311. }
  312. }
  313. template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
  314. __global__ void batch_norm_backward_kernel(
  315. const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> input,
  316. const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
  317. GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_input,
  318. GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> grad_weight,
  319. GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> grad_bias,
  320. const GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> weight,
  321. const GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> running_mean,
  322. const GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> running_var,
  323. const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> save_mean,
  324. const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> save_invstd,
  325. bool train,
  326. stat_accscalar_t epsilon) {
  327. index_t plane = blockIdx.x;
  328. index_t N = grad_output.size(0) * grad_output.size(2);
  329. stat_accscalar_t mean, invstd;
  330. if (train) {
  331. mean = save_mean[plane];
  332. invstd = save_invstd[plane];
  333. } else {
  334. mean = static_cast<stat_accscalar_t>(running_mean[plane]);
  335. invstd = static_cast<stat_accscalar_t>(1) / device_sqrt(static_cast<stat_accscalar_t>(running_var[plane]) + epsilon);
  336. }
  337. stat_accscalar_t weight_val = weight.size(0) > 0 ? static_cast<stat_accscalar_t>(weight[plane]) : stat_accscalar_t(1);
  338. stat_accscalar_t norm = stat_accscalar_t(1) / N;
  339. // Compute two values across (batch, x/y/z) in one pass:
  340. // 1. Sum(grad_output)
  341. // 2. DotProduct(input - mean, grad_output)
  342. GradOp<input_scalar_t, stat_accscalar_t, GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t>> g(mean, input, grad_output);
  343. auto res = reduce<Float2<input_scalar_t, stat_accscalar_t>>(g, grad_output, plane);
  344. stat_accscalar_t grad_output_sum = res.v1;
  345. stat_accscalar_t dot_p = res.v2;
  346. stat_accscalar_t grad_mean = grad_output_sum * norm;
  347. stat_accscalar_t proj_scale = dot_p * norm * invstd * invstd;
  348. stat_accscalar_t grad_scale = invstd * weight_val;
  349. if (grad_input.data() != NULL) {
  350. for (int batch = threadIdx.y; batch < grad_output.size(0); batch += blockDim.y) {
  351. for (int x = threadIdx.x; x < grad_output.size(2); x += blockDim.x) {
  352. input_scalar_t go = grad_output[batch][plane][x];
  353. if (train) {
  354. stat_accscalar_t inp = input[batch][plane][x];
  355. stat_accscalar_t proj = (inp - mean) * proj_scale;
  356. grad_input[batch][plane][x] = static_cast<input_scalar_t>((go - proj - grad_mean) * grad_scale);
  357. } else {
  358. grad_input[batch][plane][x] = static_cast<input_scalar_t>(go * grad_scale);
  359. }
  360. }
  361. }
  362. }
  363. if (grad_weight.size(0) > 0) {
  364. if (threadIdx.x == 0) {
  365. grad_weight[plane] = static_cast<stat_scalar_t>(dot_p * invstd);
  366. }
  367. }
  368. if (grad_bias.size(0) > 0) {
  369. if (threadIdx.x == 0) {
  370. grad_bias[plane] = static_cast<stat_scalar_t>(grad_output_sum);
  371. }
  372. }
  373. }
  374. template <typename scalar_t, typename accscalar_t, typename index_t>
  375. __global__ void batch_norm_reduce_statistics_kernel(
  376. const GenericPackedTensorAccessor<accscalar_t, 2, RestrictPtrTraits, index_t> vec_mean,
  377. const GenericPackedTensorAccessor<accscalar_t, 2, RestrictPtrTraits, index_t> vec_invstd,
  378. GenericPackedTensorAccessor<accscalar_t, 1, RestrictPtrTraits, index_t> mean,
  379. GenericPackedTensorAccessor<accscalar_t, 1, RestrictPtrTraits, index_t> invstd,
  380. GenericPackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> running_mean,
  381. GenericPackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> running_var,
  382. const accscalar_t epsilon,
  383. const accscalar_t momentum,
  384. const GenericPackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> counts) {
  385. int feature_size = vec_mean.size(1);
  386. int world_size = vec_mean.size(0);
  387. int bid = blockIdx.x;
  388. int tid = threadIdx.x;
  389. // first the reductions each thread does separately
  390. for (int i = bid*blockDim.x+tid; i < feature_size; i += gridDim.x*blockDim.x) {
  391. accscalar_t avg = 0;
  392. accscalar_t var_n = 0;
  393. index_t n = 0;
  394. for (int j = 0; j < world_size; j++) {
  395. scalar_t count = counts[j];
  396. accscalar_t m = vec_mean[j][i];
  397. accscalar_t v = accscalar_t(1.0) / (vec_invstd[j][i]);
  398. v = (v * v - epsilon) * count;
  399. accscalar_t factor = 1.0 / (n + count);
  400. var_n += v + (avg - m) * (avg - m) * n * count * factor;
  401. avg = n * factor * avg + count * factor * m;
  402. n += count;
  403. }
  404. mean[i] = avg;
  405. invstd[i] = static_cast<accscalar_t>(1) / device_sqrt(var_n / n + epsilon);
  406. if (running_mean.data() != NULL) {
  407. running_mean[i] = static_cast<scalar_t>((1 - momentum) * running_mean[i] + momentum * avg);
  408. }
  409. accscalar_t unbiasedVar = var_n / (n - 1);
  410. if (running_var.data() != NULL) {
  411. running_var[i] = static_cast<scalar_t>((1 - momentum) * running_var[i] + momentum * unbiasedVar);
  412. }
  413. }
  414. }
  415. template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
  416. __global__ void batch_norm_backward_reduce_kernel(
  417. const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> input,
  418. const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
  419. GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> mean,
  420. GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> invstd,
  421. GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy,
  422. GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy_xmu,
  423. GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> grad_weight,
  424. GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> grad_bias) {
  425. index_t plane = blockIdx.x;
  426. stat_accscalar_t r_mean = mean[plane];
  427. stat_accscalar_t factor = invstd[plane];
  428. GradOp<input_scalar_t, stat_accscalar_t, GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t>> g(r_mean, input, grad_output);
  429. auto res = reduce<Float2<input_scalar_t, stat_accscalar_t>>(g, grad_output, plane);
  430. if (threadIdx.x == 0) {
  431. if (grad_weight.size(0) > 0) {
  432. grad_weight[plane] = static_cast<stat_scalar_t>(res.v2 * factor);
  433. }
  434. if (grad_bias.size(0) > 0) {
  435. grad_bias[plane] = static_cast<stat_scalar_t>(res.v1);
  436. }
  437. if (sum_dy.size(0) > 0) {
  438. sum_dy[plane] = static_cast<stat_accscalar_t>(res.v1);
  439. }
  440. if (sum_dy_xmu.size(0) > 0) {
  441. sum_dy_xmu[plane] = static_cast<stat_accscalar_t>(res.v2);
  442. }
  443. }
  444. }
  445. template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
  446. __device__ __forceinline__ void batch_norm_backward_elemt_kernel_impl(
  447. const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> input,
  448. const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
  449. const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> mean,
  450. const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> invstd,
  451. const GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> weight,
  452. const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy,
  453. const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy_xmu,
  454. GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_input,
  455. const stat_accscalar_t norm_fct) {
  456. index_t plane = blockIdx.x;
  457. if (plane >= input.size(1)) {
  458. return;
  459. }
  460. stat_accscalar_t m_c = mean[plane];
  461. stat_accscalar_t m_dy_c = sum_dy[plane] * norm_fct;
  462. stat_accscalar_t factor_1_c = invstd[plane];
  463. stat_accscalar_t factor_2_c = weight.size(0) > 0 ? static_cast<stat_accscalar_t>(weight[plane]) : stat_accscalar_t(1);
  464. factor_2_c *= factor_1_c;
  465. factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[plane] * norm_fct;
  466. index_t bs = input.size(0);
  467. index_t fs = input.size(2);
  468. index_t bstep = blockDim.y * gridDim.y;
  469. for (index_t batch = threadIdx.y + blockIdx.y * blockDim.y; batch < bs; batch += bstep) {
  470. auto g_i = grad_input[batch][plane];
  471. auto g_o = grad_output[batch][plane];
  472. auto i = input[batch][plane];
  473. for (index_t feature = threadIdx.x; feature < fs; feature += blockDim.x) {
  474. g_i[feature] = static_cast<input_scalar_t>((g_o[feature] - m_dy_c - (i[feature] - m_c) * factor_1_c) * factor_2_c);
  475. }
  476. }
  477. }
  478. template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
  479. __global__ void batch_norm_backward_elemt_kernel(
  480. const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> input,
  481. const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
  482. const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> mean,
  483. const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> invstd,
  484. const GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> weight,
  485. const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy,
  486. const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy_xmu,
  487. GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_input,
  488. const int* __restrict__ numel, const int world_size) {
  489. int64_t total_numel = 0;
  490. for (int i = 0; i < world_size; i ++) {
  491. total_numel += numel[i];
  492. }
  493. const stat_accscalar_t norm_fct =
  494. static_cast<stat_accscalar_t>(1) / static_cast<stat_accscalar_t>(total_numel);
  495. batch_norm_backward_elemt_kernel_impl(
  496. input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct);
  497. }
  498. template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
  499. __global__ void batch_norm_backward_elemt_kernel(
  500. const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> input,
  501. const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
  502. const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> mean,
  503. const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> invstd,
  504. const GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> weight,
  505. const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy,
  506. const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy_xmu,
  507. GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_input,
  508. const stat_accscalar_t norm_fct) {
  509. batch_norm_backward_elemt_kernel_impl(
  510. input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct);
  511. }
  512. template <typename scalar_t, int64_t dim, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
  513. static GenericPackedTensorAccessor<scalar_t, dim, PtrTraits, index_t> get_packed_accessor(
  514. const Tensor& t, c10::string_view var_name) {
  515. constexpr auto expect_type = c10::CppTypeToScalarType<scalar_t>::value;
  516. const auto actual_type = t.scalar_type();
  517. TORCH_CHECK(actual_type == expect_type, "Expected ", var_name,
  518. " to have type ", expect_type, " but got ", actual_type);
  519. return t.generic_packed_accessor<scalar_t, dim, PtrTraits, index_t>();
  520. }
  521. template <typename scalar_t, int64_t dim, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
  522. static GenericPackedTensorAccessor<scalar_t, dim, PtrTraits, index_t> packed_accessor_or_dummy(
  523. const Tensor& t, c10::string_view var_name) {
  524. if (!t.defined()) {
  525. const std::array<index_t, dim> zeros{{0}};
  526. return GenericPackedTensorAccessor<scalar_t, dim, PtrTraits, index_t>(nullptr, zeros.data(), zeros.data());
  527. }
  528. return get_packed_accessor<scalar_t, dim, PtrTraits, index_t>(t, var_name);
  529. }
  530. template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
  531. std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda_template(const Tensor& grad_out_, const Tensor& input_, const Tensor& weight_,
  532. const Tensor& running_mean_, const Tensor& running_var_, const Tensor& save_mean_, const Tensor& save_invstd_,
  533. bool train, double epsilon, std::array<bool,3> grad_input_mask) {
  534. using accscalar_t = at::acc_type<stat_scalar_t, true>;
  535. Tensor grad_input_;
  536. Tensor grad_input_reshaped;
  537. Tensor grad_weight_;
  538. Tensor grad_bias_;
  539. auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1});
  540. auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes());
  541. if (grad_input_mask[0]) {
  542. grad_input_ = at::empty_like(input_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  543. grad_input_reshaped = grad_input_.view(input_reshaped.sizes());
  544. }
  545. if (grad_input_mask[1]) {
  546. grad_weight_ = at::empty_like(weight_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  547. }
  548. if (grad_input_mask[2]) {
  549. grad_bias_ = at::empty_like(weight_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  550. }
  551. auto input = get_packed_accessor<
  552. input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input");
  553. auto grad_output = get_packed_accessor<
  554. input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output");
  555. auto grad_input = packed_accessor_or_dummy<
  556. input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input");
  557. auto weight = packed_accessor_or_dummy<
  558. stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight");
  559. auto grad_weight = packed_accessor_or_dummy<
  560. stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_weight_, "grad_weight");
  561. auto grad_bias = packed_accessor_or_dummy<
  562. stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_bias_, "grad_bias");
  563. auto running_mean = packed_accessor_or_dummy<
  564. stat_scalar_t, 1, DefaultPtrTraits, index_t>(running_mean_, "running_mean");
  565. auto running_var = packed_accessor_or_dummy<
  566. stat_scalar_t, 1, DefaultPtrTraits, index_t>(running_var_, "running_var");
  567. auto save_mean = packed_accessor_or_dummy<
  568. accscalar_t, 1, DefaultPtrTraits, index_t>(save_mean_, "save_mean");
  569. auto save_invstd = packed_accessor_or_dummy<
  570. accscalar_t, 1, DefaultPtrTraits, index_t>(save_invstd_, "save_invstd");
  571. auto stream = at::cuda::getCurrentCUDAStream();
  572. dim3 blocks(input.size(1));
  573. int tf = getNumThreads(input.size(2));
  574. dim3 threads(tf, std::max<int>(1, MAX_BLOCK_SIZE/tf));
  575. batch_norm_backward_kernel<input_scalar_t, stat_scalar_t, accscalar_t, index_t> <<<blocks, threads, 0, stream>>>
  576. (input, grad_output, grad_input, grad_weight, grad_bias, weight, running_mean, running_var,
  577. save_mean, save_invstd, train, epsilon);
  578. C10_CUDA_KERNEL_LAUNCH_CHECK();
  579. return std::make_tuple(grad_input_, grad_weight_, grad_bias_);
  580. }
  581. template<typename scalar_t, typename index_t, typename VarTransform>
  582. void batch_norm_stats_cuda_template(
  583. const Tensor& out_mean, const Tensor& out_invstd, const Tensor& input_, double epsilon) {
  584. using accscalar_t = at::acc_type<scalar_t, true>;
  585. int64_t n_input = input_.size(1);
  586. Tensor dummy_mean_;
  587. Tensor dummy_var_;
  588. auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
  589. resize_output(out_mean, {n_input});
  590. resize_output(out_invstd, {n_input});
  591. auto input = get_packed_accessor<
  592. scalar_t, 3, RestrictPtrTraits, index_t>(input_reshaped, "input");
  593. TORCH_INTERNAL_ASSERT(out_invstd.dim() == 1 && out_invstd.is_contiguous() &&
  594. out_invstd.sizes()[0]);
  595. TORCH_INTERNAL_ASSERT(out_mean.dim() == 1 && out_mean.is_contiguous() &&
  596. out_mean.sizes()[0]);
  597. auto mean = packed_accessor_or_dummy<
  598. accscalar_t, 1, RestrictPtrTraits, index_t>(out_mean, "out_mean");
  599. auto invstd = packed_accessor_or_dummy<
  600. accscalar_t, 1, RestrictPtrTraits, index_t>(out_invstd, "out_invstd");
  601. auto stream = at::cuda::getCurrentCUDAStream();
  602. dim3 blocks(input.size(1));
  603. int tf = getNumThreads(input.size(2));
  604. dim3 threads(tf, std::max<int>(1, MAX_BLOCK_SIZE/tf));
  605. batch_norm_collect_statistics_kernel<VarTransform, scalar_t, scalar_t, accscalar_t, index_t> <<<blocks, threads, 0, stream>>>
  606. (input, epsilon, 0.0, mean, invstd);
  607. C10_CUDA_KERNEL_LAUNCH_CHECK();
  608. }
  609. template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
  610. void batch_norm_elemt_cuda_template(const Tensor& output_, const Tensor& input_, const Tensor& weight_,
  611. const Tensor& bias_, const Tensor& mean_, const Tensor& invstd_) {
  612. using stat_accscalar_t = at::acc_type<stat_scalar_t, true>;
  613. int64_t n_input = input_.size(1);
  614. auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
  615. auto output_reshaped = output_.view({input_.size(0), input_.size(1), -1});
  616. auto input = get_packed_accessor<
  617. input_scalar_t, 3, RestrictPtrTraits, index_t>(input_reshaped, "input");
  618. auto output = get_packed_accessor<
  619. input_scalar_t, 3, RestrictPtrTraits, index_t>(output_reshaped, "output");
  620. auto weight = packed_accessor_or_dummy<
  621. stat_scalar_t, 1, RestrictPtrTraits, index_t>(weight_, "weight");
  622. auto bias = packed_accessor_or_dummy<
  623. stat_scalar_t, 1, RestrictPtrTraits, index_t>(bias_, "bias");
  624. auto mean = packed_accessor_or_dummy<
  625. stat_accscalar_t, 1, RestrictPtrTraits, index_t>(mean_, "mean");
  626. auto invstd = packed_accessor_or_dummy<
  627. stat_accscalar_t, 1, RestrictPtrTraits, index_t>(invstd_, "invstd");
  628. auto stream = at::cuda::getCurrentCUDAStream();
  629. // NOTE: We use transform_input_kernel in training mode, which ignores epsilon
  630. const double dummy_epsilon = 1e-5;
  631. // The input_transform kernel is pointwise, but we need to balance reading parameters (save_var/mean,
  632. // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks
  633. // and good occupancy. Quiet likely, we could go with even more blocks than 1024.
  634. // The various planes are independent, so we use blocks for them.
  635. int tf = std::max<int>(getNumThreads(input.size(2)/4),
  636. std::min<int>(getNumThreads(input.size(2)), 64));
  637. int tb = std::max<int>(64/tf, 1);
  638. dim3 blocks_trans(input.size(1), std::max<int>(1, std::min<int>((256*1024)/input.size(1),
  639. (input.size(0)+tb-1)/tb)));
  640. blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE);
  641. dim3 threads_trans(tf, tb);
  642. batch_norm_transform_input_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, true, index_t> <<<blocks_trans, threads_trans, 0, stream>>>
  643. (input, output, mean, invstd, weight, bias, dummy_epsilon);
  644. C10_CUDA_KERNEL_LAUNCH_CHECK();
  645. }
  646. template<typename scalar_t, typename accscalar_t, typename index_t>
  647. std::tuple<Tensor, Tensor> batch_norm_gather_stats_cuda_template(const Tensor& mean_, const Tensor& invstd_,
  648. const Tensor& running_mean_, const Tensor& running_var_,
  649. double momentum, double epsilon, const Tensor& counts_) {
  650. Tensor save_mean_;
  651. Tensor save_invstd_;
  652. auto features = mean_.size(1);
  653. auto input_options = mean_.options();
  654. if (mean_.scalar_type() == at::ScalarType::Half || mean_.scalar_type() == at::ScalarType::BFloat16) {
  655. input_options = input_options.dtype(ScalarType::Float);
  656. }
  657. save_mean_ = at::empty({features}, input_options);
  658. save_invstd_ = at::empty({features}, input_options);
  659. auto mean = packed_accessor_or_dummy<
  660. accscalar_t, 2, RestrictPtrTraits, index_t>(mean_, "mean");
  661. auto invstd = packed_accessor_or_dummy<
  662. accscalar_t, 2, RestrictPtrTraits, index_t>(invstd_, "invstd");
  663. auto running_mean = packed_accessor_or_dummy<
  664. scalar_t, 1, RestrictPtrTraits, index_t>(running_mean_, "running_mean");
  665. auto running_var = packed_accessor_or_dummy<
  666. scalar_t, 1, RestrictPtrTraits, index_t>(running_var_, "running_mean");
  667. auto counts = packed_accessor_or_dummy<
  668. scalar_t, 1, RestrictPtrTraits, index_t>(counts_, "counts");
  669. auto save_mean = get_packed_accessor<
  670. accscalar_t, 1, RestrictPtrTraits, index_t>(save_mean_, "save_mean");
  671. auto save_invstd = get_packed_accessor<
  672. accscalar_t, 1, RestrictPtrTraits, index_t>(save_invstd_, "save_invstd");
  673. auto stream = at::cuda::getCurrentCUDAStream();
  674. int block = getNumThreads(features);
  675. int grid = std::max<int>(1, features/block);
  676. batch_norm_reduce_statistics_kernel<scalar_t, accscalar_t, index_t> <<<grid, block, 0, stream>>>
  677. (mean, invstd, save_mean, save_invstd, running_mean, running_var, epsilon, momentum, counts);
  678. C10_CUDA_KERNEL_LAUNCH_CHECK();
  679. return std::make_tuple(save_mean_, save_invstd_);
  680. }
  681. template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
  682. std::tuple<Tensor, Tensor, Tensor, Tensor> batch_norm_backward_reduce_cuda_template(const Tensor& grad_out_, const Tensor& input_,
  683. const Tensor& mean_, const Tensor& invstd_, const Tensor& weight_,
  684. const bool input_g, const bool weight_g, const bool bias_g) {
  685. using stat_accscalar_t = at::acc_type<stat_scalar_t, true>;
  686. int64_t n_input = input_.size(1);
  687. Tensor sum_dy_;
  688. Tensor sum_dy_xmu_;
  689. Tensor grad_weight_;
  690. Tensor grad_bias_;
  691. auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
  692. auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes());
  693. if (input_g) {
  694. sum_dy_ = at::empty_like(mean_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  695. sum_dy_xmu_ = at::empty_like(mean_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  696. }
  697. if (weight_g) {
  698. grad_weight_ = at::empty({n_input}, weight_.options());
  699. }
  700. if (bias_g) {
  701. grad_bias_ = at::empty({n_input}, weight_.options());
  702. }
  703. auto input = get_packed_accessor<
  704. input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input");
  705. auto grad_output = get_packed_accessor<
  706. input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output");
  707. auto grad_weight = packed_accessor_or_dummy<
  708. stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_weight_, "grad_weight");
  709. auto grad_bias = packed_accessor_or_dummy<
  710. stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_bias_, "grad_bias");
  711. auto mean = packed_accessor_or_dummy<
  712. stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean");
  713. auto invstd = packed_accessor_or_dummy<
  714. stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd");
  715. auto sum_dy = packed_accessor_or_dummy<
  716. stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy");
  717. auto sum_dy_xmu = packed_accessor_or_dummy<
  718. stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu");
  719. auto batch_size = input_reshaped.size(0);
  720. auto feature_size = input_reshaped.size(2);
  721. auto stream = at::cuda::getCurrentCUDAStream();
  722. int warp_size = at::cuda::warp_size();
  723. int block_y = std::min<int>(lastPow2(batch_size), MAX_BLOCK_SIZE/warp_size);
  724. // We want block_x to be at least a warp width
  725. int block_x = std::min<int>(std::max<int>(getNumThreads(feature_size), warp_size), MAX_BLOCK_SIZE/block_y);
  726. const dim3 block(block_x, block_y);
  727. const dim3 grid(n_input);
  728. batch_norm_backward_reduce_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, index_t> <<<grid, block, 0, stream>>>
  729. (input, grad_output, mean, invstd, sum_dy, sum_dy_xmu, grad_weight, grad_bias);
  730. C10_CUDA_KERNEL_LAUNCH_CHECK();
  731. return std::make_tuple(sum_dy_, sum_dy_xmu_, grad_weight_, grad_bias_);
  732. }
  733. template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
  734. Tensor batch_norm_backward_elemt_cuda_template(const Tensor& grad_out_, const Tensor& input_,
  735. const Tensor& mean_, const Tensor& invstd_,
  736. const Tensor& weight_, const Tensor& sum_dy_, const Tensor& sum_dy_xmu_) {
  737. using stat_accscalar_t = at::acc_type<stat_scalar_t, true>;
  738. int64_t n_input = input_.size(1);
  739. auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
  740. auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes());
  741. auto grad_input_reshaped = at::empty_like(input_reshaped, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  742. auto input = get_packed_accessor<
  743. input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input");
  744. auto grad_input = get_packed_accessor<
  745. input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input");
  746. auto grad_output = get_packed_accessor<
  747. input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output");
  748. auto mean = packed_accessor_or_dummy<
  749. stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean");
  750. auto invstd = packed_accessor_or_dummy<
  751. stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd");
  752. auto weight = packed_accessor_or_dummy<
  753. stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight");
  754. auto sum_dy = packed_accessor_or_dummy<
  755. stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy");
  756. auto sum_dy_xmu = packed_accessor_or_dummy<
  757. stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu");
  758. auto stream = at::cuda::getCurrentCUDAStream();
  759. // The kernel is pointwise, but we need to balance reading parameters (save_var/mean,
  760. // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks
  761. // and good occupancy. Quiet likely, we could go with even more blocks than 1024.
  762. // The various planes are independent, so we use blocks for them.
  763. int tf = std::max<int>(getNumThreads(input.size(2)/4),
  764. std::min<int>(getNumThreads(input.size(2)), 64));
  765. int tb = std::max<int>(64/tf, 1);
  766. dim3 blocks_trans(input.size(1), std::max<int>(1, std::min<int>((256*1024)/input.size(1),
  767. (input.size(0)+tb-1)/tb)));
  768. blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE);
  769. dim3 threads_trans(tf, tb);
  770. auto reduction_size = input_.numel() / n_input;
  771. auto norm_fct = static_cast<stat_accscalar_t>(1.0 / reduction_size);
  772. batch_norm_backward_elemt_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, index_t>
  773. <<<blocks_trans, threads_trans, 0, stream>>>
  774. (input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct);
  775. C10_CUDA_KERNEL_LAUNCH_CHECK();
  776. return grad_input_reshaped.view(input_.sizes());
  777. }
  778. template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
  779. Tensor batch_norm_backward_elemt_cuda_template(const Tensor& grad_out_, const Tensor& input_,
  780. const Tensor& mean_, const Tensor& invstd_,
  781. const Tensor& weight_, const Tensor& sum_dy_, const Tensor& sum_dy_xmu_, const Tensor& count) {
  782. using stat_accscalar_t = at::acc_type<stat_scalar_t, true>;
  783. int64_t n_input = input_.size(1);
  784. auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
  785. auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes());
  786. auto grad_input_reshaped = at::empty_like(input_reshaped, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  787. auto input = get_packed_accessor<
  788. input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input");
  789. auto grad_input = get_packed_accessor<
  790. input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input");
  791. auto grad_output = get_packed_accessor<
  792. input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output");
  793. auto mean = packed_accessor_or_dummy<
  794. stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean");
  795. auto invstd = packed_accessor_or_dummy<
  796. stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd");
  797. auto weight = packed_accessor_or_dummy<
  798. stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight");
  799. auto sum_dy = packed_accessor_or_dummy<
  800. stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy");
  801. auto sum_dy_xmu = packed_accessor_or_dummy<
  802. stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu");
  803. auto stream = at::cuda::getCurrentCUDAStream();
  804. // The kernel is pointwise, but we need to balance reading parameters (save_var/mean,
  805. // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks
  806. // and good occupancy. Quiet likely, we could go with even more blocks than 1024.
  807. // The various planes are independent, so we use blocks for them.
  808. int tf = std::max<int>(getNumThreads(input.size(2)/4),
  809. std::min<int>(getNumThreads(input.size(2)), 64));
  810. int tb = std::max<int>(64/tf, 1);
  811. dim3 blocks_trans(input.size(1), std::max<int>(1, std::min<int>((256*1024)/input.size(1),
  812. (input.size(0)+tb-1)/tb)));
  813. blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE);
  814. dim3 threads_trans(tf, tb);
  815. batch_norm_backward_elemt_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, index_t> <<<blocks_trans, threads_trans, 0, stream>>>
  816. (input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, count.data_ptr<int>(), count.numel());
  817. C10_CUDA_KERNEL_LAUNCH_CHECK();
  818. return grad_input_reshaped.view(input_.sizes());
  819. }
  820. // welford kernel for c last tensor calculating mean/biased_variance/unbiased_variance
  821. // original apex name: welford_kernel_c_last
  822. template
  823. <typename VarTransform,
  824. typename scalar_t,
  825. typename accscalar_t,
  826. int PARALLEL_LOADS>
  827. __global__ void
  828. batch_norm_collect_statistics_channels_last_kernel(
  829. const scalar_t* __restrict__ input,
  830. accscalar_t* __restrict__ out_mean,
  831. accscalar_t* __restrict__ out_invstd,
  832. volatile accscalar_t* staging_data,
  833. int* semaphores,
  834. const int reduction_size,
  835. const int stride,
  836. accscalar_t epsilon) {
  837. // hide latency with concurrency
  838. accscalar_t x_mean[PARALLEL_LOADS];
  839. accscalar_t m_2_n[PARALLEL_LOADS];
  840. int count[PARALLEL_LOADS];
  841. #pragma unroll
  842. for (int i = 0; i < PARALLEL_LOADS; i++) {
  843. x_mean[i] = accscalar_t(0);
  844. m_2_n[i] = accscalar_t(0);
  845. count[i] = accscalar_t(0);
  846. }
  847. // tensor dimension (m,c)
  848. // loop along m dimension
  849. int inner_loop_stride = blockDim.y * gridDim.y;
  850. // offset along m dimension
  851. int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
  852. int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
  853. int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
  854. int address_base = m_offset * stride + c_offset;
  855. int address_increment = inner_loop_stride * stride;
  856. for (int i = 0; i < loop_count; i++) {
  857. accscalar_t x_math[PARALLEL_LOADS];
  858. accscalar_t x_count_inv[PARALLEL_LOADS];
  859. accscalar_t is_valid[PARALLEL_LOADS];
  860. // load multiple data in
  861. #pragma unroll
  862. for (int j = 0; j < PARALLEL_LOADS; j++) {
  863. if (c_offset < stride && m_offset < reduction_size) {
  864. x_math[j] = input[address_base];
  865. count[j]++;
  866. x_count_inv[j] = accscalar_t(1) / count[j];
  867. is_valid[j] = accscalar_t(1);
  868. } else {
  869. x_math[j] = accscalar_t(0);
  870. x_count_inv[j] = accscalar_t(0);
  871. is_valid[j] = accscalar_t(0);
  872. }
  873. m_offset += inner_loop_stride;
  874. address_base += address_increment;
  875. }
  876. // calculate mean/m2n with welford
  877. #pragma unroll
  878. for (int j = 0; j < PARALLEL_LOADS; j++) {
  879. accscalar_t delta0 = x_math[j] - x_mean[j];
  880. x_mean[j] += delta0 * x_count_inv[j];
  881. accscalar_t delta1 = x_math[j] - x_mean[j];
  882. m_2_n[j] += delta0 * delta1 * is_valid[j];
  883. }
  884. }
  885. // thread reduction to accumulate mean/m_2_n/count between PARALLEL_LOADS
  886. #pragma unroll
  887. for (int j = 1; j < PARALLEL_LOADS; j++) {
  888. welford_merge_element(count[0], x_mean[0], m_2_n[0], count[j], x_mean[j], m_2_n[j]);
  889. }
  890. // release x_mean / m_2_n
  891. auto mean_th = x_mean[0];
  892. auto m2_th = m_2_n[0];
  893. auto count_th = count[0];
  894. // block-wise reduction with shared memory (since reduction cannot be done within a warp)
  895. static __shared__ accscalar_t shmem_mean[MAX_BLOCK_SIZE];
  896. static __shared__ accscalar_t shmem_m2n[MAX_BLOCK_SIZE];
  897. static __shared__ int shmem_count[MAX_BLOCK_SIZE];
  898. welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n);
  899. if (gridDim.y > 1) {
  900. volatile accscalar_t* staging_mean = staging_data;
  901. volatile accscalar_t* staging_m2n = &staging_data[stride*gridDim.y];
  902. volatile int* staging_count = reinterpret_cast<volatile int*>(&staging_m2n[stride*gridDim.y]);
  903. address_base = c_offset + blockIdx.y * stride;
  904. // write data to staging_data;
  905. if (threadIdx.y == 0 && c_offset < stride) {
  906. staging_mean[address_base] = mean_th;
  907. staging_m2n[address_base] = m2_th;
  908. staging_count[address_base] = count_th;
  909. }
  910. __threadfence();
  911. __syncthreads(); // ensuring writes to staging_ is visible to all blocks
  912. __shared__ bool is_last_block_done;
  913. // mark block done
  914. if (threadIdx.x == 0 && threadIdx.y == 0) {
  915. int old = atomicAdd(&semaphores[blockIdx.x], 1);
  916. is_last_block_done = (old == (gridDim.y-1));
  917. }
  918. __syncthreads();
  919. // check that all data is now available in global memory
  920. if (is_last_block_done) {
  921. count_th = 0;
  922. mean_th = accscalar_t(0.0);
  923. m2_th = accscalar_t(0.0);
  924. for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) {
  925. address_base = c_offset + y * stride;
  926. int count_new = c_offset < stride ? staging_count[address_base] : 0;
  927. accscalar_t mean_new = c_offset < stride ? staging_mean[address_base] : accscalar_t(0.0);
  928. accscalar_t m2n_new = c_offset < stride ? staging_m2n[address_base] : accscalar_t(0.0);
  929. welford_merge_element(count_th, mean_th, m2_th, count_new, mean_new, m2n_new);
  930. }
  931. welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n);
  932. if (threadIdx.y == 0 && c_offset < stride) {
  933. out_mean[c_offset] = static_cast<accscalar_t>(mean_th);
  934. out_invstd[c_offset] = VarTransform{}(m2_th/count_th, epsilon);
  935. }
  936. }
  937. } else {
  938. if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) {
  939. out_mean[c_offset] = static_cast<accscalar_t>(mean_th);
  940. out_invstd[c_offset] = VarTransform{}(m2_th/count_th, epsilon);
  941. }
  942. }
  943. }
  944. // elementwise BN kernel
  945. // original apex name: batchnorm_forward_c_last_kernel
  946. template <
  947. typename scalar_t,
  948. typename accscalar_t,
  949. typename layerscalar_t,
  950. int PARALLEL_LOADS>
  951. __global__ void batch_norm_transform_input_channels_last_kernel(
  952. const scalar_t* __restrict__ input,
  953. const scalar_t* __restrict__ z,
  954. const accscalar_t* __restrict__ mean,
  955. const accscalar_t* __restrict__ inv_std,
  956. const layerscalar_t* __restrict__ weight,
  957. const layerscalar_t* __restrict__ shift,
  958. scalar_t* __restrict__ out,
  959. const int reduction_size,
  960. const int stride,
  961. const bool fuse_relu) {
  962. // tensor dimension (m,c)
  963. // loop along m dimension
  964. int inner_loop_stride = blockDim.y * gridDim.y;
  965. // offset along m dimension
  966. int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
  967. int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
  968. if (c_offset >= stride || m_offset >= reduction_size) {
  969. return;
  970. }
  971. auto m_c = mean[c_offset];
  972. auto inv_std_c = static_cast<accscalar_t>(inv_std[c_offset]);
  973. auto w_c = weight == nullptr ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset]);
  974. auto s_c = shift == nullptr ? accscalar_t(0.0) : static_cast<accscalar_t>(shift[c_offset]);
  975. int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
  976. int address_base = m_offset * stride + c_offset;
  977. int address_increment = inner_loop_stride * stride;
  978. for (int i = 0; i < loop_count; i++) {
  979. #pragma unroll
  980. for (int j = 0; j < PARALLEL_LOADS; j++) {
  981. if (c_offset < stride && m_offset < reduction_size) {
  982. auto tmp = w_c * (static_cast<accscalar_t>(input[address_base]) - m_c ) * inv_std_c + s_c;
  983. if (z != nullptr) {
  984. tmp += z[address_base];
  985. }
  986. out[address_base] = (fuse_relu && tmp <= accscalar_t(0.0) ? scalar_t(0.0) : static_cast<scalar_t>(tmp));
  987. }
  988. m_offset += inner_loop_stride;
  989. address_base += address_increment;
  990. }
  991. }
  992. }
  993. template<typename T>
  994. __device__ __forceinline__ void merge_block_vertical_backward(T& sum_dy,
  995. T& sum_dy_xmu,
  996. T* shmem_sum_dy,
  997. T* shmem_sum_dy_xmu) {
  998. // write to shared memory
  999. auto address_base = threadIdx.x + threadIdx.y * blockDim.x;
  1000. #pragma unroll
  1001. for (int offset = blockDim.y/2; offset > 0; offset >>= 1) {
  1002. if (threadIdx.y < offset*2) {
  1003. shmem_sum_dy[address_base] = sum_dy;
  1004. shmem_sum_dy_xmu[address_base] = sum_dy_xmu;
  1005. }
  1006. __syncthreads();
  1007. if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
  1008. auto address = address_base + offset * blockDim.x;
  1009. sum_dy += shmem_sum_dy[address];
  1010. sum_dy_xmu += shmem_sum_dy_xmu[address];
  1011. }
  1012. }
  1013. }
  1014. // batchnorm backward kernel for c last tensor
  1015. // original apex name: reduce_bn_c_last_kernel
  1016. template <
  1017. int PARALLEL_LOADS,
  1018. typename scalar_t,
  1019. typename accscalar_t,
  1020. typename layerscalar_t>
  1021. __global__ void batch_norm_backward_reduce_channels_last_kernel(
  1022. const scalar_t* __restrict__ input,
  1023. const scalar_t* __restrict__ grad_output,
  1024. const accscalar_t* __restrict__ mean,
  1025. const accscalar_t* __restrict__ inv_std,
  1026. accscalar_t* __restrict__ sum_dy_o,
  1027. accscalar_t* __restrict__ sum_dy_xmu_o,
  1028. layerscalar_t* __restrict__ grad_weight,
  1029. layerscalar_t* __restrict__ grad_bias,
  1030. volatile accscalar_t* staging_data,
  1031. int* semaphores,
  1032. const int reduction_size,
  1033. const int stride) {
  1034. // hide latency with concurrency
  1035. accscalar_t sum_dy[PARALLEL_LOADS];
  1036. accscalar_t sum_dy_xmu[PARALLEL_LOADS];
  1037. #pragma unroll
  1038. for (int i = 0; i < PARALLEL_LOADS; i++) {
  1039. sum_dy[i] = accscalar_t(0);
  1040. sum_dy_xmu[i] = accscalar_t(0);
  1041. }
  1042. // tensor dimension (m,c)
  1043. // loop along m dimension
  1044. int inner_loop_stride = blockDim.y * gridDim.y;
  1045. // offset along m dimension
  1046. int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
  1047. int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
  1048. if (c_offset >= stride || m_offset >= reduction_size) {
  1049. return;
  1050. }
  1051. int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
  1052. int address_base = m_offset * stride + c_offset;
  1053. int address_increment = inner_loop_stride * stride;
  1054. auto r_mean = mean[c_offset];
  1055. auto factor = inv_std[c_offset];
  1056. for (int i = 0; i < loop_count; i++) {
  1057. accscalar_t x_input[PARALLEL_LOADS];
  1058. accscalar_t x_grad_output[PARALLEL_LOADS];
  1059. // load multiple data in
  1060. #pragma unroll
  1061. for (int j = 0; j < PARALLEL_LOADS; j++) {
  1062. if (c_offset < stride && m_offset < reduction_size) {
  1063. x_input[j] = input[address_base];
  1064. x_grad_output[j] = grad_output[address_base];
  1065. } else {
  1066. x_input[j] = accscalar_t(0);
  1067. x_grad_output[j] = accscalar_t(0);
  1068. }
  1069. m_offset += inner_loop_stride;
  1070. address_base += address_increment;
  1071. }
  1072. // calculate sum_dy / sum_dy_xmu
  1073. #pragma unroll
  1074. for (int j = 0; j < PARALLEL_LOADS; j++) {
  1075. sum_dy[j] += x_grad_output[j];
  1076. sum_dy_xmu[j] += x_grad_output[j] * (x_input[j] - r_mean);
  1077. }
  1078. }
  1079. // thread reduction to accumulate sum_dy / sum_dy_xmu between PARALLEL_LOADS
  1080. #pragma unroll
  1081. for (int j = 1; j < PARALLEL_LOADS; j++) {
  1082. sum_dy[0] += sum_dy[j];
  1083. sum_dy_xmu[0] += sum_dy_xmu[j];
  1084. }
  1085. // release array of registers
  1086. auto sum_dy_th = sum_dy[0];
  1087. auto sum_dy_xmu_th = sum_dy_xmu[0];
  1088. // block-wise reduction with shared memory (since reduction cannot be done within a warp)
  1089. static __shared__ accscalar_t shmem_sum_dy[MAX_BLOCK_SIZE];
  1090. static __shared__ accscalar_t shmem_sum_dy_xmu[MAX_BLOCK_SIZE];
  1091. merge_block_vertical_backward(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu);
  1092. if (gridDim.y > 1) {
  1093. volatile accscalar_t* staging_sum_dy = staging_data;
  1094. volatile accscalar_t* staging_sum_dy_xmu = &staging_data[stride*gridDim.y];
  1095. address_base = c_offset + blockIdx.y * stride;
  1096. // write data to staging_data;
  1097. if (threadIdx.y == 0 && c_offset < stride) {
  1098. staging_sum_dy[address_base] = sum_dy_th;
  1099. staging_sum_dy_xmu[address_base] = sum_dy_xmu_th;
  1100. }
  1101. __threadfence();
  1102. __syncthreads(); // ensuring writes to staging_ is visible to all blocks
  1103. __shared__ bool is_last_block_done;
  1104. // mark block done
  1105. if (threadIdx.x == 0 && threadIdx.y == 0) {
  1106. int old = atomicAdd(&semaphores[blockIdx.x], 1);
  1107. is_last_block_done = (old == (gridDim.y-1));
  1108. }
  1109. __syncthreads();
  1110. // check that all data is now available in global memory
  1111. if (is_last_block_done) {
  1112. sum_dy_th = accscalar_t(0.0);
  1113. sum_dy_xmu_th = accscalar_t(0.0);
  1114. for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) {
  1115. address_base = c_offset + y * stride;
  1116. sum_dy_th += (c_offset < stride ? staging_sum_dy[address_base] : accscalar_t(0.0));
  1117. sum_dy_xmu_th += (c_offset < stride ? staging_sum_dy_xmu[address_base] : accscalar_t(0.0));
  1118. }
  1119. merge_block_vertical_backward(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu);
  1120. if (threadIdx.y == 0 && c_offset < stride) {
  1121. if (grad_bias != nullptr) {
  1122. grad_bias[c_offset] = static_cast<layerscalar_t>(sum_dy_th);
  1123. }
  1124. if (grad_weight != nullptr) {
  1125. grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor);
  1126. }
  1127. //mean_dy[c_offset] = sum_dy_th / reduction_size;
  1128. //mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
  1129. sum_dy_o[c_offset] = sum_dy_th;
  1130. sum_dy_xmu_o[c_offset] = sum_dy_xmu_th;
  1131. }
  1132. }
  1133. } else {
  1134. if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) {
  1135. if (grad_bias != nullptr) {
  1136. grad_bias[c_offset] = static_cast<layerscalar_t>(sum_dy_th);
  1137. }
  1138. if (grad_weight != nullptr) {
  1139. grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor);
  1140. }
  1141. //mean_dy[c_offset] = sum_dy_th / reduction_size;
  1142. //mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
  1143. sum_dy_o[c_offset] = sum_dy_th;
  1144. sum_dy_xmu_o[c_offset] = sum_dy_xmu_th;
  1145. }
  1146. }
  1147. }
  1148. // elementwise BN kernel
  1149. // original apex name: batchnorm_backward_c_last_kernel
  1150. template <
  1151. int PARALLEL_LOADS,
  1152. typename scalar_t,
  1153. typename accscalar_t,
  1154. typename layerscalar_t>
  1155. __device__ __forceinline__ void batch_norm_backward_elemt_channels_last_kernel_impl(
  1156. const scalar_t* __restrict__ grad_output,
  1157. const scalar_t* __restrict__ input,
  1158. const accscalar_t* __restrict__ mean,
  1159. const accscalar_t* __restrict__ inv_std,
  1160. const layerscalar_t* __restrict__ weight,
  1161. const accscalar_t* __restrict__ sum_dy,
  1162. const accscalar_t* __restrict__ sum_dy_xmu,
  1163. scalar_t* __restrict__ grad_input,
  1164. const accscalar_t norm_fct,
  1165. const int reduction_size,
  1166. const int stride) {
  1167. // tensor dimension (m,c)
  1168. // loop along m dimension
  1169. int inner_loop_stride = blockDim.y * gridDim.y;
  1170. // offset along m dimension
  1171. int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
  1172. int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
  1173. if (c_offset >= stride || m_offset >= reduction_size) {
  1174. return;
  1175. }
  1176. auto m_c = mean[c_offset];
  1177. auto m_dy_c = sum_dy[c_offset] * norm_fct;
  1178. auto factor_1_c = inv_std[c_offset];
  1179. auto factor_2_c = (weight == nullptr? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset])) * factor_1_c;
  1180. factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[c_offset] * norm_fct;
  1181. int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
  1182. int address_base = m_offset * stride + c_offset;
  1183. int address_increment = inner_loop_stride * stride;
  1184. for (int i = 0; i < loop_count; i++) {
  1185. #pragma unroll
  1186. for (int j = 0; j < PARALLEL_LOADS; j++) {
  1187. if (c_offset < stride && m_offset < reduction_size) {
  1188. grad_input[address_base] = static_cast<scalar_t>(
  1189. (static_cast<accscalar_t>(grad_output[address_base]) - m_dy_c -
  1190. (static_cast<accscalar_t>(input[address_base]) - m_c) * factor_1_c)
  1191. * factor_2_c);
  1192. }
  1193. m_offset += inner_loop_stride;
  1194. address_base += address_increment;
  1195. }
  1196. }
  1197. }
  1198. template <
  1199. int PARALLEL_LOADS,
  1200. typename scalar_t,
  1201. typename accscalar_t,
  1202. typename layerscalar_t>
  1203. __global__ void batch_norm_backward_elemt_channels_last_kernel(
  1204. const scalar_t* __restrict__ grad_output,
  1205. const scalar_t* __restrict__ input,
  1206. const accscalar_t* __restrict__ mean,
  1207. const accscalar_t* __restrict__ inv_std,
  1208. const layerscalar_t* __restrict__ weight,
  1209. const accscalar_t* __restrict__ sum_dy,
  1210. const accscalar_t* __restrict__ sum_dy_xmu,
  1211. const int* __restrict__ numel,
  1212. scalar_t* __restrict__ grad_input,
  1213. const int64_t world_size,
  1214. const int reduction_size,
  1215. const int stride) {
  1216. int64_t total_numel = 0;
  1217. for (int i = 0; i < world_size; i++) {
  1218. total_numel += numel[i];
  1219. }
  1220. auto norm_fct = static_cast<accscalar_t>(1) / static_cast<accscalar_t>(total_numel);
  1221. batch_norm_backward_elemt_channels_last_kernel_impl<PARALLEL_LOADS>(
  1222. grad_output, input, mean, inv_std, weight, sum_dy, sum_dy_xmu,
  1223. grad_input, norm_fct, reduction_size, stride);
  1224. }
  1225. template <
  1226. int PARALLEL_LOADS,
  1227. typename scalar_t,
  1228. typename accscalar_t,
  1229. typename layerscalar_t>
  1230. __global__ void batch_norm_backward_elemt_channels_last_kernel(
  1231. const scalar_t* __restrict__ grad_output,
  1232. const scalar_t* __restrict__ input,
  1233. const accscalar_t* __restrict__ mean,
  1234. const accscalar_t* __restrict__ inv_std,
  1235. const layerscalar_t* __restrict__ weight,
  1236. const accscalar_t* __restrict__ sum_dy,
  1237. const accscalar_t* __restrict__ sum_dy_xmu,
  1238. scalar_t* __restrict__ grad_input,
  1239. const accscalar_t norm_fct,
  1240. const int reduction_size,
  1241. const int stride) {
  1242. batch_norm_backward_elemt_channels_last_kernel_impl<PARALLEL_LOADS>(
  1243. grad_output, input, mean, inv_std, weight, sum_dy, sum_dy_xmu,
  1244. grad_input, norm_fct, reduction_size, stride);
  1245. }
  1246. template<typename scalar_t, typename VarTransform>
  1247. void batch_norm_stats_channels_last_cuda_template(
  1248. const Tensor& out_mean, const Tensor& out_invstd, const Tensor& input, double epsilon) {
  1249. using accscalar_t = at::acc_type<scalar_t, true>;
  1250. const auto stride = input.sizes()[1];
  1251. const auto reduction_size = input.numel() / stride;
  1252. resize_output(out_mean, {stride});
  1253. resize_output(out_invstd, {stride});
  1254. TORCH_INTERNAL_ASSERT(out_invstd.dim() == 1 && out_invstd.is_contiguous() &&
  1255. out_invstd.sizes()[0]);
  1256. TORCH_INTERNAL_ASSERT(out_mean.dim() == 1 && out_mean.is_contiguous() &&
  1257. out_mean.sizes()[0]);
  1258. dim3 block;
  1259. dim3 grid;
  1260. flexible_launch_configs(reduction_size, stride, block, grid, true);
  1261. at::Tensor staging_data;
  1262. at::Tensor semaphores;
  1263. if (grid.y > 1) {
  1264. staging_data = at::empty({4*stride*grid.y}, out_mean.options());
  1265. semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt));
  1266. }
  1267. accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data_ptr<accscalar_t>() : nullptr;
  1268. int* semaphores_ptr = grid.y > 1 ? semaphores.data_ptr<int>() : nullptr;
  1269. batch_norm_collect_statistics_channels_last_kernel<VarTransform, scalar_t, accscalar_t, ELEMENTS_PER_ITER>
  1270. <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
  1271. input.data_ptr<scalar_t>(),
  1272. out_mean.data_ptr<accscalar_t>(),
  1273. out_invstd.data_ptr<accscalar_t>(),
  1274. staging_data_ptr,
  1275. semaphores_ptr,
  1276. reduction_size,
  1277. stride,
  1278. epsilon);
  1279. C10_CUDA_KERNEL_LAUNCH_CHECK();
  1280. }
  1281. void batch_norm_elemt_channels_last_cuda_template(
  1282. const at::Tensor& output,
  1283. const at::Tensor& input,
  1284. const at::Tensor& weight,
  1285. const at::Tensor& shift, // bias of BN
  1286. const at::Tensor& mean,
  1287. const at::Tensor& inv_std,
  1288. const at::optional<at::Tensor>& z = c10::nullopt, // bias after BN
  1289. const bool fuse_relu = false) {
  1290. const auto stride = input.sizes()[1];
  1291. const auto reduction_size = input.numel() / stride;
  1292. dim3 block;
  1293. dim3 grid;
  1294. flexible_launch_configs(reduction_size, stride, block, grid);
  1295. auto stream = at::cuda::getCurrentCUDAStream();
  1296. const auto second_dtype = weight.defined() ? weight.scalar_type() :
  1297. (shift.defined() ? shift.scalar_type() : input.scalar_type());
  1298. if (input.scalar_type() != second_dtype) {
  1299. AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_forward", [&] {
  1300. using accscalar_t = at::acc_type<scalar_t, true>;
  1301. batch_norm_transform_input_channels_last_kernel<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
  1302. <<<grid, block, 0, stream>>>(
  1303. input.data_ptr<scalar_t>(),
  1304. z.has_value() ? z.value().data_ptr<scalar_t>() : nullptr,
  1305. mean.data_ptr<accscalar_t>(),
  1306. inv_std.data_ptr<accscalar_t>(),
  1307. weight.defined() ? weight.data_ptr<accscalar_t>() : nullptr,
  1308. shift.defined() ? shift.data_ptr<accscalar_t>() : nullptr,
  1309. output.data_ptr<scalar_t>(),
  1310. reduction_size,
  1311. stride,
  1312. fuse_relu);
  1313. C10_CUDA_KERNEL_LAUNCH_CHECK();
  1314. });
  1315. } else {
  1316. if (weight.defined()){
  1317. TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_forward: input.scalar_type() ", input.scalar_type(),
  1318. " is not supported with weight.scalar_type() ", weight.scalar_type());
  1319. }
  1320. AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_forward", [&] {
  1321. using accscalar_t = at::acc_type<scalar_t, true>;
  1322. batch_norm_transform_input_channels_last_kernel<scalar_t, accscalar_t, scalar_t, ELEMENTS_PER_ITER>
  1323. <<<grid, block, 0, stream>>>(
  1324. input.data_ptr<scalar_t>(),
  1325. z.has_value() ? z.value().data_ptr<scalar_t>() : nullptr,
  1326. mean.data_ptr<accscalar_t>(),
  1327. inv_std.data_ptr<accscalar_t>(),
  1328. weight.defined() ? weight.data_ptr<scalar_t>() : nullptr,
  1329. shift.defined() ? shift.data_ptr<scalar_t>(): nullptr,
  1330. output.data_ptr<scalar_t>(),
  1331. reduction_size,
  1332. stride,
  1333. fuse_relu);
  1334. C10_CUDA_KERNEL_LAUNCH_CHECK();
  1335. });
  1336. }
  1337. }
  1338. std::tuple<Tensor, Tensor, Tensor, Tensor>
  1339. batch_norm_backward_reduce_cuda_channels_last_template(const at::Tensor& grad_output,
  1340. const at::Tensor& input,
  1341. const at::Tensor& mean,
  1342. const at::Tensor& inv_std,
  1343. const at::Tensor& weight,
  1344. const bool input_g, const bool weight_g, const bool bias_g) {
  1345. const auto stride = input.sizes()[1];
  1346. const auto reduction_size = input.numel() / stride;
  1347. at::Tensor sumn_dy = at::empty({stride}, mean.options());
  1348. at::Tensor sum_dy_xmu = at::empty({stride}, mean.options());
  1349. at::Tensor grad_weight;
  1350. at::Tensor grad_bias;
  1351. if (weight.defined()) {
  1352. grad_weight = at::empty({stride}, weight.options());
  1353. grad_bias = at::empty({stride}, weight.options());
  1354. } else {
  1355. // because I cannot return an uninitialized at::Tensor
  1356. grad_weight = at::empty({0}, mean.options());
  1357. grad_bias = at::empty({0}, mean.options());
  1358. }
  1359. dim3 block;
  1360. dim3 grid;
  1361. flexible_launch_configs(reduction_size, stride, block, grid, true);
  1362. at::Tensor staging_data;
  1363. at::Tensor semaphores;
  1364. if (grid.y > 1) {
  1365. staging_data = at::empty({2*stride*grid.y}, mean.options());
  1366. semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt));
  1367. }
  1368. auto stream = at::cuda::getCurrentCUDAStream();
  1369. if (weight.defined() && input.scalar_type() != weight.scalar_type()) {
  1370. AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_reduce", [&] {
  1371. using accscalar_t = at::acc_type<scalar_t, true>;
  1372. accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data_ptr<accscalar_t>() : nullptr;
  1373. int* semaphores_ptr = grid.y > 1 ? semaphores.data_ptr<int>() : nullptr;
  1374. batch_norm_backward_reduce_channels_last_kernel<ELEMENTS_PER_ITER>
  1375. <<<grid, block, 0, stream>>>(
  1376. input.data_ptr<scalar_t>(),
  1377. grad_output.data_ptr<scalar_t>(),
  1378. mean.data_ptr<accscalar_t>(),
  1379. inv_std.data_ptr<accscalar_t>(),
  1380. sumn_dy.data_ptr<accscalar_t>(),
  1381. sum_dy_xmu.data_ptr<accscalar_t>(),
  1382. grad_weight.data_ptr<accscalar_t>(),
  1383. grad_bias.data_ptr<accscalar_t>(),
  1384. staging_data_ptr,
  1385. semaphores_ptr,
  1386. reduction_size,
  1387. stride);
  1388. C10_CUDA_KERNEL_LAUNCH_CHECK();
  1389. });
  1390. } else {
  1391. if (weight.defined()) {
  1392. TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_backward_reduce: input.scalar_type() ", input.scalar_type(),
  1393. " is not supported with weight.scalar_type() ", weight.scalar_type());
  1394. }
  1395. AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_reduce", [&] {
  1396. using accscalar_t = at::acc_type<scalar_t, true>;
  1397. accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data_ptr<accscalar_t>() : nullptr;
  1398. int* semaphores_ptr = grid.y > 1 ? semaphores.data_ptr<int>() : nullptr;
  1399. batch_norm_backward_reduce_channels_last_kernel<ELEMENTS_PER_ITER>
  1400. <<<grid, block, 0, stream>>>(
  1401. input.data_ptr<scalar_t>(),
  1402. grad_output.data_ptr<scalar_t>(),
  1403. mean.data_ptr<accscalar_t>(),
  1404. inv_std.data_ptr<accscalar_t>(),
  1405. sumn_dy.data_ptr<accscalar_t>(),
  1406. sum_dy_xmu.data_ptr<accscalar_t>(),
  1407. weight.defined() ? grad_weight.data_ptr<scalar_t>() : nullptr,
  1408. weight.defined() ? grad_bias.data_ptr<scalar_t>() : nullptr,
  1409. staging_data_ptr,
  1410. semaphores_ptr,
  1411. reduction_size,
  1412. stride);
  1413. C10_CUDA_KERNEL_LAUNCH_CHECK();
  1414. });
  1415. }
  1416. return std::make_tuple(sumn_dy, sum_dy_xmu, grad_weight, grad_bias);
  1417. }
  1418. at::Tensor batch_norm_backward_elemt_channels_last_cuda_template(
  1419. const at::Tensor& grad_output,
  1420. const at::Tensor& input,
  1421. const at::Tensor& mean,
  1422. const at::Tensor& inv_std,
  1423. const at::Tensor& weight,
  1424. const at::Tensor& sum_dy,
  1425. const at::Tensor& sum_dy_xmu,
  1426. const at::Tensor& count) {
  1427. const auto stride = input.sizes()[1];
  1428. const auto reduction_size = input.numel() / stride;
  1429. // Input is guarunteed to be channels-last compatible
  1430. at::Tensor grad_input = at::empty_like(input);
  1431. dim3 block;
  1432. dim3 grid;
  1433. flexible_launch_configs(reduction_size, stride, block, grid);
  1434. auto stream = at::cuda::getCurrentCUDAStream();
  1435. if (weight.defined() && weight.scalar_type() != input.scalar_type()) {
  1436. AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_element", [&] {
  1437. using accscalar_t = at::acc_type<scalar_t, true>;
  1438. batch_norm_backward_elemt_channels_last_kernel<ELEMENTS_PER_ITER>
  1439. <<<grid, block, 0, stream>>>(
  1440. grad_output.data_ptr<scalar_t>(),
  1441. input.data_ptr<scalar_t>(),
  1442. mean.data_ptr<accscalar_t>(),
  1443. inv_std.data_ptr<accscalar_t>(),
  1444. weight.data_ptr<accscalar_t>(),
  1445. sum_dy.data_ptr<accscalar_t>(),
  1446. sum_dy_xmu.data_ptr<accscalar_t>(),
  1447. count.data_ptr<int>(),
  1448. grad_input.data_ptr<scalar_t>(),
  1449. count.numel(),
  1450. reduction_size,
  1451. stride);
  1452. C10_CUDA_KERNEL_LAUNCH_CHECK();
  1453. });
  1454. } else {
  1455. if (weight.defined()) {
  1456. TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_backward_element: input.scalar_type() ", input.scalar_type(),
  1457. " is not supported with weight.scalar_type() ", weight.scalar_type());
  1458. }
  1459. AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "batchnorm_backward_element", [&] {
  1460. using accscalar_t = at::acc_type<scalar_t, true>;
  1461. batch_norm_backward_elemt_channels_last_kernel<ELEMENTS_PER_ITER>
  1462. <<<grid, block, 0, stream>>>(
  1463. grad_output.data_ptr<scalar_t>(),
  1464. input.data_ptr<scalar_t>(),
  1465. mean.data_ptr<accscalar_t>(),
  1466. inv_std.data_ptr<accscalar_t>(),
  1467. weight.defined() ? weight.data_ptr<scalar_t>() : nullptr,
  1468. sum_dy.data_ptr<accscalar_t>(),
  1469. sum_dy_xmu.data_ptr<accscalar_t>(),
  1470. count.data_ptr<int>(),
  1471. grad_input.data_ptr<scalar_t>(),
  1472. count.numel(),
  1473. reduction_size,
  1474. stride);
  1475. C10_CUDA_KERNEL_LAUNCH_CHECK();
  1476. });
  1477. }
  1478. return grad_input;
  1479. }
  1480. at::Tensor batch_norm_backward_elemt_channels_last_cuda_template(
  1481. const at::Tensor& grad_output,
  1482. const at::Tensor& input,
  1483. const at::Tensor& mean,
  1484. const at::Tensor& inv_std,
  1485. const at::Tensor& weight,
  1486. const at::Tensor& sum_dy,
  1487. const at::Tensor& sum_dy_xmu) {
  1488. const auto stride = input.sizes()[1];
  1489. const auto reduction_size = input.numel() / stride;
  1490. auto norm_fct = 1.0 / reduction_size;
  1491. // Input is guarunteed to be channels-last compatible
  1492. at::Tensor grad_input = at::empty_like(input);
  1493. dim3 block;
  1494. dim3 grid;
  1495. flexible_launch_configs(reduction_size, stride, block, grid);
  1496. auto stream = at::cuda::getCurrentCUDAStream();
  1497. AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_element", [&] {
  1498. using accscalar_t = at::acc_type<scalar_t, true>;
  1499. if (weight.defined() && weight.scalar_type() != input.scalar_type()) {
  1500. batch_norm_backward_elemt_channels_last_kernel<ELEMENTS_PER_ITER>
  1501. <<<grid, block, 0, stream>>>(
  1502. grad_output.data_ptr<scalar_t>(),
  1503. input.data_ptr<scalar_t>(),
  1504. mean.data_ptr<accscalar_t>(),
  1505. inv_std.data_ptr<accscalar_t>(),
  1506. weight.data_ptr<accscalar_t>(),
  1507. sum_dy.data_ptr<accscalar_t>(),
  1508. sum_dy_xmu.data_ptr<accscalar_t>(),
  1509. grad_input.data_ptr<scalar_t>(),
  1510. static_cast<accscalar_t>(norm_fct),
  1511. reduction_size,
  1512. stride);
  1513. C10_CUDA_KERNEL_LAUNCH_CHECK();
  1514. } else {
  1515. batch_norm_backward_elemt_channels_last_kernel<ELEMENTS_PER_ITER>
  1516. <<<grid, block, 0, stream>>>(
  1517. grad_output.data_ptr<scalar_t>(),
  1518. input.data_ptr<scalar_t>(),
  1519. mean.data_ptr<accscalar_t>(),
  1520. inv_std.data_ptr<accscalar_t>(),
  1521. weight.defined() ? weight.data_ptr<scalar_t>() : nullptr,
  1522. sum_dy.data_ptr<accscalar_t>(),
  1523. sum_dy_xmu.data_ptr<accscalar_t>(),
  1524. grad_input.data_ptr<scalar_t>(),
  1525. static_cast<accscalar_t>(norm_fct),
  1526. reduction_size,
  1527. stride);
  1528. C10_CUDA_KERNEL_LAUNCH_CHECK();
  1529. }
  1530. });
  1531. return grad_input;
  1532. }
  1533. } } // namespace at::native