123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533 |
- import enum
- import operator
- import torch
- import torch.nn as nn
- import torch.ao.nn.intrinsic.quantized as nniq
- import torch.ao.nn.quantized as nnq
- toq = torch.ops.quantized
- from typing import Tuple, Callable, Dict, Set, List, Optional, Union
- from torch.fx import GraphModule
- from torch.fx.graph import Node
- from torch.ao.quantization import (
- ObserverBase,
- FakeQuantizeBase,
- )
- from torch.ao.quantization.utils import getattr_from_fqn
- from torch.ao.quantization.observer import _is_activation_post_process
- from .ns_types import NSNodeTargetType, NSResultsType
- # TODO(future PR): consider deleting this enum and using the torch types
- # directly. This might be tricky because it is not a one to one mapping.
- class NodeInputOrOutputType(enum.Enum):
- FP32 = enum.auto() # torch.float
- INT8 = enum.auto() # torch.qint8 or torch.quint8
- FP16 = enum.auto() # torch.float16
- UNKNOWN = enum.auto() # we cannot determine input/output dtype
- # TODO(future PR): while these functions can support multiple dtypes,
- # for the purposes of numerical debugging we want to get the actual
- # dtype used in the model. We will likely need some kind of dtype
- # propagation to estimate this.
- FP32_OR_INT8 = enum.auto() # either torch.float or torch.quint8 or torch.qint8
- # TODO(future PRs): dynamic quant, fake quant, etc
- def get_node_first_input_and_output_type(
- node: Node,
- gm: GraphModule,
- logger_cls: Callable,
- node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]],
- ) -> Tuple[NodeInputOrOutputType, NodeInputOrOutputType]:
- # TODO(future PR): clean this up
- FUNS_IO_TYPE_FP32 = node_type_to_io_type_map["funs_io_type_fp32"]
- FUNS_IO_TYPE_FP16 = node_type_to_io_type_map["funs_io_type_fp16"]
- FUNS_IO_TYPE_INT8 = node_type_to_io_type_map["funs_io_type_int8"]
- FUNS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["funs_io_type_fp32_or_int8"]
- MODS_IO_TYPE_FP32 = node_type_to_io_type_map["mods_io_type_fp32"]
- MODS_IO_TYPE_INT8 = node_type_to_io_type_map["mods_io_type_int8"]
- MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["mods_io_type_fp32_or_int8"]
- METHS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["meths_io_type_fp32_or_int8"]
- if node.op == "call_function":
- if node.target in FUNS_IO_TYPE_FP32:
- return (NodeInputOrOutputType.FP32, NodeInputOrOutputType.FP32)
- if node.target in FUNS_IO_TYPE_FP16:
- return (NodeInputOrOutputType.FP16, NodeInputOrOutputType.FP16)
- elif node.target in FUNS_IO_TYPE_INT8:
- return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8)
- elif node.target in FUNS_IO_TYPE_FP32_OR_INT8:
- first_arg = get_normalized_nth_input(node, gm, 0)
- assert isinstance(first_arg, Node)
- (
- _prev_node_input_type,
- prev_node_output_type,
- ) = get_node_first_input_and_output_type(
- first_arg, gm, logger_cls, node_type_to_io_type_map
- )
- return (prev_node_output_type, prev_node_output_type)
- else:
- return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
- elif node.op == "call_module":
- assert node.op == "call_module"
- assert isinstance(node.target, str)
- mod = getattr_from_fqn(gm, node.target)
- is_known_fp32_or_int8_input_module = any(
- isinstance(mod, target_type) for target_type in MODS_IO_TYPE_FP32_OR_INT8 # type: ignore[arg-type]
- )
- if (
- isinstance(mod, (logger_cls, ObserverBase, FakeQuantizeBase)) # type: ignore[arg-type]
- or is_known_fp32_or_int8_input_module
- ):
- # A logger or observer's input and output type is the output
- # type of the preceding node.
- first_arg = get_normalized_nth_input(node, gm, 0)
- assert isinstance(first_arg, Node)
- (
- _prev_node_input_type,
- prev_node_output_type,
- ) = get_node_first_input_and_output_type(
- first_arg, gm, logger_cls, node_type_to_io_type_map
- )
- return (prev_node_output_type, prev_node_output_type)
- is_known_fp32_input_module = any(
- isinstance(mod, target_type) for target_type in MODS_IO_TYPE_FP32 # type: ignore[arg-type]
- )
- is_known_int8_input_module = any(
- isinstance(mod, target_type) for target_type in MODS_IO_TYPE_INT8 # type: ignore[arg-type]
- )
- if is_known_fp32_input_module:
- return (NodeInputOrOutputType.FP32, NodeInputOrOutputType.FP32)
- elif is_known_int8_input_module:
- return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8)
- else:
- return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
- elif node.op == "call_method":
- if node.target == "dequantize":
- # Dequantize is a special node because it allows multiple input types.
- # So, we look up the output type of the previous node and return that
- # as the input type of this node instance.
- prev_node = get_normalized_nth_input(node, gm, 0)
- assert isinstance(prev_node, Node)
- (
- _prev_node_input_type,
- prev_node_output_type,
- ) = get_node_first_input_and_output_type(
- prev_node, gm, logger_cls, node_type_to_io_type_map
- )
- return (prev_node_output_type, NodeInputOrOutputType.FP32)
- elif node.target == "to":
- # to is a special node because it allows multiple input types.
- # So, we look up the output type of the previous node and return that
- # as the input type of this node instance. We also look up the target
- # of to and return the correct output type.
- prev_node = get_normalized_nth_input(node, gm, 0)
- assert isinstance(prev_node, Node)
- (
- _prev_node_input_type,
- prev_node_output_type,
- ) = get_node_first_input_and_output_type(
- prev_node, gm, logger_cls, node_type_to_io_type_map
- )
- cur_node_dtype_target = get_normalized_nth_input(node, gm, 1)
- assert (
- cur_node_dtype_target is torch.float16
- ), f"{cur_node_dtype_target} handling needs to be added"
- return (prev_node_output_type, NodeInputOrOutputType.FP16)
- elif node.target in METHS_IO_TYPE_FP32_OR_INT8:
- first_arg = get_normalized_nth_input(node, gm, 0)
- assert isinstance(first_arg, Node)
- (
- _prev_node_input_type,
- prev_node_output_type,
- ) = get_node_first_input_and_output_type(
- first_arg, gm, logger_cls, node_type_to_io_type_map
- )
- return (prev_node_output_type, prev_node_output_type)
- return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
- else:
- return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
- def get_node_input_qparams(
- node: Node,
- gm: GraphModule,
- node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]],
- ) -> Optional[Tuple[Union[torch.Tensor, float], Union[torch.Tensor, int]]]:
- """
- Returns the qparams (scale, zero_point) of the first input to `node`,
- if they can be inferred from the graph.
- """
- prev_node = get_normalized_nth_input(node, gm, 0)
- if not isinstance(prev_node, Node):
- return None
- MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["mods_io_type_fp32_or_int8"]
- def _get_scale_zp_from_function_args(node, gm, scale_arg_idx, zp_arg_idx):
- scale_node = get_normalized_nth_input(node, gm, scale_arg_idx)
- zp_node = get_normalized_nth_input(node, gm, zp_arg_idx)
- assert isinstance(scale_node, Node) and isinstance(scale_node.target, str)
- assert isinstance(zp_node, Node) and isinstance(zp_node.target, str)
- scale_obj = getattr_from_fqn(gm, scale_node.target)
- zp_obj = getattr_from_fqn(gm, zp_node.target)
- return (scale_obj, zp_obj)
- if prev_node.op == "call_function":
- # quantize - read the args directly
- if prev_node.target == torch.quantize_per_tensor:
- return _get_scale_zp_from_function_args(prev_node, gm, 1, 2)
- elif prev_node.target in (toq.add, toq.add_relu, toq.mul, toq.mul_relu):
- return _get_scale_zp_from_function_args(prev_node, gm, 2, 3)
- return None
- # TODO(future PR): handle more functionals
- # TODO(future PR): handle functional ops which inherit qparams from input
- elif prev_node.op == "call_module":
- # get type of the module
- assert isinstance(prev_node.target, str)
- module_obj = getattr_from_fqn(gm, prev_node.target)
- if isinstance(
- module_obj,
- (
- nnq.Linear,
- nnq.Conv1d,
- nnq.Conv2d,
- nniq.ConvReLU2d,
- nnq.Conv3d,
- nnq.BatchNorm2d,
- nnq.BatchNorm3d,
- nnq.ConvTranspose1d,
- nnq.ConvTranspose2d,
- nnq.ELU,
- nnq.GroupNorm,
- nnq.InstanceNorm1d,
- nnq.InstanceNorm2d,
- nnq.InstanceNorm3d,
- nnq.LayerNorm,
- nnq.Hardswish,
- nnq.LeakyReLU,
- nnq.ReLU6,
- nniq.BNReLU2d,
- nniq.BNReLU3d,
- nniq.ConvReLU1d,
- nniq.ConvReLU2d,
- nniq.ConvReLU3d,
- nniq.LinearReLU,
- ),
- ):
- return (module_obj.scale, module_obj.zero_point) # type: ignore[return-value]
- is_known_fp32_or_int8_input_module = any(
- isinstance(module_obj, target_type) for target_type in MODS_IO_TYPE_FP32_OR_INT8 # type: ignore[arg-type]
- )
- if is_known_fp32_or_int8_input_module:
- return get_node_input_qparams(prev_node, gm, node_type_to_io_type_map)
- return None
- def return_first_non_observer_node(
- node: Node,
- gm: GraphModule,
- ) -> Node:
- """
- If node is not an observer, returns it. If node is an observer,
- navigates up the graph and returns the first parent which is not an
- observer. For example,
- graph: (node_non_obs), node = node_non_obs : returns node_non_obs
- graph: (node_non_obs -> obs0), node = obs0 : returns node_non_obs
- graph: (node_non_obs -> obs0 -> fq0), node = fq0 : returns node_non_obs
- """
- if node.op == "call_module":
- node_obj = getattr_from_fqn(gm, node.target) # type: ignore[arg-type]
- if _is_activation_post_process(node_obj):
- assert len(node.args) == 1
- assert isinstance(node.args[0], Node)
- node = node.args[0]
- # code duplication intended, not worth refactoring
- assert isinstance(node.target, str)
- node_obj = getattr_from_fqn(gm, node.target)
- if _is_activation_post_process(node_obj):
- assert len(node.args) == 1
- assert isinstance(node.args[0], Node)
- node = node.args[0]
- return node
- def get_number_of_non_param_args(
- node: Node,
- gm: GraphModule,
- ) -> int:
- """
- Assumes that all non-param args occur first. Returns the number of
- non-param args expected for a node. For example, for
- F.linear(x, weight, bias)
- Returns 1, because x is a non-param arg and weight and bias are params.
- For
- lstm_mod(x, hid)
- Returns 2, because both x and hid are non-param args.
- """
- if node.op == "call_module":
- node_obj = getattr_from_fqn(gm, node.target) # type: ignore[arg-type]
- if isinstance(node_obj, nn.LSTM):
- return 2
- # default is 1
- return 1
- def get_arg_indices_of_inputs_to_log(node: Node) -> List[int]:
- """
- Returns the indices of args of the node which we should attach
- loggers to, if input logging is enabled.
- For example,
- * for (x + y), returns [0, 1]
- * for (1 + y), returns [1]
- * for (x + 1), returns [0]
- * for (linear(x, w, b)) returns [0]
- * by default, returns [0]
- """
- if len(node.args) == 0:
- return []
- if node.op == "call_function" and (
- # TODO(future PR): use relationship map instead of hardcoding
- node.target in (torch.add, torch.ops.quantized.add, operator.add)
- or node.target in (torch.mul, torch.ops.quantized.mul, operator.mul)
- ):
- result = []
- for i in range(2):
- if type(node.args[i]) == Node:
- result.append(i)
- return result
- return [0]
- def get_target_type_str(node: Node, gm: GraphModule) -> str:
- """
- Returns a string representation of the type of the function or module
- pointed to by this node, or '' for other node types.
- """
- target_type = ""
- if node.op in ("call_function", "call_method"):
- target_type = torch.typename(node.target)
- elif node.op == "call_module":
- assert isinstance(node.target, str)
- target_mod = getattr_from_fqn(gm, node.target)
- target_type = torch.typename(target_mod)
- return target_type
- def rekey_logger_info_on_node_name_of_model(
- results: NSResultsType,
- model_name: str,
- ) -> NSResultsType:
- """
- Rekeys the layer name of a results dictionary to use node names
- from `model_name`.
- For example, transforms
- {'base_op_1_0': {'node_output': {'model_a':
- [{'ref_node_name': 'linear1', ...}]}}}
- into
- {'linear1': {'node_output': {'model_a':
- [{'ref_node_name': 'linear1', ...}]}}}
- Note: we cannot use these node names directly because they are not
- guaranteed to be consistent across models. This is why we extract
- the results first and rekey afterwards.
- """
- new_results = {}
- for old_layer_name, result_type_to_results in results.items():
- new_layer_name = None
- for _result_type, model_name_to_results in result_type_to_results.items():
- for cur_model_name, list_of_results in model_name_to_results.items():
- if cur_model_name == model_name:
- assert len(list_of_results)
- new_layer_name = list_of_results[0]["ref_node_name"]
- else:
- continue
- if new_layer_name is not None:
- new_results[new_layer_name] = result_type_to_results
- else:
- new_results[old_layer_name] = result_type_to_results
- return new_results
- def maybe_add_missing_fqns(results: NSResultsType) -> None:
- """
- If `fqn` entries are filled in for one of the models in `results`, copies
- them over to any models which do not have them filled out.
- A common use case benefitting from this is comparing a model prepared by
- quantization to a quantized model. In this case, the model prepared by
- quantization would have `fqn` entries, and the quantized model would not.
- """
- # Check in the first result to find any model with fqn entries defined.
- model_name_with_fqns = None
- for layer_name, result_type_to_results in results.items():
- for result_type, model_name_to_results in result_type_to_results.items():
- for model_name, model_results in model_name_to_results.items():
- if len(model_results) > 0:
- if model_results[0]["fqn"] is not None:
- model_name_with_fqns = model_name
- break
- break
- break
- if model_name_with_fqns:
- for layer_name, result_type_to_results in results.items():
- for result_type, model_name_to_results in result_type_to_results.items():
- ref_model_results = model_name_to_results[model_name_with_fqns]
- for model_name, model_results in model_name_to_results.items():
- if model_name == model_name_with_fqns:
- continue
- for i in range(len(model_results)):
- fqn = ref_model_results[i]["fqn"]
- model_results[i]["fqn"] = fqn
- def maybe_dequantize_first_two_tensor_args_and_handle_tuples(f):
- def inner(*args, **kwargs):
- a0, a1, *a_other = args
- if (isinstance(a0, tuple) and isinstance(a1, tuple)) or (
- isinstance(a0, list) and isinstance(a1, list)
- ):
- results = []
- for el0, el1 in zip(a0, a1):
- new_args = (el0, el1, *a_other)
- results.append(inner(*new_args, **kwargs))
- return results
- elif isinstance(a0, torch.Tensor) and isinstance(a1, torch.Tensor):
- if a0.is_quantized:
- a0 = a0.dequantize()
- if a1.is_quantized:
- a1 = a1.dequantize()
- # for the purposes of this util, only handle floats
- if a0.dtype != torch.float or a1.dtype != torch.float:
- return None
- new_args = (a0, a1, *a_other)
- return f(*new_args, **kwargs)
- return inner
- @maybe_dequantize_first_two_tensor_args_and_handle_tuples
- def compute_sqnr(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
- """
- Computes the SQNR between `x` and `y`.
- Args:
- x: Tensor or tuple of tensors
- y: Tensor or tuple of tensors
- Return:
- float or tuple of floats
- """
- Ps = torch.norm(x)
- Pn = torch.norm(x - y)
- return 20 * torch.log10(Ps / Pn)
- @maybe_dequantize_first_two_tensor_args_and_handle_tuples
- def compute_normalized_l2_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
- """
- Computes the normalized L2 error between `x` and `y`.
- Args:
- x: Tensor or tuple of tensors
- y: Tensor or tuple of tensors
- Return:
- float or tuple of floats
- """
- return torch.sqrt(((x - y) ** 2).sum() / (x ** 2).sum())
- @maybe_dequantize_first_two_tensor_args_and_handle_tuples
- def compute_cosine_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
- """
- Computes the cosine similarity between `x` and `y`.
- Args:
- x: Tensor or tuple of tensors
- y: Tensor or tuple of tensors
- Return:
- float or tuple of floats
- """
- # For convolutions, the shape of the quantized weight has one additional
- # dimension compared to the shape of the fp32 weight. Match the shapes
- # to enable cosine similarity comparison.
- x = x.reshape(1, -1)
- y = y.reshape(1, -1)
- return torch.nn.functional.cosine_similarity(x, y)
- def op_type_supports_shadowing(node: Node) -> bool:
- if node.op == 'call_function':
- if node.target in (torch.add, torch.mul, operator.add, operator.mul, torch.cat, torch.stack):
- # shadowing for ops with multiple tensor inputs is not implemented yet
- return False
- return True
- def get_normalized_nth_input(node: Node, gm: GraphModule, idx: int) -> Node:
- """
- Given a node, gets the n'th input to that node, normalizing
- args and kwargs to the best of its ability.
- """
- try:
- norm_args_and_kwargs = node.normalized_arguments(
- gm, normalize_to_only_use_kwargs=True)
- if norm_args_and_kwargs is not None:
- norm_args, norm_kwargs = norm_args_and_kwargs
- assert len(norm_args) + len(norm_kwargs) > idx
- if idx < len(norm_args):
- return norm_args[idx]
- else:
- # note: in Python 3.7+ dicts are ordered
- return list(norm_kwargs.values())[idx]
- else:
- assert len(node.args) + len(node.kwargs) > idx
- if idx < len(node.args):
- return node.args[idx] # type: ignore[return-value]
- else:
- kwargs_idx = idx + len(node.args)
- return list(node.kwargs.values())[kwargs_idx] # type: ignore[return-value]
- except RuntimeError:
- # this RuntimeError happens when node argument normalization
- # requires typehints to proceed, such as for torch.add where
- # either the first, second or both arguments could be tensors
- assert len(node.args) + len(node.kwargs) > idx
- if idx < len(node.args):
- return node.args[idx] # type: ignore[return-value]
- else:
- kwargs_idx = idx + len(node.args)
- return list(node.kwargs.values())[kwargs_idx] # type: ignore[return-value]
|