weight_norm.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. r"""
  2. Weight Normalization from https://arxiv.org/abs/1602.07868
  3. """
  4. from torch.nn.parameter import Parameter, UninitializedParameter
  5. from torch import _weight_norm, norm_except_dim
  6. from typing import Any, TypeVar
  7. from ..modules import Module
  8. __all__ = ['WeightNorm', 'weight_norm', 'remove_weight_norm']
  9. class WeightNorm:
  10. name: str
  11. dim: int
  12. def __init__(self, name: str, dim: int) -> None:
  13. if dim is None:
  14. dim = -1
  15. self.name = name
  16. self.dim = dim
  17. # TODO Make return type more specific
  18. def compute_weight(self, module: Module) -> Any:
  19. g = getattr(module, self.name + '_g')
  20. v = getattr(module, self.name + '_v')
  21. return _weight_norm(v, g, self.dim)
  22. @staticmethod
  23. def apply(module, name: str, dim: int) -> 'WeightNorm':
  24. for k, hook in module._forward_pre_hooks.items():
  25. if isinstance(hook, WeightNorm) and hook.name == name:
  26. raise RuntimeError("Cannot register two weight_norm hooks on "
  27. "the same parameter {}".format(name))
  28. if dim is None:
  29. dim = -1
  30. fn = WeightNorm(name, dim)
  31. weight = getattr(module, name)
  32. if isinstance(weight, UninitializedParameter):
  33. raise ValueError(
  34. 'The module passed to `WeightNorm` can\'t have uninitialized parameters. '
  35. 'Make sure to run the dummy forward before applying weight normalization')
  36. # remove w from parameter list
  37. del module._parameters[name]
  38. # add g and v as new parameters and express w as g/||v|| * v
  39. module.register_parameter(name + '_g', Parameter(norm_except_dim(weight, 2, dim).data))
  40. module.register_parameter(name + '_v', Parameter(weight.data))
  41. setattr(module, name, fn.compute_weight(module))
  42. # recompute weight before every forward()
  43. module.register_forward_pre_hook(fn)
  44. return fn
  45. def remove(self, module: Module) -> None:
  46. weight = self.compute_weight(module)
  47. delattr(module, self.name)
  48. del module._parameters[self.name + '_g']
  49. del module._parameters[self.name + '_v']
  50. setattr(module, self.name, Parameter(weight.data))
  51. def __call__(self, module: Module, inputs: Any) -> None:
  52. setattr(module, self.name, self.compute_weight(module))
  53. T_module = TypeVar('T_module', bound=Module)
  54. def weight_norm(module: T_module, name: str = 'weight', dim: int = 0) -> T_module:
  55. r"""Applies weight normalization to a parameter in the given module.
  56. .. math::
  57. \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}
  58. Weight normalization is a reparameterization that decouples the magnitude
  59. of a weight tensor from its direction. This replaces the parameter specified
  60. by :attr:`name` (e.g. ``'weight'``) with two parameters: one specifying the magnitude
  61. (e.g. ``'weight_g'``) and one specifying the direction (e.g. ``'weight_v'``).
  62. Weight normalization is implemented via a hook that recomputes the weight
  63. tensor from the magnitude and direction before every :meth:`~Module.forward`
  64. call.
  65. By default, with ``dim=0``, the norm is computed independently per output
  66. channel/plane. To compute a norm over the entire weight tensor, use
  67. ``dim=None``.
  68. See https://arxiv.org/abs/1602.07868
  69. Args:
  70. module (Module): containing module
  71. name (str, optional): name of weight parameter
  72. dim (int, optional): dimension over which to compute the norm
  73. Returns:
  74. The original module with the weight norm hook
  75. Example::
  76. >>> m = weight_norm(nn.Linear(20, 40), name='weight')
  77. >>> m
  78. Linear(in_features=20, out_features=40, bias=True)
  79. >>> m.weight_g.size()
  80. torch.Size([40, 1])
  81. >>> m.weight_v.size()
  82. torch.Size([40, 20])
  83. """
  84. WeightNorm.apply(module, name, dim)
  85. return module
  86. def remove_weight_norm(module: T_module, name: str = 'weight') -> T_module:
  87. r"""Removes the weight normalization reparameterization from a module.
  88. Args:
  89. module (Module): containing module
  90. name (str, optional): name of weight parameter
  91. Example:
  92. >>> m = weight_norm(nn.Linear(20, 40))
  93. >>> remove_weight_norm(m)
  94. """
  95. for k, hook in module._forward_pre_hooks.items():
  96. if isinstance(hook, WeightNorm) and hook.name == name:
  97. hook.remove(module)
  98. del module._forward_pre_hooks[k]
  99. return module
  100. raise ValueError("weight_norm of '{}' not found in {}"
  101. .format(name, module))