import torch import torch.nn as nn from torch import Tensor from .utils import _quantize_and_dequantize_weight from .utils import _quantize_weight from typing import Optional, Dict, Any, Tuple from torch import _VF from torch.nn.utils.rnn import PackedSequence __all__ = ['RNNCellBase', 'RNNCell', 'LSTMCell', 'GRUCell', 'RNNBase', 'LSTM', 'GRU', 'get_quantized_weight'] def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: return tensor.index_select(dim, permutation) def _get_weight_and_quantization_params(module, wn): weight = getattr(module, wn) params = [weight] for param_name in [wn + n for n in ["_qscheme", "_dtype", "_scale", "_zero_point", "_axis_int"]]: if hasattr(module, param_name): param = getattr(module, param_name) else: param = None params.append(param) return params def get_quantized_weight(module, wn): if not hasattr(module, wn): return None params = _get_weight_and_quantization_params(module, wn) weight = _quantize_weight(*params) return weight def _get_quantize_and_dequantized_weight(module, wn): if not hasattr(module, wn): return None params = _get_weight_and_quantization_params(module, wn) weight = _quantize_and_dequantize_weight(*params) return weight class RNNCellBase(nn.RNNCellBase): def __init__(self, input_size: int, hidden_size: int, bias: bool, num_chunks: int, device=None, dtype=None, weight_qparams_dict=None) -> None: super().__init__(input_size, hidden_size, bias, num_chunks, device=device, dtype=dtype) # TODO(jerryzh168): maybe make this arg a required arg if weight_qparams_dict is None: weight_qparams = { "qscheme": torch.per_tensor_affine, "dtype": torch.quint8, "scale": 1.0, "zero_point": 0 } weight_qparams_dict = { "weight_ih": weight_qparams, "weight_hh": weight_qparams, "is_decomposed": False, } assert len(weight_qparams_dict) == 3, "Expected length for weight_qparams_dict to be 3 for QuantizedRNNCellBase(Reference)" self._init_weight_qparams_dict(weight_qparams_dict, device) def _init_weight_qparams_dict(self, weight_qparams_dict, device): assert weight_qparams_dict is not None self.is_decomposed = weight_qparams_dict["is_decomposed"] for key, weight_qparams in weight_qparams_dict.items(): if key == "is_decomposed": continue # TODO: refactor the duplicated code to utils.py weight_qscheme = weight_qparams["qscheme"] weight_dtype = weight_qparams["dtype"] setattr(self, key + "_qscheme", weight_qscheme) setattr(self, key + "_dtype", weight_dtype) assert weight_qscheme in [None, torch.per_tensor_affine, torch.per_channel_affine], \ Exception(f"qscheme: {weight_qscheme} is not support in {self._get_name()}") if weight_qscheme is not None: scale = weight_qparams["scale"] scale_tensor = scale.clone().detach() \ if isinstance(scale, torch.Tensor) else \ torch.tensor(scale, dtype=torch.float, device=device) self.register_buffer(key + "_scale", scale_tensor) zp = weight_qparams["zero_point"] zp_tensor = zp.clone().detach() \ if isinstance(zp, torch.Tensor) else \ torch.tensor(zp, dtype=torch.int, device=device) self.register_buffer(key + "_zero_point", zp_tensor) if weight_qscheme == torch.per_channel_affine: axis = weight_qparams["axis"] axis_tensor = axis.clone().detach() \ if isinstance(axis, torch.Tensor) else \ torch.tensor(axis, dtype=torch.int, device=device) self.register_buffer(key + "_axis", axis_tensor) else: # added for TorchScriptability, not used self.register_buffer( key + "_axis", torch.tensor(0, dtype=torch.int, device=device)) setattr(self, key + "_axis_int", getattr(self, key + "_axis").item()) def _get_name(self): return "QuantizedRNNCellBase(Reference)" def get_quantized_weight_ih(self): return get_quantized_weight(self, "weight_ih") def get_quantized_weight_hh(self): return get_quantized_weight(self, "weight_hh") def get_weight_ih(self): return _get_quantize_and_dequantized_weight(self, "weight_ih") def get_weight_hh(self): return _get_quantize_and_dequantized_weight(self, "weight_hh") class RNNCell(RNNCellBase): """ We'll store weight_qparams for all the weights (weight_ih and weight_hh), we need to pass in a `weight_qparams_dict` that maps from weight name, e.g. weight_ih, to the weight_qparams for that weight """ def __init__(self, input_size: int, hidden_size: int, bias: bool = True, nonlinearity: str = "tanh", device=None, dtype=None, weight_qparams_dict: Optional[Dict[str, Any]] = None) -> None: factory_kwargs = {'device': device, 'dtype': dtype, 'weight_qparams_dict': weight_qparams_dict} super().__init__(input_size, hidden_size, bias, num_chunks=1, **factory_kwargs) self.nonlinearity = nonlinearity def _get_name(self): return "QuantizedRNNCell(Reference)" # TODO: refactor nn.RNNCell to have a _forward that takes weight_ih and weight_hh as input # and remove duplicated code, same for the other two Cell modules def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: assert input.dim() in (1, 2), \ f"RNNCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor" is_batched = input.dim() == 2 if not is_batched: input = input.unsqueeze(0) if hx is None: hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device) else: hx = hx.unsqueeze(0) if not is_batched else hx if self.nonlinearity == "tanh": ret = _VF.rnn_tanh_cell( input, hx, self.get_weight_ih(), self.get_weight_hh(), self.bias_ih, self.bias_hh, ) elif self.nonlinearity == "relu": ret = _VF.rnn_relu_cell( input, hx, self.get_weight_ih(), self.get_weight_hh(), self.bias_ih, self.bias_hh, ) else: ret = input # TODO: remove when jit supports exception flow raise RuntimeError( "Unknown nonlinearity: {}".format(self.nonlinearity)) if not is_batched: ret = ret.squeeze(0) return ret @classmethod def from_float(cls, mod, weight_qparams_dict): ref_mod = cls( mod.input_size, mod.hidden_size, mod.bias, mod.nonlinearity, mod.weight_ih.device, mod.weight_ih.dtype, weight_qparams_dict) ref_mod.weight_ih = mod.weight_ih ref_mod.weight_hh = mod.weight_hh ref_mod.bias_ih = mod.bias_ih ref_mod.bias_hh = mod.bias_hh return ref_mod class LSTMCell(RNNCellBase): """ We'll store weight_qparams for all the weights (weight_ih and weight_hh), we need to pass in a `weight_qparams_dict` that maps from weight name, e.g. weight_ih, to the weight_qparams for that weight """ def __init__(self, input_size: int, hidden_size: int, bias: bool = True, device=None, dtype=None, weight_qparams_dict: Optional[Dict[str, Any]] = None) -> None: factory_kwargs = {'device': device, 'dtype': dtype, 'weight_qparams_dict': weight_qparams_dict} super().__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs) def _get_name(self): return "QuantizedLSTMCell(Reference)" def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]: assert input.dim() in (1, 2), \ f"LSTMCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor" is_batched = input.dim() == 2 if not is_batched: input = input.unsqueeze(0) if hx is None: zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device) hx = (zeros, zeros) else: hx = (hx[0].unsqueeze(0), hx[1].unsqueeze(0)) if not is_batched else hx ret = _VF.lstm_cell( input, hx, self.get_weight_ih(), self.get_weight_hh(), self.bias_ih, self.bias_hh, ) if not is_batched: ret = (ret[0].squeeze(0), ret[1].squeeze(0)) return ret @classmethod def from_float(cls, mod, weight_qparams_dict): ref_mod = cls( mod.input_size, mod.hidden_size, mod.bias, mod.weight_ih.device, mod.weight_ih.dtype, weight_qparams_dict) ref_mod.weight_ih = mod.weight_ih ref_mod.weight_hh = mod.weight_hh ref_mod.bias_ih = mod.bias_ih ref_mod.bias_hh = mod.bias_hh return ref_mod class GRUCell(RNNCellBase): """ We'll store weight_qparams for all the weights (weight_ih and weight_hh), we need to pass in a `weight_qparams_dict` that maps from weight name, e.g. weight_ih, to the weight_qparams for that weight """ def __init__(self, input_size: int, hidden_size: int, bias: bool = True, device=None, dtype=None, weight_qparams_dict: Optional[Dict[str, Any]] = None) -> None: factory_kwargs = {'device': device, 'dtype': dtype, 'weight_qparams_dict': weight_qparams_dict} super().__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs) def _get_name(self): return "QuantizedGRUCell(Reference)" def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: assert input.dim() in (1, 2), \ f"GRUCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor" is_batched = input.dim() == 2 if not is_batched: input = input.unsqueeze(0) if hx is None: hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device) else: hx = hx.unsqueeze(0) if not is_batched else hx ret = _VF.gru_cell( input, hx, self.get_weight_ih(), self.get_weight_hh(), self.bias_ih, self.bias_hh, ) if not is_batched: ret = ret.squeeze(0) return ret @classmethod def from_float(cls, mod, weight_qparams_dict): ref_mod = cls( mod.input_size, mod.hidden_size, mod.bias, mod.weight_ih.device, mod.weight_ih.dtype, weight_qparams_dict) ref_mod.weight_ih = mod.weight_ih ref_mod.weight_hh = mod.weight_hh ref_mod.bias_ih = mod.bias_ih ref_mod.bias_hh = mod.bias_hh return ref_mod class RNNBase(nn.RNNBase): def __init__(self, mode: str, input_size: int, hidden_size: int, num_layers: int = 1, bias: bool = True, batch_first: bool = False, dropout: float = 0., bidirectional: bool = False, proj_size: int = 0, device=None, dtype=None, weight_qparams_dict: Optional[Dict[str, Any]] = None) -> None: super().__init__( mode, input_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectional, proj_size, device, dtype ) # TODO(jerryzh168): maybe make this arg a required arg if weight_qparams_dict is None: weight_qparams = { 'qscheme': torch.per_tensor_affine, 'dtype': torch.quint8, 'scale': 1.0, 'zero_point': 0 } weight_qparams_dict = {"is_decomposed": False} # type: ignore[dict-item] for wn in self._flat_weights_names: if wn.startswith("weight"): weight_qparams_dict[wn] = weight_qparams self._init_weight_qparams_dict(weight_qparams_dict, device) def _init_weight_qparams_dict(self, weight_qparams_dict, device): self.is_decomposed = weight_qparams_dict["is_decomposed"] for key, weight_qparams in weight_qparams_dict.items(): if key == "is_decomposed": continue weight_qscheme = weight_qparams["qscheme"] weight_dtype = weight_qparams["dtype"] setattr(self, key + "_qscheme", weight_qscheme) setattr(self, key + "_dtype", weight_dtype) assert weight_qscheme in [None, torch.per_tensor_affine, torch.per_channel_affine], \ Exception(f"qscheme: {weight_qscheme} is not support in {self._get_name()}") if weight_qscheme is not None: self.register_buffer( key + "_scale", torch.tensor(weight_qparams["scale"], dtype=torch.float, device=device)) self.register_buffer( key + "_zero_point", torch.tensor(weight_qparams["zero_point"], dtype=torch.int, device=device)) if weight_qscheme == torch.per_channel_affine: self.register_buffer( key + "_axis", torch.tensor(weight_qparams["axis"], dtype=torch.int, device=device)) else: # added for TorchScriptability, not used self.register_buffer( key + "_axis", torch.tensor(0, dtype=torch.int, device=device)) setattr(self, key + "_axis_int", getattr(self, key + "_axis").item()) class LSTM(RNNBase): """ Reference Quantized LSTM Module We'll store weight_qparams for all the weights in _flat_weights, we need to pass in a `weight_qparams_dict` that maps from weight name, e.g. weight_ih_l0, to the weight_qparams for that weight """ def __init__(self, *args, **kwargs): super().__init__('LSTM', *args, **kwargs) # Same as above, see torch/nn/modules/module.py::_forward_unimplemented def permute_hidden(self, # type: ignore[override] hx: Tuple[Tensor, Tensor], permutation: Optional[Tensor] ) -> Tuple[Tensor, Tensor]: if permutation is None: return hx return _apply_permutation(hx[0], permutation), _apply_permutation(hx[1], permutation) def get_expected_cell_size(self, input: Tensor, batch_sizes: Optional[Tensor]) -> Tuple[int, int, int]: if batch_sizes is not None: mini_batch = int(batch_sizes[0]) else: mini_batch = input.size(0) if self.batch_first else input.size(1) num_directions = 2 if self.bidirectional else 1 expected_hidden_size = (self.num_layers * num_directions, mini_batch, self.hidden_size) return expected_hidden_size # In the future, we should prevent mypy from applying contravariance rules here. # See torch/nn/modules/module.py::_forward_unimplemented def check_forward_args(self, # type: ignore[override] input: Tensor, hidden: Tuple[Tensor, Tensor], batch_sizes: Optional[Tensor], ): self.check_input(input, batch_sizes) self.check_hidden_size(hidden[0], self.get_expected_hidden_size(input, batch_sizes), 'Expected hidden[0] size {}, got {}') self.check_hidden_size(hidden[1], self.get_expected_cell_size(input, batch_sizes), 'Expected hidden[1] size {}, got {}') def get_quantized_weight_bias_dict(self): """ dictionary from flat_weight_name to quantized weight or (unquantized) bias e.g. { "weight_ih_l0": quantized_weight, "bias_ih_l0": unquantized_bias, ... } """ quantized_weight_bias_dict = {} for wn in self._flat_weights_names: if hasattr(self, wn): if wn.startswith("weight"): weight_or_bias = get_quantized_weight(self, wn) else: weight_or_bias = getattr(self, wn) else: weight_or_bias = None quantized_weight_bias_dict[wn] = weight_or_bias return quantized_weight_bias_dict def get_flat_weights(self): flat_weights = [] for wn in self._flat_weights_names: if hasattr(self, wn): weight = getattr(self, wn) if wn.startswith("weight"): params = _get_weight_and_quantization_params(self, wn) weight = _quantize_and_dequantize_weight(*params) else: weight = None flat_weights.append(weight) return flat_weights def forward(self, input, hx=None): # noqa: F811 orig_input = input # xxx: isinstance check needs to be in conditional for TorchScript to compile batch_sizes = None if isinstance(orig_input, PackedSequence): input, batch_sizes, sorted_indices, unsorted_indices = input max_batch_size = batch_sizes[0] max_batch_size = int(max_batch_size) else: batch_sizes = None is_batched = input.dim() == 3 batch_dim = 0 if self.batch_first else 1 if not is_batched: input = input.unsqueeze(batch_dim) max_batch_size = input.size(0) if self.batch_first else input.size(1) sorted_indices = None unsorted_indices = None if hx is None: num_directions = 2 if self.bidirectional else 1 real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size h_zeros = torch.zeros(self.num_layers * num_directions, max_batch_size, real_hidden_size, dtype=input.dtype, device=input.device) c_zeros = torch.zeros(self.num_layers * num_directions, max_batch_size, self.hidden_size, dtype=input.dtype, device=input.device) hx = (h_zeros, c_zeros) else: if batch_sizes is None: # If not PackedSequence input. if is_batched: if (hx[0].dim() != 3 or hx[1].dim() != 3): msg = ("For batched 3-D input, hx and cx should " f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors") raise RuntimeError(msg) else: if hx[0].dim() != 2 or hx[1].dim() != 2: msg = ("For unbatched 2-D input, hx and cx should " f"also be 2-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors") raise RuntimeError(msg) hx = (hx[0].unsqueeze(1), hx[1].unsqueeze(1)) # Each batch of the hidden state should match the input sequence that # the user believes he/she is passing in. hx = self.permute_hidden(hx, sorted_indices) self.check_forward_args(input, hx, batch_sizes) if batch_sizes is None: result = _VF.lstm(input, hx, self.get_flat_weights(), self.bias, self.num_layers, self.dropout, self.training, self.bidirectional, self.batch_first) else: result = _VF.lstm(input, batch_sizes, hx, self.get_flat_weights(), self.bias, self.num_layers, self.dropout, self.training, self.bidirectional) output = result[0] hidden = result[1:] # xxx: isinstance check needs to be in conditional for TorchScript to compile if isinstance(orig_input, PackedSequence): output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices) return output_packed, self.permute_hidden(hidden, unsorted_indices) else: if not is_batched: output = output.squeeze(batch_dim) hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1)) return output, self.permute_hidden(hidden, unsorted_indices) def _get_name(self): return "QuantizedLSTM(Reference)" @classmethod def from_float(cls, mod, weight_qparams_dict): ref_mod = cls( mod.input_size, mod.hidden_size, mod.num_layers, mod.bias, mod.batch_first, mod.dropout, mod.bidirectional, weight_qparams_dict=weight_qparams_dict) for wn in mod._flat_weights_names: setattr(ref_mod, wn, getattr(mod, wn)) return ref_mod class GRU(RNNBase): """ Reference Quantized GRU Module We'll store weight_qparams for all the weights in _flat_weights, we need to pass in a `weight_qparams_dict` that maps from weight name, e.g. weight_ih_l0, to the weight_qparams for that weight """ def __init__(self, *args, **kwargs): if 'proj_size' in kwargs: raise ValueError("proj_size argument is only supported for LSTM, not RNN or GRU") super().__init__('GRU', *args, **kwargs) def get_quantized_weight_bias_dict(self): """ dictionary from flat_weight_name to quantized weight or (unquantized) bias e.g. { "weight_ih_l0": quantized_weight, "bias_ih_l0": unquantized_bias, ... } """ quantized_weight_bias_dict = {} for wn in self._flat_weights_names: if hasattr(self, wn): if wn.startswith("weight"): weight_or_bias = get_quantized_weight(self, wn) else: weight_or_bias = getattr(self, wn) else: weight_or_bias = None quantized_weight_bias_dict[wn] = weight_or_bias return quantized_weight_bias_dict def get_flat_weights(self): flat_weights = [] for wn in self._flat_weights_names: if hasattr(self, wn): weight = getattr(self, wn) if wn.startswith("weight"): params = _get_weight_and_quantization_params(self, wn) weight = _quantize_and_dequantize_weight(*params) else: weight = None flat_weights.append(weight) return flat_weights def forward(self, input, hx=None): # noqa: F811 # Note: this is copied from the forward of GRU in https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # only changed self._flat_weights to self.get_flat_weights() # TODO: maybe we can try inheriting from that class and define get_flat_weights # as a @property? this might interfere with TorchScript, if we remove that # requirement in the future we should be able to do this orig_input = input # xxx: isinstance check needs to be in conditional for TorchScript to compile if isinstance(orig_input, PackedSequence): input, batch_sizes, sorted_indices, unsorted_indices = input max_batch_size = batch_sizes[0] max_batch_size = int(max_batch_size) else: batch_sizes = None assert (input.dim() in (2, 3)), f"GRU: Expected input to be 2-D or 3-D but received {input.dim()}-D tensor" is_batched = input.dim() == 3 batch_dim = 0 if self.batch_first else 1 if not is_batched: input = input.unsqueeze(batch_dim) if hx is not None: if hx.dim() != 2: raise RuntimeError( f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor") hx = hx.unsqueeze(1) else: if hx is not None and hx.dim() != 3: raise RuntimeError( f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor") max_batch_size = input.size(0) if self.batch_first else input.size(1) sorted_indices = None unsorted_indices = None if hx is None: num_directions = 2 if self.bidirectional else 1 hx = torch.zeros(self.num_layers * num_directions, max_batch_size, self.hidden_size, dtype=input.dtype, device=input.device) else: # Each batch of the hidden state should match the input sequence that # the user believes he/she is passing in. hx = self.permute_hidden(hx, sorted_indices) self.check_forward_args(input, hx, batch_sizes) if batch_sizes is None: result = _VF.gru(input, hx, self.get_flat_weights(), self.bias, self.num_layers, self.dropout, self.training, self.bidirectional, self.batch_first) else: result = _VF.gru(input, batch_sizes, hx, self.get_flat_weights(), self.bias, self.num_layers, self.dropout, self.training, self.bidirectional) output = result[0] hidden = result[1] # xxx: isinstance check needs to be in conditional for TorchScript to compile if isinstance(orig_input, PackedSequence): output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices) return output_packed, self.permute_hidden(hidden, unsorted_indices) else: if not is_batched: output = output.squeeze(batch_dim) hidden = hidden.squeeze(1) return output, self.permute_hidden(hidden, unsorted_indices) def _get_name(self): return "QuantizedGRU(Reference)" @classmethod def from_float(cls, mod, weight_qparams_dict): ref_mod = cls( mod.input_size, mod.hidden_size, mod.num_layers, mod.bias, mod.batch_first, mod.dropout, mod.bidirectional, weight_qparams_dict=weight_qparams_dict) for wn in mod._flat_weights_names: setattr(ref_mod, wn, getattr(mod, wn)) return ref_mod