123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562 |
- import math
- import warnings
- from torch import Tensor
- import torch
- # These no_grad_* functions are necessary as wrappers around the parts of these
- # functions that use `with torch.no_grad()`. The JIT doesn't support context
- # managers, so these need to be implemented as builtins. Using these wrappers
- # lets us keep those builtins small and re-usable.
- def _no_grad_uniform_(tensor, a, b):
- with torch.no_grad():
- return tensor.uniform_(a, b)
- def _no_grad_normal_(tensor, mean, std):
- with torch.no_grad():
- return tensor.normal_(mean, std)
- def _no_grad_trunc_normal_(tensor, mean, std, a, b):
- # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
- def norm_cdf(x):
- # Computes standard normal cumulative distribution function
- return (1. + math.erf(x / math.sqrt(2.))) / 2.
- if (mean < a - 2 * std) or (mean > b + 2 * std):
- warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
- "The distribution of values may be incorrect.",
- stacklevel=2)
- with torch.no_grad():
- # Values are generated by using a truncated uniform distribution and
- # then using the inverse CDF for the normal distribution.
- # Get upper and lower cdf values
- l = norm_cdf((a - mean) / std)
- u = norm_cdf((b - mean) / std)
- # Uniformly fill tensor with values from [l, u], then translate to
- # [2l-1, 2u-1].
- tensor.uniform_(2 * l - 1, 2 * u - 1)
- # Use inverse cdf transform for normal distribution to get truncated
- # standard normal
- tensor.erfinv_()
- # Transform to proper mean, std
- tensor.mul_(std * math.sqrt(2.))
- tensor.add_(mean)
- # Clamp to ensure it's in the proper range
- tensor.clamp_(min=a, max=b)
- return tensor
- def _no_grad_fill_(tensor, val):
- with torch.no_grad():
- return tensor.fill_(val)
- def _no_grad_zero_(tensor):
- with torch.no_grad():
- return tensor.zero_()
- def calculate_gain(nonlinearity, param=None):
- r"""Return the recommended gain value for the given nonlinearity function.
- The values are as follows:
- ================= ====================================================
- nonlinearity gain
- ================= ====================================================
- Linear / Identity :math:`1`
- Conv{1,2,3}D :math:`1`
- Sigmoid :math:`1`
- Tanh :math:`\frac{5}{3}`
- ReLU :math:`\sqrt{2}`
- Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
- SELU :math:`\frac{3}{4}`
- ================= ====================================================
- .. warning::
- In order to implement `Self-Normalizing Neural Networks`_ ,
- you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``.
- This gives the initial weights a variance of ``1 / N``,
- which is necessary to induce a stable fixed point in the forward pass.
- In contrast, the default gain for ``SELU`` sacrifices the normalisation
- effect for more stable gradient flow in rectangular layers.
- Args:
- nonlinearity: the non-linear function (`nn.functional` name)
- param: optional parameter for the non-linear function
- Examples:
- >>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2
- .. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html
- """
- linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
- if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
- return 1
- elif nonlinearity == 'tanh':
- return 5.0 / 3
- elif nonlinearity == 'relu':
- return math.sqrt(2.0)
- elif nonlinearity == 'leaky_relu':
- if param is None:
- negative_slope = 0.01
- elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
- # True/False are instances of int, hence check above
- negative_slope = param
- else:
- raise ValueError("negative_slope {} not a valid number".format(param))
- return math.sqrt(2.0 / (1 + negative_slope ** 2))
- elif nonlinearity == 'selu':
- return 3.0 / 4 # Value found empirically (https://github.com/pytorch/pytorch/pull/50664)
- else:
- raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
- def uniform_(tensor: Tensor, a: float = 0., b: float = 1.) -> Tensor:
- r"""Fills the input Tensor with values drawn from the uniform
- distribution :math:`\mathcal{U}(a, b)`.
- Args:
- tensor: an n-dimensional `torch.Tensor`
- a: the lower bound of the uniform distribution
- b: the upper bound of the uniform distribution
- Examples:
- >>> w = torch.empty(3, 5)
- >>> nn.init.uniform_(w)
- """
- if torch.overrides.has_torch_function_variadic(tensor):
- return torch.overrides.handle_torch_function(uniform_, (tensor,), tensor=tensor, a=a, b=b)
- return _no_grad_uniform_(tensor, a, b)
- def normal_(tensor: Tensor, mean: float = 0., std: float = 1.) -> Tensor:
- r"""Fills the input Tensor with values drawn from the normal
- distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.
- Args:
- tensor: an n-dimensional `torch.Tensor`
- mean: the mean of the normal distribution
- std: the standard deviation of the normal distribution
- Examples:
- >>> w = torch.empty(3, 5)
- >>> nn.init.normal_(w)
- """
- if torch.overrides.has_torch_function_variadic(tensor):
- return torch.overrides.handle_torch_function(normal_, (tensor,), tensor=tensor, mean=mean, std=std)
- return _no_grad_normal_(tensor, mean, std)
- def trunc_normal_(tensor: Tensor, mean: float = 0., std: float = 1., a: float = -2., b: float = 2.) -> Tensor:
- r"""Fills the input Tensor with values drawn from a truncated
- normal distribution. The values are effectively drawn from the
- normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
- with values outside :math:`[a, b]` redrawn until they are within
- the bounds. The method used for generating the random values works
- best when :math:`a \leq \text{mean} \leq b`.
- Args:
- tensor: an n-dimensional `torch.Tensor`
- mean: the mean of the normal distribution
- std: the standard deviation of the normal distribution
- a: the minimum cutoff value
- b: the maximum cutoff value
- Examples:
- >>> w = torch.empty(3, 5)
- >>> nn.init.trunc_normal_(w)
- """
- return _no_grad_trunc_normal_(tensor, mean, std, a, b)
- def constant_(tensor: Tensor, val: float) -> Tensor:
- r"""Fills the input Tensor with the value :math:`\text{val}`.
- Args:
- tensor: an n-dimensional `torch.Tensor`
- val: the value to fill the tensor with
- Examples:
- >>> w = torch.empty(3, 5)
- >>> nn.init.constant_(w, 0.3)
- """
- if torch.overrides.has_torch_function_variadic(tensor):
- return torch.overrides.handle_torch_function(constant_, (tensor,), tensor=tensor, val=val)
- return _no_grad_fill_(tensor, val)
- def ones_(tensor: Tensor) -> Tensor:
- r"""Fills the input Tensor with the scalar value `1`.
- Args:
- tensor: an n-dimensional `torch.Tensor`
- Examples:
- >>> w = torch.empty(3, 5)
- >>> nn.init.ones_(w)
- """
- return _no_grad_fill_(tensor, 1.)
- def zeros_(tensor: Tensor) -> Tensor:
- r"""Fills the input Tensor with the scalar value `0`.
- Args:
- tensor: an n-dimensional `torch.Tensor`
- Examples:
- >>> w = torch.empty(3, 5)
- >>> nn.init.zeros_(w)
- """
- return _no_grad_zero_(tensor)
- def eye_(tensor):
- r"""Fills the 2-dimensional input `Tensor` with the identity
- matrix. Preserves the identity of the inputs in `Linear` layers, where as
- many inputs are preserved as possible.
- Args:
- tensor: a 2-dimensional `torch.Tensor`
- Examples:
- >>> w = torch.empty(3, 5)
- >>> nn.init.eye_(w)
- """
- if tensor.ndimension() != 2:
- raise ValueError("Only tensors with 2 dimensions are supported")
- with torch.no_grad():
- torch.eye(*tensor.shape, out=tensor, requires_grad=tensor.requires_grad)
- return tensor
- def dirac_(tensor, groups=1):
- r"""Fills the {3, 4, 5}-dimensional input `Tensor` with the Dirac
- delta function. Preserves the identity of the inputs in `Convolutional`
- layers, where as many input channels are preserved as possible. In case
- of groups>1, each group of channels preserves identity
- Args:
- tensor: a {3, 4, 5}-dimensional `torch.Tensor`
- groups (int, optional): number of groups in the conv layer (default: 1)
- Examples:
- >>> w = torch.empty(3, 16, 5, 5)
- >>> nn.init.dirac_(w)
- >>> w = torch.empty(3, 24, 5, 5)
- >>> nn.init.dirac_(w, 3)
- """
- dimensions = tensor.ndimension()
- if dimensions not in [3, 4, 5]:
- raise ValueError("Only tensors with 3, 4, or 5 dimensions are supported")
- sizes = tensor.size()
- if sizes[0] % groups != 0:
- raise ValueError('dim 0 must be divisible by groups')
- out_chans_per_grp = sizes[0] // groups
- min_dim = min(out_chans_per_grp, sizes[1])
- with torch.no_grad():
- tensor.zero_()
- for g in range(groups):
- for d in range(min_dim):
- if dimensions == 3: # Temporal convolution
- tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2] = 1
- elif dimensions == 4: # Spatial convolution
- tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2,
- tensor.size(3) // 2] = 1
- else: # Volumetric convolution
- tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2,
- tensor.size(3) // 2, tensor.size(4) // 2] = 1
- return tensor
- def _calculate_fan_in_and_fan_out(tensor):
- dimensions = tensor.dim()
- if dimensions < 2:
- raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
- num_input_fmaps = tensor.size(1)
- num_output_fmaps = tensor.size(0)
- receptive_field_size = 1
- if tensor.dim() > 2:
- # math.prod is not always available, accumulate the product manually
- # we could use functools.reduce but that is not supported by TorchScript
- for s in tensor.shape[2:]:
- receptive_field_size *= s
- fan_in = num_input_fmaps * receptive_field_size
- fan_out = num_output_fmaps * receptive_field_size
- return fan_in, fan_out
- def xavier_uniform_(tensor: Tensor, gain: float = 1.) -> Tensor:
- r"""Fills the input `Tensor` with values according to the method
- described in `Understanding the difficulty of training deep feedforward
- neural networks` - Glorot, X. & Bengio, Y. (2010), using a uniform
- distribution. The resulting tensor will have values sampled from
- :math:`\mathcal{U}(-a, a)` where
- .. math::
- a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}}
- Also known as Glorot initialization.
- Args:
- tensor: an n-dimensional `torch.Tensor`
- gain: an optional scaling factor
- Examples:
- >>> w = torch.empty(3, 5)
- >>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))
- """
- fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
- std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
- a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
- return _no_grad_uniform_(tensor, -a, a)
- def xavier_normal_(tensor: Tensor, gain: float = 1.) -> Tensor:
- r"""Fills the input `Tensor` with values according to the method
- described in `Understanding the difficulty of training deep feedforward
- neural networks` - Glorot, X. & Bengio, Y. (2010), using a normal
- distribution. The resulting tensor will have values sampled from
- :math:`\mathcal{N}(0, \text{std}^2)` where
- .. math::
- \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan\_in} + \text{fan\_out}}}
- Also known as Glorot initialization.
- Args:
- tensor: an n-dimensional `torch.Tensor`
- gain: an optional scaling factor
- Examples:
- >>> w = torch.empty(3, 5)
- >>> nn.init.xavier_normal_(w)
- """
- fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
- std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
- return _no_grad_normal_(tensor, 0., std)
- def _calculate_correct_fan(tensor, mode):
- mode = mode.lower()
- valid_modes = ['fan_in', 'fan_out']
- if mode not in valid_modes:
- raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
- fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
- return fan_in if mode == 'fan_in' else fan_out
- def kaiming_uniform_(
- tensor: Tensor, a: float = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu'
- ):
- r"""Fills the input `Tensor` with values according to the method
- described in `Delving deep into rectifiers: Surpassing human-level
- performance on ImageNet classification` - He, K. et al. (2015), using a
- uniform distribution. The resulting tensor will have values sampled from
- :math:`\mathcal{U}(-\text{bound}, \text{bound})` where
- .. math::
- \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
- Also known as He initialization.
- Args:
- tensor: an n-dimensional `torch.Tensor`
- a: the negative slope of the rectifier used after this layer (only
- used with ``'leaky_relu'``)
- mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
- preserves the magnitude of the variance of the weights in the
- forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
- backwards pass.
- nonlinearity: the non-linear function (`nn.functional` name),
- recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
- Examples:
- >>> w = torch.empty(3, 5)
- >>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
- """
- if torch.overrides.has_torch_function_variadic(tensor):
- return torch.overrides.handle_torch_function(
- kaiming_uniform_,
- (tensor,),
- tensor=tensor,
- a=a,
- mode=mode,
- nonlinearity=nonlinearity)
- if 0 in tensor.shape:
- warnings.warn("Initializing zero-element tensors is a no-op")
- return tensor
- fan = _calculate_correct_fan(tensor, mode)
- gain = calculate_gain(nonlinearity, a)
- std = gain / math.sqrt(fan)
- bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
- with torch.no_grad():
- return tensor.uniform_(-bound, bound)
- def kaiming_normal_(
- tensor: Tensor, a: float = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu'
- ):
- r"""Fills the input `Tensor` with values according to the method
- described in `Delving deep into rectifiers: Surpassing human-level
- performance on ImageNet classification` - He, K. et al. (2015), using a
- normal distribution. The resulting tensor will have values sampled from
- :math:`\mathcal{N}(0, \text{std}^2)` where
- .. math::
- \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}
- Also known as He initialization.
- Args:
- tensor: an n-dimensional `torch.Tensor`
- a: the negative slope of the rectifier used after this layer (only
- used with ``'leaky_relu'``)
- mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
- preserves the magnitude of the variance of the weights in the
- forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
- backwards pass.
- nonlinearity: the non-linear function (`nn.functional` name),
- recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
- Examples:
- >>> w = torch.empty(3, 5)
- >>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')
- """
- if 0 in tensor.shape:
- warnings.warn("Initializing zero-element tensors is a no-op")
- return tensor
- fan = _calculate_correct_fan(tensor, mode)
- gain = calculate_gain(nonlinearity, a)
- std = gain / math.sqrt(fan)
- with torch.no_grad():
- return tensor.normal_(0, std)
- def orthogonal_(tensor, gain=1):
- r"""Fills the input `Tensor` with a (semi) orthogonal matrix, as
- described in `Exact solutions to the nonlinear dynamics of learning in deep
- linear neural networks` - Saxe, A. et al. (2013). The input tensor must have
- at least 2 dimensions, and for tensors with more than 2 dimensions the
- trailing dimensions are flattened.
- Args:
- tensor: an n-dimensional `torch.Tensor`, where :math:`n \geq 2`
- gain: optional scaling factor
- Examples:
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
- >>> w = torch.empty(3, 5)
- >>> nn.init.orthogonal_(w)
- """
- if tensor.ndimension() < 2:
- raise ValueError("Only tensors with 2 or more dimensions are supported")
- if tensor.numel() == 0:
- # no-op
- return tensor
- rows = tensor.size(0)
- cols = tensor.numel() // rows
- flattened = tensor.new(rows, cols).normal_(0, 1)
- if rows < cols:
- flattened.t_()
- # Compute the qr factorization
- q, r = torch.linalg.qr(flattened)
- # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf
- d = torch.diag(r, 0)
- ph = d.sign()
- q *= ph
- if rows < cols:
- q.t_()
- with torch.no_grad():
- tensor.view_as(q).copy_(q)
- tensor.mul_(gain)
- return tensor
- def sparse_(tensor, sparsity, std=0.01):
- r"""Fills the 2D input `Tensor` as a sparse matrix, where the
- non-zero elements will be drawn from the normal distribution
- :math:`\mathcal{N}(0, 0.01)`, as described in `Deep learning via
- Hessian-free optimization` - Martens, J. (2010).
- Args:
- tensor: an n-dimensional `torch.Tensor`
- sparsity: The fraction of elements in each column to be set to zero
- std: the standard deviation of the normal distribution used to generate
- the non-zero values
- Examples:
- >>> w = torch.empty(3, 5)
- >>> nn.init.sparse_(w, sparsity=0.1)
- """
- if tensor.ndimension() != 2:
- raise ValueError("Only tensors with 2 dimensions are supported")
- rows, cols = tensor.shape
- num_zeros = int(math.ceil(sparsity * rows))
- with torch.no_grad():
- tensor.normal_(0, std)
- for col_idx in range(cols):
- row_indices = torch.randperm(rows)
- zero_indices = row_indices[:num_zeros]
- tensor[zero_indices, col_idx] = 0
- return tensor
- # for backward compatibility
- def _make_deprecate(meth):
- new_name = meth.__name__
- old_name = new_name[:-1]
- def deprecated_init(*args, **kwargs):
- warnings.warn("nn.init.{} is now deprecated in favor of nn.init.{}."
- .format(old_name, new_name), stacklevel=2)
- return meth(*args, **kwargs)
- deprecated_init.__doc__ = r"""
- {old_name}(...)
- .. warning::
- This method is now deprecated in favor of :func:`torch.nn.init.{new_name}`.
- See :func:`~torch.nn.init.{new_name}` for details.""".format(
- old_name=old_name, new_name=new_name)
- deprecated_init.__name__ = old_name
- return deprecated_init
- uniform = _make_deprecate(uniform_)
- normal = _make_deprecate(normal_)
- constant = _make_deprecate(constant_)
- eye = _make_deprecate(eye_)
- dirac = _make_deprecate(dirac_)
- xavier_uniform = _make_deprecate(xavier_uniform_)
- xavier_normal = _make_deprecate(xavier_normal_)
- kaiming_uniform = _make_deprecate(kaiming_uniform_)
- kaiming_normal = _make_deprecate(kaiming_normal_)
- orthogonal = _make_deprecate(orthogonal_)
- sparse = _make_deprecate(sparse_)
|