fx_minifier.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. import torch.fx as fx
  2. import copy
  3. import torch
  4. import math
  5. from typing import Callable, List
  6. from functools import wraps, partial
  7. from dataclasses import dataclass
  8. from .compile_utils import get_placeholders, get_outputs
  9. class ConcreteProp(torch.fx.Interpreter):
  10. def run_node(self, n):
  11. result = super().run_node(n)
  12. found_tensor = False
  13. def extract_tensor_meta(obj):
  14. if isinstance(obj, torch.Tensor):
  15. nonlocal found_tensor
  16. found_tensor = True
  17. return obj
  18. else:
  19. return obj
  20. from torch.fx.node import map_aggregate
  21. concrete_value = map_aggregate(result, extract_tensor_meta)
  22. if found_tensor:
  23. n.meta['concrete_value'] = concrete_value
  24. return result
  25. def propagate(self, *args):
  26. return super().run(*args)
  27. # inplace modifies node/inps
  28. def _convert_node_to_placeholder(node, inps):
  29. if node.op == 'output' or node.op == "placeholder":
  30. return
  31. node.op = 'placeholder'
  32. node.args = ()
  33. node.kwargs = {}
  34. node.target = node.name
  35. concrete_val = node.meta.get('concrete_value', None)
  36. if isinstance(concrete_val, torch.Tensor):
  37. inps.append(concrete_val)
  38. else:
  39. inps.append(torch.zeros(()))
  40. for tuple_user in list(node.users):
  41. _convert_node_to_placeholder(tuple_user, inps)
  42. def dump_state(fx_g, inps):
  43. print(f"""
  44. # Working Repro with {len(fx_g.graph.nodes)} nodes
  45. inps = {[(i.shape, i.dtype, i.device.type) for i in inps]}
  46. inps = [torch.zeros(())] + [torch.ones(shape, dtype=dtype, device=device) for (shape, dtype, device) in inps]
  47. {fx_g.code}
  48. """)
  49. @dataclass
  50. class ReproState:
  51. graph: fx.Graph
  52. inps: List[torch.Tensor]
  53. def minifier(fail_f: fx.GraphModule, inps, module_fails, dump_state: Callable = dump_state):
  54. """
  55. Minimizes a FX graph with given inputs, such that the resulting FX graph still returns True for module_fails.
  56. Does 2 main strategies:
  57. 1. Truncates suffix: Removes some suffix from the graph and sets a new output.
  58. 2. Delta Debugging: Tries replacing half of the graph with inputs. If fails,
  59. tries replacing quarter of the graph, etc.
  60. >>> # xdoctest: +SKIP(failing)
  61. >>> failing_function = fx.symbolic_trace(f)
  62. >>> minimize(failing_function, [torch.randn(5)], lambda fx_g, inps: fx_g(*inps))
  63. note: module_fails returns True if it fails.
  64. """
  65. failing_graph = fail_f.graph
  66. cur_size = len(failing_graph.nodes)
  67. num_queries = 0
  68. def deepcopy_fx_graph(fx_graph):
  69. return fx.GraphModule(fail_f, copy.deepcopy(fx_graph)).graph
  70. def graph_fails(graph, inps):
  71. nonlocal num_queries
  72. graph = copy.deepcopy(graph)
  73. num_queries += 1
  74. mod = fx.GraphModule(fail_f, graph)
  75. mod.graph.lint()
  76. return module_fails(mod, inps)
  77. ConcreteProp(fail_f).propagate(*inps)
  78. if not graph_fails(failing_graph, inps):
  79. raise RuntimeError("Input graph did not fail the tester")
  80. print(f"Started off with {cur_size} nodes")
  81. def _register_strategy(strategy: Callable, name: str):
  82. @wraps(strategy)
  83. def new_func(old_state: ReproState, granularity=1):
  84. print()
  85. print(f"Strategy: {name} (G: {granularity}) ({len(old_state.graph.nodes)} nodes, {len(old_state.inps)} inputs)")
  86. new_state = strategy(deepcopy_fx_graph(old_state.graph), list(old_state.inps), granularity)
  87. if new_state is not None:
  88. new_nodes = len(new_state.graph.nodes)
  89. old_nodes = len(old_state.graph.nodes)
  90. new_inps = len(new_state.inps)
  91. old_inps = len(old_state.inps)
  92. new_outs = len(get_outputs(new_state.graph))
  93. old_outs = len(get_outputs(old_state.graph))
  94. progress_made = False
  95. if new_nodes < old_nodes:
  96. progress_made = True
  97. print(f"SUCCESS: Went from {old_nodes} to {new_nodes} nodes")
  98. if new_inps > old_inps:
  99. progress_made = True
  100. print(f"SUCCESS: Went from {old_inps} to {new_inps} inputs")
  101. if new_outs < old_outs:
  102. progress_made = True
  103. print(f"SUCCESS: Went from {old_outs} to {new_outs} outputs")
  104. if not progress_made:
  105. raise RuntimeError("Success raised but no progress made?")
  106. if not graph_fails(new_state.graph, new_state.inps):
  107. print("WARNING: Something went wrong, not applying this minification")
  108. return None
  109. return new_state
  110. else:
  111. print(f"FAIL: {name}")
  112. return None
  113. return new_func
  114. def register_strategy(name: str):
  115. return partial(_register_strategy, name=name)
  116. @register_strategy("Truncate suffix")
  117. def remove_suffix(cur_graph, cur_inps, granularity):
  118. tested = set()
  119. new_graph = fx.Graph()
  120. env = {}
  121. for idx, node in enumerate(cur_graph.nodes):
  122. new_node = new_graph.node_copy(node, lambda x: env[x])
  123. if node.op not in ['placeholder', 'output']:
  124. # If idx is divisible by (granularity * 2), it would have been checked already.
  125. if idx % granularity == 0 and (idx % (granularity * 2) != 0) and idx not in tested:
  126. output_node = new_graph.output((new_node,))
  127. if len(new_graph.nodes) < len(cur_graph.nodes) and graph_fails(new_graph, cur_inps):
  128. return ReproState(new_graph, cur_inps)
  129. else:
  130. tested.add(idx)
  131. new_graph.erase_node(output_node)
  132. env[node] = new_node
  133. return None
  134. @register_strategy("Remove outputs")
  135. def remove_outputs(cur_graph, cur_inps, granularity):
  136. granularity = max(1, granularity // 2)
  137. for idx, node in enumerate(cur_graph.nodes):
  138. node.idx = idx
  139. if node.op == 'output':
  140. output = node
  141. break
  142. output_args = sorted(output.args[0], key=lambda x: x.idx if isinstance(x, fx.Node) else int(1e9))
  143. if len(output_args) == 1:
  144. return None
  145. for idx in range(0, len(output_args), granularity):
  146. output.args = (output_args[:idx] + output_args[idx + granularity:],)
  147. if graph_fails(cur_graph, cur_inps):
  148. return ReproState(cur_graph, cur_inps)
  149. return None
  150. def remove_unused_inputs_unchecked(cur_state: ReproState):
  151. cur_graph = cur_state.graph
  152. cur_inps = cur_state.inps
  153. ph_nodes = get_placeholders(cur_graph)
  154. assert len(ph_nodes) == len(cur_inps)
  155. new_inps = []
  156. for idx in range(len(ph_nodes)):
  157. if len(ph_nodes[idx].users) == 0:
  158. cur_graph.erase_node(ph_nodes[idx])
  159. else:
  160. new_inps.append(cur_inps[idx])
  161. if len(new_inps) < len(cur_inps):
  162. return ReproState(cur_graph, new_inps)
  163. return None
  164. def remove_unused_inputs_checked(cur_state: ReproState):
  165. new_state = remove_unused_inputs_unchecked(cur_state)
  166. if new_state is not None and graph_fails(new_state.graph, new_state.inps):
  167. return new_state
  168. return None
  169. def _remove_unused_wrapper(cur_graph, cur_inps, granularity):
  170. return remove_unused_inputs_checked(ReproState(cur_graph, cur_inps))
  171. remove_unused_inputs = register_strategy("Remove unused inputs")(_remove_unused_wrapper)
  172. @register_strategy("Eliminate dead code")
  173. def eliminate_dead_code(cur_graph, cur_inps, granularity):
  174. if cur_graph.eliminate_dead_code() and graph_fails(cur_graph, cur_inps):
  175. return ReproState(cur_graph, cur_inps)
  176. return None
  177. def _consolidate_placeholders(cur_graph):
  178. new_graph = fx.Graph()
  179. env = {}
  180. for node in cur_graph.nodes:
  181. if node.op == 'placeholder':
  182. new_node = new_graph.node_copy(node, lambda x: env[x])
  183. env[node] = new_node
  184. for node in cur_graph.nodes:
  185. if node.op != 'placeholder':
  186. new_node = new_graph.node_copy(node, lambda x: env[x])
  187. env[node] = new_node
  188. return new_graph
  189. @register_strategy("Delta Debugging")
  190. def delta_debugging(cur_graph: fx.Graph, cur_inps, granularity):
  191. num_nodes = len(cur_graph.nodes)
  192. for start_range in range(0, num_nodes, granularity):
  193. is_removing = False
  194. new_graph = deepcopy_fx_graph(cur_graph)
  195. new_inps = cur_inps[:]
  196. end_range = min(num_nodes, start_range + granularity)
  197. for idx in range(start_range, end_range):
  198. new_node = list(new_graph.nodes)[idx]
  199. if new_node.op not in ['placeholder', 'output']:
  200. is_removing = True
  201. _convert_node_to_placeholder(new_node, new_inps)
  202. if not is_removing:
  203. continue
  204. new_graph = _consolidate_placeholders(new_graph)
  205. new_state = remove_unused_inputs_unchecked(ReproState(new_graph, new_inps))
  206. if new_state is None:
  207. new_state = ReproState(new_graph, new_inps)
  208. if graph_fails(new_state.graph, new_state.inps):
  209. return ReproState(new_state.graph, new_state.inps)
  210. return None
  211. failing_state = ReproState(failing_graph, inps)
  212. def try_granularity(failing_state, granularity, use_non_granular):
  213. print(f"Trying granularity {granularity}")
  214. strategies = []
  215. num_nodes = len(failing_state.graph.nodes)
  216. num_outputs = len(get_outputs(failing_state.graph))
  217. if num_outputs > num_nodes // 2:
  218. strategies += [remove_outputs]
  219. if use_non_granular:
  220. strategies += [eliminate_dead_code, remove_unused_inputs]
  221. strategies += [remove_suffix, delta_debugging]
  222. for strategy in strategies:
  223. new_state = strategy(failing_state, granularity)
  224. if new_state is not None:
  225. return new_state
  226. return None
  227. while True:
  228. dump_state(fx.GraphModule(fail_f, failing_state.graph), failing_state.inps)
  229. granularity = int(2**(math.floor(math.log2(len(failing_state.graph.nodes)))))
  230. new_state = try_granularity(failing_state, granularity, use_non_granular=True)
  231. if new_state is not None:
  232. failing_state = new_state
  233. continue
  234. granularity //= 2
  235. has_progress = False
  236. while granularity >= 1:
  237. new_state = try_granularity(failing_state, granularity, use_non_granular=False)
  238. if new_state is not None:
  239. failing_state = new_state
  240. has_progress = True
  241. break
  242. granularity //= 2
  243. if has_progress:
  244. continue
  245. new_state = remove_outputs(failing_state, 1)
  246. if new_state is not None:
  247. failing_state = new_state
  248. continue
  249. break
  250. if not graph_fails(failing_state.graph, failing_state.inps):
  251. raise RuntimeError("Uh oh, something went wrong :( Final graph is not failing")
  252. print(f"Made {num_queries} queries")
  253. failing_fx = fx.GraphModule(fail_f, failing_state.graph)
  254. dump_state(failing_fx, failing_state.inps)
  255. print("Wrote minimal repro out to repro.py")
  256. return failing_fx, failing_state.inps