__init__.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import sys
  2. import torch
  3. import types
  4. from typing import List
  5. # This function should correspond to the enums present in c10/core/QEngine.h
  6. def _get_qengine_id(qengine: str) -> int:
  7. if qengine == 'none' or qengine == '' or qengine is None:
  8. ret = 0
  9. elif qengine == 'fbgemm':
  10. ret = 1
  11. elif qengine == 'qnnpack':
  12. ret = 2
  13. elif qengine == 'onednn':
  14. ret = 3
  15. elif qengine == 'x86':
  16. ret = 4
  17. else:
  18. ret = -1
  19. raise RuntimeError("{} is not a valid value for quantized engine".format(qengine))
  20. return ret
  21. # This function should correspond to the enums present in c10/core/QEngine.h
  22. def _get_qengine_str(qengine: int) -> str:
  23. all_engines = {0 : 'none', 1 : 'fbgemm', 2 : 'qnnpack', 3 : 'onednn', 4 : 'x86'}
  24. return all_engines.get(qengine, '*undefined')
  25. class _QEngineProp:
  26. def __get__(self, obj, objtype) -> str:
  27. return _get_qengine_str(torch._C._get_qengine())
  28. def __set__(self, obj, val: str) -> None:
  29. torch._C._set_qengine(_get_qengine_id(val))
  30. class _SupportedQEnginesProp:
  31. def __get__(self, obj, objtype) -> List[str]:
  32. qengines = torch._C._supported_qengines()
  33. return [_get_qengine_str(qe) for qe in qengines]
  34. def __set__(self, obj, val) -> None:
  35. raise RuntimeError("Assignment not supported")
  36. class QuantizedEngine(types.ModuleType):
  37. def __init__(self, m, name):
  38. super().__init__(name)
  39. self.m = m
  40. def __getattr__(self, attr):
  41. return self.m.__getattribute__(attr)
  42. engine = _QEngineProp()
  43. supported_engines = _SupportedQEnginesProp()
  44. # This is the sys.modules replacement trick, see
  45. # https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
  46. sys.modules[__name__] = QuantizedEngine(sys.modules[__name__], __name__)
  47. engine: str
  48. supported_engines: List[str]