rnn.py 47 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081
  1. import numbers
  2. import warnings
  3. import torch
  4. import torch.nn as nn
  5. from torch import Tensor # noqa: F401
  6. from torch._jit_internal import Tuple, Optional, List, Union, Dict # noqa: F401
  7. from torch.nn.utils.rnn import PackedSequence
  8. from torch.ao.nn.quantized.modules.utils import _quantize_weight
  9. __all__ = ['pack_weight_bias', 'PackedParameter', 'RNNBase', 'LSTM', 'GRU', 'RNNCellBase', 'RNNCell', 'LSTMCell',
  10. 'GRUCell', "apply_permutation"]
  11. def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
  12. return tensor.index_select(dim, permutation)
  13. def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
  14. warnings.warn("apply_permutation is deprecated, please use tensor.index_select(dim, permutation) instead")
  15. return _apply_permutation(tensor, permutation, dim)
  16. def pack_weight_bias(qweight, bias, dtype):
  17. if dtype == torch.qint8:
  18. # for each layer, for each direction we need to quantize and pack
  19. # weights and pack parameters in this order:
  20. #
  21. # w_ih, w_hh
  22. packed_weight = \
  23. torch.ops.quantized.linear_prepack(qweight, bias)
  24. return packed_weight
  25. else:
  26. # for each layer, for each direction we need to quantize and pack
  27. # weights and pack parameters in this order:
  28. #
  29. # packed_ih, packed_hh, b_ih, b_hh
  30. packed_weight = torch.ops.quantized.linear_prepack_fp16(
  31. qweight, bias)
  32. return packed_weight
  33. class PackedParameter(torch.nn.Module):
  34. def __init__(self, param):
  35. super().__init__()
  36. self.param = param
  37. def _save_to_state_dict(self, destination, prefix, keep_vars):
  38. super()._save_to_state_dict(destination, prefix, keep_vars)
  39. destination[prefix + 'param'] = self.param
  40. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
  41. missing_keys, unexpected_keys, error_msgs):
  42. self.param = state_dict[prefix + 'param']
  43. super()._load_from_state_dict(state_dict, prefix, local_metadata, False,
  44. missing_keys, unexpected_keys, error_msgs)
  45. class RNNBase(torch.nn.Module):
  46. _FLOAT_MODULE = nn.RNNBase
  47. _version = 2
  48. def __init__(self, mode, input_size, hidden_size,
  49. num_layers=1, bias=True, batch_first=False,
  50. dropout=0., bidirectional=False, dtype=torch.qint8):
  51. super().__init__()
  52. self.mode = mode
  53. self.input_size = input_size
  54. self.hidden_size = hidden_size
  55. self.num_layers = num_layers
  56. self.bias = bias
  57. self.batch_first = batch_first
  58. self.dropout = float(dropout)
  59. self.bidirectional = bidirectional
  60. self.dtype = dtype
  61. self.version = 2
  62. self.training = False
  63. num_directions = 2 if bidirectional else 1
  64. # "type: ignore" is required since ints and Numbers are not fully comparable
  65. # https://github.com/python/mypy/issues/8566
  66. if not isinstance(dropout, numbers.Number) \
  67. or not 0 <= dropout <= 1 or isinstance(dropout, bool): # type: ignore[operator]
  68. raise ValueError("dropout should be a number in range [0, 1] "
  69. "representing the probability of an element being "
  70. "zeroed")
  71. if dropout > 0 and num_layers == 1: # type: ignore[operator]
  72. warnings.warn("dropout option adds dropout after all but last "
  73. "recurrent layer, so non-zero dropout expects "
  74. "num_layers greater than 1, but got dropout={} and "
  75. "num_layers={}".format(dropout, num_layers))
  76. if mode == 'LSTM':
  77. gate_size = 4 * hidden_size
  78. elif mode == 'GRU':
  79. gate_size = 3 * hidden_size
  80. else:
  81. raise ValueError("Unrecognized RNN mode: " + mode)
  82. _all_weight_values = []
  83. for layer in range(num_layers):
  84. for direction in range(num_directions):
  85. layer_input_size = input_size if layer == 0 else hidden_size * num_directions
  86. w_ih = torch.randn(gate_size, layer_input_size).to(torch.float)
  87. w_hh = torch.randn(gate_size, hidden_size).to(torch.float)
  88. b_ih = torch.randn(gate_size).to(torch.float)
  89. b_hh = torch.randn(gate_size).to(torch.float)
  90. if dtype == torch.qint8:
  91. w_ih = torch.quantize_per_tensor(w_ih, scale=0.1, zero_point=0, dtype=torch.qint8)
  92. w_hh = torch.quantize_per_tensor(w_hh, scale=0.1, zero_point=0, dtype=torch.qint8)
  93. packed_ih = \
  94. torch.ops.quantized.linear_prepack(w_ih, b_ih)
  95. packed_hh = \
  96. torch.ops.quantized.linear_prepack(w_hh, b_hh)
  97. if self.version is None or self.version < 2:
  98. cell_params = torch.ops.quantized.make_quantized_cell_params_dynamic(
  99. packed_ih, packed_hh, b_ih, b_hh)
  100. else:
  101. cell_params = torch.ops.quantized.make_quantized_cell_params_dynamic(
  102. packed_ih, packed_hh, b_ih, b_hh, True)
  103. else:
  104. packed_ih = torch.ops.quantized.linear_prepack_fp16(w_ih, b_ih)
  105. packed_hh = torch.ops.quantized.linear_prepack_fp16(w_hh, b_hh)
  106. cell_params = torch.ops.quantized.make_quantized_cell_params_fp16(
  107. packed_ih, packed_hh)
  108. _all_weight_values.append(PackedParameter(cell_params))
  109. self._all_weight_values = torch.nn.ModuleList(_all_weight_values)
  110. def _get_name(self):
  111. return 'DynamicQuantizedRNN'
  112. def extra_repr(self):
  113. s = '{input_size}, {hidden_size}'
  114. if self.num_layers != 1:
  115. s += ', num_layers={num_layers}'
  116. if self.bias is not True:
  117. s += ', bias={bias}'
  118. if self.batch_first is not False:
  119. s += ', batch_first={batch_first}'
  120. if self.dropout != 0:
  121. s += ', dropout={dropout}'
  122. if self.bidirectional is not False:
  123. s += ', bidirectional={bidirectional}'
  124. return s.format(**self.__dict__)
  125. def __repr__(self):
  126. # We don't want to show `ModuleList` children, hence custom
  127. # `__repr__`. This is the same as nn.Module.__repr__, except the check
  128. # for the `PackedParameter` and `nn.ModuleList`.
  129. # You should still override `extra_repr` to add more info.
  130. extra_lines = []
  131. extra_repr = self.extra_repr()
  132. # empty string will be split into list ['']
  133. if extra_repr:
  134. extra_lines = extra_repr.split('\n')
  135. child_lines = []
  136. for key, module in self._modules.items():
  137. if isinstance(module, (PackedParameter, nn.ModuleList)):
  138. continue
  139. mod_str = repr(module)
  140. mod_str = nn.modules.module._addindent(mod_str, 2)
  141. child_lines.append('(' + key + '): ' + mod_str)
  142. lines = extra_lines + child_lines
  143. main_str = self._get_name() + '('
  144. if lines:
  145. # simple one-liner info, which most builtin Modules will use
  146. if len(extra_lines) == 1 and not child_lines:
  147. main_str += extra_lines[0]
  148. else:
  149. main_str += '\n ' + '\n '.join(lines) + '\n'
  150. main_str += ')'
  151. return main_str
  152. def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None:
  153. expected_input_dim = 2 if batch_sizes is not None else 3
  154. if input.dim() != expected_input_dim:
  155. raise RuntimeError(
  156. 'input must have {} dimensions, got {}'.format(
  157. expected_input_dim, input.dim()))
  158. if self.input_size != input.size(-1):
  159. raise RuntimeError(
  160. 'input.size(-1) must be equal to input_size. Expected {}, got {}'.format(
  161. self.input_size, input.size(-1)))
  162. def get_expected_hidden_size(self, input: Tensor, batch_sizes: Optional[Tensor]) -> Tuple[int, int, int]:
  163. if batch_sizes is not None:
  164. mini_batch = int(batch_sizes[0])
  165. else:
  166. mini_batch = input.size(0) if self.batch_first else input.size(1)
  167. num_directions = 2 if self.bidirectional else 1
  168. expected_hidden_size = (self.num_layers * num_directions,
  169. mini_batch, self.hidden_size)
  170. return expected_hidden_size
  171. def check_hidden_size(
  172. self, hx: Tensor, expected_hidden_size: Tuple[int, int, int],
  173. msg: str = 'Expected hidden size {}, got {}'
  174. ) -> None:
  175. if hx.size() != expected_hidden_size:
  176. raise RuntimeError(msg.format(
  177. expected_hidden_size, list(hx.size())))
  178. def check_forward_args(self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor]) -> None:
  179. self.check_input(input, batch_sizes)
  180. expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
  181. self.check_hidden_size(hidden, expected_hidden_size,
  182. msg='Expected hidden size {}, got {}')
  183. def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]) -> Tensor:
  184. if permutation is None:
  185. return hx
  186. return _apply_permutation(hx, permutation)
  187. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
  188. missing_keys, unexpected_keys, error_msgs):
  189. version = local_metadata.get('version', None)
  190. self.version = version
  191. super()._load_from_state_dict(state_dict, prefix, local_metadata, False,
  192. missing_keys, unexpected_keys, error_msgs)
  193. def set_weight_bias(self, weight_bias_dict):
  194. def weight_bias_name(ihhh, layer, suffix):
  195. weight_name = "weight_{}_l{}{}".format(ihhh, layer, suffix)
  196. bias_name = "bias_{}_l{}{}".format(ihhh, layer, suffix)
  197. return weight_name, bias_name
  198. num_directions = 2 if self.bidirectional else 1
  199. # TODO: dedup with __init__ of RNNBase
  200. _all_weight_values = []
  201. for layer in range(self.num_layers):
  202. for direction in range(num_directions):
  203. suffix = "_reverse" if direction == 1 else ""
  204. w_ih_name, b_ih_name = weight_bias_name("ih", layer, suffix)
  205. w_hh_name, b_hh_name = weight_bias_name("hh", layer, suffix)
  206. w_ih = weight_bias_dict[w_ih_name]
  207. b_ih = weight_bias_dict[b_ih_name]
  208. w_hh = weight_bias_dict[w_hh_name]
  209. b_hh = weight_bias_dict[b_hh_name]
  210. if w_ih.dtype == torch.qint8:
  211. packed_ih = torch.ops.quantized.linear_prepack(w_ih, b_ih)
  212. packed_hh = torch.ops.quantized.linear_prepack(w_hh, b_hh)
  213. if self.version is None or self.version < 2:
  214. cell_params = torch.ops.quantized.make_quantized_cell_params_dynamic(
  215. packed_ih, packed_hh, b_ih, b_hh)
  216. else:
  217. cell_params = torch.ops.quantized.make_quantized_cell_params_dynamic(
  218. packed_ih, packed_hh, b_ih, b_hh, True)
  219. else:
  220. packed_ih = torch.ops.quantized.linear_prepack_fp16(w_ih, b_ih)
  221. packed_hh = torch.ops.quantized.linear_prepack_fp16(w_hh, b_hh)
  222. cell_params = torch.ops.quantized.make_quantized_cell_params_fp16(
  223. packed_ih, packed_hh)
  224. _all_weight_values.append(PackedParameter(cell_params))
  225. self._all_weight_values = torch.nn.ModuleList(_all_weight_values)
  226. @classmethod
  227. def from_float(cls, mod):
  228. assert type(mod) in {torch.nn.LSTM,
  229. torch.nn.GRU}, 'nn.quantized.dynamic.RNNBase.from_float only works for nn.LSTM and nn.GRU'
  230. assert hasattr(
  231. mod,
  232. 'qconfig'
  233. ), 'Input float module must have qconfig defined'
  234. if mod.qconfig is not None and mod.qconfig.weight is not None:
  235. weight_observer_method = mod.qconfig.weight
  236. else:
  237. # We have the circular import issues if we import the qconfig in the beginning of this file:
  238. # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
  239. # import until we need it.
  240. from torch.ao.quantization.qconfig import default_dynamic_qconfig
  241. weight_observer_method = default_dynamic_qconfig.weight
  242. dtype = weight_observer_method().dtype
  243. supported_scalar_types = [torch.qint8, torch.float16]
  244. if dtype not in supported_scalar_types:
  245. raise RuntimeError('Unsupported dtype for dynamic RNN quantization: {}'.format(dtype))
  246. # RNNBase can be either LSTM or GRU
  247. qRNNBase: Union[LSTM, GRU]
  248. if mod.mode == 'LSTM':
  249. qRNNBase = LSTM(mod.input_size, mod.hidden_size, mod.num_layers,
  250. mod.bias, mod.batch_first, mod.dropout, mod.bidirectional, dtype)
  251. elif mod.mode == 'GRU':
  252. qRNNBase = GRU(mod.input_size, mod.hidden_size, mod.num_layers,
  253. mod.bias, mod.batch_first, mod.dropout, mod.bidirectional, dtype)
  254. else:
  255. raise NotImplementedError('Only LSTM/GRU is supported for QuantizedRNN for now')
  256. num_directions = 2 if mod.bidirectional else 1
  257. assert mod.bias
  258. _all_weight_values = []
  259. for layer in range(qRNNBase.num_layers):
  260. for direction in range(num_directions):
  261. suffix = '_reverse' if direction == 1 else ''
  262. def retrieve_weight_bias(ihhh):
  263. weight_name = 'weight_{}_l{}{}'.format(ihhh, layer, suffix)
  264. bias_name = 'bias_{}_l{}{}'.format(ihhh, layer, suffix)
  265. weight = getattr(mod, weight_name)
  266. bias = getattr(mod, bias_name)
  267. return weight, bias
  268. weight_ih, bias_ih = retrieve_weight_bias('ih')
  269. weight_hh, bias_hh = retrieve_weight_bias('hh')
  270. if dtype == torch.qint8:
  271. def quantize_and_pack(w, b):
  272. weight_observer = weight_observer_method()
  273. weight_observer(w)
  274. qweight = _quantize_weight(w.float(), weight_observer)
  275. packed_weight = \
  276. torch.ops.quantized.linear_prepack(qweight, b)
  277. return packed_weight
  278. packed_ih = quantize_and_pack(weight_ih, bias_ih)
  279. packed_hh = quantize_and_pack(weight_hh, bias_hh)
  280. if qRNNBase.version is None or qRNNBase.version < 2:
  281. cell_params = torch.ops.quantized.make_quantized_cell_params_dynamic(
  282. packed_ih, packed_hh, bias_ih, bias_hh)
  283. else:
  284. cell_params = torch.ops.quantized.make_quantized_cell_params_dynamic(
  285. packed_ih, packed_hh, bias_ih, bias_hh, True)
  286. elif dtype == torch.float16:
  287. packed_ih = torch.ops.quantized.linear_prepack_fp16(
  288. weight_ih.float(), bias_ih)
  289. packed_hh = torch.ops.quantized.linear_prepack_fp16(
  290. weight_hh.float(), bias_hh)
  291. cell_params = torch.ops.quantized.make_quantized_cell_params_fp16(
  292. packed_ih, packed_hh)
  293. else:
  294. raise RuntimeError('Unsupported dtype specified for dynamic quantized LSTM!')
  295. _all_weight_values.append(PackedParameter(cell_params))
  296. qRNNBase._all_weight_values = torch.nn.ModuleList(_all_weight_values)
  297. return qRNNBase
  298. def _weight_bias(self):
  299. # Returns a dict of weights and biases
  300. weight_bias_dict: Dict[str, Dict] = {'weight' : {}, 'bias' : {}}
  301. count = 0
  302. num_directions = 2 if self.bidirectional else 1
  303. for layer in range(self.num_layers):
  304. for direction in range(num_directions):
  305. suffix = '_reverse' if direction == 1 else ''
  306. key_name1 = 'weight_ih_l{layer_idx}{suffix}'.format(layer_idx=layer, suffix=suffix)
  307. key_name2 = 'weight_hh_l{layer_idx}{suffix}'.format(layer_idx=layer, suffix=suffix)
  308. # packed weights are part of torchbind class, CellParamsSerializationType
  309. # Within the packed weight class, the weight and bias are accessible as Tensors
  310. packed_weight_bias = self._all_weight_values[count].param.__getstate__()[0][4]
  311. weight_bias_dict['weight'][key_name1] = packed_weight_bias[0].__getstate__()[0][0]
  312. weight_bias_dict['weight'][key_name2] = packed_weight_bias[1].__getstate__()[0][0]
  313. key_name1 = 'bias_ih_l{layer_idx}{suffix}'.format(layer_idx=layer, suffix=suffix)
  314. key_name2 = 'bias_hh_l{layer_idx}{suffix}'.format(layer_idx=layer, suffix=suffix)
  315. weight_bias_dict['bias'][key_name1] = packed_weight_bias[0].__getstate__()[0][1]
  316. weight_bias_dict['bias'][key_name2] = packed_weight_bias[1].__getstate__()[0][1]
  317. count = count + 1
  318. return weight_bias_dict
  319. def get_weight(self):
  320. return self._weight_bias()['weight']
  321. def get_bias(self):
  322. return self._weight_bias()['bias']
  323. class LSTM(RNNBase):
  324. r"""
  325. A dynamic quantized LSTM module with floating point tensor as inputs and outputs.
  326. We adopt the same interface as `torch.nn.LSTM`, please see
  327. https://pytorch.org/docs/stable/nn.html#torch.nn.LSTM for documentation.
  328. Examples::
  329. >>> # xdoctest: +SKIP
  330. >>> rnn = nn.LSTM(10, 20, 2)
  331. >>> input = torch.randn(5, 3, 10)
  332. >>> h0 = torch.randn(2, 3, 20)
  333. >>> c0 = torch.randn(2, 3, 20)
  334. >>> output, (hn, cn) = rnn(input, (h0, c0))
  335. """
  336. _FLOAT_MODULE = nn.LSTM
  337. __overloads__ = {'forward': ['forward_packed', 'forward_tensor']}
  338. def __init__(self, *args, **kwargs):
  339. super().__init__('LSTM', *args, **kwargs)
  340. def _get_name(self):
  341. return 'DynamicQuantizedLSTM'
  342. def forward_impl(
  343. self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]],
  344. batch_sizes: Optional[Tensor], max_batch_size: int,
  345. sorted_indices: Optional[Tensor]
  346. ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
  347. if hx is None:
  348. num_directions = 2 if self.bidirectional else 1
  349. zeros = torch.zeros(self.num_layers * num_directions,
  350. max_batch_size, self.hidden_size,
  351. dtype=input.dtype, device=input.device)
  352. hx = (zeros, zeros)
  353. else:
  354. # Each batch of the hidden state should match the input sequence that
  355. # the user believes he/she is passing in.
  356. hx = self.permute_hidden(hx, sorted_indices)
  357. self.check_forward_args(input, hx, batch_sizes)
  358. _all_params = ([m.param for m in self._all_weight_values])
  359. if batch_sizes is None:
  360. result = torch.quantized_lstm(input, hx, _all_params, self.bias, self.num_layers,
  361. float(self.dropout), self.training, self.bidirectional,
  362. self.batch_first, dtype=self.dtype, use_dynamic=True)
  363. else:
  364. result = torch.quantized_lstm(input, batch_sizes, hx, _all_params, self.bias,
  365. self.num_layers, float(self.dropout), self.training,
  366. self.bidirectional, dtype=self.dtype, use_dynamic=True)
  367. output = result[0]
  368. hidden = result[1:]
  369. return output, hidden
  370. @torch.jit.export
  371. def forward_tensor(
  372. self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None
  373. ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
  374. batch_sizes = None
  375. max_batch_size = input.size(0) if self.batch_first else input.size(1)
  376. sorted_indices = None
  377. unsorted_indices = None
  378. output, hidden = self.forward_impl(
  379. input, hx, batch_sizes, max_batch_size, sorted_indices)
  380. return output, self.permute_hidden(hidden, unsorted_indices)
  381. @torch.jit.export
  382. def forward_packed(
  383. self, input: PackedSequence, hx: Optional[Tuple[Tensor, Tensor]] = None
  384. ) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]]:
  385. input_, batch_sizes, sorted_indices, unsorted_indices = input
  386. max_batch_size = batch_sizes[0]
  387. max_batch_size = int(max_batch_size)
  388. output_, hidden = self.forward_impl(
  389. input_, hx, batch_sizes, max_batch_size, sorted_indices)
  390. output = PackedSequence(output_, batch_sizes,
  391. sorted_indices, unsorted_indices)
  392. return output, self.permute_hidden(hidden, unsorted_indices)
  393. # "type: ignore" is required due to issue #43072
  394. def permute_hidden( # type: ignore[override]
  395. self, hx: Tuple[Tensor, Tensor], permutation: Optional[Tensor]
  396. ) -> Tuple[Tensor, Tensor]:
  397. if permutation is None:
  398. return hx
  399. return _apply_permutation(hx[0], permutation), _apply_permutation(hx[1], permutation)
  400. # "type: ignore" is required due to issue #43072
  401. def check_forward_args( # type: ignore[override]
  402. self, input: Tensor, hidden: Tuple[Tensor, Tensor], batch_sizes: Optional[Tensor]
  403. ) -> None:
  404. self.check_input(input, batch_sizes)
  405. expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
  406. self.check_hidden_size(hidden[0], expected_hidden_size,
  407. 'Expected hidden[0] size {}, got {}')
  408. self.check_hidden_size(hidden[1], expected_hidden_size,
  409. 'Expected hidden[1] size {}, got {}')
  410. @torch.jit.ignore
  411. def forward(self, input, hx=None):
  412. if isinstance(input, PackedSequence):
  413. return self.forward_packed(input, hx)
  414. else:
  415. return self.forward_tensor(input, hx)
  416. @classmethod
  417. def from_float(cls, mod):
  418. return super(LSTM, cls).from_float(mod)
  419. @classmethod
  420. def from_reference(cls, ref_mod):
  421. assert hasattr(ref_mod, "weight_ih_l0_dtype"), "We are assuming weight_ih_l0 "
  422. "exists in LSTM, may need to relax the assumption to support the use case"
  423. qmod = cls(
  424. ref_mod.input_size,
  425. ref_mod.hidden_size,
  426. ref_mod.num_layers,
  427. ref_mod.bias,
  428. ref_mod.batch_first,
  429. ref_mod.dropout,
  430. ref_mod.bidirectional,
  431. # assuming there is layer 0, which should be OK
  432. ref_mod.weight_ih_l0_dtype,
  433. )
  434. qmod.set_weight_bias(ref_mod.get_quantized_weight_bias_dict())
  435. return qmod
  436. class GRU(RNNBase):
  437. r"""Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence.
  438. For each element in the input sequence, each layer computes the following
  439. function:
  440. .. math::
  441. \begin{array}{ll}
  442. r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\
  443. z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\
  444. n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\
  445. h_t = (1 - z_t) * n_t + z_t * h_{(t-1)}
  446. \end{array}
  447. where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the input
  448. at time `t`, :math:`h_{(t-1)}` is the hidden state of the layer
  449. at time `t-1` or the initial hidden state at time `0`, and :math:`r_t`,
  450. :math:`z_t`, :math:`n_t` are the reset, update, and new gates, respectively.
  451. :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product.
  452. In a multilayer GRU, the input :math:`x^{(l)}_t` of the :math:`l` -th layer
  453. (:math:`l >= 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by
  454. dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random
  455. variable which is :math:`0` with probability :attr:`dropout`.
  456. Args:
  457. input_size: The number of expected features in the input `x`
  458. hidden_size: The number of features in the hidden state `h`
  459. num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
  460. would mean stacking two GRUs together to form a `stacked GRU`,
  461. with the second GRU taking in outputs of the first GRU and
  462. computing the final results. Default: 1
  463. bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
  464. Default: ``True``
  465. batch_first: If ``True``, then the input and output tensors are provided
  466. as (batch, seq, feature). Default: ``False``
  467. dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
  468. GRU layer except the last layer, with dropout probability equal to
  469. :attr:`dropout`. Default: 0
  470. bidirectional: If ``True``, becomes a bidirectional GRU. Default: ``False``
  471. Inputs: input, h_0
  472. - **input** of shape `(seq_len, batch, input_size)`: tensor containing the features
  473. of the input sequence. The input can also be a packed variable length
  474. sequence. See :func:`torch.nn.utils.rnn.pack_padded_sequence`
  475. for details.
  476. - **h_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
  477. containing the initial hidden state for each element in the batch.
  478. Defaults to zero if not provided. If the RNN is bidirectional,
  479. num_directions should be 2, else it should be 1.
  480. Outputs: output, h_n
  481. - **output** of shape `(seq_len, batch, num_directions * hidden_size)`: tensor
  482. containing the output features h_t from the last layer of the GRU,
  483. for each `t`. If a :class:`torch.nn.utils.rnn.PackedSequence` has been
  484. given as the input, the output will also be a packed sequence.
  485. For the unpacked case, the directions can be separated
  486. using ``output.view(seq_len, batch, num_directions, hidden_size)``,
  487. with forward and backward being direction `0` and `1` respectively.
  488. Similarly, the directions can be separated in the packed case.
  489. - **h_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
  490. containing the hidden state for `t = seq_len`
  491. Like *output*, the layers can be separated using
  492. ``h_n.view(num_layers, num_directions, batch, hidden_size)``.
  493. Shape:
  494. - Input1: :math:`(L, N, H_{in})` tensor containing input features where
  495. :math:`H_{in}=\text{input\_size}` and `L` represents a sequence length.
  496. - Input2: :math:`(S, N, H_{out})` tensor
  497. containing the initial hidden state for each element in the batch.
  498. :math:`H_{out}=\text{hidden\_size}`
  499. Defaults to zero if not provided. where :math:`S=\text{num\_layers} * \text{num\_directions}`
  500. If the RNN is bidirectional, num_directions should be 2, else it should be 1.
  501. - Output1: :math:`(L, N, H_{all})` where :math:`H_{all}=\text{num\_directions} * \text{hidden\_size}`
  502. - Output2: :math:`(S, N, H_{out})` tensor containing the next hidden state
  503. for each element in the batch
  504. Attributes:
  505. weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer
  506. (W_ir|W_iz|W_in), of shape `(3*hidden_size, input_size)` for `k = 0`.
  507. Otherwise, the shape is `(3*hidden_size, num_directions * hidden_size)`
  508. weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer
  509. (W_hr|W_hz|W_hn), of shape `(3*hidden_size, hidden_size)`
  510. bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer
  511. (b_ir|b_iz|b_in), of shape `(3*hidden_size)`
  512. bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer
  513. (b_hr|b_hz|b_hn), of shape `(3*hidden_size)`
  514. .. note::
  515. All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
  516. where :math:`k = \frac{1}{\text{hidden\_size}}`
  517. .. include:: ../cudnn_persistent_rnn.rst
  518. Examples::
  519. >>> # xdoctest: +SKIP
  520. >>> rnn = nn.GRU(10, 20, 2)
  521. >>> input = torch.randn(5, 3, 10)
  522. >>> h0 = torch.randn(2, 3, 20)
  523. >>> output, hn = rnn(input, h0)
  524. """
  525. _FLOAT_MODULE = nn.GRU
  526. __overloads__ = {'forward': ['forward_packed', 'forward_tensor']}
  527. def __init__(self, *args, **kwargs):
  528. super().__init__('GRU', *args, **kwargs)
  529. def _get_name(self):
  530. return 'DynamicQuantizedGRU'
  531. def check_forward_args(self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor]) -> None:
  532. self.check_input(input, batch_sizes)
  533. expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
  534. self.check_hidden_size(hidden, expected_hidden_size,
  535. 'Expected hidden size {}, got {}')
  536. def forward_impl(
  537. self, input: Tensor, hx: Optional[Tensor],
  538. batch_sizes: Optional[Tensor], max_batch_size: int,
  539. sorted_indices: Optional[Tensor]
  540. ) -> Tuple[Tensor, Tensor]:
  541. if hx is None:
  542. num_directions = 2 if self.bidirectional else 1
  543. zeros = torch.zeros(self.num_layers * num_directions,
  544. max_batch_size, self.hidden_size,
  545. dtype=input.dtype, device=input.device)
  546. hx = zeros
  547. else:
  548. # Each batch of the hidden state should match the input sequence that
  549. # the user believes he/she is passing in.
  550. hx = self.permute_hidden(hx, sorted_indices)
  551. self.check_forward_args(input, hx, batch_sizes)
  552. _all_params = ([m.param for m in self._all_weight_values])
  553. if batch_sizes is None:
  554. result = torch.quantized_gru(input,
  555. hx,
  556. _all_params,
  557. self.bias,
  558. self.num_layers,
  559. self.dropout,
  560. self.training,
  561. self.bidirectional,
  562. self.batch_first)
  563. else:
  564. result = torch.quantized_gru(input,
  565. batch_sizes,
  566. hx,
  567. _all_params,
  568. self.bias,
  569. self.num_layers,
  570. self.dropout,
  571. self.training,
  572. self.bidirectional)
  573. output = result[0]
  574. hidden = result[1]
  575. return output, hidden
  576. @torch.jit.export
  577. def forward_tensor(
  578. self, input: Tensor, hx: Optional[Tensor] = None
  579. ) -> Tuple[Tensor, Tensor]:
  580. batch_sizes = None
  581. max_batch_size = input.size(0) if self.batch_first else input.size(1)
  582. sorted_indices = None
  583. unsorted_indices = None
  584. output, hidden = self.forward_impl(
  585. input, hx, batch_sizes, max_batch_size, sorted_indices)
  586. return output, self.permute_hidden(hidden, unsorted_indices)
  587. @torch.jit.export
  588. def forward_packed(
  589. self, input: PackedSequence, hx: Optional[Tensor] = None
  590. ) -> Tuple[PackedSequence, Tensor]:
  591. input_, batch_sizes, sorted_indices, unsorted_indices = input
  592. max_batch_size = batch_sizes[0]
  593. max_batch_size = int(max_batch_size)
  594. output_, hidden = self.forward_impl(
  595. input_, hx, batch_sizes, max_batch_size, sorted_indices)
  596. output = PackedSequence(output_, batch_sizes,
  597. sorted_indices, unsorted_indices)
  598. return output, self.permute_hidden(hidden, unsorted_indices)
  599. def permute_hidden(
  600. self, hx: Tensor, permutation: Optional[Tensor]
  601. ) -> Tensor:
  602. if permutation is None:
  603. return hx
  604. return _apply_permutation(hx, permutation)
  605. @torch.jit.ignore
  606. def forward(self, input, hx=None):
  607. if isinstance(input, PackedSequence):
  608. return self.forward_packed(input, hx)
  609. else:
  610. return self.forward_tensor(input, hx)
  611. @classmethod
  612. def from_float(cls, mod):
  613. return super(GRU, cls).from_float(mod)
  614. @classmethod
  615. def from_reference(cls, ref_mod):
  616. assert hasattr(ref_mod, "weight_ih_l0_dtype"), "We are assuming weight_ih_l0 "
  617. "exists in LSTM, may need to relax the assumption to support the use case"
  618. qmod = cls(
  619. ref_mod.input_size,
  620. ref_mod.hidden_size,
  621. ref_mod.num_layers,
  622. ref_mod.bias,
  623. ref_mod.batch_first,
  624. ref_mod.dropout,
  625. ref_mod.bidirectional,
  626. # assuming there is layer 0, which should be OK
  627. ref_mod.weight_ih_l0_dtype,
  628. )
  629. qmod.set_weight_bias(ref_mod.get_quantized_weight_bias_dict())
  630. return qmod
  631. class RNNCellBase(torch.nn.Module):
  632. # _FLOAT_MODULE = nn.CellRNNBase
  633. __constants__ = ['input_size', 'hidden_size', 'bias']
  634. def __init__(self, input_size, hidden_size, bias=True, num_chunks=4, dtype=torch.qint8):
  635. super().__init__()
  636. self.input_size = input_size
  637. self.hidden_size = hidden_size
  638. self.bias = bias
  639. self.weight_dtype = dtype
  640. if bias:
  641. self.bias_ih = torch.randn(num_chunks * hidden_size).to(dtype=torch.float)
  642. self.bias_hh = torch.randn(num_chunks * hidden_size).to(dtype=torch.float)
  643. else:
  644. self.register_parameter('bias_ih', None)
  645. self.register_parameter('bias_hh', None)
  646. weight_ih = torch.randn(num_chunks * hidden_size, input_size).to(torch.float)
  647. weight_hh = torch.randn(num_chunks * hidden_size, hidden_size).to(torch.float)
  648. if dtype == torch.qint8:
  649. weight_ih = torch.quantize_per_tensor(weight_ih, scale=1, zero_point=0, dtype=torch.qint8)
  650. weight_hh = torch.quantize_per_tensor(weight_hh, scale=1, zero_point=0, dtype=torch.qint8)
  651. if dtype == torch.qint8:
  652. # for each layer, for each direction we need to quantize and pack
  653. # weights and pack parameters in this order:
  654. #
  655. # w_ih, w_hh
  656. packed_weight_ih = \
  657. torch.ops.quantized.linear_prepack(weight_ih, self.bias_ih)
  658. packed_weight_hh = \
  659. torch.ops.quantized.linear_prepack(weight_hh, self.bias_hh)
  660. else:
  661. # for each layer, for each direction we need to quantize and pack
  662. # weights and pack parameters in this order:
  663. #
  664. # packed_ih, packed_hh, b_ih, b_hh
  665. packed_weight_ih = torch.ops.quantized.linear_prepack_fp16(
  666. weight_ih, self.bias_ih)
  667. packed_weight_hh = torch.ops.quantized.linear_prepack_fp16(
  668. weight_hh, self.bias_hh)
  669. self._packed_weight_ih = packed_weight_ih
  670. self._packed_weight_hh = packed_weight_hh
  671. def _get_name(self):
  672. return 'DynamicQuantizedRNNBase'
  673. def extra_repr(self):
  674. s = '{input_size}, {hidden_size}'
  675. if 'bias' in self.__dict__ and self.bias is not True:
  676. s += ', bias={bias}'
  677. if 'nonlinearity' in self.__dict__ and self.nonlinearity != "tanh":
  678. s += ', nonlinearity={nonlinearity}'
  679. return s.format(**self.__dict__)
  680. def check_forward_input(self, input):
  681. if input.size(1) != self.input_size:
  682. raise RuntimeError(
  683. "input has inconsistent input_size: got {}, expected {}".format(
  684. input.size(1), self.input_size))
  685. def check_forward_hidden(self, input: Tensor, hx: Tensor, hidden_label: str = '') -> None:
  686. if input.size(0) != hx.size(0):
  687. raise RuntimeError(
  688. "Input batch size {} doesn't match hidden{} batch size {}".format(
  689. input.size(0), hidden_label, hx.size(0)))
  690. if hx.size(1) != self.hidden_size:
  691. raise RuntimeError(
  692. "hidden{} has inconsistent hidden_size: got {}, expected {}".format(
  693. hidden_label, hx.size(1), self.hidden_size))
  694. @classmethod
  695. def from_float(cls, mod):
  696. assert type(mod) in {torch.nn.LSTMCell,
  697. torch.nn.GRUCell,
  698. torch.nn.RNNCell}, 'nn.quantized.dynamic.RNNCellBase.from_float \
  699. only works for nn.LSTMCell, nn.GRUCell and nn.RNNCell'
  700. assert hasattr(
  701. mod, 'qconfig'), 'Input float module must have qconfig defined'
  702. if mod.qconfig is not None and mod.qconfig.weight is not None:
  703. weight_observer_method = mod.qconfig.weight
  704. else:
  705. # We have the circular import issues if we import the qconfig in the beginning of this file:
  706. # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
  707. # import until we need it.
  708. from torch.ao.quantization.qconfig import default_dynamic_qconfig
  709. weight_observer_method = default_dynamic_qconfig.weight
  710. dtype = weight_observer_method().dtype
  711. supported_scalar_types = [torch.qint8, torch.float16]
  712. if dtype not in supported_scalar_types:
  713. raise RuntimeError('Unsupported dtype for dynamic RNN quantization: {}'.format(dtype))
  714. qRNNCellBase: Union[LSTMCell, GRUCell, RNNCell]
  715. if type(mod) == torch.nn.LSTMCell:
  716. qRNNCellBase = LSTMCell(mod.input_size, mod.hidden_size, bias=mod.bias, dtype=dtype)
  717. elif type(mod) == torch.nn.GRUCell:
  718. qRNNCellBase = GRUCell(mod.input_size, mod.hidden_size, bias=mod.bias, dtype=dtype)
  719. elif type(mod) == torch.nn.RNNCell:
  720. qRNNCellBase = RNNCell(mod.input_size, mod.hidden_size, bias=mod.bias, nonlinearity=mod.nonlinearity, dtype=dtype)
  721. else:
  722. raise NotImplementedError('Only LSTMCell, GRUCell and RNNCell \
  723. are supported for QuantizedRNN for now')
  724. assert mod.bias
  725. def _observe_and_quantize_weight(weight):
  726. if dtype == torch.qint8:
  727. weight_observer = weight_observer_method()
  728. weight_observer(weight)
  729. qweight = _quantize_weight(weight.float(), weight_observer)
  730. return qweight
  731. else:
  732. return weight.float()
  733. qRNNCellBase._packed_weight_ih = pack_weight_bias(_observe_and_quantize_weight(mod.weight_ih), mod.bias_ih, dtype)
  734. qRNNCellBase._packed_weight_hh = pack_weight_bias(_observe_and_quantize_weight(mod.weight_hh), mod.bias_hh, dtype)
  735. return qRNNCellBase
  736. @classmethod
  737. def from_reference(cls, ref_mod):
  738. assert hasattr(ref_mod, "weight_ih_dtype"), "We are assuming weight_ih "
  739. "exists in reference module, may need to relax the assumption to support the use case"
  740. if hasattr(ref_mod, "nonlinearity"):
  741. qmod = cls(
  742. ref_mod.input_size,
  743. ref_mod.hidden_size,
  744. ref_mod.bias,
  745. ref_mod.nonlinearity,
  746. dtype=ref_mod.weight_ih_dtype
  747. )
  748. else:
  749. qmod = cls(
  750. ref_mod.input_size,
  751. ref_mod.hidden_size,
  752. ref_mod.bias,
  753. dtype=ref_mod.weight_ih_dtype
  754. )
  755. weight_bias_dict = {
  756. "weight": {
  757. "weight_ih": ref_mod.get_quantized_weight_ih(),
  758. "weight_hh": ref_mod.get_quantized_weight_hh(),
  759. },
  760. "bias": {
  761. "bias_ih": ref_mod.bias_ih,
  762. "bias_hh": ref_mod.bias_hh,
  763. }
  764. }
  765. qmod.set_weight_bias(weight_bias_dict)
  766. return qmod
  767. def _weight_bias(self):
  768. # Returns a dict of weights and biases
  769. weight_bias_dict: Dict[str, Dict] = {'weight' : {}, 'bias' : {}}
  770. w1, b1 = self._packed_weight_ih.__getstate__()[0]
  771. w2, b2 = self._packed_weight_hh.__getstate__()[0]
  772. # TODO: these can be simplified to one level? e.g. using weight_ih as key
  773. # directly
  774. weight_bias_dict['weight']['weight_ih'] = w1
  775. weight_bias_dict['weight']['weight_hh'] = w2
  776. weight_bias_dict['bias']['bias_ih'] = b1
  777. weight_bias_dict['bias']['bias_hh'] = b2
  778. return weight_bias_dict
  779. def get_weight(self):
  780. return self._weight_bias()['weight']
  781. def get_bias(self):
  782. return self._weight_bias()['bias']
  783. def set_weight_bias(self, weight_bias_dict):
  784. # TODO: these can be simplified to one level? e.g. using weight_ih as key
  785. # directly
  786. self._packed_weight_ih = pack_weight_bias(
  787. weight_bias_dict["weight"]["weight_ih"],
  788. weight_bias_dict["bias"]["bias_ih"],
  789. self.weight_dtype)
  790. self._packed_weight_hh = pack_weight_bias(
  791. weight_bias_dict["weight"]["weight_hh"],
  792. weight_bias_dict["bias"]["bias_hh"],
  793. self.weight_dtype)
  794. def _save_to_state_dict(self, destination, prefix, keep_vars):
  795. super()._save_to_state_dict(destination, prefix, keep_vars)
  796. destination[prefix + '_packed_weight_ih'] = self._packed_weight_ih
  797. destination[prefix + '_packed_weight_hh'] = self._packed_weight_hh
  798. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
  799. missing_keys, unexpected_keys, error_msgs):
  800. self._packed_weight_ih = state_dict.pop(prefix + '_packed_weight_ih')
  801. self._packed_weight_hh = state_dict.pop(prefix + '_packed_weight_hh')
  802. super()._load_from_state_dict(state_dict, prefix, local_metadata, False,
  803. missing_keys, unexpected_keys, error_msgs)
  804. class RNNCell(RNNCellBase):
  805. r"""An Elman RNN cell with tanh or ReLU non-linearity.
  806. A dynamic quantized RNNCell module with floating point tensor as inputs and outputs.
  807. Weights are quantized to 8 bits. We adopt the same interface as `torch.nn.RNNCell`,
  808. please see https://pytorch.org/docs/stable/nn.html#torch.nn.RNNCell for documentation.
  809. Examples::
  810. >>> # xdoctest: +SKIP
  811. >>> rnn = nn.RNNCell(10, 20)
  812. >>> input = torch.randn(6, 3, 10)
  813. >>> hx = torch.randn(3, 20)
  814. >>> output = []
  815. >>> for i in range(6):
  816. ... hx = rnn(input[i], hx)
  817. ... output.append(hx)
  818. """
  819. __constants__ = ['input_size', 'hidden_size', 'bias', 'nonlinearity']
  820. def __init__(self, input_size, hidden_size, bias=True, nonlinearity="tanh", dtype=torch.qint8):
  821. super().__init__(input_size, hidden_size, bias, num_chunks=1, dtype=dtype)
  822. self.nonlinearity = nonlinearity
  823. def _get_name(self):
  824. return 'DynamicQuantizedRNNCell'
  825. def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
  826. self.check_forward_input(input)
  827. if hx is None:
  828. hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
  829. self.check_forward_hidden(input, hx, '')
  830. if self.nonlinearity == "tanh":
  831. ret = torch.ops.quantized.quantized_rnn_tanh_cell_dynamic(
  832. input, hx,
  833. self._packed_weight_ih, self._packed_weight_hh,
  834. self.bias_ih, self.bias_hh)
  835. elif self.nonlinearity == "relu":
  836. ret = torch.ops.quantized.quantized_rnn_relu_cell_dynamic(
  837. input, hx,
  838. self._packed_weight_ih, self._packed_weight_hh,
  839. self.bias_ih, self.bias_hh)
  840. else:
  841. ret = input # TODO: remove when jit supports exception flow
  842. raise RuntimeError(
  843. "Unknown nonlinearity: {}".format(self.nonlinearity))
  844. return ret
  845. @classmethod
  846. def from_float(cls, mod):
  847. return super(RNNCell, cls).from_float(mod)
  848. class LSTMCell(RNNCellBase):
  849. r"""A long short-term memory (LSTM) cell.
  850. A dynamic quantized LSTMCell module with floating point tensor as inputs and outputs.
  851. Weights are quantized to 8 bits. We adopt the same interface as `torch.nn.LSTMCell`,
  852. please see https://pytorch.org/docs/stable/nn.html#torch.nn.LSTMCell for documentation.
  853. Examples::
  854. >>> # xdoctest: +SKIP
  855. >>> rnn = nn.LSTMCell(10, 20)
  856. >>> input = torch.randn(6, 3, 10)
  857. >>> hx = torch.randn(3, 20)
  858. >>> cx = torch.randn(3, 20)
  859. >>> output = []
  860. >>> for i in range(6):
  861. ... hx, cx = rnn(input[i], (hx, cx))
  862. ... output.append(hx)
  863. """
  864. def __init__(self, *args, **kwargs):
  865. super().__init__(*args, num_chunks=4, **kwargs) # type: ignore[misc]
  866. def _get_name(self):
  867. return 'DynamicQuantizedLSTMCell'
  868. def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
  869. self.check_forward_input(input)
  870. if hx is None:
  871. zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
  872. hx = (zeros, zeros)
  873. self.check_forward_hidden(input, hx[0], '[0]')
  874. self.check_forward_hidden(input, hx[1], '[1]')
  875. return torch.ops.quantized.quantized_lstm_cell_dynamic(
  876. input, hx,
  877. self._packed_weight_ih, self._packed_weight_hh,
  878. self.bias_ih, self.bias_hh)
  879. @classmethod
  880. def from_float(cls, mod):
  881. return super(LSTMCell, cls).from_float(mod)
  882. class GRUCell(RNNCellBase):
  883. r"""A gated recurrent unit (GRU) cell
  884. A dynamic quantized GRUCell module with floating point tensor as inputs and outputs.
  885. Weights are quantized to 8 bits. We adopt the same interface as `torch.nn.GRUCell`,
  886. please see https://pytorch.org/docs/stable/nn.html#torch.nn.GRUCell for documentation.
  887. Examples::
  888. >>> # xdoctest: +SKIP
  889. >>> rnn = nn.GRUCell(10, 20)
  890. >>> input = torch.randn(6, 3, 10)
  891. >>> hx = torch.randn(3, 20)
  892. >>> output = []
  893. >>> for i in range(6):
  894. ... hx = rnn(input[i], hx)
  895. ... output.append(hx)
  896. """
  897. def __init__(self, input_size, hidden_size, bias=True, dtype=torch.qint8):
  898. super().__init__(input_size, hidden_size, bias, num_chunks=3, dtype=dtype)
  899. def _get_name(self):
  900. return 'DynamicQuantizedGRUCell'
  901. def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
  902. self.check_forward_input(input)
  903. if hx is None:
  904. hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
  905. self.check_forward_hidden(input, hx, '')
  906. return torch.ops.quantized.quantized_gru_cell_dynamic(
  907. input, hx,
  908. self._packed_weight_ih, self._packed_weight_hh,
  909. self.bias_ih, self.bias_hh,
  910. )
  911. @classmethod
  912. def from_float(cls, mod):
  913. return super(GRUCell, cls).from_float(mod)