123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533 |
- import warnings
- from typing import Optional, Tuple
- import torch
- from torch import Tensor
- from .linear import NonDynamicallyQuantizableLinear
- from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
- from torch.nn.parameter import Parameter
- from .module import Module
- from .. import functional as F
- __all__ = ['Threshold', 'ReLU', 'RReLU', 'Hardtanh', 'ReLU6', 'Sigmoid', 'Hardsigmoid', 'Tanh',
- 'SiLU', 'Mish', 'Hardswish', 'ELU', 'CELU', 'SELU', 'GLU', 'GELU', 'Hardshrink', 'LeakyReLU',
- 'LogSigmoid', 'Softplus', 'Softshrink', 'MultiheadAttention', 'PReLU', 'Softsign', 'Tanhshrink',
- 'Softmin', 'Softmax', 'Softmax2d', 'LogSoftmax']
- class Threshold(Module):
- r"""Thresholds each element of the input Tensor.
- Threshold is defined as:
- .. math::
- y =
- \begin{cases}
- x, &\text{ if } x > \text{threshold} \\
- \text{value}, &\text{ otherwise }
- \end{cases}
- Args:
- threshold: The value to threshold at
- value: The value to replace with
- inplace: can optionally do the operation in-place. Default: ``False``
- Shape:
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- - Output: :math:`(*)`, same shape as the input.
- Examples::
- >>> m = nn.Threshold(0.1, 20)
- >>> input = torch.randn(2)
- >>> output = m(input)
- """
- __constants__ = ['threshold', 'value', 'inplace']
- threshold: float
- value: float
- inplace: bool
- def __init__(self, threshold: float, value: float, inplace: bool = False) -> None:
- super().__init__()
- self.threshold = threshold
- self.value = value
- self.inplace = inplace
- # TODO: check in THNN (if inplace == True, then assert value <= threshold)
- def forward(self, input: Tensor) -> Tensor:
- return F.threshold(input, self.threshold, self.value, self.inplace)
- def extra_repr(self):
- inplace_str = ', inplace=True' if self.inplace else ''
- return 'threshold={}, value={}{}'.format(
- self.threshold, self.value, inplace_str
- )
- class ReLU(Module):
- r"""Applies the rectified linear unit function element-wise:
- :math:`\text{ReLU}(x) = (x)^+ = \max(0, x)`
- Args:
- inplace: can optionally do the operation in-place. Default: ``False``
- Shape:
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- - Output: :math:`(*)`, same shape as the input.
- .. image:: ../scripts/activation_images/ReLU.png
- Examples::
- >>> m = nn.ReLU()
- >>> input = torch.randn(2)
- >>> output = m(input)
- An implementation of CReLU - https://arxiv.org/abs/1603.05201
- >>> m = nn.ReLU()
- >>> input = torch.randn(2).unsqueeze(0)
- >>> output = torch.cat((m(input), m(-input)))
- """
- __constants__ = ['inplace']
- inplace: bool
- def __init__(self, inplace: bool = False):
- super().__init__()
- self.inplace = inplace
- def forward(self, input: Tensor) -> Tensor:
- return F.relu(input, inplace=self.inplace)
- def extra_repr(self) -> str:
- inplace_str = 'inplace=True' if self.inplace else ''
- return inplace_str
- class RReLU(Module):
- r"""Applies the randomized leaky rectified liner unit function, element-wise,
- as described in the paper:
- `Empirical Evaluation of Rectified Activations in Convolutional Network`_.
- The function is defined as:
- .. math::
- \text{RReLU}(x) =
- \begin{cases}
- x & \text{if } x \geq 0 \\
- ax & \text{ otherwise }
- \end{cases}
- where :math:`a` is randomly sampled from uniform distribution
- :math:`\mathcal{U}(\text{lower}, \text{upper})`.
- See: https://arxiv.org/pdf/1505.00853.pdf
- Args:
- lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}`
- upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}`
- inplace: can optionally do the operation in-place. Default: ``False``
- Shape:
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- - Output: :math:`(*)`, same shape as the input.
- .. image:: ../scripts/activation_images/RReLU.png
- Examples::
- >>> m = nn.RReLU(0.1, 0.3)
- >>> input = torch.randn(2)
- >>> output = m(input)
- .. _`Empirical Evaluation of Rectified Activations in Convolutional Network`:
- https://arxiv.org/abs/1505.00853
- """
- __constants__ = ['lower', 'upper', 'inplace']
- lower: float
- upper: float
- inplace: bool
- def __init__(
- self,
- lower: float = 1. / 8,
- upper: float = 1. / 3,
- inplace: bool = False
- ):
- super().__init__()
- self.lower = lower
- self.upper = upper
- self.inplace = inplace
- def forward(self, input: Tensor) -> Tensor:
- return F.rrelu(input, self.lower, self.upper, self.training, self.inplace)
- def extra_repr(self):
- inplace_str = ', inplace=True' if self.inplace else ''
- return 'lower={}, upper={}{}'.format(self.lower, self.upper, inplace_str)
- class Hardtanh(Module):
- r"""Applies the HardTanh function element-wise.
- HardTanh is defined as:
- .. math::
- \text{HardTanh}(x) = \begin{cases}
- \text{max\_val} & \text{ if } x > \text{ max\_val } \\
- \text{min\_val} & \text{ if } x < \text{ min\_val } \\
- x & \text{ otherwise } \\
- \end{cases}
- Args:
- min_val: minimum value of the linear region range. Default: -1
- max_val: maximum value of the linear region range. Default: 1
- inplace: can optionally do the operation in-place. Default: ``False``
- Keyword arguments :attr:`min_value` and :attr:`max_value`
- have been deprecated in favor of :attr:`min_val` and :attr:`max_val`.
- Shape:
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- - Output: :math:`(*)`, same shape as the input.
- .. image:: ../scripts/activation_images/Hardtanh.png
- Examples::
- >>> m = nn.Hardtanh(-2, 2)
- >>> input = torch.randn(2)
- >>> output = m(input)
- """
- __constants__ = ['min_val', 'max_val', 'inplace']
- min_val: float
- max_val: float
- inplace: bool
- def __init__(
- self,
- min_val: float = -1.,
- max_val: float = 1.,
- inplace: bool = False,
- min_value: Optional[float] = None,
- max_value: Optional[float] = None
- ) -> None:
- super().__init__()
- if min_value is not None:
- warnings.warn("keyword argument min_value is deprecated and rename to min_val")
- min_val = min_value
- if max_value is not None:
- warnings.warn("keyword argument max_value is deprecated and rename to max_val")
- max_val = max_value
- self.min_val = min_val
- self.max_val = max_val
- self.inplace = inplace
- assert self.max_val > self.min_val
- def forward(self, input: Tensor) -> Tensor:
- return F.hardtanh(input, self.min_val, self.max_val, self.inplace)
- def extra_repr(self) -> str:
- inplace_str = ', inplace=True' if self.inplace else ''
- return 'min_val={}, max_val={}{}'.format(
- self.min_val, self.max_val, inplace_str
- )
- class ReLU6(Hardtanh):
- r"""Applies the element-wise function:
- .. math::
- \text{ReLU6}(x) = \min(\max(0,x), 6)
- Args:
- inplace: can optionally do the operation in-place. Default: ``False``
- Shape:
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- - Output: :math:`(*)`, same shape as the input.
- .. image:: ../scripts/activation_images/ReLU6.png
- Examples::
- >>> m = nn.ReLU6()
- >>> input = torch.randn(2)
- >>> output = m(input)
- """
- def __init__(self, inplace: bool = False):
- super().__init__(0., 6., inplace)
- def extra_repr(self) -> str:
- inplace_str = 'inplace=True' if self.inplace else ''
- return inplace_str
- class Sigmoid(Module):
- r"""Applies the element-wise function:
- .. math::
- \text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
- Shape:
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- - Output: :math:`(*)`, same shape as the input.
- .. image:: ../scripts/activation_images/Sigmoid.png
- Examples::
- >>> m = nn.Sigmoid()
- >>> input = torch.randn(2)
- >>> output = m(input)
- """
- def forward(self, input: Tensor) -> Tensor:
- return torch.sigmoid(input)
- class Hardsigmoid(Module):
- r"""Applies the Hardsigmoid function element-wise.
- Hardsigmoid is defined as:
- .. math::
- \text{Hardsigmoid}(x) = \begin{cases}
- 0 & \text{if~} x \le -3, \\
- 1 & \text{if~} x \ge +3, \\
- x / 6 + 1 / 2 & \text{otherwise}
- \end{cases}
- Args:
- inplace: can optionally do the operation in-place. Default: ``False``
- Shape:
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- - Output: :math:`(*)`, same shape as the input.
- .. image:: ../scripts/activation_images/Hardsigmoid.png
- Examples::
- >>> m = nn.Hardsigmoid()
- >>> input = torch.randn(2)
- >>> output = m(input)
- """
- __constants__ = ['inplace']
- inplace: bool
- def __init__(self, inplace : bool = False) -> None:
- super().__init__()
- self.inplace = inplace
- def forward(self, input: Tensor) -> Tensor:
- return F.hardsigmoid(input, self.inplace)
- class Tanh(Module):
- r"""Applies the Hyperbolic Tangent (Tanh) function element-wise.
- Tanh is defined as:
- .. math::
- \text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)} {\exp(x) + \exp(-x)}
- Shape:
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- - Output: :math:`(*)`, same shape as the input.
- .. image:: ../scripts/activation_images/Tanh.png
- Examples::
- >>> m = nn.Tanh()
- >>> input = torch.randn(2)
- >>> output = m(input)
- """
- def forward(self, input: Tensor) -> Tensor:
- return torch.tanh(input)
- class SiLU(Module):
- r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise.
- The SiLU function is also known as the swish function.
- .. math::
- \text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.}
- .. note::
- See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_
- where the SiLU (Sigmoid Linear Unit) was originally coined, and see
- `Sigmoid-Weighted Linear Units for Neural Network Function Approximation
- in Reinforcement Learning <https://arxiv.org/abs/1702.03118>`_ and `Swish:
- a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941v1>`_
- where the SiLU was experimented with later.
- Shape:
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- - Output: :math:`(*)`, same shape as the input.
- .. image:: ../scripts/activation_images/SiLU.png
- Examples::
- >>> m = nn.SiLU()
- >>> input = torch.randn(2)
- >>> output = m(input)
- """
- __constants__ = ['inplace']
- inplace: bool
- def __init__(self, inplace: bool = False):
- super().__init__()
- self.inplace = inplace
- def forward(self, input: Tensor) -> Tensor:
- return F.silu(input, inplace=self.inplace)
- def extra_repr(self) -> str:
- inplace_str = 'inplace=True' if self.inplace else ''
- return inplace_str
- class Mish(Module):
- r"""Applies the Mish function, element-wise.
- Mish: A Self Regularized Non-Monotonic Neural Activation Function.
- .. math::
- \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
- .. note::
- See `Mish: A Self Regularized Non-Monotonic Neural Activation Function <https://arxiv.org/abs/1908.08681>`_
- Shape:
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- - Output: :math:`(*)`, same shape as the input.
- .. image:: ../scripts/activation_images/Mish.png
- Examples::
- >>> m = nn.Mish()
- >>> input = torch.randn(2)
- >>> output = m(input)
- """
- __constants__ = ['inplace']
- inplace: bool
- def __init__(self, inplace: bool = False):
- super().__init__()
- self.inplace = inplace
- def forward(self, input: Tensor) -> Tensor:
- return F.mish(input, inplace=self.inplace)
- def extra_repr(self) -> str:
- inplace_str = 'inplace=True' if self.inplace else ''
- return inplace_str
- class Hardswish(Module):
- r"""Applies the Hardswish function, element-wise, as described in the paper:
- `Searching for MobileNetV3 <https://arxiv.org/abs/1905.02244>`_.
- Hardswish is defined as:
- .. math::
- \text{Hardswish}(x) = \begin{cases}
- 0 & \text{if~} x \le -3, \\
- x & \text{if~} x \ge +3, \\
- x \cdot (x + 3) /6 & \text{otherwise}
- \end{cases}
- Args:
- inplace: can optionally do the operation in-place. Default: ``False``
- Shape:
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- - Output: :math:`(*)`, same shape as the input.
- .. image:: ../scripts/activation_images/Hardswish.png
- Examples::
- >>> m = nn.Hardswish()
- >>> input = torch.randn(2)
- >>> output = m(input)
- """
- __constants__ = ['inplace']
- inplace: bool
- def __init__(self, inplace : bool = False) -> None:
- super().__init__()
- self.inplace = inplace
- def forward(self, input: Tensor) -> Tensor:
- return F.hardswish(input, self.inplace)
- class ELU(Module):
- r"""Applies the Exponential Linear Unit (ELU) function, element-wise, as described
- in the paper: `Fast and Accurate Deep Network Learning by Exponential Linear
- Units (ELUs) <https://arxiv.org/abs/1511.07289>`__.
- ELU is defined as:
- .. math::
- \text{ELU}(x) = \begin{cases}
- x, & \text{ if } x > 0\\
- \alpha * (\exp(x) - 1), & \text{ if } x \leq 0
- \end{cases}
- Args:
- alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0
- inplace: can optionally do the operation in-place. Default: ``False``
- Shape:
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- - Output: :math:`(*)`, same shape as the input.
- .. image:: ../scripts/activation_images/ELU.png
- Examples::
- >>> m = nn.ELU()
- >>> input = torch.randn(2)
- >>> output = m(input)
- """
- __constants__ = ['alpha', 'inplace']
- alpha: float
- inplace: bool
- def __init__(self, alpha: float = 1., inplace: bool = False) -> None:
- super().__init__()
- self.alpha = alpha
- self.inplace = inplace
- def forward(self, input: Tensor) -> Tensor:
- return F.elu(input, self.alpha, self.inplace)
- def extra_repr(self) -> str:
- inplace_str = ', inplace=True' if self.inplace else ''
- return 'alpha={}{}'.format(self.alpha, inplace_str)
- class CELU(Module):
- r"""Applies the element-wise function:
- .. math::
- \text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))
- More details can be found in the paper `Continuously Differentiable Exponential Linear Units`_ .
- Args:
- alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0
- inplace: can optionally do the operation in-place. Default: ``False``
- Shape:
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- - Output: :math:`(*)`, same shape as the input.
- .. image:: ../scripts/activation_images/CELU.png
- Examples::
- >>> m = nn.CELU()
- >>> input = torch.randn(2)
- >>> output = m(input)
- .. _`Continuously Differentiable Exponential Linear Units`:
- https://arxiv.org/abs/1704.07483
- """
- __constants__ = ['alpha', 'inplace']
- alpha: float
- inplace: bool
- def __init__(self, alpha: float = 1., inplace: bool = False) -> None:
- super().__init__()
- self.alpha = alpha
- self.inplace = inplace
- def forward(self, input: Tensor) -> Tensor:
- return F.celu(input, self.alpha, self.inplace)
- def extra_repr(self) -> str:
- inplace_str = ', inplace=True' if self.inplace else ''
- return 'alpha={}{}'.format(self.alpha, inplace_str)
- class SELU(Module):
- r"""Applied element-wise, as:
- .. math::
- \text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)))
- with :math:`\alpha = 1.6732632423543772848170429916717` and
- :math:`\text{scale} = 1.0507009873554804934193349852946`.
- .. warning::
- When using ``kaiming_normal`` or ``kaiming_normal_`` for initialisation,
- ``nonlinearity='linear'`` should be used instead of ``nonlinearity='selu'``
- in order to get `Self-Normalizing Neural Networks`_.
- See :func:`torch.nn.init.calculate_gain` for more information.
- More details can be found in the paper `Self-Normalizing Neural Networks`_ .
- Args:
- inplace (bool, optional): can optionally do the operation in-place. Default: ``False``
- Shape:
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- - Output: :math:`(*)`, same shape as the input.
- .. image:: ../scripts/activation_images/SELU.png
- Examples::
- >>> m = nn.SELU()
- >>> input = torch.randn(2)
- >>> output = m(input)
- .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
- """
- __constants__ = ['inplace']
- inplace: bool
- def __init__(self, inplace: bool = False) -> None:
- super().__init__()
- self.inplace = inplace
- def forward(self, input: Tensor) -> Tensor:
- return F.selu(input, self.inplace)
- def extra_repr(self) -> str:
- inplace_str = 'inplace=True' if self.inplace else ''
- return inplace_str
- class GLU(Module):
- r"""Applies the gated linear unit function
- :math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half
- of the input matrices and :math:`b` is the second half.
- Args:
- dim (int): the dimension on which to split the input. Default: -1
- Shape:
- - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
- dimensions
- - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
- Examples::
- >>> m = nn.GLU()
- >>> input = torch.randn(4, 2)
- >>> output = m(input)
- """
- __constants__ = ['dim']
- dim: int
- def __init__(self, dim: int = -1) -> None:
- super().__init__()
- self.dim = dim
- def forward(self, input: Tensor) -> Tensor:
- return F.glu(input, self.dim)
- def extra_repr(self) -> str:
- return 'dim={}'.format(self.dim)
- class GELU(Module):
- r"""Applies the Gaussian Error Linear Units function:
- .. math:: \text{GELU}(x) = x * \Phi(x)
- where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
- When the approximate argument is 'tanh', Gelu is estimated with:
- .. math:: \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt(2 / \pi) * (x + 0.044715 * x^3)))
- Args:
- approximate (str, optional): the gelu approximation algorithm to use:
- ``'none'`` | ``'tanh'``. Default: ``'none'``
- Shape:
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- - Output: :math:`(*)`, same shape as the input.
- .. image:: ../scripts/activation_images/GELU.png
- Examples::
- >>> m = nn.GELU()
- >>> input = torch.randn(2)
- >>> output = m(input)
- """
- __constants__ = ['approximate']
- approximate: str
- def __init__(self, approximate: str = 'none') -> None:
- super().__init__()
- self.approximate = approximate
- def forward(self, input: Tensor) -> Tensor:
- return F.gelu(input, approximate=self.approximate)
- def extra_repr(self) -> str:
- return 'approximate={}'.format(repr(self.approximate))
- class Hardshrink(Module):
- r"""Applies the Hard Shrinkage (Hardshrink) function element-wise.
- Hardshrink is defined as:
- .. math::
- \text{HardShrink}(x) =
- \begin{cases}
- x, & \text{ if } x > \lambda \\
- x, & \text{ if } x < -\lambda \\
- 0, & \text{ otherwise }
- \end{cases}
- Args:
- lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5
- Shape:
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- - Output: :math:`(*)`, same shape as the input.
- .. image:: ../scripts/activation_images/Hardshrink.png
- Examples::
- >>> m = nn.Hardshrink()
- >>> input = torch.randn(2)
- >>> output = m(input)
- """
- __constants__ = ['lambd']
- lambd: float
- def __init__(self, lambd: float = 0.5) -> None:
- super().__init__()
- self.lambd = lambd
- def forward(self, input: Tensor) -> Tensor:
- return F.hardshrink(input, self.lambd)
- def extra_repr(self) -> str:
- return '{}'.format(self.lambd)
- class LeakyReLU(Module):
- r"""Applies the element-wise function:
- .. math::
- \text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)
- or
- .. math::
- \text{LeakyReLU}(x) =
- \begin{cases}
- x, & \text{ if } x \geq 0 \\
- \text{negative\_slope} \times x, & \text{ otherwise }
- \end{cases}
- Args:
- negative_slope: Controls the angle of the negative slope (which is used for
- negative input values). Default: 1e-2
- inplace: can optionally do the operation in-place. Default: ``False``
- Shape:
- - Input: :math:`(*)` where `*` means, any number of additional
- dimensions
- - Output: :math:`(*)`, same shape as the input
- .. image:: ../scripts/activation_images/LeakyReLU.png
- Examples::
- >>> m = nn.LeakyReLU(0.1)
- >>> input = torch.randn(2)
- >>> output = m(input)
- """
- __constants__ = ['inplace', 'negative_slope']
- inplace: bool
- negative_slope: float
- def __init__(self, negative_slope: float = 1e-2, inplace: bool = False) -> None:
- super().__init__()
- self.negative_slope = negative_slope
- self.inplace = inplace
- def forward(self, input: Tensor) -> Tensor:
- return F.leaky_relu(input, self.negative_slope, self.inplace)
- def extra_repr(self) -> str:
- inplace_str = ', inplace=True' if self.inplace else ''
- return 'negative_slope={}{}'.format(self.negative_slope, inplace_str)
- class LogSigmoid(Module):
- r"""Applies the element-wise function:
- .. math::
- \text{LogSigmoid}(x) = \log\left(\frac{ 1 }{ 1 + \exp(-x)}\right)
- Shape:
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- - Output: :math:`(*)`, same shape as the input.
- .. image:: ../scripts/activation_images/LogSigmoid.png
- Examples::
- >>> m = nn.LogSigmoid()
- >>> input = torch.randn(2)
- >>> output = m(input)
- """
- def forward(self, input: Tensor) -> Tensor:
- return F.logsigmoid(input)
- class Softplus(Module):
- r"""Applies the Softplus function :math:`\text{Softplus}(x) = \frac{1}{\beta} *
- \log(1 + \exp(\beta * x))` element-wise.
- SoftPlus is a smooth approximation to the ReLU function and can be used
- to constrain the output of a machine to always be positive.
- For numerical stability the implementation reverts to the linear function
- when :math:`input \times \beta > threshold`.
- Args:
- beta: the :math:`\beta` value for the Softplus formulation. Default: 1
- threshold: values above this revert to a linear function. Default: 20
- Shape:
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- - Output: :math:`(*)`, same shape as the input.
- .. image:: ../scripts/activation_images/Softplus.png
- Examples::
- >>> m = nn.Softplus()
- >>> input = torch.randn(2)
- >>> output = m(input)
- """
- __constants__ = ['beta', 'threshold']
- beta: int
- threshold: int
- def __init__(self, beta: int = 1, threshold: int = 20) -> None:
- super().__init__()
- self.beta = beta
- self.threshold = threshold
- def forward(self, input: Tensor) -> Tensor:
- return F.softplus(input, self.beta, self.threshold)
- def extra_repr(self) -> str:
- return 'beta={}, threshold={}'.format(self.beta, self.threshold)
- class Softshrink(Module):
- r"""Applies the soft shrinkage function elementwise:
- .. math::
- \text{SoftShrinkage}(x) =
- \begin{cases}
- x - \lambda, & \text{ if } x > \lambda \\
- x + \lambda, & \text{ if } x < -\lambda \\
- 0, & \text{ otherwise }
- \end{cases}
- Args:
- lambd: the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5
- Shape:
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- - Output: :math:`(*)`, same shape as the input.
- .. image:: ../scripts/activation_images/Softshrink.png
- Examples::
- >>> m = nn.Softshrink()
- >>> input = torch.randn(2)
- >>> output = m(input)
- """
- __constants__ = ['lambd']
- lambd: float
- def __init__(self, lambd: float = 0.5) -> None:
- super().__init__()
- self.lambd = lambd
- def forward(self, input: Tensor) -> Tensor:
- return F.softshrink(input, self.lambd)
- def extra_repr(self) -> str:
- return str(self.lambd)
- class MultiheadAttention(Module):
- r"""Allows the model to jointly attend to information
- from different representation subspaces as described in the paper:
- `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
- Multi-Head Attention is defined as:
- .. math::
- \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
- where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
- ``forward()`` will use the optimized implementation described in
- `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`_ if all of the following
- conditions are met:
- - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
- restriction will be loosened in the future.)
- - inputs are batched (3D) with ``batch_first==True``
- - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
- - training is disabled (using ``.eval()``)
- - ``add_bias_kv`` is ``False``
- - ``add_zero_attn`` is ``False``
- - ``batch_first`` is ``True`` and the input is batched
- - ``kdim`` and ``vdim`` are equal to ``embed_dim``
- - if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
- nor ``attn_mask`` is passed
- - autocast is disabled
- If the optimized implementation is in use, a
- `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
- ``query``/``key``/``value`` to represent padding more efficiently than using a
- padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
- will be returned, and an additional speedup proportional to the fraction of the input
- that is padding can be expected.
- Args:
- embed_dim: Total dimension of the model.
- num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
- across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
- dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
- bias: If specified, adds bias to input / output projection layers. Default: ``True``.
- add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
- add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
- Default: ``False``.
- kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
- vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
- batch_first: If ``True``, then the input and output tensors are provided
- as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
- Examples::
- >>> # xdoctest: +SKIP
- >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
- >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
- .. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`:
- https://arxiv.org/abs/2205.14135
- """
- __constants__ = ['batch_first']
- bias_k: Optional[torch.Tensor]
- bias_v: Optional[torch.Tensor]
- def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False,
- kdim=None, vdim=None, batch_first=False, device=None, dtype=None) -> None:
- factory_kwargs = {'device': device, 'dtype': dtype}
- super().__init__()
- self.embed_dim = embed_dim
- self.kdim = kdim if kdim is not None else embed_dim
- self.vdim = vdim if vdim is not None else embed_dim
- self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
- self.num_heads = num_heads
- self.dropout = dropout
- self.batch_first = batch_first
- self.head_dim = embed_dim // num_heads
- assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
- if not self._qkv_same_embed_dim:
- self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs))
- self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs))
- self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs))
- self.register_parameter('in_proj_weight', None)
- else:
- self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
- self.register_parameter('q_proj_weight', None)
- self.register_parameter('k_proj_weight', None)
- self.register_parameter('v_proj_weight', None)
- if bias:
- self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
- else:
- self.register_parameter('in_proj_bias', None)
- self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
- if add_bias_kv:
- self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
- self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
- else:
- self.bias_k = self.bias_v = None
- self.add_zero_attn = add_zero_attn
- self._reset_parameters()
- def _reset_parameters(self):
- if self._qkv_same_embed_dim:
- xavier_uniform_(self.in_proj_weight)
- else:
- xavier_uniform_(self.q_proj_weight)
- xavier_uniform_(self.k_proj_weight)
- xavier_uniform_(self.v_proj_weight)
- if self.in_proj_bias is not None:
- constant_(self.in_proj_bias, 0.)
- constant_(self.out_proj.bias, 0.)
- if self.bias_k is not None:
- xavier_normal_(self.bias_k)
- if self.bias_v is not None:
- xavier_normal_(self.bias_v)
- def __setstate__(self, state):
- # Support loading old MultiheadAttention checkpoints generated by v1.1.0
- if '_qkv_same_embed_dim' not in state:
- state['_qkv_same_embed_dim'] = True
- super().__setstate__(state)
- def forward(
- self,
- query: Tensor,
- key: Tensor,
- value: Tensor,
- key_padding_mask: Optional[Tensor] = None,
- need_weights: bool = True,
- attn_mask: Optional[Tensor] = None,
- average_attn_weights: bool = True,
- is_causal : bool = False) -> Tuple[Tensor, Optional[Tensor]]:
- r"""
- Args:
- query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
- or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
- :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
- Queries are compared against key-value pairs to produce the output.
- See "Attention Is All You Need" for more details.
- key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
- or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
- :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
- See "Attention Is All You Need" for more details.
- value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
- ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
- sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
- See "Attention Is All You Need" for more details.
- key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
- to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
- Binary and float masks are supported.
- For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
- the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
- need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
- Default: ``True``.
- attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
- :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
- :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
- broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
- Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the
- corresponding position is not allowed to attend. For a float mask, the mask values will be added to
- the attention weight.
- If both attn_mask and key_padding_mask are supplied, their types should match.
- is_causal: If specified, applies a causal mask as attention mask. Mutually exclusive with providing attn_mask.
- Default: ``False``.
- average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
- heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
- effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
- Outputs:
- - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
- :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
- where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
- embedding dimension ``embed_dim``.
- - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
- returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
- :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
- :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
- head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
- .. note::
- `batch_first` argument is ignored for unbatched inputs.
- """
- if attn_mask is not None and is_causal:
- raise AssertionError("Only allow causal mask or attn_mask")
- is_batched = query.dim() == 3
- key_padding_mask = F._canonical_mask(
- mask=key_padding_mask,
- mask_name="key_padding_mask",
- other_type=F._none_or_dtype(attn_mask),
- other_name="attn_mask",
- target_type=query.dtype
- )
- why_not_fast_path = ''
- if not is_batched:
- why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
- elif query is not key or key is not value:
- # When lifting this restriction, don't forget to either
- # enforce that the dtypes all match or test cases where
- # they don't!
- why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
- elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
- why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
- elif self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype:
- # this case will fail anyway, but at least they'll get a useful error message.
- why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
- elif self.training:
- why_not_fast_path = "training is enabled"
- elif not self.batch_first:
- why_not_fast_path = "batch_first was not True"
- elif self.bias_k is not None:
- why_not_fast_path = "self.bias_k was not None"
- elif self.bias_v is not None:
- why_not_fast_path = "self.bias_v was not None"
- elif self.add_zero_attn:
- why_not_fast_path = "add_zero_attn was enabled"
- elif not self._qkv_same_embed_dim:
- why_not_fast_path = "_qkv_same_embed_dim was not True"
- elif query.is_nested and (key_padding_mask is not None or attn_mask is not None):
- why_not_fast_path = "supplying both src_key_padding_mask and src_mask at the same time \
- is not supported with NestedTensor input"
- elif torch.is_autocast_enabled():
- why_not_fast_path = "autocast is enabled"
- if not why_not_fast_path:
- tensor_args = (
- query,
- key,
- value,
- self.in_proj_weight,
- self.in_proj_bias,
- self.out_proj.weight,
- self.out_proj.bias,
- )
- # We have to use list comprehensions below because TorchScript does not support
- # generator expressions.
- if torch.overrides.has_torch_function(tensor_args):
- why_not_fast_path = "some Tensor argument has_torch_function"
- elif not all([(x is None or x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]):
- why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
- elif torch.is_grad_enabled() and any([x is not None and x.requires_grad for x in tensor_args]):
- why_not_fast_path = ("grad is enabled and at least one of query or the "
- "input/output projection weights or biases requires_grad")
- if not why_not_fast_path:
- merged_mask, mask_type = self.merge_masks(attn_mask, key_padding_mask, query)
- return torch._native_multi_head_attention(
- query,
- key,
- value,
- self.embed_dim,
- self.num_heads,
- self.in_proj_weight,
- self.in_proj_bias,
- self.out_proj.weight,
- self.out_proj.bias,
- merged_mask,
- need_weights,
- average_attn_weights,
- mask_type)
- any_nested = query.is_nested or key.is_nested or value.is_nested
- assert not any_nested, ("MultiheadAttention does not support NestedTensor outside of its fast path. " +
- f"The fast path was not hit because {why_not_fast_path}")
- if self.batch_first and is_batched:
- # make sure that the transpose op does not affect the "is" property
- if key is value:
- if query is key:
- query = key = value = query.transpose(1, 0)
- else:
- query, key = [x.transpose(1, 0) for x in (query, key)]
- value = key
- else:
- query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
- if not self._qkv_same_embed_dim:
- attn_output, attn_output_weights = F.multi_head_attention_forward(
- query, key, value, self.embed_dim, self.num_heads,
- self.in_proj_weight, self.in_proj_bias,
- self.bias_k, self.bias_v, self.add_zero_attn,
- self.dropout, self.out_proj.weight, self.out_proj.bias,
- training=self.training,
- key_padding_mask=key_padding_mask, need_weights=need_weights,
- attn_mask=attn_mask,
- use_separate_proj_weight=True,
- q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
- v_proj_weight=self.v_proj_weight,
- average_attn_weights=average_attn_weights,
- is_causal=is_causal)
- else:
- attn_output, attn_output_weights = F.multi_head_attention_forward(
- query, key, value, self.embed_dim, self.num_heads,
- self.in_proj_weight, self.in_proj_bias,
- self.bias_k, self.bias_v, self.add_zero_attn,
- self.dropout, self.out_proj.weight, self.out_proj.bias,
- training=self.training,
- key_padding_mask=key_padding_mask,
- need_weights=need_weights,
- attn_mask=attn_mask,
- average_attn_weights=average_attn_weights,
- is_causal=is_causal)
- if self.batch_first and is_batched:
- return attn_output.transpose(1, 0), attn_output_weights
- else:
- return attn_output, attn_output_weights
- def merge_masks(self, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor],
- query: Tensor) -> Tuple[Optional[Tensor], Optional[int]]:
- r"""
- Determine mask type and combine masks if necessary. If only one mask is provided, that mask
- and the corresponding mask type will be returned. If both masks are provided, they will be both
- expanded to shape ``(batch_size, num_heads, seq_len, seq_len)``, combined with logical ``or``
- and mask type 2 will be returned
- Args:
- attn_mask: attention mask of shape ``(seq_len, seq_len)``, mask type 0
- key_padding_mask: padding mask of shape ``(batch_size, seq_len)``, mask type 1
- query: query embeddings of shape ``(batch_size, seq_len, embed_dim)``
- Returns:
- merged_mask: merged mask
- mask_type: merged mask type (0, 1, or 2)
- """
- mask_type: Optional[int] = None
- merged_mask: Optional[Tensor] = None
- attn_mask = F._canonical_mask(
- mask=attn_mask,
- mask_name="attn_mask",
- other_type=F._none_or_dtype(key_padding_mask),
- other_name="key_padding_mask",
- target_type=query.dtype,
- check_other=False,
- )
- if attn_mask is not None:
- mask_type = 0
- merged_mask = attn_mask
- if key_padding_mask is not None:
- mask_type = 1
- merged_mask = key_padding_mask
- if (attn_mask is not None) and (key_padding_mask is not None):
- # In this branch query can't be a nested tensor, so it has a shape
- batch_size, seq_len, _ = query.shape
- mask_type = 2
- key_padding_mask_expanded = key_padding_mask.view(batch_size, 1, 1, seq_len) \
- .expand(-1, self.num_heads, -1, -1)
- attn_mask_expanded = attn_mask.view(1, 1, seq_len, seq_len).expand(batch_size, self.num_heads, -1, -1)
- merged_mask = attn_mask_expanded + key_padding_mask_expanded
- return merged_mask, mask_type
- class PReLU(Module):
- r"""Applies the element-wise function:
- .. math::
- \text{PReLU}(x) = \max(0,x) + a * \min(0,x)
- or
- .. math::
- \text{PReLU}(x) =
- \begin{cases}
- x, & \text{ if } x \geq 0 \\
- ax, & \text{ otherwise }
- \end{cases}
- Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single
- parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`,
- a separate :math:`a` is used for each input channel.
- .. note::
- weight decay should not be used when learning :math:`a` for good performance.
- .. note::
- Channel dim is the 2nd dim of input. When input has dims < 2, then there is
- no channel dim and the number of channels = 1.
- Args:
- num_parameters (int): number of :math:`a` to learn.
- Although it takes an int as input, there is only two values are legitimate:
- 1, or the number of channels at input. Default: 1
- init (float): the initial value of :math:`a`. Default: 0.25
- Shape:
- - Input: :math:`( *)` where `*` means, any number of additional
- dimensions.
- - Output: :math:`(*)`, same shape as the input.
- Attributes:
- weight (Tensor): the learnable weights of shape (:attr:`num_parameters`).
- .. image:: ../scripts/activation_images/PReLU.png
- Examples::
- >>> m = nn.PReLU()
- >>> input = torch.randn(2)
- >>> output = m(input)
- """
- __constants__ = ['num_parameters']
- num_parameters: int
- def __init__(self, num_parameters: int = 1, init: float = 0.25,
- device=None, dtype=None) -> None:
- factory_kwargs = {'device': device, 'dtype': dtype}
- self.num_parameters = num_parameters
- super().__init__()
- self.weight = Parameter(torch.empty(num_parameters, **factory_kwargs).fill_(init))
- def forward(self, input: Tensor) -> Tensor:
- return F.prelu(input, self.weight)
- def extra_repr(self) -> str:
- return 'num_parameters={}'.format(self.num_parameters)
- class Softsign(Module):
- r"""Applies the element-wise function:
- .. math::
- \text{SoftSign}(x) = \frac{x}{ 1 + |x|}
- Shape:
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- - Output: :math:`(*)`, same shape as the input.
- .. image:: ../scripts/activation_images/Softsign.png
- Examples::
- >>> m = nn.Softsign()
- >>> input = torch.randn(2)
- >>> output = m(input)
- """
- def forward(self, input: Tensor) -> Tensor:
- return F.softsign(input)
- class Tanhshrink(Module):
- r"""Applies the element-wise function:
- .. math::
- \text{Tanhshrink}(x) = x - \tanh(x)
- Shape:
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- - Output: :math:`(*)`, same shape as the input.
- .. image:: ../scripts/activation_images/Tanhshrink.png
- Examples::
- >>> m = nn.Tanhshrink()
- >>> input = torch.randn(2)
- >>> output = m(input)
- """
- def forward(self, input: Tensor) -> Tensor:
- return F.tanhshrink(input)
- class Softmin(Module):
- r"""Applies the Softmin function to an n-dimensional input Tensor
- rescaling them so that the elements of the n-dimensional output Tensor
- lie in the range `[0, 1]` and sum to 1.
- Softmin is defined as:
- .. math::
- \text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
- Shape:
- - Input: :math:`(*)` where `*` means, any number of additional
- dimensions
- - Output: :math:`(*)`, same shape as the input
- Args:
- dim (int): A dimension along which Softmin will be computed (so every slice
- along dim will sum to 1).
- Returns:
- a Tensor of the same dimension and shape as the input, with
- values in the range [0, 1]
- Examples::
- >>> m = nn.Softmin(dim=1)
- >>> input = torch.randn(2, 3)
- >>> output = m(input)
- """
- __constants__ = ['dim']
- dim: Optional[int]
- def __init__(self, dim: Optional[int] = None) -> None:
- super().__init__()
- self.dim = dim
- def __setstate__(self, state):
- super().__setstate__(state)
- if not hasattr(self, 'dim'):
- self.dim = None
- def forward(self, input: Tensor) -> Tensor:
- return F.softmin(input, self.dim, _stacklevel=5)
- def extra_repr(self):
- return 'dim={dim}'.format(dim=self.dim)
- class Softmax(Module):
- r"""Applies the Softmax function to an n-dimensional input Tensor
- rescaling them so that the elements of the n-dimensional output Tensor
- lie in the range [0,1] and sum to 1.
- Softmax is defined as:
- .. math::
- \text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
- When the input Tensor is a sparse tensor then the unspecified
- values are treated as ``-inf``.
- Shape:
- - Input: :math:`(*)` where `*` means, any number of additional
- dimensions
- - Output: :math:`(*)`, same shape as the input
- Returns:
- a Tensor of the same dimension and shape as the input with
- values in the range [0, 1]
- Args:
- dim (int): A dimension along which Softmax will be computed (so every slice
- along dim will sum to 1).
- .. note::
- This module doesn't work directly with NLLLoss,
- which expects the Log to be computed between the Softmax and itself.
- Use `LogSoftmax` instead (it's faster and has better numerical properties).
- Examples::
- >>> m = nn.Softmax(dim=1)
- >>> input = torch.randn(2, 3)
- >>> output = m(input)
- """
- __constants__ = ['dim']
- dim: Optional[int]
- def __init__(self, dim: Optional[int] = None) -> None:
- super().__init__()
- self.dim = dim
- def __setstate__(self, state):
- super().__setstate__(state)
- if not hasattr(self, 'dim'):
- self.dim = None
- def forward(self, input: Tensor) -> Tensor:
- return F.softmax(input, self.dim, _stacklevel=5)
- def extra_repr(self) -> str:
- return 'dim={dim}'.format(dim=self.dim)
- class Softmax2d(Module):
- r"""Applies SoftMax over features to each spatial location.
- When given an image of ``Channels x Height x Width``, it will
- apply `Softmax` to each location :math:`(Channels, h_i, w_j)`
- Shape:
- - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`.
- - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input)
- Returns:
- a Tensor of the same dimension and shape as the input with
- values in the range [0, 1]
- Examples::
- >>> m = nn.Softmax2d()
- >>> # you softmax over the 2nd dimension
- >>> input = torch.randn(2, 3, 12, 13)
- >>> output = m(input)
- """
- def forward(self, input: Tensor) -> Tensor:
- assert input.dim() == 4 or input.dim() == 3, 'Softmax2d requires a 3D or 4D tensor as input'
- return F.softmax(input, -3, _stacklevel=5)
- class LogSoftmax(Module):
- r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional
- input Tensor. The LogSoftmax formulation can be simplified as:
- .. math::
- \text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)
- Shape:
- - Input: :math:`(*)` where `*` means, any number of additional
- dimensions
- - Output: :math:`(*)`, same shape as the input
- Args:
- dim (int): A dimension along which LogSoftmax will be computed.
- Returns:
- a Tensor of the same dimension and shape as the input with
- values in the range [-inf, 0)
- Examples::
- >>> m = nn.LogSoftmax(dim=1)
- >>> input = torch.randn(2, 3)
- >>> output = m(input)
- """
- __constants__ = ['dim']
- dim: Optional[int]
- def __init__(self, dim: Optional[int] = None) -> None:
- super().__init__()
- self.dim = dim
- def __setstate__(self, state):
- super().__setstate__(state)
- if not hasattr(self, 'dim'):
- self.dim = None
- def forward(self, input: Tensor) -> Tensor:
- return F.log_softmax(input, self.dim, _stacklevel=5)
- def extra_repr(self):
- return 'dim={dim}'.format(dim=self.dim)
|