batchnorm.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826
  1. from typing import Optional, Any
  2. import torch
  3. from torch import Tensor
  4. from torch.nn.parameter import Parameter, UninitializedParameter, UninitializedBuffer
  5. from .. import functional as F
  6. from .. import init
  7. from ._functions import SyncBatchNorm as sync_batch_norm
  8. from .lazy import LazyModuleMixin
  9. from .module import Module
  10. __all__ = ['BatchNorm1d', 'LazyBatchNorm1d', 'BatchNorm2d', 'LazyBatchNorm2d', 'BatchNorm3d',
  11. 'LazyBatchNorm3d', 'SyncBatchNorm']
  12. class _NormBase(Module):
  13. """Common base of _InstanceNorm and _BatchNorm"""
  14. _version = 2
  15. __constants__ = ["track_running_stats", "momentum", "eps", "num_features", "affine"]
  16. num_features: int
  17. eps: float
  18. momentum: float
  19. affine: bool
  20. track_running_stats: bool
  21. # WARNING: weight and bias purposely not defined here.
  22. # See https://github.com/pytorch/pytorch/issues/39670
  23. def __init__(
  24. self,
  25. num_features: int,
  26. eps: float = 1e-5,
  27. momentum: float = 0.1,
  28. affine: bool = True,
  29. track_running_stats: bool = True,
  30. device=None,
  31. dtype=None
  32. ) -> None:
  33. factory_kwargs = {'device': device, 'dtype': dtype}
  34. super().__init__()
  35. self.num_features = num_features
  36. self.eps = eps
  37. self.momentum = momentum
  38. self.affine = affine
  39. self.track_running_stats = track_running_stats
  40. if self.affine:
  41. self.weight = Parameter(torch.empty(num_features, **factory_kwargs))
  42. self.bias = Parameter(torch.empty(num_features, **factory_kwargs))
  43. else:
  44. self.register_parameter("weight", None)
  45. self.register_parameter("bias", None)
  46. if self.track_running_stats:
  47. self.register_buffer('running_mean', torch.zeros(num_features, **factory_kwargs))
  48. self.register_buffer('running_var', torch.ones(num_features, **factory_kwargs))
  49. self.running_mean: Optional[Tensor]
  50. self.running_var: Optional[Tensor]
  51. self.register_buffer('num_batches_tracked',
  52. torch.tensor(0, dtype=torch.long,
  53. **{k: v for k, v in factory_kwargs.items() if k != 'dtype'}))
  54. self.num_batches_tracked: Optional[Tensor]
  55. else:
  56. self.register_buffer("running_mean", None)
  57. self.register_buffer("running_var", None)
  58. self.register_buffer("num_batches_tracked", None)
  59. self.reset_parameters()
  60. def reset_running_stats(self) -> None:
  61. if self.track_running_stats:
  62. # running_mean/running_var/num_batches... are registered at runtime depending
  63. # if self.track_running_stats is on
  64. self.running_mean.zero_() # type: ignore[union-attr]
  65. self.running_var.fill_(1) # type: ignore[union-attr]
  66. self.num_batches_tracked.zero_() # type: ignore[union-attr,operator]
  67. def reset_parameters(self) -> None:
  68. self.reset_running_stats()
  69. if self.affine:
  70. init.ones_(self.weight)
  71. init.zeros_(self.bias)
  72. def _check_input_dim(self, input):
  73. raise NotImplementedError
  74. def extra_repr(self):
  75. return (
  76. "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, "
  77. "track_running_stats={track_running_stats}".format(**self.__dict__)
  78. )
  79. def _load_from_state_dict(
  80. self,
  81. state_dict,
  82. prefix,
  83. local_metadata,
  84. strict,
  85. missing_keys,
  86. unexpected_keys,
  87. error_msgs,
  88. ):
  89. version = local_metadata.get("version", None)
  90. if (version is None or version < 2) and self.track_running_stats:
  91. # at version 2: added num_batches_tracked buffer
  92. # this should have a default value of 0
  93. num_batches_tracked_key = prefix + "num_batches_tracked"
  94. if num_batches_tracked_key not in state_dict:
  95. state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long)
  96. super()._load_from_state_dict(
  97. state_dict,
  98. prefix,
  99. local_metadata,
  100. strict,
  101. missing_keys,
  102. unexpected_keys,
  103. error_msgs,
  104. )
  105. class _BatchNorm(_NormBase):
  106. def __init__(
  107. self,
  108. num_features: int,
  109. eps: float = 1e-5,
  110. momentum: float = 0.1,
  111. affine: bool = True,
  112. track_running_stats: bool = True,
  113. device=None,
  114. dtype=None
  115. ) -> None:
  116. factory_kwargs = {'device': device, 'dtype': dtype}
  117. super().__init__(
  118. num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
  119. )
  120. def forward(self, input: Tensor) -> Tensor:
  121. self._check_input_dim(input)
  122. # exponential_average_factor is set to self.momentum
  123. # (when it is available) only so that it gets updated
  124. # in ONNX graph when this node is exported to ONNX.
  125. if self.momentum is None:
  126. exponential_average_factor = 0.0
  127. else:
  128. exponential_average_factor = self.momentum
  129. if self.training and self.track_running_stats:
  130. # TODO: if statement only here to tell the jit to skip emitting this when it is None
  131. if self.num_batches_tracked is not None: # type: ignore[has-type]
  132. self.num_batches_tracked.add_(1) # type: ignore[has-type]
  133. if self.momentum is None: # use cumulative moving average
  134. exponential_average_factor = 1.0 / float(self.num_batches_tracked)
  135. else: # use exponential moving average
  136. exponential_average_factor = self.momentum
  137. r"""
  138. Decide whether the mini-batch stats should be used for normalization rather than the buffers.
  139. Mini-batch stats are used in training mode, and in eval mode when buffers are None.
  140. """
  141. if self.training:
  142. bn_training = True
  143. else:
  144. bn_training = (self.running_mean is None) and (self.running_var is None)
  145. r"""
  146. Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
  147. passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
  148. used for normalization (i.e. in eval mode when buffers are not None).
  149. """
  150. return F.batch_norm(
  151. input,
  152. # If buffers are not to be tracked, ensure that they won't be updated
  153. self.running_mean
  154. if not self.training or self.track_running_stats
  155. else None,
  156. self.running_var if not self.training or self.track_running_stats else None,
  157. self.weight,
  158. self.bias,
  159. bn_training,
  160. exponential_average_factor,
  161. self.eps,
  162. )
  163. class _LazyNormBase(LazyModuleMixin, _NormBase):
  164. weight: UninitializedParameter # type: ignore[assignment]
  165. bias: UninitializedParameter # type: ignore[assignment]
  166. def __init__(self, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True,
  167. device=None, dtype=None) -> None:
  168. factory_kwargs = {'device': device, 'dtype': dtype}
  169. super().__init__(
  170. # affine and track_running_stats are hardcoded to False to
  171. # avoid creating tensors that will soon be overwritten.
  172. 0,
  173. eps,
  174. momentum,
  175. False,
  176. False,
  177. **factory_kwargs,
  178. )
  179. self.affine = affine
  180. self.track_running_stats = track_running_stats
  181. if self.affine:
  182. self.weight = UninitializedParameter(**factory_kwargs)
  183. self.bias = UninitializedParameter(**factory_kwargs)
  184. if self.track_running_stats:
  185. self.running_mean = UninitializedBuffer(**factory_kwargs)
  186. self.running_var = UninitializedBuffer(**factory_kwargs)
  187. self.num_batches_tracked = torch.tensor(
  188. 0, dtype=torch.long, **{k: v for k, v in factory_kwargs.items() if k != 'dtype'})
  189. def reset_parameters(self) -> None:
  190. if not self.has_uninitialized_params() and self.num_features != 0:
  191. super().reset_parameters()
  192. def initialize_parameters(self, input) -> None: # type: ignore[override]
  193. if self.has_uninitialized_params():
  194. self.num_features = input.shape[1]
  195. if self.affine:
  196. assert isinstance(self.weight, UninitializedParameter)
  197. assert isinstance(self.bias, UninitializedParameter)
  198. self.weight.materialize((self.num_features,))
  199. self.bias.materialize((self.num_features,))
  200. if self.track_running_stats:
  201. self.running_mean.materialize((self.num_features,)) # type:ignore[union-attr]
  202. self.running_var.materialize((self.num_features,)) # type:ignore[union-attr]
  203. self.reset_parameters()
  204. class BatchNorm1d(_BatchNorm):
  205. r"""Applies Batch Normalization over a 2D or 3D input as described in the paper
  206. `Batch Normalization: Accelerating Deep Network Training by Reducing
  207. Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
  208. .. math::
  209. y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  210. The mean and standard-deviation are calculated per-dimension over
  211. the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
  212. of size `C` (where `C` is the number of features or channels of the input). By default, the
  213. elements of :math:`\gamma` are set to 1 and the elements of :math:`\beta` are set to 0. The
  214. standard-deviation is calculated via the biased estimator, equivalent to `torch.var(input, unbiased=False)`.
  215. Also by default, during training this layer keeps running estimates of its
  216. computed mean and variance, which are then used for normalization during
  217. evaluation. The running estimates are kept with a default :attr:`momentum`
  218. of 0.1.
  219. If :attr:`track_running_stats` is set to ``False``, this layer then does not
  220. keep running estimates, and batch statistics are instead used during
  221. evaluation time as well.
  222. .. note::
  223. This :attr:`momentum` argument is different from one used in optimizer
  224. classes and the conventional notion of momentum. Mathematically, the
  225. update rule for running statistics here is
  226. :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
  227. where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
  228. new observed value.
  229. Because the Batch Normalization is done over the `C` dimension, computing statistics
  230. on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization.
  231. Args:
  232. num_features: number of features or channels :math:`C` of the input
  233. eps: a value added to the denominator for numerical stability.
  234. Default: 1e-5
  235. momentum: the value used for the running_mean and running_var
  236. computation. Can be set to ``None`` for cumulative moving average
  237. (i.e. simple average). Default: 0.1
  238. affine: a boolean value that when set to ``True``, this module has
  239. learnable affine parameters. Default: ``True``
  240. track_running_stats: a boolean value that when set to ``True``, this
  241. module tracks the running mean and variance, and when set to ``False``,
  242. this module does not track such statistics, and initializes statistics
  243. buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
  244. When these buffers are ``None``, this module always uses batch statistics.
  245. in both training and eval modes. Default: ``True``
  246. Shape:
  247. - Input: :math:`(N, C)` or :math:`(N, C, L)`, where :math:`N` is the batch size,
  248. :math:`C` is the number of features or channels, and :math:`L` is the sequence length
  249. - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
  250. Examples::
  251. >>> # With Learnable Parameters
  252. >>> m = nn.BatchNorm1d(100)
  253. >>> # Without Learnable Parameters
  254. >>> m = nn.BatchNorm1d(100, affine=False)
  255. >>> input = torch.randn(20, 100)
  256. >>> output = m(input)
  257. """
  258. def _check_input_dim(self, input):
  259. if input.dim() != 2 and input.dim() != 3:
  260. raise ValueError(
  261. "expected 2D or 3D input (got {}D input)".format(input.dim())
  262. )
  263. class LazyBatchNorm1d(_LazyNormBase, _BatchNorm):
  264. r"""A :class:`torch.nn.BatchNorm1d` module with lazy initialization of
  265. the ``num_features`` argument of the :class:`BatchNorm1d` that is inferred
  266. from the ``input.size(1)``.
  267. The attributes that will be lazily initialized are `weight`, `bias`,
  268. `running_mean` and `running_var`.
  269. Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
  270. on lazy modules and their limitations.
  271. Args:
  272. eps: a value added to the denominator for numerical stability.
  273. Default: 1e-5
  274. momentum: the value used for the running_mean and running_var
  275. computation. Can be set to ``None`` for cumulative moving average
  276. (i.e. simple average). Default: 0.1
  277. affine: a boolean value that when set to ``True``, this module has
  278. learnable affine parameters. Default: ``True``
  279. track_running_stats: a boolean value that when set to ``True``, this
  280. module tracks the running mean and variance, and when set to ``False``,
  281. this module does not track such statistics, and initializes statistics
  282. buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
  283. When these buffers are ``None``, this module always uses batch statistics.
  284. in both training and eval modes. Default: ``True``
  285. """
  286. cls_to_become = BatchNorm1d # type: ignore[assignment]
  287. def _check_input_dim(self, input):
  288. if input.dim() != 2 and input.dim() != 3:
  289. raise ValueError(
  290. "expected 2D or 3D input (got {}D input)".format(input.dim())
  291. )
  292. class BatchNorm2d(_BatchNorm):
  293. r"""Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs
  294. with additional channel dimension) as described in the paper
  295. `Batch Normalization: Accelerating Deep Network Training by Reducing
  296. Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
  297. .. math::
  298. y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  299. The mean and standard-deviation are calculated per-dimension over
  300. the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
  301. of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set
  302. to 1 and the elements of :math:`\beta` are set to 0. The standard-deviation is calculated
  303. via the biased estimator, equivalent to `torch.var(input, unbiased=False)`.
  304. Also by default, during training this layer keeps running estimates of its
  305. computed mean and variance, which are then used for normalization during
  306. evaluation. The running estimates are kept with a default :attr:`momentum`
  307. of 0.1.
  308. If :attr:`track_running_stats` is set to ``False``, this layer then does not
  309. keep running estimates, and batch statistics are instead used during
  310. evaluation time as well.
  311. .. note::
  312. This :attr:`momentum` argument is different from one used in optimizer
  313. classes and the conventional notion of momentum. Mathematically, the
  314. update rule for running statistics here is
  315. :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
  316. where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
  317. new observed value.
  318. Because the Batch Normalization is done over the `C` dimension, computing statistics
  319. on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization.
  320. Args:
  321. num_features: :math:`C` from an expected input of size
  322. :math:`(N, C, H, W)`
  323. eps: a value added to the denominator for numerical stability.
  324. Default: 1e-5
  325. momentum: the value used for the running_mean and running_var
  326. computation. Can be set to ``None`` for cumulative moving average
  327. (i.e. simple average). Default: 0.1
  328. affine: a boolean value that when set to ``True``, this module has
  329. learnable affine parameters. Default: ``True``
  330. track_running_stats: a boolean value that when set to ``True``, this
  331. module tracks the running mean and variance, and when set to ``False``,
  332. this module does not track such statistics, and initializes statistics
  333. buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
  334. When these buffers are ``None``, this module always uses batch statistics.
  335. in both training and eval modes. Default: ``True``
  336. Shape:
  337. - Input: :math:`(N, C, H, W)`
  338. - Output: :math:`(N, C, H, W)` (same shape as input)
  339. Examples::
  340. >>> # With Learnable Parameters
  341. >>> m = nn.BatchNorm2d(100)
  342. >>> # Without Learnable Parameters
  343. >>> m = nn.BatchNorm2d(100, affine=False)
  344. >>> input = torch.randn(20, 100, 35, 45)
  345. >>> output = m(input)
  346. """
  347. def _check_input_dim(self, input):
  348. if input.dim() != 4:
  349. raise ValueError("expected 4D input (got {}D input)".format(input.dim()))
  350. class LazyBatchNorm2d(_LazyNormBase, _BatchNorm):
  351. r"""A :class:`torch.nn.BatchNorm2d` module with lazy initialization of
  352. the ``num_features`` argument of the :class:`BatchNorm2d` that is inferred
  353. from the ``input.size(1)``.
  354. The attributes that will be lazily initialized are `weight`, `bias`,
  355. `running_mean` and `running_var`.
  356. Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
  357. on lazy modules and their limitations.
  358. Args:
  359. eps: a value added to the denominator for numerical stability.
  360. Default: 1e-5
  361. momentum: the value used for the running_mean and running_var
  362. computation. Can be set to ``None`` for cumulative moving average
  363. (i.e. simple average). Default: 0.1
  364. affine: a boolean value that when set to ``True``, this module has
  365. learnable affine parameters. Default: ``True``
  366. track_running_stats: a boolean value that when set to ``True``, this
  367. module tracks the running mean and variance, and when set to ``False``,
  368. this module does not track such statistics, and initializes statistics
  369. buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
  370. When these buffers are ``None``, this module always uses batch statistics.
  371. in both training and eval modes. Default: ``True``
  372. """
  373. cls_to_become = BatchNorm2d # type: ignore[assignment]
  374. def _check_input_dim(self, input):
  375. if input.dim() != 4:
  376. raise ValueError("expected 4D input (got {}D input)".format(input.dim()))
  377. class BatchNorm3d(_BatchNorm):
  378. r"""Applies Batch Normalization over a 5D input (a mini-batch of 3D inputs
  379. with additional channel dimension) as described in the paper
  380. `Batch Normalization: Accelerating Deep Network Training by Reducing
  381. Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
  382. .. math::
  383. y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  384. The mean and standard-deviation are calculated per-dimension over
  385. the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
  386. of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set
  387. to 1 and the elements of :math:`\beta` are set to 0. The standard-deviation is calculated
  388. via the biased estimator, equivalent to `torch.var(input, unbiased=False)`.
  389. Also by default, during training this layer keeps running estimates of its
  390. computed mean and variance, which are then used for normalization during
  391. evaluation. The running estimates are kept with a default :attr:`momentum`
  392. of 0.1.
  393. If :attr:`track_running_stats` is set to ``False``, this layer then does not
  394. keep running estimates, and batch statistics are instead used during
  395. evaluation time as well.
  396. .. note::
  397. This :attr:`momentum` argument is different from one used in optimizer
  398. classes and the conventional notion of momentum. Mathematically, the
  399. update rule for running statistics here is
  400. :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
  401. where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
  402. new observed value.
  403. Because the Batch Normalization is done over the `C` dimension, computing statistics
  404. on `(N, D, H, W)` slices, it's common terminology to call this Volumetric Batch Normalization
  405. or Spatio-temporal Batch Normalization.
  406. Args:
  407. num_features: :math:`C` from an expected input of size
  408. :math:`(N, C, D, H, W)`
  409. eps: a value added to the denominator for numerical stability.
  410. Default: 1e-5
  411. momentum: the value used for the running_mean and running_var
  412. computation. Can be set to ``None`` for cumulative moving average
  413. (i.e. simple average). Default: 0.1
  414. affine: a boolean value that when set to ``True``, this module has
  415. learnable affine parameters. Default: ``True``
  416. track_running_stats: a boolean value that when set to ``True``, this
  417. module tracks the running mean and variance, and when set to ``False``,
  418. this module does not track such statistics, and initializes statistics
  419. buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
  420. When these buffers are ``None``, this module always uses batch statistics.
  421. in both training and eval modes. Default: ``True``
  422. Shape:
  423. - Input: :math:`(N, C, D, H, W)`
  424. - Output: :math:`(N, C, D, H, W)` (same shape as input)
  425. Examples::
  426. >>> # With Learnable Parameters
  427. >>> m = nn.BatchNorm3d(100)
  428. >>> # Without Learnable Parameters
  429. >>> m = nn.BatchNorm3d(100, affine=False)
  430. >>> input = torch.randn(20, 100, 35, 45, 10)
  431. >>> output = m(input)
  432. """
  433. def _check_input_dim(self, input):
  434. if input.dim() != 5:
  435. raise ValueError("expected 5D input (got {}D input)".format(input.dim()))
  436. class LazyBatchNorm3d(_LazyNormBase, _BatchNorm):
  437. r"""A :class:`torch.nn.BatchNorm3d` module with lazy initialization of
  438. the ``num_features`` argument of the :class:`BatchNorm3d` that is inferred
  439. from the ``input.size(1)``.
  440. The attributes that will be lazily initialized are `weight`, `bias`,
  441. `running_mean` and `running_var`.
  442. Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
  443. on lazy modules and their limitations.
  444. Args:
  445. eps: a value added to the denominator for numerical stability.
  446. Default: 1e-5
  447. momentum: the value used for the running_mean and running_var
  448. computation. Can be set to ``None`` for cumulative moving average
  449. (i.e. simple average). Default: 0.1
  450. affine: a boolean value that when set to ``True``, this module has
  451. learnable affine parameters. Default: ``True``
  452. track_running_stats: a boolean value that when set to ``True``, this
  453. module tracks the running mean and variance, and when set to ``False``,
  454. this module does not track such statistics, and initializes statistics
  455. buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
  456. When these buffers are ``None``, this module always uses batch statistics.
  457. in both training and eval modes. Default: ``True``
  458. """
  459. cls_to_become = BatchNorm3d # type: ignore[assignment]
  460. def _check_input_dim(self, input):
  461. if input.dim() != 5:
  462. raise ValueError("expected 5D input (got {}D input)".format(input.dim()))
  463. class SyncBatchNorm(_BatchNorm):
  464. r"""Applies Batch Normalization over a N-Dimensional input (a mini-batch of [N-2]D inputs
  465. with additional channel dimension) as described in the paper
  466. `Batch Normalization: Accelerating Deep Network Training by Reducing
  467. Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
  468. .. math::
  469. y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  470. The mean and standard-deviation are calculated per-dimension over all
  471. mini-batches of the same process groups. :math:`\gamma` and :math:`\beta`
  472. are learnable parameter vectors of size `C` (where `C` is the input size).
  473. By default, the elements of :math:`\gamma` are sampled from
  474. :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0.
  475. The standard-deviation is calculated via the biased estimator, equivalent to
  476. `torch.var(input, unbiased=False)`.
  477. Also by default, during training this layer keeps running estimates of its
  478. computed mean and variance, which are then used for normalization during
  479. evaluation. The running estimates are kept with a default :attr:`momentum`
  480. of 0.1.
  481. If :attr:`track_running_stats` is set to ``False``, this layer then does not
  482. keep running estimates, and batch statistics are instead used during
  483. evaluation time as well.
  484. .. note::
  485. This :attr:`momentum` argument is different from one used in optimizer
  486. classes and the conventional notion of momentum. Mathematically, the
  487. update rule for running statistics here is
  488. :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
  489. where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
  490. new observed value.
  491. Because the Batch Normalization is done for each channel in the ``C`` dimension, computing
  492. statistics on ``(N, +)`` slices, it's common terminology to call this Volumetric Batch
  493. Normalization or Spatio-temporal Batch Normalization.
  494. Currently :class:`SyncBatchNorm` only supports
  495. :class:`~torch.nn.DistributedDataParallel` (DDP) with single GPU per process. Use
  496. :meth:`torch.nn.SyncBatchNorm.convert_sync_batchnorm()` to convert
  497. :attr:`BatchNorm*D` layer to :class:`SyncBatchNorm` before wrapping
  498. Network with DDP.
  499. Args:
  500. num_features: :math:`C` from an expected input of size
  501. :math:`(N, C, +)`
  502. eps: a value added to the denominator for numerical stability.
  503. Default: ``1e-5``
  504. momentum: the value used for the running_mean and running_var
  505. computation. Can be set to ``None`` for cumulative moving average
  506. (i.e. simple average). Default: 0.1
  507. affine: a boolean value that when set to ``True``, this module has
  508. learnable affine parameters. Default: ``True``
  509. track_running_stats: a boolean value that when set to ``True``, this
  510. module tracks the running mean and variance, and when set to ``False``,
  511. this module does not track such statistics, and initializes statistics
  512. buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
  513. When these buffers are ``None``, this module always uses batch statistics.
  514. in both training and eval modes. Default: ``True``
  515. process_group: synchronization of stats happen within each process group
  516. individually. Default behavior is synchronization across the whole
  517. world
  518. Shape:
  519. - Input: :math:`(N, C, +)`
  520. - Output: :math:`(N, C, +)` (same shape as input)
  521. .. note::
  522. Synchronization of batchnorm statistics occurs only while training, i.e.
  523. synchronization is disabled when ``model.eval()`` is set or if
  524. ``self.training`` is otherwise ``False``.
  525. Examples::
  526. >>> # xdoctest: +SKIP
  527. >>> # With Learnable Parameters
  528. >>> m = nn.SyncBatchNorm(100)
  529. >>> # creating process group (optional)
  530. >>> # ranks is a list of int identifying rank ids.
  531. >>> ranks = list(range(8))
  532. >>> r1, r2 = ranks[:4], ranks[4:]
  533. >>> # Note: every rank calls into new_group for every
  534. >>> # process group created, even if that rank is not
  535. >>> # part of the group.
  536. >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]
  537. >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]
  538. >>> # Without Learnable Parameters
  539. >>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group)
  540. >>> input = torch.randn(20, 100, 35, 45, 10)
  541. >>> output = m(input)
  542. >>> # network is nn.BatchNorm layer
  543. >>> sync_bn_network = nn.SyncBatchNorm.convert_sync_batchnorm(network, process_group)
  544. >>> # only single gpu per process is currently supported
  545. >>> ddp_sync_bn_network = torch.nn.parallel.DistributedDataParallel(
  546. >>> sync_bn_network,
  547. >>> device_ids=[args.local_rank],
  548. >>> output_device=args.local_rank)
  549. """
  550. def __init__(
  551. self,
  552. num_features: int,
  553. eps: float = 1e-5,
  554. momentum: float = 0.1,
  555. affine: bool = True,
  556. track_running_stats: bool = True,
  557. process_group: Optional[Any] = None,
  558. device=None,
  559. dtype=None
  560. ) -> None:
  561. factory_kwargs = {'device': device, 'dtype': dtype}
  562. super().__init__(
  563. num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
  564. )
  565. self.process_group = process_group
  566. def _check_input_dim(self, input):
  567. if input.dim() < 2:
  568. raise ValueError(
  569. "expected at least 2D input (got {}D input)".format(input.dim())
  570. )
  571. def _check_non_zero_input_channels(self, input):
  572. if input.size(1) == 0:
  573. raise ValueError(
  574. "SyncBatchNorm number of input channels should be non-zero"
  575. )
  576. def forward(self, input: Tensor) -> Tensor:
  577. self._check_input_dim(input)
  578. self._check_non_zero_input_channels(input)
  579. # exponential_average_factor is set to self.momentum
  580. # (when it is available) only so that it gets updated
  581. # in ONNX graph when this node is exported to ONNX.
  582. if self.momentum is None:
  583. exponential_average_factor = 0.0
  584. else:
  585. exponential_average_factor = self.momentum
  586. if self.training and self.track_running_stats:
  587. assert self.num_batches_tracked is not None
  588. self.num_batches_tracked.add_(1)
  589. if self.momentum is None: # use cumulative moving average
  590. exponential_average_factor = 1.0 / self.num_batches_tracked.item()
  591. else: # use exponential moving average
  592. exponential_average_factor = self.momentum
  593. r"""
  594. Decide whether the mini-batch stats should be used for normalization rather than the buffers.
  595. Mini-batch stats are used in training mode, and in eval mode when buffers are None.
  596. """
  597. if self.training:
  598. bn_training = True
  599. else:
  600. bn_training = (self.running_mean is None) and (self.running_var is None)
  601. r"""
  602. Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
  603. passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
  604. used for normalization (i.e. in eval mode when buffers are not None).
  605. """
  606. # If buffers are not to be tracked, ensure that they won't be updated
  607. running_mean = (
  608. self.running_mean if not self.training or self.track_running_stats else None
  609. )
  610. running_var = (
  611. self.running_var if not self.training or self.track_running_stats else None
  612. )
  613. # Don't sync batchnorm stats in inference mode (model.eval()).
  614. need_sync = (bn_training and self.training and
  615. torch.distributed.is_available() and torch.distributed.is_initialized())
  616. if need_sync:
  617. # currently only GPU input is supported
  618. if not input.is_cuda:
  619. raise ValueError("SyncBatchNorm expected input tensor to be on GPU")
  620. process_group = torch.distributed.group.WORLD
  621. if self.process_group:
  622. process_group = self.process_group
  623. world_size = torch.distributed.get_world_size(process_group)
  624. need_sync = world_size > 1
  625. # fallback to framework BN when synchronization is not necessary
  626. if not need_sync:
  627. return F.batch_norm(
  628. input,
  629. running_mean,
  630. running_var,
  631. self.weight,
  632. self.bias,
  633. bn_training,
  634. exponential_average_factor,
  635. self.eps,
  636. )
  637. else:
  638. assert bn_training
  639. return sync_batch_norm.apply(
  640. input,
  641. self.weight,
  642. self.bias,
  643. running_mean,
  644. running_var,
  645. self.eps,
  646. exponential_average_factor,
  647. process_group,
  648. world_size,
  649. )
  650. @classmethod
  651. def convert_sync_batchnorm(cls, module, process_group=None):
  652. r"""Helper function to convert all :attr:`BatchNorm*D` layers in the model to
  653. :class:`torch.nn.SyncBatchNorm` layers.
  654. Args:
  655. module (nn.Module): module containing one or more :attr:`BatchNorm*D` layers
  656. process_group (optional): process group to scope synchronization,
  657. default is the whole world
  658. Returns:
  659. The original :attr:`module` with the converted :class:`torch.nn.SyncBatchNorm`
  660. layers. If the original :attr:`module` is a :attr:`BatchNorm*D` layer,
  661. a new :class:`torch.nn.SyncBatchNorm` layer object will be returned
  662. instead.
  663. Example::
  664. >>> # Network with nn.BatchNorm layer
  665. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
  666. >>> module = torch.nn.Sequential(
  667. >>> torch.nn.Linear(20, 100),
  668. >>> torch.nn.BatchNorm1d(100),
  669. >>> ).cuda()
  670. >>> # creating process group (optional)
  671. >>> # ranks is a list of int identifying rank ids.
  672. >>> ranks = list(range(8))
  673. >>> r1, r2 = ranks[:4], ranks[4:]
  674. >>> # Note: every rank calls into new_group for every
  675. >>> # process group created, even if that rank is not
  676. >>> # part of the group.
  677. >>> # xdoctest: +SKIP("distributed")
  678. >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]
  679. >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]
  680. >>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module, process_group)
  681. """
  682. module_output = module
  683. if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
  684. module_output = torch.nn.SyncBatchNorm(
  685. module.num_features,
  686. module.eps,
  687. module.momentum,
  688. module.affine,
  689. module.track_running_stats,
  690. process_group,
  691. )
  692. if module.affine:
  693. with torch.no_grad():
  694. module_output.weight = module.weight
  695. module_output.bias = module.bias
  696. module_output.running_mean = module.running_mean
  697. module_output.running_var = module.running_var
  698. module_output.num_batches_tracked = module.num_batches_tracked
  699. if hasattr(module, "qconfig"):
  700. module_output.qconfig = module.qconfig
  701. for name, child in module.named_children():
  702. module_output.add_module(
  703. name, cls.convert_sync_batchnorm(child, process_group)
  704. )
  705. del module
  706. return module_output