activation.py 52 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533
  1. import warnings
  2. from typing import Optional, Tuple
  3. import torch
  4. from torch import Tensor
  5. from .linear import NonDynamicallyQuantizableLinear
  6. from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
  7. from torch.nn.parameter import Parameter
  8. from .module import Module
  9. from .. import functional as F
  10. __all__ = ['Threshold', 'ReLU', 'RReLU', 'Hardtanh', 'ReLU6', 'Sigmoid', 'Hardsigmoid', 'Tanh',
  11. 'SiLU', 'Mish', 'Hardswish', 'ELU', 'CELU', 'SELU', 'GLU', 'GELU', 'Hardshrink', 'LeakyReLU',
  12. 'LogSigmoid', 'Softplus', 'Softshrink', 'MultiheadAttention', 'PReLU', 'Softsign', 'Tanhshrink',
  13. 'Softmin', 'Softmax', 'Softmax2d', 'LogSoftmax']
  14. class Threshold(Module):
  15. r"""Thresholds each element of the input Tensor.
  16. Threshold is defined as:
  17. .. math::
  18. y =
  19. \begin{cases}
  20. x, &\text{ if } x > \text{threshold} \\
  21. \text{value}, &\text{ otherwise }
  22. \end{cases}
  23. Args:
  24. threshold: The value to threshold at
  25. value: The value to replace with
  26. inplace: can optionally do the operation in-place. Default: ``False``
  27. Shape:
  28. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  29. - Output: :math:`(*)`, same shape as the input.
  30. Examples::
  31. >>> m = nn.Threshold(0.1, 20)
  32. >>> input = torch.randn(2)
  33. >>> output = m(input)
  34. """
  35. __constants__ = ['threshold', 'value', 'inplace']
  36. threshold: float
  37. value: float
  38. inplace: bool
  39. def __init__(self, threshold: float, value: float, inplace: bool = False) -> None:
  40. super().__init__()
  41. self.threshold = threshold
  42. self.value = value
  43. self.inplace = inplace
  44. # TODO: check in THNN (if inplace == True, then assert value <= threshold)
  45. def forward(self, input: Tensor) -> Tensor:
  46. return F.threshold(input, self.threshold, self.value, self.inplace)
  47. def extra_repr(self):
  48. inplace_str = ', inplace=True' if self.inplace else ''
  49. return 'threshold={}, value={}{}'.format(
  50. self.threshold, self.value, inplace_str
  51. )
  52. class ReLU(Module):
  53. r"""Applies the rectified linear unit function element-wise:
  54. :math:`\text{ReLU}(x) = (x)^+ = \max(0, x)`
  55. Args:
  56. inplace: can optionally do the operation in-place. Default: ``False``
  57. Shape:
  58. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  59. - Output: :math:`(*)`, same shape as the input.
  60. .. image:: ../scripts/activation_images/ReLU.png
  61. Examples::
  62. >>> m = nn.ReLU()
  63. >>> input = torch.randn(2)
  64. >>> output = m(input)
  65. An implementation of CReLU - https://arxiv.org/abs/1603.05201
  66. >>> m = nn.ReLU()
  67. >>> input = torch.randn(2).unsqueeze(0)
  68. >>> output = torch.cat((m(input), m(-input)))
  69. """
  70. __constants__ = ['inplace']
  71. inplace: bool
  72. def __init__(self, inplace: bool = False):
  73. super().__init__()
  74. self.inplace = inplace
  75. def forward(self, input: Tensor) -> Tensor:
  76. return F.relu(input, inplace=self.inplace)
  77. def extra_repr(self) -> str:
  78. inplace_str = 'inplace=True' if self.inplace else ''
  79. return inplace_str
  80. class RReLU(Module):
  81. r"""Applies the randomized leaky rectified liner unit function, element-wise,
  82. as described in the paper:
  83. `Empirical Evaluation of Rectified Activations in Convolutional Network`_.
  84. The function is defined as:
  85. .. math::
  86. \text{RReLU}(x) =
  87. \begin{cases}
  88. x & \text{if } x \geq 0 \\
  89. ax & \text{ otherwise }
  90. \end{cases}
  91. where :math:`a` is randomly sampled from uniform distribution
  92. :math:`\mathcal{U}(\text{lower}, \text{upper})`.
  93. See: https://arxiv.org/pdf/1505.00853.pdf
  94. Args:
  95. lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}`
  96. upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}`
  97. inplace: can optionally do the operation in-place. Default: ``False``
  98. Shape:
  99. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  100. - Output: :math:`(*)`, same shape as the input.
  101. .. image:: ../scripts/activation_images/RReLU.png
  102. Examples::
  103. >>> m = nn.RReLU(0.1, 0.3)
  104. >>> input = torch.randn(2)
  105. >>> output = m(input)
  106. .. _`Empirical Evaluation of Rectified Activations in Convolutional Network`:
  107. https://arxiv.org/abs/1505.00853
  108. """
  109. __constants__ = ['lower', 'upper', 'inplace']
  110. lower: float
  111. upper: float
  112. inplace: bool
  113. def __init__(
  114. self,
  115. lower: float = 1. / 8,
  116. upper: float = 1. / 3,
  117. inplace: bool = False
  118. ):
  119. super().__init__()
  120. self.lower = lower
  121. self.upper = upper
  122. self.inplace = inplace
  123. def forward(self, input: Tensor) -> Tensor:
  124. return F.rrelu(input, self.lower, self.upper, self.training, self.inplace)
  125. def extra_repr(self):
  126. inplace_str = ', inplace=True' if self.inplace else ''
  127. return 'lower={}, upper={}{}'.format(self.lower, self.upper, inplace_str)
  128. class Hardtanh(Module):
  129. r"""Applies the HardTanh function element-wise.
  130. HardTanh is defined as:
  131. .. math::
  132. \text{HardTanh}(x) = \begin{cases}
  133. \text{max\_val} & \text{ if } x > \text{ max\_val } \\
  134. \text{min\_val} & \text{ if } x < \text{ min\_val } \\
  135. x & \text{ otherwise } \\
  136. \end{cases}
  137. Args:
  138. min_val: minimum value of the linear region range. Default: -1
  139. max_val: maximum value of the linear region range. Default: 1
  140. inplace: can optionally do the operation in-place. Default: ``False``
  141. Keyword arguments :attr:`min_value` and :attr:`max_value`
  142. have been deprecated in favor of :attr:`min_val` and :attr:`max_val`.
  143. Shape:
  144. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  145. - Output: :math:`(*)`, same shape as the input.
  146. .. image:: ../scripts/activation_images/Hardtanh.png
  147. Examples::
  148. >>> m = nn.Hardtanh(-2, 2)
  149. >>> input = torch.randn(2)
  150. >>> output = m(input)
  151. """
  152. __constants__ = ['min_val', 'max_val', 'inplace']
  153. min_val: float
  154. max_val: float
  155. inplace: bool
  156. def __init__(
  157. self,
  158. min_val: float = -1.,
  159. max_val: float = 1.,
  160. inplace: bool = False,
  161. min_value: Optional[float] = None,
  162. max_value: Optional[float] = None
  163. ) -> None:
  164. super().__init__()
  165. if min_value is not None:
  166. warnings.warn("keyword argument min_value is deprecated and rename to min_val")
  167. min_val = min_value
  168. if max_value is not None:
  169. warnings.warn("keyword argument max_value is deprecated and rename to max_val")
  170. max_val = max_value
  171. self.min_val = min_val
  172. self.max_val = max_val
  173. self.inplace = inplace
  174. assert self.max_val > self.min_val
  175. def forward(self, input: Tensor) -> Tensor:
  176. return F.hardtanh(input, self.min_val, self.max_val, self.inplace)
  177. def extra_repr(self) -> str:
  178. inplace_str = ', inplace=True' if self.inplace else ''
  179. return 'min_val={}, max_val={}{}'.format(
  180. self.min_val, self.max_val, inplace_str
  181. )
  182. class ReLU6(Hardtanh):
  183. r"""Applies the element-wise function:
  184. .. math::
  185. \text{ReLU6}(x) = \min(\max(0,x), 6)
  186. Args:
  187. inplace: can optionally do the operation in-place. Default: ``False``
  188. Shape:
  189. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  190. - Output: :math:`(*)`, same shape as the input.
  191. .. image:: ../scripts/activation_images/ReLU6.png
  192. Examples::
  193. >>> m = nn.ReLU6()
  194. >>> input = torch.randn(2)
  195. >>> output = m(input)
  196. """
  197. def __init__(self, inplace: bool = False):
  198. super().__init__(0., 6., inplace)
  199. def extra_repr(self) -> str:
  200. inplace_str = 'inplace=True' if self.inplace else ''
  201. return inplace_str
  202. class Sigmoid(Module):
  203. r"""Applies the element-wise function:
  204. .. math::
  205. \text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
  206. Shape:
  207. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  208. - Output: :math:`(*)`, same shape as the input.
  209. .. image:: ../scripts/activation_images/Sigmoid.png
  210. Examples::
  211. >>> m = nn.Sigmoid()
  212. >>> input = torch.randn(2)
  213. >>> output = m(input)
  214. """
  215. def forward(self, input: Tensor) -> Tensor:
  216. return torch.sigmoid(input)
  217. class Hardsigmoid(Module):
  218. r"""Applies the Hardsigmoid function element-wise.
  219. Hardsigmoid is defined as:
  220. .. math::
  221. \text{Hardsigmoid}(x) = \begin{cases}
  222. 0 & \text{if~} x \le -3, \\
  223. 1 & \text{if~} x \ge +3, \\
  224. x / 6 + 1 / 2 & \text{otherwise}
  225. \end{cases}
  226. Args:
  227. inplace: can optionally do the operation in-place. Default: ``False``
  228. Shape:
  229. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  230. - Output: :math:`(*)`, same shape as the input.
  231. .. image:: ../scripts/activation_images/Hardsigmoid.png
  232. Examples::
  233. >>> m = nn.Hardsigmoid()
  234. >>> input = torch.randn(2)
  235. >>> output = m(input)
  236. """
  237. __constants__ = ['inplace']
  238. inplace: bool
  239. def __init__(self, inplace : bool = False) -> None:
  240. super().__init__()
  241. self.inplace = inplace
  242. def forward(self, input: Tensor) -> Tensor:
  243. return F.hardsigmoid(input, self.inplace)
  244. class Tanh(Module):
  245. r"""Applies the Hyperbolic Tangent (Tanh) function element-wise.
  246. Tanh is defined as:
  247. .. math::
  248. \text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)} {\exp(x) + \exp(-x)}
  249. Shape:
  250. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  251. - Output: :math:`(*)`, same shape as the input.
  252. .. image:: ../scripts/activation_images/Tanh.png
  253. Examples::
  254. >>> m = nn.Tanh()
  255. >>> input = torch.randn(2)
  256. >>> output = m(input)
  257. """
  258. def forward(self, input: Tensor) -> Tensor:
  259. return torch.tanh(input)
  260. class SiLU(Module):
  261. r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise.
  262. The SiLU function is also known as the swish function.
  263. .. math::
  264. \text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.}
  265. .. note::
  266. See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_
  267. where the SiLU (Sigmoid Linear Unit) was originally coined, and see
  268. `Sigmoid-Weighted Linear Units for Neural Network Function Approximation
  269. in Reinforcement Learning <https://arxiv.org/abs/1702.03118>`_ and `Swish:
  270. a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941v1>`_
  271. where the SiLU was experimented with later.
  272. Shape:
  273. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  274. - Output: :math:`(*)`, same shape as the input.
  275. .. image:: ../scripts/activation_images/SiLU.png
  276. Examples::
  277. >>> m = nn.SiLU()
  278. >>> input = torch.randn(2)
  279. >>> output = m(input)
  280. """
  281. __constants__ = ['inplace']
  282. inplace: bool
  283. def __init__(self, inplace: bool = False):
  284. super().__init__()
  285. self.inplace = inplace
  286. def forward(self, input: Tensor) -> Tensor:
  287. return F.silu(input, inplace=self.inplace)
  288. def extra_repr(self) -> str:
  289. inplace_str = 'inplace=True' if self.inplace else ''
  290. return inplace_str
  291. class Mish(Module):
  292. r"""Applies the Mish function, element-wise.
  293. Mish: A Self Regularized Non-Monotonic Neural Activation Function.
  294. .. math::
  295. \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
  296. .. note::
  297. See `Mish: A Self Regularized Non-Monotonic Neural Activation Function <https://arxiv.org/abs/1908.08681>`_
  298. Shape:
  299. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  300. - Output: :math:`(*)`, same shape as the input.
  301. .. image:: ../scripts/activation_images/Mish.png
  302. Examples::
  303. >>> m = nn.Mish()
  304. >>> input = torch.randn(2)
  305. >>> output = m(input)
  306. """
  307. __constants__ = ['inplace']
  308. inplace: bool
  309. def __init__(self, inplace: bool = False):
  310. super().__init__()
  311. self.inplace = inplace
  312. def forward(self, input: Tensor) -> Tensor:
  313. return F.mish(input, inplace=self.inplace)
  314. def extra_repr(self) -> str:
  315. inplace_str = 'inplace=True' if self.inplace else ''
  316. return inplace_str
  317. class Hardswish(Module):
  318. r"""Applies the Hardswish function, element-wise, as described in the paper:
  319. `Searching for MobileNetV3 <https://arxiv.org/abs/1905.02244>`_.
  320. Hardswish is defined as:
  321. .. math::
  322. \text{Hardswish}(x) = \begin{cases}
  323. 0 & \text{if~} x \le -3, \\
  324. x & \text{if~} x \ge +3, \\
  325. x \cdot (x + 3) /6 & \text{otherwise}
  326. \end{cases}
  327. Args:
  328. inplace: can optionally do the operation in-place. Default: ``False``
  329. Shape:
  330. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  331. - Output: :math:`(*)`, same shape as the input.
  332. .. image:: ../scripts/activation_images/Hardswish.png
  333. Examples::
  334. >>> m = nn.Hardswish()
  335. >>> input = torch.randn(2)
  336. >>> output = m(input)
  337. """
  338. __constants__ = ['inplace']
  339. inplace: bool
  340. def __init__(self, inplace : bool = False) -> None:
  341. super().__init__()
  342. self.inplace = inplace
  343. def forward(self, input: Tensor) -> Tensor:
  344. return F.hardswish(input, self.inplace)
  345. class ELU(Module):
  346. r"""Applies the Exponential Linear Unit (ELU) function, element-wise, as described
  347. in the paper: `Fast and Accurate Deep Network Learning by Exponential Linear
  348. Units (ELUs) <https://arxiv.org/abs/1511.07289>`__.
  349. ELU is defined as:
  350. .. math::
  351. \text{ELU}(x) = \begin{cases}
  352. x, & \text{ if } x > 0\\
  353. \alpha * (\exp(x) - 1), & \text{ if } x \leq 0
  354. \end{cases}
  355. Args:
  356. alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0
  357. inplace: can optionally do the operation in-place. Default: ``False``
  358. Shape:
  359. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  360. - Output: :math:`(*)`, same shape as the input.
  361. .. image:: ../scripts/activation_images/ELU.png
  362. Examples::
  363. >>> m = nn.ELU()
  364. >>> input = torch.randn(2)
  365. >>> output = m(input)
  366. """
  367. __constants__ = ['alpha', 'inplace']
  368. alpha: float
  369. inplace: bool
  370. def __init__(self, alpha: float = 1., inplace: bool = False) -> None:
  371. super().__init__()
  372. self.alpha = alpha
  373. self.inplace = inplace
  374. def forward(self, input: Tensor) -> Tensor:
  375. return F.elu(input, self.alpha, self.inplace)
  376. def extra_repr(self) -> str:
  377. inplace_str = ', inplace=True' if self.inplace else ''
  378. return 'alpha={}{}'.format(self.alpha, inplace_str)
  379. class CELU(Module):
  380. r"""Applies the element-wise function:
  381. .. math::
  382. \text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))
  383. More details can be found in the paper `Continuously Differentiable Exponential Linear Units`_ .
  384. Args:
  385. alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0
  386. inplace: can optionally do the operation in-place. Default: ``False``
  387. Shape:
  388. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  389. - Output: :math:`(*)`, same shape as the input.
  390. .. image:: ../scripts/activation_images/CELU.png
  391. Examples::
  392. >>> m = nn.CELU()
  393. >>> input = torch.randn(2)
  394. >>> output = m(input)
  395. .. _`Continuously Differentiable Exponential Linear Units`:
  396. https://arxiv.org/abs/1704.07483
  397. """
  398. __constants__ = ['alpha', 'inplace']
  399. alpha: float
  400. inplace: bool
  401. def __init__(self, alpha: float = 1., inplace: bool = False) -> None:
  402. super().__init__()
  403. self.alpha = alpha
  404. self.inplace = inplace
  405. def forward(self, input: Tensor) -> Tensor:
  406. return F.celu(input, self.alpha, self.inplace)
  407. def extra_repr(self) -> str:
  408. inplace_str = ', inplace=True' if self.inplace else ''
  409. return 'alpha={}{}'.format(self.alpha, inplace_str)
  410. class SELU(Module):
  411. r"""Applied element-wise, as:
  412. .. math::
  413. \text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)))
  414. with :math:`\alpha = 1.6732632423543772848170429916717` and
  415. :math:`\text{scale} = 1.0507009873554804934193349852946`.
  416. .. warning::
  417. When using ``kaiming_normal`` or ``kaiming_normal_`` for initialisation,
  418. ``nonlinearity='linear'`` should be used instead of ``nonlinearity='selu'``
  419. in order to get `Self-Normalizing Neural Networks`_.
  420. See :func:`torch.nn.init.calculate_gain` for more information.
  421. More details can be found in the paper `Self-Normalizing Neural Networks`_ .
  422. Args:
  423. inplace (bool, optional): can optionally do the operation in-place. Default: ``False``
  424. Shape:
  425. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  426. - Output: :math:`(*)`, same shape as the input.
  427. .. image:: ../scripts/activation_images/SELU.png
  428. Examples::
  429. >>> m = nn.SELU()
  430. >>> input = torch.randn(2)
  431. >>> output = m(input)
  432. .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
  433. """
  434. __constants__ = ['inplace']
  435. inplace: bool
  436. def __init__(self, inplace: bool = False) -> None:
  437. super().__init__()
  438. self.inplace = inplace
  439. def forward(self, input: Tensor) -> Tensor:
  440. return F.selu(input, self.inplace)
  441. def extra_repr(self) -> str:
  442. inplace_str = 'inplace=True' if self.inplace else ''
  443. return inplace_str
  444. class GLU(Module):
  445. r"""Applies the gated linear unit function
  446. :math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half
  447. of the input matrices and :math:`b` is the second half.
  448. Args:
  449. dim (int): the dimension on which to split the input. Default: -1
  450. Shape:
  451. - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
  452. dimensions
  453. - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
  454. Examples::
  455. >>> m = nn.GLU()
  456. >>> input = torch.randn(4, 2)
  457. >>> output = m(input)
  458. """
  459. __constants__ = ['dim']
  460. dim: int
  461. def __init__(self, dim: int = -1) -> None:
  462. super().__init__()
  463. self.dim = dim
  464. def forward(self, input: Tensor) -> Tensor:
  465. return F.glu(input, self.dim)
  466. def extra_repr(self) -> str:
  467. return 'dim={}'.format(self.dim)
  468. class GELU(Module):
  469. r"""Applies the Gaussian Error Linear Units function:
  470. .. math:: \text{GELU}(x) = x * \Phi(x)
  471. where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
  472. When the approximate argument is 'tanh', Gelu is estimated with:
  473. .. math:: \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt(2 / \pi) * (x + 0.044715 * x^3)))
  474. Args:
  475. approximate (str, optional): the gelu approximation algorithm to use:
  476. ``'none'`` | ``'tanh'``. Default: ``'none'``
  477. Shape:
  478. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  479. - Output: :math:`(*)`, same shape as the input.
  480. .. image:: ../scripts/activation_images/GELU.png
  481. Examples::
  482. >>> m = nn.GELU()
  483. >>> input = torch.randn(2)
  484. >>> output = m(input)
  485. """
  486. __constants__ = ['approximate']
  487. approximate: str
  488. def __init__(self, approximate: str = 'none') -> None:
  489. super().__init__()
  490. self.approximate = approximate
  491. def forward(self, input: Tensor) -> Tensor:
  492. return F.gelu(input, approximate=self.approximate)
  493. def extra_repr(self) -> str:
  494. return 'approximate={}'.format(repr(self.approximate))
  495. class Hardshrink(Module):
  496. r"""Applies the Hard Shrinkage (Hardshrink) function element-wise.
  497. Hardshrink is defined as:
  498. .. math::
  499. \text{HardShrink}(x) =
  500. \begin{cases}
  501. x, & \text{ if } x > \lambda \\
  502. x, & \text{ if } x < -\lambda \\
  503. 0, & \text{ otherwise }
  504. \end{cases}
  505. Args:
  506. lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5
  507. Shape:
  508. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  509. - Output: :math:`(*)`, same shape as the input.
  510. .. image:: ../scripts/activation_images/Hardshrink.png
  511. Examples::
  512. >>> m = nn.Hardshrink()
  513. >>> input = torch.randn(2)
  514. >>> output = m(input)
  515. """
  516. __constants__ = ['lambd']
  517. lambd: float
  518. def __init__(self, lambd: float = 0.5) -> None:
  519. super().__init__()
  520. self.lambd = lambd
  521. def forward(self, input: Tensor) -> Tensor:
  522. return F.hardshrink(input, self.lambd)
  523. def extra_repr(self) -> str:
  524. return '{}'.format(self.lambd)
  525. class LeakyReLU(Module):
  526. r"""Applies the element-wise function:
  527. .. math::
  528. \text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)
  529. or
  530. .. math::
  531. \text{LeakyReLU}(x) =
  532. \begin{cases}
  533. x, & \text{ if } x \geq 0 \\
  534. \text{negative\_slope} \times x, & \text{ otherwise }
  535. \end{cases}
  536. Args:
  537. negative_slope: Controls the angle of the negative slope (which is used for
  538. negative input values). Default: 1e-2
  539. inplace: can optionally do the operation in-place. Default: ``False``
  540. Shape:
  541. - Input: :math:`(*)` where `*` means, any number of additional
  542. dimensions
  543. - Output: :math:`(*)`, same shape as the input
  544. .. image:: ../scripts/activation_images/LeakyReLU.png
  545. Examples::
  546. >>> m = nn.LeakyReLU(0.1)
  547. >>> input = torch.randn(2)
  548. >>> output = m(input)
  549. """
  550. __constants__ = ['inplace', 'negative_slope']
  551. inplace: bool
  552. negative_slope: float
  553. def __init__(self, negative_slope: float = 1e-2, inplace: bool = False) -> None:
  554. super().__init__()
  555. self.negative_slope = negative_slope
  556. self.inplace = inplace
  557. def forward(self, input: Tensor) -> Tensor:
  558. return F.leaky_relu(input, self.negative_slope, self.inplace)
  559. def extra_repr(self) -> str:
  560. inplace_str = ', inplace=True' if self.inplace else ''
  561. return 'negative_slope={}{}'.format(self.negative_slope, inplace_str)
  562. class LogSigmoid(Module):
  563. r"""Applies the element-wise function:
  564. .. math::
  565. \text{LogSigmoid}(x) = \log\left(\frac{ 1 }{ 1 + \exp(-x)}\right)
  566. Shape:
  567. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  568. - Output: :math:`(*)`, same shape as the input.
  569. .. image:: ../scripts/activation_images/LogSigmoid.png
  570. Examples::
  571. >>> m = nn.LogSigmoid()
  572. >>> input = torch.randn(2)
  573. >>> output = m(input)
  574. """
  575. def forward(self, input: Tensor) -> Tensor:
  576. return F.logsigmoid(input)
  577. class Softplus(Module):
  578. r"""Applies the Softplus function :math:`\text{Softplus}(x) = \frac{1}{\beta} *
  579. \log(1 + \exp(\beta * x))` element-wise.
  580. SoftPlus is a smooth approximation to the ReLU function and can be used
  581. to constrain the output of a machine to always be positive.
  582. For numerical stability the implementation reverts to the linear function
  583. when :math:`input \times \beta > threshold`.
  584. Args:
  585. beta: the :math:`\beta` value for the Softplus formulation. Default: 1
  586. threshold: values above this revert to a linear function. Default: 20
  587. Shape:
  588. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  589. - Output: :math:`(*)`, same shape as the input.
  590. .. image:: ../scripts/activation_images/Softplus.png
  591. Examples::
  592. >>> m = nn.Softplus()
  593. >>> input = torch.randn(2)
  594. >>> output = m(input)
  595. """
  596. __constants__ = ['beta', 'threshold']
  597. beta: int
  598. threshold: int
  599. def __init__(self, beta: int = 1, threshold: int = 20) -> None:
  600. super().__init__()
  601. self.beta = beta
  602. self.threshold = threshold
  603. def forward(self, input: Tensor) -> Tensor:
  604. return F.softplus(input, self.beta, self.threshold)
  605. def extra_repr(self) -> str:
  606. return 'beta={}, threshold={}'.format(self.beta, self.threshold)
  607. class Softshrink(Module):
  608. r"""Applies the soft shrinkage function elementwise:
  609. .. math::
  610. \text{SoftShrinkage}(x) =
  611. \begin{cases}
  612. x - \lambda, & \text{ if } x > \lambda \\
  613. x + \lambda, & \text{ if } x < -\lambda \\
  614. 0, & \text{ otherwise }
  615. \end{cases}
  616. Args:
  617. lambd: the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5
  618. Shape:
  619. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  620. - Output: :math:`(*)`, same shape as the input.
  621. .. image:: ../scripts/activation_images/Softshrink.png
  622. Examples::
  623. >>> m = nn.Softshrink()
  624. >>> input = torch.randn(2)
  625. >>> output = m(input)
  626. """
  627. __constants__ = ['lambd']
  628. lambd: float
  629. def __init__(self, lambd: float = 0.5) -> None:
  630. super().__init__()
  631. self.lambd = lambd
  632. def forward(self, input: Tensor) -> Tensor:
  633. return F.softshrink(input, self.lambd)
  634. def extra_repr(self) -> str:
  635. return str(self.lambd)
  636. class MultiheadAttention(Module):
  637. r"""Allows the model to jointly attend to information
  638. from different representation subspaces as described in the paper:
  639. `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
  640. Multi-Head Attention is defined as:
  641. .. math::
  642. \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
  643. where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
  644. ``forward()`` will use the optimized implementation described in
  645. `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`_ if all of the following
  646. conditions are met:
  647. - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
  648. restriction will be loosened in the future.)
  649. - inputs are batched (3D) with ``batch_first==True``
  650. - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
  651. - training is disabled (using ``.eval()``)
  652. - ``add_bias_kv`` is ``False``
  653. - ``add_zero_attn`` is ``False``
  654. - ``batch_first`` is ``True`` and the input is batched
  655. - ``kdim`` and ``vdim`` are equal to ``embed_dim``
  656. - if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
  657. nor ``attn_mask`` is passed
  658. - autocast is disabled
  659. If the optimized implementation is in use, a
  660. `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
  661. ``query``/``key``/``value`` to represent padding more efficiently than using a
  662. padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
  663. will be returned, and an additional speedup proportional to the fraction of the input
  664. that is padding can be expected.
  665. Args:
  666. embed_dim: Total dimension of the model.
  667. num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
  668. across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
  669. dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
  670. bias: If specified, adds bias to input / output projection layers. Default: ``True``.
  671. add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
  672. add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
  673. Default: ``False``.
  674. kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
  675. vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
  676. batch_first: If ``True``, then the input and output tensors are provided
  677. as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
  678. Examples::
  679. >>> # xdoctest: +SKIP
  680. >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
  681. >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
  682. .. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`:
  683. https://arxiv.org/abs/2205.14135
  684. """
  685. __constants__ = ['batch_first']
  686. bias_k: Optional[torch.Tensor]
  687. bias_v: Optional[torch.Tensor]
  688. def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False,
  689. kdim=None, vdim=None, batch_first=False, device=None, dtype=None) -> None:
  690. factory_kwargs = {'device': device, 'dtype': dtype}
  691. super().__init__()
  692. self.embed_dim = embed_dim
  693. self.kdim = kdim if kdim is not None else embed_dim
  694. self.vdim = vdim if vdim is not None else embed_dim
  695. self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
  696. self.num_heads = num_heads
  697. self.dropout = dropout
  698. self.batch_first = batch_first
  699. self.head_dim = embed_dim // num_heads
  700. assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
  701. if not self._qkv_same_embed_dim:
  702. self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs))
  703. self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs))
  704. self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs))
  705. self.register_parameter('in_proj_weight', None)
  706. else:
  707. self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
  708. self.register_parameter('q_proj_weight', None)
  709. self.register_parameter('k_proj_weight', None)
  710. self.register_parameter('v_proj_weight', None)
  711. if bias:
  712. self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
  713. else:
  714. self.register_parameter('in_proj_bias', None)
  715. self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
  716. if add_bias_kv:
  717. self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
  718. self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
  719. else:
  720. self.bias_k = self.bias_v = None
  721. self.add_zero_attn = add_zero_attn
  722. self._reset_parameters()
  723. def _reset_parameters(self):
  724. if self._qkv_same_embed_dim:
  725. xavier_uniform_(self.in_proj_weight)
  726. else:
  727. xavier_uniform_(self.q_proj_weight)
  728. xavier_uniform_(self.k_proj_weight)
  729. xavier_uniform_(self.v_proj_weight)
  730. if self.in_proj_bias is not None:
  731. constant_(self.in_proj_bias, 0.)
  732. constant_(self.out_proj.bias, 0.)
  733. if self.bias_k is not None:
  734. xavier_normal_(self.bias_k)
  735. if self.bias_v is not None:
  736. xavier_normal_(self.bias_v)
  737. def __setstate__(self, state):
  738. # Support loading old MultiheadAttention checkpoints generated by v1.1.0
  739. if '_qkv_same_embed_dim' not in state:
  740. state['_qkv_same_embed_dim'] = True
  741. super().__setstate__(state)
  742. def forward(
  743. self,
  744. query: Tensor,
  745. key: Tensor,
  746. value: Tensor,
  747. key_padding_mask: Optional[Tensor] = None,
  748. need_weights: bool = True,
  749. attn_mask: Optional[Tensor] = None,
  750. average_attn_weights: bool = True,
  751. is_causal : bool = False) -> Tuple[Tensor, Optional[Tensor]]:
  752. r"""
  753. Args:
  754. query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
  755. or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
  756. :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
  757. Queries are compared against key-value pairs to produce the output.
  758. See "Attention Is All You Need" for more details.
  759. key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
  760. or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
  761. :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
  762. See "Attention Is All You Need" for more details.
  763. value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
  764. ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
  765. sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
  766. See "Attention Is All You Need" for more details.
  767. key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
  768. to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
  769. Binary and float masks are supported.
  770. For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
  771. the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
  772. need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
  773. Default: ``True``.
  774. attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
  775. :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
  776. :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
  777. broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
  778. Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the
  779. corresponding position is not allowed to attend. For a float mask, the mask values will be added to
  780. the attention weight.
  781. If both attn_mask and key_padding_mask are supplied, their types should match.
  782. is_causal: If specified, applies a causal mask as attention mask. Mutually exclusive with providing attn_mask.
  783. Default: ``False``.
  784. average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
  785. heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
  786. effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
  787. Outputs:
  788. - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
  789. :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
  790. where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
  791. embedding dimension ``embed_dim``.
  792. - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
  793. returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
  794. :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
  795. :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
  796. head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
  797. .. note::
  798. `batch_first` argument is ignored for unbatched inputs.
  799. """
  800. if attn_mask is not None and is_causal:
  801. raise AssertionError("Only allow causal mask or attn_mask")
  802. is_batched = query.dim() == 3
  803. key_padding_mask = F._canonical_mask(
  804. mask=key_padding_mask,
  805. mask_name="key_padding_mask",
  806. other_type=F._none_or_dtype(attn_mask),
  807. other_name="attn_mask",
  808. target_type=query.dtype
  809. )
  810. why_not_fast_path = ''
  811. if not is_batched:
  812. why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
  813. elif query is not key or key is not value:
  814. # When lifting this restriction, don't forget to either
  815. # enforce that the dtypes all match or test cases where
  816. # they don't!
  817. why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
  818. elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
  819. why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
  820. elif self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype:
  821. # this case will fail anyway, but at least they'll get a useful error message.
  822. why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
  823. elif self.training:
  824. why_not_fast_path = "training is enabled"
  825. elif not self.batch_first:
  826. why_not_fast_path = "batch_first was not True"
  827. elif self.bias_k is not None:
  828. why_not_fast_path = "self.bias_k was not None"
  829. elif self.bias_v is not None:
  830. why_not_fast_path = "self.bias_v was not None"
  831. elif self.add_zero_attn:
  832. why_not_fast_path = "add_zero_attn was enabled"
  833. elif not self._qkv_same_embed_dim:
  834. why_not_fast_path = "_qkv_same_embed_dim was not True"
  835. elif query.is_nested and (key_padding_mask is not None or attn_mask is not None):
  836. why_not_fast_path = "supplying both src_key_padding_mask and src_mask at the same time \
  837. is not supported with NestedTensor input"
  838. elif torch.is_autocast_enabled():
  839. why_not_fast_path = "autocast is enabled"
  840. if not why_not_fast_path:
  841. tensor_args = (
  842. query,
  843. key,
  844. value,
  845. self.in_proj_weight,
  846. self.in_proj_bias,
  847. self.out_proj.weight,
  848. self.out_proj.bias,
  849. )
  850. # We have to use list comprehensions below because TorchScript does not support
  851. # generator expressions.
  852. if torch.overrides.has_torch_function(tensor_args):
  853. why_not_fast_path = "some Tensor argument has_torch_function"
  854. elif not all([(x is None or x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]):
  855. why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
  856. elif torch.is_grad_enabled() and any([x is not None and x.requires_grad for x in tensor_args]):
  857. why_not_fast_path = ("grad is enabled and at least one of query or the "
  858. "input/output projection weights or biases requires_grad")
  859. if not why_not_fast_path:
  860. merged_mask, mask_type = self.merge_masks(attn_mask, key_padding_mask, query)
  861. return torch._native_multi_head_attention(
  862. query,
  863. key,
  864. value,
  865. self.embed_dim,
  866. self.num_heads,
  867. self.in_proj_weight,
  868. self.in_proj_bias,
  869. self.out_proj.weight,
  870. self.out_proj.bias,
  871. merged_mask,
  872. need_weights,
  873. average_attn_weights,
  874. mask_type)
  875. any_nested = query.is_nested or key.is_nested or value.is_nested
  876. assert not any_nested, ("MultiheadAttention does not support NestedTensor outside of its fast path. " +
  877. f"The fast path was not hit because {why_not_fast_path}")
  878. if self.batch_first and is_batched:
  879. # make sure that the transpose op does not affect the "is" property
  880. if key is value:
  881. if query is key:
  882. query = key = value = query.transpose(1, 0)
  883. else:
  884. query, key = [x.transpose(1, 0) for x in (query, key)]
  885. value = key
  886. else:
  887. query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
  888. if not self._qkv_same_embed_dim:
  889. attn_output, attn_output_weights = F.multi_head_attention_forward(
  890. query, key, value, self.embed_dim, self.num_heads,
  891. self.in_proj_weight, self.in_proj_bias,
  892. self.bias_k, self.bias_v, self.add_zero_attn,
  893. self.dropout, self.out_proj.weight, self.out_proj.bias,
  894. training=self.training,
  895. key_padding_mask=key_padding_mask, need_weights=need_weights,
  896. attn_mask=attn_mask,
  897. use_separate_proj_weight=True,
  898. q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
  899. v_proj_weight=self.v_proj_weight,
  900. average_attn_weights=average_attn_weights,
  901. is_causal=is_causal)
  902. else:
  903. attn_output, attn_output_weights = F.multi_head_attention_forward(
  904. query, key, value, self.embed_dim, self.num_heads,
  905. self.in_proj_weight, self.in_proj_bias,
  906. self.bias_k, self.bias_v, self.add_zero_attn,
  907. self.dropout, self.out_proj.weight, self.out_proj.bias,
  908. training=self.training,
  909. key_padding_mask=key_padding_mask,
  910. need_weights=need_weights,
  911. attn_mask=attn_mask,
  912. average_attn_weights=average_attn_weights,
  913. is_causal=is_causal)
  914. if self.batch_first and is_batched:
  915. return attn_output.transpose(1, 0), attn_output_weights
  916. else:
  917. return attn_output, attn_output_weights
  918. def merge_masks(self, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor],
  919. query: Tensor) -> Tuple[Optional[Tensor], Optional[int]]:
  920. r"""
  921. Determine mask type and combine masks if necessary. If only one mask is provided, that mask
  922. and the corresponding mask type will be returned. If both masks are provided, they will be both
  923. expanded to shape ``(batch_size, num_heads, seq_len, seq_len)``, combined with logical ``or``
  924. and mask type 2 will be returned
  925. Args:
  926. attn_mask: attention mask of shape ``(seq_len, seq_len)``, mask type 0
  927. key_padding_mask: padding mask of shape ``(batch_size, seq_len)``, mask type 1
  928. query: query embeddings of shape ``(batch_size, seq_len, embed_dim)``
  929. Returns:
  930. merged_mask: merged mask
  931. mask_type: merged mask type (0, 1, or 2)
  932. """
  933. mask_type: Optional[int] = None
  934. merged_mask: Optional[Tensor] = None
  935. attn_mask = F._canonical_mask(
  936. mask=attn_mask,
  937. mask_name="attn_mask",
  938. other_type=F._none_or_dtype(key_padding_mask),
  939. other_name="key_padding_mask",
  940. target_type=query.dtype,
  941. check_other=False,
  942. )
  943. if attn_mask is not None:
  944. mask_type = 0
  945. merged_mask = attn_mask
  946. if key_padding_mask is not None:
  947. mask_type = 1
  948. merged_mask = key_padding_mask
  949. if (attn_mask is not None) and (key_padding_mask is not None):
  950. # In this branch query can't be a nested tensor, so it has a shape
  951. batch_size, seq_len, _ = query.shape
  952. mask_type = 2
  953. key_padding_mask_expanded = key_padding_mask.view(batch_size, 1, 1, seq_len) \
  954. .expand(-1, self.num_heads, -1, -1)
  955. attn_mask_expanded = attn_mask.view(1, 1, seq_len, seq_len).expand(batch_size, self.num_heads, -1, -1)
  956. merged_mask = attn_mask_expanded + key_padding_mask_expanded
  957. return merged_mask, mask_type
  958. class PReLU(Module):
  959. r"""Applies the element-wise function:
  960. .. math::
  961. \text{PReLU}(x) = \max(0,x) + a * \min(0,x)
  962. or
  963. .. math::
  964. \text{PReLU}(x) =
  965. \begin{cases}
  966. x, & \text{ if } x \geq 0 \\
  967. ax, & \text{ otherwise }
  968. \end{cases}
  969. Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single
  970. parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`,
  971. a separate :math:`a` is used for each input channel.
  972. .. note::
  973. weight decay should not be used when learning :math:`a` for good performance.
  974. .. note::
  975. Channel dim is the 2nd dim of input. When input has dims < 2, then there is
  976. no channel dim and the number of channels = 1.
  977. Args:
  978. num_parameters (int): number of :math:`a` to learn.
  979. Although it takes an int as input, there is only two values are legitimate:
  980. 1, or the number of channels at input. Default: 1
  981. init (float): the initial value of :math:`a`. Default: 0.25
  982. Shape:
  983. - Input: :math:`( *)` where `*` means, any number of additional
  984. dimensions.
  985. - Output: :math:`(*)`, same shape as the input.
  986. Attributes:
  987. weight (Tensor): the learnable weights of shape (:attr:`num_parameters`).
  988. .. image:: ../scripts/activation_images/PReLU.png
  989. Examples::
  990. >>> m = nn.PReLU()
  991. >>> input = torch.randn(2)
  992. >>> output = m(input)
  993. """
  994. __constants__ = ['num_parameters']
  995. num_parameters: int
  996. def __init__(self, num_parameters: int = 1, init: float = 0.25,
  997. device=None, dtype=None) -> None:
  998. factory_kwargs = {'device': device, 'dtype': dtype}
  999. self.num_parameters = num_parameters
  1000. super().__init__()
  1001. self.weight = Parameter(torch.empty(num_parameters, **factory_kwargs).fill_(init))
  1002. def forward(self, input: Tensor) -> Tensor:
  1003. return F.prelu(input, self.weight)
  1004. def extra_repr(self) -> str:
  1005. return 'num_parameters={}'.format(self.num_parameters)
  1006. class Softsign(Module):
  1007. r"""Applies the element-wise function:
  1008. .. math::
  1009. \text{SoftSign}(x) = \frac{x}{ 1 + |x|}
  1010. Shape:
  1011. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  1012. - Output: :math:`(*)`, same shape as the input.
  1013. .. image:: ../scripts/activation_images/Softsign.png
  1014. Examples::
  1015. >>> m = nn.Softsign()
  1016. >>> input = torch.randn(2)
  1017. >>> output = m(input)
  1018. """
  1019. def forward(self, input: Tensor) -> Tensor:
  1020. return F.softsign(input)
  1021. class Tanhshrink(Module):
  1022. r"""Applies the element-wise function:
  1023. .. math::
  1024. \text{Tanhshrink}(x) = x - \tanh(x)
  1025. Shape:
  1026. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  1027. - Output: :math:`(*)`, same shape as the input.
  1028. .. image:: ../scripts/activation_images/Tanhshrink.png
  1029. Examples::
  1030. >>> m = nn.Tanhshrink()
  1031. >>> input = torch.randn(2)
  1032. >>> output = m(input)
  1033. """
  1034. def forward(self, input: Tensor) -> Tensor:
  1035. return F.tanhshrink(input)
  1036. class Softmin(Module):
  1037. r"""Applies the Softmin function to an n-dimensional input Tensor
  1038. rescaling them so that the elements of the n-dimensional output Tensor
  1039. lie in the range `[0, 1]` and sum to 1.
  1040. Softmin is defined as:
  1041. .. math::
  1042. \text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
  1043. Shape:
  1044. - Input: :math:`(*)` where `*` means, any number of additional
  1045. dimensions
  1046. - Output: :math:`(*)`, same shape as the input
  1047. Args:
  1048. dim (int): A dimension along which Softmin will be computed (so every slice
  1049. along dim will sum to 1).
  1050. Returns:
  1051. a Tensor of the same dimension and shape as the input, with
  1052. values in the range [0, 1]
  1053. Examples::
  1054. >>> m = nn.Softmin(dim=1)
  1055. >>> input = torch.randn(2, 3)
  1056. >>> output = m(input)
  1057. """
  1058. __constants__ = ['dim']
  1059. dim: Optional[int]
  1060. def __init__(self, dim: Optional[int] = None) -> None:
  1061. super().__init__()
  1062. self.dim = dim
  1063. def __setstate__(self, state):
  1064. super().__setstate__(state)
  1065. if not hasattr(self, 'dim'):
  1066. self.dim = None
  1067. def forward(self, input: Tensor) -> Tensor:
  1068. return F.softmin(input, self.dim, _stacklevel=5)
  1069. def extra_repr(self):
  1070. return 'dim={dim}'.format(dim=self.dim)
  1071. class Softmax(Module):
  1072. r"""Applies the Softmax function to an n-dimensional input Tensor
  1073. rescaling them so that the elements of the n-dimensional output Tensor
  1074. lie in the range [0,1] and sum to 1.
  1075. Softmax is defined as:
  1076. .. math::
  1077. \text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
  1078. When the input Tensor is a sparse tensor then the unspecified
  1079. values are treated as ``-inf``.
  1080. Shape:
  1081. - Input: :math:`(*)` where `*` means, any number of additional
  1082. dimensions
  1083. - Output: :math:`(*)`, same shape as the input
  1084. Returns:
  1085. a Tensor of the same dimension and shape as the input with
  1086. values in the range [0, 1]
  1087. Args:
  1088. dim (int): A dimension along which Softmax will be computed (so every slice
  1089. along dim will sum to 1).
  1090. .. note::
  1091. This module doesn't work directly with NLLLoss,
  1092. which expects the Log to be computed between the Softmax and itself.
  1093. Use `LogSoftmax` instead (it's faster and has better numerical properties).
  1094. Examples::
  1095. >>> m = nn.Softmax(dim=1)
  1096. >>> input = torch.randn(2, 3)
  1097. >>> output = m(input)
  1098. """
  1099. __constants__ = ['dim']
  1100. dim: Optional[int]
  1101. def __init__(self, dim: Optional[int] = None) -> None:
  1102. super().__init__()
  1103. self.dim = dim
  1104. def __setstate__(self, state):
  1105. super().__setstate__(state)
  1106. if not hasattr(self, 'dim'):
  1107. self.dim = None
  1108. def forward(self, input: Tensor) -> Tensor:
  1109. return F.softmax(input, self.dim, _stacklevel=5)
  1110. def extra_repr(self) -> str:
  1111. return 'dim={dim}'.format(dim=self.dim)
  1112. class Softmax2d(Module):
  1113. r"""Applies SoftMax over features to each spatial location.
  1114. When given an image of ``Channels x Height x Width``, it will
  1115. apply `Softmax` to each location :math:`(Channels, h_i, w_j)`
  1116. Shape:
  1117. - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`.
  1118. - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input)
  1119. Returns:
  1120. a Tensor of the same dimension and shape as the input with
  1121. values in the range [0, 1]
  1122. Examples::
  1123. >>> m = nn.Softmax2d()
  1124. >>> # you softmax over the 2nd dimension
  1125. >>> input = torch.randn(2, 3, 12, 13)
  1126. >>> output = m(input)
  1127. """
  1128. def forward(self, input: Tensor) -> Tensor:
  1129. assert input.dim() == 4 or input.dim() == 3, 'Softmax2d requires a 3D or 4D tensor as input'
  1130. return F.softmax(input, -3, _stacklevel=5)
  1131. class LogSoftmax(Module):
  1132. r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional
  1133. input Tensor. The LogSoftmax formulation can be simplified as:
  1134. .. math::
  1135. \text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)
  1136. Shape:
  1137. - Input: :math:`(*)` where `*` means, any number of additional
  1138. dimensions
  1139. - Output: :math:`(*)`, same shape as the input
  1140. Args:
  1141. dim (int): A dimension along which LogSoftmax will be computed.
  1142. Returns:
  1143. a Tensor of the same dimension and shape as the input with
  1144. values in the range [-inf, 0)
  1145. Examples::
  1146. >>> m = nn.LogSoftmax(dim=1)
  1147. >>> input = torch.randn(2, 3)
  1148. >>> output = m(input)
  1149. """
  1150. __constants__ = ['dim']
  1151. dim: Optional[int]
  1152. def __init__(self, dim: Optional[int] = None) -> None:
  1153. super().__init__()
  1154. self.dim = dim
  1155. def __setstate__(self, state):
  1156. super().__setstate__(state)
  1157. if not hasattr(self, 'dim'):
  1158. self.dim = None
  1159. def forward(self, input: Tensor) -> Tensor:
  1160. return F.log_softmax(input, self.dim, _stacklevel=5)
  1161. def extra_repr(self):
  1162. return 'dim={dim}'.format(dim=self.dim)