pass_manager.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. import inspect
  2. import logging
  3. from queue import Queue
  4. from functools import wraps
  5. from typing import Callable, Dict, List
  6. import torch.nn as nn
  7. from torch.fx.graph_module import GraphModule
  8. from torch.fx._compatibility import compatibility
  9. from torch.fx.passes.infra.pass_base import PassResult
  10. logger = logging.getLogger(__name__)
  11. logger.setLevel(logging.WARNING)
  12. __all__ = ['pass_result_wrapper', 'this_before_that_pass_constraint', 'PassManager']
  13. @compatibility(is_backward_compatible=False)
  14. def pass_result_wrapper(fn: Callable) -> Callable:
  15. """
  16. Wrapper for passes which currently do not return a PassResult.
  17. This wrapper makes them return a PassResult containing the modified object
  18. and True for the "modified" flag.
  19. Args:
  20. fn (Callable[Module, Any])
  21. Returns:
  22. wrapped_fn (Callable[Module, PassResult])
  23. """
  24. if fn is None:
  25. return None
  26. @wraps(fn)
  27. def wrapped_fn(gm):
  28. res = fn(gm)
  29. if res is None:
  30. return PassResult(gm, True)
  31. if isinstance(res, PassResult):
  32. return res
  33. elif isinstance(res, nn.Module):
  34. return PassResult(res, True)
  35. if not inspect.isfunction(fn):
  36. wrapped_fn.__name__ = type(fn).__name__
  37. return wrapped_fn
  38. def _validate_pass_schedule_constraint(
  39. constraint: Callable[[Callable, Callable], bool], passes: List[Callable]
  40. ) -> None:
  41. for i, a in enumerate(passes):
  42. for j, b in enumerate(passes[i + 1 :]):
  43. if constraint(a, b):
  44. continue
  45. raise RuntimeError(
  46. f"pass schedule constraint violated. Expected {a} before {b}"
  47. f" but found {a} at index {i} and {b} at index{j} in pass"
  48. f" list."
  49. )
  50. def _topological_sort_passes(
  51. passes: List[Callable], constraints: List[Callable]
  52. ) -> List[Callable]:
  53. """
  54. Args
  55. passes: Passes that we are ordering
  56. constraints: Constraints applied on these passes
  57. Returns
  58. A sorted list of callables and a boolean of if a circular dependency
  59. existed
  60. """
  61. if len(constraints) == 0:
  62. return passes
  63. # Contruct a graph mapping nodes to a list of their users
  64. graph: Dict[Callable, List[Callable]] = {p : [] for p in passes}
  65. indegree_map: Dict[Callable, int] = {p : 0 for p in passes}
  66. candidates: Queue = Queue()
  67. for a in passes:
  68. for b in passes:
  69. if a == b:
  70. continue
  71. for constraint in constraints:
  72. if not constraint(a, b):
  73. graph[b].append(a)
  74. indegree_map[a] += 1
  75. if indegree_map[a] == 0:
  76. candidates.put(a)
  77. visited: Dict[Callable, bool] = {p : False for p in passes}
  78. sorted_passes: List[Callable] = []
  79. while not candidates.empty():
  80. p = candidates.get()
  81. sorted_passes.append(p)
  82. visited[p] = True
  83. for n in graph[p]:
  84. if not visited[n]:
  85. indegree_map[n] -= 1
  86. if indegree_map[n] == 0:
  87. candidates.put(n)
  88. # Check if there are unvisited nodes (aka cycles in the graph)
  89. cycle_passes = list(filter(lambda p: indegree_map[p] != 0, indegree_map.keys()))
  90. if len(cycle_passes) != 0:
  91. error = f"Circular dependency detected within the following passes: {cycle_passes}"
  92. raise RuntimeError(error)
  93. return sorted_passes
  94. @compatibility(is_backward_compatible=False)
  95. def this_before_that_pass_constraint(this: Callable, that: Callable) -> Callable:
  96. """
  97. Defines a partial order ('depends on' function) where `this` must occur
  98. before `that`.
  99. For example, the following pass list and constraint list would be invalid.
  100. ```
  101. passes = [pass_b, pass_a]
  102. constraints = [
  103. this_before_that_pass_constraint(pass_a, pass_b)
  104. ]
  105. ```
  106. Args:
  107. this (Callable): pass which should occur first
  108. that (Callable): pass which should occur later
  109. Returns:
  110. depends_on (Callable[[Object, Object], bool]
  111. """
  112. def depends_on(a: Callable, b: Callable):
  113. if a == that and b == this:
  114. return False
  115. return True
  116. return depends_on
  117. @compatibility(is_backward_compatible=False)
  118. class PassManager:
  119. """
  120. Construct a PassManager.
  121. Collects passes and constraints. This defines the pass schedule, manages
  122. pass constraints and pass execution.
  123. Args:
  124. passes (Optional[List[Callable]]): List of passes. A pass is a
  125. callable which modifies an object and returns a PassResult
  126. constraint (Optional[List[Callable]]): List of constraints. A
  127. constraint is a callable which takes two passes (A, B) and returns
  128. True if A depends on B and False otherwise. See implementation of
  129. `this_before_that_pass_constraint` for example.
  130. steps (int): Max number of times we run the passes (default = 1).
  131. run_checks_after_each_pass (bool): Whether to run checks and linting
  132. after each pass
  133. suppress_check_failures (bool): Whether to raise errors when running
  134. checks
  135. """
  136. passes: List[Callable[[nn.Module], PassResult]]
  137. constraints: List[Callable[[Callable, Callable], bool]]
  138. _validated: bool = False
  139. steps: int = 1
  140. def __init__(
  141. self,
  142. passes=None,
  143. constraints=None,
  144. steps=None,
  145. run_checks_after_each_pass: bool = False,
  146. suppress_check_failures: bool = False,
  147. ):
  148. self.passes = passes or []
  149. self.constraints = constraints or []
  150. if steps:
  151. self.steps = steps
  152. self.run_checks_after_each_pass = run_checks_after_each_pass
  153. self.suppress_check_failures = suppress_check_failures
  154. def add_pass(self, _pass: Callable):
  155. """
  156. Adds a pass into the current list of passes.
  157. """
  158. self.passes.append(_pass)
  159. self._validated = False
  160. def add_constraint(self, constraint: Callable):
  161. """
  162. Adds a constraint into the current list of constraints.
  163. """
  164. self.constraints.append(constraint)
  165. self._validated = False
  166. def validate_constraints(self):
  167. """
  168. Validates that current pass schedule defined by `self.passes` is valid
  169. according to all constraints in `self.constraints`
  170. """
  171. if self._validated:
  172. return
  173. for constraint in self.constraints:
  174. _validate_pass_schedule_constraint(constraint, self.passes)
  175. self._validated = True
  176. def solve_constraints(self):
  177. """
  178. Finds a valid traversal order based on the given constraints and orders
  179. the passes based on this order.
  180. If a circular dependency exists between the constraints and steps = 1,
  181. then we will raise an error because if steps != 1 this means that we
  182. will re-run the passes, allowing for circular dependencies.
  183. """
  184. self.passes = _topological_sort_passes(self.passes, self.constraints)
  185. self._validated = True
  186. def add_checks(self, check: Callable) -> None:
  187. """
  188. Adds a function which takes runs various checks on a given graph module.
  189. This function is run before and after each pass if the
  190. `run_checks_after_each_pass` flag is enabled.
  191. """
  192. sig = inspect.signature(check)
  193. if len(list(sig.parameters.values())) != 1:
  194. raise TypeError("PassManager check function should only take in one variable, a module")
  195. setattr(self, "check", check) # noqa: B010
  196. def check(self, module: nn.Module) -> None:
  197. pass
  198. def __call__(self, module: nn.Module) -> PassResult:
  199. """
  200. Runs a list of passes in the order based on `self.passes` on the given
  201. graph module. Each time a pass is run, checks and linting will be run on
  202. the graph module if `run_checks_after_each_pass` is set.
  203. If the module is a graph module, we will run the list of passes until
  204. the graph stops changing, or until `steps` number of times.
  205. """
  206. # Order the passes based on the constraints
  207. if not self._validated:
  208. self.solve_constraints()
  209. # Check graph invariants
  210. self.check(module)
  211. # Run the set of passes `steps` number of times or until the graph stops
  212. # changing
  213. overall_modified = False
  214. for _ in range(self.steps):
  215. modified = False
  216. # Run the set of passes on the graph module
  217. for i, fn in enumerate(self.passes):
  218. fn_name = fn.__name__ if inspect.isfunction(fn) else type(fn).__name__
  219. logger.debug(f"Running pass '{fn_name}'")
  220. try:
  221. res = fn(module)
  222. if not isinstance(res, PassResult) and not hasattr(
  223. res, "graph_module"
  224. ):
  225. raise TypeError(
  226. f"The result of the pass {fn_name} should be type PassResult."
  227. + "Please wrap it with pass_result_wrapper()"
  228. )
  229. module = res.graph_module
  230. modified = modified or res.modified
  231. if isinstance(module, GraphModule):
  232. logger.debug(f"Graph after pass '{fn_name}':", module.graph)
  233. module.recompile()
  234. # Check graph invariants
  235. if self.run_checks_after_each_pass:
  236. self.check(module)
  237. except Exception as e:
  238. prev_pass_names = [
  239. p.__name__ if inspect.isfunction(p) else type(p).__name__
  240. for p in self.passes[:i]
  241. ]
  242. msg = f"An error occurred when running the '{fn_name}' pass after the following passes: {prev_pass_names}"
  243. raise Exception(msg) from e
  244. # If the graph no longer changes, then we can stop running these passes
  245. overall_modified = overall_modified or modified
  246. if not modified:
  247. break
  248. return PassResult(module, overall_modified)