linear.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from typing import Optional, Dict, Any
  5. from .utils import ReferenceQuantizedModule
  6. __all__ = ['Linear']
  7. class Linear(nn.Linear, ReferenceQuantizedModule):
  8. """ A reference quantized linear module that fits into the FX
  9. Graph Mode Quantization workflow
  10. activation will be floating point Tensor, we will store floating
  11. point weight as well in the module, but in forward we'll quantize
  12. and dequantize the weight before running the floating point functional
  13. linear operator.
  14. """
  15. _IS_REFERENCE = True
  16. def __init__(
  17. self,
  18. in_features: int,
  19. out_features: int,
  20. bias_: bool = True,
  21. device: Optional[torch.device] = None,
  22. dtype: Optional[torch.dtype] = None,
  23. weight_qparams: Optional[Dict[str, Any]] = None):
  24. super().__init__(in_features, out_features, bias_, device, dtype)
  25. self._init_weight_qparams(weight_qparams, device)
  26. def _get_name(self):
  27. return "QuantizedLinear(Reference)"
  28. def forward(self, x: torch.Tensor) -> torch.Tensor:
  29. """
  30. we have:
  31. w(float) -- quant - dequant \
  32. x(float) ------------- F.linear ---
  33. In the full model, we will see
  34. w(float) -- quant - *dequant \
  35. x -- quant --- *dequant -- *F.linear --- *quant - dequant
  36. and the backend should be able to fuse the ops with `*` into a quantized linear
  37. """
  38. weight_quant_dequant = self.get_weight()
  39. result = F.linear(x, weight_quant_dequant, self.bias)
  40. return result
  41. @classmethod
  42. def from_float(cls, float_linear, weight_qparams):
  43. qref_linear = Linear(
  44. float_linear.in_features, float_linear.out_features,
  45. float_linear.bias is not None, device=float_linear.weight.device,
  46. dtype=float_linear.weight.dtype, weight_qparams=weight_qparams)
  47. qref_linear.weight = torch.nn.Parameter(float_linear.weight.detach())
  48. if float_linear.bias is not None:
  49. qref_linear.bias = torch.nn.Parameter(float_linear.bias.detach())
  50. return qref_linear