12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394 |
- import torch.nn as nn
- import torch.nn.functional as F
- from torch import Tensor
- from .utils import ReferenceQuantizedModule
- from typing import Optional, Dict, Any
- __all__ = ['Embedding', 'EmbeddingBag']
- class Embedding(nn.Embedding, ReferenceQuantizedModule):
- """ A reference quantized Embedding module that fits into the
- FX Graph Mode Quantization workflow, activation will be floating point Tensor,
- we will store floating point weight as well in the module, but in forward we'll
- quantize and dequantize the weight before running the floating point functional
- embedding operator.
- """
- def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
- max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
- sparse: bool = False, _weight: Optional[Tensor] = None,
- device=None, dtype=None,
- weight_qparams: Optional[Dict[str, Any]] = None) -> None:
- super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm,
- norm_type, scale_grad_by_freq, sparse, _weight, device, dtype)
- self._init_weight_qparams(weight_qparams, device)
- def _get_name(self):
- return "QuantizedEmbedding(Reference)"
- def forward(self, input: Tensor) -> Tensor:
- weight_quant_dequant = self.get_weight()
- return F.embedding(
- input, weight_quant_dequant, self.padding_idx, self.max_norm,
- self.norm_type, self.scale_grad_by_freq, self.sparse)
- @classmethod
- def from_float(cls, mod, weight_qparams):
- return cls(
- mod.num_embeddings,
- mod.embedding_dim,
- mod.padding_idx,
- mod.max_norm,
- mod.norm_type,
- mod.scale_grad_by_freq,
- mod.sparse,
- mod.weight,
- mod.weight.device,
- mod.weight.dtype,
- weight_qparams)
- class EmbeddingBag(nn.EmbeddingBag, ReferenceQuantizedModule):
- """ A reference quantized EmbeddingBag module that fits into the
- FX Graph Mode Quantization workflow, activation will be floating point Tensor,
- we will store floating point weight as well in the module, but in forward we'll
- quantize and dequantize the weight before running the floating point functional
- embedding operator.
- """
- def __init__(self, num_embeddings: int, embedding_dim: int,
- max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
- mode: str = 'mean', sparse: bool = False, _weight: Optional[Tensor] = None,
- include_last_offset: bool = False, padding_idx: Optional[int] = None,
- device=None, dtype=None,
- weight_qparams: Optional[Dict[str, Any]] = None) -> None:
- super().__init__(num_embeddings, embedding_dim, max_norm, norm_type,
- scale_grad_by_freq, mode, sparse, _weight, include_last_offset,
- padding_idx, device, dtype)
- self._init_weight_qparams(weight_qparams, device)
- def _get_name(self):
- return "QuantizedEmbedding(Reference)"
- def forward(self, input: Tensor, offsets: Optional[Tensor] = None, per_sample_weights: Optional[Tensor] = None) -> Tensor:
- weight_quant_dequant = self.get_weight()
- return F.embedding_bag(input, weight_quant_dequant, offsets,
- self.max_norm, self.norm_type,
- self.scale_grad_by_freq, self.mode, self.sparse,
- per_sample_weights, self.include_last_offset,
- self.padding_idx)
- @classmethod
- def from_float(cls, mod, weight_qparams):
- return cls(
- mod.num_embeddings,
- mod.embedding_dim,
- mod.max_norm,
- mod.norm_type,
- mod.scale_grad_by_freq,
- mod.mode,
- mod.sparse,
- mod.weight,
- mod.include_last_offset,
- mod.padding_idx,
- mod.weight.device,
- mod.weight.dtype,
- weight_qparams
- )
|