123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from typing import Optional, Dict, Any
- from .utils import ReferenceQuantizedModule
- __all__ = ['Linear']
- class Linear(nn.Linear, ReferenceQuantizedModule):
- """ A reference quantized linear module that fits into the FX
- Graph Mode Quantization workflow
- activation will be floating point Tensor, we will store floating
- point weight as well in the module, but in forward we'll quantize
- and dequantize the weight before running the floating point functional
- linear operator.
- """
- _IS_REFERENCE = True
- def __init__(
- self,
- in_features: int,
- out_features: int,
- bias_: bool = True,
- device: Optional[torch.device] = None,
- dtype: Optional[torch.dtype] = None,
- weight_qparams: Optional[Dict[str, Any]] = None):
- super().__init__(in_features, out_features, bias_, device, dtype)
- self._init_weight_qparams(weight_qparams, device)
- def _get_name(self):
- return "QuantizedLinear(Reference)"
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """
- we have:
- w(float) -- quant - dequant \
- x(float) ------------- F.linear ---
- In the full model, we will see
- w(float) -- quant - *dequant \
- x -- quant --- *dequant -- *F.linear --- *quant - dequant
- and the backend should be able to fuse the ops with `*` into a quantized linear
- """
- weight_quant_dequant = self.get_weight()
- result = F.linear(x, weight_quant_dequant, self.bias)
- return result
- @classmethod
- def from_float(cls, float_linear, weight_qparams):
- qref_linear = Linear(
- float_linear.in_features, float_linear.out_features,
- float_linear.bias is not None, device=float_linear.weight.device,
- dtype=float_linear.weight.dtype, weight_qparams=weight_qparams)
- qref_linear.weight = torch.nn.Parameter(float_linear.weight.detach())
- if float_linear.bias is not None:
- qref_linear.bias = torch.nn.Parameter(float_linear.bias.detach())
- return qref_linear
|