import numbers from typing import Optional, Tuple import warnings import torch from torch import Tensor """ We will recreate all the RNN modules as we require the modules to be decomposed into its building blocks to be able to observe. """ __all__ = [ "LSTMCell", "LSTM" ] class LSTMCell(torch.nn.Module): r"""A quantizable long short-term memory (LSTM) cell. For the description and the argument types, please, refer to :class:`~torch.nn.LSTMCell` Examples:: >>> import torch.ao.nn.quantizable as nnqa >>> rnn = nnqa.LSTMCell(10, 20) >>> input = torch.randn(6, 10) >>> hx = torch.randn(3, 20) >>> cx = torch.randn(3, 20) >>> output = [] >>> for i in range(6): ... hx, cx = rnn(input[i], (hx, cx)) ... output.append(hx) """ _FLOAT_MODULE = torch.nn.LSTMCell def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() self.input_size = input_dim self.hidden_size = hidden_dim self.bias = bias self.igates = torch.nn.Linear(input_dim, 4 * hidden_dim, bias=bias, **factory_kwargs) self.hgates = torch.nn.Linear(hidden_dim, 4 * hidden_dim, bias=bias, **factory_kwargs) self.gates = torch.ao.nn.quantized.FloatFunctional() self.input_gate = torch.nn.Sigmoid() self.forget_gate = torch.nn.Sigmoid() self.cell_gate = torch.nn.Tanh() self.output_gate = torch.nn.Sigmoid() self.fgate_cx = torch.ao.nn.quantized.FloatFunctional() self.igate_cgate = torch.ao.nn.quantized.FloatFunctional() self.fgate_cx_igate_cgate = torch.ao.nn.quantized.FloatFunctional() self.ogate_cy = torch.ao.nn.quantized.FloatFunctional() self.initial_hidden_state_qparams: Tuple[float, int] = (1.0, 0) self.initial_cell_state_qparams: Tuple[float, int] = (1.0, 0) self.hidden_state_dtype: torch.dtype = torch.quint8 self.cell_state_dtype: torch.dtype = torch.quint8 def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]: if hidden is None or hidden[0] is None or hidden[1] is None: hidden = self.initialize_hidden(x.shape[0], x.is_quantized) hx, cx = hidden igates = self.igates(x) hgates = self.hgates(hx) gates = self.gates.add(igates, hgates) input_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1) input_gate = self.input_gate(input_gate) forget_gate = self.forget_gate(forget_gate) cell_gate = self.cell_gate(cell_gate) out_gate = self.output_gate(out_gate) fgate_cx = self.fgate_cx.mul(forget_gate, cx) igate_cgate = self.igate_cgate.mul(input_gate, cell_gate) fgate_cx_igate_cgate = self.fgate_cx_igate_cgate.add(fgate_cx, igate_cgate) cy = fgate_cx_igate_cgate tanh_cy = torch.tanh(cy) hy = self.ogate_cy.mul(out_gate, tanh_cy) return hy, cy def initialize_hidden(self, batch_size: int, is_quantized: bool = False) -> Tuple[Tensor, Tensor]: h, c = torch.zeros((batch_size, self.hidden_size)), torch.zeros((batch_size, self.hidden_size)) if is_quantized: (h_scale, h_zp) = self.initial_hidden_state_qparams (c_scale, c_zp) = self.initial_cell_state_qparams h = torch.quantize_per_tensor(h, scale=h_scale, zero_point=h_zp, dtype=self.hidden_state_dtype) c = torch.quantize_per_tensor(c, scale=c_scale, zero_point=c_zp, dtype=self.cell_state_dtype) return h, c def _get_name(self): return 'QuantizableLSTMCell' @classmethod def from_params(cls, wi, wh, bi=None, bh=None): """Uses the weights and biases to create a new LSTM cell. Args: wi, wh: Weights for the input and hidden layers bi, bh: Biases for the input and hidden layers """ assert (bi is None) == (bh is None) # Either both None or both have values input_size = wi.shape[1] hidden_size = wh.shape[1] cell = cls(input_dim=input_size, hidden_dim=hidden_size, bias=(bi is not None)) cell.igates.weight = torch.nn.Parameter(wi) if bi is not None: cell.igates.bias = torch.nn.Parameter(bi) cell.hgates.weight = torch.nn.Parameter(wh) if bh is not None: cell.hgates.bias = torch.nn.Parameter(bh) return cell @classmethod def from_float(cls, other): assert type(other) == cls._FLOAT_MODULE assert hasattr(other, 'qconfig'), "The float module must have 'qconfig'" observed = cls.from_params(other.weight_ih, other.weight_hh, other.bias_ih, other.bias_hh) observed.qconfig = other.qconfig observed.igates.qconfig = other.qconfig observed.hgates.qconfig = other.qconfig return observed class _LSTMSingleLayer(torch.nn.Module): r"""A single one-directional LSTM layer. The difference between a layer and a cell is that the layer can process a sequence, while the cell only expects an instantaneous value. """ def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() self.cell = LSTMCell(input_dim, hidden_dim, bias=bias, **factory_kwargs) def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None): result = [] for xx in x: hidden = self.cell(xx, hidden) result.append(hidden[0]) # type: ignore[index] result_tensor = torch.stack(result, 0) return result_tensor, hidden @classmethod def from_params(cls, *args, **kwargs): cell = LSTMCell.from_params(*args, **kwargs) layer = cls(cell.input_size, cell.hidden_size, cell.bias) layer.cell = cell return layer class _LSTMLayer(torch.nn.Module): r"""A single bi-directional LSTM layer.""" def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True, batch_first: bool = False, bidirectional: bool = False, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() self.batch_first = batch_first self.bidirectional = bidirectional self.layer_fw = _LSTMSingleLayer(input_dim, hidden_dim, bias=bias, **factory_kwargs) if self.bidirectional: self.layer_bw = _LSTMSingleLayer(input_dim, hidden_dim, bias=bias, **factory_kwargs) def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None): if self.batch_first: x = x.transpose(0, 1) if hidden is None: hx_fw, cx_fw = (None, None) else: hx_fw, cx_fw = hidden hidden_bw: Optional[Tuple[Tensor, Tensor]] = None if self.bidirectional: if hx_fw is None: hx_bw = None else: hx_bw = hx_fw[1] hx_fw = hx_fw[0] if cx_fw is None: cx_bw = None else: cx_bw = cx_fw[1] cx_fw = cx_fw[0] if hx_bw is not None and cx_bw is not None: hidden_bw = hx_bw, cx_bw if hx_fw is None and cx_fw is None: hidden_fw = None else: hidden_fw = torch.jit._unwrap_optional(hx_fw), torch.jit._unwrap_optional(cx_fw) result_fw, hidden_fw = self.layer_fw(x, hidden_fw) if hasattr(self, 'layer_bw') and self.bidirectional: x_reversed = x.flip(0) result_bw, hidden_bw = self.layer_bw(x_reversed, hidden_bw) result_bw = result_bw.flip(0) result = torch.cat([result_fw, result_bw], result_fw.dim() - 1) if hidden_fw is None and hidden_bw is None: h = None c = None elif hidden_fw is None: (h, c) = torch.jit._unwrap_optional(hidden_bw) elif hidden_bw is None: (h, c) = torch.jit._unwrap_optional(hidden_fw) else: h = torch.stack([hidden_fw[0], hidden_bw[0]], 0) # type: ignore[list-item] c = torch.stack([hidden_fw[1], hidden_bw[1]], 0) # type: ignore[list-item] else: result = result_fw h, c = torch.jit._unwrap_optional(hidden_fw) # type: ignore[assignment] if self.batch_first: result.transpose_(0, 1) return result, (h, c) @classmethod def from_float(cls, other, layer_idx=0, qconfig=None, **kwargs): r""" There is no FP equivalent of this class. This function is here just to mimic the behavior of the `prepare` within the `torch.ao.quantization` flow. """ assert hasattr(other, 'qconfig') or (qconfig is not None) input_size = kwargs.get('input_size', other.input_size) hidden_size = kwargs.get('hidden_size', other.hidden_size) bias = kwargs.get('bias', other.bias) batch_first = kwargs.get('batch_first', other.batch_first) bidirectional = kwargs.get('bidirectional', other.bidirectional) layer = cls(input_size, hidden_size, bias, batch_first, bidirectional) layer.qconfig = getattr(other, 'qconfig', qconfig) wi = getattr(other, f'weight_ih_l{layer_idx}') wh = getattr(other, f'weight_hh_l{layer_idx}') bi = getattr(other, f'bias_ih_l{layer_idx}', None) bh = getattr(other, f'bias_hh_l{layer_idx}', None) layer.layer_fw = _LSTMSingleLayer.from_params(wi, wh, bi, bh) if other.bidirectional: wi = getattr(other, f'weight_ih_l{layer_idx}_reverse') wh = getattr(other, f'weight_hh_l{layer_idx}_reverse') bi = getattr(other, f'bias_ih_l{layer_idx}_reverse', None) bh = getattr(other, f'bias_hh_l{layer_idx}_reverse', None) layer.layer_bw = _LSTMSingleLayer.from_params(wi, wh, bi, bh) return layer class LSTM(torch.nn.Module): r"""A quantizable long short-term memory (LSTM). For the description and the argument types, please, refer to :class:`~torch.nn.LSTM` Attributes: layers : instances of the `_LSTMLayer` .. note:: To access the weights and biases, you need to access them per layer. See examples below. Examples:: >>> import torch.ao.nn.quantizable as nnqa >>> rnn = nnqa.LSTM(10, 20, 2) >>> input = torch.randn(5, 3, 10) >>> h0 = torch.randn(2, 3, 20) >>> c0 = torch.randn(2, 3, 20) >>> output, (hn, cn) = rnn(input, (h0, c0)) >>> # To get the weights: >>> # xdoctest: +SKIP >>> print(rnn.layers[0].weight_ih) tensor([[...]]) >>> print(rnn.layers[0].weight_hh) AssertionError: There is no reverse path in the non-bidirectional layer """ _FLOAT_MODULE = torch.nn.LSTM def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1, bias: bool = True, batch_first: bool = False, dropout: float = 0., bidirectional: bool = False, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() self.input_size = input_size self.hidden_size = hidden_size self.num_layers = num_layers self.bias = bias self.batch_first = batch_first self.dropout = float(dropout) self.bidirectional = bidirectional self.training = False # We don't want to train using this module num_directions = 2 if bidirectional else 1 if not isinstance(dropout, numbers.Number) or not 0 <= dropout <= 1 or \ isinstance(dropout, bool): raise ValueError("dropout should be a number in range [0, 1] " "representing the probability of an element being " "zeroed") if dropout > 0: warnings.warn("dropout option for quantizable LSTM is ignored. " "If you are training, please, use nn.LSTM version " "followed by `prepare` step.") if num_layers == 1: warnings.warn("dropout option adds dropout after all but last " "recurrent layer, so non-zero dropout expects " "num_layers greater than 1, but got dropout={} " "and num_layers={}".format(dropout, num_layers)) layers = [_LSTMLayer(self.input_size, self.hidden_size, self.bias, batch_first=False, bidirectional=self.bidirectional, **factory_kwargs)] for layer in range(1, num_layers): layers.append(_LSTMLayer(self.hidden_size, self.hidden_size, self.bias, batch_first=False, bidirectional=self.bidirectional, **factory_kwargs)) self.layers = torch.nn.ModuleList(layers) def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None): if self.batch_first: x = x.transpose(0, 1) max_batch_size = x.size(1) num_directions = 2 if self.bidirectional else 1 if hidden is None: zeros = torch.zeros(num_directions, max_batch_size, self.hidden_size, dtype=torch.float, device=x.device) zeros.squeeze_(0) if x.is_quantized: zeros = torch.quantize_per_tensor(zeros, scale=1.0, zero_point=0, dtype=x.dtype) hxcx = [(zeros, zeros) for _ in range(self.num_layers)] else: hidden_non_opt = torch.jit._unwrap_optional(hidden) if isinstance(hidden_non_opt[0], Tensor): hx = hidden_non_opt[0].reshape(self.num_layers, num_directions, max_batch_size, self.hidden_size).unbind(0) cx = hidden_non_opt[1].reshape(self.num_layers, num_directions, max_batch_size, self.hidden_size).unbind(0) hxcx = [(hx[idx].squeeze_(0), cx[idx].squeeze_(0)) for idx in range(self.num_layers)] else: hxcx = hidden_non_opt hx_list = [] cx_list = [] for idx, layer in enumerate(self.layers): x, (h, c) = layer(x, hxcx[idx]) hx_list.append(torch.jit._unwrap_optional(h)) cx_list.append(torch.jit._unwrap_optional(c)) hx_tensor = torch.stack(hx_list) cx_tensor = torch.stack(cx_list) # We are creating another dimension for bidirectional case # need to collapse it hx_tensor = hx_tensor.reshape(-1, hx_tensor.shape[-2], hx_tensor.shape[-1]) cx_tensor = cx_tensor.reshape(-1, cx_tensor.shape[-2], cx_tensor.shape[-1]) if self.batch_first: x = x.transpose(0, 1) return x, (hx_tensor, cx_tensor) def _get_name(self): return 'QuantizableLSTM' @classmethod def from_float(cls, other, qconfig=None): assert isinstance(other, cls._FLOAT_MODULE) assert (hasattr(other, 'qconfig') or qconfig) observed = cls(other.input_size, other.hidden_size, other.num_layers, other.bias, other.batch_first, other.dropout, other.bidirectional) observed.qconfig = getattr(other, 'qconfig', qconfig) for idx in range(other.num_layers): observed.layers[idx] = _LSTMLayer.from_float(other, idx, qconfig, batch_first=False) observed.eval() observed = torch.ao.quantization.prepare(observed, inplace=True) return observed @classmethod def from_observed(cls, other): # The whole flow is float -> observed -> quantized # This class does float -> observed only raise NotImplementedError("It looks like you are trying to convert a " "non-quantizable LSTM module. Please, see " "the examples on quantizable LSTMs.")