python.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. import torch._C
  2. from contextlib import contextmanager
  3. import unittest.mock
  4. import torch
  5. import torch.utils._pytree as pytree
  6. import itertools
  7. __all__ = ['enable_python_dispatcher', 'no_python_dispatcher']
  8. @contextmanager
  9. def no_python_dispatcher():
  10. g = torch._C._DisablePythonDispatcher()
  11. try:
  12. yield
  13. finally:
  14. del g
  15. @contextmanager
  16. def enable_python_dispatcher():
  17. g = torch._C._EnablePythonDispatcher()
  18. try:
  19. yield
  20. finally:
  21. del g
  22. CROSSREF_FUNCTIONALIZE = False
  23. def all_known_overloads():
  24. for ns in torch.ops:
  25. packets = getattr(torch.ops, ns)
  26. for op_name in packets:
  27. packet = getattr(packets, op_name)
  28. for overload in packet:
  29. yield getattr(packet, overload)
  30. @contextmanager
  31. def suspend_functionalization():
  32. f_tls = torch._C._dispatch_tls_is_dispatch_key_included(torch._C.DispatchKey.Functionalize)
  33. f_rv = torch._C._functionalization_reapply_views_tls()
  34. if f_tls:
  35. torch._disable_functionalization()
  36. try:
  37. yield
  38. finally:
  39. if f_tls:
  40. torch._enable_functionalization(reapply_views=f_rv)
  41. def check_tensor_metadata_matches(nv, rv, desc):
  42. assert callable(desc)
  43. assert nv.size() == rv.size(), f"{desc()}: sizes {nv.size()} != {rv.size()}"
  44. assert nv.dtype == rv.dtype, f"{desc()}: dtype {nv.dtype} != {rv.dtype}"
  45. same_strides, idx = torch._prims_common.check_significant_strides(nv, rv, only_cuda=False)
  46. assert same_strides, f"{desc()}: strides {nv.stride()} != {rv.stride()} (mismatch at index {idx})"
  47. def check_metadata_matches(n, r, desc):
  48. assert callable(desc)
  49. n_vals, n_spec = pytree.tree_flatten(n)
  50. r_vals, r_spec = pytree.tree_flatten(r)
  51. # TODO: test the specs match; empirically sometimes we have a tuple
  52. # on one side and a list on the other
  53. assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}"
  54. for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals):
  55. if not isinstance(rv, torch.Tensor):
  56. continue
  57. check_tensor_metadata_matches(nv, rv, lambda: f"{desc()} output {i}")
  58. class Lit:
  59. def __init__(self, s):
  60. self.s = s
  61. def __repr__(self):
  62. return self.s
  63. def _fmt(a: object) -> object:
  64. if isinstance(a, torch.Tensor):
  65. return Lit(f"torch.empty_strided({tuple(a.size())}, {a.stride()}, dtype={a.dtype})")
  66. else:
  67. return a
  68. def make_crossref_functionalize(op, final_key):
  69. from torch._subclasses.fake_tensor import FakeTensorMode
  70. # This case is pretty weird, suppress it for now
  71. if op == torch.ops.aten.lift_fresh.default:
  72. return final_key
  73. def handler(*args, **kwargs):
  74. fake_mode = FakeTensorMode()
  75. def fakeify_defun(t):
  76. if isinstance(t, torch.Tensor):
  77. if torch._is_functional_tensor(t):
  78. r = torch._from_functional_tensor(t)
  79. # NB: This assumes that the inner tensor sizes/strides match
  80. # the outer tensor sizes/strides. This doesn't necessarily have to
  81. # be the case, see discussion at
  82. # https://github.com/pytorch/pytorch/pull/87610/files/401ddeda1d769bedc88a12de332c7357b60e51a4#r1007264456
  83. assert t.size() == r.size()
  84. assert t.stride() == r.stride()
  85. else:
  86. r = t
  87. # TODO: suppress guards
  88. return fake_mode.from_tensor(r)
  89. return t
  90. def maybe_detach(t):
  91. if isinstance(t, torch.Tensor):
  92. return t.detach()
  93. else:
  94. return t
  95. with suspend_functionalization():
  96. f_args, f_kwargs = pytree.tree_map(fakeify_defun, (args, kwargs))
  97. orig_f_args, orig_f_kwargs = pytree.tree_map(maybe_detach, (f_args, f_kwargs))
  98. with fake_mode:
  99. f_r = op(*f_args, **f_kwargs)
  100. r = op._op_dk(final_key, *args, **kwargs)
  101. def desc():
  102. fmt_args = ", ".join(
  103. itertools.chain(
  104. (repr(pytree.tree_map(_fmt, a)) for a in orig_f_args),
  105. (f"{k}={pytree.tree_map(_fmt, v)}" for k, v in orig_f_kwargs.items()),
  106. )
  107. )
  108. return f"{op}({fmt_args})"
  109. check_metadata_matches(f_r, r, desc)
  110. return r
  111. return handler
  112. # NB: enabling this is slow, don't do it in a hot loop. This is purely
  113. # for debugging purposes.
  114. @contextmanager
  115. def enable_crossref_functionalize():
  116. for op in all_known_overloads():
  117. op._uncache_dispatch(torch._C.DispatchKey.Functionalize)
  118. try:
  119. with enable_python_dispatcher(), unittest.mock.patch(
  120. 'torch._dispatch.python.CROSSREF_FUNCTIONALIZE', True):
  121. yield
  122. finally:
  123. for op in all_known_overloads():
  124. op._uncache_dispatch(torch._C.DispatchKey.Functionalize)