12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742 |
- #pragma once
- #include <ATen/core/Tensor.h>
- #include <ATen/Dispatch.h>
- #include <ATen/AccumulateType.h>
- #include <ATen/ceil_div.h>
- #include <ATen/cuda/CUDAContext.h>
- #include <ATen/cuda/DeviceUtils.cuh>
- #include <ATen/native/cuda/block_reduce.cuh>
- #include <ATen/native/cuda/DeviceSqrt.cuh>
- #include <ATen/native/cuda/LaunchUtils.h>
- #include <c10/macros/Macros.h>
- #ifndef AT_PER_OPERATOR_HEADERS
- #include <ATen/Functions.h>
- #else
- #include <ATen/ops/empty.h>
- #include <ATen/ops/empty_like.h>
- #include <ATen/ops/zeros.h>
- #endif
- namespace at { namespace native {
- // The maximum number of threads in a block
- #if defined(USE_ROCM)
- constexpr int MAX_BLOCK_SIZE = 256;
- #else
- constexpr int MAX_BLOCK_SIZE = 512;
- #endif
- constexpr unsigned MAX_GRID_SIZE = 65535u;
- // Number of threads in a block given an input size up to MAX_BLOCK_SIZE
- static int getNumThreads(int nElem) {
- #if defined(USE_ROCM)
- int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE };
- #else
- int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
- #endif
- for (int i = 0; i != 5; ++i) {
- if (nElem <= threadSizes[i]) {
- return threadSizes[i];
- }
- }
- return MAX_BLOCK_SIZE;
- }
- // Returns the index of the most significant 1 bit in `val`.
- __device__ __forceinline__ int getMSB(int val) {
- return 31 - __clz(val);
- }
- template <typename scalar_t, typename accscalar_t>
- struct Float2 {
- accscalar_t v1, v2;
- __device__ Float2() {}
- __device__ Float2(scalar_t v1, scalar_t v2) : v1(static_cast<accscalar_t>(v1)), v2(static_cast<accscalar_t>(v2)) {}
- __device__ Float2(int v) : v1(static_cast<accscalar_t>(v)), v2(static_cast<accscalar_t>(v)) {}
- __device__ Float2& operator+=(const Float2& a) {
- v1 += a.v1;
- v2 += a.v2;
- return *this;
- }
- __device__ friend Float2 operator+(Float2 a, const Float2& b) {
- a += b;
- return a;
- }
- };
- template <typename scalar_t, typename accscalar_t, typename PTA>
- struct GradOp {
- __device__ GradOp(accscalar_t m, const PTA& i, const PTA& g)
- : mean(m), input(i), grad_output(g) {}
- __device__ __forceinline__ Float2<scalar_t, accscalar_t> operator()(int batch, int plane, int n) {
- accscalar_t g = grad_output[batch][plane][n];
- accscalar_t c = static_cast<accscalar_t>(input[batch][plane][n]) - mean;
- return Float2<scalar_t, accscalar_t>(g, g * c);
- }
- const accscalar_t mean;
- const PTA& input;
- const PTA& grad_output;
- };
- template <typename acc_t>
- struct SumReduceOp {
- __device__ __forceinline__ acc_t combine(acc_t a, acc_t b) const { return a + b; }
- __device__ __forceinline__ acc_t warp_shfl_down(acc_t data, int offset) const {
- return WARP_SHFL_DOWN(data, offset);
- }
- };
- template <typename scalar_t, typename accscalar_t>
- struct SumReduceOp<Float2<scalar_t, accscalar_t>> {
- using acc_t = Float2<scalar_t, accscalar_t>;
- __device__ __forceinline__ acc_t combine(acc_t a, acc_t b) const { return a + b; }
- __device__ __forceinline__ acc_t warp_shfl_down(acc_t data, int offset) const {
- return {WARP_SHFL_DOWN(data.v1, offset), WARP_SHFL_DOWN(data.v2, offset)};
- }
- };
- // Sum across (batch, x/y/z) applying Op() pointwise
- // this works by first having each thread sum it's part
- // of the data. Then there is a double-shuffling reduction.
- // First each warp (of C10_WARP_SIZE threads) uses warpSum to reduce its
- // data to the "warp leader", who writes its value into shared memory.
- // Then a single warp reads the remaining (at most C10_WARP_SIZE) items
- // and reduces them using another warpSum.
- // The implicit assumption is that there are no more
- // than C10_WARP_SIZE**2 threads.
- template<typename scalar_t, typename Op, typename PTA>
- __device__ scalar_t reduce(Op op, PTA tensor, int plane) {
- // first the reductions each thread does separately
- scalar_t sum = static_cast<scalar_t>(0);
- for (int batch = threadIdx.y; batch < tensor.size(0); batch += blockDim.y) {
- for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x) {
- sum += op(batch, plane, x);
- }
- }
- __shared__ scalar_t shared[C10_WARP_SIZE];
- SumReduceOp<scalar_t> reduce_op;
- sum = cuda_utils::BlockReduce<scalar_t, SumReduceOp<scalar_t>, cuda_utils::Block2D>(sum, reduce_op, 0, shared);
- if (threadIdx.x == 0 && threadIdx.y == 0) {
- shared[0] = sum;
- }
- __syncthreads();
- // Everyone picks it up, should be broadcast into the whole grad_input
- return shared[0];
- }
- constexpr int ELEMENTS_PER_ITER = 4; // enables concurrency within each thread to hide latency
- constexpr int ELEMENTS_PER_THREAD = 16;
- constexpr int OPTIMAL_TILE_W = 32;
- constexpr int MAX_H_BLOCK = 128;
- __host__ void flexible_launch_configs(
- const int reduction,
- const int stride,
- dim3 &block,
- dim3 &grid,
- const bool coop_flag = false) {
- int block_x = std::min(lastPow2(stride), OPTIMAL_TILE_W);
- int block_y = std::min(lastPow2(at::ceil_div(reduction , ELEMENTS_PER_THREAD)),
- MAX_BLOCK_SIZE / block_x);
- if (block_x * block_y != MAX_BLOCK_SIZE) {
- block_x = std::min(lastPow2(stride), MAX_BLOCK_SIZE / block_y);
- }
- int grid_x = at::ceil_div(stride, block_x);
- int grid_y = std::min(at::ceil_div(reduction, block_y * ELEMENTS_PER_THREAD), MAX_H_BLOCK);
- if (coop_flag) {
- // it's not worth having a grid reduction if the reduction dimension is not big enough
- grid_y = grid_y < 8 ? 1 : grid_y;
- }
- block.x = block_x;
- block.y = block_y;
- block.z = 1;
- grid.x = grid_x;
- grid.y = grid_y;
- grid.z = 1;
- }
- template<typename T, typename C>
- __device__ __forceinline__ void welford_merge_element(C& count,
- T& mean,
- T& m2n,
- const C& count_new,
- const T& mean_new,
- const T& m2n_new) {
- T factor = T(1.0) / ::max(1, (count + count_new));
- T delta0 = mean - mean_new;
- mean = (mean_new * count_new + mean * count) * factor;
- m2n += m2n_new + delta0 * delta0 * count_new * count * factor;
- count += count_new;
- }
- // merge mean/m2n among threadIdx.y within block
- template<typename T, typename C>
- __device__ __forceinline__ void welford_merge_block_vertical(C& count,
- T& mean,
- T& m2n,
- C* shmem_count,
- T* shmem_mean,
- T* shmem_m2n) {
- // write to shared memory
- auto address_base = threadIdx.x + threadIdx.y * blockDim.x;
- #pragma unroll
- for (int offset = blockDim.y/2; offset > 0; offset >>= 1) {
- if (threadIdx.y < offset*2) {
- shmem_mean[address_base] = mean;
- shmem_m2n[address_base] = m2n;
- shmem_count[address_base] = count;
- }
- __syncthreads();
- if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
- auto address = address_base + offset * blockDim.x;
- // read shared memory back to register for reduction
- auto count_new = shmem_count[address];
- auto mean_new = shmem_mean[address];
- auto m2n_new = shmem_m2n[address];
- welford_merge_element(count, mean, m2n, count_new, mean_new, m2n_new);
- }
- }
- }
- template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, bool train, typename index_t>
- __global__ void batch_norm_transform_input_kernel(
- const GenericPackedTensorAccessor<input_scalar_t, 3, RestrictPtrTraits, index_t> input,
- GenericPackedTensorAccessor<input_scalar_t, 3, RestrictPtrTraits, index_t> output,
- const GenericPackedTensorAccessor<typename std::conditional<train, stat_accscalar_t, stat_scalar_t>::type, 1, RestrictPtrTraits, index_t> mean_,
- const GenericPackedTensorAccessor<typename std::conditional<train, stat_accscalar_t, stat_scalar_t>::type, 1, RestrictPtrTraits, index_t> var_or_invstd,
- const GenericPackedTensorAccessor<stat_scalar_t, 1, RestrictPtrTraits, index_t> weight,
- const GenericPackedTensorAccessor<stat_scalar_t, 1, RestrictPtrTraits, index_t> bias,
- stat_accscalar_t epsilon) {
- index_t plane = blockIdx.x;
- if (plane >= input.size(1)) {
- return;
- }
- stat_accscalar_t gamma = weight.size(0) > 0 ? static_cast<stat_accscalar_t>(weight[plane]) : static_cast<stat_accscalar_t>(1);
- stat_accscalar_t beta = bias.size(0) > 0 ? static_cast<stat_accscalar_t>(bias[plane]) : static_cast<stat_accscalar_t>(0);
- stat_accscalar_t mean = static_cast<stat_accscalar_t>(mean_[plane]);
- stat_accscalar_t invstd;
- if (train) {
- invstd = var_or_invstd[plane];
- } else {
- invstd = static_cast<stat_accscalar_t>(1) / device_sqrt(static_cast<stat_accscalar_t>(var_or_invstd[plane]) + epsilon);
- }
- index_t bs = input.size(0);
- index_t fs = input.size(2);
- index_t bstep = blockDim.y * gridDim.y;
- for (index_t batch = threadIdx.y + blockIdx.y * blockDim.y; batch < bs; batch += bstep) {
- auto o = output[batch][plane];
- auto i = input[batch][plane];
- for (index_t feature = threadIdx.x; feature < fs; feature += blockDim.x) {
- o[feature] = static_cast<input_scalar_t>(gamma * (i[feature] - mean) * invstd + beta);
- }
- }
- }
- struct InvStd {
- template <typename T>
- __device__ __forceinline__ T operator()(T var, double epsilon) const {
- T invstd = 0;
- if (var != static_cast<T>(0) || epsilon != static_cast<T>(0)) {
- invstd = static_cast<T>(1) / device_sqrt(var + epsilon);
- }
- return invstd;
- }
- };
- struct Var {
- template <typename T>
- __device__ __forceinline__ T operator()(T var, double epsilon) const {
- return var;
- }
- };
- template <typename VarTransform, typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
- __global__ void batch_norm_collect_statistics_kernel(
- const GenericPackedTensorAccessor<input_scalar_t, 3, RestrictPtrTraits, index_t> input,
- const stat_accscalar_t epsilon,
- const stat_accscalar_t momentum,
- GenericPackedTensorAccessor<stat_accscalar_t, 1, RestrictPtrTraits, index_t> save_mean,
- GenericPackedTensorAccessor<stat_accscalar_t, 1, RestrictPtrTraits, index_t> save_transformed_var) {
- __shared__ int shared_n[2 * 2 * C10_WARP_SIZE + C10_WARP_SIZE];
- int plane = blockIdx.x;
- int N = input.size(0) * input.size(2);
- int tid = threadIdx.x + threadIdx.y * blockDim.x;
- // Compute the mean and variance across (batch, x/y/z)
- // this uses the Welford (in the for loop)/parallel algorithm (to sum across the block)
- // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
- // and the parallel algorithm on the same page.
- // We use two shuffles to reduce across the entire block.
- // https://devblogs.nvidia.com/faster-parallel-reductions-kepler/ has a description.
- stat_accscalar_t* shared_avg_var = (stat_accscalar_t*) &shared_n[C10_WARP_SIZE];
- // first the reductions each thread does separately
- stat_accscalar_t avg = 0;
- stat_accscalar_t var_n = 0;
- int n = 0;
- for (int batch = threadIdx.y; batch < input.size(0); batch += blockDim.y) {
- for (int x = threadIdx.x; x < input.size(2); x += blockDim.x) {
- stat_accscalar_t v = input[batch][plane][x];
- stat_accscalar_t d1 = v - avg;
- n++;
- avg += d1 / n;
- var_n += d1 * (v - avg);
- }
- }
- // first warpSum to get one value per thread to
- // one value per warp
- for (int i = 0; i < getMSB(C10_WARP_SIZE); ++i) {
- stat_accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, C10_WARP_SIZE);
- int o_n = WARP_SHFL_XOR(n, 1 << i, C10_WARP_SIZE);
- stat_accscalar_t factor = 1.0 / fmaxf(1.0, n+o_n);
- var_n += WARP_SHFL_XOR(var_n, 1 << i, C10_WARP_SIZE) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor;
- avg = (n * avg + o_n * o_avg) * factor;
- n += o_n;
- }
- // this writes each warps item into shared memory
- // there are at most C10_WARP_SIZE items left because
- // there are at most C10_WARP_SIZE**2 threads at the beginning
- __syncthreads();
- if (tid % C10_WARP_SIZE == 0) {
- shared_n[tid / C10_WARP_SIZE] = n;
- shared_avg_var[tid / C10_WARP_SIZE * 2] = avg;
- shared_avg_var[tid / C10_WARP_SIZE * 2 + 1] = var_n;
- }
- __syncthreads();
- // now have a second warpSum to reduce the intermediate values
- // from shared memory to a single number. The very first
- // thread writes it to shared memory.
- if (tid < C10_WARP_SIZE) {
- n = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_n[tid] : 0);
- avg = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_avg_var[2 * tid] : stat_accscalar_t(0));
- var_n = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_avg_var[2 * tid + 1] : stat_accscalar_t(0));
- }
- for (int i = 0; i < getMSB(C10_WARP_SIZE); ++i) {
- stat_accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, C10_WARP_SIZE);
- int o_n = WARP_SHFL_XOR(n, 1 << i, C10_WARP_SIZE);
- stat_accscalar_t factor = 1.0 / fmaxf(1.0, n+o_n);
- var_n += WARP_SHFL_XOR(var_n, 1 << i, C10_WARP_SIZE) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor;
- avg = (n * avg + o_n * o_avg) * factor;
- n += o_n;
- }
- // Save the mean, variance, and moving averages
- if (tid == 0) {
- if (save_mean.data() != NULL) {
- save_mean[plane] = avg;
- }
- if (save_transformed_var.data() != NULL) {
- save_transformed_var[plane] = VarTransform{}(var_n / N, epsilon);
- }
- }
- }
- template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
- __global__ void batch_norm_backward_kernel(
- const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> input,
- const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
- GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_input,
- GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> grad_weight,
- GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> grad_bias,
- const GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> weight,
- const GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> running_mean,
- const GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> running_var,
- const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> save_mean,
- const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> save_invstd,
- bool train,
- stat_accscalar_t epsilon) {
- index_t plane = blockIdx.x;
- index_t N = grad_output.size(0) * grad_output.size(2);
- stat_accscalar_t mean, invstd;
- if (train) {
- mean = save_mean[plane];
- invstd = save_invstd[plane];
- } else {
- mean = static_cast<stat_accscalar_t>(running_mean[plane]);
- invstd = static_cast<stat_accscalar_t>(1) / device_sqrt(static_cast<stat_accscalar_t>(running_var[plane]) + epsilon);
- }
- stat_accscalar_t weight_val = weight.size(0) > 0 ? static_cast<stat_accscalar_t>(weight[plane]) : stat_accscalar_t(1);
- stat_accscalar_t norm = stat_accscalar_t(1) / N;
- // Compute two values across (batch, x/y/z) in one pass:
- // 1. Sum(grad_output)
- // 2. DotProduct(input - mean, grad_output)
- GradOp<input_scalar_t, stat_accscalar_t, GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t>> g(mean, input, grad_output);
- auto res = reduce<Float2<input_scalar_t, stat_accscalar_t>>(g, grad_output, plane);
- stat_accscalar_t grad_output_sum = res.v1;
- stat_accscalar_t dot_p = res.v2;
- stat_accscalar_t grad_mean = grad_output_sum * norm;
- stat_accscalar_t proj_scale = dot_p * norm * invstd * invstd;
- stat_accscalar_t grad_scale = invstd * weight_val;
- if (grad_input.data() != NULL) {
- for (int batch = threadIdx.y; batch < grad_output.size(0); batch += blockDim.y) {
- for (int x = threadIdx.x; x < grad_output.size(2); x += blockDim.x) {
- input_scalar_t go = grad_output[batch][plane][x];
- if (train) {
- stat_accscalar_t inp = input[batch][plane][x];
- stat_accscalar_t proj = (inp - mean) * proj_scale;
- grad_input[batch][plane][x] = static_cast<input_scalar_t>((go - proj - grad_mean) * grad_scale);
- } else {
- grad_input[batch][plane][x] = static_cast<input_scalar_t>(go * grad_scale);
- }
- }
- }
- }
- if (grad_weight.size(0) > 0) {
- if (threadIdx.x == 0) {
- grad_weight[plane] = static_cast<stat_scalar_t>(dot_p * invstd);
- }
- }
- if (grad_bias.size(0) > 0) {
- if (threadIdx.x == 0) {
- grad_bias[plane] = static_cast<stat_scalar_t>(grad_output_sum);
- }
- }
- }
- template <typename scalar_t, typename accscalar_t, typename index_t>
- __global__ void batch_norm_reduce_statistics_kernel(
- const GenericPackedTensorAccessor<accscalar_t, 2, RestrictPtrTraits, index_t> vec_mean,
- const GenericPackedTensorAccessor<accscalar_t, 2, RestrictPtrTraits, index_t> vec_invstd,
- GenericPackedTensorAccessor<accscalar_t, 1, RestrictPtrTraits, index_t> mean,
- GenericPackedTensorAccessor<accscalar_t, 1, RestrictPtrTraits, index_t> invstd,
- GenericPackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> running_mean,
- GenericPackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> running_var,
- const accscalar_t epsilon,
- const accscalar_t momentum,
- const GenericPackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> counts) {
- int feature_size = vec_mean.size(1);
- int world_size = vec_mean.size(0);
- int bid = blockIdx.x;
- int tid = threadIdx.x;
- // first the reductions each thread does separately
- for (int i = bid*blockDim.x+tid; i < feature_size; i += gridDim.x*blockDim.x) {
- accscalar_t avg = 0;
- accscalar_t var_n = 0;
- index_t n = 0;
- for (int j = 0; j < world_size; j++) {
- scalar_t count = counts[j];
- accscalar_t m = vec_mean[j][i];
- accscalar_t v = accscalar_t(1.0) / (vec_invstd[j][i]);
- v = (v * v - epsilon) * count;
- accscalar_t factor = 1.0 / (n + count);
- var_n += v + (avg - m) * (avg - m) * n * count * factor;
- avg = n * factor * avg + count * factor * m;
- n += count;
- }
- mean[i] = avg;
- invstd[i] = static_cast<accscalar_t>(1) / device_sqrt(var_n / n + epsilon);
- if (running_mean.data() != NULL) {
- running_mean[i] = static_cast<scalar_t>((1 - momentum) * running_mean[i] + momentum * avg);
- }
- accscalar_t unbiasedVar = var_n / (n - 1);
- if (running_var.data() != NULL) {
- running_var[i] = static_cast<scalar_t>((1 - momentum) * running_var[i] + momentum * unbiasedVar);
- }
- }
- }
- template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
- __global__ void batch_norm_backward_reduce_kernel(
- const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> input,
- const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
- GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> mean,
- GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> invstd,
- GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy,
- GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy_xmu,
- GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> grad_weight,
- GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> grad_bias) {
- index_t plane = blockIdx.x;
- stat_accscalar_t r_mean = mean[plane];
- stat_accscalar_t factor = invstd[plane];
- GradOp<input_scalar_t, stat_accscalar_t, GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t>> g(r_mean, input, grad_output);
- auto res = reduce<Float2<input_scalar_t, stat_accscalar_t>>(g, grad_output, plane);
- if (threadIdx.x == 0) {
- if (grad_weight.size(0) > 0) {
- grad_weight[plane] = static_cast<stat_scalar_t>(res.v2 * factor);
- }
- if (grad_bias.size(0) > 0) {
- grad_bias[plane] = static_cast<stat_scalar_t>(res.v1);
- }
- if (sum_dy.size(0) > 0) {
- sum_dy[plane] = static_cast<stat_accscalar_t>(res.v1);
- }
- if (sum_dy_xmu.size(0) > 0) {
- sum_dy_xmu[plane] = static_cast<stat_accscalar_t>(res.v2);
- }
- }
- }
- template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
- __device__ __forceinline__ void batch_norm_backward_elemt_kernel_impl(
- const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> input,
- const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
- const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> mean,
- const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> invstd,
- const GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> weight,
- const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy,
- const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy_xmu,
- GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_input,
- const stat_accscalar_t norm_fct) {
- index_t plane = blockIdx.x;
- if (plane >= input.size(1)) {
- return;
- }
- stat_accscalar_t m_c = mean[plane];
- stat_accscalar_t m_dy_c = sum_dy[plane] * norm_fct;
- stat_accscalar_t factor_1_c = invstd[plane];
- stat_accscalar_t factor_2_c = weight.size(0) > 0 ? static_cast<stat_accscalar_t>(weight[plane]) : stat_accscalar_t(1);
- factor_2_c *= factor_1_c;
- factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[plane] * norm_fct;
- index_t bs = input.size(0);
- index_t fs = input.size(2);
- index_t bstep = blockDim.y * gridDim.y;
- for (index_t batch = threadIdx.y + blockIdx.y * blockDim.y; batch < bs; batch += bstep) {
- auto g_i = grad_input[batch][plane];
- auto g_o = grad_output[batch][plane];
- auto i = input[batch][plane];
- for (index_t feature = threadIdx.x; feature < fs; feature += blockDim.x) {
- g_i[feature] = static_cast<input_scalar_t>((g_o[feature] - m_dy_c - (i[feature] - m_c) * factor_1_c) * factor_2_c);
- }
- }
- }
- template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
- __global__ void batch_norm_backward_elemt_kernel(
- const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> input,
- const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
- const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> mean,
- const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> invstd,
- const GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> weight,
- const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy,
- const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy_xmu,
- GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_input,
- const int* __restrict__ numel, const int world_size) {
- int64_t total_numel = 0;
- for (int i = 0; i < world_size; i ++) {
- total_numel += numel[i];
- }
- const stat_accscalar_t norm_fct =
- static_cast<stat_accscalar_t>(1) / static_cast<stat_accscalar_t>(total_numel);
- batch_norm_backward_elemt_kernel_impl(
- input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct);
- }
- template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
- __global__ void batch_norm_backward_elemt_kernel(
- const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> input,
- const GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
- const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> mean,
- const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> invstd,
- const GenericPackedTensorAccessor<stat_scalar_t, 1, DefaultPtrTraits, index_t> weight,
- const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy,
- const GenericPackedTensorAccessor<stat_accscalar_t, 1, DefaultPtrTraits, index_t> sum_dy_xmu,
- GenericPackedTensorAccessor<input_scalar_t, 3, DefaultPtrTraits, index_t> grad_input,
- const stat_accscalar_t norm_fct) {
- batch_norm_backward_elemt_kernel_impl(
- input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct);
- }
- template <typename scalar_t, int64_t dim, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
- static GenericPackedTensorAccessor<scalar_t, dim, PtrTraits, index_t> get_packed_accessor(
- const Tensor& t, c10::string_view var_name) {
- constexpr auto expect_type = c10::CppTypeToScalarType<scalar_t>::value;
- const auto actual_type = t.scalar_type();
- TORCH_CHECK(actual_type == expect_type, "Expected ", var_name,
- " to have type ", expect_type, " but got ", actual_type);
- return t.generic_packed_accessor<scalar_t, dim, PtrTraits, index_t>();
- }
- template <typename scalar_t, int64_t dim, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
- static GenericPackedTensorAccessor<scalar_t, dim, PtrTraits, index_t> packed_accessor_or_dummy(
- const Tensor& t, c10::string_view var_name) {
- if (!t.defined()) {
- const std::array<index_t, dim> zeros{{0}};
- return GenericPackedTensorAccessor<scalar_t, dim, PtrTraits, index_t>(nullptr, zeros.data(), zeros.data());
- }
- return get_packed_accessor<scalar_t, dim, PtrTraits, index_t>(t, var_name);
- }
- template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
- std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda_template(const Tensor& grad_out_, const Tensor& input_, const Tensor& weight_,
- const Tensor& running_mean_, const Tensor& running_var_, const Tensor& save_mean_, const Tensor& save_invstd_,
- bool train, double epsilon, std::array<bool,3> grad_input_mask) {
- using accscalar_t = at::acc_type<stat_scalar_t, true>;
- Tensor grad_input_;
- Tensor grad_input_reshaped;
- Tensor grad_weight_;
- Tensor grad_bias_;
- auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1});
- auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes());
- if (grad_input_mask[0]) {
- grad_input_ = at::empty_like(input_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
- grad_input_reshaped = grad_input_.view(input_reshaped.sizes());
- }
- if (grad_input_mask[1]) {
- grad_weight_ = at::empty_like(weight_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
- }
- if (grad_input_mask[2]) {
- grad_bias_ = at::empty_like(weight_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
- }
- auto input = get_packed_accessor<
- input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input");
- auto grad_output = get_packed_accessor<
- input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output");
- auto grad_input = packed_accessor_or_dummy<
- input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input");
- auto weight = packed_accessor_or_dummy<
- stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight");
- auto grad_weight = packed_accessor_or_dummy<
- stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_weight_, "grad_weight");
- auto grad_bias = packed_accessor_or_dummy<
- stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_bias_, "grad_bias");
- auto running_mean = packed_accessor_or_dummy<
- stat_scalar_t, 1, DefaultPtrTraits, index_t>(running_mean_, "running_mean");
- auto running_var = packed_accessor_or_dummy<
- stat_scalar_t, 1, DefaultPtrTraits, index_t>(running_var_, "running_var");
- auto save_mean = packed_accessor_or_dummy<
- accscalar_t, 1, DefaultPtrTraits, index_t>(save_mean_, "save_mean");
- auto save_invstd = packed_accessor_or_dummy<
- accscalar_t, 1, DefaultPtrTraits, index_t>(save_invstd_, "save_invstd");
- auto stream = at::cuda::getCurrentCUDAStream();
- dim3 blocks(input.size(1));
- int tf = getNumThreads(input.size(2));
- dim3 threads(tf, std::max<int>(1, MAX_BLOCK_SIZE/tf));
- batch_norm_backward_kernel<input_scalar_t, stat_scalar_t, accscalar_t, index_t> <<<blocks, threads, 0, stream>>>
- (input, grad_output, grad_input, grad_weight, grad_bias, weight, running_mean, running_var,
- save_mean, save_invstd, train, epsilon);
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- return std::make_tuple(grad_input_, grad_weight_, grad_bias_);
- }
- template<typename scalar_t, typename index_t, typename VarTransform>
- void batch_norm_stats_cuda_template(
- const Tensor& out_mean, const Tensor& out_invstd, const Tensor& input_, double epsilon) {
- using accscalar_t = at::acc_type<scalar_t, true>;
- int64_t n_input = input_.size(1);
- Tensor dummy_mean_;
- Tensor dummy_var_;
- auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
- resize_output(out_mean, {n_input});
- resize_output(out_invstd, {n_input});
- auto input = get_packed_accessor<
- scalar_t, 3, RestrictPtrTraits, index_t>(input_reshaped, "input");
- TORCH_INTERNAL_ASSERT(out_invstd.dim() == 1 && out_invstd.is_contiguous() &&
- out_invstd.sizes()[0]);
- TORCH_INTERNAL_ASSERT(out_mean.dim() == 1 && out_mean.is_contiguous() &&
- out_mean.sizes()[0]);
- auto mean = packed_accessor_or_dummy<
- accscalar_t, 1, RestrictPtrTraits, index_t>(out_mean, "out_mean");
- auto invstd = packed_accessor_or_dummy<
- accscalar_t, 1, RestrictPtrTraits, index_t>(out_invstd, "out_invstd");
- auto stream = at::cuda::getCurrentCUDAStream();
- dim3 blocks(input.size(1));
- int tf = getNumThreads(input.size(2));
- dim3 threads(tf, std::max<int>(1, MAX_BLOCK_SIZE/tf));
- batch_norm_collect_statistics_kernel<VarTransform, scalar_t, scalar_t, accscalar_t, index_t> <<<blocks, threads, 0, stream>>>
- (input, epsilon, 0.0, mean, invstd);
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- }
- template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
- void batch_norm_elemt_cuda_template(const Tensor& output_, const Tensor& input_, const Tensor& weight_,
- const Tensor& bias_, const Tensor& mean_, const Tensor& invstd_) {
- using stat_accscalar_t = at::acc_type<stat_scalar_t, true>;
- int64_t n_input = input_.size(1);
- auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
- auto output_reshaped = output_.view({input_.size(0), input_.size(1), -1});
- auto input = get_packed_accessor<
- input_scalar_t, 3, RestrictPtrTraits, index_t>(input_reshaped, "input");
- auto output = get_packed_accessor<
- input_scalar_t, 3, RestrictPtrTraits, index_t>(output_reshaped, "output");
- auto weight = packed_accessor_or_dummy<
- stat_scalar_t, 1, RestrictPtrTraits, index_t>(weight_, "weight");
- auto bias = packed_accessor_or_dummy<
- stat_scalar_t, 1, RestrictPtrTraits, index_t>(bias_, "bias");
- auto mean = packed_accessor_or_dummy<
- stat_accscalar_t, 1, RestrictPtrTraits, index_t>(mean_, "mean");
- auto invstd = packed_accessor_or_dummy<
- stat_accscalar_t, 1, RestrictPtrTraits, index_t>(invstd_, "invstd");
- auto stream = at::cuda::getCurrentCUDAStream();
- // NOTE: We use transform_input_kernel in training mode, which ignores epsilon
- const double dummy_epsilon = 1e-5;
- // The input_transform kernel is pointwise, but we need to balance reading parameters (save_var/mean,
- // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks
- // and good occupancy. Quiet likely, we could go with even more blocks than 1024.
- // The various planes are independent, so we use blocks for them.
- int tf = std::max<int>(getNumThreads(input.size(2)/4),
- std::min<int>(getNumThreads(input.size(2)), 64));
- int tb = std::max<int>(64/tf, 1);
- dim3 blocks_trans(input.size(1), std::max<int>(1, std::min<int>((256*1024)/input.size(1),
- (input.size(0)+tb-1)/tb)));
- blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE);
- dim3 threads_trans(tf, tb);
- batch_norm_transform_input_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, true, index_t> <<<blocks_trans, threads_trans, 0, stream>>>
- (input, output, mean, invstd, weight, bias, dummy_epsilon);
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- }
- template<typename scalar_t, typename accscalar_t, typename index_t>
- std::tuple<Tensor, Tensor> batch_norm_gather_stats_cuda_template(const Tensor& mean_, const Tensor& invstd_,
- const Tensor& running_mean_, const Tensor& running_var_,
- double momentum, double epsilon, const Tensor& counts_) {
- Tensor save_mean_;
- Tensor save_invstd_;
- auto features = mean_.size(1);
- auto input_options = mean_.options();
- if (mean_.scalar_type() == at::ScalarType::Half || mean_.scalar_type() == at::ScalarType::BFloat16) {
- input_options = input_options.dtype(ScalarType::Float);
- }
- save_mean_ = at::empty({features}, input_options);
- save_invstd_ = at::empty({features}, input_options);
- auto mean = packed_accessor_or_dummy<
- accscalar_t, 2, RestrictPtrTraits, index_t>(mean_, "mean");
- auto invstd = packed_accessor_or_dummy<
- accscalar_t, 2, RestrictPtrTraits, index_t>(invstd_, "invstd");
- auto running_mean = packed_accessor_or_dummy<
- scalar_t, 1, RestrictPtrTraits, index_t>(running_mean_, "running_mean");
- auto running_var = packed_accessor_or_dummy<
- scalar_t, 1, RestrictPtrTraits, index_t>(running_var_, "running_mean");
- auto counts = packed_accessor_or_dummy<
- scalar_t, 1, RestrictPtrTraits, index_t>(counts_, "counts");
- auto save_mean = get_packed_accessor<
- accscalar_t, 1, RestrictPtrTraits, index_t>(save_mean_, "save_mean");
- auto save_invstd = get_packed_accessor<
- accscalar_t, 1, RestrictPtrTraits, index_t>(save_invstd_, "save_invstd");
- auto stream = at::cuda::getCurrentCUDAStream();
- int block = getNumThreads(features);
- int grid = std::max<int>(1, features/block);
- batch_norm_reduce_statistics_kernel<scalar_t, accscalar_t, index_t> <<<grid, block, 0, stream>>>
- (mean, invstd, save_mean, save_invstd, running_mean, running_var, epsilon, momentum, counts);
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- return std::make_tuple(save_mean_, save_invstd_);
- }
- template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
- std::tuple<Tensor, Tensor, Tensor, Tensor> batch_norm_backward_reduce_cuda_template(const Tensor& grad_out_, const Tensor& input_,
- const Tensor& mean_, const Tensor& invstd_, const Tensor& weight_,
- const bool input_g, const bool weight_g, const bool bias_g) {
- using stat_accscalar_t = at::acc_type<stat_scalar_t, true>;
- int64_t n_input = input_.size(1);
- Tensor sum_dy_;
- Tensor sum_dy_xmu_;
- Tensor grad_weight_;
- Tensor grad_bias_;
- auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
- auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes());
- if (input_g) {
- sum_dy_ = at::empty_like(mean_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
- sum_dy_xmu_ = at::empty_like(mean_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
- }
- if (weight_g) {
- grad_weight_ = at::empty({n_input}, weight_.options());
- }
- if (bias_g) {
- grad_bias_ = at::empty({n_input}, weight_.options());
- }
- auto input = get_packed_accessor<
- input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input");
- auto grad_output = get_packed_accessor<
- input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output");
- auto grad_weight = packed_accessor_or_dummy<
- stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_weight_, "grad_weight");
- auto grad_bias = packed_accessor_or_dummy<
- stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_bias_, "grad_bias");
- auto mean = packed_accessor_or_dummy<
- stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean");
- auto invstd = packed_accessor_or_dummy<
- stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd");
- auto sum_dy = packed_accessor_or_dummy<
- stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy");
- auto sum_dy_xmu = packed_accessor_or_dummy<
- stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu");
- auto batch_size = input_reshaped.size(0);
- auto feature_size = input_reshaped.size(2);
- auto stream = at::cuda::getCurrentCUDAStream();
- int warp_size = at::cuda::warp_size();
- int block_y = std::min<int>(lastPow2(batch_size), MAX_BLOCK_SIZE/warp_size);
- // We want block_x to be at least a warp width
- int block_x = std::min<int>(std::max<int>(getNumThreads(feature_size), warp_size), MAX_BLOCK_SIZE/block_y);
- const dim3 block(block_x, block_y);
- const dim3 grid(n_input);
- batch_norm_backward_reduce_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, index_t> <<<grid, block, 0, stream>>>
- (input, grad_output, mean, invstd, sum_dy, sum_dy_xmu, grad_weight, grad_bias);
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- return std::make_tuple(sum_dy_, sum_dy_xmu_, grad_weight_, grad_bias_);
- }
- template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
- Tensor batch_norm_backward_elemt_cuda_template(const Tensor& grad_out_, const Tensor& input_,
- const Tensor& mean_, const Tensor& invstd_,
- const Tensor& weight_, const Tensor& sum_dy_, const Tensor& sum_dy_xmu_) {
- using stat_accscalar_t = at::acc_type<stat_scalar_t, true>;
- int64_t n_input = input_.size(1);
- auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
- auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes());
- auto grad_input_reshaped = at::empty_like(input_reshaped, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
- auto input = get_packed_accessor<
- input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input");
- auto grad_input = get_packed_accessor<
- input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input");
- auto grad_output = get_packed_accessor<
- input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output");
- auto mean = packed_accessor_or_dummy<
- stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean");
- auto invstd = packed_accessor_or_dummy<
- stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd");
- auto weight = packed_accessor_or_dummy<
- stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight");
- auto sum_dy = packed_accessor_or_dummy<
- stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy");
- auto sum_dy_xmu = packed_accessor_or_dummy<
- stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu");
- auto stream = at::cuda::getCurrentCUDAStream();
- // The kernel is pointwise, but we need to balance reading parameters (save_var/mean,
- // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks
- // and good occupancy. Quiet likely, we could go with even more blocks than 1024.
- // The various planes are independent, so we use blocks for them.
- int tf = std::max<int>(getNumThreads(input.size(2)/4),
- std::min<int>(getNumThreads(input.size(2)), 64));
- int tb = std::max<int>(64/tf, 1);
- dim3 blocks_trans(input.size(1), std::max<int>(1, std::min<int>((256*1024)/input.size(1),
- (input.size(0)+tb-1)/tb)));
- blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE);
- dim3 threads_trans(tf, tb);
- auto reduction_size = input_.numel() / n_input;
- auto norm_fct = static_cast<stat_accscalar_t>(1.0 / reduction_size);
- batch_norm_backward_elemt_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, index_t>
- <<<blocks_trans, threads_trans, 0, stream>>>
- (input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct);
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- return grad_input_reshaped.view(input_.sizes());
- }
- template<typename input_scalar_t, typename stat_scalar_t, typename index_t>
- Tensor batch_norm_backward_elemt_cuda_template(const Tensor& grad_out_, const Tensor& input_,
- const Tensor& mean_, const Tensor& invstd_,
- const Tensor& weight_, const Tensor& sum_dy_, const Tensor& sum_dy_xmu_, const Tensor& count) {
- using stat_accscalar_t = at::acc_type<stat_scalar_t, true>;
- int64_t n_input = input_.size(1);
- auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
- auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes());
- auto grad_input_reshaped = at::empty_like(input_reshaped, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
- auto input = get_packed_accessor<
- input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input");
- auto grad_input = get_packed_accessor<
- input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input");
- auto grad_output = get_packed_accessor<
- input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output");
- auto mean = packed_accessor_or_dummy<
- stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean");
- auto invstd = packed_accessor_or_dummy<
- stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd");
- auto weight = packed_accessor_or_dummy<
- stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight");
- auto sum_dy = packed_accessor_or_dummy<
- stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy");
- auto sum_dy_xmu = packed_accessor_or_dummy<
- stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu");
- auto stream = at::cuda::getCurrentCUDAStream();
- // The kernel is pointwise, but we need to balance reading parameters (save_var/mean,
- // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks
- // and good occupancy. Quiet likely, we could go with even more blocks than 1024.
- // The various planes are independent, so we use blocks for them.
- int tf = std::max<int>(getNumThreads(input.size(2)/4),
- std::min<int>(getNumThreads(input.size(2)), 64));
- int tb = std::max<int>(64/tf, 1);
- dim3 blocks_trans(input.size(1), std::max<int>(1, std::min<int>((256*1024)/input.size(1),
- (input.size(0)+tb-1)/tb)));
- blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE);
- dim3 threads_trans(tf, tb);
- batch_norm_backward_elemt_kernel<input_scalar_t, stat_scalar_t, stat_accscalar_t, index_t> <<<blocks_trans, threads_trans, 0, stream>>>
- (input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, count.data_ptr<int>(), count.numel());
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- return grad_input_reshaped.view(input_.sizes());
- }
- // welford kernel for c last tensor calculating mean/biased_variance/unbiased_variance
- // original apex name: welford_kernel_c_last
- template
- <typename VarTransform,
- typename scalar_t,
- typename accscalar_t,
- int PARALLEL_LOADS>
- __global__ void
- batch_norm_collect_statistics_channels_last_kernel(
- const scalar_t* __restrict__ input,
- accscalar_t* __restrict__ out_mean,
- accscalar_t* __restrict__ out_invstd,
- volatile accscalar_t* staging_data,
- int* semaphores,
- const int reduction_size,
- const int stride,
- accscalar_t epsilon) {
- // hide latency with concurrency
- accscalar_t x_mean[PARALLEL_LOADS];
- accscalar_t m_2_n[PARALLEL_LOADS];
- int count[PARALLEL_LOADS];
- #pragma unroll
- for (int i = 0; i < PARALLEL_LOADS; i++) {
- x_mean[i] = accscalar_t(0);
- m_2_n[i] = accscalar_t(0);
- count[i] = accscalar_t(0);
- }
- // tensor dimension (m,c)
- // loop along m dimension
- int inner_loop_stride = blockDim.y * gridDim.y;
- // offset along m dimension
- int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
- int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
- int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
- int address_base = m_offset * stride + c_offset;
- int address_increment = inner_loop_stride * stride;
- for (int i = 0; i < loop_count; i++) {
- accscalar_t x_math[PARALLEL_LOADS];
- accscalar_t x_count_inv[PARALLEL_LOADS];
- accscalar_t is_valid[PARALLEL_LOADS];
- // load multiple data in
- #pragma unroll
- for (int j = 0; j < PARALLEL_LOADS; j++) {
- if (c_offset < stride && m_offset < reduction_size) {
- x_math[j] = input[address_base];
- count[j]++;
- x_count_inv[j] = accscalar_t(1) / count[j];
- is_valid[j] = accscalar_t(1);
- } else {
- x_math[j] = accscalar_t(0);
- x_count_inv[j] = accscalar_t(0);
- is_valid[j] = accscalar_t(0);
- }
- m_offset += inner_loop_stride;
- address_base += address_increment;
- }
- // calculate mean/m2n with welford
- #pragma unroll
- for (int j = 0; j < PARALLEL_LOADS; j++) {
- accscalar_t delta0 = x_math[j] - x_mean[j];
- x_mean[j] += delta0 * x_count_inv[j];
- accscalar_t delta1 = x_math[j] - x_mean[j];
- m_2_n[j] += delta0 * delta1 * is_valid[j];
- }
- }
- // thread reduction to accumulate mean/m_2_n/count between PARALLEL_LOADS
- #pragma unroll
- for (int j = 1; j < PARALLEL_LOADS; j++) {
- welford_merge_element(count[0], x_mean[0], m_2_n[0], count[j], x_mean[j], m_2_n[j]);
- }
- // release x_mean / m_2_n
- auto mean_th = x_mean[0];
- auto m2_th = m_2_n[0];
- auto count_th = count[0];
- // block-wise reduction with shared memory (since reduction cannot be done within a warp)
- static __shared__ accscalar_t shmem_mean[MAX_BLOCK_SIZE];
- static __shared__ accscalar_t shmem_m2n[MAX_BLOCK_SIZE];
- static __shared__ int shmem_count[MAX_BLOCK_SIZE];
- welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n);
- if (gridDim.y > 1) {
- volatile accscalar_t* staging_mean = staging_data;
- volatile accscalar_t* staging_m2n = &staging_data[stride*gridDim.y];
- volatile int* staging_count = reinterpret_cast<volatile int*>(&staging_m2n[stride*gridDim.y]);
- address_base = c_offset + blockIdx.y * stride;
- // write data to staging_data;
- if (threadIdx.y == 0 && c_offset < stride) {
- staging_mean[address_base] = mean_th;
- staging_m2n[address_base] = m2_th;
- staging_count[address_base] = count_th;
- }
- __threadfence();
- __syncthreads(); // ensuring writes to staging_ is visible to all blocks
- __shared__ bool is_last_block_done;
- // mark block done
- if (threadIdx.x == 0 && threadIdx.y == 0) {
- int old = atomicAdd(&semaphores[blockIdx.x], 1);
- is_last_block_done = (old == (gridDim.y-1));
- }
- __syncthreads();
- // check that all data is now available in global memory
- if (is_last_block_done) {
- count_th = 0;
- mean_th = accscalar_t(0.0);
- m2_th = accscalar_t(0.0);
- for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) {
- address_base = c_offset + y * stride;
- int count_new = c_offset < stride ? staging_count[address_base] : 0;
- accscalar_t mean_new = c_offset < stride ? staging_mean[address_base] : accscalar_t(0.0);
- accscalar_t m2n_new = c_offset < stride ? staging_m2n[address_base] : accscalar_t(0.0);
- welford_merge_element(count_th, mean_th, m2_th, count_new, mean_new, m2n_new);
- }
- welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n);
- if (threadIdx.y == 0 && c_offset < stride) {
- out_mean[c_offset] = static_cast<accscalar_t>(mean_th);
- out_invstd[c_offset] = VarTransform{}(m2_th/count_th, epsilon);
- }
- }
- } else {
- if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) {
- out_mean[c_offset] = static_cast<accscalar_t>(mean_th);
- out_invstd[c_offset] = VarTransform{}(m2_th/count_th, epsilon);
- }
- }
- }
- // elementwise BN kernel
- // original apex name: batchnorm_forward_c_last_kernel
- template <
- typename scalar_t,
- typename accscalar_t,
- typename layerscalar_t,
- int PARALLEL_LOADS>
- __global__ void batch_norm_transform_input_channels_last_kernel(
- const scalar_t* __restrict__ input,
- const scalar_t* __restrict__ z,
- const accscalar_t* __restrict__ mean,
- const accscalar_t* __restrict__ inv_std,
- const layerscalar_t* __restrict__ weight,
- const layerscalar_t* __restrict__ shift,
- scalar_t* __restrict__ out,
- const int reduction_size,
- const int stride,
- const bool fuse_relu) {
- // tensor dimension (m,c)
- // loop along m dimension
- int inner_loop_stride = blockDim.y * gridDim.y;
- // offset along m dimension
- int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
- int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
- if (c_offset >= stride || m_offset >= reduction_size) {
- return;
- }
- auto m_c = mean[c_offset];
- auto inv_std_c = static_cast<accscalar_t>(inv_std[c_offset]);
- auto w_c = weight == nullptr ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset]);
- auto s_c = shift == nullptr ? accscalar_t(0.0) : static_cast<accscalar_t>(shift[c_offset]);
- int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
- int address_base = m_offset * stride + c_offset;
- int address_increment = inner_loop_stride * stride;
- for (int i = 0; i < loop_count; i++) {
- #pragma unroll
- for (int j = 0; j < PARALLEL_LOADS; j++) {
- if (c_offset < stride && m_offset < reduction_size) {
- auto tmp = w_c * (static_cast<accscalar_t>(input[address_base]) - m_c ) * inv_std_c + s_c;
- if (z != nullptr) {
- tmp += z[address_base];
- }
- out[address_base] = (fuse_relu && tmp <= accscalar_t(0.0) ? scalar_t(0.0) : static_cast<scalar_t>(tmp));
- }
- m_offset += inner_loop_stride;
- address_base += address_increment;
- }
- }
- }
- template<typename T>
- __device__ __forceinline__ void merge_block_vertical_backward(T& sum_dy,
- T& sum_dy_xmu,
- T* shmem_sum_dy,
- T* shmem_sum_dy_xmu) {
- // write to shared memory
- auto address_base = threadIdx.x + threadIdx.y * blockDim.x;
- #pragma unroll
- for (int offset = blockDim.y/2; offset > 0; offset >>= 1) {
- if (threadIdx.y < offset*2) {
- shmem_sum_dy[address_base] = sum_dy;
- shmem_sum_dy_xmu[address_base] = sum_dy_xmu;
- }
- __syncthreads();
- if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
- auto address = address_base + offset * blockDim.x;
- sum_dy += shmem_sum_dy[address];
- sum_dy_xmu += shmem_sum_dy_xmu[address];
- }
- }
- }
- // batchnorm backward kernel for c last tensor
- // original apex name: reduce_bn_c_last_kernel
- template <
- int PARALLEL_LOADS,
- typename scalar_t,
- typename accscalar_t,
- typename layerscalar_t>
- __global__ void batch_norm_backward_reduce_channels_last_kernel(
- const scalar_t* __restrict__ input,
- const scalar_t* __restrict__ grad_output,
- const accscalar_t* __restrict__ mean,
- const accscalar_t* __restrict__ inv_std,
- accscalar_t* __restrict__ sum_dy_o,
- accscalar_t* __restrict__ sum_dy_xmu_o,
- layerscalar_t* __restrict__ grad_weight,
- layerscalar_t* __restrict__ grad_bias,
- volatile accscalar_t* staging_data,
- int* semaphores,
- const int reduction_size,
- const int stride) {
- // hide latency with concurrency
- accscalar_t sum_dy[PARALLEL_LOADS];
- accscalar_t sum_dy_xmu[PARALLEL_LOADS];
- #pragma unroll
- for (int i = 0; i < PARALLEL_LOADS; i++) {
- sum_dy[i] = accscalar_t(0);
- sum_dy_xmu[i] = accscalar_t(0);
- }
- // tensor dimension (m,c)
- // loop along m dimension
- int inner_loop_stride = blockDim.y * gridDim.y;
- // offset along m dimension
- int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
- int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
- if (c_offset >= stride || m_offset >= reduction_size) {
- return;
- }
- int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
- int address_base = m_offset * stride + c_offset;
- int address_increment = inner_loop_stride * stride;
- auto r_mean = mean[c_offset];
- auto factor = inv_std[c_offset];
- for (int i = 0; i < loop_count; i++) {
- accscalar_t x_input[PARALLEL_LOADS];
- accscalar_t x_grad_output[PARALLEL_LOADS];
- // load multiple data in
- #pragma unroll
- for (int j = 0; j < PARALLEL_LOADS; j++) {
- if (c_offset < stride && m_offset < reduction_size) {
- x_input[j] = input[address_base];
- x_grad_output[j] = grad_output[address_base];
- } else {
- x_input[j] = accscalar_t(0);
- x_grad_output[j] = accscalar_t(0);
- }
- m_offset += inner_loop_stride;
- address_base += address_increment;
- }
- // calculate sum_dy / sum_dy_xmu
- #pragma unroll
- for (int j = 0; j < PARALLEL_LOADS; j++) {
- sum_dy[j] += x_grad_output[j];
- sum_dy_xmu[j] += x_grad_output[j] * (x_input[j] - r_mean);
- }
- }
- // thread reduction to accumulate sum_dy / sum_dy_xmu between PARALLEL_LOADS
- #pragma unroll
- for (int j = 1; j < PARALLEL_LOADS; j++) {
- sum_dy[0] += sum_dy[j];
- sum_dy_xmu[0] += sum_dy_xmu[j];
- }
- // release array of registers
- auto sum_dy_th = sum_dy[0];
- auto sum_dy_xmu_th = sum_dy_xmu[0];
- // block-wise reduction with shared memory (since reduction cannot be done within a warp)
- static __shared__ accscalar_t shmem_sum_dy[MAX_BLOCK_SIZE];
- static __shared__ accscalar_t shmem_sum_dy_xmu[MAX_BLOCK_SIZE];
- merge_block_vertical_backward(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu);
- if (gridDim.y > 1) {
- volatile accscalar_t* staging_sum_dy = staging_data;
- volatile accscalar_t* staging_sum_dy_xmu = &staging_data[stride*gridDim.y];
- address_base = c_offset + blockIdx.y * stride;
- // write data to staging_data;
- if (threadIdx.y == 0 && c_offset < stride) {
- staging_sum_dy[address_base] = sum_dy_th;
- staging_sum_dy_xmu[address_base] = sum_dy_xmu_th;
- }
- __threadfence();
- __syncthreads(); // ensuring writes to staging_ is visible to all blocks
- __shared__ bool is_last_block_done;
- // mark block done
- if (threadIdx.x == 0 && threadIdx.y == 0) {
- int old = atomicAdd(&semaphores[blockIdx.x], 1);
- is_last_block_done = (old == (gridDim.y-1));
- }
- __syncthreads();
- // check that all data is now available in global memory
- if (is_last_block_done) {
- sum_dy_th = accscalar_t(0.0);
- sum_dy_xmu_th = accscalar_t(0.0);
- for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) {
- address_base = c_offset + y * stride;
- sum_dy_th += (c_offset < stride ? staging_sum_dy[address_base] : accscalar_t(0.0));
- sum_dy_xmu_th += (c_offset < stride ? staging_sum_dy_xmu[address_base] : accscalar_t(0.0));
- }
- merge_block_vertical_backward(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu);
- if (threadIdx.y == 0 && c_offset < stride) {
- if (grad_bias != nullptr) {
- grad_bias[c_offset] = static_cast<layerscalar_t>(sum_dy_th);
- }
- if (grad_weight != nullptr) {
- grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor);
- }
- //mean_dy[c_offset] = sum_dy_th / reduction_size;
- //mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
- sum_dy_o[c_offset] = sum_dy_th;
- sum_dy_xmu_o[c_offset] = sum_dy_xmu_th;
- }
- }
- } else {
- if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) {
- if (grad_bias != nullptr) {
- grad_bias[c_offset] = static_cast<layerscalar_t>(sum_dy_th);
- }
- if (grad_weight != nullptr) {
- grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor);
- }
- //mean_dy[c_offset] = sum_dy_th / reduction_size;
- //mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
- sum_dy_o[c_offset] = sum_dy_th;
- sum_dy_xmu_o[c_offset] = sum_dy_xmu_th;
- }
- }
- }
- // elementwise BN kernel
- // original apex name: batchnorm_backward_c_last_kernel
- template <
- int PARALLEL_LOADS,
- typename scalar_t,
- typename accscalar_t,
- typename layerscalar_t>
- __device__ __forceinline__ void batch_norm_backward_elemt_channels_last_kernel_impl(
- const scalar_t* __restrict__ grad_output,
- const scalar_t* __restrict__ input,
- const accscalar_t* __restrict__ mean,
- const accscalar_t* __restrict__ inv_std,
- const layerscalar_t* __restrict__ weight,
- const accscalar_t* __restrict__ sum_dy,
- const accscalar_t* __restrict__ sum_dy_xmu,
- scalar_t* __restrict__ grad_input,
- const accscalar_t norm_fct,
- const int reduction_size,
- const int stride) {
- // tensor dimension (m,c)
- // loop along m dimension
- int inner_loop_stride = blockDim.y * gridDim.y;
- // offset along m dimension
- int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
- int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
- if (c_offset >= stride || m_offset >= reduction_size) {
- return;
- }
- auto m_c = mean[c_offset];
- auto m_dy_c = sum_dy[c_offset] * norm_fct;
- auto factor_1_c = inv_std[c_offset];
- auto factor_2_c = (weight == nullptr? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset])) * factor_1_c;
- factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[c_offset] * norm_fct;
- int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
- int address_base = m_offset * stride + c_offset;
- int address_increment = inner_loop_stride * stride;
- for (int i = 0; i < loop_count; i++) {
- #pragma unroll
- for (int j = 0; j < PARALLEL_LOADS; j++) {
- if (c_offset < stride && m_offset < reduction_size) {
- grad_input[address_base] = static_cast<scalar_t>(
- (static_cast<accscalar_t>(grad_output[address_base]) - m_dy_c -
- (static_cast<accscalar_t>(input[address_base]) - m_c) * factor_1_c)
- * factor_2_c);
- }
- m_offset += inner_loop_stride;
- address_base += address_increment;
- }
- }
- }
- template <
- int PARALLEL_LOADS,
- typename scalar_t,
- typename accscalar_t,
- typename layerscalar_t>
- __global__ void batch_norm_backward_elemt_channels_last_kernel(
- const scalar_t* __restrict__ grad_output,
- const scalar_t* __restrict__ input,
- const accscalar_t* __restrict__ mean,
- const accscalar_t* __restrict__ inv_std,
- const layerscalar_t* __restrict__ weight,
- const accscalar_t* __restrict__ sum_dy,
- const accscalar_t* __restrict__ sum_dy_xmu,
- const int* __restrict__ numel,
- scalar_t* __restrict__ grad_input,
- const int64_t world_size,
- const int reduction_size,
- const int stride) {
- int64_t total_numel = 0;
- for (int i = 0; i < world_size; i++) {
- total_numel += numel[i];
- }
- auto norm_fct = static_cast<accscalar_t>(1) / static_cast<accscalar_t>(total_numel);
- batch_norm_backward_elemt_channels_last_kernel_impl<PARALLEL_LOADS>(
- grad_output, input, mean, inv_std, weight, sum_dy, sum_dy_xmu,
- grad_input, norm_fct, reduction_size, stride);
- }
- template <
- int PARALLEL_LOADS,
- typename scalar_t,
- typename accscalar_t,
- typename layerscalar_t>
- __global__ void batch_norm_backward_elemt_channels_last_kernel(
- const scalar_t* __restrict__ grad_output,
- const scalar_t* __restrict__ input,
- const accscalar_t* __restrict__ mean,
- const accscalar_t* __restrict__ inv_std,
- const layerscalar_t* __restrict__ weight,
- const accscalar_t* __restrict__ sum_dy,
- const accscalar_t* __restrict__ sum_dy_xmu,
- scalar_t* __restrict__ grad_input,
- const accscalar_t norm_fct,
- const int reduction_size,
- const int stride) {
- batch_norm_backward_elemt_channels_last_kernel_impl<PARALLEL_LOADS>(
- grad_output, input, mean, inv_std, weight, sum_dy, sum_dy_xmu,
- grad_input, norm_fct, reduction_size, stride);
- }
- template<typename scalar_t, typename VarTransform>
- void batch_norm_stats_channels_last_cuda_template(
- const Tensor& out_mean, const Tensor& out_invstd, const Tensor& input, double epsilon) {
- using accscalar_t = at::acc_type<scalar_t, true>;
- const auto stride = input.sizes()[1];
- const auto reduction_size = input.numel() / stride;
- resize_output(out_mean, {stride});
- resize_output(out_invstd, {stride});
- TORCH_INTERNAL_ASSERT(out_invstd.dim() == 1 && out_invstd.is_contiguous() &&
- out_invstd.sizes()[0]);
- TORCH_INTERNAL_ASSERT(out_mean.dim() == 1 && out_mean.is_contiguous() &&
- out_mean.sizes()[0]);
- dim3 block;
- dim3 grid;
- flexible_launch_configs(reduction_size, stride, block, grid, true);
- at::Tensor staging_data;
- at::Tensor semaphores;
- if (grid.y > 1) {
- staging_data = at::empty({4*stride*grid.y}, out_mean.options());
- semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt));
- }
- accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data_ptr<accscalar_t>() : nullptr;
- int* semaphores_ptr = grid.y > 1 ? semaphores.data_ptr<int>() : nullptr;
- batch_norm_collect_statistics_channels_last_kernel<VarTransform, scalar_t, accscalar_t, ELEMENTS_PER_ITER>
- <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
- input.data_ptr<scalar_t>(),
- out_mean.data_ptr<accscalar_t>(),
- out_invstd.data_ptr<accscalar_t>(),
- staging_data_ptr,
- semaphores_ptr,
- reduction_size,
- stride,
- epsilon);
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- }
- void batch_norm_elemt_channels_last_cuda_template(
- const at::Tensor& output,
- const at::Tensor& input,
- const at::Tensor& weight,
- const at::Tensor& shift, // bias of BN
- const at::Tensor& mean,
- const at::Tensor& inv_std,
- const at::optional<at::Tensor>& z = c10::nullopt, // bias after BN
- const bool fuse_relu = false) {
- const auto stride = input.sizes()[1];
- const auto reduction_size = input.numel() / stride;
- dim3 block;
- dim3 grid;
- flexible_launch_configs(reduction_size, stride, block, grid);
- auto stream = at::cuda::getCurrentCUDAStream();
- const auto second_dtype = weight.defined() ? weight.scalar_type() :
- (shift.defined() ? shift.scalar_type() : input.scalar_type());
- if (input.scalar_type() != second_dtype) {
- AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_forward", [&] {
- using accscalar_t = at::acc_type<scalar_t, true>;
- batch_norm_transform_input_channels_last_kernel<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
- <<<grid, block, 0, stream>>>(
- input.data_ptr<scalar_t>(),
- z.has_value() ? z.value().data_ptr<scalar_t>() : nullptr,
- mean.data_ptr<accscalar_t>(),
- inv_std.data_ptr<accscalar_t>(),
- weight.defined() ? weight.data_ptr<accscalar_t>() : nullptr,
- shift.defined() ? shift.data_ptr<accscalar_t>() : nullptr,
- output.data_ptr<scalar_t>(),
- reduction_size,
- stride,
- fuse_relu);
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- });
- } else {
- if (weight.defined()){
- TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_forward: input.scalar_type() ", input.scalar_type(),
- " is not supported with weight.scalar_type() ", weight.scalar_type());
- }
- AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_forward", [&] {
- using accscalar_t = at::acc_type<scalar_t, true>;
- batch_norm_transform_input_channels_last_kernel<scalar_t, accscalar_t, scalar_t, ELEMENTS_PER_ITER>
- <<<grid, block, 0, stream>>>(
- input.data_ptr<scalar_t>(),
- z.has_value() ? z.value().data_ptr<scalar_t>() : nullptr,
- mean.data_ptr<accscalar_t>(),
- inv_std.data_ptr<accscalar_t>(),
- weight.defined() ? weight.data_ptr<scalar_t>() : nullptr,
- shift.defined() ? shift.data_ptr<scalar_t>(): nullptr,
- output.data_ptr<scalar_t>(),
- reduction_size,
- stride,
- fuse_relu);
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- });
- }
- }
- std::tuple<Tensor, Tensor, Tensor, Tensor>
- batch_norm_backward_reduce_cuda_channels_last_template(const at::Tensor& grad_output,
- const at::Tensor& input,
- const at::Tensor& mean,
- const at::Tensor& inv_std,
- const at::Tensor& weight,
- const bool input_g, const bool weight_g, const bool bias_g) {
- const auto stride = input.sizes()[1];
- const auto reduction_size = input.numel() / stride;
- at::Tensor sumn_dy = at::empty({stride}, mean.options());
- at::Tensor sum_dy_xmu = at::empty({stride}, mean.options());
- at::Tensor grad_weight;
- at::Tensor grad_bias;
- if (weight.defined()) {
- grad_weight = at::empty({stride}, weight.options());
- grad_bias = at::empty({stride}, weight.options());
- } else {
- // because I cannot return an uninitialized at::Tensor
- grad_weight = at::empty({0}, mean.options());
- grad_bias = at::empty({0}, mean.options());
- }
- dim3 block;
- dim3 grid;
- flexible_launch_configs(reduction_size, stride, block, grid, true);
- at::Tensor staging_data;
- at::Tensor semaphores;
- if (grid.y > 1) {
- staging_data = at::empty({2*stride*grid.y}, mean.options());
- semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt));
- }
- auto stream = at::cuda::getCurrentCUDAStream();
- if (weight.defined() && input.scalar_type() != weight.scalar_type()) {
- AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_reduce", [&] {
- using accscalar_t = at::acc_type<scalar_t, true>;
- accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data_ptr<accscalar_t>() : nullptr;
- int* semaphores_ptr = grid.y > 1 ? semaphores.data_ptr<int>() : nullptr;
- batch_norm_backward_reduce_channels_last_kernel<ELEMENTS_PER_ITER>
- <<<grid, block, 0, stream>>>(
- input.data_ptr<scalar_t>(),
- grad_output.data_ptr<scalar_t>(),
- mean.data_ptr<accscalar_t>(),
- inv_std.data_ptr<accscalar_t>(),
- sumn_dy.data_ptr<accscalar_t>(),
- sum_dy_xmu.data_ptr<accscalar_t>(),
- grad_weight.data_ptr<accscalar_t>(),
- grad_bias.data_ptr<accscalar_t>(),
- staging_data_ptr,
- semaphores_ptr,
- reduction_size,
- stride);
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- });
- } else {
- if (weight.defined()) {
- TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_backward_reduce: input.scalar_type() ", input.scalar_type(),
- " is not supported with weight.scalar_type() ", weight.scalar_type());
- }
- AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_reduce", [&] {
- using accscalar_t = at::acc_type<scalar_t, true>;
- accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data_ptr<accscalar_t>() : nullptr;
- int* semaphores_ptr = grid.y > 1 ? semaphores.data_ptr<int>() : nullptr;
- batch_norm_backward_reduce_channels_last_kernel<ELEMENTS_PER_ITER>
- <<<grid, block, 0, stream>>>(
- input.data_ptr<scalar_t>(),
- grad_output.data_ptr<scalar_t>(),
- mean.data_ptr<accscalar_t>(),
- inv_std.data_ptr<accscalar_t>(),
- sumn_dy.data_ptr<accscalar_t>(),
- sum_dy_xmu.data_ptr<accscalar_t>(),
- weight.defined() ? grad_weight.data_ptr<scalar_t>() : nullptr,
- weight.defined() ? grad_bias.data_ptr<scalar_t>() : nullptr,
- staging_data_ptr,
- semaphores_ptr,
- reduction_size,
- stride);
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- });
- }
- return std::make_tuple(sumn_dy, sum_dy_xmu, grad_weight, grad_bias);
- }
- at::Tensor batch_norm_backward_elemt_channels_last_cuda_template(
- const at::Tensor& grad_output,
- const at::Tensor& input,
- const at::Tensor& mean,
- const at::Tensor& inv_std,
- const at::Tensor& weight,
- const at::Tensor& sum_dy,
- const at::Tensor& sum_dy_xmu,
- const at::Tensor& count) {
- const auto stride = input.sizes()[1];
- const auto reduction_size = input.numel() / stride;
- // Input is guarunteed to be channels-last compatible
- at::Tensor grad_input = at::empty_like(input);
- dim3 block;
- dim3 grid;
- flexible_launch_configs(reduction_size, stride, block, grid);
- auto stream = at::cuda::getCurrentCUDAStream();
- if (weight.defined() && weight.scalar_type() != input.scalar_type()) {
- AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_element", [&] {
- using accscalar_t = at::acc_type<scalar_t, true>;
- batch_norm_backward_elemt_channels_last_kernel<ELEMENTS_PER_ITER>
- <<<grid, block, 0, stream>>>(
- grad_output.data_ptr<scalar_t>(),
- input.data_ptr<scalar_t>(),
- mean.data_ptr<accscalar_t>(),
- inv_std.data_ptr<accscalar_t>(),
- weight.data_ptr<accscalar_t>(),
- sum_dy.data_ptr<accscalar_t>(),
- sum_dy_xmu.data_ptr<accscalar_t>(),
- count.data_ptr<int>(),
- grad_input.data_ptr<scalar_t>(),
- count.numel(),
- reduction_size,
- stride);
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- });
- } else {
- if (weight.defined()) {
- TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_backward_element: input.scalar_type() ", input.scalar_type(),
- " is not supported with weight.scalar_type() ", weight.scalar_type());
- }
- AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "batchnorm_backward_element", [&] {
- using accscalar_t = at::acc_type<scalar_t, true>;
- batch_norm_backward_elemt_channels_last_kernel<ELEMENTS_PER_ITER>
- <<<grid, block, 0, stream>>>(
- grad_output.data_ptr<scalar_t>(),
- input.data_ptr<scalar_t>(),
- mean.data_ptr<accscalar_t>(),
- inv_std.data_ptr<accscalar_t>(),
- weight.defined() ? weight.data_ptr<scalar_t>() : nullptr,
- sum_dy.data_ptr<accscalar_t>(),
- sum_dy_xmu.data_ptr<accscalar_t>(),
- count.data_ptr<int>(),
- grad_input.data_ptr<scalar_t>(),
- count.numel(),
- reduction_size,
- stride);
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- });
- }
- return grad_input;
- }
- at::Tensor batch_norm_backward_elemt_channels_last_cuda_template(
- const at::Tensor& grad_output,
- const at::Tensor& input,
- const at::Tensor& mean,
- const at::Tensor& inv_std,
- const at::Tensor& weight,
- const at::Tensor& sum_dy,
- const at::Tensor& sum_dy_xmu) {
- const auto stride = input.sizes()[1];
- const auto reduction_size = input.numel() / stride;
- auto norm_fct = 1.0 / reduction_size;
- // Input is guarunteed to be channels-last compatible
- at::Tensor grad_input = at::empty_like(input);
- dim3 block;
- dim3 grid;
- flexible_launch_configs(reduction_size, stride, block, grid);
- auto stream = at::cuda::getCurrentCUDAStream();
- AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_element", [&] {
- using accscalar_t = at::acc_type<scalar_t, true>;
- if (weight.defined() && weight.scalar_type() != input.scalar_type()) {
- batch_norm_backward_elemt_channels_last_kernel<ELEMENTS_PER_ITER>
- <<<grid, block, 0, stream>>>(
- grad_output.data_ptr<scalar_t>(),
- input.data_ptr<scalar_t>(),
- mean.data_ptr<accscalar_t>(),
- inv_std.data_ptr<accscalar_t>(),
- weight.data_ptr<accscalar_t>(),
- sum_dy.data_ptr<accscalar_t>(),
- sum_dy_xmu.data_ptr<accscalar_t>(),
- grad_input.data_ptr<scalar_t>(),
- static_cast<accscalar_t>(norm_fct),
- reduction_size,
- stride);
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- } else {
- batch_norm_backward_elemt_channels_last_kernel<ELEMENTS_PER_ITER>
- <<<grid, block, 0, stream>>>(
- grad_output.data_ptr<scalar_t>(),
- input.data_ptr<scalar_t>(),
- mean.data_ptr<accscalar_t>(),
- inv_std.data_ptr<accscalar_t>(),
- weight.defined() ? weight.data_ptr<scalar_t>() : nullptr,
- sum_dy.data_ptr<accscalar_t>(),
- sum_dy_xmu.data_ptr<accscalar_t>(),
- grad_input.data_ptr<scalar_t>(),
- static_cast<accscalar_t>(norm_fct),
- reduction_size,
- stride);
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- }
- });
- return grad_input;
- }
- } } // namespace at::native
|