reinplace.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674
  1. import torch
  2. from torch.fx import Node
  3. from torch.fx._compatibility import compatibility
  4. from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor
  5. from torch.utils._pytree import tree_map, tree_flatten, tree_map_only
  6. from torch.multiprocessing.reductions import StorageWeakRef
  7. import _operator
  8. from enum import Enum
  9. import itertools
  10. from typing import Set, Dict
  11. from collections import defaultdict
  12. __all__ = ['reinplace']
  13. class _ViewType(Enum):
  14. NonView = 0
  15. SingleOutputView = 1
  16. MultiOutputView = 2
  17. def _is_view_op(tgt):
  18. if tgt is not None and isinstance(tgt, torch._ops.OpOverload):
  19. schema = tgt._schema
  20. if len(schema.arguments) > 0:
  21. first_arg = schema.arguments[0]
  22. # check if op is a view
  23. return first_arg.alias_info is not None and not first_arg.alias_info.is_write
  24. def _get_view_type(tgt) -> _ViewType:
  25. if tgt is not None and isinstance(tgt, torch._ops.OpOverload):
  26. schema = tgt._schema
  27. if len(schema.arguments) > 0:
  28. first_arg = schema.arguments[0]
  29. # check if op is a view
  30. if first_arg.alias_info is not None and not first_arg.alias_info.is_write:
  31. # check if op is a multi-output view
  32. if '*' in first_arg.alias_info.after_set:
  33. return _ViewType.MultiOutputView
  34. else:
  35. return _ViewType.SingleOutputView
  36. return _ViewType.NonView
  37. # Stores a bunch of metadata related to functionalization each node.
  38. # Relevant metadata:
  39. # n.meta['fake_result']: FakeTensor (same type as the output of the node, but with FakeTenors instead of Tensors)
  40. # The fake tensor output from running the current node
  41. # n.meta['view_of']: Node
  42. # If the current node n is a view of some base tensor, the 'view_of' field tells us which
  43. # view node was used to generate the current node (a view tensor).
  44. # This information actually makes `fake_result` redundant, but we can use `fake_result`
  45. # to sanity check that our aliasing information is correct.
  46. @compatibility(is_backward_compatible=False)
  47. class _FunctionalizationMetadataProp(torch.fx.Interpreter):
  48. def run_node(self, node: Node):
  49. self.node_counter += 1
  50. result = super().run_node(node)
  51. node.meta['fake_result'] = result
  52. node.meta['node_idx'] = self.node_counter
  53. # (1) Update metadata with the list of nodes that are used by this node
  54. # copy_() doesn't read from its first argument; it writes to it, overwriting previous data.
  55. # We don't want to treat it as "being used as an input".
  56. node_args = node.args
  57. if node.target is torch.ops.aten.copy_.default:
  58. node_args = node_args[1:]
  59. # (2) Update metadata to track aliasing information about view tensor nodes.
  60. if node.op == 'call_function':
  61. view_type = _get_view_type(node.target)
  62. if view_type == _ViewType.SingleOutputView:
  63. assert isinstance(node.args[0], Node)
  64. node.meta['view_of'] = node.args[0]
  65. elif view_type == _ViewType.MultiOutputView:
  66. self.multi_output_view_nodes[node] = node.args[0]
  67. # Check if we returned a multi-output view,
  68. # and we're now grabbing the individual views from the output.
  69. #
  70. # For multi-output views, we want to map each output view to the base,
  71. # but this mapping involves two separate nodes in FX IR.
  72. # e.g. "a, b = x_1.split(...)" becomes:
  73. # %split_tensor : [#users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%x_1, 2), kwargs = {})
  74. # %getitem : [#users=1] = call_function[target=operator.getitem](args = (%split_tensor, 0), kwargs = {})
  75. # %getitem_1 : [#users=1] = call_function[target=operator.getitem](args = (%split_tensor, 1), kwargs = {})
  76. # And we'd like to set:
  77. # getitem1.meta['view_of'] = x_1
  78. elif node.target is _operator.getitem:
  79. list_arg = node.args[0]
  80. maybe_base_of_view = self.multi_output_view_nodes.get(list_arg, None)
  81. if maybe_base_of_view is not None:
  82. # Note: we could also track indexing info here for multi-output views.
  83. # I don't think this metadata is strictly needed for de-functionalization.
  84. assert isinstance(maybe_base_of_view, Node)
  85. node.meta['view_of'] = maybe_base_of_view
  86. if 'view_of' in node.meta:
  87. # We're linking the current node with its first argument as views.
  88. # Assert here that this is actually the case, and their storages are the same.
  89. assert isinstance(node.meta['fake_result'], FakeTensor)
  90. assert isinstance(node.meta['view_of'].meta['fake_result'], FakeTensor)
  91. view_storage = StorageWeakRef(node.meta['fake_result']._typed_storage())
  92. base_storage = StorageWeakRef(node.meta['view_of'].meta['fake_result']._typed_storage())
  93. assert view_storage == base_storage
  94. return result
  95. def propagate(self, *args):
  96. self.multi_output_view_nodes = {}
  97. self.node_counter = -1
  98. with FakeTensorMode() as mode:
  99. fake_args = [mode.from_tensor(a) for a in args]
  100. return super().run(*fake_args)
  101. def _schemas_match(functional_schema, inplace_schema):
  102. names_match = inplace_schema.name.endswith("_") and inplace_schema.name[:-1] == functional_schema.name
  103. arg_types_match = len(functional_schema.arguments) == len(inplace_schema.arguments) and all(
  104. a1.type == a2.type for a1, a2 in zip(functional_schema.arguments, inplace_schema.arguments))
  105. # for the inplace op, its first argument should be mutable
  106. assert inplace_schema.arguments[0].alias_info is not None and inplace_schema.arguments[0].alias_info.is_write
  107. # and its remaining arguments shouldn't be.
  108. assert all(a.alias_info is None for a in inplace_schema.arguments[1:])
  109. return names_match and arg_types_match
  110. # TODO: this should be beefed up to be able to properly re-inplace with:
  111. # - mutating ops (e.g. _fused_moving_avg_obs_fq_helper)
  112. # - out= ops (e.g. angle -> angle.out)
  113. # TODO: we should also figure this info out using torchgen.
  114. def _maybe_get_inplace_op(op):
  115. # __module__ seems broken; it returns torch._ops.aten which doesn't exist
  116. if not isinstance(op, torch._ops.OpOverload):
  117. return None
  118. # Some view ops have inplace variants (as_strided_, etc),
  119. # but we do NOT want the reinplacing pass to directly add these into the program.
  120. # (they'll require extra special handling, aren't aren't really useful for perf anyway)
  121. if _is_view_op(op):
  122. return None
  123. op_namespace = op.__module__.split(".")[-1]
  124. op_base_name = op.overloadpacket.__name__
  125. maybe_namespace_module = getattr(torch.ops, op_namespace)
  126. maybe_inplace_op = None if maybe_namespace_module is None else getattr(maybe_namespace_module, f'{op_base_name}_', None)
  127. if maybe_inplace_op is None:
  128. return None
  129. inplace_overloads = [
  130. getattr(maybe_inplace_op, overload_name) for overload_name in maybe_inplace_op.overloads()
  131. ]
  132. inplace_overloads_with_matching_schemas = [
  133. f
  134. for f in inplace_overloads
  135. if _schemas_match(op._schema, f._schema)
  136. ]
  137. # Just becuase foo() and foo_() are both existing operators,
  138. # They aren't guaranteed to have compatible schemas.
  139. # For example, pow.Scalar(Scalar self, Tensor exponent) has no valid inplace variant,
  140. # Even though several overloads of pow_ exist.
  141. if len(inplace_overloads_with_matching_schemas) == 0:
  142. return None
  143. assert len(inplace_overloads_with_matching_schemas) == 1
  144. inplace_op = inplace_overloads_with_matching_schemas[0]
  145. return inplace_op
  146. _VIEW_INVERSE_MAP = {
  147. torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default,
  148. torch.ops.aten.select_scatter.default: torch.ops.aten.select.int,
  149. torch.ops.aten.slice_scatter.default: torch.ops.aten.slice.Tensor,
  150. torch.ops.aten.as_strided_scatter.default: torch.ops.aten.as_strided.default,
  151. }
  152. # This function, given a set of set of (aliased) tensor nodes,
  153. # Returns any nodes in the graph that *use* any of the aliases, that occur *after* op_index
  154. # in the node ordering.
  155. def _get_all_later_node_usages(tensor_aliases: Set[Node], op_index: int):
  156. def _add_if_tensor(x, set_):
  157. if isinstance(x, FakeTensor):
  158. set_.add(StorageWeakRef(x._typed_storage()))
  159. nodes_used_after = set()
  160. for t in tensor_aliases:
  161. # get all nodes that use the current alias
  162. usage_nodes = t.users
  163. for n in usage_nodes:
  164. # We only care about usages after the current node
  165. if 'node_idx' not in n.meta or n.meta['node_idx'] <= op_index:
  166. continue
  167. # We also don't care about intermediate view ops.
  168. # They only matter if their output is then used elsewhere
  169. # (either in an out-of-place op, or as an output to the function).
  170. if n in tensor_aliases:
  171. if isinstance(n.target, torch._ops.OpOverload) or n.target == _operator.getitem:
  172. continue
  173. nodes_used_after.add(n)
  174. return nodes_used_after
  175. # Given an op that we're trying to re-inplace, "b = foo(a)",
  176. # And given a {view}_scatter op that shows up later in the graph, "y = {view}_scatter(base, x, args...)"
  177. # Then re-inplacing `foo()` would allow us to remove the `{view}_scatter` op entirely, IF:
  178. # If there are any aliases in the alias_set(a) that satisfy:
  179. # (1) The base of "alias", "alias_base", has the same size/stride/offset metadata as "base"
  180. # (2) The output of running {view}(alias, args...) gives you the same size/stride/offset metadata
  181. # as "alias"
  182. def _get_view_inverse_node_usages(later_node_usages: Set[Node], self_aliases: Set[Node]) -> Set[Node]:
  183. def matching_view_metadata(a, b):
  184. return a.size() == b.size() and \
  185. a.stride() == b.stride() and \
  186. a.storage_offset() == b.storage_offset()
  187. view_inverse_nodes = set()
  188. # Go through them in node order, so we can see chains of view_scatter ops.
  189. for n in sorted(later_node_usages, key=lambda x: x.meta['node_idx']):
  190. if n.target not in _VIEW_INVERSE_MAP:
  191. continue
  192. base = n.args[0]
  193. mutated_view = n.args[1]
  194. assert isinstance(base, Node)
  195. assert isinstance(base.meta['fake_result'], FakeTensor)
  196. assert isinstance(mutated_view, Node)
  197. assert isinstance(mutated_view.meta['fake_result'], FakeTensor)
  198. # Check that this view_inverse op actually corresponds to taking doing the inverse
  199. # of one of our existing self_alias nodes.
  200. original_view = _VIEW_INVERSE_MAP[n.target]
  201. for self_alias in self_aliases:
  202. # We're looking for some alias of the self arg, "alias",
  203. # that was created from some op `alias = foo(base, args...)`
  204. # such that the current _scatter op "inverts" that foo call.
  205. # We can check that by running the original op again, and checking that the strides match.
  206. if 'view_of' not in self_alias.meta:
  207. continue
  208. self_alias_base = self_alias.meta['view_of']
  209. try:
  210. # The we're trying to re-use the args from the view_scatter call inside of the corresponding
  211. # view op, which might throw. This just indicates that view_scatter op isn't a valid inverse
  212. # of the current alias we're looking at.
  213. view_replay_metadata = original_view(self_alias_base.meta['fake_result'], *n.args[2:], **n.kwargs)
  214. expected_metadata = self_alias.meta['fake_result']
  215. # If the alias and its base both have matching metadata, then this view_scatter op is valid to re-inplace.
  216. if matching_view_metadata(self_alias_base.meta['fake_result'], base.meta['fake_result']) and \
  217. matching_view_metadata(view_replay_metadata, expected_metadata):
  218. view_inverse_nodes.add(n)
  219. except Exception:
  220. continue
  221. return view_inverse_nodes
  222. @compatibility(is_backward_compatible=True)
  223. def reinplace(gm, *sample_args):
  224. """
  225. Given an fx.GraphModule, modifies it to perform "reinplacing",
  226. mutating the nodes of the graph.
  227. We look for out-of-place op call sites like `b = a.add(...)`,
  228. and convert them to be inplace (`b = a.add_(...)`),
  229. as long as the input to the current operator ("a") isn't re-used
  230. anywhere later in the graph.
  231. This pass currently expects to operate on a **functional, ATen** graph.
  232. This can be obtained by running `make_fx(functionalize(f))`.
  233. Sample inputs are needed to determine aliasing relationships of the inputs.
  234. In general, we can't reinplace node `b = a.add(...)` if "a" aliases any of the
  235. inputs to the program.
  236. Given a node "b = foo(a, args...) the algorithm for re-inplacing is as follows:
  237. (1) Perform some initial checks on the metadata of "a" and "args..."
  238. that can disqualify them from being reinplaced.
  239. (1a) Check that the self argument we're attempting to reinplace
  240. has acceptable dtype/size metadata to reinplace with.
  241. For example, if we have:
  242. a = torch.ones(1)
  243. b = torch.ones(10)
  244. out = torch.add(a, b)
  245. We can't turn that into
  246. a.add_(b)
  247. Because that would require resizing "a".
  248. Similarly, we can't convert torch.ge(a, b) into a.ge_(b),
  249. beause that would require changing a's dtype (from e.g. float32 to bool).
  250. Note that in this specific example, we could technically do better..
  251. If we see the pattern:
  252. a_1 = a.ge(b)
  253. a_2 = aten._to_copy(a_1, a.dtype)
  254. Then we this should be valid to completely re-inplace
  255. (this is exactly what functionalization will emit when it sees a.ge_(b)).
  256. This optimization is only really important for user programs
  257. that directly use inplace comparison ops though.
  258. We also cannot re-inplace on tensors that have overlapping memory,
  259. e.g. torch.ones(1).expand(4, 4).add_(1)
  260. (1b) Check if "a" is an alias of any of the program inputs.
  261. If it is, skip and move to the next node.
  262. Inplace'ing an op that would cause it to mutate a program is not sound,
  263. because that would be a side effect visible to the user.
  264. NOTE: there's a future optimization that we should make:
  265. if "a" is a (alias of a) program input, but later in the program
  266. there is a node that looks like "a.copy_(...)",
  267. Then re-inplacing is ok to do - we are temporarily re-using a's buffer,
  268. which will later be overwritten by the copy_() call.
  269. This will be an important optimization to have for programs that mutate
  270. their inputs. It currently isn't implemented though.
  271. (1c) Check if "a" and "args..." alias
  272. For example, re-inplacing to create code like the below
  273. isn't guaranteed to be sound:
  274. aten.mul_(a, a)
  275. (2) Check that "a" and all of its outstanding aliases are not used anywhere
  276. later in the graph. If this is the case, then it's safe to re-inplace
  277. to "b = foo_(a)".
  278. There are a few caveats to this, explained in more detail below:
  279. (a) If "a" is used later as an argument to a view op, that is okay.
  280. It's only a problem if "a" (or that view) is later passed
  281. into a normal operator, or if it is returned as the program output.
  282. (b) If "a" is a repeat argument in `foo()`, then don't reinplace.
  283. Most ATen kernels don't make any guarantees that this is sound,
  284. e.g. if you do aten.mul_(a, a).
  285. So we'll just ban re-inplacing in this case.
  286. It's only a problem if "a" (or that view) is later passed
  287. (c) If "a" is used as an input into a view "inverse" / "scatter"
  288. operator, it is potentially fine to re-inplace
  289. (and remove that scatter operator from the graph).
  290. See below for a more detailed example.
  291. NOTE: there is an optimization in this step that is crucial
  292. to fully recovering performance from functionalization.
  293. Given this program:
  294. def f(x):
  295. a = torch.ops.aten.add(x, x)
  296. b = torch.ops.aten.diagonal(a)
  297. torch.ops.aten.fill_(b, 0)
  298. return d
  299. Functionalization will emit the following:
  300. def f(x):
  301. a = torch.ops.aten.add(x, x)
  302. b = torch.ops.aten.diagonal(a, 0, 1)
  303. b_updated = torch.ops.aten.fill(b, 0)
  304. a_updated = torch.ops.aten.diagonal_scatter(a, b_updated, 0, 1)
  305. return a_updated
  306. Ordinarily, we would not be able to reinplace the fill,
  307. because "b" aliases with "a" which is used by the diagonal_scatter call.
  308. "re-inplacing" is on the hook for figuring out that it is ok to
  309. completely, the expensive diagonal_scatter call, if we re-inplace the add().
  310. So, for every `alias in alias_set(a)`, instead of checking
  311. that "alias" is not used anywhere later in the graph,
  312. we check that
  313. EITHER:
  314. (a) alias is not used anywhere later in the graph
  315. OR:
  316. (b) alias is used exactly once later on in the graph,
  317. in the following op:
  318. out = foo_scatter(alias, x, args...)
  319. where the following must hold:
  320. (i) "foo_scatter" is the "inverse" operator for foo.
  321. This only applies to "foo" ops that are view operators,
  322. which view into a subset of the original tensor's memory.
  323. In practice, there are ~4 operators where this applies:
  324. diagonal -> diagonal_scatter
  325. slice -> slice_scatter
  326. select -> select_scatter
  327. as_strided -> as_strided_scatter
  328. (ii) "args..." are the same between the foo() and foo_scatter() calls.
  329. (3) Perform the actual re-inplacing on foo!
  330. (3b) is the common case, but special care is needed for {view}_scatter (3a)
  331. (3a) {view}_scatter ops.
  332. Consider this program:
  333. a = torch.zeros(2, 2)
  334. b = torch.ones(2)
  335. a[0] = b
  336. Post functionalization, that will look like:
  337. a = torch.zeros(2)
  338. b = torch.ones(1)
  339. a_updated = torch.select_scatter(a, b, 0, 0)
  340. In this case though, there is no "functional" op to re-inplace!
  341. Instead, we'd like to directly remove toe select_scatter call.
  342. We already know from (3) that this is valid,
  343. because "a" has no later usages in the graph.
  344. We perform the re-inplacing on the {view}_scatter op like so
  345. Before:
  346. a_updated = torch.select_scatter(a, b, args...)
  347. After:
  348. a_slice = a.select(a, args...)
  349. a_slice.copy_(b)
  350. (3b) Otherwise, replace the functional op with its inplace variant.
  351. Before:
  352. b = foo(a, args...)
  353. After:
  354. a.foo_(args...)
  355. (4) Finally, after converting either:
  356. Before:
  357. b = foo(a)
  358. After:
  359. foo_(a)
  360. or
  361. Before:
  362. b = {slice}_scatter(a, mutated_slice, args...)
  363. After:
  364. slice = {slice}(a, args...)
  365. slice.copy_(mutated_slice)
  366. We now need to find all later nodes that use "b" as an argument
  367. and update them to take in "a" instead.
  368. Note that for the majority of inplace ops, this isn't actually necessary
  369. (because most inplace ops return "self" as their output).
  370. This isn't generally true for all mutable ops though, which is why
  371. we need to actually replace all of the arguments.
  372. We also need to update our metadata of Dict[StorageWeakRef, Set[Node]],
  373. That maps a given tensor storage to the set of all nodes that take in that storage
  374. as an input.
  375. Specifically, re-inplacing `b = foo(a)` causes "a" and "b"'s sets to get fused
  376. together.
  377. (5) Any "view_inverse/scatter" nodes that were identified as "it's ok to ignore them"
  378. during step (3) get manually deleted from the graph.
  379. Their outputs are no longer used, so technically standard DCE would be able
  380. to do this, but we can no longer run FX's DCE pass now that we have mutable
  381. ops in the graph.
  382. """
  383. _FunctionalizationMetadataProp(gm).propagate(*sample_args)
  384. # Useful debug printing
  385. # def _print(x):
  386. # if isinstance(x, FakeTensor):
  387. # print(f'fake_result: {StorageWeakRef(x._typed_storage()).cdata}')
  388. # for n in gm.graph.nodes:
  389. # print(n.format_node())
  390. # if hasattr(n, 'meta'):
  391. # print(f'node_idx: {n.meta["node_idx"]}')
  392. # if 'fake_result' in n.meta:
  393. # tree_map(_print, n.meta['fake_result'])
  394. # if 'view_of' in n.meta:
  395. # print(f'view_of: {str(n.meta["view_of"])}')
  396. # print()
  397. # We need to know which nodes correspond to inputs (or their aliases)
  398. # so we know not to re-inplace them.
  399. # NOTE: later, we'll need to add an optimization for fully recovering performance
  400. # on programs that mutate inputs.
  401. input_storages = {
  402. StorageWeakRef(
  403. node.meta['fake_result']._typed_storage()
  404. ) for node in gm.graph.nodes if node.op == 'placeholder'}
  405. # We also need to know for a given node, what are all of its aliasing nodes.
  406. storage_to_nodes: Dict[StorageWeakRef, Set[Node]] = defaultdict(set)
  407. for n in gm.graph.nodes:
  408. if 'fake_result' in n.meta:
  409. # Tree-mapping because some ops can return lists of tensors.
  410. def _add_to_map(x):
  411. if isinstance(x, FakeTensor):
  412. storage_to_nodes[StorageWeakRef(x._typed_storage())].add(n)
  413. tree_map(_add_to_map, n.meta['fake_result'])
  414. # inplace-ify functional ops, subject to the constraints written below.
  415. all_later_view_inverse_nodes_to_delete = set()
  416. for idx, node in enumerate(gm.graph.nodes):
  417. if node.op == 'call_function':
  418. # Today, the re-inplace pass on directly acts on:
  419. # - functional ops with an inplace variant
  420. # - {view}_scatter ops that can be potentially removed from the graph.
  421. # Both of these ops take in tensor first args, so filtering on this condition
  422. # makes the later code simpler.
  423. # We should revisit this at some point though, particularly when we also want
  424. # the reinplacer to be able to handle out= and mutable operators
  425. # and tensorlist first args (like `_foreach_` ops).
  426. if not isinstance(node.target, torch._ops.OpOverload):
  427. continue
  428. if len(node.target._schema.arguments) < 1:
  429. continue
  430. if type(node.target._schema.arguments[0].type) != torch.TensorType:
  431. continue
  432. # Step 1a: Check that the self argument we're attempting to reinplace
  433. # has the same size/stride as the output.
  434. # For example, we shouldn't try to reinplace torch.add(scalar_tensor, larger_tensor)
  435. # As it would require resizing scalar_tensor.
  436. # (We could potentially swizzle this into larger_tensor.add_(scalar_tensor),
  437. # this is probably an optimization to revisit later).
  438. self_arg = node.args[0]
  439. self_flattened, _ = tree_flatten(self_arg.meta['fake_result'])
  440. node_flattened, _ = tree_flatten(node.meta['fake_result'])
  441. self_has_wrong_metadata = False
  442. if len(self_flattened) == len(node_flattened):
  443. for self_meta, node_meta in zip(self_flattened, node_flattened):
  444. if self_meta.numel() != node_meta.numel():
  445. self_has_wrong_metadata = True
  446. if self_meta.dtype != node_meta.dtype:
  447. self_has_wrong_metadata = True
  448. # We also cannot re-inplace on tensors that have internal memory overlap.
  449. # e.g. torch.ones(1).expand(4, 4).add_(1)
  450. if torch._debug_has_internal_overlap(self_meta) == 1:
  451. self_has_wrong_metadata = True
  452. # Here, we (optimistically) assume that a.resize(b) is valid to re-inplace,
  453. # Since users should never really be calling the functional "torch.ops.aten.resize"
  454. # op directly in their programs.
  455. if self_has_wrong_metadata and node.target != torch.ops.aten.resize.default:
  456. continue
  457. # Step 1b: ensure that the op we're trying to re-inplace isn't a program input
  458. self_arg_name = self_arg.name
  459. self_arg_storage = StorageWeakRef(self_arg.meta['fake_result']._typed_storage())
  460. if self_arg_storage in input_storages:
  461. # TODO: later, add the optimization for handling `copy_()` calls in the graph.
  462. continue
  463. if len([x for x in node.args if x is self_arg]) > 1:
  464. # Step 1c:
  465. # Calling stuff like aten.mul_(a, a) isn't guaranteed to be sound,
  466. # so we prevent re-inplacing in this case.
  467. continue
  468. self_arg_storage = StorageWeakRef(self_arg.meta['fake_result']._typed_storage())
  469. self_aliases = storage_to_nodes[self_arg_storage]
  470. # First, we find all later usages of any of the aliases of self_arg.
  471. later_node_usages = _get_all_later_node_usages(self_aliases, node.meta['node_idx'])
  472. # Then, we check if any of those later usages are actually view_scatter ops
  473. # that are safe to fully remove.
  474. later_view_inverse_node_usages = _get_view_inverse_node_usages(later_node_usages, self_aliases)
  475. # Step 2: Check to see if the input to the op is re-used later in the graph.
  476. # If not (same goes for its aliases), then this op is safe to re-in place.
  477. # This is a slightly roundabout way to check that there are no later usages of the current self argument.
  478. # (later_view_inverse_node_usages corresponds to "view_scatter" nodes that we are allowed to delete)
  479. can_reinplace = len(later_node_usages - later_view_inverse_node_usages) == 0
  480. if not can_reinplace:
  481. continue
  482. # Step 3a: Special handling for when we see *_scatter operators.
  483. # When we see an operator like `b = torch.slice_scatter(a, ...)`,
  484. # instead of trying to "inplace" it into a.slice_scatter_(..._),
  485. # we would prefer to remove it from the graph entirely,
  486. # and instead copy_() the slice directly into the larger tensor.
  487. # See the description of the algorithm for a full example.
  488. if node.target in _VIEW_INVERSE_MAP and node not in all_later_view_inverse_nodes_to_delete:
  489. view_op = _VIEW_INVERSE_MAP[node.target]
  490. # Before:
  491. # base_updated = torch.ops.aten.slice_scatter.default(base, mutated_slice, args...)
  492. # After:
  493. # slice = torch.ops.aten.slice.default(base, args...)
  494. # slice.copy_(mutated_slice)
  495. with gm.graph.inserting_before(node):
  496. mutated_slice_node = node.args[1]
  497. remaining_slice_args = node.args[2:]
  498. slice_node = gm.graph.create_node(
  499. 'call_function', view_op, (self_arg,) + tuple(remaining_slice_args), node.kwargs)
  500. copy_node = gm.graph.create_node(
  501. 'call_function', torch.ops.aten.copy_.default, (slice_node, mutated_slice_node,), {})
  502. # Add the slice_scatter node to our "nodes to delete" list.
  503. all_later_view_inverse_nodes_to_delete.add(node)
  504. else:
  505. # Step 3b: Check to see if this operator has an inplace variant.
  506. maybe_inplace_op = _maybe_get_inplace_op(node.target)
  507. if maybe_inplace_op is None:
  508. continue
  509. # And if so, replace it with its inplace variant.
  510. node.target = maybe_inplace_op
  511. # At this point, 'storage_to_nodes' will be stale.
  512. # Now that we're inplacing `b = foo(a)`, we need to effectively
  513. # union together the dict values for b and a's storage.
  514. # Hmm... morally I think we also want to keep the `fake_result` metadata
  515. # up to date here, but I'm not sure how easy it is to do.
  516. # Maybe it's fine to wait until the end of the pass to update it.
  517. curr_node_storage = StorageWeakRef(node.meta['fake_result']._typed_storage())
  518. storage_to_nodes[self_arg_storage].update(storage_to_nodes[curr_node_storage])
  519. storage_to_nodes[curr_node_storage].update(storage_to_nodes[self_arg_storage])
  520. # Need to remember the view_scatter view nodes we found so we can remove them alter.
  521. all_later_view_inverse_nodes_to_delete.update(later_view_inverse_node_usages)
  522. # Step 4:
  523. # Now that we've replaced b = a.foo() with a.foo_(),
  524. # We need to replace any later usages of "b" with "a"
  525. for old in itertools.chain([node], later_view_inverse_node_usages):
  526. new = old.args[0]
  527. nodes_to_update = [n for n in old.users if n.meta['node_idx'] > node.meta['node_idx']]
  528. for node_to_update in nodes_to_update:
  529. new_args = []
  530. args = node_to_update.args
  531. def replace_arg(a):
  532. if a == old:
  533. return new
  534. return a
  535. # First, replace usages of "b" with "a"
  536. node_to_update.args = tree_map_only(Node, replace_arg, node_to_update.args)
  537. node_to_update.kwargs = tree_map_only(Node, replace_arg, node_to_update.kwargs)
  538. # Second, update our storage_to_nodes data structure.
  539. old_flattened_res, _ = tree_flatten(old.meta['fake_result'])
  540. node_flattened_res, _ = tree_flatten(node_to_update.meta['fake_result'])
  541. old_res_storage = {
  542. StorageWeakRef(
  543. x._typed_storage()
  544. ) for x in old_flattened_res if isinstance(x, FakeTensor)}
  545. node_res_storage = {
  546. StorageWeakRef(
  547. x._typed_storage()
  548. ) for x in node_flattened_res if isinstance(x, FakeTensor)}
  549. # This will happen if we're updating a view op, e.g.
  550. # e.g. replacing
  551. # x = view(old)
  552. # x = view(new)
  553. # When that happens, we need to make sure to keep our
  554. # storage mapping up to date.
  555. #
  556. # We're checking for len(...) == 1 here because all view ops are guaranteed to return either a single tensor,
  557. # or multiple tensors that all share the same storage.
  558. # We can't just check equality because we might encounter FX nodes that return zero tensor outputs.
  559. if len(old_res_storage) == 1 and len(node_res_storage) == 1 and old_res_storage == node_res_storage:
  560. new_flattened_res, _ = tree_flatten(new.meta['fake_result'])
  561. new_res_storage = {
  562. StorageWeakRef(
  563. x._typed_storage()
  564. ) for x in new_flattened_res if isinstance(x, FakeTensor)}
  565. assert len(new_res_storage) == 1
  566. (old_ref,) = old_res_storage
  567. (new_ref,) = new_res_storage
  568. (node_ref,) = node_res_storage
  569. # Technically, "old_ref" and all its aliases will remain
  570. # in our mapping.
  571. # That should be fine though, since we deleted "old"
  572. # from the graph at this point.
  573. storage_to_nodes[node_ref].update(storage_to_nodes[new_ref])
  574. storage_to_nodes[new_ref].update(storage_to_nodes[node_ref])
  575. # Step 4: delete any _scatter nodes that we de-functionalized
  576. # Need to take care not to delete any of these nodes until after *all* modifications
  577. # to the graph are finished.
  578. for to_delete in all_later_view_inverse_nodes_to_delete:
  579. gm.graph.erase_node(to_delete)
  580. gm.recompile()
  581. return gm