123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316 |
- """
- Spectral Normalization from https://arxiv.org/abs/1802.05957
- """
- import torch
- from torch.nn.functional import normalize
- from typing import Any, Optional, TypeVar
- from ..modules import Module
- __all__ = ['SpectralNorm', 'SpectralNormLoadStateDictPreHook', 'SpectralNormStateDictHook',
- 'spectral_norm', 'remove_spectral_norm']
- class SpectralNorm:
- # Invariant before and after each forward call:
- # u = normalize(W @ v)
- # NB: At initialization, this invariant is not enforced
- _version: int = 1
- # At version 1:
- # made `W` not a buffer,
- # added `v` as a buffer, and
- # made eval mode use `W = u @ W_orig @ v` rather than the stored `W`.
- name: str
- dim: int
- n_power_iterations: int
- eps: float
- def __init__(self, name: str = 'weight', n_power_iterations: int = 1, dim: int = 0, eps: float = 1e-12) -> None:
- self.name = name
- self.dim = dim
- if n_power_iterations <= 0:
- raise ValueError('Expected n_power_iterations to be positive, but '
- 'got n_power_iterations={}'.format(n_power_iterations))
- self.n_power_iterations = n_power_iterations
- self.eps = eps
- def reshape_weight_to_matrix(self, weight: torch.Tensor) -> torch.Tensor:
- weight_mat = weight
- if self.dim != 0:
- # permute dim to front
- weight_mat = weight_mat.permute(self.dim,
- *[d for d in range(weight_mat.dim()) if d != self.dim])
- height = weight_mat.size(0)
- return weight_mat.reshape(height, -1)
- def compute_weight(self, module: Module, do_power_iteration: bool) -> torch.Tensor:
- # NB: If `do_power_iteration` is set, the `u` and `v` vectors are
- # updated in power iteration **in-place**. This is very important
- # because in `DataParallel` forward, the vectors (being buffers) are
- # broadcast from the parallelized module to each module replica,
- # which is a new module object created on the fly. And each replica
- # runs its own spectral norm power iteration. So simply assigning
- # the updated vectors to the module this function runs on will cause
- # the update to be lost forever. And the next time the parallelized
- # module is replicated, the same randomly initialized vectors are
- # broadcast and used!
- #
- # Therefore, to make the change propagate back, we rely on two
- # important behaviors (also enforced via tests):
- # 1. `DataParallel` doesn't clone storage if the broadcast tensor
- # is already on correct device; and it makes sure that the
- # parallelized module is already on `device[0]`.
- # 2. If the out tensor in `out=` kwarg has correct shape, it will
- # just fill in the values.
- # Therefore, since the same power iteration is performed on all
- # devices, simply updating the tensors in-place will make sure that
- # the module replica on `device[0]` will update the _u vector on the
- # parallized module (by shared storage).
- #
- # However, after we update `u` and `v` in-place, we need to **clone**
- # them before using them to normalize the weight. This is to support
- # backproping through two forward passes, e.g., the common pattern in
- # GAN training: loss = D(real) - D(fake). Otherwise, engine will
- # complain that variables needed to do backward for the first forward
- # (i.e., the `u` and `v` vectors) are changed in the second forward.
- weight = getattr(module, self.name + '_orig')
- u = getattr(module, self.name + '_u')
- v = getattr(module, self.name + '_v')
- weight_mat = self.reshape_weight_to_matrix(weight)
- if do_power_iteration:
- with torch.no_grad():
- for _ in range(self.n_power_iterations):
- # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
- # are the first left and right singular vectors.
- # This power iteration produces approximations of `u` and `v`.
- v = normalize(torch.mv(weight_mat.t(), u), dim=0, eps=self.eps, out=v)
- u = normalize(torch.mv(weight_mat, v), dim=0, eps=self.eps, out=u)
- if self.n_power_iterations > 0:
- # See above on why we need to clone
- u = u.clone(memory_format=torch.contiguous_format)
- v = v.clone(memory_format=torch.contiguous_format)
- sigma = torch.dot(u, torch.mv(weight_mat, v))
- weight = weight / sigma
- return weight
- def remove(self, module: Module) -> None:
- with torch.no_grad():
- weight = self.compute_weight(module, do_power_iteration=False)
- delattr(module, self.name)
- delattr(module, self.name + '_u')
- delattr(module, self.name + '_v')
- delattr(module, self.name + '_orig')
- module.register_parameter(self.name, torch.nn.Parameter(weight.detach()))
- def __call__(self, module: Module, inputs: Any) -> None:
- setattr(module, self.name, self.compute_weight(module, do_power_iteration=module.training))
- def _solve_v_and_rescale(self, weight_mat, u, target_sigma):
- # Tries to returns a vector `v` s.t. `u = normalize(W @ v)`
- # (the invariant at top of this class) and `u @ W @ v = sigma`.
- # This uses pinverse in case W^T W is not invertible.
- v = torch.linalg.multi_dot([weight_mat.t().mm(weight_mat).pinverse(), weight_mat.t(), u.unsqueeze(1)]).squeeze(1)
- return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v)))
- @staticmethod
- def apply(module: Module, name: str, n_power_iterations: int, dim: int, eps: float) -> 'SpectralNorm':
- for k, hook in module._forward_pre_hooks.items():
- if isinstance(hook, SpectralNorm) and hook.name == name:
- raise RuntimeError("Cannot register two spectral_norm hooks on "
- "the same parameter {}".format(name))
- fn = SpectralNorm(name, n_power_iterations, dim, eps)
- weight = module._parameters[name]
- if weight is None:
- raise ValueError(f'`SpectralNorm` cannot be applied as parameter `{name}` is None')
- if isinstance(weight, torch.nn.parameter.UninitializedParameter):
- raise ValueError(
- 'The module passed to `SpectralNorm` can\'t have uninitialized parameters. '
- 'Make sure to run the dummy forward before applying spectral normalization')
- with torch.no_grad():
- weight_mat = fn.reshape_weight_to_matrix(weight)
- h, w = weight_mat.size()
- # randomly initialize `u` and `v`
- u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps)
- v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps)
- delattr(module, fn.name)
- module.register_parameter(fn.name + "_orig", weight)
- # We still need to assign weight back as fn.name because all sorts of
- # things may assume that it exists, e.g., when initializing weights.
- # However, we can't directly assign as it could be an nn.Parameter and
- # gets added as a parameter. Instead, we register weight.data as a plain
- # attribute.
- setattr(module, fn.name, weight.data)
- module.register_buffer(fn.name + "_u", u)
- module.register_buffer(fn.name + "_v", v)
- module.register_forward_pre_hook(fn)
- module._register_state_dict_hook(SpectralNormStateDictHook(fn))
- module._register_load_state_dict_pre_hook(SpectralNormLoadStateDictPreHook(fn))
- return fn
- # This is a top level class because Py2 pickle doesn't like inner class nor an
- # instancemethod.
- class SpectralNormLoadStateDictPreHook:
- # See docstring of SpectralNorm._version on the changes to spectral_norm.
- def __init__(self, fn) -> None:
- self.fn = fn
- # For state_dict with version None, (assuming that it has gone through at
- # least one training forward), we have
- #
- # u = normalize(W_orig @ v)
- # W = W_orig / sigma, where sigma = u @ W_orig @ v
- #
- # To compute `v`, we solve `W_orig @ x = u`, and let
- # v = x / (u @ W_orig @ x) * (W / W_orig).
- def __call__(self, state_dict, prefix, local_metadata, strict,
- missing_keys, unexpected_keys, error_msgs) -> None:
- fn = self.fn
- version = local_metadata.get('spectral_norm', {}).get(fn.name + '.version', None)
- if version is None or version < 1:
- weight_key = prefix + fn.name
- if version is None and all(weight_key + s in state_dict for s in ('_orig', '_u', '_v')) and \
- weight_key not in state_dict:
- # Detect if it is the updated state dict and just missing metadata.
- # This could happen if the users are crafting a state dict themselves,
- # so we just pretend that this is the newest.
- return
- has_missing_keys = False
- for suffix in ('_orig', '', '_u'):
- key = weight_key + suffix
- if key not in state_dict:
- has_missing_keys = True
- if strict:
- missing_keys.append(key)
- if has_missing_keys:
- return
- with torch.no_grad():
- weight_orig = state_dict[weight_key + '_orig']
- weight = state_dict.pop(weight_key)
- sigma = (weight_orig / weight).mean()
- weight_mat = fn.reshape_weight_to_matrix(weight_orig)
- u = state_dict[weight_key + '_u']
- v = fn._solve_v_and_rescale(weight_mat, u, sigma)
- state_dict[weight_key + '_v'] = v
- # This is a top level class because Py2 pickle doesn't like inner class nor an
- # instancemethod.
- class SpectralNormStateDictHook:
- # See docstring of SpectralNorm._version on the changes to spectral_norm.
- def __init__(self, fn) -> None:
- self.fn = fn
- def __call__(self, module, state_dict, prefix, local_metadata) -> None:
- if 'spectral_norm' not in local_metadata:
- local_metadata['spectral_norm'] = {}
- key = self.fn.name + '.version'
- if key in local_metadata['spectral_norm']:
- raise RuntimeError("Unexpected key in metadata['spectral_norm']: {}".format(key))
- local_metadata['spectral_norm'][key] = self.fn._version
- T_module = TypeVar('T_module', bound=Module)
- def spectral_norm(module: T_module,
- name: str = 'weight',
- n_power_iterations: int = 1,
- eps: float = 1e-12,
- dim: Optional[int] = None) -> T_module:
- r"""Applies spectral normalization to a parameter in the given module.
- .. math::
- \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})},
- \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
- Spectral normalization stabilizes the training of discriminators (critics)
- in Generative Adversarial Networks (GANs) by rescaling the weight tensor
- with spectral norm :math:`\sigma` of the weight matrix calculated using
- power iteration method. If the dimension of the weight tensor is greater
- than 2, it is reshaped to 2D in power iteration method to get spectral
- norm. This is implemented via a hook that calculates spectral norm and
- rescales weight before every :meth:`~Module.forward` call.
- See `Spectral Normalization for Generative Adversarial Networks`_ .
- .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
- Args:
- module (nn.Module): containing module
- name (str, optional): name of weight parameter
- n_power_iterations (int, optional): number of power iterations to
- calculate spectral norm
- eps (float, optional): epsilon for numerical stability in
- calculating norms
- dim (int, optional): dimension corresponding to number of outputs,
- the default is ``0``, except for modules that are instances of
- ConvTranspose{1,2,3}d, when it is ``1``
- Returns:
- The original module with the spectral norm hook
- .. note::
- This function has been reimplemented as
- :func:`torch.nn.utils.parametrizations.spectral_norm` using the new
- parametrization functionality in
- :func:`torch.nn.utils.parametrize.register_parametrization`. Please use
- the newer version. This function will be deprecated in a future version
- of PyTorch.
- Example::
- >>> m = spectral_norm(nn.Linear(20, 40))
- >>> m
- Linear(in_features=20, out_features=40, bias=True)
- >>> m.weight_u.size()
- torch.Size([40])
- """
- if dim is None:
- if isinstance(module, (torch.nn.ConvTranspose1d,
- torch.nn.ConvTranspose2d,
- torch.nn.ConvTranspose3d)):
- dim = 1
- else:
- dim = 0
- SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
- return module
- def remove_spectral_norm(module: T_module, name: str = 'weight') -> T_module:
- r"""Removes the spectral normalization reparameterization from a module.
- Args:
- module (Module): containing module
- name (str, optional): name of weight parameter
- Example:
- >>> m = spectral_norm(nn.Linear(40, 10))
- >>> remove_spectral_norm(m)
- """
- for k, hook in module._forward_pre_hooks.items():
- if isinstance(hook, SpectralNorm) and hook.name == name:
- hook.remove(module)
- del module._forward_pre_hooks[k]
- break
- else:
- raise ValueError("spectral_norm of '{}' not found in {}".format(
- name, module))
- for k, hook in module._state_dict_hooks.items():
- if isinstance(hook, SpectralNormStateDictHook) and hook.fn.name == name:
- del module._state_dict_hooks[k]
- break
- for k, hook in module._load_state_dict_pre_hooks.items():
- if isinstance(hook, SpectralNormLoadStateDictPreHook) and hook.fn.name == name:
- del module._load_state_dict_pre_hooks[k]
- break
- return module
|