partitioners.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555
  1. from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types
  2. from torch.fx.experimental.symbolic_shapes import hint_int
  3. import torch
  4. import torch.fx as fx
  5. import operator
  6. import math
  7. import torch.utils._pytree as pytree
  8. import copy
  9. import os
  10. from collections import defaultdict
  11. from torch.fx.passes import graph_drawer
  12. from typing import Tuple
  13. from .compile_utils import fx_graph_cse, get_aten_target
  14. from . import config
  15. import functools
  16. AOT_PARTITIONER_DEBUG = config.debug_partitioner
  17. class InvalidNodeBase:
  18. def __repr__(self):
  19. return "Invalid Node"
  20. InvalidNode = InvalidNodeBase()
  21. def _extract_graph_with_inputs_outputs(joint_graph, inputs, outputs):
  22. """
  23. Given a graph, extracts out a subgraph that takes the specified nodes as
  24. inputs and returns the specified outputs.
  25. This includes specifying non-placeholder nodes as inputs.
  26. The general strategy is to initialize all inputs with proxies as we
  27. encounter them, and trace through the graph, only keeping values which take
  28. in valid proxies. Then, all dead code is eliminated.
  29. """
  30. new_graph = fx.Graph()
  31. env = {}
  32. # Add new placeholder nodes in the order specified by the inputs
  33. for node in inputs:
  34. new_node = new_graph.placeholder(node.name)
  35. # Can't use node_copy here as we may be turning previous call_function into placeholders
  36. new_node.meta = node.meta
  37. env[node] = new_node
  38. for node in joint_graph.nodes:
  39. if node in inputs:
  40. continue
  41. elif node.op == 'placeholder':
  42. env[node] = InvalidNode
  43. elif node.op == 'call_function':
  44. all_args = pytree.tree_flatten((node.args, node.kwargs))[0]
  45. all_args = [isinstance(env[x], InvalidNodeBase) for x in all_args if isinstance(x, fx.Node)]
  46. if any(all_args):
  47. env[node] = InvalidNode
  48. continue
  49. env[node] = new_graph.node_copy(node, lambda x: env[x])
  50. elif node.op == 'get_attr':
  51. env[node] = new_graph.node_copy(node, lambda x: env[x])
  52. elif node.op == 'output':
  53. pass
  54. output_values = []
  55. for x in outputs:
  56. if isinstance(x, fx.Node):
  57. if x not in env:
  58. raise RuntimeError(f"Node {x} couldn't be found in env")
  59. output_values.append(env[x])
  60. else:
  61. output_values.append(x)
  62. new_graph.output(output_values)
  63. new_graph.eliminate_dead_code()
  64. new_graph.lint()
  65. return new_graph
  66. def _is_primal(node):
  67. return node.op == "placeholder" and "tangents" not in node.target
  68. def _is_tangent(node):
  69. return node.op == "placeholder" and "tangents" in node.target
  70. def _extract_fwd_bwd_outputs(joint_module: fx.GraphModule, *, num_fwd_outputs):
  71. outputs = pytree.tree_flatten([node.args for node in joint_module.graph.nodes if node.op == 'output'])[0]
  72. fwd_outputs = outputs[:num_fwd_outputs]
  73. bwd_outputs = outputs[num_fwd_outputs:]
  74. return fwd_outputs, bwd_outputs
  75. def _extract_fwd_bwd_modules(joint_module: fx.GraphModule, saved_values, saved_sym_nodes=(), *, num_fwd_outputs):
  76. fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
  77. primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
  78. tangent_inputs = list(filter(_is_tangent, joint_module.graph.nodes))
  79. # Construct the forward module
  80. # Keep symints separate from tensors, passed between fwd/bwd graphs, and in the right order.
  81. fwd_graph = _extract_graph_with_inputs_outputs(joint_module.graph, primal_inputs, fwd_outputs + saved_values + saved_sym_nodes)
  82. bwd_graph = _extract_graph_with_inputs_outputs(joint_module.graph, saved_sym_nodes + saved_values + tangent_inputs, bwd_outputs)
  83. # This is to filter out saved values that don't actually end up being used by the backwards pass
  84. for node in bwd_graph.nodes:
  85. if node.op == 'placeholder' and not node.users:
  86. for saved_value in saved_values:
  87. if saved_value.name == node.name:
  88. saved_values.remove(saved_value)
  89. break
  90. for saved_sym in saved_sym_nodes:
  91. if saved_sym.name == node.name:
  92. saved_sym_nodes.remove(saved_sym)
  93. break
  94. # Now, we re-generate the fwd/bwd graphs.
  95. # NB: This might increase compilation time, but I doubt it matters
  96. fwd_graph = _extract_graph_with_inputs_outputs(joint_module.graph, primal_inputs, fwd_outputs + saved_values + saved_sym_nodes)
  97. bwd_graph = _extract_graph_with_inputs_outputs(joint_module.graph, saved_sym_nodes + saved_values + tangent_inputs, bwd_outputs)
  98. fwd_module = fx.GraphModule(joint_module, fwd_graph)
  99. bwd_module = fx.GraphModule(joint_module, bwd_graph)
  100. return fwd_module, bwd_module
  101. def default_partition(
  102. joint_module: fx.GraphModule, _joint_inputs, *, num_fwd_outputs
  103. ) -> Tuple[fx.GraphModule, fx.GraphModule]:
  104. """
  105. Partitions the :attr:`joint_module` in a manner that closely resembles the
  106. behavior observed in the original ``.forward()`` and ``.backward()`` of the
  107. callable, i.e., the resulting forward graph contains those operators that
  108. are executed in the original ``.forward()`` callable passed to
  109. :func:`aot_function`.
  110. The default partitioner collects the operators that are between the forward
  111. inputs and the forward outputs. This helps in finding the tensors which have
  112. to be stashed for the backward pass. These stashed tensors become the output
  113. of the generated forward graph. The remaining operators are then placed in
  114. the backward graph.
  115. .. warning::
  116. This API is experimental and likely to change.
  117. Args:
  118. joint_module(fx.GraphModule): The joint forward and backward graph. This
  119. is the result of AOT Autograd tracing.
  120. Returns:
  121. Returns the generated forward and backward Fx graph modules.
  122. """
  123. primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
  124. fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
  125. forward_only_graph = _extract_graph_with_inputs_outputs(joint_module.graph, primal_inputs, fwd_outputs)
  126. forward_node_names = {node.name for node in forward_only_graph.nodes if node.op != 'output'}
  127. saved_values = []
  128. saved_sym_nodes = []
  129. for node in joint_module.graph.nodes:
  130. if node.name not in forward_node_names:
  131. continue
  132. if is_sym_node(node):
  133. # Symints must be kept separate from tensors so that PythonFunction only calls
  134. # save_for_backward on tensors and stashes symints in autograd .ctx
  135. saved_sym_nodes.append(node)
  136. elif (
  137. 'tensor_meta' not in node.meta
  138. and node.op == 'call_function'
  139. ):
  140. # Since we can't save tuple of tensor values, we need to flatten out what we're saving
  141. users = node.users
  142. assert all(user.target == operator.getitem for user in users)
  143. for user in users:
  144. saved_values.append(user)
  145. else:
  146. backward_usages = [n for n in node.users if n.name not in forward_node_names]
  147. if 'tensor_meta' in node.meta and all(is_sym_node(n) for n in backward_usages):
  148. # If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
  149. # and not the actual tensor data,
  150. # then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
  151. #
  152. # Note that saving the tensor could also cause compilation problems:
  153. # If the user mutated an input in the forward and uses its sizes/strides in the backward,
  154. # then we would be obligated to clone the input before saving it to appease autograd.
  155. # (This is how we originally found this bug).
  156. for user in backward_usages:
  157. saved_sym_nodes.append(user)
  158. else:
  159. saved_values.append(node)
  160. saved_values = list(set(saved_values))
  161. saved_sym_nodes = list(set(saved_sym_nodes))
  162. return _extract_fwd_bwd_modules(joint_module, saved_values, saved_sym_nodes=saved_sym_nodes, num_fwd_outputs=num_fwd_outputs)
  163. def _prod(x):
  164. s = 1
  165. for i in x:
  166. s *= i
  167. return s
  168. def _tensor_nbytes(numel, dtype):
  169. sizes = {
  170. torch.float: 4,
  171. torch.float16: 2,
  172. torch.bfloat16: 2,
  173. torch.float32: 4,
  174. torch.float64: 8,
  175. torch.int: 4,
  176. torch.int8: 1,
  177. torch.int16: 2,
  178. torch.int32: 4,
  179. torch.int64: 8,
  180. torch.uint8: 1,
  181. torch.bool: 1,
  182. }
  183. if dtype not in sizes:
  184. raise NotImplementedError("Don't know the size of dtype ", dtype)
  185. return numel * sizes[dtype]
  186. def _size_of(node: fx.Node) -> int:
  187. if 'val' in node.meta:
  188. val = node.meta['val']
  189. if isinstance(val, py_sym_types):
  190. return 1
  191. elif isinstance(val, (list, tuple)):
  192. return sum(_tensor_nbytes(hint_int(n.numel()), n.dtype) for n in val if isinstance(n, torch.Tensor))
  193. elif isinstance(val, torch.Tensor):
  194. return _tensor_nbytes(hint_int(val.numel()), val.dtype)
  195. raise RuntimeError(f"Unknown metadata type {type(val)}")
  196. # Only needed since we don't always trace with fake tensors.
  197. if 'tensor_meta' in node.meta:
  198. metadata = node.meta['tensor_meta']
  199. numel = _prod(map(to_size_hint, metadata.shape))
  200. dtype = metadata.dtype
  201. else:
  202. return 0
  203. return _tensor_nbytes(numel, dtype)
  204. # Used for some investigative purposes
  205. def _count_ops(graph):
  206. from collections import defaultdict
  207. cnt = defaultdict(int)
  208. for node in graph.nodes:
  209. if node.op == 'call_function':
  210. cnt[node.target.__name__] += 1
  211. print(sorted(cnt.items(), key=lambda x: x[1], reverse=True))
  212. @functools.lru_cache(None)
  213. def pointwise_ops():
  214. ops = []
  215. for attr_name in dir(torch.ops.aten):
  216. opoverloadpacket = getattr(torch.ops.aten, attr_name)
  217. if not isinstance(opoverloadpacket, torch._ops.OpOverloadPacket):
  218. continue
  219. for overload in opoverloadpacket.overloads():
  220. op_overload = getattr(opoverloadpacket, overload)
  221. if torch.Tag.pointwise in op_overload.tags:
  222. # currently aot autograd uses packet not overload
  223. ops.append(opoverloadpacket)
  224. break
  225. return ops
  226. def min_cut_rematerialization_partition(
  227. joint_module: fx.GraphModule, _joint_inputs, compiler="nvfuser", recomputable_ops=None,
  228. *, num_fwd_outputs
  229. ) -> Tuple[fx.GraphModule, fx.GraphModule]:
  230. """
  231. Partitions the joint graph such that the backward recomputes the forward.
  232. Recomputing helps in trading off memory bandwidth with computation.
  233. To create the fwd and bwd graph, we copy the joint graph, manually set the
  234. outputs to just original forward or backward outputs. And then we run the
  235. resulting graphs through dead code elimintation.
  236. .. warning::
  237. This API is experimental and likely to change.
  238. Args:
  239. joint_module(fx.GraphModule): The joint forward and backward graph. This
  240. is the result of AOT Autograd tracing.
  241. _joint_inputs: The inputs to the joint graph. This is unused.
  242. compiler: This option determines the default set of recomputable ops.
  243. Currently, there are two options: ``nvfuser`` and ``inductor``.
  244. recomputable_ops: This is an optional set of recomputable ops. If this
  245. is not None, then this set of ops will be used instead of the
  246. default set of ops.
  247. num_fwd_outputs: The number of outputs from the forward graph.
  248. Returns:
  249. Returns the generated forward and backward Fx graph modules.
  250. """
  251. try:
  252. import networkx as nx
  253. except ImportError as e:
  254. raise RuntimeError("Need networkx installed to perform smart recomputation "
  255. "heuristics") from e
  256. joint_module.graph.eliminate_dead_code()
  257. joint_module.recompile()
  258. fx_g = joint_module.graph
  259. # add the CSE pass
  260. if config.cse:
  261. cse_graph = fx_graph_cse(fx_g)
  262. joint_module.graph = cse_graph
  263. full_bw_graph = joint_module.graph
  264. name_to_node = {}
  265. for node in joint_module.graph.nodes:
  266. name_to_node[node.name] = node
  267. def classify_nodes(joint_module):
  268. required_bw_nodes = set()
  269. for node in joint_module.graph.nodes:
  270. if node.op == 'placeholder' and "tangents" in node.target:
  271. required_bw_nodes.add(node)
  272. if node in required_bw_nodes:
  273. for user in node.users:
  274. required_bw_nodes.add(user)
  275. primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
  276. fwd_outputs, _ = _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
  277. forward_only_graph = _extract_graph_with_inputs_outputs(joint_module.graph, primal_inputs, fwd_outputs)
  278. required_fw_nodes = {name_to_node[node.name] for node in forward_only_graph.nodes
  279. if node.op != 'output'}
  280. unclaimed_nodes = {node for node in joint_module.graph.nodes
  281. if node not in required_fw_nodes and node not in required_bw_nodes}
  282. return fwd_outputs, required_fw_nodes, required_bw_nodes, unclaimed_nodes
  283. orig_fw_outputs, required_fw_nodes, required_bw_nodes, unclaimed_nodes = classify_nodes(joint_module)
  284. def is_tensor_node(x):
  285. # When dynamic shapes are not enabled, fw outputs can be raw ints and not fx nodes
  286. if not isinstance(x, fx.Node):
  287. return False
  288. # It would be nice if we could guarantee that all fx nodes from make_fx get a 'val'
  289. # key in their meta dict, but that isn't always true today (see proxy_tensor.py)
  290. return 'tensor_meta' in x.meta or ('val' in x.meta and isinstance(x.meta['val'], torch.Tensor))
  291. # networkx blows up on graphs with no required backward nodes
  292. # Since there's nothing to partition anyway, and the default partitioner can "handle"
  293. # this case, send our graph over to the default partitioner.
  294. if len(required_bw_nodes) == 0:
  295. return default_partition(joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs)
  296. for node in reversed(joint_module.graph.nodes):
  297. if node not in required_fw_nodes:
  298. node.dist_from_bw = 0
  299. else:
  300. node.dist_from_bw = int(1e9)
  301. for user in node.users:
  302. node.dist_from_bw = min(node.dist_from_bw, user.dist_from_bw + 1)
  303. aten = torch.ops.aten
  304. prims = torch.ops.prims
  305. # compiler == "nvfuser" is the default set of recomputable ops
  306. default_recomputable_ops = [aten.add, aten.sub, aten.div, aten.atan2, aten.mul, aten.max, aten.min, aten.pow, aten.remainder, aten.fmod, aten.__and__, aten.__or__, aten.__xor__, aten.__lshift__, aten.__rshift__, aten.eq, aten.ne, aten.ge, aten.gt, aten.le, aten.lt, aten.abs, aten.bitwise_not, aten.ceil, aten.floor, aten.frac, aten.neg, aten.relu, aten.round, aten.silu, aten.trunc, aten.log, aten.log10, aten.log1p, aten.log2, aten.lgamma, aten.exp, aten.expm1, aten.erf, aten.erfc, aten.cos, aten.acos, aten.cosh, aten.sin, aten.asin, aten.sinh, aten.tan, aten.atan, aten.tanh, aten.atanh, aten.sqrt, aten.rsqrt, aten.reciprocal, aten.sigmoid, aten.softplus, aten.threshold, aten.threshold_backward, aten.clamp, aten.where, aten.lerp, aten.addcmul, aten.gelu, aten.gelu_backward, aten.sum, aten.mean, aten._grad_sum_to_size, aten.sum_to_size, aten.amax, aten.to, aten.type_as, operator.getitem, aten.squeeze, aten.unsqueeze, aten.rsub, aten._to_copy] # noqa: E501
  307. view_ops = [aten.squeeze, aten.unsqueeze, aten.alias]
  308. if compiler == "inductor":
  309. default_recomputable_ops += [prims.div, prims.convert_element_type, aten.clone, aten._to_copy, aten.full_like, prims.var, prims.sum, aten.var, aten.std, prims.broadcast_in_dim, aten.select, aten.permute, aten._unsafe_view, aten.view, aten.expand, aten.slice, aten.reshape, aten.broadcast_tensors, aten.scalar_tensor, aten.ones, aten.new_zeros, aten.lift_fresh_copy, aten.arange, aten.triu, aten.var_mean, aten.isinf, aten.any, aten.full, aten.as_strided, aten.zeros, aten.argmax, aten.maximum] # noqa: E501
  310. view_ops += [aten.view, aten.slice, aten.permute, aten.t, prims.broadcast_in_dim, aten.expand, aten.as_strided]
  311. # Natalia said that we should allow recomputing indexing :)
  312. default_recomputable_ops += [aten.index]
  313. default_recomputable_ops += view_ops
  314. default_recomputable_ops += pointwise_ops()
  315. recomputable_ops = set(recomputable_ops) if recomputable_ops is not None else set(default_recomputable_ops)
  316. random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like]
  317. compute_intensive_ops = [aten.mm, aten.convolution, aten.convolution_backward, aten.bmm, aten.addmm, aten.upsample_bilinear2d, aten._softmax, aten._softmax_backward_data, aten.native_layer_norm, aten.native_layer_norm_backward, aten.native_batch_norm, aten.native_batch_norm_backward, aten._native_batch_norm_legit] # noqa: E501
  318. unrecomputable_ops = random_ops + compute_intensive_ops
  319. fusible_ops = recomputable_ops | set(random_ops)
  320. if AOT_PARTITIONER_DEBUG:
  321. joint_module_ops = {
  322. str(node.target._overloadpacket)
  323. for node in joint_module.graph.nodes
  324. if node.op == "call_function" and hasattr(node.target, "_overloadpacket")
  325. }
  326. ops_ignored = joint_module_ops - {str(i) for i in recomputable_ops}
  327. print("Ops banned from rematerialization: ", ops_ignored)
  328. print()
  329. AGGRESSIVE_RECOMPUTATION = False
  330. def is_materialized_backwards(node):
  331. cur_nodes = {node}
  332. while len(cur_nodes) > 0:
  333. cur = cur_nodes.pop()
  334. for user in cur.users:
  335. if user not in required_fw_nodes and not is_fusible(cur, user):
  336. return True
  337. if user not in required_fw_nodes and get_aten_target(user) in view_ops:
  338. cur_nodes.add(user)
  339. return False
  340. def ban_recomputation(node):
  341. if AGGRESSIVE_RECOMPUTATION:
  342. return (node.op == 'call_function' and get_aten_target(node) in unrecomputable_ops)
  343. else:
  344. if node.op != 'call_function':
  345. return False
  346. if get_aten_target(node) not in recomputable_ops:
  347. return True
  348. if node.target == operator.getitem:
  349. return False
  350. if node.target in [aten.lift_fresh_copy.default, aten.lift_fresh.default]:
  351. return False
  352. # If a node *must* be materialized in the backwards pass, then we
  353. # should never recompute it. This is a pretty subtle point. In
  354. # general, the assumption we make is that recomputing a node in the
  355. # backwards pass is "free". However, if a node must be materialized
  356. # in the backwards pass, then recomputing it is never free.
  357. if is_materialized_backwards(node):
  358. return True
  359. # Arbitrary hack that sometimes seems to help things. The above
  360. # modification appears to have made this heuristic a lot less critical
  361. # for performance.
  362. # TODO: Investigate why this hack helps.
  363. if compiler == "inductor" and node.dist_from_bw > config.max_dist_from_bw:
  364. return True
  365. # If the output of an op is 4x smaller (arbitrary choice),
  366. # then we don't allow recomputation.
  367. input_tensors_size = sum(_size_of(i) for i in node.args if isinstance(i, fx.Node))
  368. output_size = _size_of(node)
  369. return (output_size * 4 < input_tensors_size)
  370. def is_fusible(a, b):
  371. return get_aten_target(a) in fusible_ops and get_aten_target(b) in fusible_ops
  372. def is_materialized(node):
  373. if node.op == 'placeholder':
  374. return True
  375. return not all(is_fusible(node, user) for user in node.users)
  376. def get_node_weight(node) -> int:
  377. mem_sz = _size_of(node)
  378. # Heuristic to bias towards nodes closer to the backwards pass
  379. # Complete guess about current value
  380. mem_sz = int(mem_sz * (1.1 ** max(min(node.dist_from_bw, 100), 1)))
  381. # mem_sz = int(mem_sz + node.dist_from_bw)
  382. if is_materialized(node):
  383. return mem_sz
  384. else:
  385. return mem_sz * 2
  386. nx_graph = nx.DiGraph()
  387. for node in full_bw_graph.nodes:
  388. if node.op == 'output':
  389. continue
  390. if node in required_bw_nodes:
  391. nx_graph.add_edge(node.name + "_in", "sink", capacity=math.inf)
  392. continue
  393. if node.op == 'placeholder' and "primals" in node.target:
  394. nx_graph.add_edge("source", node.name + "_in", capacity=math.inf)
  395. # If a node can't be recomputed (too expensive or involves randomness),
  396. # we prevent it from being recomputed by adding an inf edge to the source
  397. # We only need to ban nodes in the fw pass, as those are the only ones that would be recomputed.
  398. if ban_recomputation(node) and node in required_fw_nodes:
  399. nx_graph.add_edge("source", node.name + "_in", capacity=math.inf)
  400. # Checks if a node is actually a tuple. Can be simplified to just an isisinstance check if we always use faketensors.
  401. is_non_tensor_node = (('val' not in node.meta and 'tensor_meta' not in node.meta) or
  402. ('val' in node.meta and not isinstance(node.meta['val'], torch.Tensor)))
  403. if is_sym_node(node):
  404. weight = 1
  405. elif is_non_tensor_node:
  406. weight = math.inf
  407. else:
  408. weight = get_node_weight(node)
  409. # Creates the weights on the "node" edge
  410. nx_graph.add_edge(node.name + "_in", node.name + "_out", capacity=weight)
  411. for user in node.users:
  412. nx_graph.add_edge(node.name + "_out", user.name + "_in", capacity=math.inf)
  413. cut_value, partition = nx.minimum_cut(nx_graph, "source", "sink")
  414. reachable, non_reachable = partition
  415. cutset = set()
  416. for u, nbrs in ((n, nx_graph[n]) for n in reachable):
  417. cutset.update((u, v) for v in nbrs if v in non_reachable)
  418. cut_nodes = set()
  419. for node_in, node_out in cutset:
  420. assert node_in[:-3] == node_out[:-4]
  421. node_name = node_in[:-3]
  422. cut_nodes.add(node_name)
  423. # To make this stuff deterministic
  424. node_idx = {node: idx for idx, node in enumerate(joint_module.graph.nodes)}
  425. saved_values = sorted((name_to_node[node] for node in cut_nodes), key=lambda x: node_idx[x])
  426. # Symints must be kept separate from tensors so that PythonFunction only calls
  427. # save_for_backward on tensors and stashes symints in autograd .ctx
  428. saved_sym_nodes = list(filter(lambda n: is_sym_node(n), saved_values))
  429. saved_values = list(filter(lambda n: not is_sym_node(n), saved_values))
  430. fw_module, bw_module = _extract_fwd_bwd_modules(
  431. joint_module, saved_values, saved_sym_nodes=saved_sym_nodes, num_fwd_outputs=num_fwd_outputs)
  432. if AOT_PARTITIONER_DEBUG:
  433. print("Theoretical Activations Stored: ", sum([_size_of(i) for i in saved_values]) / 1e9)
  434. fw_module_nodes = {node.name for node in fw_module.graph.nodes if node.op == 'call_function'}
  435. bw_module_nodes = {node.name for node in bw_module.graph.nodes if node.op == 'call_function'}
  436. remat_nodes = fw_module_nodes & bw_module_nodes
  437. counts = defaultdict(int)
  438. for node in fw_module.graph.nodes:
  439. if node.name in remat_nodes and hasattr(node.target, '_overloadpacket'):
  440. counts[str(node.target._overloadpacket)] += 1
  441. print(f"# remat/fw/bw: {len(remat_nodes)}/{len(fw_module_nodes)}/{len(bw_module_nodes)}")
  442. print("Count of Ops Rematerialized: ", sorted(counts.items(), key=lambda x: x[1], reverse=True))
  443. return fw_module, bw_module
  444. def draw_graph(traced: torch.fx.GraphModule, fname: str, figname: str = "fx_graph", clear_meta=True):
  445. if clear_meta:
  446. new_graph = copy.deepcopy(traced.graph)
  447. traced = fx.GraphModule(traced, new_graph)
  448. for node in traced.graph.nodes:
  449. node.meta = {}
  450. base, ext = os.path.splitext(fname)
  451. if not ext:
  452. ext = ".svg"
  453. print(f"Writing FX graph to file: {base}{ext}")
  454. g = graph_drawer.FxGraphDrawer(traced, figname)
  455. x = g.get_main_dot_graph()
  456. getattr(x, "write_" + ext.lstrip("."))(f"{base}{ext}")
  457. def draw_joint_graph(graph, joint_inputs, file_name="full_graph.png"):
  458. draw_graph(graph, file_name)
  459. return default_partition(graph, joint_inputs)