tracer.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. import torch
  2. from torch.fx._symbolic_trace import Tracer
  3. from torch.fx.proxy import Scope
  4. from torch.ao.nn.intrinsic import _FusedModule
  5. from typing import List, Callable
  6. __all__ = [
  7. "QuantizationTracer",
  8. ]
  9. class ScopeContextManager(torch.fx.proxy.ScopeContextManager):
  10. def __init__(
  11. self,
  12. scope: Scope,
  13. current_module: torch.nn.Module,
  14. current_module_path: str
  15. ):
  16. super().__init__(scope, Scope(current_module_path, type(current_module)))
  17. class QuantizationTracer(Tracer):
  18. def __init__(
  19. self, skipped_module_names: List[str], skipped_module_classes: List[Callable]
  20. ):
  21. super().__init__()
  22. self.skipped_module_names = skipped_module_names
  23. self.skipped_module_classes = skipped_module_classes
  24. # NB: initialized the module_type of top level module to None
  25. # we are assuming people won't configure the model with the type of top level
  26. # module here, since people can use "" for global config
  27. # We can change this if there is a use case that configures
  28. # qconfig using top level module type
  29. self.scope = Scope("", None)
  30. self.record_stack_traces = True
  31. def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
  32. return (
  33. (
  34. (m.__module__.startswith("torch.nn") or m.__module__.startswith("torch.ao.nn"))
  35. and not isinstance(m, torch.nn.Sequential)
  36. )
  37. or module_qualified_name in self.skipped_module_names
  38. or type(m) in self.skipped_module_classes
  39. or isinstance(m, _FusedModule)
  40. )