normalization.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. import torch
  2. import numbers
  3. from torch.nn.parameter import Parameter
  4. from .module import Module
  5. from ._functions import CrossMapLRN2d as _cross_map_lrn2d
  6. from .. import functional as F
  7. from .. import init
  8. from torch import Tensor, Size
  9. from typing import Union, List, Tuple
  10. __all__ = ['LocalResponseNorm', 'CrossMapLRN2d', 'LayerNorm', 'GroupNorm']
  11. class LocalResponseNorm(Module):
  12. r"""Applies local response normalization over an input signal composed
  13. of several input planes, where channels occupy the second dimension.
  14. Applies normalization across channels.
  15. .. math::
  16. b_{c} = a_{c}\left(k + \frac{\alpha}{n}
  17. \sum_{c'=\max(0, c-n/2)}^{\min(N-1,c+n/2)}a_{c'}^2\right)^{-\beta}
  18. Args:
  19. size: amount of neighbouring channels used for normalization
  20. alpha: multiplicative factor. Default: 0.0001
  21. beta: exponent. Default: 0.75
  22. k: additive factor. Default: 1
  23. Shape:
  24. - Input: :math:`(N, C, *)`
  25. - Output: :math:`(N, C, *)` (same shape as input)
  26. Examples::
  27. >>> lrn = nn.LocalResponseNorm(2)
  28. >>> signal_2d = torch.randn(32, 5, 24, 24)
  29. >>> signal_4d = torch.randn(16, 5, 7, 7, 7, 7)
  30. >>> output_2d = lrn(signal_2d)
  31. >>> output_4d = lrn(signal_4d)
  32. """
  33. __constants__ = ['size', 'alpha', 'beta', 'k']
  34. size: int
  35. alpha: float
  36. beta: float
  37. k: float
  38. def __init__(self, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1.) -> None:
  39. super().__init__()
  40. self.size = size
  41. self.alpha = alpha
  42. self.beta = beta
  43. self.k = k
  44. def forward(self, input: Tensor) -> Tensor:
  45. return F.local_response_norm(input, self.size, self.alpha, self.beta,
  46. self.k)
  47. def extra_repr(self):
  48. return '{size}, alpha={alpha}, beta={beta}, k={k}'.format(**self.__dict__)
  49. class CrossMapLRN2d(Module):
  50. size: int
  51. alpha: float
  52. beta: float
  53. k: float
  54. def __init__(self, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1) -> None:
  55. super().__init__()
  56. self.size = size
  57. self.alpha = alpha
  58. self.beta = beta
  59. self.k = k
  60. def forward(self, input: Tensor) -> Tensor:
  61. return _cross_map_lrn2d.apply(input, self.size, self.alpha, self.beta,
  62. self.k)
  63. def extra_repr(self) -> str:
  64. return '{size}, alpha={alpha}, beta={beta}, k={k}'.format(**self.__dict__)
  65. _shape_t = Union[int, List[int], Size]
  66. class LayerNorm(Module):
  67. r"""Applies Layer Normalization over a mini-batch of inputs as described in
  68. the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__
  69. .. math::
  70. y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  71. The mean and standard-deviation are calculated over the last `D` dimensions, where `D`
  72. is the dimension of :attr:`normalized_shape`. For example, if :attr:`normalized_shape`
  73. is ``(3, 5)`` (a 2-dimensional shape), the mean and standard-deviation are computed over
  74. the last 2 dimensions of the input (i.e. ``input.mean((-2, -1))``).
  75. :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
  76. :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
  77. The standard-deviation is calculated via the biased estimator, equivalent to
  78. `torch.var(input, unbiased=False)`.
  79. .. note::
  80. Unlike Batch Normalization and Instance Normalization, which applies
  81. scalar scale and bias for each entire channel/plane with the
  82. :attr:`affine` option, Layer Normalization applies per-element scale and
  83. bias with :attr:`elementwise_affine`.
  84. This layer uses statistics computed from input data in both training and
  85. evaluation modes.
  86. Args:
  87. normalized_shape (int or list or torch.Size): input shape from an expected input
  88. of size
  89. .. math::
  90. [* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1]
  91. \times \ldots \times \text{normalized\_shape}[-1]]
  92. If a single integer is used, it is treated as a singleton list, and this module will
  93. normalize over the last dimension which is expected to be of that specific size.
  94. eps: a value added to the denominator for numerical stability. Default: 1e-5
  95. elementwise_affine: a boolean value that when set to ``True``, this module
  96. has learnable per-element affine parameters initialized to ones (for weights)
  97. and zeros (for biases). Default: ``True``.
  98. Attributes:
  99. weight: the learnable weights of the module of shape
  100. :math:`\text{normalized\_shape}` when :attr:`elementwise_affine` is set to ``True``.
  101. The values are initialized to 1.
  102. bias: the learnable bias of the module of shape
  103. :math:`\text{normalized\_shape}` when :attr:`elementwise_affine` is set to ``True``.
  104. The values are initialized to 0.
  105. Shape:
  106. - Input: :math:`(N, *)`
  107. - Output: :math:`(N, *)` (same shape as input)
  108. Examples::
  109. >>> # NLP Example
  110. >>> batch, sentence_length, embedding_dim = 20, 5, 10
  111. >>> embedding = torch.randn(batch, sentence_length, embedding_dim)
  112. >>> layer_norm = nn.LayerNorm(embedding_dim)
  113. >>> # Activate module
  114. >>> layer_norm(embedding)
  115. >>>
  116. >>> # Image Example
  117. >>> N, C, H, W = 20, 5, 10, 10
  118. >>> input = torch.randn(N, C, H, W)
  119. >>> # Normalize over the last three dimensions (i.e. the channel and spatial dimensions)
  120. >>> # as shown in the image below
  121. >>> layer_norm = nn.LayerNorm([C, H, W])
  122. >>> output = layer_norm(input)
  123. .. image:: ../_static/img/nn/layer_norm.jpg
  124. :scale: 50 %
  125. """
  126. __constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
  127. normalized_shape: Tuple[int, ...]
  128. eps: float
  129. elementwise_affine: bool
  130. def __init__(self, normalized_shape: _shape_t, eps: float = 1e-5, elementwise_affine: bool = True,
  131. device=None, dtype=None) -> None:
  132. factory_kwargs = {'device': device, 'dtype': dtype}
  133. super().__init__()
  134. if isinstance(normalized_shape, numbers.Integral):
  135. # mypy error: incompatible types in assignment
  136. normalized_shape = (normalized_shape,) # type: ignore[assignment]
  137. self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
  138. self.eps = eps
  139. self.elementwise_affine = elementwise_affine
  140. if self.elementwise_affine:
  141. self.weight = Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
  142. self.bias = Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
  143. else:
  144. self.register_parameter('weight', None)
  145. self.register_parameter('bias', None)
  146. self.reset_parameters()
  147. def reset_parameters(self) -> None:
  148. if self.elementwise_affine:
  149. init.ones_(self.weight)
  150. init.zeros_(self.bias)
  151. def forward(self, input: Tensor) -> Tensor:
  152. return F.layer_norm(
  153. input, self.normalized_shape, self.weight, self.bias, self.eps)
  154. def extra_repr(self) -> str:
  155. return '{normalized_shape}, eps={eps}, ' \
  156. 'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
  157. class GroupNorm(Module):
  158. r"""Applies Group Normalization over a mini-batch of inputs as described in
  159. the paper `Group Normalization <https://arxiv.org/abs/1803.08494>`__
  160. .. math::
  161. y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  162. The input channels are separated into :attr:`num_groups` groups, each containing
  163. ``num_channels / num_groups`` channels. :attr:`num_channels` must be divisible by
  164. :attr:`num_groups`. The mean and standard-deviation are calculated
  165. separately over the each group. :math:`\gamma` and :math:`\beta` are learnable
  166. per-channel affine transform parameter vectors of size :attr:`num_channels` if
  167. :attr:`affine` is ``True``.
  168. The standard-deviation is calculated via the biased estimator, equivalent to
  169. `torch.var(input, unbiased=False)`.
  170. This layer uses statistics computed from input data in both training and
  171. evaluation modes.
  172. Args:
  173. num_groups (int): number of groups to separate the channels into
  174. num_channels (int): number of channels expected in input
  175. eps: a value added to the denominator for numerical stability. Default: 1e-5
  176. affine: a boolean value that when set to ``True``, this module
  177. has learnable per-channel affine parameters initialized to ones (for weights)
  178. and zeros (for biases). Default: ``True``.
  179. Shape:
  180. - Input: :math:`(N, C, *)` where :math:`C=\text{num\_channels}`
  181. - Output: :math:`(N, C, *)` (same shape as input)
  182. Examples::
  183. >>> input = torch.randn(20, 6, 10, 10)
  184. >>> # Separate 6 channels into 3 groups
  185. >>> m = nn.GroupNorm(3, 6)
  186. >>> # Separate 6 channels into 6 groups (equivalent with InstanceNorm)
  187. >>> m = nn.GroupNorm(6, 6)
  188. >>> # Put all 6 channels into a single group (equivalent with LayerNorm)
  189. >>> m = nn.GroupNorm(1, 6)
  190. >>> # Activating the module
  191. >>> output = m(input)
  192. """
  193. __constants__ = ['num_groups', 'num_channels', 'eps', 'affine']
  194. num_groups: int
  195. num_channels: int
  196. eps: float
  197. affine: bool
  198. def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5, affine: bool = True,
  199. device=None, dtype=None) -> None:
  200. factory_kwargs = {'device': device, 'dtype': dtype}
  201. super().__init__()
  202. if num_channels % num_groups != 0:
  203. raise ValueError('num_channels must be divisible by num_groups')
  204. self.num_groups = num_groups
  205. self.num_channels = num_channels
  206. self.eps = eps
  207. self.affine = affine
  208. if self.affine:
  209. self.weight = Parameter(torch.empty(num_channels, **factory_kwargs))
  210. self.bias = Parameter(torch.empty(num_channels, **factory_kwargs))
  211. else:
  212. self.register_parameter('weight', None)
  213. self.register_parameter('bias', None)
  214. self.reset_parameters()
  215. def reset_parameters(self) -> None:
  216. if self.affine:
  217. init.ones_(self.weight)
  218. init.zeros_(self.bias)
  219. def forward(self, input: Tensor) -> Tensor:
  220. return F.group_norm(
  221. input, self.num_groups, self.weight, self.bias, self.eps)
  222. def extra_repr(self) -> str:
  223. return '{num_groups}, {num_channels}, eps={eps}, ' \
  224. 'affine={affine}'.format(**self.__dict__)
  225. # TODO: ContrastiveNorm2d
  226. # TODO: DivisiveNorm2d
  227. # TODO: SubtractiveNorm2d