init.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562
  1. import math
  2. import warnings
  3. from torch import Tensor
  4. import torch
  5. # These no_grad_* functions are necessary as wrappers around the parts of these
  6. # functions that use `with torch.no_grad()`. The JIT doesn't support context
  7. # managers, so these need to be implemented as builtins. Using these wrappers
  8. # lets us keep those builtins small and re-usable.
  9. def _no_grad_uniform_(tensor, a, b):
  10. with torch.no_grad():
  11. return tensor.uniform_(a, b)
  12. def _no_grad_normal_(tensor, mean, std):
  13. with torch.no_grad():
  14. return tensor.normal_(mean, std)
  15. def _no_grad_trunc_normal_(tensor, mean, std, a, b):
  16. # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
  17. def norm_cdf(x):
  18. # Computes standard normal cumulative distribution function
  19. return (1. + math.erf(x / math.sqrt(2.))) / 2.
  20. if (mean < a - 2 * std) or (mean > b + 2 * std):
  21. warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
  22. "The distribution of values may be incorrect.",
  23. stacklevel=2)
  24. with torch.no_grad():
  25. # Values are generated by using a truncated uniform distribution and
  26. # then using the inverse CDF for the normal distribution.
  27. # Get upper and lower cdf values
  28. l = norm_cdf((a - mean) / std)
  29. u = norm_cdf((b - mean) / std)
  30. # Uniformly fill tensor with values from [l, u], then translate to
  31. # [2l-1, 2u-1].
  32. tensor.uniform_(2 * l - 1, 2 * u - 1)
  33. # Use inverse cdf transform for normal distribution to get truncated
  34. # standard normal
  35. tensor.erfinv_()
  36. # Transform to proper mean, std
  37. tensor.mul_(std * math.sqrt(2.))
  38. tensor.add_(mean)
  39. # Clamp to ensure it's in the proper range
  40. tensor.clamp_(min=a, max=b)
  41. return tensor
  42. def _no_grad_fill_(tensor, val):
  43. with torch.no_grad():
  44. return tensor.fill_(val)
  45. def _no_grad_zero_(tensor):
  46. with torch.no_grad():
  47. return tensor.zero_()
  48. def calculate_gain(nonlinearity, param=None):
  49. r"""Return the recommended gain value for the given nonlinearity function.
  50. The values are as follows:
  51. ================= ====================================================
  52. nonlinearity gain
  53. ================= ====================================================
  54. Linear / Identity :math:`1`
  55. Conv{1,2,3}D :math:`1`
  56. Sigmoid :math:`1`
  57. Tanh :math:`\frac{5}{3}`
  58. ReLU :math:`\sqrt{2}`
  59. Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
  60. SELU :math:`\frac{3}{4}`
  61. ================= ====================================================
  62. .. warning::
  63. In order to implement `Self-Normalizing Neural Networks`_ ,
  64. you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``.
  65. This gives the initial weights a variance of ``1 / N``,
  66. which is necessary to induce a stable fixed point in the forward pass.
  67. In contrast, the default gain for ``SELU`` sacrifices the normalisation
  68. effect for more stable gradient flow in rectangular layers.
  69. Args:
  70. nonlinearity: the non-linear function (`nn.functional` name)
  71. param: optional parameter for the non-linear function
  72. Examples:
  73. >>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2
  74. .. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html
  75. """
  76. linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
  77. if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
  78. return 1
  79. elif nonlinearity == 'tanh':
  80. return 5.0 / 3
  81. elif nonlinearity == 'relu':
  82. return math.sqrt(2.0)
  83. elif nonlinearity == 'leaky_relu':
  84. if param is None:
  85. negative_slope = 0.01
  86. elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
  87. # True/False are instances of int, hence check above
  88. negative_slope = param
  89. else:
  90. raise ValueError("negative_slope {} not a valid number".format(param))
  91. return math.sqrt(2.0 / (1 + negative_slope ** 2))
  92. elif nonlinearity == 'selu':
  93. return 3.0 / 4 # Value found empirically (https://github.com/pytorch/pytorch/pull/50664)
  94. else:
  95. raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
  96. def uniform_(tensor: Tensor, a: float = 0., b: float = 1.) -> Tensor:
  97. r"""Fills the input Tensor with values drawn from the uniform
  98. distribution :math:`\mathcal{U}(a, b)`.
  99. Args:
  100. tensor: an n-dimensional `torch.Tensor`
  101. a: the lower bound of the uniform distribution
  102. b: the upper bound of the uniform distribution
  103. Examples:
  104. >>> w = torch.empty(3, 5)
  105. >>> nn.init.uniform_(w)
  106. """
  107. if torch.overrides.has_torch_function_variadic(tensor):
  108. return torch.overrides.handle_torch_function(uniform_, (tensor,), tensor=tensor, a=a, b=b)
  109. return _no_grad_uniform_(tensor, a, b)
  110. def normal_(tensor: Tensor, mean: float = 0., std: float = 1.) -> Tensor:
  111. r"""Fills the input Tensor with values drawn from the normal
  112. distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.
  113. Args:
  114. tensor: an n-dimensional `torch.Tensor`
  115. mean: the mean of the normal distribution
  116. std: the standard deviation of the normal distribution
  117. Examples:
  118. >>> w = torch.empty(3, 5)
  119. >>> nn.init.normal_(w)
  120. """
  121. if torch.overrides.has_torch_function_variadic(tensor):
  122. return torch.overrides.handle_torch_function(normal_, (tensor,), tensor=tensor, mean=mean, std=std)
  123. return _no_grad_normal_(tensor, mean, std)
  124. def trunc_normal_(tensor: Tensor, mean: float = 0., std: float = 1., a: float = -2., b: float = 2.) -> Tensor:
  125. r"""Fills the input Tensor with values drawn from a truncated
  126. normal distribution. The values are effectively drawn from the
  127. normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
  128. with values outside :math:`[a, b]` redrawn until they are within
  129. the bounds. The method used for generating the random values works
  130. best when :math:`a \leq \text{mean} \leq b`.
  131. Args:
  132. tensor: an n-dimensional `torch.Tensor`
  133. mean: the mean of the normal distribution
  134. std: the standard deviation of the normal distribution
  135. a: the minimum cutoff value
  136. b: the maximum cutoff value
  137. Examples:
  138. >>> w = torch.empty(3, 5)
  139. >>> nn.init.trunc_normal_(w)
  140. """
  141. return _no_grad_trunc_normal_(tensor, mean, std, a, b)
  142. def constant_(tensor: Tensor, val: float) -> Tensor:
  143. r"""Fills the input Tensor with the value :math:`\text{val}`.
  144. Args:
  145. tensor: an n-dimensional `torch.Tensor`
  146. val: the value to fill the tensor with
  147. Examples:
  148. >>> w = torch.empty(3, 5)
  149. >>> nn.init.constant_(w, 0.3)
  150. """
  151. if torch.overrides.has_torch_function_variadic(tensor):
  152. return torch.overrides.handle_torch_function(constant_, (tensor,), tensor=tensor, val=val)
  153. return _no_grad_fill_(tensor, val)
  154. def ones_(tensor: Tensor) -> Tensor:
  155. r"""Fills the input Tensor with the scalar value `1`.
  156. Args:
  157. tensor: an n-dimensional `torch.Tensor`
  158. Examples:
  159. >>> w = torch.empty(3, 5)
  160. >>> nn.init.ones_(w)
  161. """
  162. return _no_grad_fill_(tensor, 1.)
  163. def zeros_(tensor: Tensor) -> Tensor:
  164. r"""Fills the input Tensor with the scalar value `0`.
  165. Args:
  166. tensor: an n-dimensional `torch.Tensor`
  167. Examples:
  168. >>> w = torch.empty(3, 5)
  169. >>> nn.init.zeros_(w)
  170. """
  171. return _no_grad_zero_(tensor)
  172. def eye_(tensor):
  173. r"""Fills the 2-dimensional input `Tensor` with the identity
  174. matrix. Preserves the identity of the inputs in `Linear` layers, where as
  175. many inputs are preserved as possible.
  176. Args:
  177. tensor: a 2-dimensional `torch.Tensor`
  178. Examples:
  179. >>> w = torch.empty(3, 5)
  180. >>> nn.init.eye_(w)
  181. """
  182. if tensor.ndimension() != 2:
  183. raise ValueError("Only tensors with 2 dimensions are supported")
  184. with torch.no_grad():
  185. torch.eye(*tensor.shape, out=tensor, requires_grad=tensor.requires_grad)
  186. return tensor
  187. def dirac_(tensor, groups=1):
  188. r"""Fills the {3, 4, 5}-dimensional input `Tensor` with the Dirac
  189. delta function. Preserves the identity of the inputs in `Convolutional`
  190. layers, where as many input channels are preserved as possible. In case
  191. of groups>1, each group of channels preserves identity
  192. Args:
  193. tensor: a {3, 4, 5}-dimensional `torch.Tensor`
  194. groups (int, optional): number of groups in the conv layer (default: 1)
  195. Examples:
  196. >>> w = torch.empty(3, 16, 5, 5)
  197. >>> nn.init.dirac_(w)
  198. >>> w = torch.empty(3, 24, 5, 5)
  199. >>> nn.init.dirac_(w, 3)
  200. """
  201. dimensions = tensor.ndimension()
  202. if dimensions not in [3, 4, 5]:
  203. raise ValueError("Only tensors with 3, 4, or 5 dimensions are supported")
  204. sizes = tensor.size()
  205. if sizes[0] % groups != 0:
  206. raise ValueError('dim 0 must be divisible by groups')
  207. out_chans_per_grp = sizes[0] // groups
  208. min_dim = min(out_chans_per_grp, sizes[1])
  209. with torch.no_grad():
  210. tensor.zero_()
  211. for g in range(groups):
  212. for d in range(min_dim):
  213. if dimensions == 3: # Temporal convolution
  214. tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2] = 1
  215. elif dimensions == 4: # Spatial convolution
  216. tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2,
  217. tensor.size(3) // 2] = 1
  218. else: # Volumetric convolution
  219. tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2,
  220. tensor.size(3) // 2, tensor.size(4) // 2] = 1
  221. return tensor
  222. def _calculate_fan_in_and_fan_out(tensor):
  223. dimensions = tensor.dim()
  224. if dimensions < 2:
  225. raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
  226. num_input_fmaps = tensor.size(1)
  227. num_output_fmaps = tensor.size(0)
  228. receptive_field_size = 1
  229. if tensor.dim() > 2:
  230. # math.prod is not always available, accumulate the product manually
  231. # we could use functools.reduce but that is not supported by TorchScript
  232. for s in tensor.shape[2:]:
  233. receptive_field_size *= s
  234. fan_in = num_input_fmaps * receptive_field_size
  235. fan_out = num_output_fmaps * receptive_field_size
  236. return fan_in, fan_out
  237. def xavier_uniform_(tensor: Tensor, gain: float = 1.) -> Tensor:
  238. r"""Fills the input `Tensor` with values according to the method
  239. described in `Understanding the difficulty of training deep feedforward
  240. neural networks` - Glorot, X. & Bengio, Y. (2010), using a uniform
  241. distribution. The resulting tensor will have values sampled from
  242. :math:`\mathcal{U}(-a, a)` where
  243. .. math::
  244. a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}}
  245. Also known as Glorot initialization.
  246. Args:
  247. tensor: an n-dimensional `torch.Tensor`
  248. gain: an optional scaling factor
  249. Examples:
  250. >>> w = torch.empty(3, 5)
  251. >>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))
  252. """
  253. fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
  254. std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
  255. a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
  256. return _no_grad_uniform_(tensor, -a, a)
  257. def xavier_normal_(tensor: Tensor, gain: float = 1.) -> Tensor:
  258. r"""Fills the input `Tensor` with values according to the method
  259. described in `Understanding the difficulty of training deep feedforward
  260. neural networks` - Glorot, X. & Bengio, Y. (2010), using a normal
  261. distribution. The resulting tensor will have values sampled from
  262. :math:`\mathcal{N}(0, \text{std}^2)` where
  263. .. math::
  264. \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan\_in} + \text{fan\_out}}}
  265. Also known as Glorot initialization.
  266. Args:
  267. tensor: an n-dimensional `torch.Tensor`
  268. gain: an optional scaling factor
  269. Examples:
  270. >>> w = torch.empty(3, 5)
  271. >>> nn.init.xavier_normal_(w)
  272. """
  273. fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
  274. std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
  275. return _no_grad_normal_(tensor, 0., std)
  276. def _calculate_correct_fan(tensor, mode):
  277. mode = mode.lower()
  278. valid_modes = ['fan_in', 'fan_out']
  279. if mode not in valid_modes:
  280. raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
  281. fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
  282. return fan_in if mode == 'fan_in' else fan_out
  283. def kaiming_uniform_(
  284. tensor: Tensor, a: float = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu'
  285. ):
  286. r"""Fills the input `Tensor` with values according to the method
  287. described in `Delving deep into rectifiers: Surpassing human-level
  288. performance on ImageNet classification` - He, K. et al. (2015), using a
  289. uniform distribution. The resulting tensor will have values sampled from
  290. :math:`\mathcal{U}(-\text{bound}, \text{bound})` where
  291. .. math::
  292. \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
  293. Also known as He initialization.
  294. Args:
  295. tensor: an n-dimensional `torch.Tensor`
  296. a: the negative slope of the rectifier used after this layer (only
  297. used with ``'leaky_relu'``)
  298. mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
  299. preserves the magnitude of the variance of the weights in the
  300. forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
  301. backwards pass.
  302. nonlinearity: the non-linear function (`nn.functional` name),
  303. recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
  304. Examples:
  305. >>> w = torch.empty(3, 5)
  306. >>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
  307. """
  308. if torch.overrides.has_torch_function_variadic(tensor):
  309. return torch.overrides.handle_torch_function(
  310. kaiming_uniform_,
  311. (tensor,),
  312. tensor=tensor,
  313. a=a,
  314. mode=mode,
  315. nonlinearity=nonlinearity)
  316. if 0 in tensor.shape:
  317. warnings.warn("Initializing zero-element tensors is a no-op")
  318. return tensor
  319. fan = _calculate_correct_fan(tensor, mode)
  320. gain = calculate_gain(nonlinearity, a)
  321. std = gain / math.sqrt(fan)
  322. bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
  323. with torch.no_grad():
  324. return tensor.uniform_(-bound, bound)
  325. def kaiming_normal_(
  326. tensor: Tensor, a: float = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu'
  327. ):
  328. r"""Fills the input `Tensor` with values according to the method
  329. described in `Delving deep into rectifiers: Surpassing human-level
  330. performance on ImageNet classification` - He, K. et al. (2015), using a
  331. normal distribution. The resulting tensor will have values sampled from
  332. :math:`\mathcal{N}(0, \text{std}^2)` where
  333. .. math::
  334. \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}
  335. Also known as He initialization.
  336. Args:
  337. tensor: an n-dimensional `torch.Tensor`
  338. a: the negative slope of the rectifier used after this layer (only
  339. used with ``'leaky_relu'``)
  340. mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
  341. preserves the magnitude of the variance of the weights in the
  342. forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
  343. backwards pass.
  344. nonlinearity: the non-linear function (`nn.functional` name),
  345. recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
  346. Examples:
  347. >>> w = torch.empty(3, 5)
  348. >>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')
  349. """
  350. if 0 in tensor.shape:
  351. warnings.warn("Initializing zero-element tensors is a no-op")
  352. return tensor
  353. fan = _calculate_correct_fan(tensor, mode)
  354. gain = calculate_gain(nonlinearity, a)
  355. std = gain / math.sqrt(fan)
  356. with torch.no_grad():
  357. return tensor.normal_(0, std)
  358. def orthogonal_(tensor, gain=1):
  359. r"""Fills the input `Tensor` with a (semi) orthogonal matrix, as
  360. described in `Exact solutions to the nonlinear dynamics of learning in deep
  361. linear neural networks` - Saxe, A. et al. (2013). The input tensor must have
  362. at least 2 dimensions, and for tensors with more than 2 dimensions the
  363. trailing dimensions are flattened.
  364. Args:
  365. tensor: an n-dimensional `torch.Tensor`, where :math:`n \geq 2`
  366. gain: optional scaling factor
  367. Examples:
  368. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
  369. >>> w = torch.empty(3, 5)
  370. >>> nn.init.orthogonal_(w)
  371. """
  372. if tensor.ndimension() < 2:
  373. raise ValueError("Only tensors with 2 or more dimensions are supported")
  374. if tensor.numel() == 0:
  375. # no-op
  376. return tensor
  377. rows = tensor.size(0)
  378. cols = tensor.numel() // rows
  379. flattened = tensor.new(rows, cols).normal_(0, 1)
  380. if rows < cols:
  381. flattened.t_()
  382. # Compute the qr factorization
  383. q, r = torch.linalg.qr(flattened)
  384. # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf
  385. d = torch.diag(r, 0)
  386. ph = d.sign()
  387. q *= ph
  388. if rows < cols:
  389. q.t_()
  390. with torch.no_grad():
  391. tensor.view_as(q).copy_(q)
  392. tensor.mul_(gain)
  393. return tensor
  394. def sparse_(tensor, sparsity, std=0.01):
  395. r"""Fills the 2D input `Tensor` as a sparse matrix, where the
  396. non-zero elements will be drawn from the normal distribution
  397. :math:`\mathcal{N}(0, 0.01)`, as described in `Deep learning via
  398. Hessian-free optimization` - Martens, J. (2010).
  399. Args:
  400. tensor: an n-dimensional `torch.Tensor`
  401. sparsity: The fraction of elements in each column to be set to zero
  402. std: the standard deviation of the normal distribution used to generate
  403. the non-zero values
  404. Examples:
  405. >>> w = torch.empty(3, 5)
  406. >>> nn.init.sparse_(w, sparsity=0.1)
  407. """
  408. if tensor.ndimension() != 2:
  409. raise ValueError("Only tensors with 2 dimensions are supported")
  410. rows, cols = tensor.shape
  411. num_zeros = int(math.ceil(sparsity * rows))
  412. with torch.no_grad():
  413. tensor.normal_(0, std)
  414. for col_idx in range(cols):
  415. row_indices = torch.randperm(rows)
  416. zero_indices = row_indices[:num_zeros]
  417. tensor[zero_indices, col_idx] = 0
  418. return tensor
  419. # for backward compatibility
  420. def _make_deprecate(meth):
  421. new_name = meth.__name__
  422. old_name = new_name[:-1]
  423. def deprecated_init(*args, **kwargs):
  424. warnings.warn("nn.init.{} is now deprecated in favor of nn.init.{}."
  425. .format(old_name, new_name), stacklevel=2)
  426. return meth(*args, **kwargs)
  427. deprecated_init.__doc__ = r"""
  428. {old_name}(...)
  429. .. warning::
  430. This method is now deprecated in favor of :func:`torch.nn.init.{new_name}`.
  431. See :func:`~torch.nn.init.{new_name}` for details.""".format(
  432. old_name=old_name, new_name=new_name)
  433. deprecated_init.__name__ = old_name
  434. return deprecated_init
  435. uniform = _make_deprecate(uniform_)
  436. normal = _make_deprecate(normal_)
  437. constant = _make_deprecate(constant_)
  438. eye = _make_deprecate(eye_)
  439. dirac = _make_deprecate(dirac_)
  440. xavier_uniform = _make_deprecate(xavier_uniform_)
  441. xavier_normal = _make_deprecate(xavier_normal_)
  442. kaiming_uniform = _make_deprecate(kaiming_uniform_)
  443. kaiming_normal = _make_deprecate(kaiming_normal_)
  444. orthogonal = _make_deprecate(orthogonal_)
  445. sparse = _make_deprecate(sparse_)