123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899 |
- from typing import Any
- import warnings
- import sys
- from functools import lru_cache as _lru_cache
- from contextlib import contextmanager
- from torch.backends import ContextProp, PropModule, __allow_nonbracketed_mutation
- try:
- import opt_einsum as _opt_einsum # type: ignore[import]
- except ImportError:
- _opt_einsum = None
- @_lru_cache()
- def is_available() -> bool:
- r"""Returns a bool indicating if opt_einsum is currently available."""
- return _opt_einsum is not None
- def get_opt_einsum() -> Any:
- r"""Returns the opt_einsum package if opt_einsum is currently available, else None."""
- return _opt_einsum
- def _set_enabled(_enabled: bool) -> None:
- if not is_available() and _enabled:
- raise ValueError(f'opt_einsum is not available, so setting `enabled` to {_enabled} will not reap '
- 'the benefits of calculating an optimal path for einsum. torch.einsum will '
- 'fall back to contracting from left to right. To enable this optimal path '
- 'calculation, please install opt-einsum.')
- global enabled
- enabled = _enabled
- def _get_enabled() -> bool:
- return enabled
- def _set_strategy(_strategy: str) -> None:
- if not is_available():
- raise ValueError(f'opt_einsum is not available, so setting `strategy` to {_strategy} will not be meaningful. '
- 'torch.einsum will bypass path calculation and simply contract from left to right. '
- 'Please install opt_einsum or unset `strategy`.')
- if not enabled:
- raise ValueError(f'opt_einsum is not enabled, so setting a `strategy` to {_strategy} will not be meaningful. '
- 'torch.einsum will bypass path calculation and simply contract from left to right. '
- 'Please set `enabled` to `True` as well or unset `strategy`.')
- if _strategy not in ['auto', 'greedy', 'optimal']:
- raise ValueError(f'`strategy` must be one of the following: [auto, greedy, optimal] but is {_strategy}')
- global strategy
- strategy = _strategy
- def _get_strategy() -> str:
- return strategy
- def set_flags(_enabled=None, _strategy=None):
- orig_flags = (enabled, None if not is_available() else strategy)
- if _enabled is not None:
- _set_enabled(_enabled)
- if _strategy is not None:
- _set_strategy(_strategy)
- return orig_flags
- @contextmanager
- def flags(enabled=None, strategy=None):
- with __allow_nonbracketed_mutation():
- orig_flags = set_flags(enabled, strategy)
- try:
- yield
- finally:
- # recover the previous values
- with __allow_nonbracketed_mutation():
- set_flags(*orig_flags)
- # The magic here is to allow us to intercept code like this:
- #
- # torch.backends.opt_einsum.enabled = True
- class OptEinsumModule(PropModule):
- def __init__(self, m, name):
- super().__init__(m, name)
- global enabled
- enabled = ContextProp(_get_enabled, _set_enabled)
- global strategy
- strategy = None
- if is_available():
- strategy = ContextProp(_get_strategy, _set_strategy)
- # This is the sys.modules replacement trick, see
- # https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
- sys.modules[__name__] = OptEinsumModule(sys.modules[__name__], __name__)
- enabled = True if is_available() else False
- strategy = 'auto' if is_available() else None
|