sparse.py 4.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import torch.nn as nn
  2. import torch.nn.functional as F
  3. from torch import Tensor
  4. from .utils import ReferenceQuantizedModule
  5. from typing import Optional, Dict, Any
  6. __all__ = ['Embedding', 'EmbeddingBag']
  7. class Embedding(nn.Embedding, ReferenceQuantizedModule):
  8. """ A reference quantized Embedding module that fits into the
  9. FX Graph Mode Quantization workflow, activation will be floating point Tensor,
  10. we will store floating point weight as well in the module, but in forward we'll
  11. quantize and dequantize the weight before running the floating point functional
  12. embedding operator.
  13. """
  14. def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
  15. max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
  16. sparse: bool = False, _weight: Optional[Tensor] = None,
  17. device=None, dtype=None,
  18. weight_qparams: Optional[Dict[str, Any]] = None) -> None:
  19. super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm,
  20. norm_type, scale_grad_by_freq, sparse, _weight, device, dtype)
  21. self._init_weight_qparams(weight_qparams, device)
  22. def _get_name(self):
  23. return "QuantizedEmbedding(Reference)"
  24. def forward(self, input: Tensor) -> Tensor:
  25. weight_quant_dequant = self.get_weight()
  26. return F.embedding(
  27. input, weight_quant_dequant, self.padding_idx, self.max_norm,
  28. self.norm_type, self.scale_grad_by_freq, self.sparse)
  29. @classmethod
  30. def from_float(cls, mod, weight_qparams):
  31. return cls(
  32. mod.num_embeddings,
  33. mod.embedding_dim,
  34. mod.padding_idx,
  35. mod.max_norm,
  36. mod.norm_type,
  37. mod.scale_grad_by_freq,
  38. mod.sparse,
  39. mod.weight,
  40. mod.weight.device,
  41. mod.weight.dtype,
  42. weight_qparams)
  43. class EmbeddingBag(nn.EmbeddingBag, ReferenceQuantizedModule):
  44. """ A reference quantized EmbeddingBag module that fits into the
  45. FX Graph Mode Quantization workflow, activation will be floating point Tensor,
  46. we will store floating point weight as well in the module, but in forward we'll
  47. quantize and dequantize the weight before running the floating point functional
  48. embedding operator.
  49. """
  50. def __init__(self, num_embeddings: int, embedding_dim: int,
  51. max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
  52. mode: str = 'mean', sparse: bool = False, _weight: Optional[Tensor] = None,
  53. include_last_offset: bool = False, padding_idx: Optional[int] = None,
  54. device=None, dtype=None,
  55. weight_qparams: Optional[Dict[str, Any]] = None) -> None:
  56. super().__init__(num_embeddings, embedding_dim, max_norm, norm_type,
  57. scale_grad_by_freq, mode, sparse, _weight, include_last_offset,
  58. padding_idx, device, dtype)
  59. self._init_weight_qparams(weight_qparams, device)
  60. def _get_name(self):
  61. return "QuantizedEmbedding(Reference)"
  62. def forward(self, input: Tensor, offsets: Optional[Tensor] = None, per_sample_weights: Optional[Tensor] = None) -> Tensor:
  63. weight_quant_dequant = self.get_weight()
  64. return F.embedding_bag(input, weight_quant_dequant, offsets,
  65. self.max_norm, self.norm_type,
  66. self.scale_grad_by_freq, self.mode, self.sparse,
  67. per_sample_weights, self.include_last_offset,
  68. self.padding_idx)
  69. @classmethod
  70. def from_float(cls, mod, weight_qparams):
  71. return cls(
  72. mod.num_embeddings,
  73. mod.embedding_dim,
  74. mod.max_norm,
  75. mod.norm_type,
  76. mod.scale_grad_by_freq,
  77. mod.mode,
  78. mod.sparse,
  79. mod.weight,
  80. mod.include_last_offset,
  81. mod.padding_idx,
  82. mod.weight.device,
  83. mod.weight.dtype,
  84. weight_qparams
  85. )