rnn.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616
  1. import torch
  2. import torch.nn as nn
  3. from torch import Tensor
  4. from .utils import _quantize_and_dequantize_weight
  5. from .utils import _quantize_weight
  6. from typing import Optional, Dict, Any, Tuple
  7. from torch import _VF
  8. from torch.nn.utils.rnn import PackedSequence
  9. __all__ = ['RNNCellBase', 'RNNCell', 'LSTMCell', 'GRUCell', 'RNNBase', 'LSTM', 'GRU', 'get_quantized_weight']
  10. def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
  11. return tensor.index_select(dim, permutation)
  12. def _get_weight_and_quantization_params(module, wn):
  13. weight = getattr(module, wn)
  14. params = [weight]
  15. for param_name in [wn + n for n in ["_qscheme", "_dtype", "_scale", "_zero_point", "_axis_int"]]:
  16. if hasattr(module, param_name):
  17. param = getattr(module, param_name)
  18. else:
  19. param = None
  20. params.append(param)
  21. return params
  22. def get_quantized_weight(module, wn):
  23. if not hasattr(module, wn):
  24. return None
  25. params = _get_weight_and_quantization_params(module, wn)
  26. weight = _quantize_weight(*params)
  27. return weight
  28. def _get_quantize_and_dequantized_weight(module, wn):
  29. if not hasattr(module, wn):
  30. return None
  31. params = _get_weight_and_quantization_params(module, wn)
  32. weight = _quantize_and_dequantize_weight(*params)
  33. return weight
  34. class RNNCellBase(nn.RNNCellBase):
  35. def __init__(self, input_size: int, hidden_size: int, bias: bool, num_chunks: int,
  36. device=None, dtype=None, weight_qparams_dict=None) -> None:
  37. super().__init__(input_size, hidden_size, bias, num_chunks, device=device, dtype=dtype)
  38. # TODO(jerryzh168): maybe make this arg a required arg
  39. if weight_qparams_dict is None:
  40. weight_qparams = {
  41. "qscheme": torch.per_tensor_affine,
  42. "dtype": torch.quint8,
  43. "scale": 1.0,
  44. "zero_point": 0
  45. }
  46. weight_qparams_dict = {
  47. "weight_ih": weight_qparams,
  48. "weight_hh": weight_qparams,
  49. "is_decomposed": False,
  50. }
  51. assert len(weight_qparams_dict) == 3, "Expected length for weight_qparams_dict to be 3 for QuantizedRNNCellBase(Reference)"
  52. self._init_weight_qparams_dict(weight_qparams_dict, device)
  53. def _init_weight_qparams_dict(self, weight_qparams_dict, device):
  54. assert weight_qparams_dict is not None
  55. self.is_decomposed = weight_qparams_dict["is_decomposed"]
  56. for key, weight_qparams in weight_qparams_dict.items():
  57. if key == "is_decomposed":
  58. continue
  59. # TODO: refactor the duplicated code to utils.py
  60. weight_qscheme = weight_qparams["qscheme"]
  61. weight_dtype = weight_qparams["dtype"]
  62. setattr(self, key + "_qscheme", weight_qscheme)
  63. setattr(self, key + "_dtype", weight_dtype)
  64. assert weight_qscheme in [None, torch.per_tensor_affine, torch.per_channel_affine], \
  65. Exception(f"qscheme: {weight_qscheme} is not support in {self._get_name()}")
  66. if weight_qscheme is not None:
  67. scale = weight_qparams["scale"]
  68. scale_tensor = scale.clone().detach() \
  69. if isinstance(scale, torch.Tensor) else \
  70. torch.tensor(scale, dtype=torch.float, device=device)
  71. self.register_buffer(key + "_scale", scale_tensor)
  72. zp = weight_qparams["zero_point"]
  73. zp_tensor = zp.clone().detach() \
  74. if isinstance(zp, torch.Tensor) else \
  75. torch.tensor(zp, dtype=torch.int, device=device)
  76. self.register_buffer(key + "_zero_point", zp_tensor)
  77. if weight_qscheme == torch.per_channel_affine:
  78. axis = weight_qparams["axis"]
  79. axis_tensor = axis.clone().detach() \
  80. if isinstance(axis, torch.Tensor) else \
  81. torch.tensor(axis, dtype=torch.int, device=device)
  82. self.register_buffer(key + "_axis", axis_tensor)
  83. else:
  84. # added for TorchScriptability, not used
  85. self.register_buffer(
  86. key + "_axis", torch.tensor(0, dtype=torch.int, device=device))
  87. setattr(self, key + "_axis_int", getattr(self, key + "_axis").item())
  88. def _get_name(self):
  89. return "QuantizedRNNCellBase(Reference)"
  90. def get_quantized_weight_ih(self):
  91. return get_quantized_weight(self, "weight_ih")
  92. def get_quantized_weight_hh(self):
  93. return get_quantized_weight(self, "weight_hh")
  94. def get_weight_ih(self):
  95. return _get_quantize_and_dequantized_weight(self, "weight_ih")
  96. def get_weight_hh(self):
  97. return _get_quantize_and_dequantized_weight(self, "weight_hh")
  98. class RNNCell(RNNCellBase):
  99. """
  100. We'll store weight_qparams for all the weights (weight_ih and weight_hh),
  101. we need to pass in a `weight_qparams_dict` that maps from weight name,
  102. e.g. weight_ih, to the weight_qparams for that weight
  103. """
  104. def __init__(self, input_size: int, hidden_size: int, bias: bool = True, nonlinearity: str = "tanh",
  105. device=None, dtype=None, weight_qparams_dict: Optional[Dict[str, Any]] = None) -> None:
  106. factory_kwargs = {'device': device, 'dtype': dtype, 'weight_qparams_dict': weight_qparams_dict}
  107. super().__init__(input_size, hidden_size, bias, num_chunks=1, **factory_kwargs)
  108. self.nonlinearity = nonlinearity
  109. def _get_name(self):
  110. return "QuantizedRNNCell(Reference)"
  111. # TODO: refactor nn.RNNCell to have a _forward that takes weight_ih and weight_hh as input
  112. # and remove duplicated code, same for the other two Cell modules
  113. def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
  114. assert input.dim() in (1, 2), \
  115. f"RNNCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor"
  116. is_batched = input.dim() == 2
  117. if not is_batched:
  118. input = input.unsqueeze(0)
  119. if hx is None:
  120. hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
  121. else:
  122. hx = hx.unsqueeze(0) if not is_batched else hx
  123. if self.nonlinearity == "tanh":
  124. ret = _VF.rnn_tanh_cell(
  125. input, hx,
  126. self.get_weight_ih(), self.get_weight_hh(),
  127. self.bias_ih, self.bias_hh,
  128. )
  129. elif self.nonlinearity == "relu":
  130. ret = _VF.rnn_relu_cell(
  131. input, hx,
  132. self.get_weight_ih(), self.get_weight_hh(),
  133. self.bias_ih, self.bias_hh,
  134. )
  135. else:
  136. ret = input # TODO: remove when jit supports exception flow
  137. raise RuntimeError(
  138. "Unknown nonlinearity: {}".format(self.nonlinearity))
  139. if not is_batched:
  140. ret = ret.squeeze(0)
  141. return ret
  142. @classmethod
  143. def from_float(cls, mod, weight_qparams_dict):
  144. ref_mod = cls(
  145. mod.input_size,
  146. mod.hidden_size,
  147. mod.bias,
  148. mod.nonlinearity,
  149. mod.weight_ih.device,
  150. mod.weight_ih.dtype,
  151. weight_qparams_dict)
  152. ref_mod.weight_ih = mod.weight_ih
  153. ref_mod.weight_hh = mod.weight_hh
  154. ref_mod.bias_ih = mod.bias_ih
  155. ref_mod.bias_hh = mod.bias_hh
  156. return ref_mod
  157. class LSTMCell(RNNCellBase):
  158. """
  159. We'll store weight_qparams for all the weights (weight_ih and weight_hh),
  160. we need to pass in a `weight_qparams_dict` that maps from weight name,
  161. e.g. weight_ih, to the weight_qparams for that weight
  162. """
  163. def __init__(self, input_size: int, hidden_size: int, bias: bool = True,
  164. device=None, dtype=None, weight_qparams_dict: Optional[Dict[str, Any]] = None) -> None:
  165. factory_kwargs = {'device': device, 'dtype': dtype, 'weight_qparams_dict': weight_qparams_dict}
  166. super().__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs)
  167. def _get_name(self):
  168. return "QuantizedLSTMCell(Reference)"
  169. def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
  170. assert input.dim() in (1, 2), \
  171. f"LSTMCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor"
  172. is_batched = input.dim() == 2
  173. if not is_batched:
  174. input = input.unsqueeze(0)
  175. if hx is None:
  176. zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
  177. hx = (zeros, zeros)
  178. else:
  179. hx = (hx[0].unsqueeze(0), hx[1].unsqueeze(0)) if not is_batched else hx
  180. ret = _VF.lstm_cell(
  181. input, hx,
  182. self.get_weight_ih(), self.get_weight_hh(),
  183. self.bias_ih, self.bias_hh,
  184. )
  185. if not is_batched:
  186. ret = (ret[0].squeeze(0), ret[1].squeeze(0))
  187. return ret
  188. @classmethod
  189. def from_float(cls, mod, weight_qparams_dict):
  190. ref_mod = cls(
  191. mod.input_size,
  192. mod.hidden_size,
  193. mod.bias,
  194. mod.weight_ih.device,
  195. mod.weight_ih.dtype,
  196. weight_qparams_dict)
  197. ref_mod.weight_ih = mod.weight_ih
  198. ref_mod.weight_hh = mod.weight_hh
  199. ref_mod.bias_ih = mod.bias_ih
  200. ref_mod.bias_hh = mod.bias_hh
  201. return ref_mod
  202. class GRUCell(RNNCellBase):
  203. """
  204. We'll store weight_qparams for all the weights (weight_ih and weight_hh),
  205. we need to pass in a `weight_qparams_dict` that maps from weight name,
  206. e.g. weight_ih, to the weight_qparams for that weight
  207. """
  208. def __init__(self, input_size: int, hidden_size: int, bias: bool = True,
  209. device=None, dtype=None, weight_qparams_dict: Optional[Dict[str, Any]] = None) -> None:
  210. factory_kwargs = {'device': device, 'dtype': dtype, 'weight_qparams_dict': weight_qparams_dict}
  211. super().__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs)
  212. def _get_name(self):
  213. return "QuantizedGRUCell(Reference)"
  214. def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
  215. assert input.dim() in (1, 2), \
  216. f"GRUCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor"
  217. is_batched = input.dim() == 2
  218. if not is_batched:
  219. input = input.unsqueeze(0)
  220. if hx is None:
  221. hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
  222. else:
  223. hx = hx.unsqueeze(0) if not is_batched else hx
  224. ret = _VF.gru_cell(
  225. input, hx,
  226. self.get_weight_ih(), self.get_weight_hh(),
  227. self.bias_ih, self.bias_hh,
  228. )
  229. if not is_batched:
  230. ret = ret.squeeze(0)
  231. return ret
  232. @classmethod
  233. def from_float(cls, mod, weight_qparams_dict):
  234. ref_mod = cls(
  235. mod.input_size,
  236. mod.hidden_size,
  237. mod.bias,
  238. mod.weight_ih.device,
  239. mod.weight_ih.dtype,
  240. weight_qparams_dict)
  241. ref_mod.weight_ih = mod.weight_ih
  242. ref_mod.weight_hh = mod.weight_hh
  243. ref_mod.bias_ih = mod.bias_ih
  244. ref_mod.bias_hh = mod.bias_hh
  245. return ref_mod
  246. class RNNBase(nn.RNNBase):
  247. def __init__(self, mode: str, input_size: int, hidden_size: int,
  248. num_layers: int = 1, bias: bool = True, batch_first: bool = False,
  249. dropout: float = 0., bidirectional: bool = False, proj_size: int = 0,
  250. device=None, dtype=None,
  251. weight_qparams_dict: Optional[Dict[str, Any]] = None) -> None:
  252. super().__init__(
  253. mode, input_size, hidden_size, num_layers, bias, batch_first, dropout,
  254. bidirectional, proj_size, device, dtype
  255. )
  256. # TODO(jerryzh168): maybe make this arg a required arg
  257. if weight_qparams_dict is None:
  258. weight_qparams = {
  259. 'qscheme': torch.per_tensor_affine,
  260. 'dtype': torch.quint8,
  261. 'scale': 1.0,
  262. 'zero_point': 0
  263. }
  264. weight_qparams_dict = {"is_decomposed": False} # type: ignore[dict-item]
  265. for wn in self._flat_weights_names:
  266. if wn.startswith("weight"):
  267. weight_qparams_dict[wn] = weight_qparams
  268. self._init_weight_qparams_dict(weight_qparams_dict, device)
  269. def _init_weight_qparams_dict(self, weight_qparams_dict, device):
  270. self.is_decomposed = weight_qparams_dict["is_decomposed"]
  271. for key, weight_qparams in weight_qparams_dict.items():
  272. if key == "is_decomposed":
  273. continue
  274. weight_qscheme = weight_qparams["qscheme"]
  275. weight_dtype = weight_qparams["dtype"]
  276. setattr(self, key + "_qscheme", weight_qscheme)
  277. setattr(self, key + "_dtype", weight_dtype)
  278. assert weight_qscheme in [None, torch.per_tensor_affine, torch.per_channel_affine], \
  279. Exception(f"qscheme: {weight_qscheme} is not support in {self._get_name()}")
  280. if weight_qscheme is not None:
  281. self.register_buffer(
  282. key + "_scale",
  283. torch.tensor(weight_qparams["scale"], dtype=torch.float, device=device))
  284. self.register_buffer(
  285. key + "_zero_point",
  286. torch.tensor(weight_qparams["zero_point"], dtype=torch.int, device=device))
  287. if weight_qscheme == torch.per_channel_affine:
  288. self.register_buffer(
  289. key + "_axis",
  290. torch.tensor(weight_qparams["axis"], dtype=torch.int, device=device))
  291. else:
  292. # added for TorchScriptability, not used
  293. self.register_buffer(
  294. key + "_axis", torch.tensor(0, dtype=torch.int, device=device))
  295. setattr(self, key + "_axis_int", getattr(self, key + "_axis").item())
  296. class LSTM(RNNBase):
  297. """ Reference Quantized LSTM Module
  298. We'll store weight_qparams for all the weights in _flat_weights, we need to pass in
  299. a `weight_qparams_dict` that maps from weight name, e.g. weight_ih_l0,
  300. to the weight_qparams for that weight
  301. """
  302. def __init__(self, *args, **kwargs):
  303. super().__init__('LSTM', *args, **kwargs)
  304. # Same as above, see torch/nn/modules/module.py::_forward_unimplemented
  305. def permute_hidden(self, # type: ignore[override]
  306. hx: Tuple[Tensor, Tensor],
  307. permutation: Optional[Tensor]
  308. ) -> Tuple[Tensor, Tensor]:
  309. if permutation is None:
  310. return hx
  311. return _apply_permutation(hx[0], permutation), _apply_permutation(hx[1], permutation)
  312. def get_expected_cell_size(self, input: Tensor, batch_sizes: Optional[Tensor]) -> Tuple[int, int, int]:
  313. if batch_sizes is not None:
  314. mini_batch = int(batch_sizes[0])
  315. else:
  316. mini_batch = input.size(0) if self.batch_first else input.size(1)
  317. num_directions = 2 if self.bidirectional else 1
  318. expected_hidden_size = (self.num_layers * num_directions,
  319. mini_batch, self.hidden_size)
  320. return expected_hidden_size
  321. # In the future, we should prevent mypy from applying contravariance rules here.
  322. # See torch/nn/modules/module.py::_forward_unimplemented
  323. def check_forward_args(self, # type: ignore[override]
  324. input: Tensor,
  325. hidden: Tuple[Tensor, Tensor],
  326. batch_sizes: Optional[Tensor],
  327. ):
  328. self.check_input(input, batch_sizes)
  329. self.check_hidden_size(hidden[0], self.get_expected_hidden_size(input, batch_sizes),
  330. 'Expected hidden[0] size {}, got {}')
  331. self.check_hidden_size(hidden[1], self.get_expected_cell_size(input, batch_sizes),
  332. 'Expected hidden[1] size {}, got {}')
  333. def get_quantized_weight_bias_dict(self):
  334. """ dictionary from flat_weight_name to quantized weight or (unquantized) bias
  335. e.g.
  336. {
  337. "weight_ih_l0": quantized_weight,
  338. "bias_ih_l0": unquantized_bias,
  339. ...
  340. }
  341. """
  342. quantized_weight_bias_dict = {}
  343. for wn in self._flat_weights_names:
  344. if hasattr(self, wn):
  345. if wn.startswith("weight"):
  346. weight_or_bias = get_quantized_weight(self, wn)
  347. else:
  348. weight_or_bias = getattr(self, wn)
  349. else:
  350. weight_or_bias = None
  351. quantized_weight_bias_dict[wn] = weight_or_bias
  352. return quantized_weight_bias_dict
  353. def get_flat_weights(self):
  354. flat_weights = []
  355. for wn in self._flat_weights_names:
  356. if hasattr(self, wn):
  357. weight = getattr(self, wn)
  358. if wn.startswith("weight"):
  359. params = _get_weight_and_quantization_params(self, wn)
  360. weight = _quantize_and_dequantize_weight(*params)
  361. else:
  362. weight = None
  363. flat_weights.append(weight)
  364. return flat_weights
  365. def forward(self, input, hx=None): # noqa: F811
  366. orig_input = input
  367. # xxx: isinstance check needs to be in conditional for TorchScript to compile
  368. batch_sizes = None
  369. if isinstance(orig_input, PackedSequence):
  370. input, batch_sizes, sorted_indices, unsorted_indices = input
  371. max_batch_size = batch_sizes[0]
  372. max_batch_size = int(max_batch_size)
  373. else:
  374. batch_sizes = None
  375. is_batched = input.dim() == 3
  376. batch_dim = 0 if self.batch_first else 1
  377. if not is_batched:
  378. input = input.unsqueeze(batch_dim)
  379. max_batch_size = input.size(0) if self.batch_first else input.size(1)
  380. sorted_indices = None
  381. unsorted_indices = None
  382. if hx is None:
  383. num_directions = 2 if self.bidirectional else 1
  384. real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size
  385. h_zeros = torch.zeros(self.num_layers * num_directions,
  386. max_batch_size, real_hidden_size,
  387. dtype=input.dtype, device=input.device)
  388. c_zeros = torch.zeros(self.num_layers * num_directions,
  389. max_batch_size, self.hidden_size,
  390. dtype=input.dtype, device=input.device)
  391. hx = (h_zeros, c_zeros)
  392. else:
  393. if batch_sizes is None: # If not PackedSequence input.
  394. if is_batched:
  395. if (hx[0].dim() != 3 or hx[1].dim() != 3):
  396. msg = ("For batched 3-D input, hx and cx should "
  397. f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors")
  398. raise RuntimeError(msg)
  399. else:
  400. if hx[0].dim() != 2 or hx[1].dim() != 2:
  401. msg = ("For unbatched 2-D input, hx and cx should "
  402. f"also be 2-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors")
  403. raise RuntimeError(msg)
  404. hx = (hx[0].unsqueeze(1), hx[1].unsqueeze(1))
  405. # Each batch of the hidden state should match the input sequence that
  406. # the user believes he/she is passing in.
  407. hx = self.permute_hidden(hx, sorted_indices)
  408. self.check_forward_args(input, hx, batch_sizes)
  409. if batch_sizes is None:
  410. result = _VF.lstm(input, hx, self.get_flat_weights(), self.bias, self.num_layers,
  411. self.dropout, self.training, self.bidirectional, self.batch_first)
  412. else:
  413. result = _VF.lstm(input, batch_sizes, hx, self.get_flat_weights(), self.bias,
  414. self.num_layers, self.dropout, self.training, self.bidirectional)
  415. output = result[0]
  416. hidden = result[1:]
  417. # xxx: isinstance check needs to be in conditional for TorchScript to compile
  418. if isinstance(orig_input, PackedSequence):
  419. output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
  420. return output_packed, self.permute_hidden(hidden, unsorted_indices)
  421. else:
  422. if not is_batched:
  423. output = output.squeeze(batch_dim)
  424. hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1))
  425. return output, self.permute_hidden(hidden, unsorted_indices)
  426. def _get_name(self):
  427. return "QuantizedLSTM(Reference)"
  428. @classmethod
  429. def from_float(cls, mod, weight_qparams_dict):
  430. ref_mod = cls(
  431. mod.input_size,
  432. mod.hidden_size,
  433. mod.num_layers,
  434. mod.bias,
  435. mod.batch_first,
  436. mod.dropout,
  437. mod.bidirectional,
  438. weight_qparams_dict=weight_qparams_dict)
  439. for wn in mod._flat_weights_names:
  440. setattr(ref_mod, wn, getattr(mod, wn))
  441. return ref_mod
  442. class GRU(RNNBase):
  443. """ Reference Quantized GRU Module
  444. We'll store weight_qparams for all the weights in _flat_weights, we need to pass in
  445. a `weight_qparams_dict` that maps from weight name, e.g. weight_ih_l0,
  446. to the weight_qparams for that weight
  447. """
  448. def __init__(self, *args, **kwargs):
  449. if 'proj_size' in kwargs:
  450. raise ValueError("proj_size argument is only supported for LSTM, not RNN or GRU")
  451. super().__init__('GRU', *args, **kwargs)
  452. def get_quantized_weight_bias_dict(self):
  453. """ dictionary from flat_weight_name to quantized weight or (unquantized) bias
  454. e.g.
  455. {
  456. "weight_ih_l0": quantized_weight,
  457. "bias_ih_l0": unquantized_bias,
  458. ...
  459. }
  460. """
  461. quantized_weight_bias_dict = {}
  462. for wn in self._flat_weights_names:
  463. if hasattr(self, wn):
  464. if wn.startswith("weight"):
  465. weight_or_bias = get_quantized_weight(self, wn)
  466. else:
  467. weight_or_bias = getattr(self, wn)
  468. else:
  469. weight_or_bias = None
  470. quantized_weight_bias_dict[wn] = weight_or_bias
  471. return quantized_weight_bias_dict
  472. def get_flat_weights(self):
  473. flat_weights = []
  474. for wn in self._flat_weights_names:
  475. if hasattr(self, wn):
  476. weight = getattr(self, wn)
  477. if wn.startswith("weight"):
  478. params = _get_weight_and_quantization_params(self, wn)
  479. weight = _quantize_and_dequantize_weight(*params)
  480. else:
  481. weight = None
  482. flat_weights.append(weight)
  483. return flat_weights
  484. def forward(self, input, hx=None): # noqa: F811
  485. # Note: this is copied from the forward of GRU in https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py
  486. # only changed self._flat_weights to self.get_flat_weights()
  487. # TODO: maybe we can try inheriting from that class and define get_flat_weights
  488. # as a @property? this might interfere with TorchScript, if we remove that
  489. # requirement in the future we should be able to do this
  490. orig_input = input
  491. # xxx: isinstance check needs to be in conditional for TorchScript to compile
  492. if isinstance(orig_input, PackedSequence):
  493. input, batch_sizes, sorted_indices, unsorted_indices = input
  494. max_batch_size = batch_sizes[0]
  495. max_batch_size = int(max_batch_size)
  496. else:
  497. batch_sizes = None
  498. assert (input.dim() in (2, 3)), f"GRU: Expected input to be 2-D or 3-D but received {input.dim()}-D tensor"
  499. is_batched = input.dim() == 3
  500. batch_dim = 0 if self.batch_first else 1
  501. if not is_batched:
  502. input = input.unsqueeze(batch_dim)
  503. if hx is not None:
  504. if hx.dim() != 2:
  505. raise RuntimeError(
  506. f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor")
  507. hx = hx.unsqueeze(1)
  508. else:
  509. if hx is not None and hx.dim() != 3:
  510. raise RuntimeError(
  511. f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor")
  512. max_batch_size = input.size(0) if self.batch_first else input.size(1)
  513. sorted_indices = None
  514. unsorted_indices = None
  515. if hx is None:
  516. num_directions = 2 if self.bidirectional else 1
  517. hx = torch.zeros(self.num_layers * num_directions,
  518. max_batch_size, self.hidden_size,
  519. dtype=input.dtype, device=input.device)
  520. else:
  521. # Each batch of the hidden state should match the input sequence that
  522. # the user believes he/she is passing in.
  523. hx = self.permute_hidden(hx, sorted_indices)
  524. self.check_forward_args(input, hx, batch_sizes)
  525. if batch_sizes is None:
  526. result = _VF.gru(input, hx, self.get_flat_weights(), self.bias, self.num_layers,
  527. self.dropout, self.training, self.bidirectional, self.batch_first)
  528. else:
  529. result = _VF.gru(input, batch_sizes, hx, self.get_flat_weights(), self.bias,
  530. self.num_layers, self.dropout, self.training, self.bidirectional)
  531. output = result[0]
  532. hidden = result[1]
  533. # xxx: isinstance check needs to be in conditional for TorchScript to compile
  534. if isinstance(orig_input, PackedSequence):
  535. output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
  536. return output_packed, self.permute_hidden(hidden, unsorted_indices)
  537. else:
  538. if not is_batched:
  539. output = output.squeeze(batch_dim)
  540. hidden = hidden.squeeze(1)
  541. return output, self.permute_hidden(hidden, unsorted_indices)
  542. def _get_name(self):
  543. return "QuantizedGRU(Reference)"
  544. @classmethod
  545. def from_float(cls, mod, weight_qparams_dict):
  546. ref_mod = cls(
  547. mod.input_size,
  548. mod.hidden_size,
  549. mod.num_layers,
  550. mod.bias,
  551. mod.batch_first,
  552. mod.dropout,
  553. mod.bidirectional,
  554. weight_qparams_dict=weight_qparams_dict)
  555. for wn in mod._flat_weights_names:
  556. setattr(ref_mod, wn, getattr(mod, wn))
  557. return ref_mod