import torch._C from contextlib import contextmanager import unittest.mock import torch import torch.utils._pytree as pytree import itertools __all__ = ['enable_python_dispatcher', 'no_python_dispatcher'] @contextmanager def no_python_dispatcher(): g = torch._C._DisablePythonDispatcher() try: yield finally: del g @contextmanager def enable_python_dispatcher(): g = torch._C._EnablePythonDispatcher() try: yield finally: del g CROSSREF_FUNCTIONALIZE = False def all_known_overloads(): for ns in torch.ops: packets = getattr(torch.ops, ns) for op_name in packets: packet = getattr(packets, op_name) for overload in packet: yield getattr(packet, overload) @contextmanager def suspend_functionalization(): f_tls = torch._C._dispatch_tls_is_dispatch_key_included(torch._C.DispatchKey.Functionalize) f_rv = torch._C._functionalization_reapply_views_tls() if f_tls: torch._disable_functionalization() try: yield finally: if f_tls: torch._enable_functionalization(reapply_views=f_rv) def check_tensor_metadata_matches(nv, rv, desc): assert callable(desc) assert nv.size() == rv.size(), f"{desc()}: sizes {nv.size()} != {rv.size()}" assert nv.dtype == rv.dtype, f"{desc()}: dtype {nv.dtype} != {rv.dtype}" same_strides, idx = torch._prims_common.check_significant_strides(nv, rv, only_cuda=False) assert same_strides, f"{desc()}: strides {nv.stride()} != {rv.stride()} (mismatch at index {idx})" def check_metadata_matches(n, r, desc): assert callable(desc) n_vals, n_spec = pytree.tree_flatten(n) r_vals, r_spec = pytree.tree_flatten(r) # TODO: test the specs match; empirically sometimes we have a tuple # on one side and a list on the other assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}" for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals): if not isinstance(rv, torch.Tensor): continue check_tensor_metadata_matches(nv, rv, lambda: f"{desc()} output {i}") class Lit: def __init__(self, s): self.s = s def __repr__(self): return self.s def _fmt(a: object) -> object: if isinstance(a, torch.Tensor): return Lit(f"torch.empty_strided({tuple(a.size())}, {a.stride()}, dtype={a.dtype})") else: return a def make_crossref_functionalize(op, final_key): from torch._subclasses.fake_tensor import FakeTensorMode # This case is pretty weird, suppress it for now if op == torch.ops.aten.lift_fresh.default: return final_key def handler(*args, **kwargs): fake_mode = FakeTensorMode() def fakeify_defun(t): if isinstance(t, torch.Tensor): if torch._is_functional_tensor(t): r = torch._from_functional_tensor(t) # NB: This assumes that the inner tensor sizes/strides match # the outer tensor sizes/strides. This doesn't necessarily have to # be the case, see discussion at # https://github.com/pytorch/pytorch/pull/87610/files/401ddeda1d769bedc88a12de332c7357b60e51a4#r1007264456 assert t.size() == r.size() assert t.stride() == r.stride() else: r = t # TODO: suppress guards return fake_mode.from_tensor(r) return t def maybe_detach(t): if isinstance(t, torch.Tensor): return t.detach() else: return t with suspend_functionalization(): f_args, f_kwargs = pytree.tree_map(fakeify_defun, (args, kwargs)) orig_f_args, orig_f_kwargs = pytree.tree_map(maybe_detach, (f_args, f_kwargs)) with fake_mode: f_r = op(*f_args, **f_kwargs) r = op._op_dk(final_key, *args, **kwargs) def desc(): fmt_args = ", ".join( itertools.chain( (repr(pytree.tree_map(_fmt, a)) for a in orig_f_args), (f"{k}={pytree.tree_map(_fmt, v)}" for k, v in orig_f_kwargs.items()), ) ) return f"{op}({fmt_args})" check_metadata_matches(f_r, r, desc) return r return handler # NB: enabling this is slow, don't do it in a hot loop. This is purely # for debugging purposes. @contextmanager def enable_crossref_functionalize(): for op in all_known_overloads(): op._uncache_dispatch(torch._C.DispatchKey.Functionalize) try: with enable_python_dispatcher(), unittest.mock.patch( 'torch._dispatch.python.CROSSREF_FUNCTIONALIZE', True): yield finally: for op in all_known_overloads(): op._uncache_dispatch(torch._C.DispatchKey.Functionalize)