123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131 |
- r"""
- Weight Normalization from https://arxiv.org/abs/1602.07868
- """
- from torch.nn.parameter import Parameter, UninitializedParameter
- from torch import _weight_norm, norm_except_dim
- from typing import Any, TypeVar
- from ..modules import Module
- __all__ = ['WeightNorm', 'weight_norm', 'remove_weight_norm']
- class WeightNorm:
- name: str
- dim: int
- def __init__(self, name: str, dim: int) -> None:
- if dim is None:
- dim = -1
- self.name = name
- self.dim = dim
- # TODO Make return type more specific
- def compute_weight(self, module: Module) -> Any:
- g = getattr(module, self.name + '_g')
- v = getattr(module, self.name + '_v')
- return _weight_norm(v, g, self.dim)
- @staticmethod
- def apply(module, name: str, dim: int) -> 'WeightNorm':
- for k, hook in module._forward_pre_hooks.items():
- if isinstance(hook, WeightNorm) and hook.name == name:
- raise RuntimeError("Cannot register two weight_norm hooks on "
- "the same parameter {}".format(name))
- if dim is None:
- dim = -1
- fn = WeightNorm(name, dim)
- weight = getattr(module, name)
- if isinstance(weight, UninitializedParameter):
- raise ValueError(
- 'The module passed to `WeightNorm` can\'t have uninitialized parameters. '
- 'Make sure to run the dummy forward before applying weight normalization')
- # remove w from parameter list
- del module._parameters[name]
- # add g and v as new parameters and express w as g/||v|| * v
- module.register_parameter(name + '_g', Parameter(norm_except_dim(weight, 2, dim).data))
- module.register_parameter(name + '_v', Parameter(weight.data))
- setattr(module, name, fn.compute_weight(module))
- # recompute weight before every forward()
- module.register_forward_pre_hook(fn)
- return fn
- def remove(self, module: Module) -> None:
- weight = self.compute_weight(module)
- delattr(module, self.name)
- del module._parameters[self.name + '_g']
- del module._parameters[self.name + '_v']
- setattr(module, self.name, Parameter(weight.data))
- def __call__(self, module: Module, inputs: Any) -> None:
- setattr(module, self.name, self.compute_weight(module))
- T_module = TypeVar('T_module', bound=Module)
- def weight_norm(module: T_module, name: str = 'weight', dim: int = 0) -> T_module:
- r"""Applies weight normalization to a parameter in the given module.
- .. math::
- \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}
- Weight normalization is a reparameterization that decouples the magnitude
- of a weight tensor from its direction. This replaces the parameter specified
- by :attr:`name` (e.g. ``'weight'``) with two parameters: one specifying the magnitude
- (e.g. ``'weight_g'``) and one specifying the direction (e.g. ``'weight_v'``).
- Weight normalization is implemented via a hook that recomputes the weight
- tensor from the magnitude and direction before every :meth:`~Module.forward`
- call.
- By default, with ``dim=0``, the norm is computed independently per output
- channel/plane. To compute a norm over the entire weight tensor, use
- ``dim=None``.
- See https://arxiv.org/abs/1602.07868
- Args:
- module (Module): containing module
- name (str, optional): name of weight parameter
- dim (int, optional): dimension over which to compute the norm
- Returns:
- The original module with the weight norm hook
- Example::
- >>> m = weight_norm(nn.Linear(20, 40), name='weight')
- >>> m
- Linear(in_features=20, out_features=40, bias=True)
- >>> m.weight_g.size()
- torch.Size([40, 1])
- >>> m.weight_v.size()
- torch.Size([40, 20])
- """
- WeightNorm.apply(module, name, dim)
- return module
- def remove_weight_norm(module: T_module, name: str = 'weight') -> T_module:
- r"""Removes the weight normalization reparameterization from a module.
- Args:
- module (Module): containing module
- name (str, optional): name of weight parameter
- Example:
- >>> m = weight_norm(nn.Linear(20, 40))
- >>> remove_weight_norm(m)
- """
- for k, hook in module._forward_pre_hooks.items():
- if isinstance(hook, WeightNorm) and hook.name == name:
- hook.remove(module)
- del module._forward_pre_hooks[k]
- return module
- raise ValueError("weight_norm of '{}' not found in {}"
- .format(name, module))
|