123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143 |
- import torch
- from torch import Tensor
- import torch.nn as nn
- import torch.nn.functional as F
- __all__ = ['Embedding', 'EmbeddingBag']
- class Embedding(nn.Embedding):
- r"""
- An embedding bag module attached with FakeQuantize modules for weight,
- used for quantization aware training.
- We adopt the same interface as `torch.nn.Embedding`, please see
- https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html#torch.nn.Embedding
- for documentation.
- Similar to `torch.nn.Embedding`, with FakeQuantize modules initialized to
- default.
- Attributes:
- weight: fake quant module for weight
- """
- _FLOAT_MODULE = nn.Embedding
- def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
- max_norm=None, norm_type=2.0, scale_grad_by_freq=False,
- sparse=False, _weight=None, device=None, dtype=None, qconfig=None) -> None:
- factory_kwargs = {'device': device, 'dtype': dtype}
- super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm,
- norm_type, scale_grad_by_freq, sparse, _weight,
- **factory_kwargs)
- assert qconfig, 'qconfig must be provided for QAT module'
- assert qconfig.weight().qscheme == torch.per_channel_affine_float_qparams, \
- 'Embedding weights requires a qscheme of torch.per_channel_affine_float_qparams Got ' + \
- str(qconfig.weight().qscheme)
- self.qconfig = qconfig
- self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
- def forward(self, input) -> Tensor:
- return F.embedding(input, self.weight_fake_quant(self.weight), self.padding_idx,
- self.max_norm, self.norm_type, self.scale_grad_by_freq,
- self.sparse)
- @classmethod
- def from_float(cls, mod):
- r"""Create a qat module from a float module
- Args: `mod` a float module, either produced by torch.ao.quantization utilities
- or directly from user
- """
- assert type(mod) == cls._FLOAT_MODULE, ' qat.' + cls.__name__ + '.from_float only works for ' + \
- cls._FLOAT_MODULE.__name__
- assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
- assert mod.qconfig, 'Input float module must have a valid qconfig'
- weight_qscheme = mod.qconfig.weight().qscheme # type: ignore[union-attr, operator]
- assert weight_qscheme == torch.per_channel_affine_float_qparams, \
- 'Embedding weights requires a qscheme of torch.per_channel_affine_float_qparams Got ' + \
- str(weight_qscheme)
- qconfig = mod.qconfig
- qat_embedding_bag = cls(mod.num_embeddings, mod.embedding_dim, mod.padding_idx,
- mod.max_norm, mod.norm_type, mod.scale_grad_by_freq,
- mod.sparse, mod.weight, qconfig=qconfig)
- return qat_embedding_bag
- def to_float(self):
- embedding_bag = torch.nn.Embedding(self.num_embeddings, self.embedding_dim, self.padding_idx,
- self.max_norm, self.norm_type, self.scale_grad_by_freq,
- self.sparse, None)
- embedding_bag.weight = torch.nn.Parameter(self.weight.detach())
- embedding_bag.train(self.training)
- return embedding_bag
- class EmbeddingBag(nn.EmbeddingBag):
- r"""
- An embedding bag module attached with FakeQuantize modules for weight,
- used for quantization aware training.
- We adopt the same interface as `torch.nn.EmbeddingBag`, please see
- https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html#torch.nn.EmbeddingBag
- for documentation.
- Similar to `torch.nn.EmbeddingBag`, with FakeQuantize modules initialized to
- default.
- Attributes:
- weight: fake quant module for weight
- """
- _FLOAT_MODULE = nn.EmbeddingBag
- def __init__(self, num_embeddings, embedding_dim, max_norm=None,
- norm_type=2.0, scale_grad_by_freq=False, mode='mean',
- sparse=False, _weight=None, include_last_offset=False,
- padding_idx=None, qconfig=None, device=None, dtype=None) -> None:
- factory_kwargs = {'device': device, 'dtype': dtype}
- super().__init__(num_embeddings, embedding_dim, max_norm, norm_type,
- scale_grad_by_freq, mode, sparse, _weight,
- include_last_offset, padding_idx, **factory_kwargs)
- assert qconfig, 'qconfig must be provided for QAT module'
- assert qconfig.weight().qscheme == torch.per_channel_affine_float_qparams, \
- 'Embedding Bag weights requires a qscheme of torch.per_channel_affine_float_qparams Got ' + \
- str(qconfig.weight().qscheme)
- self.qconfig = qconfig
- self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
- def forward(self, input, offsets=None, per_sample_weights=None) -> Tensor:
- return F.embedding_bag(input, self.weight_fake_quant(self.weight), offsets,
- self.max_norm, self.norm_type,
- self.scale_grad_by_freq, self.mode, self.sparse,
- per_sample_weights, self.include_last_offset,
- self.padding_idx)
- @classmethod
- def from_float(cls, mod):
- r"""Create a qat module from a float module
- Args: `mod` a float module, either produced by torch.ao.quantization utilities
- or directly from user
- """
- assert type(mod) == cls._FLOAT_MODULE, ' qat.' + cls.__name__ + '.from_float only works for ' + \
- cls._FLOAT_MODULE.__name__
- assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
- assert mod.qconfig, 'Input float module must have a valid qconfig'
- weight_qscheme = mod.qconfig.weight().qscheme # type: ignore[union-attr, operator]
- assert weight_qscheme == torch.per_channel_affine_float_qparams, \
- 'Embedding Bag weights requires a qscheme of torch.per_channel_affine_float_qparams Got ' + \
- str(weight_qscheme)
- qconfig = mod.qconfig
- qat_embedding_bag = cls(mod.num_embeddings, mod.embedding_dim, mod.max_norm, mod.norm_type,
- mod.scale_grad_by_freq, mod.mode, mod.sparse, mod.weight,
- mod.include_last_offset, mod.padding_idx, qconfig=qconfig)
- return qat_embedding_bag
- def to_float(self):
- embedding_bag = torch.nn.EmbeddingBag(self.num_embeddings, self.embedding_dim, self.max_norm,
- self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
- None, self.include_last_offset, self.padding_idx)
- embedding_bag.weight = torch.nn.Parameter(self.weight.detach())
- embedding_bag.train(self.training)
- return embedding_bag
|