utils.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. from typing import Any, List, Optional, Union
  2. import torch
  3. from torch import nn
  4. def _replace_relu(module: nn.Module) -> None:
  5. reassign = {}
  6. for name, mod in module.named_children():
  7. _replace_relu(mod)
  8. # Checking for explicit type instead of instance
  9. # as we only want to replace modules of the exact type
  10. # not inherited classes
  11. if type(mod) is nn.ReLU or type(mod) is nn.ReLU6:
  12. reassign[name] = nn.ReLU(inplace=False)
  13. for key, value in reassign.items():
  14. module._modules[key] = value
  15. def quantize_model(model: nn.Module, backend: str) -> None:
  16. _dummy_input_data = torch.rand(1, 3, 299, 299)
  17. if backend not in torch.backends.quantized.supported_engines:
  18. raise RuntimeError("Quantized backend not supported ")
  19. torch.backends.quantized.engine = backend
  20. model.eval()
  21. # Make sure that weight qconfig matches that of the serialized models
  22. if backend == "fbgemm":
  23. model.qconfig = torch.ao.quantization.QConfig( # type: ignore[assignment]
  24. activation=torch.ao.quantization.default_observer,
  25. weight=torch.ao.quantization.default_per_channel_weight_observer,
  26. )
  27. elif backend == "qnnpack":
  28. model.qconfig = torch.ao.quantization.QConfig( # type: ignore[assignment]
  29. activation=torch.ao.quantization.default_observer, weight=torch.ao.quantization.default_weight_observer
  30. )
  31. # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
  32. model.fuse_model() # type: ignore[operator]
  33. torch.ao.quantization.prepare(model, inplace=True)
  34. model(_dummy_input_data)
  35. torch.ao.quantization.convert(model, inplace=True)
  36. def _fuse_modules(
  37. model: nn.Module, modules_to_fuse: Union[List[str], List[List[str]]], is_qat: Optional[bool], **kwargs: Any
  38. ):
  39. if is_qat is None:
  40. is_qat = model.training
  41. method = torch.ao.quantization.fuse_modules_qat if is_qat else torch.ao.quantization.fuse_modules
  42. return method(model, modules_to_fuse, **kwargs)