ipex.py 1.1 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. import importlib
  2. import logging
  3. import torch
  4. from torch._dynamo import register_backend
  5. from .common import fake_tensor_unsupported
  6. log = logging.getLogger(__name__)
  7. @register_backend
  8. @fake_tensor_unsupported
  9. def ipex(model, inputs):
  10. try:
  11. import intel_extension_for_pytorch # type: ignore[import] # noqa: F401
  12. except ImportError:
  13. log.exception(
  14. "Unable to import Intel Extension for PyTorch (IPEX). "
  15. "Please install the right version of IPEX that matches the PyTorch version being used. "
  16. "Refer to https://github.com/intel/intel-extension-for-pytorch for details."
  17. )
  18. raise
  19. try:
  20. with torch.no_grad():
  21. traced_model = torch.jit.trace(model.eval(), inputs)
  22. traced_model = torch.jit.freeze(traced_model)
  23. return traced_model
  24. except Exception:
  25. log.warning("JIT trace failed during the 'ipex' optimize process.")
  26. return model
  27. def has_ipex():
  28. try:
  29. importlib.import_module("intel_extension_for_pytorch")
  30. return True
  31. except ImportError:
  32. return False