123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281 |
- import torch
- import torch.distributed as dist
- from torch.autograd.function import Function
- class SyncBatchNorm(Function):
- @staticmethod
- def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size):
- if not (
- input.is_contiguous(memory_format=torch.channels_last) or
- input.is_contiguous(memory_format=torch.channels_last_3d)
- ):
- input = input.contiguous()
- if weight is not None:
- weight = weight.contiguous()
- size = int(input.numel() // input.size(1))
- if size == 1 and world_size < 2:
- raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size))
- num_channels = input.shape[1]
- if input.numel() > 0:
- # calculate mean/invstd for input.
- mean, invstd = torch.batch_norm_stats(input, eps)
- count = torch.full(
- (1,),
- input.numel() // input.size(1),
- dtype=mean.dtype,
- device=mean.device
- )
- # C, C, 1 -> (2C + 1)
- combined = torch.cat([mean, invstd, count], dim=0)
- else:
- # for empty input, set stats and the count to zero. The stats with
- # zero count will be filtered out later when computing global mean
- # & invstd, but they still needs to participate the all_gather
- # collective communication to unblock other peer processes.
- combined = torch.zeros(
- 2 * num_channels + 1,
- dtype=input.dtype,
- device=input.device
- )
- # Use allgather instead of allreduce because count could be different across
- # ranks, simple all reduce op can not give correct results.
- # batch_norm_gather_stats_with_counts calculates global mean & invstd based on
- # all gathered mean, invstd and count.
- # for nccl backend, use the optimized version of all gather.
- if process_group._get_backend_name() == 'nccl':
- # world_size * (2C + 1)
- combined_size = combined.numel()
- combined_flat = torch.empty(1,
- combined_size * world_size,
- dtype=combined.dtype,
- device=combined.device)
- dist.all_gather_into_tensor(combined_flat, combined, process_group, async_op=False)
- combined = torch.reshape(combined_flat, (world_size, combined_size))
- # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
- mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)
- else:
- # world_size * (2C + 1)
- combined_list = [
- torch.empty_like(combined) for _ in range(world_size)
- ]
- dist.all_gather(combined_list, combined, process_group, async_op=False)
- combined = torch.stack(combined_list, dim=0)
- # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
- mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)
- if not torch.cuda.is_current_stream_capturing():
- # The lines below force a synchronization between CUDA and CPU, because
- # the shape of the result count_all depends on the values in mask tensor.
- # Such synchronizations break CUDA Graph capturing.
- # See https://github.com/pytorch/pytorch/issues/78549
- # FIXME: https://github.com/pytorch/pytorch/issues/78656 describes
- # a better longer-term solution.
- # remove stats from empty inputs
- mask = count_all.squeeze(-1) >= 1
- count_all = count_all[mask]
- mean_all = mean_all[mask]
- invstd_all = invstd_all[mask]
- # calculate global mean & invstd
- mean, invstd = torch.batch_norm_gather_stats_with_counts(
- input,
- mean_all,
- invstd_all,
- running_mean,
- running_var,
- momentum,
- eps,
- count_all.view(-1)
- )
- self.save_for_backward(input, weight, mean, invstd, count_all.to(torch.int32))
- self.process_group = process_group
- # apply element-wise normalization
- if input.numel() > 0:
- return torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
- else:
- return torch.empty_like(input)
- @staticmethod
- def backward(self, grad_output):
- if not (
- grad_output.is_contiguous(memory_format=torch.channels_last) or
- grad_output.is_contiguous(memory_format=torch.channels_last_3d)
- ):
- grad_output = grad_output.contiguous()
- saved_input, weight, mean, invstd, count_tensor = self.saved_tensors
- grad_input = grad_weight = grad_bias = None
- process_group = self.process_group
- if saved_input.numel() > 0:
- # calculate local stats as well as grad_weight / grad_bias
- sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce(
- grad_output,
- saved_input,
- mean,
- invstd,
- weight,
- self.needs_input_grad[0],
- self.needs_input_grad[1],
- self.needs_input_grad[2]
- )
- if self.needs_input_grad[0]:
- # synchronizing stats used to calculate input gradient.
- num_channels = sum_dy.shape[0]
- combined = torch.cat([sum_dy, sum_dy_xmu], dim=0)
- torch.distributed.all_reduce(
- combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False)
- sum_dy, sum_dy_xmu = torch.split(combined, num_channels)
- # backward pass for gradient calculation
- grad_input = torch.batch_norm_backward_elemt(
- grad_output,
- saved_input,
- mean,
- invstd,
- weight,
- sum_dy,
- sum_dy_xmu,
- count_tensor
- )
- # synchronizing of grad_weight / grad_bias is not needed as distributed
- # training would handle all reduce.
- if weight is None or not self.needs_input_grad[1]:
- grad_weight = None
- if weight is None or not self.needs_input_grad[2]:
- grad_bias = None
- else:
- # This process got an empty input tensor in the forward pass.
- # Although this process can directly set grad_input as an empty
- # tensor of zeros, it still needs to participate in the collective
- # communication to unblock its peers, as other peer processes might
- # have recieved non-empty inputs.
- num_channels = saved_input.shape[1]
- if self.needs_input_grad[0]:
- # launch all_reduce to unblock other peer processes
- combined = torch.zeros(
- 2 * num_channels,
- dtype=saved_input.dtype,
- device=saved_input.device
- )
- torch.distributed.all_reduce(
- combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False)
- # Leave grad_input, grad_weight and grad_bias as None, which will be
- # interpreted by the autograd engine as Tensors full of zeros.
- return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
- class CrossMapLRN2d(Function):
- @staticmethod
- def forward(ctx, input, size, alpha=1e-4, beta=0.75, k=1):
- ctx.size = size
- ctx.alpha = alpha
- ctx.beta = beta
- ctx.k = k
- ctx.scale = None
- assert input.dim() == 4
- ctx.scale = ctx.scale or input.new()
- output = input.new()
- batch_size = input.size(0)
- channels = input.size(1)
- input_height = input.size(2)
- input_width = input.size(3)
- output.resize_as_(input)
- ctx.scale.resize_as_(input)
- # use output storage as temporary buffer
- input_square = output
- torch.pow(input, 2, out=input_square)
- pre_pad = int((ctx.size - 1) / 2 + 1)
- pre_pad_crop = channels if pre_pad > channels else pre_pad
- scale_first = ctx.scale.select(1, 0)
- scale_first.zero_()
- # compute first feature map normalization
- for c in range(pre_pad_crop):
- scale_first.add_(input_square.select(1, c))
- # reuse computations for next feature maps normalization
- # by adding the next feature map and removing the previous
- for c in range(1, channels):
- scale_previous = ctx.scale.select(1, c - 1)
- scale_current = ctx.scale.select(1, c)
- scale_current.copy_(scale_previous)
- if c < channels - pre_pad + 1:
- square_next = input_square.select(1, c + pre_pad - 1)
- scale_current.add_(square_next, alpha=1)
- if c > pre_pad:
- square_previous = input_square.select(1, c - pre_pad)
- scale_current.add_(square_previous, alpha=-1)
- ctx.scale.mul_(ctx.alpha / ctx.size).add_(ctx.k)
- torch.pow(ctx.scale, -ctx.beta, out=output)
- output.mul_(input)
- ctx.save_for_backward(input, output)
- return output
- @staticmethod
- def backward(ctx, grad_output):
- input, output = ctx.saved_tensors
- grad_input = grad_output.new()
- batch_size = input.size(0)
- channels = input.size(1)
- input_height = input.size(2)
- input_width = input.size(3)
- paddded_ratio = input.new(channels + ctx.size - 1, input_height,
- input_width)
- accum_ratio = input.new(input_height, input_width)
- cache_ratio_value = 2 * ctx.alpha * ctx.beta / ctx.size
- inversePrePad = int(ctx.size - (ctx.size - 1) / 2)
- grad_input.resize_as_(input)
- torch.pow(ctx.scale, -ctx.beta, out=grad_input).mul_(grad_output)
- paddded_ratio.zero_()
- padded_ratio_center = paddded_ratio.narrow(0, inversePrePad,
- channels)
- for n in range(batch_size):
- torch.mul(grad_output[n], output[n], out=padded_ratio_center)
- padded_ratio_center.div_(ctx.scale[n])
- torch.sum(
- paddded_ratio.narrow(0, 0, ctx.size - 1), 0, keepdim=False, out=accum_ratio)
- for c in range(channels):
- accum_ratio.add_(paddded_ratio[c + ctx.size - 1])
- grad_input[n][c].addcmul_(input[n][c], accum_ratio, value=-cache_ratio_value)
- accum_ratio.add_(paddded_ratio[c], alpha=-1)
- return grad_input, None, None, None, None
- class BackwardHookFunction(torch.autograd.Function):
- @staticmethod
- def forward(ctx, *args):
- ctx.mark_non_differentiable(*[arg for arg in args if not arg.requires_grad])
- return args
- @staticmethod
- def backward(ctx, *args):
- return args
|