123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307 |
- import torch.fx as fx
- import copy
- import torch
- import math
- from typing import Callable, List
- from functools import wraps, partial
- from dataclasses import dataclass
- from .compile_utils import get_placeholders, get_outputs
- class ConcreteProp(torch.fx.Interpreter):
- def run_node(self, n):
- result = super().run_node(n)
- found_tensor = False
- def extract_tensor_meta(obj):
- if isinstance(obj, torch.Tensor):
- nonlocal found_tensor
- found_tensor = True
- return obj
- else:
- return obj
- from torch.fx.node import map_aggregate
- concrete_value = map_aggregate(result, extract_tensor_meta)
- if found_tensor:
- n.meta['concrete_value'] = concrete_value
- return result
- def propagate(self, *args):
- return super().run(*args)
- # inplace modifies node/inps
- def _convert_node_to_placeholder(node, inps):
- if node.op == 'output' or node.op == "placeholder":
- return
- node.op = 'placeholder'
- node.args = ()
- node.kwargs = {}
- node.target = node.name
- concrete_val = node.meta.get('concrete_value', None)
- if isinstance(concrete_val, torch.Tensor):
- inps.append(concrete_val)
- else:
- inps.append(torch.zeros(()))
- for tuple_user in list(node.users):
- _convert_node_to_placeholder(tuple_user, inps)
- def dump_state(fx_g, inps):
- print(f"""
- # Working Repro with {len(fx_g.graph.nodes)} nodes
- inps = {[(i.shape, i.dtype, i.device.type) for i in inps]}
- inps = [torch.zeros(())] + [torch.ones(shape, dtype=dtype, device=device) for (shape, dtype, device) in inps]
- {fx_g.code}
- """)
- @dataclass
- class ReproState:
- graph: fx.Graph
- inps: List[torch.Tensor]
- def minifier(fail_f: fx.GraphModule, inps, module_fails, dump_state: Callable = dump_state):
- """
- Minimizes a FX graph with given inputs, such that the resulting FX graph still returns True for module_fails.
- Does 2 main strategies:
- 1. Truncates suffix: Removes some suffix from the graph and sets a new output.
- 2. Delta Debugging: Tries replacing half of the graph with inputs. If fails,
- tries replacing quarter of the graph, etc.
- >>> # xdoctest: +SKIP(failing)
- >>> failing_function = fx.symbolic_trace(f)
- >>> minimize(failing_function, [torch.randn(5)], lambda fx_g, inps: fx_g(*inps))
- note: module_fails returns True if it fails.
- """
- failing_graph = fail_f.graph
- cur_size = len(failing_graph.nodes)
- num_queries = 0
- def deepcopy_fx_graph(fx_graph):
- return fx.GraphModule(fail_f, copy.deepcopy(fx_graph)).graph
- def graph_fails(graph, inps):
- nonlocal num_queries
- graph = copy.deepcopy(graph)
- num_queries += 1
- mod = fx.GraphModule(fail_f, graph)
- mod.graph.lint()
- return module_fails(mod, inps)
- ConcreteProp(fail_f).propagate(*inps)
- if not graph_fails(failing_graph, inps):
- raise RuntimeError("Input graph did not fail the tester")
- print(f"Started off with {cur_size} nodes")
- def _register_strategy(strategy: Callable, name: str):
- @wraps(strategy)
- def new_func(old_state: ReproState, granularity=1):
- print()
- print(f"Strategy: {name} (G: {granularity}) ({len(old_state.graph.nodes)} nodes, {len(old_state.inps)} inputs)")
- new_state = strategy(deepcopy_fx_graph(old_state.graph), list(old_state.inps), granularity)
- if new_state is not None:
- new_nodes = len(new_state.graph.nodes)
- old_nodes = len(old_state.graph.nodes)
- new_inps = len(new_state.inps)
- old_inps = len(old_state.inps)
- new_outs = len(get_outputs(new_state.graph))
- old_outs = len(get_outputs(old_state.graph))
- progress_made = False
- if new_nodes < old_nodes:
- progress_made = True
- print(f"SUCCESS: Went from {old_nodes} to {new_nodes} nodes")
- if new_inps > old_inps:
- progress_made = True
- print(f"SUCCESS: Went from {old_inps} to {new_inps} inputs")
- if new_outs < old_outs:
- progress_made = True
- print(f"SUCCESS: Went from {old_outs} to {new_outs} outputs")
- if not progress_made:
- raise RuntimeError("Success raised but no progress made?")
- if not graph_fails(new_state.graph, new_state.inps):
- print("WARNING: Something went wrong, not applying this minification")
- return None
- return new_state
- else:
- print(f"FAIL: {name}")
- return None
- return new_func
- def register_strategy(name: str):
- return partial(_register_strategy, name=name)
- @register_strategy("Truncate suffix")
- def remove_suffix(cur_graph, cur_inps, granularity):
- tested = set()
- new_graph = fx.Graph()
- env = {}
- for idx, node in enumerate(cur_graph.nodes):
- new_node = new_graph.node_copy(node, lambda x: env[x])
- if node.op not in ['placeholder', 'output']:
- # If idx is divisible by (granularity * 2), it would have been checked already.
- if idx % granularity == 0 and (idx % (granularity * 2) != 0) and idx not in tested:
- output_node = new_graph.output((new_node,))
- if len(new_graph.nodes) < len(cur_graph.nodes) and graph_fails(new_graph, cur_inps):
- return ReproState(new_graph, cur_inps)
- else:
- tested.add(idx)
- new_graph.erase_node(output_node)
- env[node] = new_node
- return None
- @register_strategy("Remove outputs")
- def remove_outputs(cur_graph, cur_inps, granularity):
- granularity = max(1, granularity // 2)
- for idx, node in enumerate(cur_graph.nodes):
- node.idx = idx
- if node.op == 'output':
- output = node
- break
- output_args = sorted(output.args[0], key=lambda x: x.idx if isinstance(x, fx.Node) else int(1e9))
- if len(output_args) == 1:
- return None
- for idx in range(0, len(output_args), granularity):
- output.args = (output_args[:idx] + output_args[idx + granularity:],)
- if graph_fails(cur_graph, cur_inps):
- return ReproState(cur_graph, cur_inps)
- return None
- def remove_unused_inputs_unchecked(cur_state: ReproState):
- cur_graph = cur_state.graph
- cur_inps = cur_state.inps
- ph_nodes = get_placeholders(cur_graph)
- assert len(ph_nodes) == len(cur_inps)
- new_inps = []
- for idx in range(len(ph_nodes)):
- if len(ph_nodes[idx].users) == 0:
- cur_graph.erase_node(ph_nodes[idx])
- else:
- new_inps.append(cur_inps[idx])
- if len(new_inps) < len(cur_inps):
- return ReproState(cur_graph, new_inps)
- return None
- def remove_unused_inputs_checked(cur_state: ReproState):
- new_state = remove_unused_inputs_unchecked(cur_state)
- if new_state is not None and graph_fails(new_state.graph, new_state.inps):
- return new_state
- return None
- def _remove_unused_wrapper(cur_graph, cur_inps, granularity):
- return remove_unused_inputs_checked(ReproState(cur_graph, cur_inps))
- remove_unused_inputs = register_strategy("Remove unused inputs")(_remove_unused_wrapper)
- @register_strategy("Eliminate dead code")
- def eliminate_dead_code(cur_graph, cur_inps, granularity):
- if cur_graph.eliminate_dead_code() and graph_fails(cur_graph, cur_inps):
- return ReproState(cur_graph, cur_inps)
- return None
- def _consolidate_placeholders(cur_graph):
- new_graph = fx.Graph()
- env = {}
- for node in cur_graph.nodes:
- if node.op == 'placeholder':
- new_node = new_graph.node_copy(node, lambda x: env[x])
- env[node] = new_node
- for node in cur_graph.nodes:
- if node.op != 'placeholder':
- new_node = new_graph.node_copy(node, lambda x: env[x])
- env[node] = new_node
- return new_graph
- @register_strategy("Delta Debugging")
- def delta_debugging(cur_graph: fx.Graph, cur_inps, granularity):
- num_nodes = len(cur_graph.nodes)
- for start_range in range(0, num_nodes, granularity):
- is_removing = False
- new_graph = deepcopy_fx_graph(cur_graph)
- new_inps = cur_inps[:]
- end_range = min(num_nodes, start_range + granularity)
- for idx in range(start_range, end_range):
- new_node = list(new_graph.nodes)[idx]
- if new_node.op not in ['placeholder', 'output']:
- is_removing = True
- _convert_node_to_placeholder(new_node, new_inps)
- if not is_removing:
- continue
- new_graph = _consolidate_placeholders(new_graph)
- new_state = remove_unused_inputs_unchecked(ReproState(new_graph, new_inps))
- if new_state is None:
- new_state = ReproState(new_graph, new_inps)
- if graph_fails(new_state.graph, new_state.inps):
- return ReproState(new_state.graph, new_state.inps)
- return None
- failing_state = ReproState(failing_graph, inps)
- def try_granularity(failing_state, granularity, use_non_granular):
- print(f"Trying granularity {granularity}")
- strategies = []
- num_nodes = len(failing_state.graph.nodes)
- num_outputs = len(get_outputs(failing_state.graph))
- if num_outputs > num_nodes // 2:
- strategies += [remove_outputs]
- if use_non_granular:
- strategies += [eliminate_dead_code, remove_unused_inputs]
- strategies += [remove_suffix, delta_debugging]
- for strategy in strategies:
- new_state = strategy(failing_state, granularity)
- if new_state is not None:
- return new_state
- return None
- while True:
- dump_state(fx.GraphModule(fail_f, failing_state.graph), failing_state.inps)
- granularity = int(2**(math.floor(math.log2(len(failing_state.graph.nodes)))))
- new_state = try_granularity(failing_state, granularity, use_non_granular=True)
- if new_state is not None:
- failing_state = new_state
- continue
- granularity //= 2
- has_progress = False
- while granularity >= 1:
- new_state = try_granularity(failing_state, granularity, use_non_granular=False)
- if new_state is not None:
- failing_state = new_state
- has_progress = True
- break
- granularity //= 2
- if has_progress:
- continue
- new_state = remove_outputs(failing_state, 1)
- if new_state is not None:
- failing_state = new_state
- continue
- break
- if not graph_fails(failing_state.graph, failing_state.inps):
- raise RuntimeError("Uh oh, something went wrong :( Final graph is not failing")
- print(f"Made {num_queries} queries")
- failing_fx = fx.GraphModule(fail_f, failing_state.graph)
- dump_state(failing_fx, failing_state.inps)
- print("Wrote minimal repro out to repro.py")
- return failing_fx, failing_state.inps
|