test_pass_manager.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import unittest
  2. from ..pass_manager import (
  3. inplace_wrapper,
  4. PassManager,
  5. these_before_those_pass_constraint,
  6. this_before_that_pass_constraint,
  7. )
  8. class TestPassManager(unittest.TestCase):
  9. def test_pass_manager_builder(self) -> None:
  10. passes = [lambda x: 2 * x for _ in range(10)]
  11. pm = PassManager(passes)
  12. pm.validate()
  13. def test_this_before_that_pass_constraint(self) -> None:
  14. passes = [lambda x: 2 * x for _ in range(10)]
  15. pm = PassManager(passes)
  16. # add unfulfillable constraint
  17. pm.add_constraint(this_before_that_pass_constraint(passes[-1], passes[0]))
  18. self.assertRaises(RuntimeError, pm.validate)
  19. def test_these_before_those_pass_constraint(self) -> None:
  20. passes = [lambda x: 2 * x for _ in range(10)]
  21. constraint = these_before_those_pass_constraint(passes[-1], passes[0])
  22. pm = PassManager(
  23. [inplace_wrapper(p) for p in passes]
  24. )
  25. # add unfulfillable constraint
  26. pm.add_constraint(constraint)
  27. self.assertRaises(RuntimeError, pm.validate)
  28. def test_two_pass_managers(self) -> None:
  29. """Make sure we can construct the PassManager twice and not share any
  30. state between them"""
  31. passes = [lambda x: 2 * x for _ in range(3)]
  32. constraint = these_before_those_pass_constraint(passes[0], passes[1])
  33. pm1 = PassManager()
  34. for p in passes:
  35. pm1.add_pass(p)
  36. pm1.add_constraint(constraint)
  37. output1 = pm1(1)
  38. self.assertEqual(output1, 2 ** 3)
  39. passes = [lambda x: 3 * x for _ in range(3)]
  40. constraint = these_before_those_pass_constraint(passes[0], passes[1])
  41. pm2 = PassManager()
  42. for p in passes:
  43. pm2.add_pass(p)
  44. pm2.add_constraint(constraint)
  45. output2 = pm2(1)
  46. self.assertEqual(output2, 3 ** 3)