123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220 |
- import abc
- import typing as t
- import torch
- import torch.fx
- from torch.fx._compatibility import compatibility
- from .shape_prop import TensorMetadata
- from .tools_common import get_node_target, CALLABLE_NODE_OPS
- __all__ = ['OperatorSupportBase', 'OperatorSupport', 'create_op_support', 'chain', 'OpSupports', 'any_chain']
- # fx.Node.target typename, as returned by `get_node_target()`
- TargetTypeName = str
- # Arguments' dtypes for a given node, see `OperatorSupport`
- SupportedArgumentDTypes = t.Optional[
- t.Tuple[
- t.Sequence[t.Sequence[torch.dtype]],
- t.Dict[str, t.Sequence[torch.dtype]],
- ]
- ]
- SupportDict = t.Mapping[TargetTypeName, SupportedArgumentDTypes]
- @compatibility(is_backward_compatible=False)
- class OperatorSupportBase(abc.ABC):
- """Interface for determining if a fx.Node is supported by a backend"""
- @abc.abstractmethod
- def is_node_supported(
- self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node
- ) -> bool:
- raise NotImplementedError()
- @compatibility(is_backward_compatible=False)
- class OperatorSupport(OperatorSupportBase):
- """
- `_support_dict` maps node.target typename to supported inputs dtypes.
- node.target typename is retrieved using helper function `get_node_target()`
- If supported inputs dtypes is None, it means any dtype is supported, else
- we should see a tuple like (([dtypes], ...), {"name":[dtypes], ...}).
- The first tuple ([dtypes], ...) indicates what dtypes are supported for
- inputs in node.args and the second dict {"name": [dtypes], ...} indicates
- what dtypes are supported for inputs in node.kwargs.
- For inputs in args, if we don't want to check it, we can put None there,
- e.g. (None, [torch.float]) indicates that we don't care about the type of
- the first input in args. And for inputs in kwargs, if not listed, will not
- be checked.
- """
- _support_dict: SupportDict
- def __init__(
- self,
- support_dict: t.Optional[SupportDict] = None
- ):
- self._support_dict = support_dict or {}
- def is_node_supported(
- self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node
- ) -> bool:
- """
- Args:
- `submodules`: mapping from module name to the module. This can be
- retrieved by calling model.named_modules().
- `node`: a Fx node that we want to determine whether it's supported.
- Returns:
- `is_supported`: whether the arg `node` is supported.
- """
- if node.op not in CALLABLE_NODE_OPS:
- return True
- target = get_node_target(submodules, node)
- # Target not found in _support_dict meaning that we don't support this op at all
- if target not in self._support_dict:
- return False
- # The rule for target is None meaning that we accept any dtype
- if self._support_dict[target] is None:
- return True
- args_dtypes, kwargs_dtypes = self._support_dict[target] # type: ignore[misc]
- # Check args dtypes
- for i, dtypes in enumerate(args_dtypes):
- if len(node.args) <= i:
- break
- # None indicates we don't care about the dtype of args[i]
- if dtypes is None:
- continue
- # If arg is not a node then we don't check it
- if not isinstance(node.args[i], torch.fx.Node):
- continue
- arg_dtype = _get_arg_dtype(node.args[i]) # type: ignore[arg-type]
- if arg_dtype not in dtypes:
- return False
- # Check kwargs dtypes
- for k, dtypes in kwargs_dtypes.items():
- if k not in node.kwargs:
- continue
- # If arg is not a node then we don't check it
- if not isinstance(node.kwargs[k], torch.fx.Node):
- continue
- kwarg_dtype = _get_arg_dtype(node.kwargs[k]) # type: ignore[arg-type]
- if kwarg_dtype not in dtypes:
- return False
- return True
- # ======================================================================
- # Functional interfaces and utils for defining basic operator support logic
- # and composing them into more complex ones
- # ======================================================================
- IsNodeSupported = t.Callable[[t.Mapping[str, torch.nn.Module], torch.fx.Node], bool]
- @compatibility(is_backward_compatible=False)
- def create_op_support(is_node_supported: IsNodeSupported) -> OperatorSupportBase:
- """Wraps a `IsNodeSupported` function into an `OperatorSupportBase` instance
- `IsNodeSupported` has the same call signature as
- `OperatorSupportBase.is_node_supported`
- """
- class FunctionalOperatorSupport(OperatorSupportBase):
- def is_node_supported(
- self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node
- ) -> bool:
- return is_node_supported(submodules, node)
- return FunctionalOperatorSupport()
- @compatibility(is_backward_compatible=False)
- def chain(*op_support: OperatorSupportBase) -> OperatorSupportBase:
- """Combines a sequence of `OperatorSupportBase` instances to form a single `OperatorSupportBase`
- instance by evaluating each input `OperatorSupportBase` instance, and returns False if
- any of it reports False.
- """
- def _chain(submods, node) -> bool:
- return all(
- x.is_node_supported(submods, node)
- for x in op_support
- )
- return create_op_support(_chain)
- @compatibility(is_backward_compatible=False)
- def any_chain(*op_support: OperatorSupportBase) -> OperatorSupportBase:
- """Combines a sequence of `OperatorSupportBase` instances to form a single `OperatorSupportBase`
- instance by evaluating each input `OperatorSupportBase` instance, and returns True if
- any of it reports True.
- """
- def _any_chain(submods, node) -> bool:
- return any(
- x.is_node_supported(submods, node)
- for x in op_support
- )
- return create_op_support(_any_chain)
- @compatibility(is_backward_compatible=False)
- class OpSupports:
- """A set of atomic `OperatorSupportBase` instances that can be combined together
- to form more complex operator support logic.
- """
- @classmethod
- def decline_if_input_dtype(cls, dtype: torch.dtype) -> OperatorSupportBase:
- """Report a node as non-supported, if any of its arguments is of dtype"""
- def _decline_if_input_dtype(
- submodules: t.Mapping[str, torch.nn.Module],
- node: torch.fx.Node,
- ) -> bool:
- for arg in node.all_input_nodes:
- # escape dtype check for get_attr node
- if arg.op == "get_attr":
- continue
- arg_dtype = _get_arg_dtype(arg)
- if arg_dtype == dtype:
- return False
- return True
- return create_op_support(_decline_if_input_dtype)
- @classmethod
- def decline_if_node_in_names(cls, disallow_set: t.Set[str]) -> OperatorSupportBase:
- """
- If a node has a name that is in the disallow set, reported it as non-supported.
- """
- def _decline_if_node_in_names(
- submodules: t.Mapping[str, torch.nn.Module],
- node: torch.fx.Node,
- ) -> bool:
- if node.name in disallow_set:
- return False
- else:
- return True
- return create_op_support(_decline_if_node_in_names)
- def _get_arg_dtype(arg: torch.fx.Node) -> t.Any:
- assert isinstance(arg, torch.fx.Node)
- tensor_meta = arg.meta.get("tensor_meta") # type: ignore[union-attr]
- dtype = tensor_meta.dtype if isinstance(tensor_meta, TensorMetadata) else arg.meta["type"]
- return dtype
|