1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336 |
- import math
- import warnings
- import numbers
- import weakref
- from typing import List, Tuple, Optional, overload
- import torch
- from torch import Tensor
- from .module import Module
- from ..parameter import Parameter
- from ..utils.rnn import PackedSequence
- from .. import init
- from ... import _VF
- __all__ = ['RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCellBase', 'RNNCell', 'LSTMCell', 'GRUCell']
- _rnn_impls = {
- 'RNN_TANH': _VF.rnn_tanh,
- 'RNN_RELU': _VF.rnn_relu,
- }
- def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
- return tensor.index_select(dim, permutation)
- def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
- warnings.warn("apply_permutation is deprecated, please use tensor.index_select(dim, permutation) instead")
- return _apply_permutation(tensor, permutation, dim)
- class RNNBase(Module):
- __constants__ = ['mode', 'input_size', 'hidden_size', 'num_layers', 'bias',
- 'batch_first', 'dropout', 'bidirectional', 'proj_size']
- __jit_unused_properties__ = ['all_weights']
- mode: str
- input_size: int
- hidden_size: int
- num_layers: int
- bias: bool
- batch_first: bool
- dropout: float
- bidirectional: bool
- proj_size: int
- def __init__(self, mode: str, input_size: int, hidden_size: int,
- num_layers: int = 1, bias: bool = True, batch_first: bool = False,
- dropout: float = 0., bidirectional: bool = False, proj_size: int = 0,
- device=None, dtype=None) -> None:
- factory_kwargs = {'device': device, 'dtype': dtype}
- super().__init__()
- self.mode = mode
- self.input_size = input_size
- self.hidden_size = hidden_size
- self.num_layers = num_layers
- self.bias = bias
- self.batch_first = batch_first
- self.dropout = float(dropout)
- self.bidirectional = bidirectional
- self.proj_size = proj_size
- self._flat_weight_refs: List[Optional[weakref.ReferenceType["Parameter"]]] = []
- num_directions = 2 if bidirectional else 1
- if not isinstance(dropout, numbers.Number) or not 0 <= dropout <= 1 or \
- isinstance(dropout, bool):
- raise ValueError("dropout should be a number in range [0, 1] "
- "representing the probability of an element being "
- "zeroed")
- if dropout > 0 and num_layers == 1:
- warnings.warn("dropout option adds dropout after all but last "
- "recurrent layer, so non-zero dropout expects "
- "num_layers greater than 1, but got dropout={} and "
- "num_layers={}".format(dropout, num_layers))
- if proj_size < 0:
- raise ValueError("proj_size should be a positive integer or zero to disable projections")
- if proj_size >= hidden_size:
- raise ValueError("proj_size has to be smaller than hidden_size")
- if mode == 'LSTM':
- gate_size = 4 * hidden_size
- elif mode == 'GRU':
- gate_size = 3 * hidden_size
- elif mode == 'RNN_TANH':
- gate_size = hidden_size
- elif mode == 'RNN_RELU':
- gate_size = hidden_size
- else:
- raise ValueError("Unrecognized RNN mode: " + mode)
- self._flat_weights_names = []
- self._all_weights = []
- for layer in range(num_layers):
- for direction in range(num_directions):
- real_hidden_size = proj_size if proj_size > 0 else hidden_size
- layer_input_size = input_size if layer == 0 else real_hidden_size * num_directions
- w_ih = Parameter(torch.empty((gate_size, layer_input_size), **factory_kwargs))
- w_hh = Parameter(torch.empty((gate_size, real_hidden_size), **factory_kwargs))
- b_ih = Parameter(torch.empty(gate_size, **factory_kwargs))
- # Second bias vector included for CuDNN compatibility. Only one
- # bias vector is needed in standard definition.
- b_hh = Parameter(torch.empty(gate_size, **factory_kwargs))
- layer_params: Tuple[Tensor, ...] = ()
- if self.proj_size == 0:
- if bias:
- layer_params = (w_ih, w_hh, b_ih, b_hh)
- else:
- layer_params = (w_ih, w_hh)
- else:
- w_hr = Parameter(torch.empty((proj_size, hidden_size), **factory_kwargs))
- if bias:
- layer_params = (w_ih, w_hh, b_ih, b_hh, w_hr)
- else:
- layer_params = (w_ih, w_hh, w_hr)
- suffix = '_reverse' if direction == 1 else ''
- param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}']
- if bias:
- param_names += ['bias_ih_l{}{}', 'bias_hh_l{}{}']
- if self.proj_size > 0:
- param_names += ['weight_hr_l{}{}']
- param_names = [x.format(layer, suffix) for x in param_names]
- for name, param in zip(param_names, layer_params):
- setattr(self, name, param)
- self._flat_weights_names.extend(param_names)
- self._all_weights.append(param_names)
- self._init_flat_weights()
- self.reset_parameters()
- def _init_flat_weights(self):
- self._flat_weights = [getattr(self, wn) if hasattr(self, wn) else None
- for wn in self._flat_weights_names]
- self._flat_weight_refs = [weakref.ref(w) if w is not None else None
- for w in self._flat_weights]
- self.flatten_parameters()
- def __setattr__(self, attr, value):
- if hasattr(self, "_flat_weights_names") and attr in self._flat_weights_names:
- # keep self._flat_weights up to date if you do self.weight = ...
- idx = self._flat_weights_names.index(attr)
- self._flat_weights[idx] = value
- super().__setattr__(attr, value)
- def flatten_parameters(self) -> None:
- """Resets parameter data pointer so that they can use faster code paths.
- Right now, this works only if the module is on the GPU and cuDNN is enabled.
- Otherwise, it's a no-op.
- """
- # Short-circuits if _flat_weights is only partially instantiated
- if len(self._flat_weights) != len(self._flat_weights_names):
- return
- for w in self._flat_weights:
- if not isinstance(w, Tensor):
- return
- # Short-circuits if any tensor in self._flat_weights is not acceptable to cuDNN
- # or the tensors in _flat_weights are of different dtypes
- first_fw = self._flat_weights[0]
- dtype = first_fw.dtype
- for fw in self._flat_weights:
- if (not isinstance(fw.data, Tensor) or not (fw.data.dtype == dtype) or
- not fw.data.is_cuda or
- not torch.backends.cudnn.is_acceptable(fw.data)):
- return
- # If any parameters alias, we fall back to the slower, copying code path. This is
- # a sufficient check, because overlapping parameter buffers that don't completely
- # alias would break the assumptions of the uniqueness check in
- # Module.named_parameters().
- unique_data_ptrs = {p.data_ptr() for p in self._flat_weights}
- if len(unique_data_ptrs) != len(self._flat_weights):
- return
- with torch.cuda.device_of(first_fw):
- import torch.backends.cudnn.rnn as rnn
- # Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is
- # an inplace operation on self._flat_weights
- with torch.no_grad():
- if torch._use_cudnn_rnn_flatten_weight():
- num_weights = 4 if self.bias else 2
- if self.proj_size > 0:
- num_weights += 1
- torch._cudnn_rnn_flatten_weight(
- self._flat_weights, num_weights,
- self.input_size, rnn.get_cudnn_mode(self.mode),
- self.hidden_size, self.proj_size, self.num_layers,
- self.batch_first, bool(self.bidirectional))
- def _apply(self, fn):
- ret = super()._apply(fn)
- # Resets _flat_weights
- # Note: be v. careful before removing this, as 3rd party device types
- # likely rely on this behavior to properly .to() modules like LSTM.
- self._init_flat_weights()
- return ret
- def reset_parameters(self) -> None:
- stdv = 1.0 / math.sqrt(self.hidden_size) if self.hidden_size > 0 else 0
- for weight in self.parameters():
- init.uniform_(weight, -stdv, stdv)
- def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None:
- expected_input_dim = 2 if batch_sizes is not None else 3
- if input.dim() != expected_input_dim:
- raise RuntimeError(
- 'input must have {} dimensions, got {}'.format(
- expected_input_dim, input.dim()))
- if self.input_size != input.size(-1):
- raise RuntimeError(
- 'input.size(-1) must be equal to input_size. Expected {}, got {}'.format(
- self.input_size, input.size(-1)))
- def get_expected_hidden_size(self, input: Tensor, batch_sizes: Optional[Tensor]) -> Tuple[int, int, int]:
- if batch_sizes is not None:
- mini_batch = int(batch_sizes[0])
- else:
- mini_batch = input.size(0) if self.batch_first else input.size(1)
- num_directions = 2 if self.bidirectional else 1
- if self.proj_size > 0:
- expected_hidden_size = (self.num_layers * num_directions,
- mini_batch, self.proj_size)
- else:
- expected_hidden_size = (self.num_layers * num_directions,
- mini_batch, self.hidden_size)
- return expected_hidden_size
- def check_hidden_size(self, hx: Tensor, expected_hidden_size: Tuple[int, int, int],
- msg: str = 'Expected hidden size {}, got {}') -> None:
- if hx.size() != expected_hidden_size:
- raise RuntimeError(msg.format(expected_hidden_size, list(hx.size())))
- def _weights_have_changed(self):
- # Returns True if the weight tensors have changed since the last forward pass.
- # This is the case when used with torch.func.functional_call(), for example.
- weights_changed = False
- for ref, name in zip(self._flat_weight_refs, self._flat_weights_names):
- weight = getattr(self, name) if hasattr(self, name) else None
- if weight is not None and ref is not None and ref() is not weight:
- weights_changed = True
- break
- return weights_changed
- def check_forward_args(self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor]):
- self.check_input(input, batch_sizes)
- expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
- self.check_hidden_size(hidden, expected_hidden_size)
- def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]):
- if permutation is None:
- return hx
- return _apply_permutation(hx, permutation)
- def extra_repr(self) -> str:
- s = '{input_size}, {hidden_size}'
- if self.proj_size != 0:
- s += ', proj_size={proj_size}'
- if self.num_layers != 1:
- s += ', num_layers={num_layers}'
- if self.bias is not True:
- s += ', bias={bias}'
- if self.batch_first is not False:
- s += ', batch_first={batch_first}'
- if self.dropout != 0:
- s += ', dropout={dropout}'
- if self.bidirectional is not False:
- s += ', bidirectional={bidirectional}'
- return s.format(**self.__dict__)
- def __getstate__(self):
- # Don't serialize the weight references.
- state = self.__dict__.copy()
- del state['_flat_weight_refs']
- return state
- def __setstate__(self, d):
- super().__setstate__(d)
- if 'all_weights' in d:
- self._all_weights = d['all_weights']
- # In PyTorch 1.8 we added a proj_size member variable to LSTM.
- # LSTMs that were serialized via torch.save(module) before PyTorch 1.8
- # don't have it, so to preserve compatibility we set proj_size here.
- if 'proj_size' not in d:
- self.proj_size = 0
- if not isinstance(self._all_weights[0][0], str):
- num_layers = self.num_layers
- num_directions = 2 if self.bidirectional else 1
- self._flat_weights_names = []
- self._all_weights = []
- for layer in range(num_layers):
- for direction in range(num_directions):
- suffix = '_reverse' if direction == 1 else ''
- weights = ['weight_ih_l{}{}', 'weight_hh_l{}{}', 'bias_ih_l{}{}',
- 'bias_hh_l{}{}', 'weight_hr_l{}{}']
- weights = [x.format(layer, suffix) for x in weights]
- if self.bias:
- if self.proj_size > 0:
- self._all_weights += [weights]
- self._flat_weights_names.extend(weights)
- else:
- self._all_weights += [weights[:4]]
- self._flat_weights_names.extend(weights[:4])
- else:
- if self.proj_size > 0:
- self._all_weights += [weights[:2]] + [weights[-1:]]
- self._flat_weights_names.extend(weights[:2] + [weights[-1:]])
- else:
- self._all_weights += [weights[:2]]
- self._flat_weights_names.extend(weights[:2])
- self._flat_weights = [getattr(self, wn) if hasattr(self, wn) else None
- for wn in self._flat_weights_names]
- self._flat_weight_refs = [weakref.ref(w) if w is not None else None
- for w in self._flat_weights]
- @property
- def all_weights(self) -> List[List[Parameter]]:
- return [[getattr(self, weight) for weight in weights] for weights in self._all_weights]
- def _replicate_for_data_parallel(self):
- replica = super()._replicate_for_data_parallel()
- # Need to copy these caches, otherwise the replica will share the same
- # flat weights list.
- replica._flat_weights = replica._flat_weights[:]
- replica._flat_weights_names = replica._flat_weights_names[:]
- return replica
- class RNN(RNNBase):
- r"""Applies a multi-layer Elman RNN with :math:`\tanh` or :math:`\text{ReLU}` non-linearity to an
- input sequence.
- For each element in the input sequence, each layer computes the following
- function:
- .. math::
- h_t = \tanh(x_t W_{ih}^T + b_{ih} + h_{t-1}W_{hh}^T + b_{hh})
- where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is
- the input at time `t`, and :math:`h_{(t-1)}` is the hidden state of the
- previous layer at time `t-1` or the initial hidden state at time `0`.
- If :attr:`nonlinearity` is ``'relu'``, then :math:`\text{ReLU}` is used instead of :math:`\tanh`.
- Args:
- input_size: The number of expected features in the input `x`
- hidden_size: The number of features in the hidden state `h`
- num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
- would mean stacking two RNNs together to form a `stacked RNN`,
- with the second RNN taking in outputs of the first RNN and
- computing the final results. Default: 1
- nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'``
- bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
- Default: ``True``
- batch_first: If ``True``, then the input and output tensors are provided
- as `(batch, seq, feature)` instead of `(seq, batch, feature)`.
- Note that this does not apply to hidden or cell states. See the
- Inputs/Outputs sections below for details. Default: ``False``
- dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
- RNN layer except the last layer, with dropout probability equal to
- :attr:`dropout`. Default: 0
- bidirectional: If ``True``, becomes a bidirectional RNN. Default: ``False``
- Inputs: input, h_0
- * **input**: tensor of shape :math:`(L, H_{in})` for unbatched input,
- :math:`(L, N, H_{in})` when ``batch_first=False`` or
- :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of
- the input sequence. The input can also be a packed variable length sequence.
- See :func:`torch.nn.utils.rnn.pack_padded_sequence` or
- :func:`torch.nn.utils.rnn.pack_sequence` for details.
- * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or
- :math:`(D * \text{num\_layers}, N, H_{out})` containing the initial hidden
- state for the input sequence batch. Defaults to zeros if not provided.
- where:
- .. math::
- \begin{aligned}
- N ={} & \text{batch size} \\
- L ={} & \text{sequence length} \\
- D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\
- H_{in} ={} & \text{input\_size} \\
- H_{out} ={} & \text{hidden\_size}
- \end{aligned}
- Outputs: output, h_n
- * **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input,
- :math:`(L, N, D * H_{out})` when ``batch_first=False`` or
- :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features
- `(h_t)` from the last layer of the RNN, for each `t`. If a
- :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output
- will also be a packed sequence.
- * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or
- :math:`(D * \text{num\_layers}, N, H_{out})` containing the final hidden state
- for each element in the batch.
- Attributes:
- weight_ih_l[k]: the learnable input-hidden weights of the k-th layer,
- of shape `(hidden_size, input_size)` for `k = 0`. Otherwise, the shape is
- `(hidden_size, num_directions * hidden_size)`
- weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer,
- of shape `(hidden_size, hidden_size)`
- bias_ih_l[k]: the learnable input-hidden bias of the k-th layer,
- of shape `(hidden_size)`
- bias_hh_l[k]: the learnable hidden-hidden bias of the k-th layer,
- of shape `(hidden_size)`
- .. note::
- All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
- where :math:`k = \frac{1}{\text{hidden\_size}}`
- .. note::
- For bidirectional RNNs, forward and backward are directions 0 and 1 respectively.
- Example of splitting the output layers when ``batch_first=False``:
- ``output.view(seq_len, batch, num_directions, hidden_size)``.
- .. note::
- ``batch_first`` argument is ignored for unbatched inputs.
- .. include:: ../cudnn_rnn_determinism.rst
- .. include:: ../cudnn_persistent_rnn.rst
- Examples::
- >>> rnn = nn.RNN(10, 20, 2)
- >>> input = torch.randn(5, 3, 10)
- >>> h0 = torch.randn(2, 3, 20)
- >>> output, hn = rnn(input, h0)
- """
- def __init__(self, *args, **kwargs):
- if 'proj_size' in kwargs:
- raise ValueError("proj_size argument is only supported for LSTM, not RNN or GRU")
- self.nonlinearity = kwargs.pop('nonlinearity', 'tanh')
- if self.nonlinearity == 'tanh':
- mode = 'RNN_TANH'
- elif self.nonlinearity == 'relu':
- mode = 'RNN_RELU'
- else:
- raise ValueError("Unknown nonlinearity '{}'".format(self.nonlinearity))
- super().__init__(mode, *args, **kwargs)
- @overload
- @torch._jit_internal._overload_method # noqa: F811
- def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
- pass
- @overload
- @torch._jit_internal._overload_method # noqa: F811
- def forward(self, input: PackedSequence, hx: Optional[Tensor] = None) -> Tuple[PackedSequence, Tensor]:
- pass
- def forward(self, input, hx=None): # noqa: F811
- if not torch.jit.is_scripting():
- if self._weights_have_changed():
- self._init_flat_weights()
- orig_input = input
- if isinstance(orig_input, PackedSequence):
- input, batch_sizes, sorted_indices, unsorted_indices = input
- max_batch_size = int(batch_sizes[0])
- else:
- batch_sizes = None
- assert (input.dim() in (2, 3)), f"RNN: Expected input to be 2-D or 3-D but received {input.dim()}-D tensor"
- is_batched = input.dim() == 3
- batch_dim = 0 if self.batch_first else 1
- if not is_batched:
- input = input.unsqueeze(batch_dim)
- if hx is not None:
- if hx.dim() != 2:
- raise RuntimeError(
- f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor")
- hx = hx.unsqueeze(1)
- else:
- if hx is not None and hx.dim() != 3:
- raise RuntimeError(
- f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor")
- max_batch_size = input.size(0) if self.batch_first else input.size(1)
- sorted_indices = None
- unsorted_indices = None
- if hx is None:
- num_directions = 2 if self.bidirectional else 1
- hx = torch.zeros(self.num_layers * num_directions,
- max_batch_size, self.hidden_size,
- dtype=input.dtype, device=input.device)
- else:
- # Each batch of the hidden state should match the input sequence that
- # the user believes he/she is passing in.
- hx = self.permute_hidden(hx, sorted_indices)
- assert hx is not None
- self.check_forward_args(input, hx, batch_sizes)
- assert self.mode == 'RNN_TANH' or self.mode == 'RNN_RELU'
- if batch_sizes is None:
- if self.mode == 'RNN_TANH':
- result = _VF.rnn_tanh(input, hx, self._flat_weights, self.bias, self.num_layers,
- self.dropout, self.training, self.bidirectional,
- self.batch_first)
- else:
- result = _VF.rnn_relu(input, hx, self._flat_weights, self.bias, self.num_layers,
- self.dropout, self.training, self.bidirectional,
- self.batch_first)
- else:
- if self.mode == 'RNN_TANH':
- result = _VF.rnn_tanh(input, batch_sizes, hx, self._flat_weights, self.bias,
- self.num_layers, self.dropout, self.training,
- self.bidirectional)
- else:
- result = _VF.rnn_relu(input, batch_sizes, hx, self._flat_weights, self.bias,
- self.num_layers, self.dropout, self.training,
- self.bidirectional)
- output = result[0]
- hidden = result[1]
- if isinstance(orig_input, PackedSequence):
- output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
- return output_packed, self.permute_hidden(hidden, unsorted_indices)
- if not is_batched:
- output = output.squeeze(batch_dim)
- hidden = hidden.squeeze(1)
- return output, self.permute_hidden(hidden, unsorted_indices)
- # XXX: LSTM and GRU implementation is different from RNNBase, this is because:
- # 1. we want to support nn.LSTM and nn.GRU in TorchScript and TorchScript in
- # its current state could not support the python Union Type or Any Type
- # 2. TorchScript static typing does not allow a Function or Callable type in
- # Dict values, so we have to separately call _VF instead of using _rnn_impls
- # 3. This is temporary only and in the transition state that we want to make it
- # on time for the release
- #
- # More discussion details in https://github.com/pytorch/pytorch/pull/23266
- #
- # TODO: remove the overriding implementations for LSTM and GRU when TorchScript
- # support expressing these two modules generally.
- class LSTM(RNNBase):
- r"""Applies a multi-layer long short-term memory (LSTM) RNN to an input
- sequence.
- For each element in the input sequence, each layer computes the following
- function:
- .. math::
- \begin{array}{ll} \\
- i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\
- f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\
- g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\
- o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\
- c_t = f_t \odot c_{t-1} + i_t \odot g_t \\
- h_t = o_t \odot \tanh(c_t) \\
- \end{array}
- where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the cell
- state at time `t`, :math:`x_t` is the input at time `t`, :math:`h_{t-1}`
- is the hidden state of the layer at time `t-1` or the initial hidden
- state at time `0`, and :math:`i_t`, :math:`f_t`, :math:`g_t`,
- :math:`o_t` are the input, forget, cell, and output gates, respectively.
- :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product.
- In a multilayer LSTM, the input :math:`x^{(l)}_t` of the :math:`l` -th layer
- (:math:`l >= 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by
- dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random
- variable which is :math:`0` with probability :attr:`dropout`.
- If ``proj_size > 0`` is specified, LSTM with projections will be used. This changes
- the LSTM cell in the following way. First, the dimension of :math:`h_t` will be changed from
- ``hidden_size`` to ``proj_size`` (dimensions of :math:`W_{hi}` will be changed accordingly).
- Second, the output hidden state of each layer will be multiplied by a learnable projection
- matrix: :math:`h_t = W_{hr}h_t`. Note that as a consequence of this, the output
- of LSTM network will be of different shape as well. See Inputs/Outputs sections below for exact
- dimensions of all variables. You can find more details in https://arxiv.org/abs/1402.1128.
- Args:
- input_size: The number of expected features in the input `x`
- hidden_size: The number of features in the hidden state `h`
- num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
- would mean stacking two LSTMs together to form a `stacked LSTM`,
- with the second LSTM taking in outputs of the first LSTM and
- computing the final results. Default: 1
- bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
- Default: ``True``
- batch_first: If ``True``, then the input and output tensors are provided
- as `(batch, seq, feature)` instead of `(seq, batch, feature)`.
- Note that this does not apply to hidden or cell states. See the
- Inputs/Outputs sections below for details. Default: ``False``
- dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
- LSTM layer except the last layer, with dropout probability equal to
- :attr:`dropout`. Default: 0
- bidirectional: If ``True``, becomes a bidirectional LSTM. Default: ``False``
- proj_size: If ``> 0``, will use LSTM with projections of corresponding size. Default: 0
- Inputs: input, (h_0, c_0)
- * **input**: tensor of shape :math:`(L, H_{in})` for unbatched input,
- :math:`(L, N, H_{in})` when ``batch_first=False`` or
- :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of
- the input sequence. The input can also be a packed variable length sequence.
- See :func:`torch.nn.utils.rnn.pack_padded_sequence` or
- :func:`torch.nn.utils.rnn.pack_sequence` for details.
- * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or
- :math:`(D * \text{num\_layers}, N, H_{out})` containing the
- initial hidden state for each element in the input sequence.
- Defaults to zeros if (h_0, c_0) is not provided.
- * **c_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{cell})` for unbatched input or
- :math:`(D * \text{num\_layers}, N, H_{cell})` containing the
- initial cell state for each element in the input sequence.
- Defaults to zeros if (h_0, c_0) is not provided.
- where:
- .. math::
- \begin{aligned}
- N ={} & \text{batch size} \\
- L ={} & \text{sequence length} \\
- D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\
- H_{in} ={} & \text{input\_size} \\
- H_{cell} ={} & \text{hidden\_size} \\
- H_{out} ={} & \text{proj\_size if } \text{proj\_size}>0 \text{ otherwise hidden\_size} \\
- \end{aligned}
- Outputs: output, (h_n, c_n)
- * **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input,
- :math:`(L, N, D * H_{out})` when ``batch_first=False`` or
- :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features
- `(h_t)` from the last layer of the LSTM, for each `t`. If a
- :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output
- will also be a packed sequence. When ``bidirectional=True``, `output` will contain
- a concatenation of the forward and reverse hidden states at each time step in the sequence.
- * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or
- :math:`(D * \text{num\_layers}, N, H_{out})` containing the
- final hidden state for each element in the sequence. When ``bidirectional=True``,
- `h_n` will contain a concatenation of the final forward and reverse hidden states, respectively.
- * **c_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{cell})` for unbatched input or
- :math:`(D * \text{num\_layers}, N, H_{cell})` containing the
- final cell state for each element in the sequence. When ``bidirectional=True``,
- `c_n` will contain a concatenation of the final forward and reverse cell states, respectively.
- Attributes:
- weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer
- `(W_ii|W_if|W_ig|W_io)`, of shape `(4*hidden_size, input_size)` for `k = 0`.
- Otherwise, the shape is `(4*hidden_size, num_directions * hidden_size)`. If
- ``proj_size > 0`` was specified, the shape will be
- `(4*hidden_size, num_directions * proj_size)` for `k > 0`
- weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer
- `(W_hi|W_hf|W_hg|W_ho)`, of shape `(4*hidden_size, hidden_size)`. If ``proj_size > 0``
- was specified, the shape will be `(4*hidden_size, proj_size)`.
- bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer
- `(b_ii|b_if|b_ig|b_io)`, of shape `(4*hidden_size)`
- bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer
- `(b_hi|b_hf|b_hg|b_ho)`, of shape `(4*hidden_size)`
- weight_hr_l[k] : the learnable projection weights of the :math:`\text{k}^{th}` layer
- of shape `(proj_size, hidden_size)`. Only present when ``proj_size > 0`` was
- specified.
- weight_ih_l[k]_reverse: Analogous to `weight_ih_l[k]` for the reverse direction.
- Only present when ``bidirectional=True``.
- weight_hh_l[k]_reverse: Analogous to `weight_hh_l[k]` for the reverse direction.
- Only present when ``bidirectional=True``.
- bias_ih_l[k]_reverse: Analogous to `bias_ih_l[k]` for the reverse direction.
- Only present when ``bidirectional=True``.
- bias_hh_l[k]_reverse: Analogous to `bias_hh_l[k]` for the reverse direction.
- Only present when ``bidirectional=True``.
- weight_hr_l[k]_reverse: Analogous to `weight_hr_l[k]` for the reverse direction.
- Only present when ``bidirectional=True`` and ``proj_size > 0`` was specified.
- .. note::
- All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
- where :math:`k = \frac{1}{\text{hidden\_size}}`
- .. note::
- For bidirectional LSTMs, forward and backward are directions 0 and 1 respectively.
- Example of splitting the output layers when ``batch_first=False``:
- ``output.view(seq_len, batch, num_directions, hidden_size)``.
- .. note::
- For bidirectional LSTMs, `h_n` is not equivalent to the last element of `output`; the
- former contains the final forward and reverse hidden states, while the latter contains the
- final forward hidden state and the initial reverse hidden state.
- .. note::
- ``batch_first`` argument is ignored for unbatched inputs.
- .. include:: ../cudnn_rnn_determinism.rst
- .. include:: ../cudnn_persistent_rnn.rst
- Examples::
- >>> rnn = nn.LSTM(10, 20, 2)
- >>> input = torch.randn(5, 3, 10)
- >>> h0 = torch.randn(2, 3, 20)
- >>> c0 = torch.randn(2, 3, 20)
- >>> output, (hn, cn) = rnn(input, (h0, c0))
- """
- def __init__(self, *args, **kwargs):
- super().__init__('LSTM', *args, **kwargs)
- def get_expected_cell_size(self, input: Tensor, batch_sizes: Optional[Tensor]) -> Tuple[int, int, int]:
- if batch_sizes is not None:
- mini_batch = int(batch_sizes[0])
- else:
- mini_batch = input.size(0) if self.batch_first else input.size(1)
- num_directions = 2 if self.bidirectional else 1
- expected_hidden_size = (self.num_layers * num_directions,
- mini_batch, self.hidden_size)
- return expected_hidden_size
- # In the future, we should prevent mypy from applying contravariance rules here.
- # See torch/nn/modules/module.py::_forward_unimplemented
- def check_forward_args(self, # type: ignore[override]
- input: Tensor,
- hidden: Tuple[Tensor, Tensor],
- batch_sizes: Optional[Tensor],
- ):
- self.check_input(input, batch_sizes)
- self.check_hidden_size(hidden[0], self.get_expected_hidden_size(input, batch_sizes),
- 'Expected hidden[0] size {}, got {}')
- self.check_hidden_size(hidden[1], self.get_expected_cell_size(input, batch_sizes),
- 'Expected hidden[1] size {}, got {}')
- # Same as above, see torch/nn/modules/module.py::_forward_unimplemented
- def permute_hidden(self, # type: ignore[override]
- hx: Tuple[Tensor, Tensor],
- permutation: Optional[Tensor]
- ) -> Tuple[Tensor, Tensor]:
- if permutation is None:
- return hx
- return _apply_permutation(hx[0], permutation), _apply_permutation(hx[1], permutation)
- # Same as above, see torch/nn/modules/module.py::_forward_unimplemented
- @overload # type: ignore[override]
- @torch._jit_internal._overload_method # noqa: F811
- def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None
- ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: # noqa: F811
- pass
- # Same as above, see torch/nn/modules/module.py::_forward_unimplemented
- @overload
- @torch._jit_internal._overload_method # noqa: F811
- def forward(self, input: PackedSequence, hx: Optional[Tuple[Tensor, Tensor]] = None
- ) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]]: # noqa: F811
- pass
- def forward(self, input, hx=None): # noqa: F811
- if not torch.jit.is_scripting():
- if self._weights_have_changed():
- self._init_flat_weights()
- orig_input = input
- # xxx: isinstance check needs to be in conditional for TorchScript to compile
- batch_sizes = None
- if isinstance(orig_input, PackedSequence):
- input, batch_sizes, sorted_indices, unsorted_indices = input
- max_batch_size = batch_sizes[0]
- max_batch_size = int(max_batch_size)
- else:
- batch_sizes = None
- assert (input.dim() in (2, 3)), f"LSTM: Expected input to be 2-D or 3-D but received {input.dim()}-D tensor"
- is_batched = input.dim() == 3
- batch_dim = 0 if self.batch_first else 1
- if not is_batched:
- input = input.unsqueeze(batch_dim)
- max_batch_size = input.size(0) if self.batch_first else input.size(1)
- sorted_indices = None
- unsorted_indices = None
- if hx is None:
- num_directions = 2 if self.bidirectional else 1
- real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size
- h_zeros = torch.zeros(self.num_layers * num_directions,
- max_batch_size, real_hidden_size,
- dtype=input.dtype, device=input.device)
- c_zeros = torch.zeros(self.num_layers * num_directions,
- max_batch_size, self.hidden_size,
- dtype=input.dtype, device=input.device)
- hx = (h_zeros, c_zeros)
- else:
- if batch_sizes is None: # If not PackedSequence input.
- if is_batched:
- if (hx[0].dim() != 3 or hx[1].dim() != 3):
- msg = ("For batched 3-D input, hx and cx should "
- f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors")
- raise RuntimeError(msg)
- else:
- if hx[0].dim() != 2 or hx[1].dim() != 2:
- msg = ("For unbatched 2-D input, hx and cx should "
- f"also be 2-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors")
- raise RuntimeError(msg)
- hx = (hx[0].unsqueeze(1), hx[1].unsqueeze(1))
- # Each batch of the hidden state should match the input sequence that
- # the user believes he/she is passing in.
- hx = self.permute_hidden(hx, sorted_indices)
- self.check_forward_args(input, hx, batch_sizes)
- if batch_sizes is None:
- result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
- self.dropout, self.training, self.bidirectional, self.batch_first)
- else:
- result = _VF.lstm(input, batch_sizes, hx, self._flat_weights, self.bias,
- self.num_layers, self.dropout, self.training, self.bidirectional)
- output = result[0]
- hidden = result[1:]
- # xxx: isinstance check needs to be in conditional for TorchScript to compile
- if isinstance(orig_input, PackedSequence):
- output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
- return output_packed, self.permute_hidden(hidden, unsorted_indices)
- else:
- if not is_batched:
- output = output.squeeze(batch_dim)
- hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1))
- return output, self.permute_hidden(hidden, unsorted_indices)
- class GRU(RNNBase):
- r"""Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence.
- For each element in the input sequence, each layer computes the following
- function:
- .. math::
- \begin{array}{ll}
- r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\
- z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\
- n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\
- h_t = (1 - z_t) * n_t + z_t * h_{(t-1)}
- \end{array}
- where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the input
- at time `t`, :math:`h_{(t-1)}` is the hidden state of the layer
- at time `t-1` or the initial hidden state at time `0`, and :math:`r_t`,
- :math:`z_t`, :math:`n_t` are the reset, update, and new gates, respectively.
- :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product.
- In a multilayer GRU, the input :math:`x^{(l)}_t` of the :math:`l` -th layer
- (:math:`l >= 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by
- dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random
- variable which is :math:`0` with probability :attr:`dropout`.
- Args:
- input_size: The number of expected features in the input `x`
- hidden_size: The number of features in the hidden state `h`
- num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
- would mean stacking two GRUs together to form a `stacked GRU`,
- with the second GRU taking in outputs of the first GRU and
- computing the final results. Default: 1
- bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
- Default: ``True``
- batch_first: If ``True``, then the input and output tensors are provided
- as `(batch, seq, feature)` instead of `(seq, batch, feature)`.
- Note that this does not apply to hidden or cell states. See the
- Inputs/Outputs sections below for details. Default: ``False``
- dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
- GRU layer except the last layer, with dropout probability equal to
- :attr:`dropout`. Default: 0
- bidirectional: If ``True``, becomes a bidirectional GRU. Default: ``False``
- Inputs: input, h_0
- * **input**: tensor of shape :math:`(L, H_{in})` for unbatched input,
- :math:`(L, N, H_{in})` when ``batch_first=False`` or
- :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of
- the input sequence. The input can also be a packed variable length sequence.
- See :func:`torch.nn.utils.rnn.pack_padded_sequence` or
- :func:`torch.nn.utils.rnn.pack_sequence` for details.
- * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` or
- :math:`(D * \text{num\_layers}, N, H_{out})`
- containing the initial hidden state for the input sequence. Defaults to zeros if not provided.
- where:
- .. math::
- \begin{aligned}
- N ={} & \text{batch size} \\
- L ={} & \text{sequence length} \\
- D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\
- H_{in} ={} & \text{input\_size} \\
- H_{out} ={} & \text{hidden\_size}
- \end{aligned}
- Outputs: output, h_n
- * **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input,
- :math:`(L, N, D * H_{out})` when ``batch_first=False`` or
- :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features
- `(h_t)` from the last layer of the GRU, for each `t`. If a
- :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output
- will also be a packed sequence.
- * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` or
- :math:`(D * \text{num\_layers}, N, H_{out})` containing the final hidden state
- for the input sequence.
- Attributes:
- weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer
- (W_ir|W_iz|W_in), of shape `(3*hidden_size, input_size)` for `k = 0`.
- Otherwise, the shape is `(3*hidden_size, num_directions * hidden_size)`
- weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer
- (W_hr|W_hz|W_hn), of shape `(3*hidden_size, hidden_size)`
- bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer
- (b_ir|b_iz|b_in), of shape `(3*hidden_size)`
- bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer
- (b_hr|b_hz|b_hn), of shape `(3*hidden_size)`
- .. note::
- All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
- where :math:`k = \frac{1}{\text{hidden\_size}}`
- .. note::
- For bidirectional GRUs, forward and backward are directions 0 and 1 respectively.
- Example of splitting the output layers when ``batch_first=False``:
- ``output.view(seq_len, batch, num_directions, hidden_size)``.
- .. note::
- ``batch_first`` argument is ignored for unbatched inputs.
- .. include:: ../cudnn_persistent_rnn.rst
- Examples::
- >>> rnn = nn.GRU(10, 20, 2)
- >>> input = torch.randn(5, 3, 10)
- >>> h0 = torch.randn(2, 3, 20)
- >>> output, hn = rnn(input, h0)
- """
- def __init__(self, *args, **kwargs):
- if 'proj_size' in kwargs:
- raise ValueError("proj_size argument is only supported for LSTM, not RNN or GRU")
- super().__init__('GRU', *args, **kwargs)
- @overload # type: ignore[override]
- @torch._jit_internal._overload_method # noqa: F811
- def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: # noqa: F811
- pass
- @overload
- @torch._jit_internal._overload_method # noqa: F811
- def forward(self, input: PackedSequence, hx: Optional[Tensor] = None) -> Tuple[PackedSequence, Tensor]: # noqa: F811
- pass
- def forward(self, input, hx=None): # noqa: F811
- if not torch.jit.is_scripting():
- if self._weights_have_changed():
- self._init_flat_weights()
- orig_input = input
- # xxx: isinstance check needs to be in conditional for TorchScript to compile
- if isinstance(orig_input, PackedSequence):
- input, batch_sizes, sorted_indices, unsorted_indices = input
- max_batch_size = batch_sizes[0]
- max_batch_size = int(max_batch_size)
- else:
- batch_sizes = None
- assert (input.dim() in (2, 3)), f"GRU: Expected input to be 2-D or 3-D but received {input.dim()}-D tensor"
- is_batched = input.dim() == 3
- batch_dim = 0 if self.batch_first else 1
- if not is_batched:
- input = input.unsqueeze(batch_dim)
- if hx is not None:
- if hx.dim() != 2:
- raise RuntimeError(
- f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor")
- hx = hx.unsqueeze(1)
- else:
- if hx is not None and hx.dim() != 3:
- raise RuntimeError(
- f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor")
- max_batch_size = input.size(0) if self.batch_first else input.size(1)
- sorted_indices = None
- unsorted_indices = None
- if hx is None:
- num_directions = 2 if self.bidirectional else 1
- hx = torch.zeros(self.num_layers * num_directions,
- max_batch_size, self.hidden_size,
- dtype=input.dtype, device=input.device)
- else:
- # Each batch of the hidden state should match the input sequence that
- # the user believes he/she is passing in.
- hx = self.permute_hidden(hx, sorted_indices)
- self.check_forward_args(input, hx, batch_sizes)
- if batch_sizes is None:
- result = _VF.gru(input, hx, self._flat_weights, self.bias, self.num_layers,
- self.dropout, self.training, self.bidirectional, self.batch_first)
- else:
- result = _VF.gru(input, batch_sizes, hx, self._flat_weights, self.bias,
- self.num_layers, self.dropout, self.training, self.bidirectional)
- output = result[0]
- hidden = result[1]
- # xxx: isinstance check needs to be in conditional for TorchScript to compile
- if isinstance(orig_input, PackedSequence):
- output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
- return output_packed, self.permute_hidden(hidden, unsorted_indices)
- else:
- if not is_batched:
- output = output.squeeze(batch_dim)
- hidden = hidden.squeeze(1)
- return output, self.permute_hidden(hidden, unsorted_indices)
- class RNNCellBase(Module):
- __constants__ = ['input_size', 'hidden_size', 'bias']
- input_size: int
- hidden_size: int
- bias: bool
- weight_ih: Tensor
- weight_hh: Tensor
- # WARNING: bias_ih and bias_hh purposely not defined here.
- # See https://github.com/pytorch/pytorch/issues/39670
- def __init__(self, input_size: int, hidden_size: int, bias: bool, num_chunks: int,
- device=None, dtype=None) -> None:
- factory_kwargs = {'device': device, 'dtype': dtype}
- super().__init__()
- self.input_size = input_size
- self.hidden_size = hidden_size
- self.bias = bias
- self.weight_ih = Parameter(torch.empty((num_chunks * hidden_size, input_size), **factory_kwargs))
- self.weight_hh = Parameter(torch.empty((num_chunks * hidden_size, hidden_size), **factory_kwargs))
- if bias:
- self.bias_ih = Parameter(torch.empty(num_chunks * hidden_size, **factory_kwargs))
- self.bias_hh = Parameter(torch.empty(num_chunks * hidden_size, **factory_kwargs))
- else:
- self.register_parameter('bias_ih', None)
- self.register_parameter('bias_hh', None)
- self.reset_parameters()
- def extra_repr(self) -> str:
- s = '{input_size}, {hidden_size}'
- if 'bias' in self.__dict__ and self.bias is not True:
- s += ', bias={bias}'
- if 'nonlinearity' in self.__dict__ and self.nonlinearity != "tanh":
- s += ', nonlinearity={nonlinearity}'
- return s.format(**self.__dict__)
- def reset_parameters(self) -> None:
- stdv = 1.0 / math.sqrt(self.hidden_size) if self.hidden_size > 0 else 0
- for weight in self.parameters():
- init.uniform_(weight, -stdv, stdv)
- class RNNCell(RNNCellBase):
- r"""An Elman RNN cell with tanh or ReLU non-linearity.
- .. math::
- h' = \tanh(W_{ih} x + b_{ih} + W_{hh} h + b_{hh})
- If :attr:`nonlinearity` is `'relu'`, then ReLU is used in place of tanh.
- Args:
- input_size: The number of expected features in the input `x`
- hidden_size: The number of features in the hidden state `h`
- bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
- Default: ``True``
- nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'``
- Inputs: input, hidden
- - **input**: tensor containing input features
- - **hidden**: tensor containing the initial hidden state
- Defaults to zero if not provided.
- Outputs: h'
- - **h'** of shape `(batch, hidden_size)`: tensor containing the next hidden state
- for each element in the batch
- Shape:
- - input: :math:`(N, H_{in})` or :math:`(H_{in})` tensor containing input features where
- :math:`H_{in}` = `input_size`.
- - hidden: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the initial hidden
- state where :math:`H_{out}` = `hidden_size`. Defaults to zero if not provided.
- - output: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the next hidden state.
- Attributes:
- weight_ih: the learnable input-hidden weights, of shape
- `(hidden_size, input_size)`
- weight_hh: the learnable hidden-hidden weights, of shape
- `(hidden_size, hidden_size)`
- bias_ih: the learnable input-hidden bias, of shape `(hidden_size)`
- bias_hh: the learnable hidden-hidden bias, of shape `(hidden_size)`
- .. note::
- All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
- where :math:`k = \frac{1}{\text{hidden\_size}}`
- Examples::
- >>> rnn = nn.RNNCell(10, 20)
- >>> input = torch.randn(6, 3, 10)
- >>> hx = torch.randn(3, 20)
- >>> output = []
- >>> for i in range(6):
- ... hx = rnn(input[i], hx)
- ... output.append(hx)
- """
- __constants__ = ['input_size', 'hidden_size', 'bias', 'nonlinearity']
- nonlinearity: str
- def __init__(self, input_size: int, hidden_size: int, bias: bool = True, nonlinearity: str = "tanh",
- device=None, dtype=None) -> None:
- factory_kwargs = {'device': device, 'dtype': dtype}
- super().__init__(input_size, hidden_size, bias, num_chunks=1, **factory_kwargs)
- self.nonlinearity = nonlinearity
- def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
- assert input.dim() in (1, 2), \
- f"RNNCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor"
- is_batched = input.dim() == 2
- if not is_batched:
- input = input.unsqueeze(0)
- if hx is None:
- hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
- else:
- hx = hx.unsqueeze(0) if not is_batched else hx
- if self.nonlinearity == "tanh":
- ret = _VF.rnn_tanh_cell(
- input, hx,
- self.weight_ih, self.weight_hh,
- self.bias_ih, self.bias_hh,
- )
- elif self.nonlinearity == "relu":
- ret = _VF.rnn_relu_cell(
- input, hx,
- self.weight_ih, self.weight_hh,
- self.bias_ih, self.bias_hh,
- )
- else:
- ret = input # TODO: remove when jit supports exception flow
- raise RuntimeError(
- "Unknown nonlinearity: {}".format(self.nonlinearity))
- if not is_batched:
- ret = ret.squeeze(0)
- return ret
- class LSTMCell(RNNCellBase):
- r"""A long short-term memory (LSTM) cell.
- .. math::
- \begin{array}{ll}
- i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\
- f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\
- g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\
- o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\
- c' = f * c + i * g \\
- h' = o * \tanh(c') \\
- \end{array}
- where :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product.
- Args:
- input_size: The number of expected features in the input `x`
- hidden_size: The number of features in the hidden state `h`
- bias: If ``False``, then the layer does not use bias weights `b_ih` and
- `b_hh`. Default: ``True``
- Inputs: input, (h_0, c_0)
- - **input** of shape `(batch, input_size)` or `(input_size)`: tensor containing input features
- - **h_0** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the initial hidden state
- - **c_0** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the initial cell state
- If `(h_0, c_0)` is not provided, both **h_0** and **c_0** default to zero.
- Outputs: (h_1, c_1)
- - **h_1** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the next hidden state
- - **c_1** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the next cell state
- Attributes:
- weight_ih: the learnable input-hidden weights, of shape
- `(4*hidden_size, input_size)`
- weight_hh: the learnable hidden-hidden weights, of shape
- `(4*hidden_size, hidden_size)`
- bias_ih: the learnable input-hidden bias, of shape `(4*hidden_size)`
- bias_hh: the learnable hidden-hidden bias, of shape `(4*hidden_size)`
- .. note::
- All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
- where :math:`k = \frac{1}{\text{hidden\_size}}`
- On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
- Examples::
- >>> rnn = nn.LSTMCell(10, 20) # (input_size, hidden_size)
- >>> input = torch.randn(2, 3, 10) # (time_steps, batch, input_size)
- >>> hx = torch.randn(3, 20) # (batch, hidden_size)
- >>> cx = torch.randn(3, 20)
- >>> output = []
- >>> for i in range(input.size()[0]):
- ... hx, cx = rnn(input[i], (hx, cx))
- ... output.append(hx)
- >>> output = torch.stack(output, dim=0)
- """
- def __init__(self, input_size: int, hidden_size: int, bias: bool = True,
- device=None, dtype=None) -> None:
- factory_kwargs = {'device': device, 'dtype': dtype}
- super().__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs)
- def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
- assert input.dim() in (1, 2), \
- f"LSTMCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor"
- is_batched = input.dim() == 2
- if not is_batched:
- input = input.unsqueeze(0)
- if hx is None:
- zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
- hx = (zeros, zeros)
- else:
- hx = (hx[0].unsqueeze(0), hx[1].unsqueeze(0)) if not is_batched else hx
- ret = _VF.lstm_cell(
- input, hx,
- self.weight_ih, self.weight_hh,
- self.bias_ih, self.bias_hh,
- )
- if not is_batched:
- ret = (ret[0].squeeze(0), ret[1].squeeze(0))
- return ret
- class GRUCell(RNNCellBase):
- r"""A gated recurrent unit (GRU) cell
- .. math::
- \begin{array}{ll}
- r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\
- z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\
- n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\
- h' = (1 - z) * n + z * h
- \end{array}
- where :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product.
- Args:
- input_size: The number of expected features in the input `x`
- hidden_size: The number of features in the hidden state `h`
- bias: If ``False``, then the layer does not use bias weights `b_ih` and
- `b_hh`. Default: ``True``
- Inputs: input, hidden
- - **input** : tensor containing input features
- - **hidden** : tensor containing the initial hidden
- state for each element in the batch.
- Defaults to zero if not provided.
- Outputs: h'
- - **h'** : tensor containing the next hidden state
- for each element in the batch
- Shape:
- - input: :math:`(N, H_{in})` or :math:`(H_{in})` tensor containing input features where
- :math:`H_{in}` = `input_size`.
- - hidden: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the initial hidden
- state where :math:`H_{out}` = `hidden_size`. Defaults to zero if not provided.
- - output: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the next hidden state.
- Attributes:
- weight_ih: the learnable input-hidden weights, of shape
- `(3*hidden_size, input_size)`
- weight_hh: the learnable hidden-hidden weights, of shape
- `(3*hidden_size, hidden_size)`
- bias_ih: the learnable input-hidden bias, of shape `(3*hidden_size)`
- bias_hh: the learnable hidden-hidden bias, of shape `(3*hidden_size)`
- .. note::
- All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
- where :math:`k = \frac{1}{\text{hidden\_size}}`
- On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
- Examples::
- >>> rnn = nn.GRUCell(10, 20)
- >>> input = torch.randn(6, 3, 10)
- >>> hx = torch.randn(3, 20)
- >>> output = []
- >>> for i in range(6):
- ... hx = rnn(input[i], hx)
- ... output.append(hx)
- """
- def __init__(self, input_size: int, hidden_size: int, bias: bool = True,
- device=None, dtype=None) -> None:
- factory_kwargs = {'device': device, 'dtype': dtype}
- super().__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs)
- def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
- assert input.dim() in (1, 2), \
- f"GRUCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor"
- is_batched = input.dim() == 2
- if not is_batched:
- input = input.unsqueeze(0)
- if hx is None:
- hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
- else:
- hx = hx.unsqueeze(0) if not is_batched else hx
- ret = _VF.gru_cell(
- input, hx,
- self.weight_ih, self.weight_hh,
- self.bias_ih, self.bias_hh,
- )
- if not is_batched:
- ret = ret.squeeze(0)
- return ret
|