|
- from functools import wraps
- from inspect import unwrap
- from typing import Callable, List
- import logging
- logger = logging.getLogger(__name__)
- # for callables which modify object inplace and return something other than
- # the object on which they act
- def inplace_wrapper(fn: Callable) -> Callable:
- """
- Convenience wrapper for passes which modify an object inplace. This
- wrapper makes them return the modified object instead.
- Args:
- fn (Callable[Object, Any])
- Returns:
- wrapped_fn (Callable[Object, Object])
- """
- @wraps(fn)
- def wrapped_fn(gm):
- val = fn(gm)
- return gm
- return wrapped_fn
- def log_hook(fn: Callable, level=logging.INFO) -> Callable:
- """
- Logs callable output.
- This is useful for logging output of passes. Note inplace_wrapper replaces
- the pass output with the modified object. If we want to log the original
- output, apply this wrapper before inplace_wrapper.
- ```
- def my_pass(d: Dict) -> bool:
- changed = False
- if 'foo' in d:
- d['foo'] = 'bar'
- changed = True
- return changed
- pm = PassManager(
- passes=[
- inplace_wrapper(log_hook(my_pass))
- ]
- )
- ```
- Args:
- fn (Callable[Type1, Type2])
- level: logging level (e.g. logging.INFO)
- Returns:
- wrapped_fn (Callable[Type1, Type2])
- """
- @wraps(fn)
- def wrapped_fn(gm):
- val = fn(gm)
- logger.log(level, f"Ran pass {fn}\t Return value: {val}",)
- return val
- return wrapped_fn
- def loop_pass(base_pass: Callable, n_iter: int = None, predicate: Callable = None):
- """
- Convenience wrapper for passes which need to be applied multiple times.
- Exactly one of `n_iter`or `predicate` must be specified.
- Args:
- base_pass (Callable[Object, Object]): pass to be applied in loop
- n_iter (int, optional): number of times to loop pass
- predicate (Callable[Object, bool], optional):
- """
- assert (n_iter is not None) ^ (
- predicate is not None
- ), "Exactly one of `n_iter`or `predicate` must be specified."
- @wraps(base_pass)
- def new_pass(source):
- output = source
- if n_iter is not None and n_iter > 0:
- for _ in range(n_iter):
- output = base_pass(output)
- elif predicate is not None:
- while predicate(output):
- output = base_pass(output)
- else:
- raise RuntimeError(
- f"loop_pass must be given positive int n_iter (given "
- f"{n_iter}) xor predicate (given {predicate})"
- )
- return output
- return new_pass
- # Pass Schedule Constraints:
- #
- # Implemented as 'depends on' operators. A constraint is satisfied iff a list
- # has a valid partial ordering according to this comparison operator.
- def _validate_pass_schedule_constraint(
- constraint: Callable[[Callable, Callable], bool], passes: List[Callable]
- ):
- for i, a in enumerate(passes):
- for j, b in enumerate(passes[i + 1 :]):
- if constraint(a, b):
- continue
- raise RuntimeError(
- f"pass schedule constraint violated. Expected {a} before {b}"
- f" but found {a} at index {i} and {b} at index{j} in pass"
- f" list."
- )
- def this_before_that_pass_constraint(this: Callable, that: Callable):
- """
- Defines a partial order ('depends on' function) where `this` must occur
- before `that`.
- """
- def depends_on(a: Callable, b: Callable):
- if a == that and b == this:
- return False
- return True
- return depends_on
- def these_before_those_pass_constraint(these: Callable, those: Callable):
- """
- Defines a partial order ('depends on' function) where `these` must occur
- before `those`. Where the inputs are 'unwrapped' before comparison.
- For example, the following pass list and constraint list would be invalid.
- ```
- passes = [
- loop_pass(pass_b, 3),
- loop_pass(pass_a, 5),
- ]
- constraints = [
- these_before_those_pass_constraint(pass_a, pass_b)
- ]
- ```
- Args:
- these (Callable): pass which should occur first
- those (Callable): pass which should occur later
- Returns:
- depends_on (Callable[[Object, Object], bool]
- """
- def depends_on(a: Callable, b: Callable):
- if unwrap(a) == those and unwrap(b) == these:
- return False
- return True
- return depends_on
- class PassManager:
- """
- Construct a PassManager.
- Collects passes and constraints. This defines the pass schedule, manages
- pass constraints and pass execution.
- Args:
- passes (Optional[List[Callable]]): list of passes. A pass is a
- callable which modifies an object and returns modified object
- constraint (Optional[List[Callable]]): list of constraints. A
- constraint is a callable which takes two passes (A, B) and returns
- True if A depends on B and False otherwise. See implementation of
- `this_before_that_pass_constraint` for example.
- """
- passes: List[Callable]
- constraints: List[Callable]
- _validated: bool = False
- def __init__(
- self,
- passes=None,
- constraints=None,
- ):
- self.passes = passes or []
- self.constraints = constraints or []
- @classmethod
- def build_from_passlist(cls, passes):
- pm = PassManager(passes)
- # TODO(alexbeloi): add constraint management/validation
- return pm
- def add_pass(self, _pass: Callable):
- self.passes.append(_pass)
- self._validated = False
- def add_constraint(self, constraint):
- self.constraints.append(constraint)
- self._validated = False
- def remove_pass(self, _passes: List[Callable]):
- if _passes is None:
- return
- passes_left = []
- for ps in self.passes:
- if ps.__name__ not in _passes:
- passes_left.append(ps)
- self.passes = passes_left
- self._validated = False
- def validate(self):
- """
- Validates that current pass schedule defined by `self.passes` is valid
- according to all constraints in `self.constraints`
- """
- if self._validated:
- return
- for constraint in self.constraints:
- _validate_pass_schedule_constraint(constraint, self.passes)
- self._validated = True
- def __call__(self, source):
- self.validate()
- out = source
- for _pass in self.passes:
- out = _pass(out)
- return out
|