_mode_utils.py 381 B

1234567891011121314151617
  1. import torch
  2. from typing import TypeVar
  3. from contextlib import contextmanager
  4. T = TypeVar('T')
  5. # returns if all are the same mode
  6. def all_same_mode(modes):
  7. return all(tuple(mode == modes[0] for mode in modes))
  8. @contextmanager
  9. def no_dispatch():
  10. guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
  11. try:
  12. yield
  13. finally:
  14. del guard