123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616 |
- import torch
- import torch.nn as nn
- from torch import Tensor
- from .utils import _quantize_and_dequantize_weight
- from .utils import _quantize_weight
- from typing import Optional, Dict, Any, Tuple
- from torch import _VF
- from torch.nn.utils.rnn import PackedSequence
- __all__ = ['RNNCellBase', 'RNNCell', 'LSTMCell', 'GRUCell', 'RNNBase', 'LSTM', 'GRU', 'get_quantized_weight']
- def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
- return tensor.index_select(dim, permutation)
- def _get_weight_and_quantization_params(module, wn):
- weight = getattr(module, wn)
- params = [weight]
- for param_name in [wn + n for n in ["_qscheme", "_dtype", "_scale", "_zero_point", "_axis_int"]]:
- if hasattr(module, param_name):
- param = getattr(module, param_name)
- else:
- param = None
- params.append(param)
- return params
- def get_quantized_weight(module, wn):
- if not hasattr(module, wn):
- return None
- params = _get_weight_and_quantization_params(module, wn)
- weight = _quantize_weight(*params)
- return weight
- def _get_quantize_and_dequantized_weight(module, wn):
- if not hasattr(module, wn):
- return None
- params = _get_weight_and_quantization_params(module, wn)
- weight = _quantize_and_dequantize_weight(*params)
- return weight
- class RNNCellBase(nn.RNNCellBase):
- def __init__(self, input_size: int, hidden_size: int, bias: bool, num_chunks: int,
- device=None, dtype=None, weight_qparams_dict=None) -> None:
- super().__init__(input_size, hidden_size, bias, num_chunks, device=device, dtype=dtype)
- # TODO(jerryzh168): maybe make this arg a required arg
- if weight_qparams_dict is None:
- weight_qparams = {
- "qscheme": torch.per_tensor_affine,
- "dtype": torch.quint8,
- "scale": 1.0,
- "zero_point": 0
- }
- weight_qparams_dict = {
- "weight_ih": weight_qparams,
- "weight_hh": weight_qparams,
- "is_decomposed": False,
- }
- assert len(weight_qparams_dict) == 3, "Expected length for weight_qparams_dict to be 3 for QuantizedRNNCellBase(Reference)"
- self._init_weight_qparams_dict(weight_qparams_dict, device)
- def _init_weight_qparams_dict(self, weight_qparams_dict, device):
- assert weight_qparams_dict is not None
- self.is_decomposed = weight_qparams_dict["is_decomposed"]
- for key, weight_qparams in weight_qparams_dict.items():
- if key == "is_decomposed":
- continue
- # TODO: refactor the duplicated code to utils.py
- weight_qscheme = weight_qparams["qscheme"]
- weight_dtype = weight_qparams["dtype"]
- setattr(self, key + "_qscheme", weight_qscheme)
- setattr(self, key + "_dtype", weight_dtype)
- assert weight_qscheme in [None, torch.per_tensor_affine, torch.per_channel_affine], \
- Exception(f"qscheme: {weight_qscheme} is not support in {self._get_name()}")
- if weight_qscheme is not None:
- scale = weight_qparams["scale"]
- scale_tensor = scale.clone().detach() \
- if isinstance(scale, torch.Tensor) else \
- torch.tensor(scale, dtype=torch.float, device=device)
- self.register_buffer(key + "_scale", scale_tensor)
- zp = weight_qparams["zero_point"]
- zp_tensor = zp.clone().detach() \
- if isinstance(zp, torch.Tensor) else \
- torch.tensor(zp, dtype=torch.int, device=device)
- self.register_buffer(key + "_zero_point", zp_tensor)
- if weight_qscheme == torch.per_channel_affine:
- axis = weight_qparams["axis"]
- axis_tensor = axis.clone().detach() \
- if isinstance(axis, torch.Tensor) else \
- torch.tensor(axis, dtype=torch.int, device=device)
- self.register_buffer(key + "_axis", axis_tensor)
- else:
- # added for TorchScriptability, not used
- self.register_buffer(
- key + "_axis", torch.tensor(0, dtype=torch.int, device=device))
- setattr(self, key + "_axis_int", getattr(self, key + "_axis").item())
- def _get_name(self):
- return "QuantizedRNNCellBase(Reference)"
- def get_quantized_weight_ih(self):
- return get_quantized_weight(self, "weight_ih")
- def get_quantized_weight_hh(self):
- return get_quantized_weight(self, "weight_hh")
- def get_weight_ih(self):
- return _get_quantize_and_dequantized_weight(self, "weight_ih")
- def get_weight_hh(self):
- return _get_quantize_and_dequantized_weight(self, "weight_hh")
- class RNNCell(RNNCellBase):
- """
- We'll store weight_qparams for all the weights (weight_ih and weight_hh),
- we need to pass in a `weight_qparams_dict` that maps from weight name,
- e.g. weight_ih, to the weight_qparams for that weight
- """
- def __init__(self, input_size: int, hidden_size: int, bias: bool = True, nonlinearity: str = "tanh",
- device=None, dtype=None, weight_qparams_dict: Optional[Dict[str, Any]] = None) -> None:
- factory_kwargs = {'device': device, 'dtype': dtype, 'weight_qparams_dict': weight_qparams_dict}
- super().__init__(input_size, hidden_size, bias, num_chunks=1, **factory_kwargs)
- self.nonlinearity = nonlinearity
- def _get_name(self):
- return "QuantizedRNNCell(Reference)"
- # TODO: refactor nn.RNNCell to have a _forward that takes weight_ih and weight_hh as input
- # and remove duplicated code, same for the other two Cell modules
- def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
- assert input.dim() in (1, 2), \
- f"RNNCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor"
- is_batched = input.dim() == 2
- if not is_batched:
- input = input.unsqueeze(0)
- if hx is None:
- hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
- else:
- hx = hx.unsqueeze(0) if not is_batched else hx
- if self.nonlinearity == "tanh":
- ret = _VF.rnn_tanh_cell(
- input, hx,
- self.get_weight_ih(), self.get_weight_hh(),
- self.bias_ih, self.bias_hh,
- )
- elif self.nonlinearity == "relu":
- ret = _VF.rnn_relu_cell(
- input, hx,
- self.get_weight_ih(), self.get_weight_hh(),
- self.bias_ih, self.bias_hh,
- )
- else:
- ret = input # TODO: remove when jit supports exception flow
- raise RuntimeError(
- "Unknown nonlinearity: {}".format(self.nonlinearity))
- if not is_batched:
- ret = ret.squeeze(0)
- return ret
- @classmethod
- def from_float(cls, mod, weight_qparams_dict):
- ref_mod = cls(
- mod.input_size,
- mod.hidden_size,
- mod.bias,
- mod.nonlinearity,
- mod.weight_ih.device,
- mod.weight_ih.dtype,
- weight_qparams_dict)
- ref_mod.weight_ih = mod.weight_ih
- ref_mod.weight_hh = mod.weight_hh
- ref_mod.bias_ih = mod.bias_ih
- ref_mod.bias_hh = mod.bias_hh
- return ref_mod
- class LSTMCell(RNNCellBase):
- """
- We'll store weight_qparams for all the weights (weight_ih and weight_hh),
- we need to pass in a `weight_qparams_dict` that maps from weight name,
- e.g. weight_ih, to the weight_qparams for that weight
- """
- def __init__(self, input_size: int, hidden_size: int, bias: bool = True,
- device=None, dtype=None, weight_qparams_dict: Optional[Dict[str, Any]] = None) -> None:
- factory_kwargs = {'device': device, 'dtype': dtype, 'weight_qparams_dict': weight_qparams_dict}
- super().__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs)
- def _get_name(self):
- return "QuantizedLSTMCell(Reference)"
- def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
- assert input.dim() in (1, 2), \
- f"LSTMCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor"
- is_batched = input.dim() == 2
- if not is_batched:
- input = input.unsqueeze(0)
- if hx is None:
- zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
- hx = (zeros, zeros)
- else:
- hx = (hx[0].unsqueeze(0), hx[1].unsqueeze(0)) if not is_batched else hx
- ret = _VF.lstm_cell(
- input, hx,
- self.get_weight_ih(), self.get_weight_hh(),
- self.bias_ih, self.bias_hh,
- )
- if not is_batched:
- ret = (ret[0].squeeze(0), ret[1].squeeze(0))
- return ret
- @classmethod
- def from_float(cls, mod, weight_qparams_dict):
- ref_mod = cls(
- mod.input_size,
- mod.hidden_size,
- mod.bias,
- mod.weight_ih.device,
- mod.weight_ih.dtype,
- weight_qparams_dict)
- ref_mod.weight_ih = mod.weight_ih
- ref_mod.weight_hh = mod.weight_hh
- ref_mod.bias_ih = mod.bias_ih
- ref_mod.bias_hh = mod.bias_hh
- return ref_mod
- class GRUCell(RNNCellBase):
- """
- We'll store weight_qparams for all the weights (weight_ih and weight_hh),
- we need to pass in a `weight_qparams_dict` that maps from weight name,
- e.g. weight_ih, to the weight_qparams for that weight
- """
- def __init__(self, input_size: int, hidden_size: int, bias: bool = True,
- device=None, dtype=None, weight_qparams_dict: Optional[Dict[str, Any]] = None) -> None:
- factory_kwargs = {'device': device, 'dtype': dtype, 'weight_qparams_dict': weight_qparams_dict}
- super().__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs)
- def _get_name(self):
- return "QuantizedGRUCell(Reference)"
- def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
- assert input.dim() in (1, 2), \
- f"GRUCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor"
- is_batched = input.dim() == 2
- if not is_batched:
- input = input.unsqueeze(0)
- if hx is None:
- hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
- else:
- hx = hx.unsqueeze(0) if not is_batched else hx
- ret = _VF.gru_cell(
- input, hx,
- self.get_weight_ih(), self.get_weight_hh(),
- self.bias_ih, self.bias_hh,
- )
- if not is_batched:
- ret = ret.squeeze(0)
- return ret
- @classmethod
- def from_float(cls, mod, weight_qparams_dict):
- ref_mod = cls(
- mod.input_size,
- mod.hidden_size,
- mod.bias,
- mod.weight_ih.device,
- mod.weight_ih.dtype,
- weight_qparams_dict)
- ref_mod.weight_ih = mod.weight_ih
- ref_mod.weight_hh = mod.weight_hh
- ref_mod.bias_ih = mod.bias_ih
- ref_mod.bias_hh = mod.bias_hh
- return ref_mod
- class RNNBase(nn.RNNBase):
- def __init__(self, mode: str, input_size: int, hidden_size: int,
- num_layers: int = 1, bias: bool = True, batch_first: bool = False,
- dropout: float = 0., bidirectional: bool = False, proj_size: int = 0,
- device=None, dtype=None,
- weight_qparams_dict: Optional[Dict[str, Any]] = None) -> None:
- super().__init__(
- mode, input_size, hidden_size, num_layers, bias, batch_first, dropout,
- bidirectional, proj_size, device, dtype
- )
- # TODO(jerryzh168): maybe make this arg a required arg
- if weight_qparams_dict is None:
- weight_qparams = {
- 'qscheme': torch.per_tensor_affine,
- 'dtype': torch.quint8,
- 'scale': 1.0,
- 'zero_point': 0
- }
- weight_qparams_dict = {"is_decomposed": False} # type: ignore[dict-item]
- for wn in self._flat_weights_names:
- if wn.startswith("weight"):
- weight_qparams_dict[wn] = weight_qparams
- self._init_weight_qparams_dict(weight_qparams_dict, device)
- def _init_weight_qparams_dict(self, weight_qparams_dict, device):
- self.is_decomposed = weight_qparams_dict["is_decomposed"]
- for key, weight_qparams in weight_qparams_dict.items():
- if key == "is_decomposed":
- continue
- weight_qscheme = weight_qparams["qscheme"]
- weight_dtype = weight_qparams["dtype"]
- setattr(self, key + "_qscheme", weight_qscheme)
- setattr(self, key + "_dtype", weight_dtype)
- assert weight_qscheme in [None, torch.per_tensor_affine, torch.per_channel_affine], \
- Exception(f"qscheme: {weight_qscheme} is not support in {self._get_name()}")
- if weight_qscheme is not None:
- self.register_buffer(
- key + "_scale",
- torch.tensor(weight_qparams["scale"], dtype=torch.float, device=device))
- self.register_buffer(
- key + "_zero_point",
- torch.tensor(weight_qparams["zero_point"], dtype=torch.int, device=device))
- if weight_qscheme == torch.per_channel_affine:
- self.register_buffer(
- key + "_axis",
- torch.tensor(weight_qparams["axis"], dtype=torch.int, device=device))
- else:
- # added for TorchScriptability, not used
- self.register_buffer(
- key + "_axis", torch.tensor(0, dtype=torch.int, device=device))
- setattr(self, key + "_axis_int", getattr(self, key + "_axis").item())
- class LSTM(RNNBase):
- """ Reference Quantized LSTM Module
- We'll store weight_qparams for all the weights in _flat_weights, we need to pass in
- a `weight_qparams_dict` that maps from weight name, e.g. weight_ih_l0,
- to the weight_qparams for that weight
- """
- def __init__(self, *args, **kwargs):
- super().__init__('LSTM', *args, **kwargs)
- # Same as above, see torch/nn/modules/module.py::_forward_unimplemented
- def permute_hidden(self, # type: ignore[override]
- hx: Tuple[Tensor, Tensor],
- permutation: Optional[Tensor]
- ) -> Tuple[Tensor, Tensor]:
- if permutation is None:
- return hx
- return _apply_permutation(hx[0], permutation), _apply_permutation(hx[1], permutation)
- def get_expected_cell_size(self, input: Tensor, batch_sizes: Optional[Tensor]) -> Tuple[int, int, int]:
- if batch_sizes is not None:
- mini_batch = int(batch_sizes[0])
- else:
- mini_batch = input.size(0) if self.batch_first else input.size(1)
- num_directions = 2 if self.bidirectional else 1
- expected_hidden_size = (self.num_layers * num_directions,
- mini_batch, self.hidden_size)
- return expected_hidden_size
- # In the future, we should prevent mypy from applying contravariance rules here.
- # See torch/nn/modules/module.py::_forward_unimplemented
- def check_forward_args(self, # type: ignore[override]
- input: Tensor,
- hidden: Tuple[Tensor, Tensor],
- batch_sizes: Optional[Tensor],
- ):
- self.check_input(input, batch_sizes)
- self.check_hidden_size(hidden[0], self.get_expected_hidden_size(input, batch_sizes),
- 'Expected hidden[0] size {}, got {}')
- self.check_hidden_size(hidden[1], self.get_expected_cell_size(input, batch_sizes),
- 'Expected hidden[1] size {}, got {}')
- def get_quantized_weight_bias_dict(self):
- """ dictionary from flat_weight_name to quantized weight or (unquantized) bias
- e.g.
- {
- "weight_ih_l0": quantized_weight,
- "bias_ih_l0": unquantized_bias,
- ...
- }
- """
- quantized_weight_bias_dict = {}
- for wn in self._flat_weights_names:
- if hasattr(self, wn):
- if wn.startswith("weight"):
- weight_or_bias = get_quantized_weight(self, wn)
- else:
- weight_or_bias = getattr(self, wn)
- else:
- weight_or_bias = None
- quantized_weight_bias_dict[wn] = weight_or_bias
- return quantized_weight_bias_dict
- def get_flat_weights(self):
- flat_weights = []
- for wn in self._flat_weights_names:
- if hasattr(self, wn):
- weight = getattr(self, wn)
- if wn.startswith("weight"):
- params = _get_weight_and_quantization_params(self, wn)
- weight = _quantize_and_dequantize_weight(*params)
- else:
- weight = None
- flat_weights.append(weight)
- return flat_weights
- def forward(self, input, hx=None): # noqa: F811
- orig_input = input
- # xxx: isinstance check needs to be in conditional for TorchScript to compile
- batch_sizes = None
- if isinstance(orig_input, PackedSequence):
- input, batch_sizes, sorted_indices, unsorted_indices = input
- max_batch_size = batch_sizes[0]
- max_batch_size = int(max_batch_size)
- else:
- batch_sizes = None
- is_batched = input.dim() == 3
- batch_dim = 0 if self.batch_first else 1
- if not is_batched:
- input = input.unsqueeze(batch_dim)
- max_batch_size = input.size(0) if self.batch_first else input.size(1)
- sorted_indices = None
- unsorted_indices = None
- if hx is None:
- num_directions = 2 if self.bidirectional else 1
- real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size
- h_zeros = torch.zeros(self.num_layers * num_directions,
- max_batch_size, real_hidden_size,
- dtype=input.dtype, device=input.device)
- c_zeros = torch.zeros(self.num_layers * num_directions,
- max_batch_size, self.hidden_size,
- dtype=input.dtype, device=input.device)
- hx = (h_zeros, c_zeros)
- else:
- if batch_sizes is None: # If not PackedSequence input.
- if is_batched:
- if (hx[0].dim() != 3 or hx[1].dim() != 3):
- msg = ("For batched 3-D input, hx and cx should "
- f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors")
- raise RuntimeError(msg)
- else:
- if hx[0].dim() != 2 or hx[1].dim() != 2:
- msg = ("For unbatched 2-D input, hx and cx should "
- f"also be 2-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors")
- raise RuntimeError(msg)
- hx = (hx[0].unsqueeze(1), hx[1].unsqueeze(1))
- # Each batch of the hidden state should match the input sequence that
- # the user believes he/she is passing in.
- hx = self.permute_hidden(hx, sorted_indices)
- self.check_forward_args(input, hx, batch_sizes)
- if batch_sizes is None:
- result = _VF.lstm(input, hx, self.get_flat_weights(), self.bias, self.num_layers,
- self.dropout, self.training, self.bidirectional, self.batch_first)
- else:
- result = _VF.lstm(input, batch_sizes, hx, self.get_flat_weights(), self.bias,
- self.num_layers, self.dropout, self.training, self.bidirectional)
- output = result[0]
- hidden = result[1:]
- # xxx: isinstance check needs to be in conditional for TorchScript to compile
- if isinstance(orig_input, PackedSequence):
- output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
- return output_packed, self.permute_hidden(hidden, unsorted_indices)
- else:
- if not is_batched:
- output = output.squeeze(batch_dim)
- hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1))
- return output, self.permute_hidden(hidden, unsorted_indices)
- def _get_name(self):
- return "QuantizedLSTM(Reference)"
- @classmethod
- def from_float(cls, mod, weight_qparams_dict):
- ref_mod = cls(
- mod.input_size,
- mod.hidden_size,
- mod.num_layers,
- mod.bias,
- mod.batch_first,
- mod.dropout,
- mod.bidirectional,
- weight_qparams_dict=weight_qparams_dict)
- for wn in mod._flat_weights_names:
- setattr(ref_mod, wn, getattr(mod, wn))
- return ref_mod
- class GRU(RNNBase):
- """ Reference Quantized GRU Module
- We'll store weight_qparams for all the weights in _flat_weights, we need to pass in
- a `weight_qparams_dict` that maps from weight name, e.g. weight_ih_l0,
- to the weight_qparams for that weight
- """
- def __init__(self, *args, **kwargs):
- if 'proj_size' in kwargs:
- raise ValueError("proj_size argument is only supported for LSTM, not RNN or GRU")
- super().__init__('GRU', *args, **kwargs)
- def get_quantized_weight_bias_dict(self):
- """ dictionary from flat_weight_name to quantized weight or (unquantized) bias
- e.g.
- {
- "weight_ih_l0": quantized_weight,
- "bias_ih_l0": unquantized_bias,
- ...
- }
- """
- quantized_weight_bias_dict = {}
- for wn in self._flat_weights_names:
- if hasattr(self, wn):
- if wn.startswith("weight"):
- weight_or_bias = get_quantized_weight(self, wn)
- else:
- weight_or_bias = getattr(self, wn)
- else:
- weight_or_bias = None
- quantized_weight_bias_dict[wn] = weight_or_bias
- return quantized_weight_bias_dict
- def get_flat_weights(self):
- flat_weights = []
- for wn in self._flat_weights_names:
- if hasattr(self, wn):
- weight = getattr(self, wn)
- if wn.startswith("weight"):
- params = _get_weight_and_quantization_params(self, wn)
- weight = _quantize_and_dequantize_weight(*params)
- else:
- weight = None
- flat_weights.append(weight)
- return flat_weights
- def forward(self, input, hx=None): # noqa: F811
- # Note: this is copied from the forward of GRU in https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py
- # only changed self._flat_weights to self.get_flat_weights()
- # TODO: maybe we can try inheriting from that class and define get_flat_weights
- # as a @property? this might interfere with TorchScript, if we remove that
- # requirement in the future we should be able to do this
- orig_input = input
- # xxx: isinstance check needs to be in conditional for TorchScript to compile
- if isinstance(orig_input, PackedSequence):
- input, batch_sizes, sorted_indices, unsorted_indices = input
- max_batch_size = batch_sizes[0]
- max_batch_size = int(max_batch_size)
- else:
- batch_sizes = None
- assert (input.dim() in (2, 3)), f"GRU: Expected input to be 2-D or 3-D but received {input.dim()}-D tensor"
- is_batched = input.dim() == 3
- batch_dim = 0 if self.batch_first else 1
- if not is_batched:
- input = input.unsqueeze(batch_dim)
- if hx is not None:
- if hx.dim() != 2:
- raise RuntimeError(
- f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor")
- hx = hx.unsqueeze(1)
- else:
- if hx is not None and hx.dim() != 3:
- raise RuntimeError(
- f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor")
- max_batch_size = input.size(0) if self.batch_first else input.size(1)
- sorted_indices = None
- unsorted_indices = None
- if hx is None:
- num_directions = 2 if self.bidirectional else 1
- hx = torch.zeros(self.num_layers * num_directions,
- max_batch_size, self.hidden_size,
- dtype=input.dtype, device=input.device)
- else:
- # Each batch of the hidden state should match the input sequence that
- # the user believes he/she is passing in.
- hx = self.permute_hidden(hx, sorted_indices)
- self.check_forward_args(input, hx, batch_sizes)
- if batch_sizes is None:
- result = _VF.gru(input, hx, self.get_flat_weights(), self.bias, self.num_layers,
- self.dropout, self.training, self.bidirectional, self.batch_first)
- else:
- result = _VF.gru(input, batch_sizes, hx, self.get_flat_weights(), self.bias,
- self.num_layers, self.dropout, self.training, self.bidirectional)
- output = result[0]
- hidden = result[1]
- # xxx: isinstance check needs to be in conditional for TorchScript to compile
- if isinstance(orig_input, PackedSequence):
- output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
- return output_packed, self.permute_hidden(hidden, unsorted_indices)
- else:
- if not is_batched:
- output = output.squeeze(batch_dim)
- hidden = hidden.squeeze(1)
- return output, self.permute_hidden(hidden, unsorted_indices)
- def _get_name(self):
- return "QuantizedGRU(Reference)"
- @classmethod
- def from_float(cls, mod, weight_qparams_dict):
- ref_mod = cls(
- mod.input_size,
- mod.hidden_size,
- mod.num_layers,
- mod.bias,
- mod.batch_first,
- mod.dropout,
- mod.bidirectional,
- weight_qparams_dict=weight_qparams_dict)
- for wn in mod._flat_weights_names:
- setattr(ref_mod, wn, getattr(mod, wn))
- return ref_mod
|