123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109 |
- import functools
- import sys
- from typing import Callable, Dict, List, Optional, Protocol, Sequence, Tuple
- import torch
- from torch import fx
- class CompiledFn(Protocol):
- def __call__(self, *args: torch.Tensor) -> Tuple[torch.Tensor, ...]:
- ...
- CompilerFn = Callable[[fx.GraphModule, List[torch.Tensor]], CompiledFn]
- _BACKENDS: Dict[str, CompilerFn] = dict()
- def register_backend(
- compiler_fn: Optional[CompilerFn] = None,
- name: Optional[str] = None,
- tags: Sequence[str] = (),
- ):
- """
- Decorator to add a given compiler to the registry to allow calling
- `torch.compile` with string shorthand. Note: for projects not
- imported by default, it might be easier to pass a function directly
- as a backend and not use a string.
- Args:
- compiler_fn: Callable taking a FX graph and fake tensor inputs
- name: Optional name, defaults to `compiler_fn.__name__`
- tags: Optional set of string tags to categorize backend with
- """
- if compiler_fn is None:
- # @register_backend(name="") syntax
- return functools.partial(register_backend, name=name, tags=tags)
- assert callable(compiler_fn)
- name = name or compiler_fn.__name__
- assert name not in _BACKENDS, f"duplicate name: {name}"
- _BACKENDS[name] = compiler_fn
- compiler_fn._tags = tuple(tags)
- return compiler_fn
- register_debug_backend = functools.partial(register_backend, tags=("debug",))
- register_experimental_backend = functools.partial(
- register_backend, tags=("experimental",)
- )
- def lookup_backend(compiler_fn):
- """Expand backend strings to functions"""
- if isinstance(compiler_fn, str):
- if compiler_fn not in _BACKENDS:
- _lazy_import()
- if compiler_fn not in _BACKENDS:
- _lazy_import_entry_point(compiler_fn)
- compiler_fn = _BACKENDS[compiler_fn]
- return compiler_fn
- def list_backends(exclude_tags=("debug", "experimental")):
- """
- Return valid strings that can be passed to:
- torch.compile(..., backend="name")
- """
- _lazy_import()
- exclude_tags = set(exclude_tags or ())
- return sorted(
- [
- name
- for name, backend in _BACKENDS.items()
- if not exclude_tags.intersection(backend._tags)
- ]
- )
- @functools.lru_cache(None)
- def _lazy_import():
- from .. import backends
- from ..utils import import_submodule
- import_submodule(backends)
- from ..debug_utils import dynamo_minifier_backend
- assert dynamo_minifier_backend is not None
- @functools.lru_cache(None)
- def _lazy_import_entry_point(backend_name: str):
- from importlib.metadata import entry_points
- compiler_fn = None
- group_name = "torch_dynamo_backends"
- if sys.version_info < (3, 10):
- backend_eps = entry_points()
- eps = [ep for ep in backend_eps[group_name] if ep.name == backend_name]
- if len(eps) > 0:
- compiler_fn = eps[0].load()
- else:
- backend_eps = entry_points(group=group_name)
- if backend_name in backend_eps.names:
- compiler_fn = backend_eps[backend_name].load()
- if compiler_fn is not None and backend_name not in list_backends(tuple()):
- register_backend(compiler_fn=compiler_fn, name=backend_name)
|