rnn.py 61 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336
  1. import math
  2. import warnings
  3. import numbers
  4. import weakref
  5. from typing import List, Tuple, Optional, overload
  6. import torch
  7. from torch import Tensor
  8. from .module import Module
  9. from ..parameter import Parameter
  10. from ..utils.rnn import PackedSequence
  11. from .. import init
  12. from ... import _VF
  13. __all__ = ['RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCellBase', 'RNNCell', 'LSTMCell', 'GRUCell']
  14. _rnn_impls = {
  15. 'RNN_TANH': _VF.rnn_tanh,
  16. 'RNN_RELU': _VF.rnn_relu,
  17. }
  18. def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
  19. return tensor.index_select(dim, permutation)
  20. def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
  21. warnings.warn("apply_permutation is deprecated, please use tensor.index_select(dim, permutation) instead")
  22. return _apply_permutation(tensor, permutation, dim)
  23. class RNNBase(Module):
  24. __constants__ = ['mode', 'input_size', 'hidden_size', 'num_layers', 'bias',
  25. 'batch_first', 'dropout', 'bidirectional', 'proj_size']
  26. __jit_unused_properties__ = ['all_weights']
  27. mode: str
  28. input_size: int
  29. hidden_size: int
  30. num_layers: int
  31. bias: bool
  32. batch_first: bool
  33. dropout: float
  34. bidirectional: bool
  35. proj_size: int
  36. def __init__(self, mode: str, input_size: int, hidden_size: int,
  37. num_layers: int = 1, bias: bool = True, batch_first: bool = False,
  38. dropout: float = 0., bidirectional: bool = False, proj_size: int = 0,
  39. device=None, dtype=None) -> None:
  40. factory_kwargs = {'device': device, 'dtype': dtype}
  41. super().__init__()
  42. self.mode = mode
  43. self.input_size = input_size
  44. self.hidden_size = hidden_size
  45. self.num_layers = num_layers
  46. self.bias = bias
  47. self.batch_first = batch_first
  48. self.dropout = float(dropout)
  49. self.bidirectional = bidirectional
  50. self.proj_size = proj_size
  51. self._flat_weight_refs: List[Optional[weakref.ReferenceType["Parameter"]]] = []
  52. num_directions = 2 if bidirectional else 1
  53. if not isinstance(dropout, numbers.Number) or not 0 <= dropout <= 1 or \
  54. isinstance(dropout, bool):
  55. raise ValueError("dropout should be a number in range [0, 1] "
  56. "representing the probability of an element being "
  57. "zeroed")
  58. if dropout > 0 and num_layers == 1:
  59. warnings.warn("dropout option adds dropout after all but last "
  60. "recurrent layer, so non-zero dropout expects "
  61. "num_layers greater than 1, but got dropout={} and "
  62. "num_layers={}".format(dropout, num_layers))
  63. if proj_size < 0:
  64. raise ValueError("proj_size should be a positive integer or zero to disable projections")
  65. if proj_size >= hidden_size:
  66. raise ValueError("proj_size has to be smaller than hidden_size")
  67. if mode == 'LSTM':
  68. gate_size = 4 * hidden_size
  69. elif mode == 'GRU':
  70. gate_size = 3 * hidden_size
  71. elif mode == 'RNN_TANH':
  72. gate_size = hidden_size
  73. elif mode == 'RNN_RELU':
  74. gate_size = hidden_size
  75. else:
  76. raise ValueError("Unrecognized RNN mode: " + mode)
  77. self._flat_weights_names = []
  78. self._all_weights = []
  79. for layer in range(num_layers):
  80. for direction in range(num_directions):
  81. real_hidden_size = proj_size if proj_size > 0 else hidden_size
  82. layer_input_size = input_size if layer == 0 else real_hidden_size * num_directions
  83. w_ih = Parameter(torch.empty((gate_size, layer_input_size), **factory_kwargs))
  84. w_hh = Parameter(torch.empty((gate_size, real_hidden_size), **factory_kwargs))
  85. b_ih = Parameter(torch.empty(gate_size, **factory_kwargs))
  86. # Second bias vector included for CuDNN compatibility. Only one
  87. # bias vector is needed in standard definition.
  88. b_hh = Parameter(torch.empty(gate_size, **factory_kwargs))
  89. layer_params: Tuple[Tensor, ...] = ()
  90. if self.proj_size == 0:
  91. if bias:
  92. layer_params = (w_ih, w_hh, b_ih, b_hh)
  93. else:
  94. layer_params = (w_ih, w_hh)
  95. else:
  96. w_hr = Parameter(torch.empty((proj_size, hidden_size), **factory_kwargs))
  97. if bias:
  98. layer_params = (w_ih, w_hh, b_ih, b_hh, w_hr)
  99. else:
  100. layer_params = (w_ih, w_hh, w_hr)
  101. suffix = '_reverse' if direction == 1 else ''
  102. param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}']
  103. if bias:
  104. param_names += ['bias_ih_l{}{}', 'bias_hh_l{}{}']
  105. if self.proj_size > 0:
  106. param_names += ['weight_hr_l{}{}']
  107. param_names = [x.format(layer, suffix) for x in param_names]
  108. for name, param in zip(param_names, layer_params):
  109. setattr(self, name, param)
  110. self._flat_weights_names.extend(param_names)
  111. self._all_weights.append(param_names)
  112. self._init_flat_weights()
  113. self.reset_parameters()
  114. def _init_flat_weights(self):
  115. self._flat_weights = [getattr(self, wn) if hasattr(self, wn) else None
  116. for wn in self._flat_weights_names]
  117. self._flat_weight_refs = [weakref.ref(w) if w is not None else None
  118. for w in self._flat_weights]
  119. self.flatten_parameters()
  120. def __setattr__(self, attr, value):
  121. if hasattr(self, "_flat_weights_names") and attr in self._flat_weights_names:
  122. # keep self._flat_weights up to date if you do self.weight = ...
  123. idx = self._flat_weights_names.index(attr)
  124. self._flat_weights[idx] = value
  125. super().__setattr__(attr, value)
  126. def flatten_parameters(self) -> None:
  127. """Resets parameter data pointer so that they can use faster code paths.
  128. Right now, this works only if the module is on the GPU and cuDNN is enabled.
  129. Otherwise, it's a no-op.
  130. """
  131. # Short-circuits if _flat_weights is only partially instantiated
  132. if len(self._flat_weights) != len(self._flat_weights_names):
  133. return
  134. for w in self._flat_weights:
  135. if not isinstance(w, Tensor):
  136. return
  137. # Short-circuits if any tensor in self._flat_weights is not acceptable to cuDNN
  138. # or the tensors in _flat_weights are of different dtypes
  139. first_fw = self._flat_weights[0]
  140. dtype = first_fw.dtype
  141. for fw in self._flat_weights:
  142. if (not isinstance(fw.data, Tensor) or not (fw.data.dtype == dtype) or
  143. not fw.data.is_cuda or
  144. not torch.backends.cudnn.is_acceptable(fw.data)):
  145. return
  146. # If any parameters alias, we fall back to the slower, copying code path. This is
  147. # a sufficient check, because overlapping parameter buffers that don't completely
  148. # alias would break the assumptions of the uniqueness check in
  149. # Module.named_parameters().
  150. unique_data_ptrs = {p.data_ptr() for p in self._flat_weights}
  151. if len(unique_data_ptrs) != len(self._flat_weights):
  152. return
  153. with torch.cuda.device_of(first_fw):
  154. import torch.backends.cudnn.rnn as rnn
  155. # Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is
  156. # an inplace operation on self._flat_weights
  157. with torch.no_grad():
  158. if torch._use_cudnn_rnn_flatten_weight():
  159. num_weights = 4 if self.bias else 2
  160. if self.proj_size > 0:
  161. num_weights += 1
  162. torch._cudnn_rnn_flatten_weight(
  163. self._flat_weights, num_weights,
  164. self.input_size, rnn.get_cudnn_mode(self.mode),
  165. self.hidden_size, self.proj_size, self.num_layers,
  166. self.batch_first, bool(self.bidirectional))
  167. def _apply(self, fn):
  168. ret = super()._apply(fn)
  169. # Resets _flat_weights
  170. # Note: be v. careful before removing this, as 3rd party device types
  171. # likely rely on this behavior to properly .to() modules like LSTM.
  172. self._init_flat_weights()
  173. return ret
  174. def reset_parameters(self) -> None:
  175. stdv = 1.0 / math.sqrt(self.hidden_size) if self.hidden_size > 0 else 0
  176. for weight in self.parameters():
  177. init.uniform_(weight, -stdv, stdv)
  178. def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None:
  179. expected_input_dim = 2 if batch_sizes is not None else 3
  180. if input.dim() != expected_input_dim:
  181. raise RuntimeError(
  182. 'input must have {} dimensions, got {}'.format(
  183. expected_input_dim, input.dim()))
  184. if self.input_size != input.size(-1):
  185. raise RuntimeError(
  186. 'input.size(-1) must be equal to input_size. Expected {}, got {}'.format(
  187. self.input_size, input.size(-1)))
  188. def get_expected_hidden_size(self, input: Tensor, batch_sizes: Optional[Tensor]) -> Tuple[int, int, int]:
  189. if batch_sizes is not None:
  190. mini_batch = int(batch_sizes[0])
  191. else:
  192. mini_batch = input.size(0) if self.batch_first else input.size(1)
  193. num_directions = 2 if self.bidirectional else 1
  194. if self.proj_size > 0:
  195. expected_hidden_size = (self.num_layers * num_directions,
  196. mini_batch, self.proj_size)
  197. else:
  198. expected_hidden_size = (self.num_layers * num_directions,
  199. mini_batch, self.hidden_size)
  200. return expected_hidden_size
  201. def check_hidden_size(self, hx: Tensor, expected_hidden_size: Tuple[int, int, int],
  202. msg: str = 'Expected hidden size {}, got {}') -> None:
  203. if hx.size() != expected_hidden_size:
  204. raise RuntimeError(msg.format(expected_hidden_size, list(hx.size())))
  205. def _weights_have_changed(self):
  206. # Returns True if the weight tensors have changed since the last forward pass.
  207. # This is the case when used with torch.func.functional_call(), for example.
  208. weights_changed = False
  209. for ref, name in zip(self._flat_weight_refs, self._flat_weights_names):
  210. weight = getattr(self, name) if hasattr(self, name) else None
  211. if weight is not None and ref is not None and ref() is not weight:
  212. weights_changed = True
  213. break
  214. return weights_changed
  215. def check_forward_args(self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor]):
  216. self.check_input(input, batch_sizes)
  217. expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
  218. self.check_hidden_size(hidden, expected_hidden_size)
  219. def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]):
  220. if permutation is None:
  221. return hx
  222. return _apply_permutation(hx, permutation)
  223. def extra_repr(self) -> str:
  224. s = '{input_size}, {hidden_size}'
  225. if self.proj_size != 0:
  226. s += ', proj_size={proj_size}'
  227. if self.num_layers != 1:
  228. s += ', num_layers={num_layers}'
  229. if self.bias is not True:
  230. s += ', bias={bias}'
  231. if self.batch_first is not False:
  232. s += ', batch_first={batch_first}'
  233. if self.dropout != 0:
  234. s += ', dropout={dropout}'
  235. if self.bidirectional is not False:
  236. s += ', bidirectional={bidirectional}'
  237. return s.format(**self.__dict__)
  238. def __getstate__(self):
  239. # Don't serialize the weight references.
  240. state = self.__dict__.copy()
  241. del state['_flat_weight_refs']
  242. return state
  243. def __setstate__(self, d):
  244. super().__setstate__(d)
  245. if 'all_weights' in d:
  246. self._all_weights = d['all_weights']
  247. # In PyTorch 1.8 we added a proj_size member variable to LSTM.
  248. # LSTMs that were serialized via torch.save(module) before PyTorch 1.8
  249. # don't have it, so to preserve compatibility we set proj_size here.
  250. if 'proj_size' not in d:
  251. self.proj_size = 0
  252. if not isinstance(self._all_weights[0][0], str):
  253. num_layers = self.num_layers
  254. num_directions = 2 if self.bidirectional else 1
  255. self._flat_weights_names = []
  256. self._all_weights = []
  257. for layer in range(num_layers):
  258. for direction in range(num_directions):
  259. suffix = '_reverse' if direction == 1 else ''
  260. weights = ['weight_ih_l{}{}', 'weight_hh_l{}{}', 'bias_ih_l{}{}',
  261. 'bias_hh_l{}{}', 'weight_hr_l{}{}']
  262. weights = [x.format(layer, suffix) for x in weights]
  263. if self.bias:
  264. if self.proj_size > 0:
  265. self._all_weights += [weights]
  266. self._flat_weights_names.extend(weights)
  267. else:
  268. self._all_weights += [weights[:4]]
  269. self._flat_weights_names.extend(weights[:4])
  270. else:
  271. if self.proj_size > 0:
  272. self._all_weights += [weights[:2]] + [weights[-1:]]
  273. self._flat_weights_names.extend(weights[:2] + [weights[-1:]])
  274. else:
  275. self._all_weights += [weights[:2]]
  276. self._flat_weights_names.extend(weights[:2])
  277. self._flat_weights = [getattr(self, wn) if hasattr(self, wn) else None
  278. for wn in self._flat_weights_names]
  279. self._flat_weight_refs = [weakref.ref(w) if w is not None else None
  280. for w in self._flat_weights]
  281. @property
  282. def all_weights(self) -> List[List[Parameter]]:
  283. return [[getattr(self, weight) for weight in weights] for weights in self._all_weights]
  284. def _replicate_for_data_parallel(self):
  285. replica = super()._replicate_for_data_parallel()
  286. # Need to copy these caches, otherwise the replica will share the same
  287. # flat weights list.
  288. replica._flat_weights = replica._flat_weights[:]
  289. replica._flat_weights_names = replica._flat_weights_names[:]
  290. return replica
  291. class RNN(RNNBase):
  292. r"""Applies a multi-layer Elman RNN with :math:`\tanh` or :math:`\text{ReLU}` non-linearity to an
  293. input sequence.
  294. For each element in the input sequence, each layer computes the following
  295. function:
  296. .. math::
  297. h_t = \tanh(x_t W_{ih}^T + b_{ih} + h_{t-1}W_{hh}^T + b_{hh})
  298. where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is
  299. the input at time `t`, and :math:`h_{(t-1)}` is the hidden state of the
  300. previous layer at time `t-1` or the initial hidden state at time `0`.
  301. If :attr:`nonlinearity` is ``'relu'``, then :math:`\text{ReLU}` is used instead of :math:`\tanh`.
  302. Args:
  303. input_size: The number of expected features in the input `x`
  304. hidden_size: The number of features in the hidden state `h`
  305. num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
  306. would mean stacking two RNNs together to form a `stacked RNN`,
  307. with the second RNN taking in outputs of the first RNN and
  308. computing the final results. Default: 1
  309. nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'``
  310. bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
  311. Default: ``True``
  312. batch_first: If ``True``, then the input and output tensors are provided
  313. as `(batch, seq, feature)` instead of `(seq, batch, feature)`.
  314. Note that this does not apply to hidden or cell states. See the
  315. Inputs/Outputs sections below for details. Default: ``False``
  316. dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
  317. RNN layer except the last layer, with dropout probability equal to
  318. :attr:`dropout`. Default: 0
  319. bidirectional: If ``True``, becomes a bidirectional RNN. Default: ``False``
  320. Inputs: input, h_0
  321. * **input**: tensor of shape :math:`(L, H_{in})` for unbatched input,
  322. :math:`(L, N, H_{in})` when ``batch_first=False`` or
  323. :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of
  324. the input sequence. The input can also be a packed variable length sequence.
  325. See :func:`torch.nn.utils.rnn.pack_padded_sequence` or
  326. :func:`torch.nn.utils.rnn.pack_sequence` for details.
  327. * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or
  328. :math:`(D * \text{num\_layers}, N, H_{out})` containing the initial hidden
  329. state for the input sequence batch. Defaults to zeros if not provided.
  330. where:
  331. .. math::
  332. \begin{aligned}
  333. N ={} & \text{batch size} \\
  334. L ={} & \text{sequence length} \\
  335. D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\
  336. H_{in} ={} & \text{input\_size} \\
  337. H_{out} ={} & \text{hidden\_size}
  338. \end{aligned}
  339. Outputs: output, h_n
  340. * **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input,
  341. :math:`(L, N, D * H_{out})` when ``batch_first=False`` or
  342. :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features
  343. `(h_t)` from the last layer of the RNN, for each `t`. If a
  344. :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output
  345. will also be a packed sequence.
  346. * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or
  347. :math:`(D * \text{num\_layers}, N, H_{out})` containing the final hidden state
  348. for each element in the batch.
  349. Attributes:
  350. weight_ih_l[k]: the learnable input-hidden weights of the k-th layer,
  351. of shape `(hidden_size, input_size)` for `k = 0`. Otherwise, the shape is
  352. `(hidden_size, num_directions * hidden_size)`
  353. weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer,
  354. of shape `(hidden_size, hidden_size)`
  355. bias_ih_l[k]: the learnable input-hidden bias of the k-th layer,
  356. of shape `(hidden_size)`
  357. bias_hh_l[k]: the learnable hidden-hidden bias of the k-th layer,
  358. of shape `(hidden_size)`
  359. .. note::
  360. All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
  361. where :math:`k = \frac{1}{\text{hidden\_size}}`
  362. .. note::
  363. For bidirectional RNNs, forward and backward are directions 0 and 1 respectively.
  364. Example of splitting the output layers when ``batch_first=False``:
  365. ``output.view(seq_len, batch, num_directions, hidden_size)``.
  366. .. note::
  367. ``batch_first`` argument is ignored for unbatched inputs.
  368. .. include:: ../cudnn_rnn_determinism.rst
  369. .. include:: ../cudnn_persistent_rnn.rst
  370. Examples::
  371. >>> rnn = nn.RNN(10, 20, 2)
  372. >>> input = torch.randn(5, 3, 10)
  373. >>> h0 = torch.randn(2, 3, 20)
  374. >>> output, hn = rnn(input, h0)
  375. """
  376. def __init__(self, *args, **kwargs):
  377. if 'proj_size' in kwargs:
  378. raise ValueError("proj_size argument is only supported for LSTM, not RNN or GRU")
  379. self.nonlinearity = kwargs.pop('nonlinearity', 'tanh')
  380. if self.nonlinearity == 'tanh':
  381. mode = 'RNN_TANH'
  382. elif self.nonlinearity == 'relu':
  383. mode = 'RNN_RELU'
  384. else:
  385. raise ValueError("Unknown nonlinearity '{}'".format(self.nonlinearity))
  386. super().__init__(mode, *args, **kwargs)
  387. @overload
  388. @torch._jit_internal._overload_method # noqa: F811
  389. def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
  390. pass
  391. @overload
  392. @torch._jit_internal._overload_method # noqa: F811
  393. def forward(self, input: PackedSequence, hx: Optional[Tensor] = None) -> Tuple[PackedSequence, Tensor]:
  394. pass
  395. def forward(self, input, hx=None): # noqa: F811
  396. if not torch.jit.is_scripting():
  397. if self._weights_have_changed():
  398. self._init_flat_weights()
  399. orig_input = input
  400. if isinstance(orig_input, PackedSequence):
  401. input, batch_sizes, sorted_indices, unsorted_indices = input
  402. max_batch_size = int(batch_sizes[0])
  403. else:
  404. batch_sizes = None
  405. assert (input.dim() in (2, 3)), f"RNN: Expected input to be 2-D or 3-D but received {input.dim()}-D tensor"
  406. is_batched = input.dim() == 3
  407. batch_dim = 0 if self.batch_first else 1
  408. if not is_batched:
  409. input = input.unsqueeze(batch_dim)
  410. if hx is not None:
  411. if hx.dim() != 2:
  412. raise RuntimeError(
  413. f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor")
  414. hx = hx.unsqueeze(1)
  415. else:
  416. if hx is not None and hx.dim() != 3:
  417. raise RuntimeError(
  418. f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor")
  419. max_batch_size = input.size(0) if self.batch_first else input.size(1)
  420. sorted_indices = None
  421. unsorted_indices = None
  422. if hx is None:
  423. num_directions = 2 if self.bidirectional else 1
  424. hx = torch.zeros(self.num_layers * num_directions,
  425. max_batch_size, self.hidden_size,
  426. dtype=input.dtype, device=input.device)
  427. else:
  428. # Each batch of the hidden state should match the input sequence that
  429. # the user believes he/she is passing in.
  430. hx = self.permute_hidden(hx, sorted_indices)
  431. assert hx is not None
  432. self.check_forward_args(input, hx, batch_sizes)
  433. assert self.mode == 'RNN_TANH' or self.mode == 'RNN_RELU'
  434. if batch_sizes is None:
  435. if self.mode == 'RNN_TANH':
  436. result = _VF.rnn_tanh(input, hx, self._flat_weights, self.bias, self.num_layers,
  437. self.dropout, self.training, self.bidirectional,
  438. self.batch_first)
  439. else:
  440. result = _VF.rnn_relu(input, hx, self._flat_weights, self.bias, self.num_layers,
  441. self.dropout, self.training, self.bidirectional,
  442. self.batch_first)
  443. else:
  444. if self.mode == 'RNN_TANH':
  445. result = _VF.rnn_tanh(input, batch_sizes, hx, self._flat_weights, self.bias,
  446. self.num_layers, self.dropout, self.training,
  447. self.bidirectional)
  448. else:
  449. result = _VF.rnn_relu(input, batch_sizes, hx, self._flat_weights, self.bias,
  450. self.num_layers, self.dropout, self.training,
  451. self.bidirectional)
  452. output = result[0]
  453. hidden = result[1]
  454. if isinstance(orig_input, PackedSequence):
  455. output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
  456. return output_packed, self.permute_hidden(hidden, unsorted_indices)
  457. if not is_batched:
  458. output = output.squeeze(batch_dim)
  459. hidden = hidden.squeeze(1)
  460. return output, self.permute_hidden(hidden, unsorted_indices)
  461. # XXX: LSTM and GRU implementation is different from RNNBase, this is because:
  462. # 1. we want to support nn.LSTM and nn.GRU in TorchScript and TorchScript in
  463. # its current state could not support the python Union Type or Any Type
  464. # 2. TorchScript static typing does not allow a Function or Callable type in
  465. # Dict values, so we have to separately call _VF instead of using _rnn_impls
  466. # 3. This is temporary only and in the transition state that we want to make it
  467. # on time for the release
  468. #
  469. # More discussion details in https://github.com/pytorch/pytorch/pull/23266
  470. #
  471. # TODO: remove the overriding implementations for LSTM and GRU when TorchScript
  472. # support expressing these two modules generally.
  473. class LSTM(RNNBase):
  474. r"""Applies a multi-layer long short-term memory (LSTM) RNN to an input
  475. sequence.
  476. For each element in the input sequence, each layer computes the following
  477. function:
  478. .. math::
  479. \begin{array}{ll} \\
  480. i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\
  481. f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\
  482. g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\
  483. o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\
  484. c_t = f_t \odot c_{t-1} + i_t \odot g_t \\
  485. h_t = o_t \odot \tanh(c_t) \\
  486. \end{array}
  487. where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the cell
  488. state at time `t`, :math:`x_t` is the input at time `t`, :math:`h_{t-1}`
  489. is the hidden state of the layer at time `t-1` or the initial hidden
  490. state at time `0`, and :math:`i_t`, :math:`f_t`, :math:`g_t`,
  491. :math:`o_t` are the input, forget, cell, and output gates, respectively.
  492. :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product.
  493. In a multilayer LSTM, the input :math:`x^{(l)}_t` of the :math:`l` -th layer
  494. (:math:`l >= 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by
  495. dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random
  496. variable which is :math:`0` with probability :attr:`dropout`.
  497. If ``proj_size > 0`` is specified, LSTM with projections will be used. This changes
  498. the LSTM cell in the following way. First, the dimension of :math:`h_t` will be changed from
  499. ``hidden_size`` to ``proj_size`` (dimensions of :math:`W_{hi}` will be changed accordingly).
  500. Second, the output hidden state of each layer will be multiplied by a learnable projection
  501. matrix: :math:`h_t = W_{hr}h_t`. Note that as a consequence of this, the output
  502. of LSTM network will be of different shape as well. See Inputs/Outputs sections below for exact
  503. dimensions of all variables. You can find more details in https://arxiv.org/abs/1402.1128.
  504. Args:
  505. input_size: The number of expected features in the input `x`
  506. hidden_size: The number of features in the hidden state `h`
  507. num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
  508. would mean stacking two LSTMs together to form a `stacked LSTM`,
  509. with the second LSTM taking in outputs of the first LSTM and
  510. computing the final results. Default: 1
  511. bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
  512. Default: ``True``
  513. batch_first: If ``True``, then the input and output tensors are provided
  514. as `(batch, seq, feature)` instead of `(seq, batch, feature)`.
  515. Note that this does not apply to hidden or cell states. See the
  516. Inputs/Outputs sections below for details. Default: ``False``
  517. dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
  518. LSTM layer except the last layer, with dropout probability equal to
  519. :attr:`dropout`. Default: 0
  520. bidirectional: If ``True``, becomes a bidirectional LSTM. Default: ``False``
  521. proj_size: If ``> 0``, will use LSTM with projections of corresponding size. Default: 0
  522. Inputs: input, (h_0, c_0)
  523. * **input**: tensor of shape :math:`(L, H_{in})` for unbatched input,
  524. :math:`(L, N, H_{in})` when ``batch_first=False`` or
  525. :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of
  526. the input sequence. The input can also be a packed variable length sequence.
  527. See :func:`torch.nn.utils.rnn.pack_padded_sequence` or
  528. :func:`torch.nn.utils.rnn.pack_sequence` for details.
  529. * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or
  530. :math:`(D * \text{num\_layers}, N, H_{out})` containing the
  531. initial hidden state for each element in the input sequence.
  532. Defaults to zeros if (h_0, c_0) is not provided.
  533. * **c_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{cell})` for unbatched input or
  534. :math:`(D * \text{num\_layers}, N, H_{cell})` containing the
  535. initial cell state for each element in the input sequence.
  536. Defaults to zeros if (h_0, c_0) is not provided.
  537. where:
  538. .. math::
  539. \begin{aligned}
  540. N ={} & \text{batch size} \\
  541. L ={} & \text{sequence length} \\
  542. D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\
  543. H_{in} ={} & \text{input\_size} \\
  544. H_{cell} ={} & \text{hidden\_size} \\
  545. H_{out} ={} & \text{proj\_size if } \text{proj\_size}>0 \text{ otherwise hidden\_size} \\
  546. \end{aligned}
  547. Outputs: output, (h_n, c_n)
  548. * **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input,
  549. :math:`(L, N, D * H_{out})` when ``batch_first=False`` or
  550. :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features
  551. `(h_t)` from the last layer of the LSTM, for each `t`. If a
  552. :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output
  553. will also be a packed sequence. When ``bidirectional=True``, `output` will contain
  554. a concatenation of the forward and reverse hidden states at each time step in the sequence.
  555. * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or
  556. :math:`(D * \text{num\_layers}, N, H_{out})` containing the
  557. final hidden state for each element in the sequence. When ``bidirectional=True``,
  558. `h_n` will contain a concatenation of the final forward and reverse hidden states, respectively.
  559. * **c_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{cell})` for unbatched input or
  560. :math:`(D * \text{num\_layers}, N, H_{cell})` containing the
  561. final cell state for each element in the sequence. When ``bidirectional=True``,
  562. `c_n` will contain a concatenation of the final forward and reverse cell states, respectively.
  563. Attributes:
  564. weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer
  565. `(W_ii|W_if|W_ig|W_io)`, of shape `(4*hidden_size, input_size)` for `k = 0`.
  566. Otherwise, the shape is `(4*hidden_size, num_directions * hidden_size)`. If
  567. ``proj_size > 0`` was specified, the shape will be
  568. `(4*hidden_size, num_directions * proj_size)` for `k > 0`
  569. weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer
  570. `(W_hi|W_hf|W_hg|W_ho)`, of shape `(4*hidden_size, hidden_size)`. If ``proj_size > 0``
  571. was specified, the shape will be `(4*hidden_size, proj_size)`.
  572. bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer
  573. `(b_ii|b_if|b_ig|b_io)`, of shape `(4*hidden_size)`
  574. bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer
  575. `(b_hi|b_hf|b_hg|b_ho)`, of shape `(4*hidden_size)`
  576. weight_hr_l[k] : the learnable projection weights of the :math:`\text{k}^{th}` layer
  577. of shape `(proj_size, hidden_size)`. Only present when ``proj_size > 0`` was
  578. specified.
  579. weight_ih_l[k]_reverse: Analogous to `weight_ih_l[k]` for the reverse direction.
  580. Only present when ``bidirectional=True``.
  581. weight_hh_l[k]_reverse: Analogous to `weight_hh_l[k]` for the reverse direction.
  582. Only present when ``bidirectional=True``.
  583. bias_ih_l[k]_reverse: Analogous to `bias_ih_l[k]` for the reverse direction.
  584. Only present when ``bidirectional=True``.
  585. bias_hh_l[k]_reverse: Analogous to `bias_hh_l[k]` for the reverse direction.
  586. Only present when ``bidirectional=True``.
  587. weight_hr_l[k]_reverse: Analogous to `weight_hr_l[k]` for the reverse direction.
  588. Only present when ``bidirectional=True`` and ``proj_size > 0`` was specified.
  589. .. note::
  590. All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
  591. where :math:`k = \frac{1}{\text{hidden\_size}}`
  592. .. note::
  593. For bidirectional LSTMs, forward and backward are directions 0 and 1 respectively.
  594. Example of splitting the output layers when ``batch_first=False``:
  595. ``output.view(seq_len, batch, num_directions, hidden_size)``.
  596. .. note::
  597. For bidirectional LSTMs, `h_n` is not equivalent to the last element of `output`; the
  598. former contains the final forward and reverse hidden states, while the latter contains the
  599. final forward hidden state and the initial reverse hidden state.
  600. .. note::
  601. ``batch_first`` argument is ignored for unbatched inputs.
  602. .. include:: ../cudnn_rnn_determinism.rst
  603. .. include:: ../cudnn_persistent_rnn.rst
  604. Examples::
  605. >>> rnn = nn.LSTM(10, 20, 2)
  606. >>> input = torch.randn(5, 3, 10)
  607. >>> h0 = torch.randn(2, 3, 20)
  608. >>> c0 = torch.randn(2, 3, 20)
  609. >>> output, (hn, cn) = rnn(input, (h0, c0))
  610. """
  611. def __init__(self, *args, **kwargs):
  612. super().__init__('LSTM', *args, **kwargs)
  613. def get_expected_cell_size(self, input: Tensor, batch_sizes: Optional[Tensor]) -> Tuple[int, int, int]:
  614. if batch_sizes is not None:
  615. mini_batch = int(batch_sizes[0])
  616. else:
  617. mini_batch = input.size(0) if self.batch_first else input.size(1)
  618. num_directions = 2 if self.bidirectional else 1
  619. expected_hidden_size = (self.num_layers * num_directions,
  620. mini_batch, self.hidden_size)
  621. return expected_hidden_size
  622. # In the future, we should prevent mypy from applying contravariance rules here.
  623. # See torch/nn/modules/module.py::_forward_unimplemented
  624. def check_forward_args(self, # type: ignore[override]
  625. input: Tensor,
  626. hidden: Tuple[Tensor, Tensor],
  627. batch_sizes: Optional[Tensor],
  628. ):
  629. self.check_input(input, batch_sizes)
  630. self.check_hidden_size(hidden[0], self.get_expected_hidden_size(input, batch_sizes),
  631. 'Expected hidden[0] size {}, got {}')
  632. self.check_hidden_size(hidden[1], self.get_expected_cell_size(input, batch_sizes),
  633. 'Expected hidden[1] size {}, got {}')
  634. # Same as above, see torch/nn/modules/module.py::_forward_unimplemented
  635. def permute_hidden(self, # type: ignore[override]
  636. hx: Tuple[Tensor, Tensor],
  637. permutation: Optional[Tensor]
  638. ) -> Tuple[Tensor, Tensor]:
  639. if permutation is None:
  640. return hx
  641. return _apply_permutation(hx[0], permutation), _apply_permutation(hx[1], permutation)
  642. # Same as above, see torch/nn/modules/module.py::_forward_unimplemented
  643. @overload # type: ignore[override]
  644. @torch._jit_internal._overload_method # noqa: F811
  645. def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None
  646. ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: # noqa: F811
  647. pass
  648. # Same as above, see torch/nn/modules/module.py::_forward_unimplemented
  649. @overload
  650. @torch._jit_internal._overload_method # noqa: F811
  651. def forward(self, input: PackedSequence, hx: Optional[Tuple[Tensor, Tensor]] = None
  652. ) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]]: # noqa: F811
  653. pass
  654. def forward(self, input, hx=None): # noqa: F811
  655. if not torch.jit.is_scripting():
  656. if self._weights_have_changed():
  657. self._init_flat_weights()
  658. orig_input = input
  659. # xxx: isinstance check needs to be in conditional for TorchScript to compile
  660. batch_sizes = None
  661. if isinstance(orig_input, PackedSequence):
  662. input, batch_sizes, sorted_indices, unsorted_indices = input
  663. max_batch_size = batch_sizes[0]
  664. max_batch_size = int(max_batch_size)
  665. else:
  666. batch_sizes = None
  667. assert (input.dim() in (2, 3)), f"LSTM: Expected input to be 2-D or 3-D but received {input.dim()}-D tensor"
  668. is_batched = input.dim() == 3
  669. batch_dim = 0 if self.batch_first else 1
  670. if not is_batched:
  671. input = input.unsqueeze(batch_dim)
  672. max_batch_size = input.size(0) if self.batch_first else input.size(1)
  673. sorted_indices = None
  674. unsorted_indices = None
  675. if hx is None:
  676. num_directions = 2 if self.bidirectional else 1
  677. real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size
  678. h_zeros = torch.zeros(self.num_layers * num_directions,
  679. max_batch_size, real_hidden_size,
  680. dtype=input.dtype, device=input.device)
  681. c_zeros = torch.zeros(self.num_layers * num_directions,
  682. max_batch_size, self.hidden_size,
  683. dtype=input.dtype, device=input.device)
  684. hx = (h_zeros, c_zeros)
  685. else:
  686. if batch_sizes is None: # If not PackedSequence input.
  687. if is_batched:
  688. if (hx[0].dim() != 3 or hx[1].dim() != 3):
  689. msg = ("For batched 3-D input, hx and cx should "
  690. f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors")
  691. raise RuntimeError(msg)
  692. else:
  693. if hx[0].dim() != 2 or hx[1].dim() != 2:
  694. msg = ("For unbatched 2-D input, hx and cx should "
  695. f"also be 2-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors")
  696. raise RuntimeError(msg)
  697. hx = (hx[0].unsqueeze(1), hx[1].unsqueeze(1))
  698. # Each batch of the hidden state should match the input sequence that
  699. # the user believes he/she is passing in.
  700. hx = self.permute_hidden(hx, sorted_indices)
  701. self.check_forward_args(input, hx, batch_sizes)
  702. if batch_sizes is None:
  703. result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
  704. self.dropout, self.training, self.bidirectional, self.batch_first)
  705. else:
  706. result = _VF.lstm(input, batch_sizes, hx, self._flat_weights, self.bias,
  707. self.num_layers, self.dropout, self.training, self.bidirectional)
  708. output = result[0]
  709. hidden = result[1:]
  710. # xxx: isinstance check needs to be in conditional for TorchScript to compile
  711. if isinstance(orig_input, PackedSequence):
  712. output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
  713. return output_packed, self.permute_hidden(hidden, unsorted_indices)
  714. else:
  715. if not is_batched:
  716. output = output.squeeze(batch_dim)
  717. hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1))
  718. return output, self.permute_hidden(hidden, unsorted_indices)
  719. class GRU(RNNBase):
  720. r"""Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence.
  721. For each element in the input sequence, each layer computes the following
  722. function:
  723. .. math::
  724. \begin{array}{ll}
  725. r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\
  726. z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\
  727. n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\
  728. h_t = (1 - z_t) * n_t + z_t * h_{(t-1)}
  729. \end{array}
  730. where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the input
  731. at time `t`, :math:`h_{(t-1)}` is the hidden state of the layer
  732. at time `t-1` or the initial hidden state at time `0`, and :math:`r_t`,
  733. :math:`z_t`, :math:`n_t` are the reset, update, and new gates, respectively.
  734. :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product.
  735. In a multilayer GRU, the input :math:`x^{(l)}_t` of the :math:`l` -th layer
  736. (:math:`l >= 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by
  737. dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random
  738. variable which is :math:`0` with probability :attr:`dropout`.
  739. Args:
  740. input_size: The number of expected features in the input `x`
  741. hidden_size: The number of features in the hidden state `h`
  742. num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
  743. would mean stacking two GRUs together to form a `stacked GRU`,
  744. with the second GRU taking in outputs of the first GRU and
  745. computing the final results. Default: 1
  746. bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
  747. Default: ``True``
  748. batch_first: If ``True``, then the input and output tensors are provided
  749. as `(batch, seq, feature)` instead of `(seq, batch, feature)`.
  750. Note that this does not apply to hidden or cell states. See the
  751. Inputs/Outputs sections below for details. Default: ``False``
  752. dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
  753. GRU layer except the last layer, with dropout probability equal to
  754. :attr:`dropout`. Default: 0
  755. bidirectional: If ``True``, becomes a bidirectional GRU. Default: ``False``
  756. Inputs: input, h_0
  757. * **input**: tensor of shape :math:`(L, H_{in})` for unbatched input,
  758. :math:`(L, N, H_{in})` when ``batch_first=False`` or
  759. :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of
  760. the input sequence. The input can also be a packed variable length sequence.
  761. See :func:`torch.nn.utils.rnn.pack_padded_sequence` or
  762. :func:`torch.nn.utils.rnn.pack_sequence` for details.
  763. * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` or
  764. :math:`(D * \text{num\_layers}, N, H_{out})`
  765. containing the initial hidden state for the input sequence. Defaults to zeros if not provided.
  766. where:
  767. .. math::
  768. \begin{aligned}
  769. N ={} & \text{batch size} \\
  770. L ={} & \text{sequence length} \\
  771. D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\
  772. H_{in} ={} & \text{input\_size} \\
  773. H_{out} ={} & \text{hidden\_size}
  774. \end{aligned}
  775. Outputs: output, h_n
  776. * **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input,
  777. :math:`(L, N, D * H_{out})` when ``batch_first=False`` or
  778. :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features
  779. `(h_t)` from the last layer of the GRU, for each `t`. If a
  780. :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output
  781. will also be a packed sequence.
  782. * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` or
  783. :math:`(D * \text{num\_layers}, N, H_{out})` containing the final hidden state
  784. for the input sequence.
  785. Attributes:
  786. weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer
  787. (W_ir|W_iz|W_in), of shape `(3*hidden_size, input_size)` for `k = 0`.
  788. Otherwise, the shape is `(3*hidden_size, num_directions * hidden_size)`
  789. weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer
  790. (W_hr|W_hz|W_hn), of shape `(3*hidden_size, hidden_size)`
  791. bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer
  792. (b_ir|b_iz|b_in), of shape `(3*hidden_size)`
  793. bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer
  794. (b_hr|b_hz|b_hn), of shape `(3*hidden_size)`
  795. .. note::
  796. All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
  797. where :math:`k = \frac{1}{\text{hidden\_size}}`
  798. .. note::
  799. For bidirectional GRUs, forward and backward are directions 0 and 1 respectively.
  800. Example of splitting the output layers when ``batch_first=False``:
  801. ``output.view(seq_len, batch, num_directions, hidden_size)``.
  802. .. note::
  803. ``batch_first`` argument is ignored for unbatched inputs.
  804. .. include:: ../cudnn_persistent_rnn.rst
  805. Examples::
  806. >>> rnn = nn.GRU(10, 20, 2)
  807. >>> input = torch.randn(5, 3, 10)
  808. >>> h0 = torch.randn(2, 3, 20)
  809. >>> output, hn = rnn(input, h0)
  810. """
  811. def __init__(self, *args, **kwargs):
  812. if 'proj_size' in kwargs:
  813. raise ValueError("proj_size argument is only supported for LSTM, not RNN or GRU")
  814. super().__init__('GRU', *args, **kwargs)
  815. @overload # type: ignore[override]
  816. @torch._jit_internal._overload_method # noqa: F811
  817. def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: # noqa: F811
  818. pass
  819. @overload
  820. @torch._jit_internal._overload_method # noqa: F811
  821. def forward(self, input: PackedSequence, hx: Optional[Tensor] = None) -> Tuple[PackedSequence, Tensor]: # noqa: F811
  822. pass
  823. def forward(self, input, hx=None): # noqa: F811
  824. if not torch.jit.is_scripting():
  825. if self._weights_have_changed():
  826. self._init_flat_weights()
  827. orig_input = input
  828. # xxx: isinstance check needs to be in conditional for TorchScript to compile
  829. if isinstance(orig_input, PackedSequence):
  830. input, batch_sizes, sorted_indices, unsorted_indices = input
  831. max_batch_size = batch_sizes[0]
  832. max_batch_size = int(max_batch_size)
  833. else:
  834. batch_sizes = None
  835. assert (input.dim() in (2, 3)), f"GRU: Expected input to be 2-D or 3-D but received {input.dim()}-D tensor"
  836. is_batched = input.dim() == 3
  837. batch_dim = 0 if self.batch_first else 1
  838. if not is_batched:
  839. input = input.unsqueeze(batch_dim)
  840. if hx is not None:
  841. if hx.dim() != 2:
  842. raise RuntimeError(
  843. f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor")
  844. hx = hx.unsqueeze(1)
  845. else:
  846. if hx is not None and hx.dim() != 3:
  847. raise RuntimeError(
  848. f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor")
  849. max_batch_size = input.size(0) if self.batch_first else input.size(1)
  850. sorted_indices = None
  851. unsorted_indices = None
  852. if hx is None:
  853. num_directions = 2 if self.bidirectional else 1
  854. hx = torch.zeros(self.num_layers * num_directions,
  855. max_batch_size, self.hidden_size,
  856. dtype=input.dtype, device=input.device)
  857. else:
  858. # Each batch of the hidden state should match the input sequence that
  859. # the user believes he/she is passing in.
  860. hx = self.permute_hidden(hx, sorted_indices)
  861. self.check_forward_args(input, hx, batch_sizes)
  862. if batch_sizes is None:
  863. result = _VF.gru(input, hx, self._flat_weights, self.bias, self.num_layers,
  864. self.dropout, self.training, self.bidirectional, self.batch_first)
  865. else:
  866. result = _VF.gru(input, batch_sizes, hx, self._flat_weights, self.bias,
  867. self.num_layers, self.dropout, self.training, self.bidirectional)
  868. output = result[0]
  869. hidden = result[1]
  870. # xxx: isinstance check needs to be in conditional for TorchScript to compile
  871. if isinstance(orig_input, PackedSequence):
  872. output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
  873. return output_packed, self.permute_hidden(hidden, unsorted_indices)
  874. else:
  875. if not is_batched:
  876. output = output.squeeze(batch_dim)
  877. hidden = hidden.squeeze(1)
  878. return output, self.permute_hidden(hidden, unsorted_indices)
  879. class RNNCellBase(Module):
  880. __constants__ = ['input_size', 'hidden_size', 'bias']
  881. input_size: int
  882. hidden_size: int
  883. bias: bool
  884. weight_ih: Tensor
  885. weight_hh: Tensor
  886. # WARNING: bias_ih and bias_hh purposely not defined here.
  887. # See https://github.com/pytorch/pytorch/issues/39670
  888. def __init__(self, input_size: int, hidden_size: int, bias: bool, num_chunks: int,
  889. device=None, dtype=None) -> None:
  890. factory_kwargs = {'device': device, 'dtype': dtype}
  891. super().__init__()
  892. self.input_size = input_size
  893. self.hidden_size = hidden_size
  894. self.bias = bias
  895. self.weight_ih = Parameter(torch.empty((num_chunks * hidden_size, input_size), **factory_kwargs))
  896. self.weight_hh = Parameter(torch.empty((num_chunks * hidden_size, hidden_size), **factory_kwargs))
  897. if bias:
  898. self.bias_ih = Parameter(torch.empty(num_chunks * hidden_size, **factory_kwargs))
  899. self.bias_hh = Parameter(torch.empty(num_chunks * hidden_size, **factory_kwargs))
  900. else:
  901. self.register_parameter('bias_ih', None)
  902. self.register_parameter('bias_hh', None)
  903. self.reset_parameters()
  904. def extra_repr(self) -> str:
  905. s = '{input_size}, {hidden_size}'
  906. if 'bias' in self.__dict__ and self.bias is not True:
  907. s += ', bias={bias}'
  908. if 'nonlinearity' in self.__dict__ and self.nonlinearity != "tanh":
  909. s += ', nonlinearity={nonlinearity}'
  910. return s.format(**self.__dict__)
  911. def reset_parameters(self) -> None:
  912. stdv = 1.0 / math.sqrt(self.hidden_size) if self.hidden_size > 0 else 0
  913. for weight in self.parameters():
  914. init.uniform_(weight, -stdv, stdv)
  915. class RNNCell(RNNCellBase):
  916. r"""An Elman RNN cell with tanh or ReLU non-linearity.
  917. .. math::
  918. h' = \tanh(W_{ih} x + b_{ih} + W_{hh} h + b_{hh})
  919. If :attr:`nonlinearity` is `'relu'`, then ReLU is used in place of tanh.
  920. Args:
  921. input_size: The number of expected features in the input `x`
  922. hidden_size: The number of features in the hidden state `h`
  923. bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
  924. Default: ``True``
  925. nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'``
  926. Inputs: input, hidden
  927. - **input**: tensor containing input features
  928. - **hidden**: tensor containing the initial hidden state
  929. Defaults to zero if not provided.
  930. Outputs: h'
  931. - **h'** of shape `(batch, hidden_size)`: tensor containing the next hidden state
  932. for each element in the batch
  933. Shape:
  934. - input: :math:`(N, H_{in})` or :math:`(H_{in})` tensor containing input features where
  935. :math:`H_{in}` = `input_size`.
  936. - hidden: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the initial hidden
  937. state where :math:`H_{out}` = `hidden_size`. Defaults to zero if not provided.
  938. - output: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the next hidden state.
  939. Attributes:
  940. weight_ih: the learnable input-hidden weights, of shape
  941. `(hidden_size, input_size)`
  942. weight_hh: the learnable hidden-hidden weights, of shape
  943. `(hidden_size, hidden_size)`
  944. bias_ih: the learnable input-hidden bias, of shape `(hidden_size)`
  945. bias_hh: the learnable hidden-hidden bias, of shape `(hidden_size)`
  946. .. note::
  947. All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
  948. where :math:`k = \frac{1}{\text{hidden\_size}}`
  949. Examples::
  950. >>> rnn = nn.RNNCell(10, 20)
  951. >>> input = torch.randn(6, 3, 10)
  952. >>> hx = torch.randn(3, 20)
  953. >>> output = []
  954. >>> for i in range(6):
  955. ... hx = rnn(input[i], hx)
  956. ... output.append(hx)
  957. """
  958. __constants__ = ['input_size', 'hidden_size', 'bias', 'nonlinearity']
  959. nonlinearity: str
  960. def __init__(self, input_size: int, hidden_size: int, bias: bool = True, nonlinearity: str = "tanh",
  961. device=None, dtype=None) -> None:
  962. factory_kwargs = {'device': device, 'dtype': dtype}
  963. super().__init__(input_size, hidden_size, bias, num_chunks=1, **factory_kwargs)
  964. self.nonlinearity = nonlinearity
  965. def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
  966. assert input.dim() in (1, 2), \
  967. f"RNNCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor"
  968. is_batched = input.dim() == 2
  969. if not is_batched:
  970. input = input.unsqueeze(0)
  971. if hx is None:
  972. hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
  973. else:
  974. hx = hx.unsqueeze(0) if not is_batched else hx
  975. if self.nonlinearity == "tanh":
  976. ret = _VF.rnn_tanh_cell(
  977. input, hx,
  978. self.weight_ih, self.weight_hh,
  979. self.bias_ih, self.bias_hh,
  980. )
  981. elif self.nonlinearity == "relu":
  982. ret = _VF.rnn_relu_cell(
  983. input, hx,
  984. self.weight_ih, self.weight_hh,
  985. self.bias_ih, self.bias_hh,
  986. )
  987. else:
  988. ret = input # TODO: remove when jit supports exception flow
  989. raise RuntimeError(
  990. "Unknown nonlinearity: {}".format(self.nonlinearity))
  991. if not is_batched:
  992. ret = ret.squeeze(0)
  993. return ret
  994. class LSTMCell(RNNCellBase):
  995. r"""A long short-term memory (LSTM) cell.
  996. .. math::
  997. \begin{array}{ll}
  998. i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\
  999. f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\
  1000. g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\
  1001. o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\
  1002. c' = f * c + i * g \\
  1003. h' = o * \tanh(c') \\
  1004. \end{array}
  1005. where :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product.
  1006. Args:
  1007. input_size: The number of expected features in the input `x`
  1008. hidden_size: The number of features in the hidden state `h`
  1009. bias: If ``False``, then the layer does not use bias weights `b_ih` and
  1010. `b_hh`. Default: ``True``
  1011. Inputs: input, (h_0, c_0)
  1012. - **input** of shape `(batch, input_size)` or `(input_size)`: tensor containing input features
  1013. - **h_0** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the initial hidden state
  1014. - **c_0** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the initial cell state
  1015. If `(h_0, c_0)` is not provided, both **h_0** and **c_0** default to zero.
  1016. Outputs: (h_1, c_1)
  1017. - **h_1** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the next hidden state
  1018. - **c_1** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the next cell state
  1019. Attributes:
  1020. weight_ih: the learnable input-hidden weights, of shape
  1021. `(4*hidden_size, input_size)`
  1022. weight_hh: the learnable hidden-hidden weights, of shape
  1023. `(4*hidden_size, hidden_size)`
  1024. bias_ih: the learnable input-hidden bias, of shape `(4*hidden_size)`
  1025. bias_hh: the learnable hidden-hidden bias, of shape `(4*hidden_size)`
  1026. .. note::
  1027. All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
  1028. where :math:`k = \frac{1}{\text{hidden\_size}}`
  1029. On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
  1030. Examples::
  1031. >>> rnn = nn.LSTMCell(10, 20) # (input_size, hidden_size)
  1032. >>> input = torch.randn(2, 3, 10) # (time_steps, batch, input_size)
  1033. >>> hx = torch.randn(3, 20) # (batch, hidden_size)
  1034. >>> cx = torch.randn(3, 20)
  1035. >>> output = []
  1036. >>> for i in range(input.size()[0]):
  1037. ... hx, cx = rnn(input[i], (hx, cx))
  1038. ... output.append(hx)
  1039. >>> output = torch.stack(output, dim=0)
  1040. """
  1041. def __init__(self, input_size: int, hidden_size: int, bias: bool = True,
  1042. device=None, dtype=None) -> None:
  1043. factory_kwargs = {'device': device, 'dtype': dtype}
  1044. super().__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs)
  1045. def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
  1046. assert input.dim() in (1, 2), \
  1047. f"LSTMCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor"
  1048. is_batched = input.dim() == 2
  1049. if not is_batched:
  1050. input = input.unsqueeze(0)
  1051. if hx is None:
  1052. zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
  1053. hx = (zeros, zeros)
  1054. else:
  1055. hx = (hx[0].unsqueeze(0), hx[1].unsqueeze(0)) if not is_batched else hx
  1056. ret = _VF.lstm_cell(
  1057. input, hx,
  1058. self.weight_ih, self.weight_hh,
  1059. self.bias_ih, self.bias_hh,
  1060. )
  1061. if not is_batched:
  1062. ret = (ret[0].squeeze(0), ret[1].squeeze(0))
  1063. return ret
  1064. class GRUCell(RNNCellBase):
  1065. r"""A gated recurrent unit (GRU) cell
  1066. .. math::
  1067. \begin{array}{ll}
  1068. r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\
  1069. z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\
  1070. n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\
  1071. h' = (1 - z) * n + z * h
  1072. \end{array}
  1073. where :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product.
  1074. Args:
  1075. input_size: The number of expected features in the input `x`
  1076. hidden_size: The number of features in the hidden state `h`
  1077. bias: If ``False``, then the layer does not use bias weights `b_ih` and
  1078. `b_hh`. Default: ``True``
  1079. Inputs: input, hidden
  1080. - **input** : tensor containing input features
  1081. - **hidden** : tensor containing the initial hidden
  1082. state for each element in the batch.
  1083. Defaults to zero if not provided.
  1084. Outputs: h'
  1085. - **h'** : tensor containing the next hidden state
  1086. for each element in the batch
  1087. Shape:
  1088. - input: :math:`(N, H_{in})` or :math:`(H_{in})` tensor containing input features where
  1089. :math:`H_{in}` = `input_size`.
  1090. - hidden: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the initial hidden
  1091. state where :math:`H_{out}` = `hidden_size`. Defaults to zero if not provided.
  1092. - output: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the next hidden state.
  1093. Attributes:
  1094. weight_ih: the learnable input-hidden weights, of shape
  1095. `(3*hidden_size, input_size)`
  1096. weight_hh: the learnable hidden-hidden weights, of shape
  1097. `(3*hidden_size, hidden_size)`
  1098. bias_ih: the learnable input-hidden bias, of shape `(3*hidden_size)`
  1099. bias_hh: the learnable hidden-hidden bias, of shape `(3*hidden_size)`
  1100. .. note::
  1101. All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
  1102. where :math:`k = \frac{1}{\text{hidden\_size}}`
  1103. On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
  1104. Examples::
  1105. >>> rnn = nn.GRUCell(10, 20)
  1106. >>> input = torch.randn(6, 3, 10)
  1107. >>> hx = torch.randn(3, 20)
  1108. >>> output = []
  1109. >>> for i in range(6):
  1110. ... hx = rnn(input[i], hx)
  1111. ... output.append(hx)
  1112. """
  1113. def __init__(self, input_size: int, hidden_size: int, bias: bool = True,
  1114. device=None, dtype=None) -> None:
  1115. factory_kwargs = {'device': device, 'dtype': dtype}
  1116. super().__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs)
  1117. def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
  1118. assert input.dim() in (1, 2), \
  1119. f"GRUCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor"
  1120. is_batched = input.dim() == 2
  1121. if not is_batched:
  1122. input = input.unsqueeze(0)
  1123. if hx is None:
  1124. hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
  1125. else:
  1126. hx = hx.unsqueeze(0) if not is_batched else hx
  1127. ret = _VF.gru_cell(
  1128. input, hx,
  1129. self.weight_ih, self.weight_hh,
  1130. self.bias_ih, self.bias_hh,
  1131. )
  1132. if not is_batched:
  1133. ret = ret.squeeze(0)
  1134. return ret