123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826 |
- from typing import Optional, Any
- import torch
- from torch import Tensor
- from torch.nn.parameter import Parameter, UninitializedParameter, UninitializedBuffer
- from .. import functional as F
- from .. import init
- from ._functions import SyncBatchNorm as sync_batch_norm
- from .lazy import LazyModuleMixin
- from .module import Module
- __all__ = ['BatchNorm1d', 'LazyBatchNorm1d', 'BatchNorm2d', 'LazyBatchNorm2d', 'BatchNorm3d',
- 'LazyBatchNorm3d', 'SyncBatchNorm']
- class _NormBase(Module):
- """Common base of _InstanceNorm and _BatchNorm"""
- _version = 2
- __constants__ = ["track_running_stats", "momentum", "eps", "num_features", "affine"]
- num_features: int
- eps: float
- momentum: float
- affine: bool
- track_running_stats: bool
- # WARNING: weight and bias purposely not defined here.
- # See https://github.com/pytorch/pytorch/issues/39670
- def __init__(
- self,
- num_features: int,
- eps: float = 1e-5,
- momentum: float = 0.1,
- affine: bool = True,
- track_running_stats: bool = True,
- device=None,
- dtype=None
- ) -> None:
- factory_kwargs = {'device': device, 'dtype': dtype}
- super().__init__()
- self.num_features = num_features
- self.eps = eps
- self.momentum = momentum
- self.affine = affine
- self.track_running_stats = track_running_stats
- if self.affine:
- self.weight = Parameter(torch.empty(num_features, **factory_kwargs))
- self.bias = Parameter(torch.empty(num_features, **factory_kwargs))
- else:
- self.register_parameter("weight", None)
- self.register_parameter("bias", None)
- if self.track_running_stats:
- self.register_buffer('running_mean', torch.zeros(num_features, **factory_kwargs))
- self.register_buffer('running_var', torch.ones(num_features, **factory_kwargs))
- self.running_mean: Optional[Tensor]
- self.running_var: Optional[Tensor]
- self.register_buffer('num_batches_tracked',
- torch.tensor(0, dtype=torch.long,
- **{k: v for k, v in factory_kwargs.items() if k != 'dtype'}))
- self.num_batches_tracked: Optional[Tensor]
- else:
- self.register_buffer("running_mean", None)
- self.register_buffer("running_var", None)
- self.register_buffer("num_batches_tracked", None)
- self.reset_parameters()
- def reset_running_stats(self) -> None:
- if self.track_running_stats:
- # running_mean/running_var/num_batches... are registered at runtime depending
- # if self.track_running_stats is on
- self.running_mean.zero_() # type: ignore[union-attr]
- self.running_var.fill_(1) # type: ignore[union-attr]
- self.num_batches_tracked.zero_() # type: ignore[union-attr,operator]
- def reset_parameters(self) -> None:
- self.reset_running_stats()
- if self.affine:
- init.ones_(self.weight)
- init.zeros_(self.bias)
- def _check_input_dim(self, input):
- raise NotImplementedError
- def extra_repr(self):
- return (
- "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, "
- "track_running_stats={track_running_stats}".format(**self.__dict__)
- )
- def _load_from_state_dict(
- self,
- state_dict,
- prefix,
- local_metadata,
- strict,
- missing_keys,
- unexpected_keys,
- error_msgs,
- ):
- version = local_metadata.get("version", None)
- if (version is None or version < 2) and self.track_running_stats:
- # at version 2: added num_batches_tracked buffer
- # this should have a default value of 0
- num_batches_tracked_key = prefix + "num_batches_tracked"
- if num_batches_tracked_key not in state_dict:
- state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long)
- super()._load_from_state_dict(
- state_dict,
- prefix,
- local_metadata,
- strict,
- missing_keys,
- unexpected_keys,
- error_msgs,
- )
- class _BatchNorm(_NormBase):
- def __init__(
- self,
- num_features: int,
- eps: float = 1e-5,
- momentum: float = 0.1,
- affine: bool = True,
- track_running_stats: bool = True,
- device=None,
- dtype=None
- ) -> None:
- factory_kwargs = {'device': device, 'dtype': dtype}
- super().__init__(
- num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
- )
- def forward(self, input: Tensor) -> Tensor:
- self._check_input_dim(input)
- # exponential_average_factor is set to self.momentum
- # (when it is available) only so that it gets updated
- # in ONNX graph when this node is exported to ONNX.
- if self.momentum is None:
- exponential_average_factor = 0.0
- else:
- exponential_average_factor = self.momentum
- if self.training and self.track_running_stats:
- # TODO: if statement only here to tell the jit to skip emitting this when it is None
- if self.num_batches_tracked is not None: # type: ignore[has-type]
- self.num_batches_tracked.add_(1) # type: ignore[has-type]
- if self.momentum is None: # use cumulative moving average
- exponential_average_factor = 1.0 / float(self.num_batches_tracked)
- else: # use exponential moving average
- exponential_average_factor = self.momentum
- r"""
- Decide whether the mini-batch stats should be used for normalization rather than the buffers.
- Mini-batch stats are used in training mode, and in eval mode when buffers are None.
- """
- if self.training:
- bn_training = True
- else:
- bn_training = (self.running_mean is None) and (self.running_var is None)
- r"""
- Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
- passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
- used for normalization (i.e. in eval mode when buffers are not None).
- """
- return F.batch_norm(
- input,
- # If buffers are not to be tracked, ensure that they won't be updated
- self.running_mean
- if not self.training or self.track_running_stats
- else None,
- self.running_var if not self.training or self.track_running_stats else None,
- self.weight,
- self.bias,
- bn_training,
- exponential_average_factor,
- self.eps,
- )
- class _LazyNormBase(LazyModuleMixin, _NormBase):
- weight: UninitializedParameter # type: ignore[assignment]
- bias: UninitializedParameter # type: ignore[assignment]
- def __init__(self, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True,
- device=None, dtype=None) -> None:
- factory_kwargs = {'device': device, 'dtype': dtype}
- super().__init__(
- # affine and track_running_stats are hardcoded to False to
- # avoid creating tensors that will soon be overwritten.
- 0,
- eps,
- momentum,
- False,
- False,
- **factory_kwargs,
- )
- self.affine = affine
- self.track_running_stats = track_running_stats
- if self.affine:
- self.weight = UninitializedParameter(**factory_kwargs)
- self.bias = UninitializedParameter(**factory_kwargs)
- if self.track_running_stats:
- self.running_mean = UninitializedBuffer(**factory_kwargs)
- self.running_var = UninitializedBuffer(**factory_kwargs)
- self.num_batches_tracked = torch.tensor(
- 0, dtype=torch.long, **{k: v for k, v in factory_kwargs.items() if k != 'dtype'})
- def reset_parameters(self) -> None:
- if not self.has_uninitialized_params() and self.num_features != 0:
- super().reset_parameters()
- def initialize_parameters(self, input) -> None: # type: ignore[override]
- if self.has_uninitialized_params():
- self.num_features = input.shape[1]
- if self.affine:
- assert isinstance(self.weight, UninitializedParameter)
- assert isinstance(self.bias, UninitializedParameter)
- self.weight.materialize((self.num_features,))
- self.bias.materialize((self.num_features,))
- if self.track_running_stats:
- self.running_mean.materialize((self.num_features,)) # type:ignore[union-attr]
- self.running_var.materialize((self.num_features,)) # type:ignore[union-attr]
- self.reset_parameters()
- class BatchNorm1d(_BatchNorm):
- r"""Applies Batch Normalization over a 2D or 3D input as described in the paper
- `Batch Normalization: Accelerating Deep Network Training by Reducing
- Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
- .. math::
- y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
- The mean and standard-deviation are calculated per-dimension over
- the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
- of size `C` (where `C` is the number of features or channels of the input). By default, the
- elements of :math:`\gamma` are set to 1 and the elements of :math:`\beta` are set to 0. The
- standard-deviation is calculated via the biased estimator, equivalent to `torch.var(input, unbiased=False)`.
- Also by default, during training this layer keeps running estimates of its
- computed mean and variance, which are then used for normalization during
- evaluation. The running estimates are kept with a default :attr:`momentum`
- of 0.1.
- If :attr:`track_running_stats` is set to ``False``, this layer then does not
- keep running estimates, and batch statistics are instead used during
- evaluation time as well.
- .. note::
- This :attr:`momentum` argument is different from one used in optimizer
- classes and the conventional notion of momentum. Mathematically, the
- update rule for running statistics here is
- :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
- where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
- new observed value.
- Because the Batch Normalization is done over the `C` dimension, computing statistics
- on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization.
- Args:
- num_features: number of features or channels :math:`C` of the input
- eps: a value added to the denominator for numerical stability.
- Default: 1e-5
- momentum: the value used for the running_mean and running_var
- computation. Can be set to ``None`` for cumulative moving average
- (i.e. simple average). Default: 0.1
- affine: a boolean value that when set to ``True``, this module has
- learnable affine parameters. Default: ``True``
- track_running_stats: a boolean value that when set to ``True``, this
- module tracks the running mean and variance, and when set to ``False``,
- this module does not track such statistics, and initializes statistics
- buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
- When these buffers are ``None``, this module always uses batch statistics.
- in both training and eval modes. Default: ``True``
- Shape:
- - Input: :math:`(N, C)` or :math:`(N, C, L)`, where :math:`N` is the batch size,
- :math:`C` is the number of features or channels, and :math:`L` is the sequence length
- - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
- Examples::
- >>> # With Learnable Parameters
- >>> m = nn.BatchNorm1d(100)
- >>> # Without Learnable Parameters
- >>> m = nn.BatchNorm1d(100, affine=False)
- >>> input = torch.randn(20, 100)
- >>> output = m(input)
- """
- def _check_input_dim(self, input):
- if input.dim() != 2 and input.dim() != 3:
- raise ValueError(
- "expected 2D or 3D input (got {}D input)".format(input.dim())
- )
- class LazyBatchNorm1d(_LazyNormBase, _BatchNorm):
- r"""A :class:`torch.nn.BatchNorm1d` module with lazy initialization of
- the ``num_features`` argument of the :class:`BatchNorm1d` that is inferred
- from the ``input.size(1)``.
- The attributes that will be lazily initialized are `weight`, `bias`,
- `running_mean` and `running_var`.
- Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
- on lazy modules and their limitations.
- Args:
- eps: a value added to the denominator for numerical stability.
- Default: 1e-5
- momentum: the value used for the running_mean and running_var
- computation. Can be set to ``None`` for cumulative moving average
- (i.e. simple average). Default: 0.1
- affine: a boolean value that when set to ``True``, this module has
- learnable affine parameters. Default: ``True``
- track_running_stats: a boolean value that when set to ``True``, this
- module tracks the running mean and variance, and when set to ``False``,
- this module does not track such statistics, and initializes statistics
- buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
- When these buffers are ``None``, this module always uses batch statistics.
- in both training and eval modes. Default: ``True``
- """
- cls_to_become = BatchNorm1d # type: ignore[assignment]
- def _check_input_dim(self, input):
- if input.dim() != 2 and input.dim() != 3:
- raise ValueError(
- "expected 2D or 3D input (got {}D input)".format(input.dim())
- )
- class BatchNorm2d(_BatchNorm):
- r"""Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs
- with additional channel dimension) as described in the paper
- `Batch Normalization: Accelerating Deep Network Training by Reducing
- Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
- .. math::
- y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
- The mean and standard-deviation are calculated per-dimension over
- the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
- of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set
- to 1 and the elements of :math:`\beta` are set to 0. The standard-deviation is calculated
- via the biased estimator, equivalent to `torch.var(input, unbiased=False)`.
- Also by default, during training this layer keeps running estimates of its
- computed mean and variance, which are then used for normalization during
- evaluation. The running estimates are kept with a default :attr:`momentum`
- of 0.1.
- If :attr:`track_running_stats` is set to ``False``, this layer then does not
- keep running estimates, and batch statistics are instead used during
- evaluation time as well.
- .. note::
- This :attr:`momentum` argument is different from one used in optimizer
- classes and the conventional notion of momentum. Mathematically, the
- update rule for running statistics here is
- :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
- where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
- new observed value.
- Because the Batch Normalization is done over the `C` dimension, computing statistics
- on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization.
- Args:
- num_features: :math:`C` from an expected input of size
- :math:`(N, C, H, W)`
- eps: a value added to the denominator for numerical stability.
- Default: 1e-5
- momentum: the value used for the running_mean and running_var
- computation. Can be set to ``None`` for cumulative moving average
- (i.e. simple average). Default: 0.1
- affine: a boolean value that when set to ``True``, this module has
- learnable affine parameters. Default: ``True``
- track_running_stats: a boolean value that when set to ``True``, this
- module tracks the running mean and variance, and when set to ``False``,
- this module does not track such statistics, and initializes statistics
- buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
- When these buffers are ``None``, this module always uses batch statistics.
- in both training and eval modes. Default: ``True``
- Shape:
- - Input: :math:`(N, C, H, W)`
- - Output: :math:`(N, C, H, W)` (same shape as input)
- Examples::
- >>> # With Learnable Parameters
- >>> m = nn.BatchNorm2d(100)
- >>> # Without Learnable Parameters
- >>> m = nn.BatchNorm2d(100, affine=False)
- >>> input = torch.randn(20, 100, 35, 45)
- >>> output = m(input)
- """
- def _check_input_dim(self, input):
- if input.dim() != 4:
- raise ValueError("expected 4D input (got {}D input)".format(input.dim()))
- class LazyBatchNorm2d(_LazyNormBase, _BatchNorm):
- r"""A :class:`torch.nn.BatchNorm2d` module with lazy initialization of
- the ``num_features`` argument of the :class:`BatchNorm2d` that is inferred
- from the ``input.size(1)``.
- The attributes that will be lazily initialized are `weight`, `bias`,
- `running_mean` and `running_var`.
- Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
- on lazy modules and their limitations.
- Args:
- eps: a value added to the denominator for numerical stability.
- Default: 1e-5
- momentum: the value used for the running_mean and running_var
- computation. Can be set to ``None`` for cumulative moving average
- (i.e. simple average). Default: 0.1
- affine: a boolean value that when set to ``True``, this module has
- learnable affine parameters. Default: ``True``
- track_running_stats: a boolean value that when set to ``True``, this
- module tracks the running mean and variance, and when set to ``False``,
- this module does not track such statistics, and initializes statistics
- buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
- When these buffers are ``None``, this module always uses batch statistics.
- in both training and eval modes. Default: ``True``
- """
- cls_to_become = BatchNorm2d # type: ignore[assignment]
- def _check_input_dim(self, input):
- if input.dim() != 4:
- raise ValueError("expected 4D input (got {}D input)".format(input.dim()))
- class BatchNorm3d(_BatchNorm):
- r"""Applies Batch Normalization over a 5D input (a mini-batch of 3D inputs
- with additional channel dimension) as described in the paper
- `Batch Normalization: Accelerating Deep Network Training by Reducing
- Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
- .. math::
- y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
- The mean and standard-deviation are calculated per-dimension over
- the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
- of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set
- to 1 and the elements of :math:`\beta` are set to 0. The standard-deviation is calculated
- via the biased estimator, equivalent to `torch.var(input, unbiased=False)`.
- Also by default, during training this layer keeps running estimates of its
- computed mean and variance, which are then used for normalization during
- evaluation. The running estimates are kept with a default :attr:`momentum`
- of 0.1.
- If :attr:`track_running_stats` is set to ``False``, this layer then does not
- keep running estimates, and batch statistics are instead used during
- evaluation time as well.
- .. note::
- This :attr:`momentum` argument is different from one used in optimizer
- classes and the conventional notion of momentum. Mathematically, the
- update rule for running statistics here is
- :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
- where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
- new observed value.
- Because the Batch Normalization is done over the `C` dimension, computing statistics
- on `(N, D, H, W)` slices, it's common terminology to call this Volumetric Batch Normalization
- or Spatio-temporal Batch Normalization.
- Args:
- num_features: :math:`C` from an expected input of size
- :math:`(N, C, D, H, W)`
- eps: a value added to the denominator for numerical stability.
- Default: 1e-5
- momentum: the value used for the running_mean and running_var
- computation. Can be set to ``None`` for cumulative moving average
- (i.e. simple average). Default: 0.1
- affine: a boolean value that when set to ``True``, this module has
- learnable affine parameters. Default: ``True``
- track_running_stats: a boolean value that when set to ``True``, this
- module tracks the running mean and variance, and when set to ``False``,
- this module does not track such statistics, and initializes statistics
- buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
- When these buffers are ``None``, this module always uses batch statistics.
- in both training and eval modes. Default: ``True``
- Shape:
- - Input: :math:`(N, C, D, H, W)`
- - Output: :math:`(N, C, D, H, W)` (same shape as input)
- Examples::
- >>> # With Learnable Parameters
- >>> m = nn.BatchNorm3d(100)
- >>> # Without Learnable Parameters
- >>> m = nn.BatchNorm3d(100, affine=False)
- >>> input = torch.randn(20, 100, 35, 45, 10)
- >>> output = m(input)
- """
- def _check_input_dim(self, input):
- if input.dim() != 5:
- raise ValueError("expected 5D input (got {}D input)".format(input.dim()))
- class LazyBatchNorm3d(_LazyNormBase, _BatchNorm):
- r"""A :class:`torch.nn.BatchNorm3d` module with lazy initialization of
- the ``num_features`` argument of the :class:`BatchNorm3d` that is inferred
- from the ``input.size(1)``.
- The attributes that will be lazily initialized are `weight`, `bias`,
- `running_mean` and `running_var`.
- Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
- on lazy modules and their limitations.
- Args:
- eps: a value added to the denominator for numerical stability.
- Default: 1e-5
- momentum: the value used for the running_mean and running_var
- computation. Can be set to ``None`` for cumulative moving average
- (i.e. simple average). Default: 0.1
- affine: a boolean value that when set to ``True``, this module has
- learnable affine parameters. Default: ``True``
- track_running_stats: a boolean value that when set to ``True``, this
- module tracks the running mean and variance, and when set to ``False``,
- this module does not track such statistics, and initializes statistics
- buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
- When these buffers are ``None``, this module always uses batch statistics.
- in both training and eval modes. Default: ``True``
- """
- cls_to_become = BatchNorm3d # type: ignore[assignment]
- def _check_input_dim(self, input):
- if input.dim() != 5:
- raise ValueError("expected 5D input (got {}D input)".format(input.dim()))
- class SyncBatchNorm(_BatchNorm):
- r"""Applies Batch Normalization over a N-Dimensional input (a mini-batch of [N-2]D inputs
- with additional channel dimension) as described in the paper
- `Batch Normalization: Accelerating Deep Network Training by Reducing
- Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
- .. math::
- y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
- The mean and standard-deviation are calculated per-dimension over all
- mini-batches of the same process groups. :math:`\gamma` and :math:`\beta`
- are learnable parameter vectors of size `C` (where `C` is the input size).
- By default, the elements of :math:`\gamma` are sampled from
- :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0.
- The standard-deviation is calculated via the biased estimator, equivalent to
- `torch.var(input, unbiased=False)`.
- Also by default, during training this layer keeps running estimates of its
- computed mean and variance, which are then used for normalization during
- evaluation. The running estimates are kept with a default :attr:`momentum`
- of 0.1.
- If :attr:`track_running_stats` is set to ``False``, this layer then does not
- keep running estimates, and batch statistics are instead used during
- evaluation time as well.
- .. note::
- This :attr:`momentum` argument is different from one used in optimizer
- classes and the conventional notion of momentum. Mathematically, the
- update rule for running statistics here is
- :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
- where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
- new observed value.
- Because the Batch Normalization is done for each channel in the ``C`` dimension, computing
- statistics on ``(N, +)`` slices, it's common terminology to call this Volumetric Batch
- Normalization or Spatio-temporal Batch Normalization.
- Currently :class:`SyncBatchNorm` only supports
- :class:`~torch.nn.DistributedDataParallel` (DDP) with single GPU per process. Use
- :meth:`torch.nn.SyncBatchNorm.convert_sync_batchnorm()` to convert
- :attr:`BatchNorm*D` layer to :class:`SyncBatchNorm` before wrapping
- Network with DDP.
- Args:
- num_features: :math:`C` from an expected input of size
- :math:`(N, C, +)`
- eps: a value added to the denominator for numerical stability.
- Default: ``1e-5``
- momentum: the value used for the running_mean and running_var
- computation. Can be set to ``None`` for cumulative moving average
- (i.e. simple average). Default: 0.1
- affine: a boolean value that when set to ``True``, this module has
- learnable affine parameters. Default: ``True``
- track_running_stats: a boolean value that when set to ``True``, this
- module tracks the running mean and variance, and when set to ``False``,
- this module does not track such statistics, and initializes statistics
- buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
- When these buffers are ``None``, this module always uses batch statistics.
- in both training and eval modes. Default: ``True``
- process_group: synchronization of stats happen within each process group
- individually. Default behavior is synchronization across the whole
- world
- Shape:
- - Input: :math:`(N, C, +)`
- - Output: :math:`(N, C, +)` (same shape as input)
- .. note::
- Synchronization of batchnorm statistics occurs only while training, i.e.
- synchronization is disabled when ``model.eval()`` is set or if
- ``self.training`` is otherwise ``False``.
- Examples::
- >>> # xdoctest: +SKIP
- >>> # With Learnable Parameters
- >>> m = nn.SyncBatchNorm(100)
- >>> # creating process group (optional)
- >>> # ranks is a list of int identifying rank ids.
- >>> ranks = list(range(8))
- >>> r1, r2 = ranks[:4], ranks[4:]
- >>> # Note: every rank calls into new_group for every
- >>> # process group created, even if that rank is not
- >>> # part of the group.
- >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]
- >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]
- >>> # Without Learnable Parameters
- >>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group)
- >>> input = torch.randn(20, 100, 35, 45, 10)
- >>> output = m(input)
- >>> # network is nn.BatchNorm layer
- >>> sync_bn_network = nn.SyncBatchNorm.convert_sync_batchnorm(network, process_group)
- >>> # only single gpu per process is currently supported
- >>> ddp_sync_bn_network = torch.nn.parallel.DistributedDataParallel(
- >>> sync_bn_network,
- >>> device_ids=[args.local_rank],
- >>> output_device=args.local_rank)
- """
- def __init__(
- self,
- num_features: int,
- eps: float = 1e-5,
- momentum: float = 0.1,
- affine: bool = True,
- track_running_stats: bool = True,
- process_group: Optional[Any] = None,
- device=None,
- dtype=None
- ) -> None:
- factory_kwargs = {'device': device, 'dtype': dtype}
- super().__init__(
- num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
- )
- self.process_group = process_group
- def _check_input_dim(self, input):
- if input.dim() < 2:
- raise ValueError(
- "expected at least 2D input (got {}D input)".format(input.dim())
- )
- def _check_non_zero_input_channels(self, input):
- if input.size(1) == 0:
- raise ValueError(
- "SyncBatchNorm number of input channels should be non-zero"
- )
- def forward(self, input: Tensor) -> Tensor:
- self._check_input_dim(input)
- self._check_non_zero_input_channels(input)
- # exponential_average_factor is set to self.momentum
- # (when it is available) only so that it gets updated
- # in ONNX graph when this node is exported to ONNX.
- if self.momentum is None:
- exponential_average_factor = 0.0
- else:
- exponential_average_factor = self.momentum
- if self.training and self.track_running_stats:
- assert self.num_batches_tracked is not None
- self.num_batches_tracked.add_(1)
- if self.momentum is None: # use cumulative moving average
- exponential_average_factor = 1.0 / self.num_batches_tracked.item()
- else: # use exponential moving average
- exponential_average_factor = self.momentum
- r"""
- Decide whether the mini-batch stats should be used for normalization rather than the buffers.
- Mini-batch stats are used in training mode, and in eval mode when buffers are None.
- """
- if self.training:
- bn_training = True
- else:
- bn_training = (self.running_mean is None) and (self.running_var is None)
- r"""
- Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
- passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
- used for normalization (i.e. in eval mode when buffers are not None).
- """
- # If buffers are not to be tracked, ensure that they won't be updated
- running_mean = (
- self.running_mean if not self.training or self.track_running_stats else None
- )
- running_var = (
- self.running_var if not self.training or self.track_running_stats else None
- )
- # Don't sync batchnorm stats in inference mode (model.eval()).
- need_sync = (bn_training and self.training and
- torch.distributed.is_available() and torch.distributed.is_initialized())
- if need_sync:
- # currently only GPU input is supported
- if not input.is_cuda:
- raise ValueError("SyncBatchNorm expected input tensor to be on GPU")
- process_group = torch.distributed.group.WORLD
- if self.process_group:
- process_group = self.process_group
- world_size = torch.distributed.get_world_size(process_group)
- need_sync = world_size > 1
- # fallback to framework BN when synchronization is not necessary
- if not need_sync:
- return F.batch_norm(
- input,
- running_mean,
- running_var,
- self.weight,
- self.bias,
- bn_training,
- exponential_average_factor,
- self.eps,
- )
- else:
- assert bn_training
- return sync_batch_norm.apply(
- input,
- self.weight,
- self.bias,
- running_mean,
- running_var,
- self.eps,
- exponential_average_factor,
- process_group,
- world_size,
- )
- @classmethod
- def convert_sync_batchnorm(cls, module, process_group=None):
- r"""Helper function to convert all :attr:`BatchNorm*D` layers in the model to
- :class:`torch.nn.SyncBatchNorm` layers.
- Args:
- module (nn.Module): module containing one or more :attr:`BatchNorm*D` layers
- process_group (optional): process group to scope synchronization,
- default is the whole world
- Returns:
- The original :attr:`module` with the converted :class:`torch.nn.SyncBatchNorm`
- layers. If the original :attr:`module` is a :attr:`BatchNorm*D` layer,
- a new :class:`torch.nn.SyncBatchNorm` layer object will be returned
- instead.
- Example::
- >>> # Network with nn.BatchNorm layer
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
- >>> module = torch.nn.Sequential(
- >>> torch.nn.Linear(20, 100),
- >>> torch.nn.BatchNorm1d(100),
- >>> ).cuda()
- >>> # creating process group (optional)
- >>> # ranks is a list of int identifying rank ids.
- >>> ranks = list(range(8))
- >>> r1, r2 = ranks[:4], ranks[4:]
- >>> # Note: every rank calls into new_group for every
- >>> # process group created, even if that rank is not
- >>> # part of the group.
- >>> # xdoctest: +SKIP("distributed")
- >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]
- >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]
- >>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module, process_group)
- """
- module_output = module
- if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
- module_output = torch.nn.SyncBatchNorm(
- module.num_features,
- module.eps,
- module.momentum,
- module.affine,
- module.track_running_stats,
- process_group,
- )
- if module.affine:
- with torch.no_grad():
- module_output.weight = module.weight
- module_output.bias = module.bias
- module_output.running_mean = module.running_mean
- module_output.running_var = module.running_var
- module_output.num_batches_tracked = module.num_batches_tracked
- if hasattr(module, "qconfig"):
- module_output.qconfig = module.qconfig
- for name, child in module.named_children():
- module_output.add_module(
- name, cls.convert_sync_batchnorm(child, process_group)
- )
- del module
- return module_output
|