proxy_tensor.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import contextlib
  7. import functools
  8. from typing import Any, Callable, Dict, List, Optional, Tuple, Union
  9. import torch
  10. import torch.utils._pytree as pytree
  11. from torch.fx import Tracer, GraphModule
  12. from torch._subclasses.fake_tensor import FakeTensorMode
  13. from torch._dispatch.python import enable_python_dispatcher
  14. import torch.fx as fx
  15. from torch.fx.passes.shape_prop import _extract_tensor_metadata
  16. from contextlib import contextmanager, nullcontext
  17. import inspect
  18. from dataclasses import dataclass
  19. import weakref
  20. import operator
  21. from torch.utils._stats import count
  22. from torch.utils._python_dispatch import TorchDispatchMode, _pop_mode_temporarily, _get_current_dispatch_mode
  23. from torch._subclasses import FakeTensor
  24. from .symbolic_shapes import ShapeEnv, SymDispatchMode, SymNode
  25. from torch.fx import Proxy
  26. from torch import SymInt, SymFloat, SymBool
  27. from torch.utils.weak import WeakTensorKeyDictionary
  28. __all__ = ["PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter", "py_sym_types", "get_innermost_proxy_mode"]
  29. aten = torch.ops.aten
  30. prim = torch.ops.prim
  31. CURRENT_DECOMPOSITION_TABLE: Dict[torch._ops.OpOverload, Callable] = {}
  32. CONSTANT_NUMEL_LIMIT = 1
  33. # We currently convert all SymInt to proxies before we use them.
  34. # This could plausibly be handled at the Dynamo level.
  35. pytree._register_pytree_node(torch.Size, lambda x: (list(x), None), lambda xs, _: tuple(xs))
  36. def fake_signature(fn, nargs):
  37. """FX gets confused by varargs, de-confuse it"""
  38. argnames = ",".join(f"arg{i}" for i in range(nargs))
  39. return eval(f"lambda {argnames}: fn({argnames})", {"fn": fn})
  40. @contextmanager
  41. def decompose(decomposition_table):
  42. global CURRENT_DECOMPOSITION_TABLE
  43. old_decomposition_table = CURRENT_DECOMPOSITION_TABLE
  44. CURRENT_DECOMPOSITION_TABLE = decomposition_table
  45. try:
  46. yield CURRENT_DECOMPOSITION_TABLE
  47. finally:
  48. CURRENT_DECOMPOSITION_TABLE = old_decomposition_table
  49. # ensure we cannot collide with other properties
  50. proxy_slot = object()
  51. no_default = object()
  52. py_sym_types = (SymInt, SymFloat, SymBool)
  53. def is_sym_node(node):
  54. assert hasattr(node, 'meta'), "All nodes traced with proxy_tensor should have meta"
  55. return "val" in node.meta and isinstance(node.meta['val'], py_sym_types)
  56. def set_proxy_slot(obj, tracer, proxy):
  57. if isinstance(obj, torch.Tensor):
  58. # We DO want to clobber proxies whenever we run an inplace operation
  59. # on a tensor, and it affects the metadata on the proxy.
  60. tracer.tensor_tracker[obj] = proxy
  61. else:
  62. # NB: Never clobber pre-existing proxy. Although the proxies
  63. # are in principle equivalent, when we do graph partitioning
  64. # we need there not to be spurious dependencies on tangent inputs.
  65. # This works because primals get their SymInts set first, and
  66. # THEN later we allocate tangent inputs. Make sure if a SymInt
  67. # is derivable from a primal that we use that.
  68. assert isinstance(obj, SymNode), type(obj)
  69. if obj not in tracer.symnode_tracker:
  70. tracer.symnode_tracker[obj] = proxy
  71. def has_proxy_slot(obj, tracer):
  72. assert isinstance(obj, (torch.Tensor, SymNode)), type(obj)
  73. return get_proxy_slot(obj, tracer, False, lambda _: True)
  74. # the default argument is what to return if the slot is not set.
  75. # the transform argument is handy if you need to extract a subfield from
  76. # the successfully looked up result (but NOT the default.)
  77. def get_proxy_slot(obj, tracer, default=no_default, transform=lambda x: x):
  78. if isinstance(obj, torch.Tensor):
  79. tracker = tracer.tensor_tracker
  80. else:
  81. assert isinstance(obj, SymNode), type(obj)
  82. tracker = tracer.symnode_tracker
  83. if obj not in tracker:
  84. if default is no_default:
  85. raise RuntimeError(f"{obj} is not tracked with proxy for {tracer}")
  86. return default
  87. return transform(tracker[obj])
  88. def snapshot_fake(val):
  89. return val.detach()
  90. def unwrap_proxy(proxy_mode, e):
  91. if isinstance(e, torch.Tensor):
  92. return get_proxy_slot(e, proxy_mode.tracer, e, lambda e: e.proxy)
  93. elif isinstance(e, (torch.SymInt, torch.SymFloat, torch.SymBool)):
  94. return get_proxy_slot(e.node, proxy_mode.tracer, e, lambda e: e())
  95. else:
  96. return e
  97. # What invariants do we have for the 'val' set on the FX node? It has accurate
  98. # metadata... but only for metadata that exists "below" all other subsystems
  99. # (most notably autograd, but also vmap, functorch transforms, etc). This means
  100. # you can get the dtype, shape, stride, storage, but you CANNOT get requires_grad,
  101. # grad_fn, _base (_base actually may be set due to recursive call to
  102. # ADInplaceOrView, but you shouldn't rely on it.)
  103. def set_meta(proxy, val):
  104. if isinstance(val, FakeTensor):
  105. proxy.node.meta['val'] = snapshot_fake(val)
  106. proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(val)
  107. elif isinstance(val, py_sym_types):
  108. proxy.node.meta['val'] = val
  109. elif isinstance(val, (list, tuple)):
  110. if all(isinstance(x, FakeTensor) for x in val):
  111. proxy.node.meta['val'] = [snapshot_fake(x) for x in val]
  112. elif isinstance(val, torch.Tensor):
  113. if not val.is_sparse:
  114. proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(val)
  115. # NB: Kinda hacky, but we should try to get val as the metadata
  116. # everywhere
  117. # TODO: This doesn't properly track storages. A more robust
  118. # approach would be to maintain a per-trace FakeTensorMode and
  119. # from_real_tensor to create fake values (don't forget to
  120. # snapshot_fake)
  121. fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=True)
  122. with fake_tensor_mode:
  123. proxy.node.meta['val'] = torch.empty_strided(val.shape, val.stride(), device=val.device, dtype=val.dtype)
  124. return proxy
  125. def thunkify(f, *args, **kwargs):
  126. """
  127. Delays computation of f until it's called again
  128. Also caches the result
  129. """
  130. return functools.lru_cache(1)(functools.partial(f, *args, **kwargs))
  131. def track_tensor(tensor, proxy, *, constant, tracer):
  132. def try_set_proxy_slot(outer_s, proxy_callable, *args):
  133. assert callable(proxy_callable)
  134. if isinstance(outer_s, SymInt):
  135. inner_s = outer_s.node
  136. set_proxy_slot(inner_s, tracer, thunkify(proxy_callable, outer_s, *args))
  137. # The basic idea is that we need to associate each tensor/SymInt
  138. # with a Proxy. How do we setup this association? We just store
  139. # the proxy on the proxy slot of the object, keyed on the tracer
  140. # (so that if we have multiple tracers at the same time, they
  141. # don't clobber each other.)
  142. for i, s in enumerate(tensor.shape):
  143. try_set_proxy_slot(s, lambda x, i: set_meta(torch.ops.aten.sym_size(proxy, i), x), i)
  144. for i, s in enumerate(tensor.stride()):
  145. try_set_proxy_slot(s, lambda x, i: set_meta(torch.ops.aten.sym_stride(proxy, i), x), i)
  146. try_set_proxy_slot(tensor.numel(), lambda x: set_meta(torch.ops.aten.sym_numel(proxy), x))
  147. try_set_proxy_slot(tensor.storage_offset(), lambda x: set_meta(torch.ops.aten.sym_storage_offset(proxy), x))
  148. set_proxy_slot(tensor, tracer, _ProxyTensor(proxy, constant))
  149. def track_tensor_tree(inner_res, proxy_res, *, constant, tracer):
  150. def wrap_with_proxy(e, proxy, constant):
  151. if isinstance(e, torch.Tensor):
  152. track_tensor(e, proxy, tracer=tracer, constant=constant)
  153. set_meta(proxy, e)
  154. elif isinstance(e, py_sym_types):
  155. # NB: eagerly set meta here, so that the numbering is in order
  156. set_meta(proxy, e)
  157. set_proxy_slot(e.node, tracer, lambda: proxy)
  158. elif isinstance(e, list):
  159. # example use case: allreduce_ returns ([tensor], work)
  160. for idx, ee in enumerate(e):
  161. wrap_with_proxy(ee, proxy[idx], get_constant(idx))
  162. def get_constant(idx):
  163. if constant is None:
  164. return None
  165. else:
  166. return constant[idx]
  167. # Unfortunately, tree_map cannot directly be used here. As the resulting
  168. # object may be a proxy that represents a tuple, we may need to
  169. # explicitly unwrap the proxy by simulating the flattening operations.
  170. if isinstance(inner_res, (tuple, list)):
  171. if isinstance(proxy_res, fx.Proxy):
  172. set_meta(proxy_res, inner_res)
  173. for idx, e in enumerate(inner_res):
  174. wrap_with_proxy(e, proxy_res[idx], get_constant(idx))
  175. elif isinstance(inner_res, py_sym_types + (torch.Tensor,)):
  176. wrap_with_proxy(inner_res, proxy_res, constant)
  177. return inner_res
  178. def maybe_disable_fake_tensor_mode():
  179. # TODO: figure out if this API generally makes sense and bake it into the
  180. # library
  181. mb_fake_mode = _get_current_dispatch_mode()
  182. if isinstance(mb_fake_mode, FakeTensorMode):
  183. return _pop_mode_temporarily()
  184. else:
  185. return nullcontext()
  186. @dataclass
  187. class _ProxyTensor:
  188. proxy: Proxy
  189. constant: Optional[torch.Tensor]
  190. def fetch_sym_proxy(tracer):
  191. def inner(e):
  192. n = e.node
  193. if n.constant is not None:
  194. return n.constant
  195. else:
  196. # NB: we REQUIRE all symints to be tracked
  197. return get_proxy_slot(n, tracer)()
  198. return inner
  199. def fetch_tensor_proxy(tracer):
  200. return lambda t: get_proxy_slot(t, tracer, t)
  201. HANDLED_TYPES = (torch.Tensor, torch.nn.Parameter)
  202. def proxy_call(proxy_mode, func, args, kwargs):
  203. def can_handle_tensor(x):
  204. return type(x) in HANDLED_TYPES or has_proxy_slot(x, proxy_mode.tracer)
  205. # If there are any tensor subclasses, we need to handle those tensor subclasses first
  206. # TODO: we could use types to test this
  207. if not pytree.tree_all_only(torch.Tensor, can_handle_tensor, (args, kwargs)):
  208. return NotImplemented
  209. if func in CURRENT_DECOMPOSITION_TABLE:
  210. with proxy_mode:
  211. r = CURRENT_DECOMPOSITION_TABLE[func](*args, **kwargs)
  212. if r is not NotImplemented:
  213. return r
  214. with proxy_mode:
  215. r = func.decompose(*args, **kwargs)
  216. if r is not NotImplemented:
  217. return r
  218. tracer = proxy_mode.tracer
  219. f_args, f_kwargs = pytree.tree_map_only(torch.Tensor, fetch_tensor_proxy(tracer), (args, kwargs))
  220. # If there are SymInts, we also should not consider this constant.
  221. # However, fake tensor handling of SymInts is sufficiently broken that
  222. # I couldn't write a test for this case
  223. all_constant = (
  224. pytree.tree_all_only(_ProxyTensor, lambda t: t.constant is not None, (f_args, f_kwargs))
  225. # TODO: maybe constant SymInts should also be allowed? Not sure if
  226. # this can happen
  227. and pytree.tree_all_only((SymInt, SymFloat, SymBool), lambda _: False, (args, kwargs))
  228. )
  229. if torch.Tag.data_dependent_output in func.tags: # type: ignore[attr-defined]
  230. # Check if all of the Tensor inputs are constants
  231. if all_constant:
  232. const_args, const_kwargs = pytree.tree_map_only(
  233. _ProxyTensor, lambda t: t.constant, (f_args, f_kwargs)
  234. )
  235. with maybe_disable_fake_tensor_mode():
  236. return func(*const_args, **const_kwargs)
  237. # If any of the Tensor inputs are "real" (not FakeTensor), we may
  238. # incorrectly burn in constants by allowing this access. Raise
  239. # an error in this case
  240. if pytree.tree_all_only(torch.Tensor, lambda t: not isinstance(t, FakeTensor), (args, kwargs)):
  241. raise RuntimeError(
  242. f"It appears that you're trying to get value out of a tracing tensor with {func} - erroring out! "
  243. "It's likely that this is caused by data-dependent control flow or similar. "
  244. "It may be possible to trace this with dynamic shapes; try setting tracing_mode='symbolic' "
  245. "in your make_fx call."
  246. )
  247. proxy_args, proxy_kwargs = pytree.tree_map_only(
  248. (SymInt, SymFloat, SymBool),
  249. fetch_sym_proxy(proxy_mode.tracer),
  250. pytree.tree_map_only(_ProxyTensor, lambda e: e.proxy, (f_args, f_kwargs))
  251. )
  252. # When we trace through a torch.tensor invocation, you never actually
  253. # see a torch.ops.aten.tensor call. Instead, the way this function is
  254. # implemented internally is that we allocate a plain tensor (this is
  255. # *guaranteed* to be a plain tensor, we disable all modes when doing
  256. # so), and then call at::lift_fresh on it (to give modes a chance to do
  257. # their stuff). Furthermore, the tensor argument to lift_fresh is guaranteed
  258. # to be freshly allocated, so we want lift_fresh to be a no-op (directly
  259. # returning the input argument).
  260. #
  261. # Here is the basic problem: when we trace this sequence of executions
  262. # into an FX graph, what happens to this call sequence? Traditionally,
  263. # tensor constants get interned as buffers on the FX GraphModule. But
  264. # this is dangerous. Consider:
  265. #
  266. # x = torch.tensor(1)
  267. # x.add_(2)
  268. #
  269. # Naively, this traces into:
  270. #
  271. # t = self._tensor_constant0 # initialized to torch.tensor(1)
  272. # x = torch.ops.aten.lift_fresh(t)
  273. # x.add_(2)
  274. #
  275. # If lift_fresh returns t directly, the subsequent add_ call will
  276. # modify the tensor constant. Really, the problem is we've violated
  277. # the invariant the the argument to lift is fresh. So what we should
  278. # preserve the invariant by replacing lift_fresh with lift_fresh_copy:
  279. #
  280. # t = self._tensor_constant0 # initialized to torch.tensor(1)
  281. # x = torch.ops.aten.lift_fresh_copy(t)
  282. # x.add_(2)
  283. #
  284. # This is what the overload modification does.
  285. if func is torch.ops.aten.lift_fresh.default:
  286. func = torch.ops.aten.lift_fresh_copy.default
  287. proxy_out = proxy_mode.tracer.create_proxy('call_function', func, proxy_args, proxy_kwargs,
  288. name=proxy_mode.tracer.graph._target_to_str(func.overloadpacket.__name__))
  289. # This makes DCE marginally less likely to DCE inplace operations.
  290. # It is not strictly necessary
  291. # Kind of a hacky way to test if an op is in-place or not
  292. if func.overloadpacket.__name__[-1] == "_" and func.overloadpacket.__name__[0] != "_":
  293. if isinstance(args[0], List):
  294. # e.g., c10d::allreduce_ returns a list of tensors as the first element
  295. # in the output.
  296. for i, a in enumerate(args[0]):
  297. a.proxy = proxy_out[0][i]
  298. else:
  299. args[0].proxy = proxy_out
  300. out = func(*args, **kwargs)
  301. # In some circumstances, we will be tracing in a situation where a tensor
  302. # is *statically* known to be a constant (currently, this only happens if
  303. # you run torch.tensor; deterministic factory functions like torch.arange
  304. # don't get this treatment). When the tensor in question is small, it's
  305. # helpful to due constant propagation in case we call item() (in which
  306. # case we can return the constant value that is known, rather than give
  307. # an error.) The logic here tests if constant propagation is possible
  308. # (because all of the inputs are constant). If so, we disable fake tensor
  309. # mode (if it is on) and do true compute on the constant.
  310. #
  311. # It's worth highlighting that we're making a policy decision here.
  312. # There is a potential that the tensor is actually quite large, and we
  313. # don't actually want to run the compute. The tensor being quite large
  314. # is one of the reasons why factory functions don't get this treatment
  315. # (since they can be quite large; if a parameter is initialized to a
  316. # constant value it will be!) Similarly, there is also a potential
  317. # to run an operator that blows up the size of a small tensor; we don't
  318. # protect against this case, but we could force, e.g., only single
  319. # element constant computation by testing the numel of the result before
  320. # propagating const-ness. Similarly, we don't require the constant to
  321. # live on CPU, but we could.
  322. any_constant = pytree.tree_any_only(_ProxyTensor, lambda t: t.constant is not None, (f_args, f_kwargs))
  323. constant = None
  324. # If this is a lift, the input tensor is guaranteed to be a
  325. # constant, so we keep a copy of the original argument along so
  326. # we can query it if we're asked to item() it at some later point
  327. if func is torch.ops.aten.lift_fresh_copy.default and out.numel() <= CONSTANT_NUMEL_LIMIT:
  328. with maybe_disable_fake_tensor_mode():
  329. constant = args[0].clone()
  330. elif (
  331. torch.Tag.nondeterministic_seeded not in func.tags # type: ignore[attr-defined]
  332. and all_constant
  333. and any_constant
  334. and pytree.tree_all_only(torch.Tensor, lambda t: t.numel() <= CONSTANT_NUMEL_LIMIT, out)
  335. ):
  336. # NB: do NOT include factories as constants
  337. with maybe_disable_fake_tensor_mode():
  338. const_args, const_kwargs = pytree.tree_map_only(
  339. _ProxyTensor, lambda t: t.constant, (f_args, f_kwargs)
  340. )
  341. constant = func(*const_args, **const_kwargs)
  342. else:
  343. constant = None
  344. track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer)
  345. return out
  346. class PythonKeyTracer(Tracer):
  347. def __init__(self):
  348. super().__init__(autowrap_modules=())
  349. self.tensor_tracker = WeakTensorKeyDictionary()
  350. self.symnode_tracker = weakref.WeakKeyDictionary() # type: ignore[var-annotated]
  351. # In general, we don't want to make modules leaves. In principle, users of
  352. # this tracer might want to override this in order to turn a couple specific
  353. # modules into leaves in the traced graph.
  354. def call_module(
  355. self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any]
  356. ) -> Any:
  357. return forward(*args, **kwargs)
  358. # We don't want to turn getattr calls into proxies. So we just return the actual value.
  359. def getattr(self, attr, attr_val, parameter_proxy_cache):
  360. return attr_val
  361. def create_arg(self, a: Any):
  362. if isinstance(a, torch.nn.Parameter):
  363. for n, p in self.root.named_parameters():
  364. if a is p:
  365. return self.create_node('get_attr', n, (), {})
  366. qualname: Optional[str] = None
  367. if not qualname:
  368. i = 0
  369. while True:
  370. qualname = f'_param_constant{i}'
  371. if not hasattr(self.root, qualname):
  372. break
  373. i += 1
  374. setattr(self.root, qualname, a)
  375. return self.create_node('get_attr', qualname, (), {})
  376. elif isinstance(a, (SymInt, SymFloat, SymBool)):
  377. assert a.node.constant is not None
  378. return a.node.constant
  379. return super().create_arg(a)
  380. def dispatch_trace(
  381. root: Union[torch.nn.Module, Callable],
  382. tracer: Tracer,
  383. concrete_args: Optional[Tuple[Any, ...]] = None,
  384. ) -> GraphModule:
  385. graph = tracer.trace(root, concrete_args)
  386. name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
  387. return GraphModule(tracer.root, graph, name)
  388. def wrap_key(f, tensors, tracer):
  389. flat_tensors, tensors_spec = pytree.tree_flatten(tensors)
  390. @functools.wraps(f)
  391. def wrapped(*proxies):
  392. flat_proxies, proxies_spec = pytree.tree_flatten(proxies)
  393. assert len(flat_proxies) == len(flat_tensors)
  394. assert isinstance(_get_current_dispatch_mode(), ProxyTorchDispatchMode)
  395. with _pop_mode_temporarily():
  396. track_tensor_tree(flat_tensors, flat_proxies, constant=None, tracer=tracer)
  397. out = f(*tensors)
  398. out = pytree.tree_map_only(
  399. torch.Tensor,
  400. lambda t: get_proxy_slot(t, tracer, t, lambda x: x.proxy),
  401. out
  402. )
  403. out = pytree.tree_map_only(
  404. (SymInt, SymFloat, SymBool),
  405. lambda t: get_proxy_slot(t.node, tracer)(),
  406. out
  407. )
  408. return out
  409. return wrapped
  410. class ProxyTorchDispatchMode(TorchDispatchMode):
  411. def __init__(self, tracer, tracing_mode):
  412. self.tracer = tracer
  413. self.tracing_mode = tracing_mode
  414. self.enable_tracing = True
  415. self.sym_mode = ProxySymDispatchMode(tracer)
  416. self.trace_state = {}
  417. self._managers = []
  418. @count
  419. def __torch_dispatch__(self, func, types, args=(), kwargs=None):
  420. with self.sym_mode.enable(False):
  421. return self.inner_torch_dispatch(func, types, args, kwargs)
  422. def __enter__(self):
  423. # sym mode first, then us...
  424. m = self.sym_mode.enable(True)
  425. self._managers.append(m)
  426. m.__enter__()
  427. return super().__enter__()
  428. def __exit__(self, exc_type, exc_value, traceback):
  429. m = self._managers.pop()
  430. # ...exit us first, then sym mode
  431. b = super().__exit__(exc_type, exc_value, traceback)
  432. if not b:
  433. return m.__exit__(exc_type, exc_value, traceback)
  434. else:
  435. return m.__exit__(None, None, None)
  436. def inner_torch_dispatch(self, func, types, args=(), kwargs=None):
  437. if not self.enable_tracing:
  438. return func(*args, **kwargs)
  439. if func in [prim.device.default]:
  440. return func(*args, **kwargs)
  441. out = proxy_call(self, func, args, kwargs)
  442. return out
  443. class ProxySymDispatchMode(SymDispatchMode):
  444. def __init__(self, tracer):
  445. super().__init__()
  446. self.tracer = tracer
  447. # When false, we don't trace operations. If you do this, you MUST
  448. # call track_tensor/track_tensor_tree on all results of the operation
  449. # to ensure we can adeduately track the results
  450. self.enable_tracing = True
  451. @contextmanager
  452. def enable(self, b):
  453. old = self.enable_tracing
  454. self.enable_tracing = b
  455. try:
  456. yield
  457. finally:
  458. self.enable_tracing = old
  459. def _compute_proxy(self, func, args, out: Union[SymInt, SymFloat, SymBool]):
  460. n_args = tuple(
  461. get_proxy_slot(a.node, self.tracer)().node if isinstance(a, py_sym_types) else a
  462. for a in args
  463. )
  464. # func doesn't have a __torch_function__ that Proxy can interpose, so
  465. # we gotta do it manually
  466. n_out = self.tracer.create_node("call_function", func, n_args, {})
  467. p_out = fx.Proxy(n_out, self.tracer)
  468. set_meta(p_out, out)
  469. return p_out
  470. def __sym_dispatch__(self, func, types, args, kwargs):
  471. if not self.enable_tracing:
  472. return func(*args, **kwargs)
  473. # Peephole optimize multiply by one
  474. # NB: be careful not to trigger guards here!
  475. if func == operator.mul:
  476. if isinstance(args[1], int) and args[1] == 1:
  477. return args[0]
  478. elif isinstance(args[0], int) and args[0] == 1:
  479. return args[1]
  480. # For speed, we assume there are no nested data structures
  481. # (otherwise we could use tree_map)
  482. # We also assume there are no keyword arguments.
  483. assert not kwargs
  484. out = func(*args, **kwargs)
  485. # If func returned a constant, we don't need to trace; we have
  486. # determined that the result is constant (no matter if the inputs
  487. # were symbolic) and it is no longer necessary to trace the
  488. # computation. This could occur if func triggered some guards.
  489. if isinstance(out, py_sym_types):
  490. # Delays tracing out the proxies on this op until we actually need it
  491. p_out_thunk = thunkify(self._compute_proxy, func=func, args=args, out=out)
  492. set_proxy_slot(out.node, self.tracer, p_out_thunk)
  493. return out
  494. # TODO: I'm not sure what the point of this class is; you can just
  495. # make_fx through a regular Interpreter
  496. class DecompositionInterpreter(torch.fx.Interpreter):
  497. def __init__(self, module: torch.fx.GraphModule, new_graph: torch.fx.Graph, decomposition_table=None, **kwargs):
  498. super().__init__(module, **kwargs)
  499. self.new_graph = new_graph
  500. self.tracer = torch.fx.proxy.GraphAppendingTracer(self.new_graph)
  501. # Blegh
  502. self.tracer.tensor_tracker = WeakTensorKeyDictionary() # type: ignore[attr-defined]
  503. self.tracer.symnode_tracker = weakref.WeakKeyDictionary() # type: ignore[attr-defined]
  504. self.decomposition_table = decomposition_table
  505. if self.decomposition_table is None:
  506. self.decomposition_table = {}
  507. self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real")
  508. def placeholder(self, target, args, kwargs):
  509. out = super().placeholder(target, args, kwargs)
  510. proxy = torch.fx.Proxy(self.new_graph.placeholder(target), self.tracer)
  511. track_tensor_tree(out, proxy, constant=None, tracer=self.tracer)
  512. # TODO handle case where the first character of target is '*'
  513. return out
  514. def get_attr(self, target, args, kwargs):
  515. out = super().get_attr(target, args, kwargs)
  516. proxy = torch.fx.Proxy(self.new_graph.get_attr(target), self.tracer)
  517. track_tensor_tree(out, proxy, constant=None, tracer=self.tracer)
  518. return out
  519. # call_function, call_method, call_module get traced automatically by the outer mode.
  520. def output(self, target, args, kwargs):
  521. out = super().output(target, args, kwargs)
  522. def unwrap(e):
  523. return get_proxy_slot(e, self.tracer, e, lambda x: x.proxy.node)
  524. self.new_graph.output(pytree.tree_map(unwrap, out))
  525. return out
  526. def run(self, *args, **kwargs):
  527. # Should enter the mode at least once for being able to restore it later
  528. # See: https://github.com/pytorch/pytorch/pull/82549#discussion_r934782025
  529. with decompose(self.decomposition_table), self.mode:
  530. return super().run(*args, **kwargs)
  531. def wrapper_and_args_for_make_fx(func, args, kwargs):
  532. # make_fx doesn't support kwargs, so we need to do this flattening
  533. # and then unflatten the args before calling func
  534. flat_args, spec = pytree.tree_flatten((args, kwargs))
  535. def wrapped(flat_args):
  536. fn_args, fn_kwargs = pytree.tree_unflatten(flat_args, spec)
  537. return func(*fn_args, **fn_kwargs)
  538. return wrapped, flat_args
  539. @contextmanager
  540. def disable_autocast_cache():
  541. old_value = torch.is_autocast_cache_enabled()
  542. torch.set_autocast_cache_enabled(False)
  543. try:
  544. yield
  545. finally:
  546. torch.set_autocast_cache_enabled(old_value)
  547. def make_fx(f, decomposition_table=None, tracing_mode="real", _allow_non_fake_inputs=False):
  548. assert tracing_mode in ["real", "fake", "symbolic"]
  549. if decomposition_table is None:
  550. decomposition_table = {}
  551. @functools.wraps(f)
  552. def wrapped(*args):
  553. phs = pytree.tree_map(lambda _: fx.PH, args) # type: ignore[attr-defined]
  554. fx_tracer = PythonKeyTracer()
  555. fake_tensor_mode: Any = nullcontext()
  556. if tracing_mode == "real":
  557. fake_tensor_mode = nullcontext()
  558. elif tracing_mode == "fake":
  559. fake_tensor_mode = FakeTensorMode(
  560. allow_fallback_kernels=True,
  561. allow_non_fake_inputs=_allow_non_fake_inputs)
  562. elif tracing_mode == "symbolic":
  563. shape_env = ShapeEnv()
  564. fake_tensor_mode = FakeTensorMode(
  565. allow_fallback_kernels=False,
  566. allow_non_fake_inputs=_allow_non_fake_inputs,
  567. shape_env=shape_env)
  568. else:
  569. raise AssertionError(f"Unexpected tracing type: {tracing_mode}")
  570. python_dispatcher_mode: Any = nullcontext()
  571. if tracing_mode == "symbolic":
  572. python_dispatcher_mode = enable_python_dispatcher()
  573. proxy_mode = ProxyTorchDispatchMode(fx_tracer, tracing_mode)
  574. arg_count = 0
  575. def wrap_fake(x):
  576. nonlocal arg_count
  577. if isinstance(x, torch.Tensor):
  578. # TODO: it would be nice to line these up with the names
  579. # FX will choose for the placeholders, but we don't
  580. # actually know what the names will be at this point yet
  581. # NB: the Source here is actually meaningless
  582. from torch._dynamo.source import ConstantSource
  583. source = ConstantSource(f"input{arg_count}")
  584. arg_count += 1
  585. return fake_tensor_mode.from_tensor(x, source=source) # type: ignore[attr-defined]
  586. return x
  587. sym_mode = proxy_mode.sym_mode
  588. wrap_fn_map = {
  589. "real": lambda x: x,
  590. "fake": wrap_fake,
  591. "symbolic": wrap_fake,
  592. }
  593. args = pytree.tree_map(wrap_fn_map[tracing_mode], args)
  594. if not hasattr(inspect.unwrap(f), '__code__') or inspect.unwrap(f).__code__.co_flags & inspect.CO_VARARGS:
  595. # FX doesn't support varargs, so we gotta fake up a wrapper
  596. # TODO: Would be nice to fix this at the source...
  597. func = fake_signature(f, len(phs))
  598. else:
  599. func = f
  600. # We disable the autocast cache as the autocast cache causes type conversions on parameters to
  601. # check a cache, which introduces untracked tensors into the graph
  602. #
  603. # We also disable tracing by any other tensor proxy-based tracers except the current. The
  604. # purpose of `make_fx` is to produce graphmodules as a side effect; its internal execution is
  605. # thus irrelevant to any external functional trace.
  606. with decompose(decomposition_table), fake_tensor_mode, python_dispatcher_mode, \
  607. sym_mode, proxy_mode, disable_autocast_cache(), disable_proxy_modes_tracing(enable_current=True):
  608. t = dispatch_trace(wrap_key(func, args, fx_tracer), tracer=fx_tracer, concrete_args=tuple(phs))
  609. # TODO: kind of a bad way to do it, should maybe figure out a better way
  610. if tracing_mode == "symbolic":
  611. t.shape_env = shape_env # type: ignore[assignment]
  612. return t
  613. return wrapped
  614. def get_torch_dispatch_modes():
  615. return torch.utils._python_dispatch._get_current_dispatch_mode_stack()
  616. def get_innermost_proxy_mode():
  617. for m in reversed(torch.utils._python_dispatch._get_current_dispatch_mode_stack()):
  618. if isinstance(m, ProxyTorchDispatchMode):
  619. return m
  620. return None
  621. @contextlib.contextmanager
  622. def disable_proxy_modes_tracing(enable_current=False):
  623. modes = get_torch_dispatch_modes()
  624. proxy_tensor_modes = [m for m in modes if isinstance(m, ProxyTorchDispatchMode)]
  625. if enable_current:
  626. proxy_tensor_modes = proxy_tensor_modes[:-1]
  627. olds = [(m.enable_tracing, m.sym_mode.enable_tracing) for m in proxy_tensor_modes]
  628. for proxy_mode in proxy_tensor_modes:
  629. proxy_mode.enable_tracing = False
  630. proxy_mode.sym_mode.enable_tracing = False
  631. try:
  632. yield
  633. finally:
  634. for proxy_mode, (old, old_sym) in zip(proxy_tensor_modes, olds):
  635. proxy_mode.enable_tracing = old
  636. proxy_mode.sym_mode.enable_tracing = old_sym
  637. def get_isolated_graphmodule(func, args, kwargs, tracing_mode="real"):
  638. """A helper function used to get the GraphModule for the given func.
  639. It's expected to be used in the ProxyTensor tracing context.
  640. It detaches the args and kwargs from the current tracer so that the trace of
  641. the current graph module can be created without any side-effects.
  642. """
  643. wrapped, all_args = wrapper_and_args_for_make_fx(func, args, kwargs)
  644. with disable_proxy_modes_tracing():
  645. gm = make_fx(wrapped, tracing_mode=tracing_mode)(all_args)
  646. return gm