123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142 |
- import torch._C
- from contextlib import contextmanager
- import unittest.mock
- import torch
- import torch.utils._pytree as pytree
- import itertools
- __all__ = ['enable_python_dispatcher', 'no_python_dispatcher']
- @contextmanager
- def no_python_dispatcher():
- g = torch._C._DisablePythonDispatcher()
- try:
- yield
- finally:
- del g
- @contextmanager
- def enable_python_dispatcher():
- g = torch._C._EnablePythonDispatcher()
- try:
- yield
- finally:
- del g
- CROSSREF_FUNCTIONALIZE = False
- def all_known_overloads():
- for ns in torch.ops:
- packets = getattr(torch.ops, ns)
- for op_name in packets:
- packet = getattr(packets, op_name)
- for overload in packet:
- yield getattr(packet, overload)
- @contextmanager
- def suspend_functionalization():
- f_tls = torch._C._dispatch_tls_is_dispatch_key_included(torch._C.DispatchKey.Functionalize)
- f_rv = torch._C._functionalization_reapply_views_tls()
- if f_tls:
- torch._disable_functionalization()
- try:
- yield
- finally:
- if f_tls:
- torch._enable_functionalization(reapply_views=f_rv)
- def check_tensor_metadata_matches(nv, rv, desc):
- assert callable(desc)
- assert nv.size() == rv.size(), f"{desc()}: sizes {nv.size()} != {rv.size()}"
- assert nv.dtype == rv.dtype, f"{desc()}: dtype {nv.dtype} != {rv.dtype}"
- same_strides, idx = torch._prims_common.check_significant_strides(nv, rv, only_cuda=False)
- assert same_strides, f"{desc()}: strides {nv.stride()} != {rv.stride()} (mismatch at index {idx})"
- def check_metadata_matches(n, r, desc):
- assert callable(desc)
- n_vals, n_spec = pytree.tree_flatten(n)
- r_vals, r_spec = pytree.tree_flatten(r)
- # TODO: test the specs match; empirically sometimes we have a tuple
- # on one side and a list on the other
- assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}"
- for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals):
- if not isinstance(rv, torch.Tensor):
- continue
- check_tensor_metadata_matches(nv, rv, lambda: f"{desc()} output {i}")
- class Lit:
- def __init__(self, s):
- self.s = s
- def __repr__(self):
- return self.s
- def _fmt(a: object) -> object:
- if isinstance(a, torch.Tensor):
- return Lit(f"torch.empty_strided({tuple(a.size())}, {a.stride()}, dtype={a.dtype})")
- else:
- return a
- def make_crossref_functionalize(op, final_key):
- from torch._subclasses.fake_tensor import FakeTensorMode
- # This case is pretty weird, suppress it for now
- if op == torch.ops.aten.lift_fresh.default:
- return final_key
- def handler(*args, **kwargs):
- fake_mode = FakeTensorMode()
- def fakeify_defun(t):
- if isinstance(t, torch.Tensor):
- if torch._is_functional_tensor(t):
- r = torch._from_functional_tensor(t)
- # NB: This assumes that the inner tensor sizes/strides match
- # the outer tensor sizes/strides. This doesn't necessarily have to
- # be the case, see discussion at
- # https://github.com/pytorch/pytorch/pull/87610/files/401ddeda1d769bedc88a12de332c7357b60e51a4#r1007264456
- assert t.size() == r.size()
- assert t.stride() == r.stride()
- else:
- r = t
- # TODO: suppress guards
- return fake_mode.from_tensor(r)
- return t
- def maybe_detach(t):
- if isinstance(t, torch.Tensor):
- return t.detach()
- else:
- return t
- with suspend_functionalization():
- f_args, f_kwargs = pytree.tree_map(fakeify_defun, (args, kwargs))
- orig_f_args, orig_f_kwargs = pytree.tree_map(maybe_detach, (f_args, f_kwargs))
- with fake_mode:
- f_r = op(*f_args, **f_kwargs)
- r = op._op_dk(final_key, *args, **kwargs)
- def desc():
- fmt_args = ", ".join(
- itertools.chain(
- (repr(pytree.tree_map(_fmt, a)) for a in orig_f_args),
- (f"{k}={pytree.tree_map(_fmt, v)}" for k, v in orig_f_kwargs.items()),
- )
- )
- return f"{op}({fmt_args})"
- check_metadata_matches(f_r, r, desc)
- return r
- return handler
- # NB: enabling this is slow, don't do it in a hot loop. This is purely
- # for debugging purposes.
- @contextmanager
- def enable_crossref_functionalize():
- for op in all_known_overloads():
- op._uncache_dispatch(torch._C.DispatchKey.Functionalize)
- try:
- with enable_python_dispatcher(), unittest.mock.patch(
- 'torch._dispatch.python.CROSSREF_FUNCTIONALIZE', True):
- yield
- finally:
- for op in all_known_overloads():
- op._uncache_dispatch(torch._C.DispatchKey.Functionalize)
|