123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162 |
- # Copyright 2019 Kakao Brain
- #
- # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
- #
- # This source code is licensed under the BSD license found in the
- # LICENSE file in the root directory of this source tree.
- """Tracks the running statistics per mini-batch instead of micro-batch."""
- from typing import TypeVar, cast
- import torch
- from torch import Tensor, nn
- from torch.nn.functional import batch_norm
- from torch.nn.modules.batchnorm import _BatchNorm
- from .checkpoint import is_recomputing
- __all__ = ["DeferredBatchNorm"]
- TModule = TypeVar("TModule", bound=nn.Module)
- class DeferredBatchNorm(_BatchNorm):
- """A BatchNorm layer tracks multiple micro-batches to update running
- statistics per mini-batch.
- """
- sum: Tensor
- sum_squares: Tensor
- running_mean: Tensor
- running_var: Tensor
- num_batches_tracked: Tensor
- def __init__(
- self,
- num_features: int,
- eps: float = 1e-5,
- momentum: float = 0.1,
- affine: bool = True,
- chunks: int = 1,
- ) -> None:
- super().__init__(num_features, eps, momentum, affine, track_running_stats=True)
- self.register_buffer("sum", torch.zeros_like(self.running_mean))
- self.register_buffer("sum_squares", torch.zeros_like(self.running_var))
- self.counter = 0
- self.tracked = 0
- self.chunks = chunks
- def _check_input_dim(self, input: Tensor) -> None:
- # It's the typical _check_input_dim() implementation in PyTorch.
- if input.dim() <= 2:
- raise ValueError("expected at least 3D input (got %dD input)" % input.dim())
- def _track(self, input: Tensor) -> bool:
- """Tracks statistics of a micro-batch."""
- # Dimensions except channel. For example, (0, 2, 3) is for BatchNorm2d.
- dim = [0]
- dim.extend(range(2, input.dim()))
- with torch.no_grad():
- self.sum += input.sum(dim)
- self.sum_squares += (input ** 2).sum(dim)
- size = input.size().numel() // input.size(1)
- self.counter += size
- self.tracked += 1
- return self.tracked == self.chunks
- def _commit(self) -> None:
- """Updates the running statistics of a mini-batch."""
- exponential_average_factor = 0.0
- self.num_batches_tracked += 1
- if self.momentum is None: # use cumulative moving average
- exponential_average_factor = 1.0 / float(self.num_batches_tracked)
- else: # use exponential moving average
- exponential_average_factor = self.momentum
- mean = self.sum / self.counter
- var = self.sum_squares / self.counter - mean ** 2
- # Calculate the exponential moving average here.
- m = exponential_average_factor
- self.running_mean *= 1 - m
- self.running_mean += mean * m
- self.running_var *= 1 - m
- self.running_var += var * m
- self.sum.zero_()
- self.sum_squares.zero_()
- self.counter = 0
- self.tracked = 0
- def forward(self, input: Tensor) -> Tensor:
- if not self.training:
- # Don't train parameters on the evaluation mode.
- return batch_norm(
- input,
- running_mean=self.running_mean,
- running_var=self.running_var,
- weight=self.weight,
- bias=self.bias,
- training=False,
- momentum=0.0,
- eps=self.eps,
- )
- if not is_recomputing():
- # Track a micro-batch on the training mode
- # but not under a recomputation.
- tracked_enough = self._track(input)
- # Update the running statistics for a mini-batch
- # if it has tracked enough micro-batches.
- if tracked_enough:
- self._commit()
- # Normalize a micro-batch and train the parameters.
- return batch_norm(
- input,
- running_mean=None,
- running_var=None,
- weight=self.weight,
- bias=self.bias,
- training=True,
- momentum=0.0,
- eps=self.eps,
- )
- @classmethod
- def convert_deferred_batch_norm(cls, module: TModule, chunks: int = 1) -> TModule:
- """Converts a :class:`nn.BatchNorm` or underlying
- :class:`nn.BatchNorm`s into :class:`DeferredBatchNorm`::
- from torchvision.models.resnet import resnet101
- from torchpipe.batchnorm import DeferredBatchNorm
- model = resnet101()
- model = DeferredBatchNorm.convert_deferred_batch_norm(model)
- """
- if isinstance(module, DeferredBatchNorm) and module.chunks is chunks:
- return cast(TModule, module)
- module_output: nn.Module = module
- if isinstance(module, _BatchNorm) and module.track_running_stats:
- module_output = DeferredBatchNorm(module.num_features, module.eps, module.momentum, module.affine, chunks)
- if module.affine:
- module_output.register_parameter("weight", module.weight)
- module_output.register_parameter("bias", module.bias)
- module_output.register_buffer("running_mean", module.running_mean)
- module_output.register_buffer("running_var", module.running_var)
- module_output.register_buffer("num_batches_tracked", module.num_batches_tracked)
- for name, child in module.named_children():
- module_output.add_module(name, cls.convert_deferred_batch_norm(child, chunks))
- return cast(TModule, module_output)
|