123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303 |
- import inspect
- import logging
- from queue import Queue
- from functools import wraps
- from typing import Callable, Dict, List
- import torch.nn as nn
- from torch.fx.graph_module import GraphModule
- from torch.fx._compatibility import compatibility
- from torch.fx.passes.infra.pass_base import PassResult
- logger = logging.getLogger(__name__)
- logger.setLevel(logging.WARNING)
- __all__ = ['pass_result_wrapper', 'this_before_that_pass_constraint', 'PassManager']
- @compatibility(is_backward_compatible=False)
- def pass_result_wrapper(fn: Callable) -> Callable:
- """
- Wrapper for passes which currently do not return a PassResult.
- This wrapper makes them return a PassResult containing the modified object
- and True for the "modified" flag.
- Args:
- fn (Callable[Module, Any])
- Returns:
- wrapped_fn (Callable[Module, PassResult])
- """
- if fn is None:
- return None
- @wraps(fn)
- def wrapped_fn(gm):
- res = fn(gm)
- if res is None:
- return PassResult(gm, True)
- if isinstance(res, PassResult):
- return res
- elif isinstance(res, nn.Module):
- return PassResult(res, True)
- if not inspect.isfunction(fn):
- wrapped_fn.__name__ = type(fn).__name__
- return wrapped_fn
- def _validate_pass_schedule_constraint(
- constraint: Callable[[Callable, Callable], bool], passes: List[Callable]
- ) -> None:
- for i, a in enumerate(passes):
- for j, b in enumerate(passes[i + 1 :]):
- if constraint(a, b):
- continue
- raise RuntimeError(
- f"pass schedule constraint violated. Expected {a} before {b}"
- f" but found {a} at index {i} and {b} at index{j} in pass"
- f" list."
- )
- def _topological_sort_passes(
- passes: List[Callable], constraints: List[Callable]
- ) -> List[Callable]:
- """
- Args
- passes: Passes that we are ordering
- constraints: Constraints applied on these passes
- Returns
- A sorted list of callables and a boolean of if a circular dependency
- existed
- """
- if len(constraints) == 0:
- return passes
- # Contruct a graph mapping nodes to a list of their users
- graph: Dict[Callable, List[Callable]] = {p : [] for p in passes}
- indegree_map: Dict[Callable, int] = {p : 0 for p in passes}
- candidates: Queue = Queue()
- for a in passes:
- for b in passes:
- if a == b:
- continue
- for constraint in constraints:
- if not constraint(a, b):
- graph[b].append(a)
- indegree_map[a] += 1
- if indegree_map[a] == 0:
- candidates.put(a)
- visited: Dict[Callable, bool] = {p : False for p in passes}
- sorted_passes: List[Callable] = []
- while not candidates.empty():
- p = candidates.get()
- sorted_passes.append(p)
- visited[p] = True
- for n in graph[p]:
- if not visited[n]:
- indegree_map[n] -= 1
- if indegree_map[n] == 0:
- candidates.put(n)
- # Check if there are unvisited nodes (aka cycles in the graph)
- cycle_passes = list(filter(lambda p: indegree_map[p] != 0, indegree_map.keys()))
- if len(cycle_passes) != 0:
- error = f"Circular dependency detected within the following passes: {cycle_passes}"
- raise RuntimeError(error)
- return sorted_passes
- @compatibility(is_backward_compatible=False)
- def this_before_that_pass_constraint(this: Callable, that: Callable) -> Callable:
- """
- Defines a partial order ('depends on' function) where `this` must occur
- before `that`.
- For example, the following pass list and constraint list would be invalid.
- ```
- passes = [pass_b, pass_a]
- constraints = [
- this_before_that_pass_constraint(pass_a, pass_b)
- ]
- ```
- Args:
- this (Callable): pass which should occur first
- that (Callable): pass which should occur later
- Returns:
- depends_on (Callable[[Object, Object], bool]
- """
- def depends_on(a: Callable, b: Callable):
- if a == that and b == this:
- return False
- return True
- return depends_on
- @compatibility(is_backward_compatible=False)
- class PassManager:
- """
- Construct a PassManager.
- Collects passes and constraints. This defines the pass schedule, manages
- pass constraints and pass execution.
- Args:
- passes (Optional[List[Callable]]): List of passes. A pass is a
- callable which modifies an object and returns a PassResult
- constraint (Optional[List[Callable]]): List of constraints. A
- constraint is a callable which takes two passes (A, B) and returns
- True if A depends on B and False otherwise. See implementation of
- `this_before_that_pass_constraint` for example.
- steps (int): Max number of times we run the passes (default = 1).
- run_checks_after_each_pass (bool): Whether to run checks and linting
- after each pass
- suppress_check_failures (bool): Whether to raise errors when running
- checks
- """
- passes: List[Callable[[nn.Module], PassResult]]
- constraints: List[Callable[[Callable, Callable], bool]]
- _validated: bool = False
- steps: int = 1
- def __init__(
- self,
- passes=None,
- constraints=None,
- steps=None,
- run_checks_after_each_pass: bool = False,
- suppress_check_failures: bool = False,
- ):
- self.passes = passes or []
- self.constraints = constraints or []
- if steps:
- self.steps = steps
- self.run_checks_after_each_pass = run_checks_after_each_pass
- self.suppress_check_failures = suppress_check_failures
- def add_pass(self, _pass: Callable):
- """
- Adds a pass into the current list of passes.
- """
- self.passes.append(_pass)
- self._validated = False
- def add_constraint(self, constraint: Callable):
- """
- Adds a constraint into the current list of constraints.
- """
- self.constraints.append(constraint)
- self._validated = False
- def validate_constraints(self):
- """
- Validates that current pass schedule defined by `self.passes` is valid
- according to all constraints in `self.constraints`
- """
- if self._validated:
- return
- for constraint in self.constraints:
- _validate_pass_schedule_constraint(constraint, self.passes)
- self._validated = True
- def solve_constraints(self):
- """
- Finds a valid traversal order based on the given constraints and orders
- the passes based on this order.
- If a circular dependency exists between the constraints and steps = 1,
- then we will raise an error because if steps != 1 this means that we
- will re-run the passes, allowing for circular dependencies.
- """
- self.passes = _topological_sort_passes(self.passes, self.constraints)
- self._validated = True
- def add_checks(self, check: Callable) -> None:
- """
- Adds a function which takes runs various checks on a given graph module.
- This function is run before and after each pass if the
- `run_checks_after_each_pass` flag is enabled.
- """
- sig = inspect.signature(check)
- if len(list(sig.parameters.values())) != 1:
- raise TypeError("PassManager check function should only take in one variable, a module")
- setattr(self, "check", check) # noqa: B010
- def check(self, module: nn.Module) -> None:
- pass
- def __call__(self, module: nn.Module) -> PassResult:
- """
- Runs a list of passes in the order based on `self.passes` on the given
- graph module. Each time a pass is run, checks and linting will be run on
- the graph module if `run_checks_after_each_pass` is set.
- If the module is a graph module, we will run the list of passes until
- the graph stops changing, or until `steps` number of times.
- """
- # Order the passes based on the constraints
- if not self._validated:
- self.solve_constraints()
- # Check graph invariants
- self.check(module)
- # Run the set of passes `steps` number of times or until the graph stops
- # changing
- overall_modified = False
- for _ in range(self.steps):
- modified = False
- # Run the set of passes on the graph module
- for i, fn in enumerate(self.passes):
- fn_name = fn.__name__ if inspect.isfunction(fn) else type(fn).__name__
- logger.debug(f"Running pass '{fn_name}'")
- try:
- res = fn(module)
- if not isinstance(res, PassResult) and not hasattr(
- res, "graph_module"
- ):
- raise TypeError(
- f"The result of the pass {fn_name} should be type PassResult."
- + "Please wrap it with pass_result_wrapper()"
- )
- module = res.graph_module
- modified = modified or res.modified
- if isinstance(module, GraphModule):
- logger.debug(f"Graph after pass '{fn_name}':", module.graph)
- module.recompile()
- # Check graph invariants
- if self.run_checks_after_each_pass:
- self.check(module)
- except Exception as e:
- prev_pass_names = [
- p.__name__ if inspect.isfunction(p) else type(p).__name__
- for p in self.passes[:i]
- ]
- msg = f"An error occurred when running the '{fn_name}' pass after the following passes: {prev_pass_names}"
- raise Exception(msg) from e
- # If the graph no longer changes, then we can stop running these passes
- overall_modified = overall_modified or modified
- if not modified:
- break
- return PassResult(module, overall_modified)
|