spectral_norm.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. """
  2. Spectral Normalization from https://arxiv.org/abs/1802.05957
  3. """
  4. import torch
  5. from torch.nn.functional import normalize
  6. from typing import Any, Optional, TypeVar
  7. from ..modules import Module
  8. __all__ = ['SpectralNorm', 'SpectralNormLoadStateDictPreHook', 'SpectralNormStateDictHook',
  9. 'spectral_norm', 'remove_spectral_norm']
  10. class SpectralNorm:
  11. # Invariant before and after each forward call:
  12. # u = normalize(W @ v)
  13. # NB: At initialization, this invariant is not enforced
  14. _version: int = 1
  15. # At version 1:
  16. # made `W` not a buffer,
  17. # added `v` as a buffer, and
  18. # made eval mode use `W = u @ W_orig @ v` rather than the stored `W`.
  19. name: str
  20. dim: int
  21. n_power_iterations: int
  22. eps: float
  23. def __init__(self, name: str = 'weight', n_power_iterations: int = 1, dim: int = 0, eps: float = 1e-12) -> None:
  24. self.name = name
  25. self.dim = dim
  26. if n_power_iterations <= 0:
  27. raise ValueError('Expected n_power_iterations to be positive, but '
  28. 'got n_power_iterations={}'.format(n_power_iterations))
  29. self.n_power_iterations = n_power_iterations
  30. self.eps = eps
  31. def reshape_weight_to_matrix(self, weight: torch.Tensor) -> torch.Tensor:
  32. weight_mat = weight
  33. if self.dim != 0:
  34. # permute dim to front
  35. weight_mat = weight_mat.permute(self.dim,
  36. *[d for d in range(weight_mat.dim()) if d != self.dim])
  37. height = weight_mat.size(0)
  38. return weight_mat.reshape(height, -1)
  39. def compute_weight(self, module: Module, do_power_iteration: bool) -> torch.Tensor:
  40. # NB: If `do_power_iteration` is set, the `u` and `v` vectors are
  41. # updated in power iteration **in-place**. This is very important
  42. # because in `DataParallel` forward, the vectors (being buffers) are
  43. # broadcast from the parallelized module to each module replica,
  44. # which is a new module object created on the fly. And each replica
  45. # runs its own spectral norm power iteration. So simply assigning
  46. # the updated vectors to the module this function runs on will cause
  47. # the update to be lost forever. And the next time the parallelized
  48. # module is replicated, the same randomly initialized vectors are
  49. # broadcast and used!
  50. #
  51. # Therefore, to make the change propagate back, we rely on two
  52. # important behaviors (also enforced via tests):
  53. # 1. `DataParallel` doesn't clone storage if the broadcast tensor
  54. # is already on correct device; and it makes sure that the
  55. # parallelized module is already on `device[0]`.
  56. # 2. If the out tensor in `out=` kwarg has correct shape, it will
  57. # just fill in the values.
  58. # Therefore, since the same power iteration is performed on all
  59. # devices, simply updating the tensors in-place will make sure that
  60. # the module replica on `device[0]` will update the _u vector on the
  61. # parallized module (by shared storage).
  62. #
  63. # However, after we update `u` and `v` in-place, we need to **clone**
  64. # them before using them to normalize the weight. This is to support
  65. # backproping through two forward passes, e.g., the common pattern in
  66. # GAN training: loss = D(real) - D(fake). Otherwise, engine will
  67. # complain that variables needed to do backward for the first forward
  68. # (i.e., the `u` and `v` vectors) are changed in the second forward.
  69. weight = getattr(module, self.name + '_orig')
  70. u = getattr(module, self.name + '_u')
  71. v = getattr(module, self.name + '_v')
  72. weight_mat = self.reshape_weight_to_matrix(weight)
  73. if do_power_iteration:
  74. with torch.no_grad():
  75. for _ in range(self.n_power_iterations):
  76. # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
  77. # are the first left and right singular vectors.
  78. # This power iteration produces approximations of `u` and `v`.
  79. v = normalize(torch.mv(weight_mat.t(), u), dim=0, eps=self.eps, out=v)
  80. u = normalize(torch.mv(weight_mat, v), dim=0, eps=self.eps, out=u)
  81. if self.n_power_iterations > 0:
  82. # See above on why we need to clone
  83. u = u.clone(memory_format=torch.contiguous_format)
  84. v = v.clone(memory_format=torch.contiguous_format)
  85. sigma = torch.dot(u, torch.mv(weight_mat, v))
  86. weight = weight / sigma
  87. return weight
  88. def remove(self, module: Module) -> None:
  89. with torch.no_grad():
  90. weight = self.compute_weight(module, do_power_iteration=False)
  91. delattr(module, self.name)
  92. delattr(module, self.name + '_u')
  93. delattr(module, self.name + '_v')
  94. delattr(module, self.name + '_orig')
  95. module.register_parameter(self.name, torch.nn.Parameter(weight.detach()))
  96. def __call__(self, module: Module, inputs: Any) -> None:
  97. setattr(module, self.name, self.compute_weight(module, do_power_iteration=module.training))
  98. def _solve_v_and_rescale(self, weight_mat, u, target_sigma):
  99. # Tries to returns a vector `v` s.t. `u = normalize(W @ v)`
  100. # (the invariant at top of this class) and `u @ W @ v = sigma`.
  101. # This uses pinverse in case W^T W is not invertible.
  102. v = torch.linalg.multi_dot([weight_mat.t().mm(weight_mat).pinverse(), weight_mat.t(), u.unsqueeze(1)]).squeeze(1)
  103. return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v)))
  104. @staticmethod
  105. def apply(module: Module, name: str, n_power_iterations: int, dim: int, eps: float) -> 'SpectralNorm':
  106. for k, hook in module._forward_pre_hooks.items():
  107. if isinstance(hook, SpectralNorm) and hook.name == name:
  108. raise RuntimeError("Cannot register two spectral_norm hooks on "
  109. "the same parameter {}".format(name))
  110. fn = SpectralNorm(name, n_power_iterations, dim, eps)
  111. weight = module._parameters[name]
  112. if weight is None:
  113. raise ValueError(f'`SpectralNorm` cannot be applied as parameter `{name}` is None')
  114. if isinstance(weight, torch.nn.parameter.UninitializedParameter):
  115. raise ValueError(
  116. 'The module passed to `SpectralNorm` can\'t have uninitialized parameters. '
  117. 'Make sure to run the dummy forward before applying spectral normalization')
  118. with torch.no_grad():
  119. weight_mat = fn.reshape_weight_to_matrix(weight)
  120. h, w = weight_mat.size()
  121. # randomly initialize `u` and `v`
  122. u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps)
  123. v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps)
  124. delattr(module, fn.name)
  125. module.register_parameter(fn.name + "_orig", weight)
  126. # We still need to assign weight back as fn.name because all sorts of
  127. # things may assume that it exists, e.g., when initializing weights.
  128. # However, we can't directly assign as it could be an nn.Parameter and
  129. # gets added as a parameter. Instead, we register weight.data as a plain
  130. # attribute.
  131. setattr(module, fn.name, weight.data)
  132. module.register_buffer(fn.name + "_u", u)
  133. module.register_buffer(fn.name + "_v", v)
  134. module.register_forward_pre_hook(fn)
  135. module._register_state_dict_hook(SpectralNormStateDictHook(fn))
  136. module._register_load_state_dict_pre_hook(SpectralNormLoadStateDictPreHook(fn))
  137. return fn
  138. # This is a top level class because Py2 pickle doesn't like inner class nor an
  139. # instancemethod.
  140. class SpectralNormLoadStateDictPreHook:
  141. # See docstring of SpectralNorm._version on the changes to spectral_norm.
  142. def __init__(self, fn) -> None:
  143. self.fn = fn
  144. # For state_dict with version None, (assuming that it has gone through at
  145. # least one training forward), we have
  146. #
  147. # u = normalize(W_orig @ v)
  148. # W = W_orig / sigma, where sigma = u @ W_orig @ v
  149. #
  150. # To compute `v`, we solve `W_orig @ x = u`, and let
  151. # v = x / (u @ W_orig @ x) * (W / W_orig).
  152. def __call__(self, state_dict, prefix, local_metadata, strict,
  153. missing_keys, unexpected_keys, error_msgs) -> None:
  154. fn = self.fn
  155. version = local_metadata.get('spectral_norm', {}).get(fn.name + '.version', None)
  156. if version is None or version < 1:
  157. weight_key = prefix + fn.name
  158. if version is None and all(weight_key + s in state_dict for s in ('_orig', '_u', '_v')) and \
  159. weight_key not in state_dict:
  160. # Detect if it is the updated state dict and just missing metadata.
  161. # This could happen if the users are crafting a state dict themselves,
  162. # so we just pretend that this is the newest.
  163. return
  164. has_missing_keys = False
  165. for suffix in ('_orig', '', '_u'):
  166. key = weight_key + suffix
  167. if key not in state_dict:
  168. has_missing_keys = True
  169. if strict:
  170. missing_keys.append(key)
  171. if has_missing_keys:
  172. return
  173. with torch.no_grad():
  174. weight_orig = state_dict[weight_key + '_orig']
  175. weight = state_dict.pop(weight_key)
  176. sigma = (weight_orig / weight).mean()
  177. weight_mat = fn.reshape_weight_to_matrix(weight_orig)
  178. u = state_dict[weight_key + '_u']
  179. v = fn._solve_v_and_rescale(weight_mat, u, sigma)
  180. state_dict[weight_key + '_v'] = v
  181. # This is a top level class because Py2 pickle doesn't like inner class nor an
  182. # instancemethod.
  183. class SpectralNormStateDictHook:
  184. # See docstring of SpectralNorm._version on the changes to spectral_norm.
  185. def __init__(self, fn) -> None:
  186. self.fn = fn
  187. def __call__(self, module, state_dict, prefix, local_metadata) -> None:
  188. if 'spectral_norm' not in local_metadata:
  189. local_metadata['spectral_norm'] = {}
  190. key = self.fn.name + '.version'
  191. if key in local_metadata['spectral_norm']:
  192. raise RuntimeError("Unexpected key in metadata['spectral_norm']: {}".format(key))
  193. local_metadata['spectral_norm'][key] = self.fn._version
  194. T_module = TypeVar('T_module', bound=Module)
  195. def spectral_norm(module: T_module,
  196. name: str = 'weight',
  197. n_power_iterations: int = 1,
  198. eps: float = 1e-12,
  199. dim: Optional[int] = None) -> T_module:
  200. r"""Applies spectral normalization to a parameter in the given module.
  201. .. math::
  202. \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})},
  203. \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
  204. Spectral normalization stabilizes the training of discriminators (critics)
  205. in Generative Adversarial Networks (GANs) by rescaling the weight tensor
  206. with spectral norm :math:`\sigma` of the weight matrix calculated using
  207. power iteration method. If the dimension of the weight tensor is greater
  208. than 2, it is reshaped to 2D in power iteration method to get spectral
  209. norm. This is implemented via a hook that calculates spectral norm and
  210. rescales weight before every :meth:`~Module.forward` call.
  211. See `Spectral Normalization for Generative Adversarial Networks`_ .
  212. .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
  213. Args:
  214. module (nn.Module): containing module
  215. name (str, optional): name of weight parameter
  216. n_power_iterations (int, optional): number of power iterations to
  217. calculate spectral norm
  218. eps (float, optional): epsilon for numerical stability in
  219. calculating norms
  220. dim (int, optional): dimension corresponding to number of outputs,
  221. the default is ``0``, except for modules that are instances of
  222. ConvTranspose{1,2,3}d, when it is ``1``
  223. Returns:
  224. The original module with the spectral norm hook
  225. .. note::
  226. This function has been reimplemented as
  227. :func:`torch.nn.utils.parametrizations.spectral_norm` using the new
  228. parametrization functionality in
  229. :func:`torch.nn.utils.parametrize.register_parametrization`. Please use
  230. the newer version. This function will be deprecated in a future version
  231. of PyTorch.
  232. Example::
  233. >>> m = spectral_norm(nn.Linear(20, 40))
  234. >>> m
  235. Linear(in_features=20, out_features=40, bias=True)
  236. >>> m.weight_u.size()
  237. torch.Size([40])
  238. """
  239. if dim is None:
  240. if isinstance(module, (torch.nn.ConvTranspose1d,
  241. torch.nn.ConvTranspose2d,
  242. torch.nn.ConvTranspose3d)):
  243. dim = 1
  244. else:
  245. dim = 0
  246. SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
  247. return module
  248. def remove_spectral_norm(module: T_module, name: str = 'weight') -> T_module:
  249. r"""Removes the spectral normalization reparameterization from a module.
  250. Args:
  251. module (Module): containing module
  252. name (str, optional): name of weight parameter
  253. Example:
  254. >>> m = spectral_norm(nn.Linear(40, 10))
  255. >>> remove_spectral_norm(m)
  256. """
  257. for k, hook in module._forward_pre_hooks.items():
  258. if isinstance(hook, SpectralNorm) and hook.name == name:
  259. hook.remove(module)
  260. del module._forward_pre_hooks[k]
  261. break
  262. else:
  263. raise ValueError("spectral_norm of '{}' not found in {}".format(
  264. name, module))
  265. for k, hook in module._state_dict_hooks.items():
  266. if isinstance(hook, SpectralNormStateDictHook) and hook.fn.name == name:
  267. del module._state_dict_hooks[k]
  268. break
  269. for k, hook in module._load_state_dict_pre_hooks.items():
  270. if isinstance(hook, SpectralNormLoadStateDictPreHook) and hook.fn.name == name:
  271. del module._load_state_dict_pre_hooks[k]
  272. break
  273. return module