proxy.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535
  1. import dis
  2. import copy
  3. import sys
  4. import torch
  5. import inspect
  6. import operator
  7. import traceback
  8. import collections
  9. from .graph import magic_methods, reflectable_magic_methods, Graph
  10. from typing import Tuple, Dict, OrderedDict, Optional, Iterable, Any, Iterator, Callable
  11. from .node import Target, Node, Argument, base_types, map_aggregate
  12. from ._compatibility import compatibility
  13. from .operator_schemas import check_for_mutable_operation
  14. import torch.fx.traceback as fx_traceback
  15. __all__ = ['TracerBase', 'GraphAppendingTracer', 'TraceError',
  16. 'Proxy', 'Attribute', 'ParameterProxy', 'Scope',
  17. 'ScopeContextManager']
  18. @compatibility(is_backward_compatible=False)
  19. class Scope:
  20. """ Scope object that records the module path and the module type
  21. of a module. Scope is used to track the information of the module
  22. that contains a Node in a Graph of GraphModule. For example::
  23. class Sub(torch.nn.Module):
  24. def forward(self, x):
  25. # This will be a call_method Node in GraphModule,
  26. # scope for this would be (module_path="sub", module_type=Sub)
  27. return x.transpose(1, 2)
  28. class M(torch.nn.Module):
  29. def __init__(self):
  30. self.sub = Sub()
  31. def forward(self, x):
  32. # This will be a call_method Node as well,
  33. # scope for this would be (module_path="", None)
  34. x = x.transpose(1, 2)
  35. x = self.sub(x)
  36. return x
  37. """
  38. def __init__(self, module_path: str, module_type: Any):
  39. super().__init__()
  40. self.module_path = module_path
  41. self.module_type = module_type
  42. @compatibility(is_backward_compatible=False)
  43. class ScopeContextManager:
  44. """ A context manager to track the Scope of Node during symbolic tracing.
  45. When entering a forward function of a Module, we'll update the scope information of
  46. the current module, and when we exit, we'll restore the previous scope information.
  47. """
  48. def __init__(
  49. self,
  50. scope: Scope,
  51. current_scope: Scope,
  52. ):
  53. super().__init__()
  54. # Keep a copy of prev scope to restore on exit
  55. self._prev_scope = copy.copy(scope)
  56. # Update scope to current scope
  57. scope.module_path = current_scope.module_path
  58. scope.module_type = current_scope.module_type
  59. # Save a reference so we can restore it
  60. self._scope = scope
  61. def __enter__(self):
  62. return self._scope
  63. def __exit__(self, *args):
  64. self._scope.module_path = self._prev_scope.module_path
  65. self._scope.module_type = self._prev_scope.module_type
  66. return
  67. @compatibility(is_backward_compatible=True)
  68. class TracerBase:
  69. graph: Graph
  70. record_stack_traces : bool = False
  71. # Feature flag for mutable schema checking
  72. # Enableby default in 1.12
  73. check_mutable_operations : bool = False
  74. # Feature flag for assert tracing
  75. trace_asserts : bool = False
  76. # Feature flag for proxying accesses to buffer values
  77. proxy_buffer_attributes : bool = False
  78. # Name of the function to be traced. It will only be used when
  79. # ``root`` is an instance of ``nn.Module``
  80. traced_func_name: str = "forward"
  81. # Maps the containing module's name to the operator name
  82. scope : Scope
  83. # Records the module call stack
  84. module_stack: OrderedDict[str, str]
  85. # Mapping of node name to module scope
  86. node_name_to_scope: Dict[str, Tuple[str, type]]
  87. @compatibility(is_backward_compatible=True)
  88. def create_node(self, kind : str, target : Target,
  89. args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None,
  90. type_expr : Optional[Any] = None) -> Node:
  91. """
  92. Inserts a graph node given target, args, kwargs, and name.
  93. This method can be overridden to do extra checking, validation, or
  94. modification of values used in node creation. For example, one might
  95. want to disallow in-place operations from being recorded.
  96. """
  97. if kind == 'call_function' and self.check_mutable_operations:
  98. check_for_mutable_operation(target, args, kwargs)
  99. node = self.graph.create_node(kind, target, args, kwargs, name, type_expr)
  100. # TODO node_name_to_scope will be depricated in favor of
  101. # node.meta['nn_module_stack']
  102. self.node_name_to_scope[node.name] = (
  103. self.scope.module_path,
  104. self.scope.module_type,
  105. )
  106. if self.module_stack:
  107. node.meta['nn_module_stack'] = copy.copy(self.module_stack)
  108. return node
  109. @compatibility(is_backward_compatible=True)
  110. def proxy(self, node: Node) -> 'Proxy':
  111. return Proxy(node, self)
  112. @compatibility(is_backward_compatible=True)
  113. def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any],
  114. name: Optional[str] = None, type_expr : Optional[Any] = None,
  115. proxy_factory_fn: Callable[[Node], 'Proxy'] = None):
  116. '''
  117. Create a Node from the given arguments, then return the Node
  118. wrapped in a Proxy object.
  119. If kind = 'placeholder', then we're creating a Node that
  120. represents the parameter of a function. If we need to encode
  121. a default parameter, we use the ``args`` tuple. ``args`` is
  122. otherwise empty for ``placeholder`` Nodes.
  123. '''
  124. args_ = self.create_arg(args)
  125. kwargs_ = self.create_arg(kwargs)
  126. assert isinstance(args_, tuple)
  127. assert isinstance(kwargs_, dict)
  128. node = self.create_node(kind, target, args_, kwargs_, name, type_expr)
  129. if not proxy_factory_fn:
  130. proxy = self.proxy(node)
  131. else:
  132. proxy = proxy_factory_fn(node)
  133. # Optionally set stack trace on the created Node for debugging purposes
  134. if fx_traceback.has_preserved_node_meta():
  135. current_meta: Dict[str, Any] = fx_traceback.get_current_meta()
  136. # Explicitly set the stack_trace, nn_module_stack and source_fn on the node.meta
  137. # If other meta fields are needed, they can be added here
  138. stack_trace = current_meta.get("stack_trace")
  139. if stack_trace:
  140. proxy.node.stack_trace = stack_trace
  141. nn_module_stack = current_meta.get("nn_module_stack")
  142. if nn_module_stack:
  143. proxy.node.meta["nn_module_stack"] = nn_module_stack
  144. source_fn = current_meta.get("source_fn")
  145. if source_fn:
  146. proxy.node.meta["source_fn"] = source_fn
  147. elif self.record_stack_traces:
  148. user_frame = self._find_user_frame()
  149. if user_frame:
  150. walk_stack_gen = traceback.walk_stack(user_frame)
  151. summary = traceback.StackSummary.extract(walk_stack_gen) # type: ignore[arg-type]
  152. tb_lines = summary.format()
  153. proxy.node.stack_trace = ''.join(tb_lines)
  154. return proxy
  155. def _find_user_frame(self):
  156. """
  157. Find the Python stack frame executing the user code during
  158. symbolic tracing.
  159. """
  160. # We have to do a little dance here. Basically, walk up the callstack and
  161. # record the first frame not in the pytorch source. This is the frame executing
  162. # the user code during tracing.
  163. frame = inspect.currentframe()
  164. pt_files = ['torch/fx/proxy.py',
  165. 'torch/fx/_symbolic_trace.py',
  166. 'torch/fx/experimental/proxy_tensor.py',
  167. 'torch/_ops.py',
  168. 'torch/_tensor.py',
  169. 'torch/utils/_python_dispatch.py',
  170. 'torch/_prims_common/wrappers.py',
  171. 'torch/_refs/__init__.py',
  172. 'torch/_refs/nn/functional/__init__.py'
  173. ]
  174. while frame:
  175. frame = frame.f_back
  176. if frame and all(not frame.f_code.co_filename.endswith(file) for file in pt_files):
  177. break
  178. if not frame:
  179. return None
  180. return frame
  181. @compatibility(is_backward_compatible=True)
  182. def create_arg(self, a: Any) -> Argument:
  183. """
  184. A method that lowers the objects seen as arguments during symbolic evaluation
  185. into Argument types that can be stored in IR.
  186. Can be override to support more trace-specific types.
  187. """
  188. if not isinstance(a, Proxy) and hasattr(a, '__fx_create_arg__'):
  189. return a.__fx_create_arg__(self)
  190. # aggregates
  191. elif isinstance(a, tuple) and hasattr(a, '_fields'):
  192. # NamedTuple constructors don't seem to like getting a generator
  193. # expression as an argument to their constructor, so build this
  194. # intermediate tuple and unpack it into the NamedTuple constructor
  195. args = tuple(self.create_arg(elem) for elem in a)
  196. return type(a)(*args) # type: ignore[arg-type]
  197. elif isinstance(a, (tuple, list)):
  198. return type(a)(self.create_arg(elem) for elem in a)
  199. elif isinstance(a, dict):
  200. r = {}
  201. for k, v in a.items():
  202. # Check for invalid dict keys. We do not want a Proxy to appear
  203. # anywhere within the key. Since keys can be collection types,
  204. # we iterate through the key with map_aggregate
  205. k = self.create_arg(k)
  206. def no_node(arg):
  207. if isinstance(arg, Node):
  208. raise RuntimeError("Keys for dictionaries used as an argument cannot contain a "
  209. f"Node. Got key: {k}")
  210. map_aggregate(k, no_node)
  211. r[k] = self.create_arg(v)
  212. return r
  213. elif isinstance(a, slice):
  214. return slice(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step))
  215. elif isinstance(a, range):
  216. return range(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step))
  217. if isinstance(a, Proxy):
  218. # base case: we unwrap the Proxy object
  219. return a.node
  220. elif isinstance(a, base_types) or a is None or a is ...:
  221. return a
  222. raise NotImplementedError(f"argument of type: {type(a)}")
  223. @compatibility(is_backward_compatible=True)
  224. def to_bool(self, obj: 'Proxy') -> bool:
  225. """Called when a proxy object is being converted to a boolean, such as
  226. when used in control flow. Normally we don't know what to do because
  227. we don't know the value of the proxy, but a custom tracer can attach more
  228. information to the graph node using create_node and can choose to return a value.
  229. """
  230. raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
  231. @compatibility(is_backward_compatible=True)
  232. def iter(self, obj: 'Proxy') -> Iterator:
  233. """Called when a proxy object is being iterated over, such as
  234. when used in control flow. Normally we don't know what to do because
  235. we don't know the value of the proxy, but a custom tracer can attach more
  236. information to the graph node using create_node and can choose to return an iterator.
  237. """
  238. raise TraceError('Proxy object cannot be iterated. This can be '
  239. 'attempted when the Proxy is used in a loop or'
  240. ' as a *args or **kwargs function argument. '
  241. 'See the torch.fx docs on pytorch.org for a '
  242. 'more detailed explanation of what types of '
  243. 'control flow can be traced, and check out the'
  244. ' Proxy docstring for help troubleshooting '
  245. 'Proxy iteration errors')
  246. @compatibility(is_backward_compatible=True)
  247. def keys(self, obj: 'Proxy') -> Any:
  248. """Called when a proxy object is has the keys() method called.
  249. This is what happens when ** is called on a proxy. This should return an
  250. iterator it ** is suppose to work in your custom tracer.
  251. """
  252. return Attribute(obj, 'keys')()
  253. # used in Proxy object when just appending to the graph while not tracing.
  254. @compatibility(is_backward_compatible=True)
  255. class GraphAppendingTracer(TracerBase):
  256. def __init__(self, graph: Graph):
  257. super().__init__()
  258. self.graph = graph
  259. self.scope = Scope("", None)
  260. self.module_stack = collections.OrderedDict()
  261. self.node_name_to_scope = {}
  262. @compatibility(is_backward_compatible=False)
  263. def assert_fn(x):
  264. assert x
  265. @compatibility(is_backward_compatible=True)
  266. class TraceError(ValueError):
  267. pass
  268. @compatibility(is_backward_compatible=True)
  269. class Proxy:
  270. """
  271. ``Proxy`` objects are ``Node`` wrappers that flow through the
  272. program during symbolic tracing and record all the operations
  273. (``torch`` function calls, method calls, operators) that they touch
  274. into the growing FX Graph.
  275. If you're doing graph transforms, you can wrap your own ``Proxy``
  276. method around a raw ``Node`` so that you can use the overloaded
  277. operators to add additional things to a ``Graph``.
  278. ``Proxy`` objects cannot be iterated. In other words, the symbolic
  279. tracer will throw an error if a ``Proxy`` is used in a loop or as
  280. an ``*args``/``**kwargs`` function argument.
  281. There are two main ways around this:
  282. 1. Factor out the untraceable logic into a top-level function and
  283. use ``fx.wrap`` on it.
  284. 2. If the control flow is static (i.e. the loop trip count is
  285. based on some hyperparameter), the code can be kept in its original
  286. position and refactored into something like::
  287. for i in range(self.some_hyperparameter):
  288. indexed_item = proxied_value[i]
  289. For a more detailed description into the Proxy internals, check out
  290. the "Proxy" section in `torch/fx/OVERVIEW.md`
  291. """
  292. @compatibility(is_backward_compatible=True)
  293. def __init__(self, node: Node, tracer: 'Optional[TracerBase]' = None):
  294. if tracer is None:
  295. # This allows you to create a Proxy object around a raw Node
  296. tracer = GraphAppendingTracer(node.graph)
  297. self.tracer = tracer
  298. self.node = node
  299. def __repr__(self) -> str:
  300. return f'Proxy({self.node.name})'
  301. def __getattr__(self, k) -> 'Attribute':
  302. # note: not added to the graph yet, if this is a method call
  303. # we peephole optimize to the method invocation
  304. return Attribute(self, k)
  305. def __call__(self, *args, **kwargs) -> 'Proxy':
  306. return self.tracer.create_proxy('call_method', '__call__', (self,) + args, kwargs)
  307. def __iter__(self) -> Iterable['Proxy']:
  308. frame = inspect.currentframe()
  309. assert frame is not None
  310. calling_frame = frame.f_back
  311. assert calling_frame is not None
  312. inst_list = list(dis.get_instructions(calling_frame.f_code))
  313. if sys.version_info >= (3, 11):
  314. from bisect import bisect_left
  315. inst_idx = bisect_left(inst_list, calling_frame.f_lasti, key=lambda x: x.offset)
  316. else:
  317. inst_idx = calling_frame.f_lasti // 2
  318. inst = inst_list[inst_idx]
  319. if inst.opname == 'UNPACK_SEQUENCE':
  320. return (self[i] for i in range(inst.argval)) # type: ignore[index]
  321. return self.tracer.iter(self)
  322. def __bool__(self) -> bool:
  323. if self.tracer.trace_asserts:
  324. # check if this boolean is used in an assertion, bytecode pattern for assertions
  325. # is pretty stable for Python 3.7--3.9
  326. frame = inspect.currentframe()
  327. assert frame is not None
  328. calling_frame = frame.f_back
  329. assert calling_frame is not None
  330. insts = list(dis.get_instructions(calling_frame.f_code))
  331. if sys.version_info >= (3, 11):
  332. from bisect import bisect_left
  333. cur = bisect_left(insts, calling_frame.f_lasti, key=lambda x: x.offset)
  334. else:
  335. cur = calling_frame.f_lasti // 2
  336. inst = insts[cur]
  337. if inst.opname == 'POP_JUMP_IF_TRUE':
  338. first = insts[cur + 1]
  339. assert inst.arg is not None
  340. last = insts[inst.arg // 2 - 1]
  341. starts_with_assert = (first.opname == 'LOAD_GLOBAL' and first.argval == 'AssertionError'
  342. or first.opname == 'LOAD_ASSERTION_ERROR')
  343. if starts_with_assert and last.opname == 'RAISE_VARARGS':
  344. self.tracer.create_proxy('call_function', assert_fn, (self,), {})
  345. return True
  346. return self.tracer.to_bool(self)
  347. @compatibility(is_backward_compatible=True)
  348. def keys(self):
  349. return self.tracer.keys(self)
  350. def __len__(self):
  351. raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
  352. "this call to be recorded, please call torch.fx.wrap('len') at "
  353. "module scope")
  354. @classmethod
  355. def __torch_function__(cls, orig_method, types, args=None, kwargs=None):
  356. args = args if args else ()
  357. kwargs = kwargs if kwargs else {}
  358. tracers : Dict[Any, None] = {}
  359. def find_tracer(a):
  360. if isinstance(a, cls):
  361. tracers[a.tracer] = None
  362. torch.fx.node.map_aggregate(args, find_tracer)
  363. torch.fx.node.map_aggregate(kwargs, find_tracer)
  364. if len(tracers) > 1:
  365. raise RuntimeError(f'Found multiple different tracers {list(tracers.keys())} while '
  366. f'trying to trace operations {orig_method}')
  367. tracer = next(iter(tracers.keys()))
  368. if isinstance(orig_method, torch._C.ScriptMethod):
  369. args = (orig_method.owner,) + args
  370. return tracer.create_proxy('call_method', orig_method.name, args, kwargs)
  371. if torch.overrides.is_tensor_method_or_property(orig_method):
  372. return tracer.create_proxy('call_method', orig_method.__name__, args, kwargs)
  373. else:
  374. if isinstance(orig_method, torch._ops.PyOperator):
  375. # TODO: Define how to symbolically trace PyOperators
  376. raise RuntimeError("Unable to symbolically trace PyOperators")
  377. return tracer.create_proxy('call_function', orig_method, args, kwargs,
  378. name=tracer.graph._target_to_str(orig_method.__name__))
  379. @compatibility(is_backward_compatible=True)
  380. class Attribute(Proxy):
  381. @compatibility(is_backward_compatible=True)
  382. def __init__(self, root: Proxy, attr: str):
  383. self.root = root
  384. self.attr = attr
  385. self.tracer = root.tracer
  386. self._node: Optional[Node] = None
  387. @property
  388. def node(self):
  389. # the node for attributes is added lazily, since most will just be method calls
  390. # which do not rely on the getitem call
  391. if self._node is None:
  392. self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
  393. return self._node
  394. def __call__(self, *args, **kwargs):
  395. return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
  396. @compatibility(is_backward_compatible=False)
  397. class ParameterProxy(Proxy):
  398. """
  399. A special proxy which lets "shape", "size", "dim", and a few other
  400. attribute accesses pass through to the underlying module parameter object,
  401. so that conditional tests on these attributes will not throw exception during tracing
  402. """
  403. def __init__(self, tracer: TracerBase, node: Node, name, param):
  404. super().__init__(node, tracer)
  405. assert(isinstance(param, torch.nn.Parameter))
  406. self.param = param
  407. self.name = name
  408. def __repr__(self) -> str:
  409. return f'ParameterProxy({self.name})'
  410. @property
  411. def shape(self):
  412. return self.param.shape
  413. def size(self):
  414. return self.param.size()
  415. def dim(self):
  416. return self.param.dim()
  417. @property
  418. def ndim(self):
  419. return self.param.ndim
  420. def numel(self):
  421. return self.param.numel()
  422. def nelement(self):
  423. return self.param.nelement()
  424. for method in magic_methods:
  425. def _scope(method):
  426. def impl(*args, **kwargs):
  427. tracer = args[0].tracer
  428. target = getattr(operator, method)
  429. return tracer.create_proxy('call_function', target, args, kwargs)
  430. impl.__name__ = method
  431. as_magic = f'__{method.strip("_")}__'
  432. setattr(Proxy, as_magic, impl)
  433. _scope(method)
  434. def _define_reflectable(orig_method_name):
  435. method_name = f'__r{orig_method_name.strip("_")}__'
  436. def impl(self, rhs):
  437. target = getattr(operator, orig_method_name)
  438. return self.tracer.create_proxy('call_function', target, (rhs, self), {})
  439. impl.__name__ = method_name
  440. impl.__qualname__ = method_name
  441. setattr(Proxy, method_name, impl)
  442. for orig_method_name in reflectable_magic_methods:
  443. _define_reflectable(orig_method_name)