rnn.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519
  1. from collections import namedtuple
  2. import warnings
  3. import torch
  4. from torch import Tensor
  5. from ... import _VF
  6. from ..._jit_internal import Optional
  7. from typing import List, Tuple, Union, Iterable
  8. __all__ = ['PackedSequence', 'invert_permutation', 'pack_padded_sequence', 'pad_packed_sequence', 'pad_sequence',
  9. 'unpad_sequence', 'pack_sequence', 'unpack_sequence']
  10. PackedSequence_ = namedtuple('PackedSequence_',
  11. ['data', 'batch_sizes', 'sorted_indices', 'unsorted_indices'])
  12. # type annotation for PackedSequence_ to make it compatible with TorchScript
  13. PackedSequence_.__annotations__ = {'data': torch.Tensor, 'batch_sizes': torch.Tensor,
  14. 'sorted_indices': Optional[torch.Tensor],
  15. 'unsorted_indices': Optional[torch.Tensor]}
  16. def bind(optional, fn):
  17. if optional is None:
  18. return None
  19. return fn(optional)
  20. class PackedSequence(PackedSequence_):
  21. r"""Holds the data and list of :attr:`batch_sizes` of a packed sequence.
  22. All RNN modules accept packed sequences as inputs.
  23. Note:
  24. Instances of this class should never be created manually. They are meant
  25. to be instantiated by functions like :func:`pack_padded_sequence`.
  26. Batch sizes represent the number elements at each sequence step in
  27. the batch, not the varying sequence lengths passed to
  28. :func:`pack_padded_sequence`. For instance, given data ``abc`` and ``x``
  29. the :class:`PackedSequence` would contain data ``axbc`` with
  30. ``batch_sizes=[2,1,1]``.
  31. Attributes:
  32. data (Tensor): Tensor containing packed sequence
  33. batch_sizes (Tensor): Tensor of integers holding
  34. information about the batch size at each sequence step
  35. sorted_indices (Tensor, optional): Tensor of integers holding how this
  36. :class:`PackedSequence` is constructed from sequences.
  37. unsorted_indices (Tensor, optional): Tensor of integers holding how this
  38. to recover the original sequences with correct order.
  39. .. note::
  40. :attr:`data` can be on arbitrary device and of arbitrary dtype.
  41. :attr:`sorted_indices` and :attr:`unsorted_indices` must be ``torch.int64``
  42. tensors on the same device as :attr:`data`.
  43. However, :attr:`batch_sizes` should always be a CPU ``torch.int64`` tensor.
  44. This invariant is maintained throughout :class:`PackedSequence` class,
  45. and all functions that construct a `:class:PackedSequence` in PyTorch
  46. (i.e., they only pass in tensors conforming to this constraint).
  47. """
  48. def __new__(cls, data, batch_sizes=None, sorted_indices=None, unsorted_indices=None):
  49. return super(PackedSequence, cls).__new__(
  50. cls,
  51. *_packed_sequence_init_args(data, batch_sizes, sorted_indices,
  52. unsorted_indices))
  53. # NOTE [ device and dtype of a PackedSequence ]
  54. #
  55. # See the note above in doc string (starting with ":attr:`data` can be on
  56. # arbitrary device...").
  57. def pin_memory(self):
  58. # Why not convert `batch_sizes`?
  59. # See NOTE [ device and dtype of a PackedSequence ]
  60. return type(self)(self.data.pin_memory(), self.batch_sizes,
  61. bind(self.sorted_indices, lambda t: t.pin_memory()),
  62. bind(self.unsorted_indices, lambda t: t.pin_memory()))
  63. def cuda(self, *args, **kwargs):
  64. # Tests to see if 'cuda' should be added to kwargs
  65. ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to(*args, **kwargs)
  66. if ex.is_cuda:
  67. return self.to(*args, **kwargs)
  68. return self.to(*args, device='cuda', **kwargs)
  69. def cpu(self, *args, **kwargs):
  70. ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to(*args, **kwargs)
  71. if ex.device.type == 'cpu':
  72. return self.to(*args, **kwargs)
  73. return self.to(*args, device='cpu', **kwargs)
  74. def double(self):
  75. return self.to(dtype=torch.double)
  76. def float(self):
  77. return self.to(dtype=torch.float)
  78. def half(self):
  79. return self.to(dtype=torch.half)
  80. def long(self):
  81. return self.to(dtype=torch.long)
  82. def int(self):
  83. return self.to(dtype=torch.int)
  84. def short(self):
  85. return self.to(dtype=torch.short)
  86. def char(self):
  87. return self.to(dtype=torch.int8)
  88. def byte(self):
  89. return self.to(dtype=torch.uint8)
  90. def to(self, *args, **kwargs):
  91. r"""Performs dtype and/or device conversion on `self.data`.
  92. It has similar signature as :meth:`torch.Tensor.to`, except optional
  93. arguments like `non_blocking` and `copy` should be passed as kwargs,
  94. not args, or they will not apply to the index tensors.
  95. .. note::
  96. If the ``self.data`` Tensor already has the correct :class:`torch.dtype`
  97. and :class:`torch.device`, then ``self`` is returned.
  98. Otherwise, returns a copy with the desired configuration.
  99. """
  100. # Why not convert `batch_sizes`?
  101. # See NOTE [ device and dtype of a PackedSequence ]
  102. data = self.data.to(*args, **kwargs)
  103. if data is self.data:
  104. return self
  105. else:
  106. # Does not forward device or dtype arg/kwargs, device is set from data.device
  107. kwargs = {k : v for k, v in filter(lambda t: t[0] != 'device' and t[0] != 'dtype', kwargs.items())}
  108. sorted_indices = bind(self.sorted_indices, lambda t: t.to(data.device, **kwargs))
  109. unsorted_indices = bind(self.unsorted_indices, lambda t: t.to(data.device, **kwargs))
  110. return type(self)(data, self.batch_sizes, sorted_indices, unsorted_indices)
  111. @property
  112. def is_cuda(self):
  113. r"""Returns true if `self.data` stored on a gpu"""
  114. return self.data.is_cuda
  115. def is_pinned(self):
  116. r"""Returns true if `self.data` stored on in pinned memory"""
  117. return self.data.is_pinned()
  118. # TorchScript doesn't support constructors on named tuples, so we use this helper
  119. # method to construct PackedSequence
  120. def _packed_sequence_init_args(
  121. data: Tensor,
  122. batch_sizes: Optional[Tensor] = None,
  123. sorted_indices: Optional[Tensor] = None,
  124. unsorted_indices: Optional[Tensor] = None,
  125. ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
  126. # NB: if unsorted_indices is provided, it should be the inverse permutation
  127. # to sorted_indices. Don't assert it here because the PackedSequence ctor
  128. # should only be used internally.
  129. if unsorted_indices is None:
  130. unsorted_indices = invert_permutation(sorted_indices)
  131. # support being called as `PackedSequence(data, batch_sizes, sorted_indices)`
  132. if batch_sizes is not None:
  133. # TODO: Re-enable this check (.type isn't supported in TorchScript)
  134. if batch_sizes.device.type != 'cpu':
  135. raise ValueError(
  136. "batch_sizes should always be on CPU. "
  137. "Instances of PackedSequence should never be created manually. "
  138. "They should be instantiated by functions like pack_sequence "
  139. "and pack_padded_sequences in nn.utils.rnn. "
  140. "https://pytorch.org/docs/stable/nn.html#torch.nn.utils.rnn.pack_sequence")
  141. return data, batch_sizes, sorted_indices, unsorted_indices
  142. # support being called as `PackedSequence((data, batch_sizes), *, sorted_indices)`
  143. else:
  144. assert isinstance(data, (list, tuple)) and len(data) == 2
  145. return data[0], data[1], sorted_indices, unsorted_indices
  146. def _packed_sequence_init(
  147. data: Tensor,
  148. batch_sizes: Optional[Tensor] = None,
  149. sorted_indices: Optional[Tensor] = None,
  150. unsorted_indices: Optional[Tensor] = None,
  151. ) -> PackedSequence:
  152. data, batch_sizes, sorted_indices, unsorted_indices = _packed_sequence_init_args(
  153. data, batch_sizes, sorted_indices, unsorted_indices)
  154. return PackedSequence(data, batch_sizes, sorted_indices, unsorted_indices)
  155. def invert_permutation(permutation: Optional[Tensor]) -> Optional[Tensor]:
  156. if permutation is None:
  157. return None
  158. output = torch.empty_like(permutation, memory_format=torch.legacy_contiguous_format)
  159. output.scatter_(0, permutation,
  160. torch.arange(0, permutation.numel(), device=permutation.device))
  161. return output
  162. def pack_padded_sequence(
  163. input: Tensor,
  164. lengths: Tensor,
  165. batch_first: bool = False,
  166. enforce_sorted: bool = True,
  167. ) -> PackedSequence:
  168. r"""Packs a Tensor containing padded sequences of variable length.
  169. :attr:`input` can be of size ``T x B x *`` where `T` is the length of the
  170. longest sequence (equal to ``lengths[0]``), ``B`` is the batch size, and
  171. ``*`` is any number of dimensions (including 0). If ``batch_first`` is
  172. ``True``, ``B x T x *`` :attr:`input` is expected.
  173. For unsorted sequences, use `enforce_sorted = False`. If :attr:`enforce_sorted` is
  174. ``True``, the sequences should be sorted by length in a decreasing order, i.e.
  175. ``input[:,0]`` should be the longest sequence, and ``input[:,B-1]`` the shortest
  176. one. `enforce_sorted = True` is only necessary for ONNX export.
  177. Note:
  178. This function accepts any input that has at least two dimensions. You
  179. can apply it to pack the labels, and use the output of the RNN with
  180. them to compute the loss directly. A Tensor can be retrieved from
  181. a :class:`PackedSequence` object by accessing its ``.data`` attribute.
  182. Args:
  183. input (Tensor): padded batch of variable length sequences.
  184. lengths (Tensor or list(int)): list of sequence lengths of each batch
  185. element (must be on the CPU if provided as a tensor).
  186. batch_first (bool, optional): if ``True``, the input is expected in ``B x T x *``
  187. format.
  188. enforce_sorted (bool, optional): if ``True``, the input is expected to
  189. contain sequences sorted by length in a decreasing order. If
  190. ``False``, the input will get sorted unconditionally. Default: ``True``.
  191. Returns:
  192. a :class:`PackedSequence` object
  193. """
  194. if torch._C._get_tracing_state() and not isinstance(lengths, torch.Tensor):
  195. warnings.warn('pack_padded_sequence has been called with a Python list of '
  196. 'sequence lengths. The tracer cannot track the data flow of Python '
  197. 'values, and it will treat them as constants, likely rendering '
  198. 'the trace incorrect for any other combination of lengths.',
  199. stacklevel=2)
  200. lengths = torch.as_tensor(lengths, dtype=torch.int64)
  201. if enforce_sorted:
  202. sorted_indices = None
  203. else:
  204. lengths, sorted_indices = torch.sort(lengths, descending=True)
  205. sorted_indices = sorted_indices.to(input.device)
  206. batch_dim = 0 if batch_first else 1
  207. input = input.index_select(batch_dim, sorted_indices)
  208. data, batch_sizes = \
  209. _VF._pack_padded_sequence(input, lengths, batch_first)
  210. return _packed_sequence_init(data, batch_sizes, sorted_indices, None)
  211. def pad_packed_sequence(
  212. sequence: PackedSequence,
  213. batch_first: bool = False,
  214. padding_value: float = 0.0,
  215. total_length: Optional[int] = None,
  216. ) -> Tuple[Tensor, Tensor]:
  217. r"""Pads a packed batch of variable length sequences.
  218. It is an inverse operation to :func:`pack_padded_sequence`.
  219. The returned Tensor's data will be of size ``T x B x *``, where `T` is the length
  220. of the longest sequence and `B` is the batch size. If ``batch_first`` is True,
  221. the data will be transposed into ``B x T x *`` format.
  222. Example:
  223. >>> from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
  224. >>> seq = torch.tensor([[1, 2, 0], [3, 0, 0], [4, 5, 6]])
  225. >>> lens = [2, 1, 3]
  226. >>> packed = pack_padded_sequence(seq, lens, batch_first=True, enforce_sorted=False)
  227. >>> packed
  228. PackedSequence(data=tensor([4, 1, 3, 5, 2, 6]), batch_sizes=tensor([3, 2, 1]),
  229. sorted_indices=tensor([2, 0, 1]), unsorted_indices=tensor([1, 2, 0]))
  230. >>> seq_unpacked, lens_unpacked = pad_packed_sequence(packed, batch_first=True)
  231. >>> seq_unpacked
  232. tensor([[1, 2, 0],
  233. [3, 0, 0],
  234. [4, 5, 6]])
  235. >>> lens_unpacked
  236. tensor([2, 1, 3])
  237. .. note::
  238. :attr:`total_length` is useful to implement the
  239. ``pack sequence -> recurrent network -> unpack sequence`` pattern in a
  240. :class:`~torch.nn.Module` wrapped in :class:`~torch.nn.DataParallel`.
  241. See :ref:`this FAQ section <pack-rnn-unpack-with-data-parallelism>` for
  242. details.
  243. Args:
  244. sequence (PackedSequence): batch to pad
  245. batch_first (bool, optional): if ``True``, the output will be in ``B x T x *``
  246. format.
  247. padding_value (float, optional): values for padded elements.
  248. total_length (int, optional): if not ``None``, the output will be padded to
  249. have length :attr:`total_length`. This method will throw :class:`ValueError`
  250. if :attr:`total_length` is less than the max sequence length in
  251. :attr:`sequence`.
  252. Returns:
  253. Tuple of Tensor containing the padded sequence, and a Tensor
  254. containing the list of lengths of each sequence in the batch.
  255. Batch elements will be re-ordered as they were ordered originally when
  256. the batch was passed to ``pack_padded_sequence`` or ``pack_sequence``.
  257. """
  258. max_seq_length = sequence.batch_sizes.size(0)
  259. if total_length is not None:
  260. if total_length < max_seq_length:
  261. raise ValueError("Expected total_length to be at least the length "
  262. "of the longest sequence in input, but got "
  263. "total_length={} and max sequence length being {}"
  264. .format(total_length, max_seq_length))
  265. max_seq_length = total_length
  266. padded_output, lengths = _VF._pad_packed_sequence(
  267. sequence.data, sequence.batch_sizes, batch_first, padding_value, max_seq_length)
  268. unsorted_indices = sequence.unsorted_indices
  269. if unsorted_indices is not None:
  270. batch_dim = 0 if batch_first else 1
  271. return padded_output.index_select(batch_dim, unsorted_indices), lengths[unsorted_indices.cpu()]
  272. return padded_output, lengths
  273. def pad_sequence(
  274. sequences: Union[Tensor, List[Tensor]],
  275. batch_first: bool = False,
  276. padding_value: float = 0.0,
  277. ) -> Tensor:
  278. r"""Pad a list of variable length Tensors with ``padding_value``
  279. ``pad_sequence`` stacks a list of Tensors along a new dimension,
  280. and pads them to equal length. For example, if the input is list of
  281. sequences with size ``L x *`` and if batch_first is False, and ``T x B x *``
  282. otherwise.
  283. `B` is batch size. It is equal to the number of elements in ``sequences``.
  284. `T` is length of the longest sequence.
  285. `L` is length of the sequence.
  286. `*` is any number of trailing dimensions, including none.
  287. Example:
  288. >>> from torch.nn.utils.rnn import pad_sequence
  289. >>> a = torch.ones(25, 300)
  290. >>> b = torch.ones(22, 300)
  291. >>> c = torch.ones(15, 300)
  292. >>> pad_sequence([a, b, c]).size()
  293. torch.Size([25, 3, 300])
  294. Note:
  295. This function returns a Tensor of size ``T x B x *`` or ``B x T x *``
  296. where `T` is the length of the longest sequence. This function assumes
  297. trailing dimensions and type of all the Tensors in sequences are same.
  298. Args:
  299. sequences (list[Tensor]): list of variable length sequences.
  300. batch_first (bool, optional): output will be in ``B x T x *`` if True, or in
  301. ``T x B x *`` otherwise. Default: False.
  302. padding_value (float, optional): value for padded elements. Default: 0.
  303. Returns:
  304. Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``.
  305. Tensor of size ``B x T x *`` otherwise
  306. """
  307. if not (torch.jit.is_tracing() or torch.jit.is_scripting()):
  308. # JIT doesn't support `Iterable`
  309. if not isinstance(sequences, Iterable):
  310. msg = ('pad_sequence: Expected iterable for input sequences, but got arg of type: '
  311. f'{type(sequences)}')
  312. raise RuntimeError(msg)
  313. # In JIT context this leads to,
  314. # RuntimeError: cannot statically infer the expected size of a list in this context
  315. sequences = tuple(sequences)
  316. else:
  317. # For JIT, we only support Union[Tensor, Tuple[Tensor]]
  318. if isinstance(sequences, torch.Tensor):
  319. sequences = sequences.unbind(0)
  320. # assuming trailing dimensions and type of all the Tensors
  321. # in sequences are same and fetching those from sequences[0]
  322. return torch._C._nn.pad_sequence(sequences, batch_first, padding_value)
  323. def unpad_sequence(
  324. padded_sequences: Tensor,
  325. lengths: Tensor,
  326. batch_first: bool = False,
  327. ) -> List[Tensor]:
  328. r"""Unpad padded Tensor into a list of variable length Tensors
  329. ``unpad_sequence`` unstacks padded Tensor into a list of variable length Tensors.
  330. Example:
  331. >>> from torch.nn.utils.rnn import pad_sequence, unpad_sequence
  332. >>> a = torch.ones(25, 300)
  333. >>> b = torch.ones(22, 300)
  334. >>> c = torch.ones(15, 300)
  335. >>> sequences = [a, b, c]
  336. >>> padded_sequences = pad_sequence(sequences)
  337. >>> lengths = torch.as_tensor([v.size(0) for v in sequences])
  338. >>> unpadded_sequences = unpad_sequence(padded_sequences, lengths)
  339. >>> torch.allclose(sequences[0], unpadded_sequences[0])
  340. True
  341. >>> torch.allclose(sequences[1], unpadded_sequences[1])
  342. True
  343. >>> torch.allclose(sequences[2], unpadded_sequences[2])
  344. True
  345. Args:
  346. padded_sequences (Tensor): padded sequences.
  347. lengths (Tensor): length of original (unpadded) sequences.
  348. batch_first (bool, optional): whether batch dimension first or not. Default: False.
  349. Returns:
  350. a list of :class:`Tensor` objects
  351. """
  352. unpadded_sequences = []
  353. if not batch_first:
  354. padded_sequences.transpose_(0, 1)
  355. max_length = padded_sequences.shape[1]
  356. idx = torch.arange(max_length)
  357. for seq, length in zip(padded_sequences, lengths):
  358. mask = idx < length
  359. unpacked_seq = seq[mask]
  360. unpadded_sequences.append(unpacked_seq)
  361. return unpadded_sequences
  362. def pack_sequence(sequences: List[Tensor], enforce_sorted: bool = True) -> PackedSequence:
  363. r"""Packs a list of variable length Tensors
  364. Consecutive call of the next functions: ``pad_sequence``, ``pack_padded_sequence``.
  365. ``sequences`` should be a list of Tensors of size ``L x *``, where `L` is
  366. the length of a sequence and `*` is any number of trailing dimensions,
  367. including zero.
  368. For unsorted sequences, use `enforce_sorted = False`. If ``enforce_sorted``
  369. is ``True``, the sequences should be sorted in the order of decreasing length.
  370. ``enforce_sorted = True`` is only necessary for ONNX export.
  371. Example:
  372. >>> from torch.nn.utils.rnn import pack_sequence
  373. >>> a = torch.tensor([1, 2, 3])
  374. >>> b = torch.tensor([4, 5])
  375. >>> c = torch.tensor([6])
  376. >>> pack_sequence([a, b, c])
  377. PackedSequence(data=tensor([1, 4, 6, 2, 5, 3]), batch_sizes=tensor([3, 2, 1]), sorted_indices=None, unsorted_indices=None)
  378. Args:
  379. sequences (list[Tensor]): A list of sequences of decreasing length.
  380. enforce_sorted (bool, optional): if ``True``, checks that the input
  381. contains sequences sorted by length in a decreasing order. If
  382. ``False``, this condition is not checked. Default: ``True``.
  383. Returns:
  384. a :class:`PackedSequence` object
  385. """
  386. lengths = torch.as_tensor([v.size(0) for v in sequences])
  387. return pack_padded_sequence(pad_sequence(sequences), lengths, enforce_sorted=enforce_sorted)
  388. def unpack_sequence(packed_sequences: PackedSequence) -> List[Tensor]:
  389. r"""Unpacks PackedSequence into a list of variable length Tensors
  390. ``packed_sequences`` should be a PackedSequence object.
  391. Example:
  392. >>> from torch.nn.utils.rnn import pack_sequence, unpack_sequence
  393. >>> a = torch.tensor([1, 2, 3])
  394. >>> b = torch.tensor([4, 5])
  395. >>> c = torch.tensor([6])
  396. >>> sequences = [a, b, c]
  397. >>> print(sequences)
  398. [tensor([1, 2, 3]), tensor([4, 5]), tensor([6])]
  399. >>> packed_sequences = pack_sequence(sequences)
  400. >>> print(packed_sequences)
  401. PackedSequence(data=tensor([1, 4, 6, 2, 5, 3]), batch_sizes=tensor([3, 2, 1]), sorted_indices=None, unsorted_indices=None)
  402. >>> unpacked_sequences = unpack_sequence(packed_sequences)
  403. >>> print(unpacked_sequences)
  404. [tensor([1, 2, 3]), tensor([4, 5]), tensor([6])]
  405. Args:
  406. packed_sequences (PackedSequence): A PackedSequence object.
  407. Returns:
  408. a list of :class:`Tensor` objects
  409. """
  410. padded_sequences, lengths = pad_packed_sequence(packed_sequences, batch_first=True)
  411. unpacked_sequences = unpad_sequence(padded_sequences, lengths, batch_first=True)
  412. return unpacked_sequences