123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119 |
- import torch
- from torch.ao.quantization.backend_config import BackendConfig
- from torch.fx.graph import Node, Graph
- from ..utils import _parent_name, NodePattern, Pattern
- from ..fuser_method_mappings import get_fuser_method_new
- from abc import ABC, abstractmethod
- from typing import Any, Callable, Dict, List, Union
- from .custom_config import FuseCustomConfig
- from .match_utils import MatchAllNode
- from torch.nn.utils.parametrize import type_before_parametrizations
- __all__ = [
- "DefaultFuseHandler",
- "FuseHandler",
- ]
- # ----------------------------
- # Fusion Pattern Registrations
- # ----------------------------
- # Base Pattern Handler
- class FuseHandler(ABC):
- """ Base handler class for the fusion patterns
- """
- def __init__(self, node: Node):
- pass
- @abstractmethod
- def fuse(self,
- load_arg: Callable,
- named_modules: Dict[str, torch.nn.Module],
- fused_graph: Graph,
- root_node: Node,
- extra_inputs: List[Any],
- matched_node_pattern: NodePattern,
- fuse_custom_config: FuseCustomConfig,
- fuser_method_mapping: Dict[Pattern, Union[torch.nn.Sequential, Callable]],
- is_qat: bool) -> Node:
- pass
- class DefaultFuseHandler(FuseHandler):
- def __init__(
- self,
- node: Node):
- super().__init__(node)
- def fuse(self,
- load_arg: Callable,
- named_modules: Dict[str, torch.nn.Module],
- fused_graph: Graph,
- root_node: Node,
- extra_inputs: List[Any],
- matched_node_pattern: NodePattern,
- fuse_custom_config: FuseCustomConfig,
- fuser_method_mapping: Dict[Pattern, Union[torch.nn.Sequential, Callable]],
- is_qat: bool) -> Node:
- assert root_node.op == "call_module", "Expecting module node to be a call_module Node"
- root_module = named_modules[str(root_node.target)]
- def get_modules(pattern):
- """ Given a node pattern, extract the corresponding modules
- e.g. input: (relu_node, (bn_node, conv_node))
- output: (relu_module, (bn_module, conv_module))
- """
- if isinstance(pattern, (tuple, list)):
- n, *args = pattern
- modules: List[torch.nn.Module] = []
- modules.append(get_modules(n))
- for a in args:
- modules.append(get_modules(a))
- return tuple(modules)
- else:
- n = pattern
- if n.op == "call_module":
- return named_modules[n.target]
- elif n.op == "call_function" and n.target == torch.nn.functional.relu:
- relu = torch.nn.ReLU()
- relu.training = root_module.training
- return relu
- elif n.op == "call_function" or n.op == "call_method":
- return n.target
- else:
- return MatchAllNode
- # since relu can be used multiple times, we'll need to create a relu module for each match
- matched_modules = get_modules(matched_node_pattern)
- def get_matched_types(m):
- if isinstance(m, tuple):
- return tuple(map(get_matched_types, m))
- if isinstance(m, torch.nn.Module):
- return type_before_parametrizations(m)
- return m
- matched_module_types = get_matched_types(matched_modules)
- module_parent_name, module_name = _parent_name(root_node.target)
- fuser_method = get_fuser_method_new(matched_module_types, fuser_method_mapping)
- # TODO: change the signature for fuser_method to take matched module patterns
- # as input
- fused_module = fuser_method(is_qat, *matched_modules)
- setattr(named_modules[module_parent_name], module_name, fused_module)
- extra_args = []
- for input in extra_inputs:
- extra_args.append(load_arg(input))
- node = fused_graph.node_copy(root_node, load_arg)
- args = list(node.args)
- args.extend(extra_args)
- node.args = tuple(args)
- return node
- def _get_fusion_pattern_to_fuse_handler_cls(
- backend_config: BackendConfig) -> Dict[Pattern, Callable]:
- fusion_pattern_to_fuse_handlers: Dict[Pattern, Callable] = {}
- for pattern, config in backend_config._pattern_complex_format_to_config.items():
- if config.fuser_method is not None:
- # TODO: is this logic right?
- fusion_pattern_to_fuse_handlers[pattern] = DefaultFuseHandler
- return fusion_pattern_to_fuse_handlers
|