import torch from typing import TypeVar from contextlib import contextmanager T = TypeVar('T') # returns if all are the same mode def all_same_mode(modes): return all(tuple(mode == modes[0] for mode in modes)) @contextmanager def no_dispatch(): guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined] try: yield finally: del guard