#pragma once #include #include #include #include #include #include #include #include #include #include #ifndef AT_PER_OPERATOR_HEADERS #include #else #include #include #include #endif namespace at { namespace native { // The maximum number of threads in a block #if defined(USE_ROCM) constexpr int MAX_BLOCK_SIZE = 256; #else constexpr int MAX_BLOCK_SIZE = 512; #endif constexpr unsigned MAX_GRID_SIZE = 65535u; // Number of threads in a block given an input size up to MAX_BLOCK_SIZE static int getNumThreads(int nElem) { #if defined(USE_ROCM) int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE }; #else int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE }; #endif for (int i = 0; i != 5; ++i) { if (nElem <= threadSizes[i]) { return threadSizes[i]; } } return MAX_BLOCK_SIZE; } // Returns the index of the most significant 1 bit in `val`. __device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); } template struct Float2 { accscalar_t v1, v2; __device__ Float2() {} __device__ Float2(scalar_t v1, scalar_t v2) : v1(static_cast(v1)), v2(static_cast(v2)) {} __device__ Float2(int v) : v1(static_cast(v)), v2(static_cast(v)) {} __device__ Float2& operator+=(const Float2& a) { v1 += a.v1; v2 += a.v2; return *this; } __device__ friend Float2 operator+(Float2 a, const Float2& b) { a += b; return a; } }; template struct GradOp { __device__ GradOp(accscalar_t m, const PTA& i, const PTA& g) : mean(m), input(i), grad_output(g) {} __device__ __forceinline__ Float2 operator()(int batch, int plane, int n) { accscalar_t g = grad_output[batch][plane][n]; accscalar_t c = static_cast(input[batch][plane][n]) - mean; return Float2(g, g * c); } const accscalar_t mean; const PTA& input; const PTA& grad_output; }; template struct SumReduceOp { __device__ __forceinline__ acc_t combine(acc_t a, acc_t b) const { return a + b; } __device__ __forceinline__ acc_t warp_shfl_down(acc_t data, int offset) const { return WARP_SHFL_DOWN(data, offset); } }; template struct SumReduceOp> { using acc_t = Float2; __device__ __forceinline__ acc_t combine(acc_t a, acc_t b) const { return a + b; } __device__ __forceinline__ acc_t warp_shfl_down(acc_t data, int offset) const { return {WARP_SHFL_DOWN(data.v1, offset), WARP_SHFL_DOWN(data.v2, offset)}; } }; // Sum across (batch, x/y/z) applying Op() pointwise // this works by first having each thread sum it's part // of the data. Then there is a double-shuffling reduction. // First each warp (of C10_WARP_SIZE threads) uses warpSum to reduce its // data to the "warp leader", who writes its value into shared memory. // Then a single warp reads the remaining (at most C10_WARP_SIZE) items // and reduces them using another warpSum. // The implicit assumption is that there are no more // than C10_WARP_SIZE**2 threads. template __device__ scalar_t reduce(Op op, PTA tensor, int plane) { // first the reductions each thread does separately scalar_t sum = static_cast(0); for (int batch = threadIdx.y; batch < tensor.size(0); batch += blockDim.y) { for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x) { sum += op(batch, plane, x); } } __shared__ scalar_t shared[C10_WARP_SIZE]; SumReduceOp reduce_op; sum = cuda_utils::BlockReduce, cuda_utils::Block2D>(sum, reduce_op, 0, shared); if (threadIdx.x == 0 && threadIdx.y == 0) { shared[0] = sum; } __syncthreads(); // Everyone picks it up, should be broadcast into the whole grad_input return shared[0]; } constexpr int ELEMENTS_PER_ITER = 4; // enables concurrency within each thread to hide latency constexpr int ELEMENTS_PER_THREAD = 16; constexpr int OPTIMAL_TILE_W = 32; constexpr int MAX_H_BLOCK = 128; __host__ void flexible_launch_configs( const int reduction, const int stride, dim3 &block, dim3 &grid, const bool coop_flag = false) { int block_x = std::min(lastPow2(stride), OPTIMAL_TILE_W); int block_y = std::min(lastPow2(at::ceil_div(reduction , ELEMENTS_PER_THREAD)), MAX_BLOCK_SIZE / block_x); if (block_x * block_y != MAX_BLOCK_SIZE) { block_x = std::min(lastPow2(stride), MAX_BLOCK_SIZE / block_y); } int grid_x = at::ceil_div(stride, block_x); int grid_y = std::min(at::ceil_div(reduction, block_y * ELEMENTS_PER_THREAD), MAX_H_BLOCK); if (coop_flag) { // it's not worth having a grid reduction if the reduction dimension is not big enough grid_y = grid_y < 8 ? 1 : grid_y; } block.x = block_x; block.y = block_y; block.z = 1; grid.x = grid_x; grid.y = grid_y; grid.z = 1; } template __device__ __forceinline__ void welford_merge_element(C& count, T& mean, T& m2n, const C& count_new, const T& mean_new, const T& m2n_new) { T factor = T(1.0) / ::max(1, (count + count_new)); T delta0 = mean - mean_new; mean = (mean_new * count_new + mean * count) * factor; m2n += m2n_new + delta0 * delta0 * count_new * count * factor; count += count_new; } // merge mean/m2n among threadIdx.y within block template __device__ __forceinline__ void welford_merge_block_vertical(C& count, T& mean, T& m2n, C* shmem_count, T* shmem_mean, T* shmem_m2n) { // write to shared memory auto address_base = threadIdx.x + threadIdx.y * blockDim.x; #pragma unroll for (int offset = blockDim.y/2; offset > 0; offset >>= 1) { if (threadIdx.y < offset*2) { shmem_mean[address_base] = mean; shmem_m2n[address_base] = m2n; shmem_count[address_base] = count; } __syncthreads(); if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) { auto address = address_base + offset * blockDim.x; // read shared memory back to register for reduction auto count_new = shmem_count[address]; auto mean_new = shmem_mean[address]; auto m2n_new = shmem_m2n[address]; welford_merge_element(count, mean, m2n, count_new, mean_new, m2n_new); } } } template __global__ void batch_norm_transform_input_kernel( const GenericPackedTensorAccessor input, GenericPackedTensorAccessor output, const GenericPackedTensorAccessor::type, 1, RestrictPtrTraits, index_t> mean_, const GenericPackedTensorAccessor::type, 1, RestrictPtrTraits, index_t> var_or_invstd, const GenericPackedTensorAccessor weight, const GenericPackedTensorAccessor bias, stat_accscalar_t epsilon) { index_t plane = blockIdx.x; if (plane >= input.size(1)) { return; } stat_accscalar_t gamma = weight.size(0) > 0 ? static_cast(weight[plane]) : static_cast(1); stat_accscalar_t beta = bias.size(0) > 0 ? static_cast(bias[plane]) : static_cast(0); stat_accscalar_t mean = static_cast(mean_[plane]); stat_accscalar_t invstd; if (train) { invstd = var_or_invstd[plane]; } else { invstd = static_cast(1) / device_sqrt(static_cast(var_or_invstd[plane]) + epsilon); } index_t bs = input.size(0); index_t fs = input.size(2); index_t bstep = blockDim.y * gridDim.y; for (index_t batch = threadIdx.y + blockIdx.y * blockDim.y; batch < bs; batch += bstep) { auto o = output[batch][plane]; auto i = input[batch][plane]; for (index_t feature = threadIdx.x; feature < fs; feature += blockDim.x) { o[feature] = static_cast(gamma * (i[feature] - mean) * invstd + beta); } } } struct InvStd { template __device__ __forceinline__ T operator()(T var, double epsilon) const { T invstd = 0; if (var != static_cast(0) || epsilon != static_cast(0)) { invstd = static_cast(1) / device_sqrt(var + epsilon); } return invstd; } }; struct Var { template __device__ __forceinline__ T operator()(T var, double epsilon) const { return var; } }; template __global__ void batch_norm_collect_statistics_kernel( const GenericPackedTensorAccessor input, const stat_accscalar_t epsilon, const stat_accscalar_t momentum, GenericPackedTensorAccessor save_mean, GenericPackedTensorAccessor save_transformed_var) { __shared__ int shared_n[2 * 2 * C10_WARP_SIZE + C10_WARP_SIZE]; int plane = blockIdx.x; int N = input.size(0) * input.size(2); int tid = threadIdx.x + threadIdx.y * blockDim.x; // Compute the mean and variance across (batch, x/y/z) // this uses the Welford (in the for loop)/parallel algorithm (to sum across the block) // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm // and the parallel algorithm on the same page. // We use two shuffles to reduce across the entire block. // https://devblogs.nvidia.com/faster-parallel-reductions-kepler/ has a description. stat_accscalar_t* shared_avg_var = (stat_accscalar_t*) &shared_n[C10_WARP_SIZE]; // first the reductions each thread does separately stat_accscalar_t avg = 0; stat_accscalar_t var_n = 0; int n = 0; for (int batch = threadIdx.y; batch < input.size(0); batch += blockDim.y) { for (int x = threadIdx.x; x < input.size(2); x += blockDim.x) { stat_accscalar_t v = input[batch][plane][x]; stat_accscalar_t d1 = v - avg; n++; avg += d1 / n; var_n += d1 * (v - avg); } } // first warpSum to get one value per thread to // one value per warp for (int i = 0; i < getMSB(C10_WARP_SIZE); ++i) { stat_accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, C10_WARP_SIZE); int o_n = WARP_SHFL_XOR(n, 1 << i, C10_WARP_SIZE); stat_accscalar_t factor = 1.0 / fmaxf(1.0, n+o_n); var_n += WARP_SHFL_XOR(var_n, 1 << i, C10_WARP_SIZE) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor; avg = (n * avg + o_n * o_avg) * factor; n += o_n; } // this writes each warps item into shared memory // there are at most C10_WARP_SIZE items left because // there are at most C10_WARP_SIZE**2 threads at the beginning __syncthreads(); if (tid % C10_WARP_SIZE == 0) { shared_n[tid / C10_WARP_SIZE] = n; shared_avg_var[tid / C10_WARP_SIZE * 2] = avg; shared_avg_var[tid / C10_WARP_SIZE * 2 + 1] = var_n; } __syncthreads(); // now have a second warpSum to reduce the intermediate values // from shared memory to a single number. The very first // thread writes it to shared memory. if (tid < C10_WARP_SIZE) { n = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_n[tid] : 0); avg = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_avg_var[2 * tid] : stat_accscalar_t(0)); var_n = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_avg_var[2 * tid + 1] : stat_accscalar_t(0)); } for (int i = 0; i < getMSB(C10_WARP_SIZE); ++i) { stat_accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, C10_WARP_SIZE); int o_n = WARP_SHFL_XOR(n, 1 << i, C10_WARP_SIZE); stat_accscalar_t factor = 1.0 / fmaxf(1.0, n+o_n); var_n += WARP_SHFL_XOR(var_n, 1 << i, C10_WARP_SIZE) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor; avg = (n * avg + o_n * o_avg) * factor; n += o_n; } // Save the mean, variance, and moving averages if (tid == 0) { if (save_mean.data() != NULL) { save_mean[plane] = avg; } if (save_transformed_var.data() != NULL) { save_transformed_var[plane] = VarTransform{}(var_n / N, epsilon); } } } template __global__ void batch_norm_backward_kernel( const GenericPackedTensorAccessor input, const GenericPackedTensorAccessor grad_output, GenericPackedTensorAccessor grad_input, GenericPackedTensorAccessor grad_weight, GenericPackedTensorAccessor grad_bias, const GenericPackedTensorAccessor weight, const GenericPackedTensorAccessor running_mean, const GenericPackedTensorAccessor running_var, const GenericPackedTensorAccessor save_mean, const GenericPackedTensorAccessor save_invstd, bool train, stat_accscalar_t epsilon) { index_t plane = blockIdx.x; index_t N = grad_output.size(0) * grad_output.size(2); stat_accscalar_t mean, invstd; if (train) { mean = save_mean[plane]; invstd = save_invstd[plane]; } else { mean = static_cast(running_mean[plane]); invstd = static_cast(1) / device_sqrt(static_cast(running_var[plane]) + epsilon); } stat_accscalar_t weight_val = weight.size(0) > 0 ? static_cast(weight[plane]) : stat_accscalar_t(1); stat_accscalar_t norm = stat_accscalar_t(1) / N; // Compute two values across (batch, x/y/z) in one pass: // 1. Sum(grad_output) // 2. DotProduct(input - mean, grad_output) GradOp> g(mean, input, grad_output); auto res = reduce>(g, grad_output, plane); stat_accscalar_t grad_output_sum = res.v1; stat_accscalar_t dot_p = res.v2; stat_accscalar_t grad_mean = grad_output_sum * norm; stat_accscalar_t proj_scale = dot_p * norm * invstd * invstd; stat_accscalar_t grad_scale = invstd * weight_val; if (grad_input.data() != NULL) { for (int batch = threadIdx.y; batch < grad_output.size(0); batch += blockDim.y) { for (int x = threadIdx.x; x < grad_output.size(2); x += blockDim.x) { input_scalar_t go = grad_output[batch][plane][x]; if (train) { stat_accscalar_t inp = input[batch][plane][x]; stat_accscalar_t proj = (inp - mean) * proj_scale; grad_input[batch][plane][x] = static_cast((go - proj - grad_mean) * grad_scale); } else { grad_input[batch][plane][x] = static_cast(go * grad_scale); } } } } if (grad_weight.size(0) > 0) { if (threadIdx.x == 0) { grad_weight[plane] = static_cast(dot_p * invstd); } } if (grad_bias.size(0) > 0) { if (threadIdx.x == 0) { grad_bias[plane] = static_cast(grad_output_sum); } } } template __global__ void batch_norm_reduce_statistics_kernel( const GenericPackedTensorAccessor vec_mean, const GenericPackedTensorAccessor vec_invstd, GenericPackedTensorAccessor mean, GenericPackedTensorAccessor invstd, GenericPackedTensorAccessor running_mean, GenericPackedTensorAccessor running_var, const accscalar_t epsilon, const accscalar_t momentum, const GenericPackedTensorAccessor counts) { int feature_size = vec_mean.size(1); int world_size = vec_mean.size(0); int bid = blockIdx.x; int tid = threadIdx.x; // first the reductions each thread does separately for (int i = bid*blockDim.x+tid; i < feature_size; i += gridDim.x*blockDim.x) { accscalar_t avg = 0; accscalar_t var_n = 0; index_t n = 0; for (int j = 0; j < world_size; j++) { scalar_t count = counts[j]; accscalar_t m = vec_mean[j][i]; accscalar_t v = accscalar_t(1.0) / (vec_invstd[j][i]); v = (v * v - epsilon) * count; accscalar_t factor = 1.0 / (n + count); var_n += v + (avg - m) * (avg - m) * n * count * factor; avg = n * factor * avg + count * factor * m; n += count; } mean[i] = avg; invstd[i] = static_cast(1) / device_sqrt(var_n / n + epsilon); if (running_mean.data() != NULL) { running_mean[i] = static_cast((1 - momentum) * running_mean[i] + momentum * avg); } accscalar_t unbiasedVar = var_n / (n - 1); if (running_var.data() != NULL) { running_var[i] = static_cast((1 - momentum) * running_var[i] + momentum * unbiasedVar); } } } template __global__ void batch_norm_backward_reduce_kernel( const GenericPackedTensorAccessor input, const GenericPackedTensorAccessor grad_output, GenericPackedTensorAccessor mean, GenericPackedTensorAccessor invstd, GenericPackedTensorAccessor sum_dy, GenericPackedTensorAccessor sum_dy_xmu, GenericPackedTensorAccessor grad_weight, GenericPackedTensorAccessor grad_bias) { index_t plane = blockIdx.x; stat_accscalar_t r_mean = mean[plane]; stat_accscalar_t factor = invstd[plane]; GradOp> g(r_mean, input, grad_output); auto res = reduce>(g, grad_output, plane); if (threadIdx.x == 0) { if (grad_weight.size(0) > 0) { grad_weight[plane] = static_cast(res.v2 * factor); } if (grad_bias.size(0) > 0) { grad_bias[plane] = static_cast(res.v1); } if (sum_dy.size(0) > 0) { sum_dy[plane] = static_cast(res.v1); } if (sum_dy_xmu.size(0) > 0) { sum_dy_xmu[plane] = static_cast(res.v2); } } } template __device__ __forceinline__ void batch_norm_backward_elemt_kernel_impl( const GenericPackedTensorAccessor input, const GenericPackedTensorAccessor grad_output, const GenericPackedTensorAccessor mean, const GenericPackedTensorAccessor invstd, const GenericPackedTensorAccessor weight, const GenericPackedTensorAccessor sum_dy, const GenericPackedTensorAccessor sum_dy_xmu, GenericPackedTensorAccessor grad_input, const stat_accscalar_t norm_fct) { index_t plane = blockIdx.x; if (plane >= input.size(1)) { return; } stat_accscalar_t m_c = mean[plane]; stat_accscalar_t m_dy_c = sum_dy[plane] * norm_fct; stat_accscalar_t factor_1_c = invstd[plane]; stat_accscalar_t factor_2_c = weight.size(0) > 0 ? static_cast(weight[plane]) : stat_accscalar_t(1); factor_2_c *= factor_1_c; factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[plane] * norm_fct; index_t bs = input.size(0); index_t fs = input.size(2); index_t bstep = blockDim.y * gridDim.y; for (index_t batch = threadIdx.y + blockIdx.y * blockDim.y; batch < bs; batch += bstep) { auto g_i = grad_input[batch][plane]; auto g_o = grad_output[batch][plane]; auto i = input[batch][plane]; for (index_t feature = threadIdx.x; feature < fs; feature += blockDim.x) { g_i[feature] = static_cast((g_o[feature] - m_dy_c - (i[feature] - m_c) * factor_1_c) * factor_2_c); } } } template __global__ void batch_norm_backward_elemt_kernel( const GenericPackedTensorAccessor input, const GenericPackedTensorAccessor grad_output, const GenericPackedTensorAccessor mean, const GenericPackedTensorAccessor invstd, const GenericPackedTensorAccessor weight, const GenericPackedTensorAccessor sum_dy, const GenericPackedTensorAccessor sum_dy_xmu, GenericPackedTensorAccessor grad_input, const int* __restrict__ numel, const int world_size) { int64_t total_numel = 0; for (int i = 0; i < world_size; i ++) { total_numel += numel[i]; } const stat_accscalar_t norm_fct = static_cast(1) / static_cast(total_numel); batch_norm_backward_elemt_kernel_impl( input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct); } template __global__ void batch_norm_backward_elemt_kernel( const GenericPackedTensorAccessor input, const GenericPackedTensorAccessor grad_output, const GenericPackedTensorAccessor mean, const GenericPackedTensorAccessor invstd, const GenericPackedTensorAccessor weight, const GenericPackedTensorAccessor sum_dy, const GenericPackedTensorAccessor sum_dy_xmu, GenericPackedTensorAccessor grad_input, const stat_accscalar_t norm_fct) { batch_norm_backward_elemt_kernel_impl( input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct); } template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> static GenericPackedTensorAccessor get_packed_accessor( const Tensor& t, c10::string_view var_name) { constexpr auto expect_type = c10::CppTypeToScalarType::value; const auto actual_type = t.scalar_type(); TORCH_CHECK(actual_type == expect_type, "Expected ", var_name, " to have type ", expect_type, " but got ", actual_type); return t.generic_packed_accessor(); } template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> static GenericPackedTensorAccessor packed_accessor_or_dummy( const Tensor& t, c10::string_view var_name) { if (!t.defined()) { const std::array zeros{{0}}; return GenericPackedTensorAccessor(nullptr, zeros.data(), zeros.data()); } return get_packed_accessor(t, var_name); } template std::tuple batch_norm_backward_cuda_template(const Tensor& grad_out_, const Tensor& input_, const Tensor& weight_, const Tensor& running_mean_, const Tensor& running_var_, const Tensor& save_mean_, const Tensor& save_invstd_, bool train, double epsilon, std::array grad_input_mask) { using accscalar_t = at::acc_type; Tensor grad_input_; Tensor grad_input_reshaped; Tensor grad_weight_; Tensor grad_bias_; auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes()); if (grad_input_mask[0]) { grad_input_ = at::empty_like(input_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); grad_input_reshaped = grad_input_.view(input_reshaped.sizes()); } if (grad_input_mask[1]) { grad_weight_ = at::empty_like(weight_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } if (grad_input_mask[2]) { grad_bias_ = at::empty_like(weight_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } auto input = get_packed_accessor< input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input"); auto grad_output = get_packed_accessor< input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output"); auto grad_input = packed_accessor_or_dummy< input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input"); auto weight = packed_accessor_or_dummy< stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight"); auto grad_weight = packed_accessor_or_dummy< stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_weight_, "grad_weight"); auto grad_bias = packed_accessor_or_dummy< stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_bias_, "grad_bias"); auto running_mean = packed_accessor_or_dummy< stat_scalar_t, 1, DefaultPtrTraits, index_t>(running_mean_, "running_mean"); auto running_var = packed_accessor_or_dummy< stat_scalar_t, 1, DefaultPtrTraits, index_t>(running_var_, "running_var"); auto save_mean = packed_accessor_or_dummy< accscalar_t, 1, DefaultPtrTraits, index_t>(save_mean_, "save_mean"); auto save_invstd = packed_accessor_or_dummy< accscalar_t, 1, DefaultPtrTraits, index_t>(save_invstd_, "save_invstd"); auto stream = at::cuda::getCurrentCUDAStream(); dim3 blocks(input.size(1)); int tf = getNumThreads(input.size(2)); dim3 threads(tf, std::max(1, MAX_BLOCK_SIZE/tf)); batch_norm_backward_kernel <<>> (input, grad_output, grad_input, grad_weight, grad_bias, weight, running_mean, running_var, save_mean, save_invstd, train, epsilon); C10_CUDA_KERNEL_LAUNCH_CHECK(); return std::make_tuple(grad_input_, grad_weight_, grad_bias_); } template void batch_norm_stats_cuda_template( const Tensor& out_mean, const Tensor& out_invstd, const Tensor& input_, double epsilon) { using accscalar_t = at::acc_type; int64_t n_input = input_.size(1); Tensor dummy_mean_; Tensor dummy_var_; auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions resize_output(out_mean, {n_input}); resize_output(out_invstd, {n_input}); auto input = get_packed_accessor< scalar_t, 3, RestrictPtrTraits, index_t>(input_reshaped, "input"); TORCH_INTERNAL_ASSERT(out_invstd.dim() == 1 && out_invstd.is_contiguous() && out_invstd.sizes()[0]); TORCH_INTERNAL_ASSERT(out_mean.dim() == 1 && out_mean.is_contiguous() && out_mean.sizes()[0]); auto mean = packed_accessor_or_dummy< accscalar_t, 1, RestrictPtrTraits, index_t>(out_mean, "out_mean"); auto invstd = packed_accessor_or_dummy< accscalar_t, 1, RestrictPtrTraits, index_t>(out_invstd, "out_invstd"); auto stream = at::cuda::getCurrentCUDAStream(); dim3 blocks(input.size(1)); int tf = getNumThreads(input.size(2)); dim3 threads(tf, std::max(1, MAX_BLOCK_SIZE/tf)); batch_norm_collect_statistics_kernel <<>> (input, epsilon, 0.0, mean, invstd); C10_CUDA_KERNEL_LAUNCH_CHECK(); } template void batch_norm_elemt_cuda_template(const Tensor& output_, const Tensor& input_, const Tensor& weight_, const Tensor& bias_, const Tensor& mean_, const Tensor& invstd_) { using stat_accscalar_t = at::acc_type; int64_t n_input = input_.size(1); auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions auto output_reshaped = output_.view({input_.size(0), input_.size(1), -1}); auto input = get_packed_accessor< input_scalar_t, 3, RestrictPtrTraits, index_t>(input_reshaped, "input"); auto output = get_packed_accessor< input_scalar_t, 3, RestrictPtrTraits, index_t>(output_reshaped, "output"); auto weight = packed_accessor_or_dummy< stat_scalar_t, 1, RestrictPtrTraits, index_t>(weight_, "weight"); auto bias = packed_accessor_or_dummy< stat_scalar_t, 1, RestrictPtrTraits, index_t>(bias_, "bias"); auto mean = packed_accessor_or_dummy< stat_accscalar_t, 1, RestrictPtrTraits, index_t>(mean_, "mean"); auto invstd = packed_accessor_or_dummy< stat_accscalar_t, 1, RestrictPtrTraits, index_t>(invstd_, "invstd"); auto stream = at::cuda::getCurrentCUDAStream(); // NOTE: We use transform_input_kernel in training mode, which ignores epsilon const double dummy_epsilon = 1e-5; // The input_transform kernel is pointwise, but we need to balance reading parameters (save_var/mean, // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks // and good occupancy. Quiet likely, we could go with even more blocks than 1024. // The various planes are independent, so we use blocks for them. int tf = std::max(getNumThreads(input.size(2)/4), std::min(getNumThreads(input.size(2)), 64)); int tb = std::max(64/tf, 1); dim3 blocks_trans(input.size(1), std::max(1, std::min((256*1024)/input.size(1), (input.size(0)+tb-1)/tb))); blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE); dim3 threads_trans(tf, tb); batch_norm_transform_input_kernel <<>> (input, output, mean, invstd, weight, bias, dummy_epsilon); C10_CUDA_KERNEL_LAUNCH_CHECK(); } template std::tuple batch_norm_gather_stats_cuda_template(const Tensor& mean_, const Tensor& invstd_, const Tensor& running_mean_, const Tensor& running_var_, double momentum, double epsilon, const Tensor& counts_) { Tensor save_mean_; Tensor save_invstd_; auto features = mean_.size(1); auto input_options = mean_.options(); if (mean_.scalar_type() == at::ScalarType::Half || mean_.scalar_type() == at::ScalarType::BFloat16) { input_options = input_options.dtype(ScalarType::Float); } save_mean_ = at::empty({features}, input_options); save_invstd_ = at::empty({features}, input_options); auto mean = packed_accessor_or_dummy< accscalar_t, 2, RestrictPtrTraits, index_t>(mean_, "mean"); auto invstd = packed_accessor_or_dummy< accscalar_t, 2, RestrictPtrTraits, index_t>(invstd_, "invstd"); auto running_mean = packed_accessor_or_dummy< scalar_t, 1, RestrictPtrTraits, index_t>(running_mean_, "running_mean"); auto running_var = packed_accessor_or_dummy< scalar_t, 1, RestrictPtrTraits, index_t>(running_var_, "running_mean"); auto counts = packed_accessor_or_dummy< scalar_t, 1, RestrictPtrTraits, index_t>(counts_, "counts"); auto save_mean = get_packed_accessor< accscalar_t, 1, RestrictPtrTraits, index_t>(save_mean_, "save_mean"); auto save_invstd = get_packed_accessor< accscalar_t, 1, RestrictPtrTraits, index_t>(save_invstd_, "save_invstd"); auto stream = at::cuda::getCurrentCUDAStream(); int block = getNumThreads(features); int grid = std::max(1, features/block); batch_norm_reduce_statistics_kernel <<>> (mean, invstd, save_mean, save_invstd, running_mean, running_var, epsilon, momentum, counts); C10_CUDA_KERNEL_LAUNCH_CHECK(); return std::make_tuple(save_mean_, save_invstd_); } template std::tuple batch_norm_backward_reduce_cuda_template(const Tensor& grad_out_, const Tensor& input_, const Tensor& mean_, const Tensor& invstd_, const Tensor& weight_, const bool input_g, const bool weight_g, const bool bias_g) { using stat_accscalar_t = at::acc_type; int64_t n_input = input_.size(1); Tensor sum_dy_; Tensor sum_dy_xmu_; Tensor grad_weight_; Tensor grad_bias_; auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes()); if (input_g) { sum_dy_ = at::empty_like(mean_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); sum_dy_xmu_ = at::empty_like(mean_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } if (weight_g) { grad_weight_ = at::empty({n_input}, weight_.options()); } if (bias_g) { grad_bias_ = at::empty({n_input}, weight_.options()); } auto input = get_packed_accessor< input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input"); auto grad_output = get_packed_accessor< input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output"); auto grad_weight = packed_accessor_or_dummy< stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_weight_, "grad_weight"); auto grad_bias = packed_accessor_or_dummy< stat_scalar_t, 1, DefaultPtrTraits, index_t>(grad_bias_, "grad_bias"); auto mean = packed_accessor_or_dummy< stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean"); auto invstd = packed_accessor_or_dummy< stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd"); auto sum_dy = packed_accessor_or_dummy< stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy"); auto sum_dy_xmu = packed_accessor_or_dummy< stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu"); auto batch_size = input_reshaped.size(0); auto feature_size = input_reshaped.size(2); auto stream = at::cuda::getCurrentCUDAStream(); int warp_size = at::cuda::warp_size(); int block_y = std::min(lastPow2(batch_size), MAX_BLOCK_SIZE/warp_size); // We want block_x to be at least a warp width int block_x = std::min(std::max(getNumThreads(feature_size), warp_size), MAX_BLOCK_SIZE/block_y); const dim3 block(block_x, block_y); const dim3 grid(n_input); batch_norm_backward_reduce_kernel <<>> (input, grad_output, mean, invstd, sum_dy, sum_dy_xmu, grad_weight, grad_bias); C10_CUDA_KERNEL_LAUNCH_CHECK(); return std::make_tuple(sum_dy_, sum_dy_xmu_, grad_weight_, grad_bias_); } template Tensor batch_norm_backward_elemt_cuda_template(const Tensor& grad_out_, const Tensor& input_, const Tensor& mean_, const Tensor& invstd_, const Tensor& weight_, const Tensor& sum_dy_, const Tensor& sum_dy_xmu_) { using stat_accscalar_t = at::acc_type; int64_t n_input = input_.size(1); auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes()); auto grad_input_reshaped = at::empty_like(input_reshaped, LEGACY_CONTIGUOUS_MEMORY_FORMAT); auto input = get_packed_accessor< input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input"); auto grad_input = get_packed_accessor< input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input"); auto grad_output = get_packed_accessor< input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output"); auto mean = packed_accessor_or_dummy< stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean"); auto invstd = packed_accessor_or_dummy< stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd"); auto weight = packed_accessor_or_dummy< stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight"); auto sum_dy = packed_accessor_or_dummy< stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy"); auto sum_dy_xmu = packed_accessor_or_dummy< stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu"); auto stream = at::cuda::getCurrentCUDAStream(); // The kernel is pointwise, but we need to balance reading parameters (save_var/mean, // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks // and good occupancy. Quiet likely, we could go with even more blocks than 1024. // The various planes are independent, so we use blocks for them. int tf = std::max(getNumThreads(input.size(2)/4), std::min(getNumThreads(input.size(2)), 64)); int tb = std::max(64/tf, 1); dim3 blocks_trans(input.size(1), std::max(1, std::min((256*1024)/input.size(1), (input.size(0)+tb-1)/tb))); blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE); dim3 threads_trans(tf, tb); auto reduction_size = input_.numel() / n_input; auto norm_fct = static_cast(1.0 / reduction_size); batch_norm_backward_elemt_kernel <<>> (input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct); C10_CUDA_KERNEL_LAUNCH_CHECK(); return grad_input_reshaped.view(input_.sizes()); } template Tensor batch_norm_backward_elemt_cuda_template(const Tensor& grad_out_, const Tensor& input_, const Tensor& mean_, const Tensor& invstd_, const Tensor& weight_, const Tensor& sum_dy_, const Tensor& sum_dy_xmu_, const Tensor& count) { using stat_accscalar_t = at::acc_type; int64_t n_input = input_.size(1); auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes()); auto grad_input_reshaped = at::empty_like(input_reshaped, LEGACY_CONTIGUOUS_MEMORY_FORMAT); auto input = get_packed_accessor< input_scalar_t, 3, DefaultPtrTraits, index_t>(input_reshaped, "input"); auto grad_input = get_packed_accessor< input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped, "grad_input"); auto grad_output = get_packed_accessor< input_scalar_t, 3, DefaultPtrTraits, index_t>(grad_output_reshaped, "grad_output"); auto mean = packed_accessor_or_dummy< stat_accscalar_t, 1, DefaultPtrTraits, index_t>(mean_, "mean"); auto invstd = packed_accessor_or_dummy< stat_accscalar_t, 1, DefaultPtrTraits, index_t>(invstd_, "invstd"); auto weight = packed_accessor_or_dummy< stat_scalar_t, 1, DefaultPtrTraits, index_t>(weight_, "weight"); auto sum_dy = packed_accessor_or_dummy< stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_, "sum_dy"); auto sum_dy_xmu = packed_accessor_or_dummy< stat_accscalar_t, 1, DefaultPtrTraits, index_t>(sum_dy_xmu_, "sum_dy_xmu"); auto stream = at::cuda::getCurrentCUDAStream(); // The kernel is pointwise, but we need to balance reading parameters (save_var/mean, // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks // and good occupancy. Quiet likely, we could go with even more blocks than 1024. // The various planes are independent, so we use blocks for them. int tf = std::max(getNumThreads(input.size(2)/4), std::min(getNumThreads(input.size(2)), 64)); int tb = std::max(64/tf, 1); dim3 blocks_trans(input.size(1), std::max(1, std::min((256*1024)/input.size(1), (input.size(0)+tb-1)/tb))); blocks_trans.y = std::min(blocks_trans.y, MAX_GRID_SIZE); dim3 threads_trans(tf, tb); batch_norm_backward_elemt_kernel <<>> (input, grad_output, mean, invstd, weight, sum_dy, sum_dy_xmu, grad_input, count.data_ptr(), count.numel()); C10_CUDA_KERNEL_LAUNCH_CHECK(); return grad_input_reshaped.view(input_.sizes()); } // welford kernel for c last tensor calculating mean/biased_variance/unbiased_variance // original apex name: welford_kernel_c_last template __global__ void batch_norm_collect_statistics_channels_last_kernel( const scalar_t* __restrict__ input, accscalar_t* __restrict__ out_mean, accscalar_t* __restrict__ out_invstd, volatile accscalar_t* staging_data, int* semaphores, const int reduction_size, const int stride, accscalar_t epsilon) { // hide latency with concurrency accscalar_t x_mean[PARALLEL_LOADS]; accscalar_t m_2_n[PARALLEL_LOADS]; int count[PARALLEL_LOADS]; #pragma unroll for (int i = 0; i < PARALLEL_LOADS; i++) { x_mean[i] = accscalar_t(0); m_2_n[i] = accscalar_t(0); count[i] = accscalar_t(0); } // tensor dimension (m,c) // loop along m dimension int inner_loop_stride = blockDim.y * gridDim.y; // offset along m dimension int m_offset = blockIdx.y * blockDim.y + threadIdx.y; int c_offset = blockIdx.x * blockDim.x + threadIdx.x; int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); int address_base = m_offset * stride + c_offset; int address_increment = inner_loop_stride * stride; for (int i = 0; i < loop_count; i++) { accscalar_t x_math[PARALLEL_LOADS]; accscalar_t x_count_inv[PARALLEL_LOADS]; accscalar_t is_valid[PARALLEL_LOADS]; // load multiple data in #pragma unroll for (int j = 0; j < PARALLEL_LOADS; j++) { if (c_offset < stride && m_offset < reduction_size) { x_math[j] = input[address_base]; count[j]++; x_count_inv[j] = accscalar_t(1) / count[j]; is_valid[j] = accscalar_t(1); } else { x_math[j] = accscalar_t(0); x_count_inv[j] = accscalar_t(0); is_valid[j] = accscalar_t(0); } m_offset += inner_loop_stride; address_base += address_increment; } // calculate mean/m2n with welford #pragma unroll for (int j = 0; j < PARALLEL_LOADS; j++) { accscalar_t delta0 = x_math[j] - x_mean[j]; x_mean[j] += delta0 * x_count_inv[j]; accscalar_t delta1 = x_math[j] - x_mean[j]; m_2_n[j] += delta0 * delta1 * is_valid[j]; } } // thread reduction to accumulate mean/m_2_n/count between PARALLEL_LOADS #pragma unroll for (int j = 1; j < PARALLEL_LOADS; j++) { welford_merge_element(count[0], x_mean[0], m_2_n[0], count[j], x_mean[j], m_2_n[j]); } // release x_mean / m_2_n auto mean_th = x_mean[0]; auto m2_th = m_2_n[0]; auto count_th = count[0]; // block-wise reduction with shared memory (since reduction cannot be done within a warp) static __shared__ accscalar_t shmem_mean[MAX_BLOCK_SIZE]; static __shared__ accscalar_t shmem_m2n[MAX_BLOCK_SIZE]; static __shared__ int shmem_count[MAX_BLOCK_SIZE]; welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n); if (gridDim.y > 1) { volatile accscalar_t* staging_mean = staging_data; volatile accscalar_t* staging_m2n = &staging_data[stride*gridDim.y]; volatile int* staging_count = reinterpret_cast(&staging_m2n[stride*gridDim.y]); address_base = c_offset + blockIdx.y * stride; // write data to staging_data; if (threadIdx.y == 0 && c_offset < stride) { staging_mean[address_base] = mean_th; staging_m2n[address_base] = m2_th; staging_count[address_base] = count_th; } __threadfence(); __syncthreads(); // ensuring writes to staging_ is visible to all blocks __shared__ bool is_last_block_done; // mark block done if (threadIdx.x == 0 && threadIdx.y == 0) { int old = atomicAdd(&semaphores[blockIdx.x], 1); is_last_block_done = (old == (gridDim.y-1)); } __syncthreads(); // check that all data is now available in global memory if (is_last_block_done) { count_th = 0; mean_th = accscalar_t(0.0); m2_th = accscalar_t(0.0); for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) { address_base = c_offset + y * stride; int count_new = c_offset < stride ? staging_count[address_base] : 0; accscalar_t mean_new = c_offset < stride ? staging_mean[address_base] : accscalar_t(0.0); accscalar_t m2n_new = c_offset < stride ? staging_m2n[address_base] : accscalar_t(0.0); welford_merge_element(count_th, mean_th, m2_th, count_new, mean_new, m2n_new); } welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n); if (threadIdx.y == 0 && c_offset < stride) { out_mean[c_offset] = static_cast(mean_th); out_invstd[c_offset] = VarTransform{}(m2_th/count_th, epsilon); } } } else { if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) { out_mean[c_offset] = static_cast(mean_th); out_invstd[c_offset] = VarTransform{}(m2_th/count_th, epsilon); } } } // elementwise BN kernel // original apex name: batchnorm_forward_c_last_kernel template < typename scalar_t, typename accscalar_t, typename layerscalar_t, int PARALLEL_LOADS> __global__ void batch_norm_transform_input_channels_last_kernel( const scalar_t* __restrict__ input, const scalar_t* __restrict__ z, const accscalar_t* __restrict__ mean, const accscalar_t* __restrict__ inv_std, const layerscalar_t* __restrict__ weight, const layerscalar_t* __restrict__ shift, scalar_t* __restrict__ out, const int reduction_size, const int stride, const bool fuse_relu) { // tensor dimension (m,c) // loop along m dimension int inner_loop_stride = blockDim.y * gridDim.y; // offset along m dimension int m_offset = blockIdx.y * blockDim.y + threadIdx.y; int c_offset = blockIdx.x * blockDim.x + threadIdx.x; if (c_offset >= stride || m_offset >= reduction_size) { return; } auto m_c = mean[c_offset]; auto inv_std_c = static_cast(inv_std[c_offset]); auto w_c = weight == nullptr ? accscalar_t(1.0) : static_cast(weight[c_offset]); auto s_c = shift == nullptr ? accscalar_t(0.0) : static_cast(shift[c_offset]); int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); int address_base = m_offset * stride + c_offset; int address_increment = inner_loop_stride * stride; for (int i = 0; i < loop_count; i++) { #pragma unroll for (int j = 0; j < PARALLEL_LOADS; j++) { if (c_offset < stride && m_offset < reduction_size) { auto tmp = w_c * (static_cast(input[address_base]) - m_c ) * inv_std_c + s_c; if (z != nullptr) { tmp += z[address_base]; } out[address_base] = (fuse_relu && tmp <= accscalar_t(0.0) ? scalar_t(0.0) : static_cast(tmp)); } m_offset += inner_loop_stride; address_base += address_increment; } } } template __device__ __forceinline__ void merge_block_vertical_backward(T& sum_dy, T& sum_dy_xmu, T* shmem_sum_dy, T* shmem_sum_dy_xmu) { // write to shared memory auto address_base = threadIdx.x + threadIdx.y * blockDim.x; #pragma unroll for (int offset = blockDim.y/2; offset > 0; offset >>= 1) { if (threadIdx.y < offset*2) { shmem_sum_dy[address_base] = sum_dy; shmem_sum_dy_xmu[address_base] = sum_dy_xmu; } __syncthreads(); if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) { auto address = address_base + offset * blockDim.x; sum_dy += shmem_sum_dy[address]; sum_dy_xmu += shmem_sum_dy_xmu[address]; } } } // batchnorm backward kernel for c last tensor // original apex name: reduce_bn_c_last_kernel template < int PARALLEL_LOADS, typename scalar_t, typename accscalar_t, typename layerscalar_t> __global__ void batch_norm_backward_reduce_channels_last_kernel( const scalar_t* __restrict__ input, const scalar_t* __restrict__ grad_output, const accscalar_t* __restrict__ mean, const accscalar_t* __restrict__ inv_std, accscalar_t* __restrict__ sum_dy_o, accscalar_t* __restrict__ sum_dy_xmu_o, layerscalar_t* __restrict__ grad_weight, layerscalar_t* __restrict__ grad_bias, volatile accscalar_t* staging_data, int* semaphores, const int reduction_size, const int stride) { // hide latency with concurrency accscalar_t sum_dy[PARALLEL_LOADS]; accscalar_t sum_dy_xmu[PARALLEL_LOADS]; #pragma unroll for (int i = 0; i < PARALLEL_LOADS; i++) { sum_dy[i] = accscalar_t(0); sum_dy_xmu[i] = accscalar_t(0); } // tensor dimension (m,c) // loop along m dimension int inner_loop_stride = blockDim.y * gridDim.y; // offset along m dimension int m_offset = blockIdx.y * blockDim.y + threadIdx.y; int c_offset = blockIdx.x * blockDim.x + threadIdx.x; if (c_offset >= stride || m_offset >= reduction_size) { return; } int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); int address_base = m_offset * stride + c_offset; int address_increment = inner_loop_stride * stride; auto r_mean = mean[c_offset]; auto factor = inv_std[c_offset]; for (int i = 0; i < loop_count; i++) { accscalar_t x_input[PARALLEL_LOADS]; accscalar_t x_grad_output[PARALLEL_LOADS]; // load multiple data in #pragma unroll for (int j = 0; j < PARALLEL_LOADS; j++) { if (c_offset < stride && m_offset < reduction_size) { x_input[j] = input[address_base]; x_grad_output[j] = grad_output[address_base]; } else { x_input[j] = accscalar_t(0); x_grad_output[j] = accscalar_t(0); } m_offset += inner_loop_stride; address_base += address_increment; } // calculate sum_dy / sum_dy_xmu #pragma unroll for (int j = 0; j < PARALLEL_LOADS; j++) { sum_dy[j] += x_grad_output[j]; sum_dy_xmu[j] += x_grad_output[j] * (x_input[j] - r_mean); } } // thread reduction to accumulate sum_dy / sum_dy_xmu between PARALLEL_LOADS #pragma unroll for (int j = 1; j < PARALLEL_LOADS; j++) { sum_dy[0] += sum_dy[j]; sum_dy_xmu[0] += sum_dy_xmu[j]; } // release array of registers auto sum_dy_th = sum_dy[0]; auto sum_dy_xmu_th = sum_dy_xmu[0]; // block-wise reduction with shared memory (since reduction cannot be done within a warp) static __shared__ accscalar_t shmem_sum_dy[MAX_BLOCK_SIZE]; static __shared__ accscalar_t shmem_sum_dy_xmu[MAX_BLOCK_SIZE]; merge_block_vertical_backward(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu); if (gridDim.y > 1) { volatile accscalar_t* staging_sum_dy = staging_data; volatile accscalar_t* staging_sum_dy_xmu = &staging_data[stride*gridDim.y]; address_base = c_offset + blockIdx.y * stride; // write data to staging_data; if (threadIdx.y == 0 && c_offset < stride) { staging_sum_dy[address_base] = sum_dy_th; staging_sum_dy_xmu[address_base] = sum_dy_xmu_th; } __threadfence(); __syncthreads(); // ensuring writes to staging_ is visible to all blocks __shared__ bool is_last_block_done; // mark block done if (threadIdx.x == 0 && threadIdx.y == 0) { int old = atomicAdd(&semaphores[blockIdx.x], 1); is_last_block_done = (old == (gridDim.y-1)); } __syncthreads(); // check that all data is now available in global memory if (is_last_block_done) { sum_dy_th = accscalar_t(0.0); sum_dy_xmu_th = accscalar_t(0.0); for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) { address_base = c_offset + y * stride; sum_dy_th += (c_offset < stride ? staging_sum_dy[address_base] : accscalar_t(0.0)); sum_dy_xmu_th += (c_offset < stride ? staging_sum_dy_xmu[address_base] : accscalar_t(0.0)); } merge_block_vertical_backward(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu); if (threadIdx.y == 0 && c_offset < stride) { if (grad_bias != nullptr) { grad_bias[c_offset] = static_cast(sum_dy_th); } if (grad_weight != nullptr) { grad_weight[c_offset] = static_cast(sum_dy_xmu_th * factor); } //mean_dy[c_offset] = sum_dy_th / reduction_size; //mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size; sum_dy_o[c_offset] = sum_dy_th; sum_dy_xmu_o[c_offset] = sum_dy_xmu_th; } } } else { if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) { if (grad_bias != nullptr) { grad_bias[c_offset] = static_cast(sum_dy_th); } if (grad_weight != nullptr) { grad_weight[c_offset] = static_cast(sum_dy_xmu_th * factor); } //mean_dy[c_offset] = sum_dy_th / reduction_size; //mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size; sum_dy_o[c_offset] = sum_dy_th; sum_dy_xmu_o[c_offset] = sum_dy_xmu_th; } } } // elementwise BN kernel // original apex name: batchnorm_backward_c_last_kernel template < int PARALLEL_LOADS, typename scalar_t, typename accscalar_t, typename layerscalar_t> __device__ __forceinline__ void batch_norm_backward_elemt_channels_last_kernel_impl( const scalar_t* __restrict__ grad_output, const scalar_t* __restrict__ input, const accscalar_t* __restrict__ mean, const accscalar_t* __restrict__ inv_std, const layerscalar_t* __restrict__ weight, const accscalar_t* __restrict__ sum_dy, const accscalar_t* __restrict__ sum_dy_xmu, scalar_t* __restrict__ grad_input, const accscalar_t norm_fct, const int reduction_size, const int stride) { // tensor dimension (m,c) // loop along m dimension int inner_loop_stride = blockDim.y * gridDim.y; // offset along m dimension int m_offset = blockIdx.y * blockDim.y + threadIdx.y; int c_offset = blockIdx.x * blockDim.x + threadIdx.x; if (c_offset >= stride || m_offset >= reduction_size) { return; } auto m_c = mean[c_offset]; auto m_dy_c = sum_dy[c_offset] * norm_fct; auto factor_1_c = inv_std[c_offset]; auto factor_2_c = (weight == nullptr? accscalar_t(1.0) : static_cast(weight[c_offset])) * factor_1_c; factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[c_offset] * norm_fct; int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); int address_base = m_offset * stride + c_offset; int address_increment = inner_loop_stride * stride; for (int i = 0; i < loop_count; i++) { #pragma unroll for (int j = 0; j < PARALLEL_LOADS; j++) { if (c_offset < stride && m_offset < reduction_size) { grad_input[address_base] = static_cast( (static_cast(grad_output[address_base]) - m_dy_c - (static_cast(input[address_base]) - m_c) * factor_1_c) * factor_2_c); } m_offset += inner_loop_stride; address_base += address_increment; } } } template < int PARALLEL_LOADS, typename scalar_t, typename accscalar_t, typename layerscalar_t> __global__ void batch_norm_backward_elemt_channels_last_kernel( const scalar_t* __restrict__ grad_output, const scalar_t* __restrict__ input, const accscalar_t* __restrict__ mean, const accscalar_t* __restrict__ inv_std, const layerscalar_t* __restrict__ weight, const accscalar_t* __restrict__ sum_dy, const accscalar_t* __restrict__ sum_dy_xmu, const int* __restrict__ numel, scalar_t* __restrict__ grad_input, const int64_t world_size, const int reduction_size, const int stride) { int64_t total_numel = 0; for (int i = 0; i < world_size; i++) { total_numel += numel[i]; } auto norm_fct = static_cast(1) / static_cast(total_numel); batch_norm_backward_elemt_channels_last_kernel_impl( grad_output, input, mean, inv_std, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct, reduction_size, stride); } template < int PARALLEL_LOADS, typename scalar_t, typename accscalar_t, typename layerscalar_t> __global__ void batch_norm_backward_elemt_channels_last_kernel( const scalar_t* __restrict__ grad_output, const scalar_t* __restrict__ input, const accscalar_t* __restrict__ mean, const accscalar_t* __restrict__ inv_std, const layerscalar_t* __restrict__ weight, const accscalar_t* __restrict__ sum_dy, const accscalar_t* __restrict__ sum_dy_xmu, scalar_t* __restrict__ grad_input, const accscalar_t norm_fct, const int reduction_size, const int stride) { batch_norm_backward_elemt_channels_last_kernel_impl( grad_output, input, mean, inv_std, weight, sum_dy, sum_dy_xmu, grad_input, norm_fct, reduction_size, stride); } template void batch_norm_stats_channels_last_cuda_template( const Tensor& out_mean, const Tensor& out_invstd, const Tensor& input, double epsilon) { using accscalar_t = at::acc_type; const auto stride = input.sizes()[1]; const auto reduction_size = input.numel() / stride; resize_output(out_mean, {stride}); resize_output(out_invstd, {stride}); TORCH_INTERNAL_ASSERT(out_invstd.dim() == 1 && out_invstd.is_contiguous() && out_invstd.sizes()[0]); TORCH_INTERNAL_ASSERT(out_mean.dim() == 1 && out_mean.is_contiguous() && out_mean.sizes()[0]); dim3 block; dim3 grid; flexible_launch_configs(reduction_size, stride, block, grid, true); at::Tensor staging_data; at::Tensor semaphores; if (grid.y > 1) { staging_data = at::empty({4*stride*grid.y}, out_mean.options()); semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt)); } accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data_ptr() : nullptr; int* semaphores_ptr = grid.y > 1 ? semaphores.data_ptr() : nullptr; batch_norm_collect_statistics_channels_last_kernel <<>>( input.data_ptr(), out_mean.data_ptr(), out_invstd.data_ptr(), staging_data_ptr, semaphores_ptr, reduction_size, stride, epsilon); C10_CUDA_KERNEL_LAUNCH_CHECK(); } void batch_norm_elemt_channels_last_cuda_template( const at::Tensor& output, const at::Tensor& input, const at::Tensor& weight, const at::Tensor& shift, // bias of BN const at::Tensor& mean, const at::Tensor& inv_std, const at::optional& z = c10::nullopt, // bias after BN const bool fuse_relu = false) { const auto stride = input.sizes()[1]; const auto reduction_size = input.numel() / stride; dim3 block; dim3 grid; flexible_launch_configs(reduction_size, stride, block, grid); auto stream = at::cuda::getCurrentCUDAStream(); const auto second_dtype = weight.defined() ? weight.scalar_type() : (shift.defined() ? shift.scalar_type() : input.scalar_type()); if (input.scalar_type() != second_dtype) { AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_forward", [&] { using accscalar_t = at::acc_type; batch_norm_transform_input_channels_last_kernel <<>>( input.data_ptr(), z.has_value() ? z.value().data_ptr() : nullptr, mean.data_ptr(), inv_std.data_ptr(), weight.defined() ? weight.data_ptr() : nullptr, shift.defined() ? shift.data_ptr() : nullptr, output.data_ptr(), reduction_size, stride, fuse_relu); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); } else { if (weight.defined()){ TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_forward: input.scalar_type() ", input.scalar_type(), " is not supported with weight.scalar_type() ", weight.scalar_type()); } AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_forward", [&] { using accscalar_t = at::acc_type; batch_norm_transform_input_channels_last_kernel <<>>( input.data_ptr(), z.has_value() ? z.value().data_ptr() : nullptr, mean.data_ptr(), inv_std.data_ptr(), weight.defined() ? weight.data_ptr() : nullptr, shift.defined() ? shift.data_ptr(): nullptr, output.data_ptr(), reduction_size, stride, fuse_relu); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); } } std::tuple batch_norm_backward_reduce_cuda_channels_last_template(const at::Tensor& grad_output, const at::Tensor& input, const at::Tensor& mean, const at::Tensor& inv_std, const at::Tensor& weight, const bool input_g, const bool weight_g, const bool bias_g) { const auto stride = input.sizes()[1]; const auto reduction_size = input.numel() / stride; at::Tensor sumn_dy = at::empty({stride}, mean.options()); at::Tensor sum_dy_xmu = at::empty({stride}, mean.options()); at::Tensor grad_weight; at::Tensor grad_bias; if (weight.defined()) { grad_weight = at::empty({stride}, weight.options()); grad_bias = at::empty({stride}, weight.options()); } else { // because I cannot return an uninitialized at::Tensor grad_weight = at::empty({0}, mean.options()); grad_bias = at::empty({0}, mean.options()); } dim3 block; dim3 grid; flexible_launch_configs(reduction_size, stride, block, grid, true); at::Tensor staging_data; at::Tensor semaphores; if (grid.y > 1) { staging_data = at::empty({2*stride*grid.y}, mean.options()); semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt)); } auto stream = at::cuda::getCurrentCUDAStream(); if (weight.defined() && input.scalar_type() != weight.scalar_type()) { AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_reduce", [&] { using accscalar_t = at::acc_type; accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data_ptr() : nullptr; int* semaphores_ptr = grid.y > 1 ? semaphores.data_ptr() : nullptr; batch_norm_backward_reduce_channels_last_kernel <<>>( input.data_ptr(), grad_output.data_ptr(), mean.data_ptr(), inv_std.data_ptr(), sumn_dy.data_ptr(), sum_dy_xmu.data_ptr(), grad_weight.data_ptr(), grad_bias.data_ptr(), staging_data_ptr, semaphores_ptr, reduction_size, stride); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); } else { if (weight.defined()) { TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_backward_reduce: input.scalar_type() ", input.scalar_type(), " is not supported with weight.scalar_type() ", weight.scalar_type()); } AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_reduce", [&] { using accscalar_t = at::acc_type; accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data_ptr() : nullptr; int* semaphores_ptr = grid.y > 1 ? semaphores.data_ptr() : nullptr; batch_norm_backward_reduce_channels_last_kernel <<>>( input.data_ptr(), grad_output.data_ptr(), mean.data_ptr(), inv_std.data_ptr(), sumn_dy.data_ptr(), sum_dy_xmu.data_ptr(), weight.defined() ? grad_weight.data_ptr() : nullptr, weight.defined() ? grad_bias.data_ptr() : nullptr, staging_data_ptr, semaphores_ptr, reduction_size, stride); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); } return std::make_tuple(sumn_dy, sum_dy_xmu, grad_weight, grad_bias); } at::Tensor batch_norm_backward_elemt_channels_last_cuda_template( const at::Tensor& grad_output, const at::Tensor& input, const at::Tensor& mean, const at::Tensor& inv_std, const at::Tensor& weight, const at::Tensor& sum_dy, const at::Tensor& sum_dy_xmu, const at::Tensor& count) { const auto stride = input.sizes()[1]; const auto reduction_size = input.numel() / stride; // Input is guarunteed to be channels-last compatible at::Tensor grad_input = at::empty_like(input); dim3 block; dim3 grid; flexible_launch_configs(reduction_size, stride, block, grid); auto stream = at::cuda::getCurrentCUDAStream(); if (weight.defined() && weight.scalar_type() != input.scalar_type()) { AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_element", [&] { using accscalar_t = at::acc_type; batch_norm_backward_elemt_channels_last_kernel <<>>( grad_output.data_ptr(), input.data_ptr(), mean.data_ptr(), inv_std.data_ptr(), weight.data_ptr(), sum_dy.data_ptr(), sum_dy_xmu.data_ptr(), count.data_ptr(), grad_input.data_ptr(), count.numel(), reduction_size, stride); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); } else { if (weight.defined()) { TORCH_CHECK(input.scalar_type() == weight.scalar_type(), "batchnorm_backward_element: input.scalar_type() ", input.scalar_type(), " is not supported with weight.scalar_type() ", weight.scalar_type()); } AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "batchnorm_backward_element", [&] { using accscalar_t = at::acc_type; batch_norm_backward_elemt_channels_last_kernel <<>>( grad_output.data_ptr(), input.data_ptr(), mean.data_ptr(), inv_std.data_ptr(), weight.defined() ? weight.data_ptr() : nullptr, sum_dy.data_ptr(), sum_dy_xmu.data_ptr(), count.data_ptr(), grad_input.data_ptr(), count.numel(), reduction_size, stride); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); } return grad_input; } at::Tensor batch_norm_backward_elemt_channels_last_cuda_template( const at::Tensor& grad_output, const at::Tensor& input, const at::Tensor& mean, const at::Tensor& inv_std, const at::Tensor& weight, const at::Tensor& sum_dy, const at::Tensor& sum_dy_xmu) { const auto stride = input.sizes()[1]; const auto reduction_size = input.numel() / stride; auto norm_fct = 1.0 / reduction_size; // Input is guarunteed to be channels-last compatible at::Tensor grad_input = at::empty_like(input); dim3 block; dim3 grid; flexible_launch_configs(reduction_size, stride, block, grid); auto stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "batchnorm_backward_element", [&] { using accscalar_t = at::acc_type; if (weight.defined() && weight.scalar_type() != input.scalar_type()) { batch_norm_backward_elemt_channels_last_kernel <<>>( grad_output.data_ptr(), input.data_ptr(), mean.data_ptr(), inv_std.data_ptr(), weight.data_ptr(), sum_dy.data_ptr(), sum_dy_xmu.data_ptr(), grad_input.data_ptr(), static_cast(norm_fct), reduction_size, stride); C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { batch_norm_backward_elemt_channels_last_kernel <<>>( grad_output.data_ptr(), input.data_ptr(), mean.data_ptr(), inv_std.data_ptr(), weight.defined() ? weight.data_ptr() : nullptr, sum_dy.data_ptr(), sum_dy_xmu.data_ptr(), grad_input.data_ptr(), static_cast(norm_fct), reduction_size, stride); C10_CUDA_KERNEL_LAUNCH_CHECK(); } }); return grad_input; } } } // namespace at::native