instancenorm.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421
  1. from torch import Tensor
  2. from .batchnorm import _LazyNormBase, _NormBase
  3. from .. import functional as F
  4. __all__ = ['InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d', 'LazyInstanceNorm1d',
  5. 'LazyInstanceNorm2d', 'LazyInstanceNorm3d']
  6. class _InstanceNorm(_NormBase):
  7. def __init__(
  8. self,
  9. num_features: int,
  10. eps: float = 1e-5,
  11. momentum: float = 0.1,
  12. affine: bool = False,
  13. track_running_stats: bool = False,
  14. device=None,
  15. dtype=None
  16. ) -> None:
  17. factory_kwargs = {'device': device, 'dtype': dtype}
  18. super().__init__(
  19. num_features, eps, momentum, affine, track_running_stats, **factory_kwargs)
  20. def _check_input_dim(self, input):
  21. raise NotImplementedError
  22. def _get_no_batch_dim(self):
  23. raise NotImplementedError
  24. def _handle_no_batch_input(self, input):
  25. return self._apply_instance_norm(input.unsqueeze(0)).squeeze(0)
  26. def _apply_instance_norm(self, input):
  27. return F.instance_norm(
  28. input, self.running_mean, self.running_var, self.weight, self.bias,
  29. self.training or not self.track_running_stats, self.momentum, self.eps)
  30. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
  31. missing_keys, unexpected_keys, error_msgs):
  32. version = local_metadata.get('version', None)
  33. # at version 1: removed running_mean and running_var when
  34. # track_running_stats=False (default)
  35. if version is None and not self.track_running_stats:
  36. running_stats_keys = []
  37. for name in ('running_mean', 'running_var'):
  38. key = prefix + name
  39. if key in state_dict:
  40. running_stats_keys.append(key)
  41. if len(running_stats_keys) > 0:
  42. error_msgs.append(
  43. 'Unexpected running stats buffer(s) {names} for {klass} '
  44. 'with track_running_stats=False. If state_dict is a '
  45. 'checkpoint saved before 0.4.0, this may be expected '
  46. 'because {klass} does not track running stats by default '
  47. 'since 0.4.0. Please remove these keys from state_dict. If '
  48. 'the running stats are actually needed, instead set '
  49. 'track_running_stats=True in {klass} to enable them. See '
  50. 'the documentation of {klass} for details.'
  51. .format(names=" and ".join('"{}"'.format(k) for k in running_stats_keys),
  52. klass=self.__class__.__name__))
  53. for key in running_stats_keys:
  54. state_dict.pop(key)
  55. super()._load_from_state_dict(
  56. state_dict, prefix, local_metadata, strict,
  57. missing_keys, unexpected_keys, error_msgs)
  58. def forward(self, input: Tensor) -> Tensor:
  59. self._check_input_dim(input)
  60. if input.dim() == self._get_no_batch_dim():
  61. return self._handle_no_batch_input(input)
  62. return self._apply_instance_norm(input)
  63. class InstanceNorm1d(_InstanceNorm):
  64. r"""Applies Instance Normalization over a 2D (unbatched) or 3D (batched) input
  65. as described in the paper
  66. `Instance Normalization: The Missing Ingredient for Fast Stylization
  67. <https://arxiv.org/abs/1607.08022>`__.
  68. .. math::
  69. y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  70. The mean and standard-deviation are calculated per-dimension separately
  71. for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors
  72. of size `C` (where `C` is the number of features or channels of the input) if :attr:`affine` is ``True``.
  73. The standard-deviation is calculated via the biased estimator, equivalent to
  74. `torch.var(input, unbiased=False)`.
  75. By default, this layer uses instance statistics computed from input data in
  76. both training and evaluation modes.
  77. If :attr:`track_running_stats` is set to ``True``, during training this
  78. layer keeps running estimates of its computed mean and variance, which are
  79. then used for normalization during evaluation. The running estimates are
  80. kept with a default :attr:`momentum` of 0.1.
  81. .. note::
  82. This :attr:`momentum` argument is different from one used in optimizer
  83. classes and the conventional notion of momentum. Mathematically, the
  84. update rule for running statistics here is
  85. :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
  86. where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
  87. new observed value.
  88. .. note::
  89. :class:`InstanceNorm1d` and :class:`LayerNorm` are very similar, but
  90. have some subtle differences. :class:`InstanceNorm1d` is applied
  91. on each channel of channeled data like multidimensional time series, but
  92. :class:`LayerNorm` is usually applied on entire sample and often in NLP
  93. tasks. Additionally, :class:`LayerNorm` applies elementwise affine
  94. transform, while :class:`InstanceNorm1d` usually don't apply affine
  95. transform.
  96. Args:
  97. num_features: number of features or channels :math:`C` of the input
  98. eps: a value added to the denominator for numerical stability. Default: 1e-5
  99. momentum: the value used for the running_mean and running_var computation. Default: 0.1
  100. affine: a boolean value that when set to ``True``, this module has
  101. learnable affine parameters, initialized the same way as done for batch normalization.
  102. Default: ``False``.
  103. track_running_stats: a boolean value that when set to ``True``, this
  104. module tracks the running mean and variance, and when set to ``False``,
  105. this module does not track such statistics and always uses batch
  106. statistics in both training and eval modes. Default: ``False``
  107. Shape:
  108. - Input: :math:`(N, C, L)` or :math:`(C, L)`
  109. - Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input)
  110. Examples::
  111. >>> # Without Learnable Parameters
  112. >>> m = nn.InstanceNorm1d(100)
  113. >>> # With Learnable Parameters
  114. >>> m = nn.InstanceNorm1d(100, affine=True)
  115. >>> input = torch.randn(20, 100, 40)
  116. >>> output = m(input)
  117. """
  118. def _get_no_batch_dim(self):
  119. return 2
  120. def _check_input_dim(self, input):
  121. if input.dim() not in (2, 3):
  122. raise ValueError('expected 2D or 3D input (got {}D input)'
  123. .format(input.dim()))
  124. class LazyInstanceNorm1d(_LazyNormBase, _InstanceNorm):
  125. r"""A :class:`torch.nn.InstanceNorm1d` module with lazy initialization of
  126. the ``num_features`` argument of the :class:`InstanceNorm1d` that is inferred
  127. from the ``input.size(1)``.
  128. The attributes that will be lazily initialized are `weight`, `bias`,
  129. `running_mean` and `running_var`.
  130. Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
  131. on lazy modules and their limitations.
  132. Args:
  133. num_features: :math:`C` from an expected input of size
  134. :math:`(N, C, L)` or :math:`(C, L)`
  135. eps: a value added to the denominator for numerical stability. Default: 1e-5
  136. momentum: the value used for the running_mean and running_var computation. Default: 0.1
  137. affine: a boolean value that when set to ``True``, this module has
  138. learnable affine parameters, initialized the same way as done for batch normalization.
  139. Default: ``False``.
  140. track_running_stats: a boolean value that when set to ``True``, this
  141. module tracks the running mean and variance, and when set to ``False``,
  142. this module does not track such statistics and always uses batch
  143. statistics in both training and eval modes. Default: ``False``
  144. Shape:
  145. - Input: :math:`(N, C, L)` or :math:`(C, L)`
  146. - Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input)
  147. """
  148. cls_to_become = InstanceNorm1d # type: ignore[assignment]
  149. def _get_no_batch_dim(self):
  150. return 2
  151. def _check_input_dim(self, input):
  152. if input.dim() not in (2, 3):
  153. raise ValueError('expected 2D or 3D input (got {}D input)'
  154. .format(input.dim()))
  155. class InstanceNorm2d(_InstanceNorm):
  156. r"""Applies Instance Normalization over a 4D input (a mini-batch of 2D inputs
  157. with additional channel dimension) as described in the paper
  158. `Instance Normalization: The Missing Ingredient for Fast Stylization
  159. <https://arxiv.org/abs/1607.08022>`__.
  160. .. math::
  161. y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  162. The mean and standard-deviation are calculated per-dimension separately
  163. for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors
  164. of size `C` (where `C` is the input size) if :attr:`affine` is ``True``.
  165. The standard-deviation is calculated via the biased estimator, equivalent to
  166. `torch.var(input, unbiased=False)`.
  167. By default, this layer uses instance statistics computed from input data in
  168. both training and evaluation modes.
  169. If :attr:`track_running_stats` is set to ``True``, during training this
  170. layer keeps running estimates of its computed mean and variance, which are
  171. then used for normalization during evaluation. The running estimates are
  172. kept with a default :attr:`momentum` of 0.1.
  173. .. note::
  174. This :attr:`momentum` argument is different from one used in optimizer
  175. classes and the conventional notion of momentum. Mathematically, the
  176. update rule for running statistics here is
  177. :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
  178. where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
  179. new observed value.
  180. .. note::
  181. :class:`InstanceNorm2d` and :class:`LayerNorm` are very similar, but
  182. have some subtle differences. :class:`InstanceNorm2d` is applied
  183. on each channel of channeled data like RGB images, but
  184. :class:`LayerNorm` is usually applied on entire sample and often in NLP
  185. tasks. Additionally, :class:`LayerNorm` applies elementwise affine
  186. transform, while :class:`InstanceNorm2d` usually don't apply affine
  187. transform.
  188. Args:
  189. num_features: :math:`C` from an expected input of size
  190. :math:`(N, C, H, W)` or :math:`(C, H, W)`
  191. eps: a value added to the denominator for numerical stability. Default: 1e-5
  192. momentum: the value used for the running_mean and running_var computation. Default: 0.1
  193. affine: a boolean value that when set to ``True``, this module has
  194. learnable affine parameters, initialized the same way as done for batch normalization.
  195. Default: ``False``.
  196. track_running_stats: a boolean value that when set to ``True``, this
  197. module tracks the running mean and variance, and when set to ``False``,
  198. this module does not track such statistics and always uses batch
  199. statistics in both training and eval modes. Default: ``False``
  200. Shape:
  201. - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`
  202. - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input)
  203. Examples::
  204. >>> # Without Learnable Parameters
  205. >>> m = nn.InstanceNorm2d(100)
  206. >>> # With Learnable Parameters
  207. >>> m = nn.InstanceNorm2d(100, affine=True)
  208. >>> input = torch.randn(20, 100, 35, 45)
  209. >>> output = m(input)
  210. """
  211. def _get_no_batch_dim(self):
  212. return 3
  213. def _check_input_dim(self, input):
  214. if input.dim() not in (3, 4):
  215. raise ValueError('expected 3D or 4D input (got {}D input)'
  216. .format(input.dim()))
  217. class LazyInstanceNorm2d(_LazyNormBase, _InstanceNorm):
  218. r"""A :class:`torch.nn.InstanceNorm2d` module with lazy initialization of
  219. the ``num_features`` argument of the :class:`InstanceNorm2d` that is inferred
  220. from the ``input.size(1)``.
  221. The attributes that will be lazily initialized are `weight`, `bias`,
  222. `running_mean` and `running_var`.
  223. Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
  224. on lazy modules and their limitations.
  225. Args:
  226. num_features: :math:`C` from an expected input of size
  227. :math:`(N, C, H, W)` or :math:`(C, H, W)`
  228. eps: a value added to the denominator for numerical stability. Default: 1e-5
  229. momentum: the value used for the running_mean and running_var computation. Default: 0.1
  230. affine: a boolean value that when set to ``True``, this module has
  231. learnable affine parameters, initialized the same way as done for batch normalization.
  232. Default: ``False``.
  233. track_running_stats: a boolean value that when set to ``True``, this
  234. module tracks the running mean and variance, and when set to ``False``,
  235. this module does not track such statistics and always uses batch
  236. statistics in both training and eval modes. Default: ``False``
  237. Shape:
  238. - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`
  239. - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input)
  240. """
  241. cls_to_become = InstanceNorm2d # type: ignore[assignment]
  242. def _get_no_batch_dim(self):
  243. return 3
  244. def _check_input_dim(self, input):
  245. if input.dim() not in (3, 4):
  246. raise ValueError('expected 3D or 4D input (got {}D input)'
  247. .format(input.dim()))
  248. class InstanceNorm3d(_InstanceNorm):
  249. r"""Applies Instance Normalization over a 5D input (a mini-batch of 3D inputs
  250. with additional channel dimension) as described in the paper
  251. `Instance Normalization: The Missing Ingredient for Fast Stylization
  252. <https://arxiv.org/abs/1607.08022>`__.
  253. .. math::
  254. y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  255. The mean and standard-deviation are calculated per-dimension separately
  256. for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors
  257. of size C (where C is the input size) if :attr:`affine` is ``True``.
  258. The standard-deviation is calculated via the biased estimator, equivalent to
  259. `torch.var(input, unbiased=False)`.
  260. By default, this layer uses instance statistics computed from input data in
  261. both training and evaluation modes.
  262. If :attr:`track_running_stats` is set to ``True``, during training this
  263. layer keeps running estimates of its computed mean and variance, which are
  264. then used for normalization during evaluation. The running estimates are
  265. kept with a default :attr:`momentum` of 0.1.
  266. .. note::
  267. This :attr:`momentum` argument is different from one used in optimizer
  268. classes and the conventional notion of momentum. Mathematically, the
  269. update rule for running statistics here is
  270. :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
  271. where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
  272. new observed value.
  273. .. note::
  274. :class:`InstanceNorm3d` and :class:`LayerNorm` are very similar, but
  275. have some subtle differences. :class:`InstanceNorm3d` is applied
  276. on each channel of channeled data like 3D models with RGB color, but
  277. :class:`LayerNorm` is usually applied on entire sample and often in NLP
  278. tasks. Additionally, :class:`LayerNorm` applies elementwise affine
  279. transform, while :class:`InstanceNorm3d` usually don't apply affine
  280. transform.
  281. Args:
  282. num_features: :math:`C` from an expected input of size
  283. :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`
  284. eps: a value added to the denominator for numerical stability. Default: 1e-5
  285. momentum: the value used for the running_mean and running_var computation. Default: 0.1
  286. affine: a boolean value that when set to ``True``, this module has
  287. learnable affine parameters, initialized the same way as done for batch normalization.
  288. Default: ``False``.
  289. track_running_stats: a boolean value that when set to ``True``, this
  290. module tracks the running mean and variance, and when set to ``False``,
  291. this module does not track such statistics and always uses batch
  292. statistics in both training and eval modes. Default: ``False``
  293. Shape:
  294. - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`
  295. - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input)
  296. Examples::
  297. >>> # Without Learnable Parameters
  298. >>> m = nn.InstanceNorm3d(100)
  299. >>> # With Learnable Parameters
  300. >>> m = nn.InstanceNorm3d(100, affine=True)
  301. >>> input = torch.randn(20, 100, 35, 45, 10)
  302. >>> output = m(input)
  303. """
  304. def _get_no_batch_dim(self):
  305. return 4
  306. def _check_input_dim(self, input):
  307. if input.dim() not in (4, 5):
  308. raise ValueError('expected 4D or 5D input (got {}D input)'
  309. .format(input.dim()))
  310. class LazyInstanceNorm3d(_LazyNormBase, _InstanceNorm):
  311. r"""A :class:`torch.nn.InstanceNorm3d` module with lazy initialization of
  312. the ``num_features`` argument of the :class:`InstanceNorm3d` that is inferred
  313. from the ``input.size(1)``.
  314. The attributes that will be lazily initialized are `weight`, `bias`,
  315. `running_mean` and `running_var`.
  316. Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
  317. on lazy modules and their limitations.
  318. Args:
  319. num_features: :math:`C` from an expected input of size
  320. :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`
  321. eps: a value added to the denominator for numerical stability. Default: 1e-5
  322. momentum: the value used for the running_mean and running_var computation. Default: 0.1
  323. affine: a boolean value that when set to ``True``, this module has
  324. learnable affine parameters, initialized the same way as done for batch normalization.
  325. Default: ``False``.
  326. track_running_stats: a boolean value that when set to ``True``, this
  327. module tracks the running mean and variance, and when set to ``False``,
  328. this module does not track such statistics and always uses batch
  329. statistics in both training and eval modes. Default: ``False``
  330. Shape:
  331. - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`
  332. - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input)
  333. """
  334. cls_to_become = InstanceNorm3d # type: ignore[assignment]
  335. def _get_no_batch_dim(self):
  336. return 4
  337. def _check_input_dim(self, input):
  338. if input.dim() not in (4, 5):
  339. raise ValueError('expected 4D or 5D input (got {}D input)'
  340. .format(input.dim()))