123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403 |
- import numbers
- from typing import Optional, Tuple
- import warnings
- import torch
- from torch import Tensor
- """
- We will recreate all the RNN modules as we require the modules to be decomposed
- into its building blocks to be able to observe.
- """
- __all__ = [
- "LSTMCell",
- "LSTM"
- ]
- class LSTMCell(torch.nn.Module):
- r"""A quantizable long short-term memory (LSTM) cell.
- For the description and the argument types, please, refer to :class:`~torch.nn.LSTMCell`
- Examples::
- >>> import torch.ao.nn.quantizable as nnqa
- >>> rnn = nnqa.LSTMCell(10, 20)
- >>> input = torch.randn(6, 10)
- >>> hx = torch.randn(3, 20)
- >>> cx = torch.randn(3, 20)
- >>> output = []
- >>> for i in range(6):
- ... hx, cx = rnn(input[i], (hx, cx))
- ... output.append(hx)
- """
- _FLOAT_MODULE = torch.nn.LSTMCell
- def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True,
- device=None, dtype=None) -> None:
- factory_kwargs = {'device': device, 'dtype': dtype}
- super().__init__()
- self.input_size = input_dim
- self.hidden_size = hidden_dim
- self.bias = bias
- self.igates = torch.nn.Linear(input_dim, 4 * hidden_dim, bias=bias, **factory_kwargs)
- self.hgates = torch.nn.Linear(hidden_dim, 4 * hidden_dim, bias=bias, **factory_kwargs)
- self.gates = torch.ao.nn.quantized.FloatFunctional()
- self.input_gate = torch.nn.Sigmoid()
- self.forget_gate = torch.nn.Sigmoid()
- self.cell_gate = torch.nn.Tanh()
- self.output_gate = torch.nn.Sigmoid()
- self.fgate_cx = torch.ao.nn.quantized.FloatFunctional()
- self.igate_cgate = torch.ao.nn.quantized.FloatFunctional()
- self.fgate_cx_igate_cgate = torch.ao.nn.quantized.FloatFunctional()
- self.ogate_cy = torch.ao.nn.quantized.FloatFunctional()
- self.initial_hidden_state_qparams: Tuple[float, int] = (1.0, 0)
- self.initial_cell_state_qparams: Tuple[float, int] = (1.0, 0)
- self.hidden_state_dtype: torch.dtype = torch.quint8
- self.cell_state_dtype: torch.dtype = torch.quint8
- def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
- if hidden is None or hidden[0] is None or hidden[1] is None:
- hidden = self.initialize_hidden(x.shape[0], x.is_quantized)
- hx, cx = hidden
- igates = self.igates(x)
- hgates = self.hgates(hx)
- gates = self.gates.add(igates, hgates)
- input_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1)
- input_gate = self.input_gate(input_gate)
- forget_gate = self.forget_gate(forget_gate)
- cell_gate = self.cell_gate(cell_gate)
- out_gate = self.output_gate(out_gate)
- fgate_cx = self.fgate_cx.mul(forget_gate, cx)
- igate_cgate = self.igate_cgate.mul(input_gate, cell_gate)
- fgate_cx_igate_cgate = self.fgate_cx_igate_cgate.add(fgate_cx, igate_cgate)
- cy = fgate_cx_igate_cgate
- tanh_cy = torch.tanh(cy)
- hy = self.ogate_cy.mul(out_gate, tanh_cy)
- return hy, cy
- def initialize_hidden(self, batch_size: int, is_quantized: bool = False) -> Tuple[Tensor, Tensor]:
- h, c = torch.zeros((batch_size, self.hidden_size)), torch.zeros((batch_size, self.hidden_size))
- if is_quantized:
- (h_scale, h_zp) = self.initial_hidden_state_qparams
- (c_scale, c_zp) = self.initial_cell_state_qparams
- h = torch.quantize_per_tensor(h, scale=h_scale, zero_point=h_zp, dtype=self.hidden_state_dtype)
- c = torch.quantize_per_tensor(c, scale=c_scale, zero_point=c_zp, dtype=self.cell_state_dtype)
- return h, c
- def _get_name(self):
- return 'QuantizableLSTMCell'
- @classmethod
- def from_params(cls, wi, wh, bi=None, bh=None):
- """Uses the weights and biases to create a new LSTM cell.
- Args:
- wi, wh: Weights for the input and hidden layers
- bi, bh: Biases for the input and hidden layers
- """
- assert (bi is None) == (bh is None) # Either both None or both have values
- input_size = wi.shape[1]
- hidden_size = wh.shape[1]
- cell = cls(input_dim=input_size, hidden_dim=hidden_size,
- bias=(bi is not None))
- cell.igates.weight = torch.nn.Parameter(wi)
- if bi is not None:
- cell.igates.bias = torch.nn.Parameter(bi)
- cell.hgates.weight = torch.nn.Parameter(wh)
- if bh is not None:
- cell.hgates.bias = torch.nn.Parameter(bh)
- return cell
- @classmethod
- def from_float(cls, other):
- assert type(other) == cls._FLOAT_MODULE
- assert hasattr(other, 'qconfig'), "The float module must have 'qconfig'"
- observed = cls.from_params(other.weight_ih, other.weight_hh,
- other.bias_ih, other.bias_hh)
- observed.qconfig = other.qconfig
- observed.igates.qconfig = other.qconfig
- observed.hgates.qconfig = other.qconfig
- return observed
- class _LSTMSingleLayer(torch.nn.Module):
- r"""A single one-directional LSTM layer.
- The difference between a layer and a cell is that the layer can process a
- sequence, while the cell only expects an instantaneous value.
- """
- def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True,
- device=None, dtype=None) -> None:
- factory_kwargs = {'device': device, 'dtype': dtype}
- super().__init__()
- self.cell = LSTMCell(input_dim, hidden_dim, bias=bias, **factory_kwargs)
- def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
- result = []
- for xx in x:
- hidden = self.cell(xx, hidden)
- result.append(hidden[0]) # type: ignore[index]
- result_tensor = torch.stack(result, 0)
- return result_tensor, hidden
- @classmethod
- def from_params(cls, *args, **kwargs):
- cell = LSTMCell.from_params(*args, **kwargs)
- layer = cls(cell.input_size, cell.hidden_size, cell.bias)
- layer.cell = cell
- return layer
- class _LSTMLayer(torch.nn.Module):
- r"""A single bi-directional LSTM layer."""
- def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True,
- batch_first: bool = False, bidirectional: bool = False,
- device=None, dtype=None) -> None:
- factory_kwargs = {'device': device, 'dtype': dtype}
- super().__init__()
- self.batch_first = batch_first
- self.bidirectional = bidirectional
- self.layer_fw = _LSTMSingleLayer(input_dim, hidden_dim, bias=bias, **factory_kwargs)
- if self.bidirectional:
- self.layer_bw = _LSTMSingleLayer(input_dim, hidden_dim, bias=bias, **factory_kwargs)
- def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
- if self.batch_first:
- x = x.transpose(0, 1)
- if hidden is None:
- hx_fw, cx_fw = (None, None)
- else:
- hx_fw, cx_fw = hidden
- hidden_bw: Optional[Tuple[Tensor, Tensor]] = None
- if self.bidirectional:
- if hx_fw is None:
- hx_bw = None
- else:
- hx_bw = hx_fw[1]
- hx_fw = hx_fw[0]
- if cx_fw is None:
- cx_bw = None
- else:
- cx_bw = cx_fw[1]
- cx_fw = cx_fw[0]
- if hx_bw is not None and cx_bw is not None:
- hidden_bw = hx_bw, cx_bw
- if hx_fw is None and cx_fw is None:
- hidden_fw = None
- else:
- hidden_fw = torch.jit._unwrap_optional(hx_fw), torch.jit._unwrap_optional(cx_fw)
- result_fw, hidden_fw = self.layer_fw(x, hidden_fw)
- if hasattr(self, 'layer_bw') and self.bidirectional:
- x_reversed = x.flip(0)
- result_bw, hidden_bw = self.layer_bw(x_reversed, hidden_bw)
- result_bw = result_bw.flip(0)
- result = torch.cat([result_fw, result_bw], result_fw.dim() - 1)
- if hidden_fw is None and hidden_bw is None:
- h = None
- c = None
- elif hidden_fw is None:
- (h, c) = torch.jit._unwrap_optional(hidden_bw)
- elif hidden_bw is None:
- (h, c) = torch.jit._unwrap_optional(hidden_fw)
- else:
- h = torch.stack([hidden_fw[0], hidden_bw[0]], 0) # type: ignore[list-item]
- c = torch.stack([hidden_fw[1], hidden_bw[1]], 0) # type: ignore[list-item]
- else:
- result = result_fw
- h, c = torch.jit._unwrap_optional(hidden_fw) # type: ignore[assignment]
- if self.batch_first:
- result.transpose_(0, 1)
- return result, (h, c)
- @classmethod
- def from_float(cls, other, layer_idx=0, qconfig=None, **kwargs):
- r"""
- There is no FP equivalent of this class. This function is here just to
- mimic the behavior of the `prepare` within the `torch.ao.quantization`
- flow.
- """
- assert hasattr(other, 'qconfig') or (qconfig is not None)
- input_size = kwargs.get('input_size', other.input_size)
- hidden_size = kwargs.get('hidden_size', other.hidden_size)
- bias = kwargs.get('bias', other.bias)
- batch_first = kwargs.get('batch_first', other.batch_first)
- bidirectional = kwargs.get('bidirectional', other.bidirectional)
- layer = cls(input_size, hidden_size, bias, batch_first, bidirectional)
- layer.qconfig = getattr(other, 'qconfig', qconfig)
- wi = getattr(other, f'weight_ih_l{layer_idx}')
- wh = getattr(other, f'weight_hh_l{layer_idx}')
- bi = getattr(other, f'bias_ih_l{layer_idx}', None)
- bh = getattr(other, f'bias_hh_l{layer_idx}', None)
- layer.layer_fw = _LSTMSingleLayer.from_params(wi, wh, bi, bh)
- if other.bidirectional:
- wi = getattr(other, f'weight_ih_l{layer_idx}_reverse')
- wh = getattr(other, f'weight_hh_l{layer_idx}_reverse')
- bi = getattr(other, f'bias_ih_l{layer_idx}_reverse', None)
- bh = getattr(other, f'bias_hh_l{layer_idx}_reverse', None)
- layer.layer_bw = _LSTMSingleLayer.from_params(wi, wh, bi, bh)
- return layer
- class LSTM(torch.nn.Module):
- r"""A quantizable long short-term memory (LSTM).
- For the description and the argument types, please, refer to :class:`~torch.nn.LSTM`
- Attributes:
- layers : instances of the `_LSTMLayer`
- .. note::
- To access the weights and biases, you need to access them per layer.
- See examples below.
- Examples::
- >>> import torch.ao.nn.quantizable as nnqa
- >>> rnn = nnqa.LSTM(10, 20, 2)
- >>> input = torch.randn(5, 3, 10)
- >>> h0 = torch.randn(2, 3, 20)
- >>> c0 = torch.randn(2, 3, 20)
- >>> output, (hn, cn) = rnn(input, (h0, c0))
- >>> # To get the weights:
- >>> # xdoctest: +SKIP
- >>> print(rnn.layers[0].weight_ih)
- tensor([[...]])
- >>> print(rnn.layers[0].weight_hh)
- AssertionError: There is no reverse path in the non-bidirectional layer
- """
- _FLOAT_MODULE = torch.nn.LSTM
- def __init__(self, input_size: int, hidden_size: int,
- num_layers: int = 1, bias: bool = True,
- batch_first: bool = False, dropout: float = 0.,
- bidirectional: bool = False,
- device=None, dtype=None) -> None:
- factory_kwargs = {'device': device, 'dtype': dtype}
- super().__init__()
- self.input_size = input_size
- self.hidden_size = hidden_size
- self.num_layers = num_layers
- self.bias = bias
- self.batch_first = batch_first
- self.dropout = float(dropout)
- self.bidirectional = bidirectional
- self.training = False # We don't want to train using this module
- num_directions = 2 if bidirectional else 1
- if not isinstance(dropout, numbers.Number) or not 0 <= dropout <= 1 or \
- isinstance(dropout, bool):
- raise ValueError("dropout should be a number in range [0, 1] "
- "representing the probability of an element being "
- "zeroed")
- if dropout > 0:
- warnings.warn("dropout option for quantizable LSTM is ignored. "
- "If you are training, please, use nn.LSTM version "
- "followed by `prepare` step.")
- if num_layers == 1:
- warnings.warn("dropout option adds dropout after all but last "
- "recurrent layer, so non-zero dropout expects "
- "num_layers greater than 1, but got dropout={} "
- "and num_layers={}".format(dropout, num_layers))
- layers = [_LSTMLayer(self.input_size, self.hidden_size,
- self.bias, batch_first=False,
- bidirectional=self.bidirectional, **factory_kwargs)]
- for layer in range(1, num_layers):
- layers.append(_LSTMLayer(self.hidden_size, self.hidden_size,
- self.bias, batch_first=False,
- bidirectional=self.bidirectional,
- **factory_kwargs))
- self.layers = torch.nn.ModuleList(layers)
- def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
- if self.batch_first:
- x = x.transpose(0, 1)
- max_batch_size = x.size(1)
- num_directions = 2 if self.bidirectional else 1
- if hidden is None:
- zeros = torch.zeros(num_directions, max_batch_size,
- self.hidden_size, dtype=torch.float,
- device=x.device)
- zeros.squeeze_(0)
- if x.is_quantized:
- zeros = torch.quantize_per_tensor(zeros, scale=1.0,
- zero_point=0, dtype=x.dtype)
- hxcx = [(zeros, zeros) for _ in range(self.num_layers)]
- else:
- hidden_non_opt = torch.jit._unwrap_optional(hidden)
- if isinstance(hidden_non_opt[0], Tensor):
- hx = hidden_non_opt[0].reshape(self.num_layers, num_directions,
- max_batch_size,
- self.hidden_size).unbind(0)
- cx = hidden_non_opt[1].reshape(self.num_layers, num_directions,
- max_batch_size,
- self.hidden_size).unbind(0)
- hxcx = [(hx[idx].squeeze_(0), cx[idx].squeeze_(0)) for idx in range(self.num_layers)]
- else:
- hxcx = hidden_non_opt
- hx_list = []
- cx_list = []
- for idx, layer in enumerate(self.layers):
- x, (h, c) = layer(x, hxcx[idx])
- hx_list.append(torch.jit._unwrap_optional(h))
- cx_list.append(torch.jit._unwrap_optional(c))
- hx_tensor = torch.stack(hx_list)
- cx_tensor = torch.stack(cx_list)
- # We are creating another dimension for bidirectional case
- # need to collapse it
- hx_tensor = hx_tensor.reshape(-1, hx_tensor.shape[-2], hx_tensor.shape[-1])
- cx_tensor = cx_tensor.reshape(-1, cx_tensor.shape[-2], cx_tensor.shape[-1])
- if self.batch_first:
- x = x.transpose(0, 1)
- return x, (hx_tensor, cx_tensor)
- def _get_name(self):
- return 'QuantizableLSTM'
- @classmethod
- def from_float(cls, other, qconfig=None):
- assert isinstance(other, cls._FLOAT_MODULE)
- assert (hasattr(other, 'qconfig') or qconfig)
- observed = cls(other.input_size, other.hidden_size, other.num_layers,
- other.bias, other.batch_first, other.dropout,
- other.bidirectional)
- observed.qconfig = getattr(other, 'qconfig', qconfig)
- for idx in range(other.num_layers):
- observed.layers[idx] = _LSTMLayer.from_float(other, idx, qconfig,
- batch_first=False)
- observed.eval()
- observed = torch.ao.quantization.prepare(observed, inplace=True)
- return observed
- @classmethod
- def from_observed(cls, other):
- # The whole flow is float -> observed -> quantized
- # This class does float -> observed only
- raise NotImplementedError("It looks like you are trying to convert a "
- "non-quantizable LSTM module. Please, see "
- "the examples on quantizable LSTMs.")
|