linear.py 933 B

12345678910111213141516171819202122232425
  1. import torch
  2. __all__ = ["Linear"]
  3. class Linear(torch.ao.nn.qat.Linear):
  4. r"""
  5. A linear module attached with FakeQuantize modules for weight,
  6. used for dynamic quantization aware training.
  7. We adopt the same interface as `torch.nn.Linear`, please see
  8. https://pytorch.org/docs/stable/nn.html#torch.nn.Linear
  9. for documentation.
  10. Similar to `torch.nn.Linear`, with FakeQuantize modules initialized to
  11. default.
  12. """
  13. def __init__(self, in_features, out_features, bias=True,
  14. qconfig=None, device=None, dtype=None) -> None:
  15. super().__init__(in_features, out_features, bias, qconfig, device, dtype)
  16. if not torch.ao.quantization.qconfig._activation_is_memoryless(qconfig):
  17. raise ValueError(
  18. "Dynamic QAT requires a memoryless observer." +
  19. "This means a MovingAverage observer with averaging constant equal to 1"
  20. )