embedding_ops.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import torch
  2. from torch import Tensor
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. __all__ = ['Embedding', 'EmbeddingBag']
  6. class Embedding(nn.Embedding):
  7. r"""
  8. An embedding bag module attached with FakeQuantize modules for weight,
  9. used for quantization aware training.
  10. We adopt the same interface as `torch.nn.Embedding`, please see
  11. https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html#torch.nn.Embedding
  12. for documentation.
  13. Similar to `torch.nn.Embedding`, with FakeQuantize modules initialized to
  14. default.
  15. Attributes:
  16. weight: fake quant module for weight
  17. """
  18. _FLOAT_MODULE = nn.Embedding
  19. def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
  20. max_norm=None, norm_type=2.0, scale_grad_by_freq=False,
  21. sparse=False, _weight=None, device=None, dtype=None, qconfig=None) -> None:
  22. factory_kwargs = {'device': device, 'dtype': dtype}
  23. super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm,
  24. norm_type, scale_grad_by_freq, sparse, _weight,
  25. **factory_kwargs)
  26. assert qconfig, 'qconfig must be provided for QAT module'
  27. assert qconfig.weight().qscheme == torch.per_channel_affine_float_qparams, \
  28. 'Embedding weights requires a qscheme of torch.per_channel_affine_float_qparams Got ' + \
  29. str(qconfig.weight().qscheme)
  30. self.qconfig = qconfig
  31. self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
  32. def forward(self, input) -> Tensor:
  33. return F.embedding(input, self.weight_fake_quant(self.weight), self.padding_idx,
  34. self.max_norm, self.norm_type, self.scale_grad_by_freq,
  35. self.sparse)
  36. @classmethod
  37. def from_float(cls, mod):
  38. r"""Create a qat module from a float module
  39. Args: `mod` a float module, either produced by torch.ao.quantization utilities
  40. or directly from user
  41. """
  42. assert type(mod) == cls._FLOAT_MODULE, ' qat.' + cls.__name__ + '.from_float only works for ' + \
  43. cls._FLOAT_MODULE.__name__
  44. assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
  45. assert mod.qconfig, 'Input float module must have a valid qconfig'
  46. weight_qscheme = mod.qconfig.weight().qscheme # type: ignore[union-attr, operator]
  47. assert weight_qscheme == torch.per_channel_affine_float_qparams, \
  48. 'Embedding weights requires a qscheme of torch.per_channel_affine_float_qparams Got ' + \
  49. str(weight_qscheme)
  50. qconfig = mod.qconfig
  51. qat_embedding_bag = cls(mod.num_embeddings, mod.embedding_dim, mod.padding_idx,
  52. mod.max_norm, mod.norm_type, mod.scale_grad_by_freq,
  53. mod.sparse, mod.weight, qconfig=qconfig)
  54. return qat_embedding_bag
  55. def to_float(self):
  56. embedding_bag = torch.nn.Embedding(self.num_embeddings, self.embedding_dim, self.padding_idx,
  57. self.max_norm, self.norm_type, self.scale_grad_by_freq,
  58. self.sparse, None)
  59. embedding_bag.weight = torch.nn.Parameter(self.weight.detach())
  60. embedding_bag.train(self.training)
  61. return embedding_bag
  62. class EmbeddingBag(nn.EmbeddingBag):
  63. r"""
  64. An embedding bag module attached with FakeQuantize modules for weight,
  65. used for quantization aware training.
  66. We adopt the same interface as `torch.nn.EmbeddingBag`, please see
  67. https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html#torch.nn.EmbeddingBag
  68. for documentation.
  69. Similar to `torch.nn.EmbeddingBag`, with FakeQuantize modules initialized to
  70. default.
  71. Attributes:
  72. weight: fake quant module for weight
  73. """
  74. _FLOAT_MODULE = nn.EmbeddingBag
  75. def __init__(self, num_embeddings, embedding_dim, max_norm=None,
  76. norm_type=2.0, scale_grad_by_freq=False, mode='mean',
  77. sparse=False, _weight=None, include_last_offset=False,
  78. padding_idx=None, qconfig=None, device=None, dtype=None) -> None:
  79. factory_kwargs = {'device': device, 'dtype': dtype}
  80. super().__init__(num_embeddings, embedding_dim, max_norm, norm_type,
  81. scale_grad_by_freq, mode, sparse, _weight,
  82. include_last_offset, padding_idx, **factory_kwargs)
  83. assert qconfig, 'qconfig must be provided for QAT module'
  84. assert qconfig.weight().qscheme == torch.per_channel_affine_float_qparams, \
  85. 'Embedding Bag weights requires a qscheme of torch.per_channel_affine_float_qparams Got ' + \
  86. str(qconfig.weight().qscheme)
  87. self.qconfig = qconfig
  88. self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
  89. def forward(self, input, offsets=None, per_sample_weights=None) -> Tensor:
  90. return F.embedding_bag(input, self.weight_fake_quant(self.weight), offsets,
  91. self.max_norm, self.norm_type,
  92. self.scale_grad_by_freq, self.mode, self.sparse,
  93. per_sample_weights, self.include_last_offset,
  94. self.padding_idx)
  95. @classmethod
  96. def from_float(cls, mod):
  97. r"""Create a qat module from a float module
  98. Args: `mod` a float module, either produced by torch.ao.quantization utilities
  99. or directly from user
  100. """
  101. assert type(mod) == cls._FLOAT_MODULE, ' qat.' + cls.__name__ + '.from_float only works for ' + \
  102. cls._FLOAT_MODULE.__name__
  103. assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
  104. assert mod.qconfig, 'Input float module must have a valid qconfig'
  105. weight_qscheme = mod.qconfig.weight().qscheme # type: ignore[union-attr, operator]
  106. assert weight_qscheme == torch.per_channel_affine_float_qparams, \
  107. 'Embedding Bag weights requires a qscheme of torch.per_channel_affine_float_qparams Got ' + \
  108. str(weight_qscheme)
  109. qconfig = mod.qconfig
  110. qat_embedding_bag = cls(mod.num_embeddings, mod.embedding_dim, mod.max_norm, mod.norm_type,
  111. mod.scale_grad_by_freq, mod.mode, mod.sparse, mod.weight,
  112. mod.include_last_offset, mod.padding_idx, qconfig=qconfig)
  113. return qat_embedding_bag
  114. def to_float(self):
  115. embedding_bag = torch.nn.EmbeddingBag(self.num_embeddings, self.embedding_dim, self.max_norm,
  116. self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
  117. None, self.include_last_offset, self.padding_idx)
  118. embedding_bag.weight = torch.nn.Parameter(self.weight.detach())
  119. embedding_bag.train(self.training)
  120. return embedding_bag