123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295 |
- import torch
- import torch.nn as nn
- from torch import Tensor
- from torch._jit_internal import Optional, List
- from .utils import _hide_packed_params_repr
- from .utils import _quantize_weight
- __all__ = ['EmbeddingPackedParams', 'Embedding', 'EmbeddingBag']
- class EmbeddingPackedParams(torch.nn.Module):
- _version = 1
- def __init__(self, num_embeddings, embedding_dim, dtype=torch.quint8):
- super().__init__()
- self.dtype = dtype
- if self.dtype in [torch.quint8, torch.quint4x2]:
- scales = torch.ones(num_embeddings, dtype=torch.float)
- zero_points = torch.zeros(num_embeddings, dtype=torch.float)
- wq = torch._empty_per_channel_affine_quantized([num_embeddings, embedding_dim], scales=scales,
- zero_points=zero_points,
- axis=0, dtype=self.dtype)
- self.set_weight(wq)
- else:
- raise NotImplementedError(f'Unsupported dtype on quantized embedding! Supports quint8 and quint4x2. Got dtype: {dtype}')
- @torch.jit.export
- def set_weight(self, weight: torch.Tensor) -> None:
- if self.dtype in [torch.quint8, torch.quint4x2]:
- self._packed_weight = torch.ops.quantized.embedding_bag_prepack(weight)
- else:
- raise NotImplementedError('Unsupported dtype for quantized embedding prepack! Supports quint8 and quint4x2.')
- @torch.jit.export
- def _weight(self):
- if self.dtype in [torch.quint8, torch.quint4x2]:
- return torch.ops.quantized.embedding_bag_unpack(self._packed_weight)
- else:
- raise NotImplementedError('Unsupported dtype for quantized embedding unpack! Supports quint8 and quint4x2.')
- def forward(self, x):
- return x
-
-
-
-
- def _save_to_state_dict(self, destination, prefix, keep_vars):
- super()._save_to_state_dict(destination, prefix, keep_vars)
- destination[prefix + 'dtype'] = self.dtype
- destination[prefix + '_packed_weight'] = self._weight()
- def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
- missing_keys, unexpected_keys, error_msgs):
- self.dtype = state_dict[prefix + 'dtype']
- state_dict.pop(prefix + 'dtype')
- weight = state_dict[prefix + '_packed_weight']
- state_dict.pop(prefix + '_packed_weight')
- self.set_weight(weight)
- super()._load_from_state_dict(state_dict, prefix, local_metadata, False,
- missing_keys, unexpected_keys, error_msgs)
- def __repr__(self):
- return self._weight().__repr__()
- class Embedding(torch.nn.Module):
- r"""
- A quantized Embedding module with quantized packed weights as inputs.
- We adopt the same interface as `torch.nn.Embedding`, please see
- https://pytorch.org/docs/stable/nn.html#torch.nn.Embedding for documentation.
- Similar to :class:`~torch.nn.Embedding`, attributes will be randomly
- initialized at module creation time and will be overwritten later
- Attributes:
- weight (Tensor): the non-learnable quantized weights of the module of
- shape :math:`(\text{num\_embeddings}, \text{embedding\_dim})`.
- Examples::
- >>> m = nn.quantized.Embedding(num_embeddings=10, embedding_dim=12)
- >>> indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8])
- >>> output = m(indices)
- >>> print(output.size())
- torch.Size([9, 12])
- """
- _version = 1
- def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
- max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
- sparse: bool = False, _weight: Optional[Tensor] = None, dtype=torch.quint8) -> None:
- super().__init__()
- self.num_embeddings = num_embeddings
- self.embedding_dim = embedding_dim
- self.dtype = dtype
- if _weight is None:
- scales = torch.ones(num_embeddings, dtype=torch.float)
- zero_points = torch.zeros(num_embeddings, dtype=torch.float)
- qweight = torch._empty_per_channel_affine_quantized([num_embeddings, embedding_dim],
- scales=scales, zero_points=zero_points,
- axis=0, dtype=torch.quint8)
- else:
- assert list(_weight.shape) == [num_embeddings, embedding_dim], \
- 'Shape of weight does not match num_embeddings and embedding_dim'
- qweight = _weight
- self._packed_params = EmbeddingPackedParams(num_embeddings, embedding_dim, dtype)
- self._packed_params.set_weight(qweight)
- def forward(self, indices: Tensor) -> Tensor:
- if self.dtype == torch.quint4x2:
- return torch.ops.quantized.embedding_4bit(self._packed_params._packed_weight, indices)
- else:
- return torch.ops.quantized.embedding_byte(self._packed_params._packed_weight, indices)
- def _get_name(self):
- return 'QuantizedEmbedding'
- def __repr__(self):
- return _hide_packed_params_repr(self, EmbeddingPackedParams)
- def extra_repr(self):
- extra_repr_str = 'num_embeddings={}, embedding_dim={}, dtype={}, qscheme={}'.format(
- self.num_embeddings, self.embedding_dim, self._packed_params.dtype, self.weight().qscheme()
- )
- return extra_repr_str
- def set_weight(self, w: torch.Tensor) -> None:
- self._packed_params.set_weight(w)
- def weight(self):
- return self._packed_params._weight()
- @classmethod
- def from_float(cls, mod):
- r"""Create a quantized embedding module from a float module
- Args:
- mod (Module): a float module, either produced by torch.ao.quantization
- utilities or provided by user
- """
- if hasattr(mod, 'weight_fake_quant'):
- assert type(mod) == torch.ao.nn.qat.Embedding, 'nnq.' + cls.__name__ + '.from_float ' + \
- 'with fake quant only works for ' + torch.ao.nn.qat.Embedding.__name__
- weight_observer = mod.weight_fake_quant
- activation_post_process = mod.activation_post_process
- else:
- assert type(mod) == nn.Embedding, 'nnq.' + cls.__name__ + '.from_float only works for ' + \
- nn.Embedding.__name__
- assert hasattr(mod, 'qconfig'), 'Embedding input float module must have qconfig defined'
- from torch.ao.quantization import float_qparams_weight_only_qconfig
- if mod.qconfig is not None and mod.qconfig.weight is not None:
- weight_observer = mod.qconfig.weight()
- else:
- weight_observer = float_qparams_weight_only_qconfig.weight()
- dtype = weight_observer.dtype
- is_float_qparams_qconfig = weight_observer.qscheme == torch.per_channel_affine_float_qparams
- assert is_float_qparams_qconfig, \
- 'Embedding quantization is only supported with float_qparams_weight_only_qconfig.'
- assert dtype == torch.quint8 or dtype == torch.quint4x2, \
- f'The only supported dtype for nnq.Embedding is torch.quint8 and torch.quint4x2, got {dtype}'
-
- weight_observer(mod.weight)
- qweight = _quantize_weight(mod.weight.float(), weight_observer)
-
- qembedding = Embedding(mod.num_embeddings, mod.embedding_dim)
- qembedding.set_weight(qweight)
- return qembedding
- @classmethod
- def from_reference(cls, ref_embedding):
- qembedding = cls(
- ref_embedding.num_embeddings,
- ref_embedding.embedding_dim,
- ref_embedding.padding_idx,
- ref_embedding.max_norm,
- ref_embedding.norm_type,
- ref_embedding.scale_grad_by_freq,
- ref_embedding.sparse,
- ref_embedding.get_quantized_weight(),
- ref_embedding.weight_dtype,
- )
- return qembedding
- class EmbeddingBag(Embedding):
- r"""
- A quantized EmbeddingBag module with quantized packed weights as inputs.
- We adopt the same interface as `torch.nn.EmbeddingBag`, please see
- https://pytorch.org/docs/stable/nn.html#torch.nn.EmbeddingBag for documentation.
- Similar to :class:`~torch.nn.EmbeddingBag`, attributes will be randomly
- initialized at module creation time and will be overwritten later
- Attributes:
- weight (Tensor): the non-learnable quantized weights of the module of
- shape :math:`(\text{num\_embeddings}, \text{embedding\_dim})`.
- Examples::
- >>> m = nn.quantized.EmbeddingBag(num_embeddings=10, embedding_dim=12, include_last_offset=True, mode='sum')
- >>> indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3])
- >>> offsets = torch.tensor([0, 19, 20, 28, 28, 32])
- >>> output = m(indices, offsets)
- >>> print(output.size())
- torch.Size([5, 12])
- """
- _version = 1
- def __init__(self, num_embeddings: int, embedding_dim: int,
- max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
- mode: str = 'sum', sparse: bool = False, _weight: Optional[Tensor] = None,
- include_last_offset: bool = False, dtype=torch.quint8) -> None:
- super().__init__(num_embeddings, embedding_dim, _weight=_weight, dtype=dtype)
- self.mode = mode
- self.pruned_weights = False
- self.include_last_offset = include_last_offset
- self.dtype = dtype
- def forward(self, indices: Tensor, offsets: Optional[Tensor] = None, per_sample_weights: Optional[Tensor] = None,
- compressed_indices_mapping: Optional[Tensor] = None) -> Tensor:
- if self.dtype == torch.quint4x2:
- return torch.ops.quantized.embedding_bag_4bit(self._packed_params._packed_weight, indices, offsets, False, 0,
- self.pruned_weights, per_sample_weights, compressed_indices_mapping,
- self.include_last_offset)
- else:
- return torch.ops.quantized.embedding_bag_byte(self._packed_params._packed_weight, indices, offsets, False, 0,
- self.pruned_weights, per_sample_weights, compressed_indices_mapping,
- self.include_last_offset)
- def _get_name(self):
- return 'QuantizedEmbeddingBag'
- @classmethod
- def from_float(cls, mod):
- r"""Create a quantized embedding_bag module from a float module
- Args:
- mod (Module): a float module, either produced by torch.ao.quantization
- utilities or provided by user
- """
- if hasattr(mod, 'weight_fake_quant'):
- weight_observer = mod.weight_fake_quant
- else:
- assert type(mod) == nn.EmbeddingBag, 'nnq.' + cls.__name__ + '.from_float only works for ' + \
- nn.EmbeddingBag.__name__
- assert hasattr(mod, 'qconfig'), 'EmbeddingBag input float module must have qconfig defined'
- from torch.ao.quantization.qconfig import float_qparams_weight_only_qconfig
- if mod.qconfig is not None and mod.qconfig.weight is not None:
- weight_observer = mod.qconfig.weight()
- else:
- weight_observer = float_qparams_weight_only_qconfig.weight()
- dtype = weight_observer.dtype
- is_float_qparams_qconfig = weight_observer.qscheme == torch.per_channel_affine_float_qparams
- assert is_float_qparams_qconfig, \
- 'EmbeddingBag quantization is only supported with float_qparams_weight_only_qconfig.'
- assert dtype == torch.quint8 or dtype == torch.quint4x2, \
- f'The only supported dtype for nnq.EmbeddingBag is torch.quint8 and torch.quint4x2, got {dtype}'
-
- weight_observer(mod.weight)
- qweight = _quantize_weight(mod.weight.float(), weight_observer)
-
- qembedding_bag = EmbeddingBag(mod.num_embeddings, mod.embedding_dim, dtype=dtype)
- qembedding_bag.set_weight(qweight)
- return qembedding_bag
- @classmethod
- def from_reference(cls, ref_embedding_bag):
- qembedding_bag = cls(
- ref_embedding_bag.num_embeddings,
- ref_embedding_bag.embedding_dim,
- ref_embedding_bag.max_norm,
- ref_embedding_bag.norm_type,
- ref_embedding_bag.scale_grad_by_freq,
- ref_embedding_bag.mode,
- ref_embedding_bag.sparse,
- ref_embedding_bag.get_quantized_weight(),
- ref_embedding_bag.include_last_offset,
- ref_embedding_bag.weight_dtype,
- )
- return qembedding_bag
|