operator_support.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. import abc
  2. import typing as t
  3. import torch
  4. import torch.fx
  5. from torch.fx._compatibility import compatibility
  6. from .shape_prop import TensorMetadata
  7. from .tools_common import get_node_target, CALLABLE_NODE_OPS
  8. __all__ = ['OperatorSupportBase', 'OperatorSupport', 'create_op_support', 'chain', 'OpSupports', 'any_chain']
  9. # fx.Node.target typename, as returned by `get_node_target()`
  10. TargetTypeName = str
  11. # Arguments' dtypes for a given node, see `OperatorSupport`
  12. SupportedArgumentDTypes = t.Optional[
  13. t.Tuple[
  14. t.Sequence[t.Sequence[torch.dtype]],
  15. t.Dict[str, t.Sequence[torch.dtype]],
  16. ]
  17. ]
  18. SupportDict = t.Mapping[TargetTypeName, SupportedArgumentDTypes]
  19. @compatibility(is_backward_compatible=False)
  20. class OperatorSupportBase(abc.ABC):
  21. """Interface for determining if a fx.Node is supported by a backend"""
  22. @abc.abstractmethod
  23. def is_node_supported(
  24. self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node
  25. ) -> bool:
  26. raise NotImplementedError()
  27. @compatibility(is_backward_compatible=False)
  28. class OperatorSupport(OperatorSupportBase):
  29. """
  30. `_support_dict` maps node.target typename to supported inputs dtypes.
  31. node.target typename is retrieved using helper function `get_node_target()`
  32. If supported inputs dtypes is None, it means any dtype is supported, else
  33. we should see a tuple like (([dtypes], ...), {"name":[dtypes], ...}).
  34. The first tuple ([dtypes], ...) indicates what dtypes are supported for
  35. inputs in node.args and the second dict {"name": [dtypes], ...} indicates
  36. what dtypes are supported for inputs in node.kwargs.
  37. For inputs in args, if we don't want to check it, we can put None there,
  38. e.g. (None, [torch.float]) indicates that we don't care about the type of
  39. the first input in args. And for inputs in kwargs, if not listed, will not
  40. be checked.
  41. """
  42. _support_dict: SupportDict
  43. def __init__(
  44. self,
  45. support_dict: t.Optional[SupportDict] = None
  46. ):
  47. self._support_dict = support_dict or {}
  48. def is_node_supported(
  49. self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node
  50. ) -> bool:
  51. """
  52. Args:
  53. `submodules`: mapping from module name to the module. This can be
  54. retrieved by calling model.named_modules().
  55. `node`: a Fx node that we want to determine whether it's supported.
  56. Returns:
  57. `is_supported`: whether the arg `node` is supported.
  58. """
  59. if node.op not in CALLABLE_NODE_OPS:
  60. return True
  61. target = get_node_target(submodules, node)
  62. # Target not found in _support_dict meaning that we don't support this op at all
  63. if target not in self._support_dict:
  64. return False
  65. # The rule for target is None meaning that we accept any dtype
  66. if self._support_dict[target] is None:
  67. return True
  68. args_dtypes, kwargs_dtypes = self._support_dict[target] # type: ignore[misc]
  69. # Check args dtypes
  70. for i, dtypes in enumerate(args_dtypes):
  71. if len(node.args) <= i:
  72. break
  73. # None indicates we don't care about the dtype of args[i]
  74. if dtypes is None:
  75. continue
  76. # If arg is not a node then we don't check it
  77. if not isinstance(node.args[i], torch.fx.Node):
  78. continue
  79. arg_dtype = _get_arg_dtype(node.args[i]) # type: ignore[arg-type]
  80. if arg_dtype not in dtypes:
  81. return False
  82. # Check kwargs dtypes
  83. for k, dtypes in kwargs_dtypes.items():
  84. if k not in node.kwargs:
  85. continue
  86. # If arg is not a node then we don't check it
  87. if not isinstance(node.kwargs[k], torch.fx.Node):
  88. continue
  89. kwarg_dtype = _get_arg_dtype(node.kwargs[k]) # type: ignore[arg-type]
  90. if kwarg_dtype not in dtypes:
  91. return False
  92. return True
  93. # ======================================================================
  94. # Functional interfaces and utils for defining basic operator support logic
  95. # and composing them into more complex ones
  96. # ======================================================================
  97. IsNodeSupported = t.Callable[[t.Mapping[str, torch.nn.Module], torch.fx.Node], bool]
  98. @compatibility(is_backward_compatible=False)
  99. def create_op_support(is_node_supported: IsNodeSupported) -> OperatorSupportBase:
  100. """Wraps a `IsNodeSupported` function into an `OperatorSupportBase` instance
  101. `IsNodeSupported` has the same call signature as
  102. `OperatorSupportBase.is_node_supported`
  103. """
  104. class FunctionalOperatorSupport(OperatorSupportBase):
  105. def is_node_supported(
  106. self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node
  107. ) -> bool:
  108. return is_node_supported(submodules, node)
  109. return FunctionalOperatorSupport()
  110. @compatibility(is_backward_compatible=False)
  111. def chain(*op_support: OperatorSupportBase) -> OperatorSupportBase:
  112. """Combines a sequence of `OperatorSupportBase` instances to form a single `OperatorSupportBase`
  113. instance by evaluating each input `OperatorSupportBase` instance, and returns False if
  114. any of it reports False.
  115. """
  116. def _chain(submods, node) -> bool:
  117. return all(
  118. x.is_node_supported(submods, node)
  119. for x in op_support
  120. )
  121. return create_op_support(_chain)
  122. @compatibility(is_backward_compatible=False)
  123. def any_chain(*op_support: OperatorSupportBase) -> OperatorSupportBase:
  124. """Combines a sequence of `OperatorSupportBase` instances to form a single `OperatorSupportBase`
  125. instance by evaluating each input `OperatorSupportBase` instance, and returns True if
  126. any of it reports True.
  127. """
  128. def _any_chain(submods, node) -> bool:
  129. return any(
  130. x.is_node_supported(submods, node)
  131. for x in op_support
  132. )
  133. return create_op_support(_any_chain)
  134. @compatibility(is_backward_compatible=False)
  135. class OpSupports:
  136. """A set of atomic `OperatorSupportBase` instances that can be combined together
  137. to form more complex operator support logic.
  138. """
  139. @classmethod
  140. def decline_if_input_dtype(cls, dtype: torch.dtype) -> OperatorSupportBase:
  141. """Report a node as non-supported, if any of its arguments is of dtype"""
  142. def _decline_if_input_dtype(
  143. submodules: t.Mapping[str, torch.nn.Module],
  144. node: torch.fx.Node,
  145. ) -> bool:
  146. for arg in node.all_input_nodes:
  147. # escape dtype check for get_attr node
  148. if arg.op == "get_attr":
  149. continue
  150. arg_dtype = _get_arg_dtype(arg)
  151. if arg_dtype == dtype:
  152. return False
  153. return True
  154. return create_op_support(_decline_if_input_dtype)
  155. @classmethod
  156. def decline_if_node_in_names(cls, disallow_set: t.Set[str]) -> OperatorSupportBase:
  157. """
  158. If a node has a name that is in the disallow set, reported it as non-supported.
  159. """
  160. def _decline_if_node_in_names(
  161. submodules: t.Mapping[str, torch.nn.Module],
  162. node: torch.fx.Node,
  163. ) -> bool:
  164. if node.name in disallow_set:
  165. return False
  166. else:
  167. return True
  168. return create_op_support(_decline_if_node_in_names)
  169. def _get_arg_dtype(arg: torch.fx.Node) -> t.Any:
  170. assert isinstance(arg, torch.fx.Node)
  171. tensor_meta = arg.meta.get("tensor_meta") # type: ignore[union-attr]
  172. dtype = tensor_meta.dtype if isinstance(tensor_meta, TensorMetadata) else arg.meta["type"]
  173. return dtype