import torch.nn as nn import torch.nn.functional as F from torch import Tensor from .utils import ReferenceQuantizedModule from typing import Optional, Dict, Any __all__ = ['Embedding', 'EmbeddingBag'] class Embedding(nn.Embedding, ReferenceQuantizedModule): """ A reference quantized Embedding module that fits into the FX Graph Mode Quantization workflow, activation will be floating point Tensor, we will store floating point weight as well in the module, but in forward we'll quantize and dequantize the weight before running the floating point functional embedding operator. """ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False, sparse: bool = False, _weight: Optional[Tensor] = None, device=None, dtype=None, weight_qparams: Optional[Dict[str, Any]] = None) -> None: super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight, device, dtype) self._init_weight_qparams(weight_qparams, device) def _get_name(self): return "QuantizedEmbedding(Reference)" def forward(self, input: Tensor) -> Tensor: weight_quant_dequant = self.get_weight() return F.embedding( input, weight_quant_dequant, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse) @classmethod def from_float(cls, mod, weight_qparams): return cls( mod.num_embeddings, mod.embedding_dim, mod.padding_idx, mod.max_norm, mod.norm_type, mod.scale_grad_by_freq, mod.sparse, mod.weight, mod.weight.device, mod.weight.dtype, weight_qparams) class EmbeddingBag(nn.EmbeddingBag, ReferenceQuantizedModule): """ A reference quantized EmbeddingBag module that fits into the FX Graph Mode Quantization workflow, activation will be floating point Tensor, we will store floating point weight as well in the module, but in forward we'll quantize and dequantize the weight before running the floating point functional embedding operator. """ def __init__(self, num_embeddings: int, embedding_dim: int, max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False, mode: str = 'mean', sparse: bool = False, _weight: Optional[Tensor] = None, include_last_offset: bool = False, padding_idx: Optional[int] = None, device=None, dtype=None, weight_qparams: Optional[Dict[str, Any]] = None) -> None: super().__init__(num_embeddings, embedding_dim, max_norm, norm_type, scale_grad_by_freq, mode, sparse, _weight, include_last_offset, padding_idx, device, dtype) self._init_weight_qparams(weight_qparams, device) def _get_name(self): return "QuantizedEmbedding(Reference)" def forward(self, input: Tensor, offsets: Optional[Tensor] = None, per_sample_weights: Optional[Tensor] = None) -> Tensor: weight_quant_dequant = self.get_weight() return F.embedding_bag(input, weight_quant_dequant, offsets, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, per_sample_weights, self.include_last_offset, self.padding_idx) @classmethod def from_float(cls, mod, weight_qparams): return cls( mod.num_embeddings, mod.embedding_dim, mod.max_norm, mod.norm_type, mod.scale_grad_by_freq, mod.mode, mod.sparse, mod.weight, mod.include_last_offset, mod.padding_idx, mod.weight.device, mod.weight.dtype, weight_qparams )