registry.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import functools
  2. import sys
  3. from typing import Callable, Dict, List, Optional, Protocol, Sequence, Tuple
  4. import torch
  5. from torch import fx
  6. class CompiledFn(Protocol):
  7. def __call__(self, *args: torch.Tensor) -> Tuple[torch.Tensor, ...]:
  8. ...
  9. CompilerFn = Callable[[fx.GraphModule, List[torch.Tensor]], CompiledFn]
  10. _BACKENDS: Dict[str, CompilerFn] = dict()
  11. def register_backend(
  12. compiler_fn: Optional[CompilerFn] = None,
  13. name: Optional[str] = None,
  14. tags: Sequence[str] = (),
  15. ):
  16. """
  17. Decorator to add a given compiler to the registry to allow calling
  18. `torch.compile` with string shorthand. Note: for projects not
  19. imported by default, it might be easier to pass a function directly
  20. as a backend and not use a string.
  21. Args:
  22. compiler_fn: Callable taking a FX graph and fake tensor inputs
  23. name: Optional name, defaults to `compiler_fn.__name__`
  24. tags: Optional set of string tags to categorize backend with
  25. """
  26. if compiler_fn is None:
  27. # @register_backend(name="") syntax
  28. return functools.partial(register_backend, name=name, tags=tags)
  29. assert callable(compiler_fn)
  30. name = name or compiler_fn.__name__
  31. assert name not in _BACKENDS, f"duplicate name: {name}"
  32. _BACKENDS[name] = compiler_fn
  33. compiler_fn._tags = tuple(tags)
  34. return compiler_fn
  35. register_debug_backend = functools.partial(register_backend, tags=("debug",))
  36. register_experimental_backend = functools.partial(
  37. register_backend, tags=("experimental",)
  38. )
  39. def lookup_backend(compiler_fn):
  40. """Expand backend strings to functions"""
  41. if isinstance(compiler_fn, str):
  42. if compiler_fn not in _BACKENDS:
  43. _lazy_import()
  44. if compiler_fn not in _BACKENDS:
  45. _lazy_import_entry_point(compiler_fn)
  46. compiler_fn = _BACKENDS[compiler_fn]
  47. return compiler_fn
  48. def list_backends(exclude_tags=("debug", "experimental")):
  49. """
  50. Return valid strings that can be passed to:
  51. torch.compile(..., backend="name")
  52. """
  53. _lazy_import()
  54. exclude_tags = set(exclude_tags or ())
  55. return sorted(
  56. [
  57. name
  58. for name, backend in _BACKENDS.items()
  59. if not exclude_tags.intersection(backend._tags)
  60. ]
  61. )
  62. @functools.lru_cache(None)
  63. def _lazy_import():
  64. from .. import backends
  65. from ..utils import import_submodule
  66. import_submodule(backends)
  67. from ..debug_utils import dynamo_minifier_backend
  68. assert dynamo_minifier_backend is not None
  69. @functools.lru_cache(None)
  70. def _lazy_import_entry_point(backend_name: str):
  71. from importlib.metadata import entry_points
  72. compiler_fn = None
  73. group_name = "torch_dynamo_backends"
  74. if sys.version_info < (3, 10):
  75. backend_eps = entry_points()
  76. eps = [ep for ep in backend_eps[group_name] if ep.name == backend_name]
  77. if len(eps) > 0:
  78. compiler_fn = eps[0].load()
  79. else:
  80. backend_eps = entry_points(group=group_name)
  81. if backend_name in backend_eps.names:
  82. compiler_fn = backend_eps[backend_name].load()
  83. if compiler_fn is not None and backend_name not in list_backends(tuple()):
  84. register_backend(compiler_fn=compiler_fn, name=backend_name)