pass_base.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import abc
  2. from collections import namedtuple
  3. from typing import Optional
  4. from torch.fx.graph_module import GraphModule
  5. from torch.fx._compatibility import compatibility
  6. __all__ = ['PassResult', 'PassBase']
  7. @compatibility(is_backward_compatible=False)
  8. class PassResult(namedtuple("PassResult", ["graph_module", "modified"])):
  9. """
  10. Result of a pass:
  11. graph_module: The modified graph module
  12. modified: A flag for if the pass has modified the graph module
  13. """
  14. def __new__(cls, graph_module, modified):
  15. return super().__new__(cls, graph_module, modified)
  16. @compatibility(is_backward_compatible=False)
  17. class PassBase(abc.ABC):
  18. """
  19. Base interface for implementing passes.
  20. It is required to implement the `call` function so that we can directly
  21. pass instances of the Pass directly to the PassManager and call them as a
  22. function.
  23. We can directly pass an instance of a class implementing this interface into
  24. the PassManager's `passes` attribute.
  25. """
  26. def __init__(self) -> None:
  27. pass
  28. def __call__(self, graph_module: GraphModule) -> Optional[PassResult]:
  29. """
  30. Runs the precondition check, the pass itself, and the postcondition check.
  31. """
  32. self.requires(graph_module)
  33. res = self.call(graph_module)
  34. self.ensures(graph_module)
  35. return res
  36. @abc.abstractmethod
  37. def call(self, graph_module: GraphModule) -> Optional[PassResult]:
  38. """
  39. The pass that is run through the given graph module. To implement a
  40. pass, it is required to implement this function.
  41. Args:
  42. graph_module: The graph module we will run a pass on
  43. """
  44. pass
  45. def requires(self, graph_module: GraphModule) -> None:
  46. """
  47. This function will be called before the pass is run and will check that
  48. the given graph module contains the preconditions needed to run the
  49. pass. It is not required to implement this function.
  50. Args:
  51. graph_module: The graph module we will run checks on
  52. """
  53. pass
  54. def ensures(self, graph_module: GraphModule) -> None:
  55. """
  56. This function will be called after the pass is run and will check that
  57. the given graph module contains the postconditions needed to run the
  58. pass. It is not required to implement this function.
  59. Args:
  60. graph_module: The graph module we will run checks on
  61. """
  62. pass