123456789101112131415161718192021222324252627282930313233343536373839 |
- import importlib
- import logging
- import torch
- from torch._dynamo import register_backend
- from .common import fake_tensor_unsupported
- log = logging.getLogger(__name__)
- @register_backend
- @fake_tensor_unsupported
- def ipex(model, inputs):
- try:
- import intel_extension_for_pytorch # type: ignore[import] # noqa: F401
- except ImportError:
- log.exception(
- "Unable to import Intel Extension for PyTorch (IPEX). "
- "Please install the right version of IPEX that matches the PyTorch version being used. "
- "Refer to https://github.com/intel/intel-extension-for-pytorch for details."
- )
- raise
- try:
- with torch.no_grad():
- traced_model = torch.jit.trace(model.eval(), inputs)
- traced_model = torch.jit.freeze(traced_model)
- return traced_model
- except Exception:
- log.warning("JIT trace failed during the 'ipex' optimize process.")
- return model
- def has_ipex():
- try:
- importlib.import_module("intel_extension_for_pytorch")
- return True
- except ImportError:
- return False
|