__init__.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. from typing import Any
  2. import warnings
  3. import sys
  4. from functools import lru_cache as _lru_cache
  5. from contextlib import contextmanager
  6. from torch.backends import ContextProp, PropModule, __allow_nonbracketed_mutation
  7. try:
  8. import opt_einsum as _opt_einsum # type: ignore[import]
  9. except ImportError:
  10. _opt_einsum = None
  11. @_lru_cache()
  12. def is_available() -> bool:
  13. r"""Returns a bool indicating if opt_einsum is currently available."""
  14. return _opt_einsum is not None
  15. def get_opt_einsum() -> Any:
  16. r"""Returns the opt_einsum package if opt_einsum is currently available, else None."""
  17. return _opt_einsum
  18. def _set_enabled(_enabled: bool) -> None:
  19. if not is_available() and _enabled:
  20. raise ValueError(f'opt_einsum is not available, so setting `enabled` to {_enabled} will not reap '
  21. 'the benefits of calculating an optimal path for einsum. torch.einsum will '
  22. 'fall back to contracting from left to right. To enable this optimal path '
  23. 'calculation, please install opt-einsum.')
  24. global enabled
  25. enabled = _enabled
  26. def _get_enabled() -> bool:
  27. return enabled
  28. def _set_strategy(_strategy: str) -> None:
  29. if not is_available():
  30. raise ValueError(f'opt_einsum is not available, so setting `strategy` to {_strategy} will not be meaningful. '
  31. 'torch.einsum will bypass path calculation and simply contract from left to right. '
  32. 'Please install opt_einsum or unset `strategy`.')
  33. if not enabled:
  34. raise ValueError(f'opt_einsum is not enabled, so setting a `strategy` to {_strategy} will not be meaningful. '
  35. 'torch.einsum will bypass path calculation and simply contract from left to right. '
  36. 'Please set `enabled` to `True` as well or unset `strategy`.')
  37. if _strategy not in ['auto', 'greedy', 'optimal']:
  38. raise ValueError(f'`strategy` must be one of the following: [auto, greedy, optimal] but is {_strategy}')
  39. global strategy
  40. strategy = _strategy
  41. def _get_strategy() -> str:
  42. return strategy
  43. def set_flags(_enabled=None, _strategy=None):
  44. orig_flags = (enabled, None if not is_available() else strategy)
  45. if _enabled is not None:
  46. _set_enabled(_enabled)
  47. if _strategy is not None:
  48. _set_strategy(_strategy)
  49. return orig_flags
  50. @contextmanager
  51. def flags(enabled=None, strategy=None):
  52. with __allow_nonbracketed_mutation():
  53. orig_flags = set_flags(enabled, strategy)
  54. try:
  55. yield
  56. finally:
  57. # recover the previous values
  58. with __allow_nonbracketed_mutation():
  59. set_flags(*orig_flags)
  60. # The magic here is to allow us to intercept code like this:
  61. #
  62. # torch.backends.opt_einsum.enabled = True
  63. class OptEinsumModule(PropModule):
  64. def __init__(self, m, name):
  65. super().__init__(m, name)
  66. global enabled
  67. enabled = ContextProp(_get_enabled, _set_enabled)
  68. global strategy
  69. strategy = None
  70. if is_available():
  71. strategy = ContextProp(_get_strategy, _set_strategy)
  72. # This is the sys.modules replacement trick, see
  73. # https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
  74. sys.modules[__name__] = OptEinsumModule(sys.modules[__name__], __name__)
  75. enabled = True if is_available() else False
  76. strategy = 'auto' if is_available() else None