1234567891011121314151617 |
- import torch
- from typing import TypeVar
- from contextlib import contextmanager
- T = TypeVar('T')
- # returns if all are the same mode
- def all_same_mode(modes):
- return all(tuple(mode == modes[0] for mode in modes))
- @contextmanager
- def no_dispatch():
- guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
- try:
- yield
- finally:
- del guard
|