pass_manager.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. from functools import wraps
  2. from inspect import unwrap
  3. from typing import Callable, List
  4. import logging
  5. logger = logging.getLogger(__name__)
  6. # for callables which modify object inplace and return something other than
  7. # the object on which they act
  8. def inplace_wrapper(fn: Callable) -> Callable:
  9. """
  10. Convenience wrapper for passes which modify an object inplace. This
  11. wrapper makes them return the modified object instead.
  12. Args:
  13. fn (Callable[Object, Any])
  14. Returns:
  15. wrapped_fn (Callable[Object, Object])
  16. """
  17. @wraps(fn)
  18. def wrapped_fn(gm):
  19. val = fn(gm)
  20. return gm
  21. return wrapped_fn
  22. def log_hook(fn: Callable, level=logging.INFO) -> Callable:
  23. """
  24. Logs callable output.
  25. This is useful for logging output of passes. Note inplace_wrapper replaces
  26. the pass output with the modified object. If we want to log the original
  27. output, apply this wrapper before inplace_wrapper.
  28. ```
  29. def my_pass(d: Dict) -> bool:
  30. changed = False
  31. if 'foo' in d:
  32. d['foo'] = 'bar'
  33. changed = True
  34. return changed
  35. pm = PassManager(
  36. passes=[
  37. inplace_wrapper(log_hook(my_pass))
  38. ]
  39. )
  40. ```
  41. Args:
  42. fn (Callable[Type1, Type2])
  43. level: logging level (e.g. logging.INFO)
  44. Returns:
  45. wrapped_fn (Callable[Type1, Type2])
  46. """
  47. @wraps(fn)
  48. def wrapped_fn(gm):
  49. val = fn(gm)
  50. logger.log(level, f"Ran pass {fn}\t Return value: {val}",)
  51. return val
  52. return wrapped_fn
  53. def loop_pass(base_pass: Callable, n_iter: int = None, predicate: Callable = None):
  54. """
  55. Convenience wrapper for passes which need to be applied multiple times.
  56. Exactly one of `n_iter`or `predicate` must be specified.
  57. Args:
  58. base_pass (Callable[Object, Object]): pass to be applied in loop
  59. n_iter (int, optional): number of times to loop pass
  60. predicate (Callable[Object, bool], optional):
  61. """
  62. assert (n_iter is not None) ^ (
  63. predicate is not None
  64. ), "Exactly one of `n_iter`or `predicate` must be specified."
  65. @wraps(base_pass)
  66. def new_pass(source):
  67. output = source
  68. if n_iter is not None and n_iter > 0:
  69. for _ in range(n_iter):
  70. output = base_pass(output)
  71. elif predicate is not None:
  72. while predicate(output):
  73. output = base_pass(output)
  74. else:
  75. raise RuntimeError(
  76. f"loop_pass must be given positive int n_iter (given "
  77. f"{n_iter}) xor predicate (given {predicate})"
  78. )
  79. return output
  80. return new_pass
  81. # Pass Schedule Constraints:
  82. #
  83. # Implemented as 'depends on' operators. A constraint is satisfied iff a list
  84. # has a valid partial ordering according to this comparison operator.
  85. def _validate_pass_schedule_constraint(
  86. constraint: Callable[[Callable, Callable], bool], passes: List[Callable]
  87. ):
  88. for i, a in enumerate(passes):
  89. for j, b in enumerate(passes[i + 1 :]):
  90. if constraint(a, b):
  91. continue
  92. raise RuntimeError(
  93. f"pass schedule constraint violated. Expected {a} before {b}"
  94. f" but found {a} at index {i} and {b} at index{j} in pass"
  95. f" list."
  96. )
  97. def this_before_that_pass_constraint(this: Callable, that: Callable):
  98. """
  99. Defines a partial order ('depends on' function) where `this` must occur
  100. before `that`.
  101. """
  102. def depends_on(a: Callable, b: Callable):
  103. if a == that and b == this:
  104. return False
  105. return True
  106. return depends_on
  107. def these_before_those_pass_constraint(these: Callable, those: Callable):
  108. """
  109. Defines a partial order ('depends on' function) where `these` must occur
  110. before `those`. Where the inputs are 'unwrapped' before comparison.
  111. For example, the following pass list and constraint list would be invalid.
  112. ```
  113. passes = [
  114. loop_pass(pass_b, 3),
  115. loop_pass(pass_a, 5),
  116. ]
  117. constraints = [
  118. these_before_those_pass_constraint(pass_a, pass_b)
  119. ]
  120. ```
  121. Args:
  122. these (Callable): pass which should occur first
  123. those (Callable): pass which should occur later
  124. Returns:
  125. depends_on (Callable[[Object, Object], bool]
  126. """
  127. def depends_on(a: Callable, b: Callable):
  128. if unwrap(a) == those and unwrap(b) == these:
  129. return False
  130. return True
  131. return depends_on
  132. class PassManager:
  133. """
  134. Construct a PassManager.
  135. Collects passes and constraints. This defines the pass schedule, manages
  136. pass constraints and pass execution.
  137. Args:
  138. passes (Optional[List[Callable]]): list of passes. A pass is a
  139. callable which modifies an object and returns modified object
  140. constraint (Optional[List[Callable]]): list of constraints. A
  141. constraint is a callable which takes two passes (A, B) and returns
  142. True if A depends on B and False otherwise. See implementation of
  143. `this_before_that_pass_constraint` for example.
  144. """
  145. passes: List[Callable]
  146. constraints: List[Callable]
  147. _validated: bool = False
  148. def __init__(
  149. self,
  150. passes=None,
  151. constraints=None,
  152. ):
  153. self.passes = passes or []
  154. self.constraints = constraints or []
  155. @classmethod
  156. def build_from_passlist(cls, passes):
  157. pm = PassManager(passes)
  158. # TODO(alexbeloi): add constraint management/validation
  159. return pm
  160. def add_pass(self, _pass: Callable):
  161. self.passes.append(_pass)
  162. self._validated = False
  163. def add_constraint(self, constraint):
  164. self.constraints.append(constraint)
  165. self._validated = False
  166. def remove_pass(self, _passes: List[Callable]):
  167. if _passes is None:
  168. return
  169. passes_left = []
  170. for ps in self.passes:
  171. if ps.__name__ not in _passes:
  172. passes_left.append(ps)
  173. self.passes = passes_left
  174. self._validated = False
  175. def validate(self):
  176. """
  177. Validates that current pass schedule defined by `self.passes` is valid
  178. according to all constraints in `self.constraints`
  179. """
  180. if self._validated:
  181. return
  182. for constraint in self.constraints:
  183. _validate_pass_schedule_constraint(constraint, self.passes)
  184. self._validated = True
  185. def __call__(self, source):
  186. self.validate()
  187. out = source
  188. for _pass in self.passes:
  189. out = _pass(out)
  190. return out