import torch import torch.nn as nn import torch.nn.functional as F import torch.ao.nn.quantized.dynamic as nnqd import torch.ao.nn.quantized as nnq import torch.ao.nn.intrinsic.qat as nniqat import torch.ao.nn.qat as nnqat import torch.ao.nn.intrinsic as nni import torch.ao.nn.intrinsic.quantized as nniq toq = torch.ops.quantized from torch.fx import GraphModule from torch.fx.graph import Node from .utils import ( get_target_type_str, getattr_from_fqn, return_first_non_observer_node, ) from .ns_types import ( NSSingleResultValuesType, NSSingleResultType, ) from typing import List, Optional, Dict, Callable def mod_weight_detach(mod: nn.Module) -> torch.Tensor: return mod.weight.detach() # type: ignore[operator] def mod_0_weight_detach(mod: nn.Module) -> torch.Tensor: return mod[0].weight.detach() # type: ignore[index] def mod_weight_bias_0(mod: nn.Module) -> torch.Tensor: return mod._weight_bias()[0] # type: ignore[operator] def get_lstm_weight(mod: nn.Module) -> List[torch.Tensor]: res = [] for idx, param_name in enumerate(mod._flat_weights_names): # type: ignore[arg-type] if 'weight_ih_l' in param_name or 'weight_hh_l' in param_name: param_value = mod._flat_weights[idx].detach() # type: ignore[index] res.append(param_value) return res def get_qlstm_weight(mod: nn.Module) -> List[torch.Tensor]: res = [] for weight_value in mod._all_weight_values: # type: ignore[union-attr] res.append(weight_value.param.__getstate__()[0][4][0].__getstate__()[0][0]) res.append(weight_value.param.__getstate__()[0][4][1].__getstate__()[0][0]) return res def get_conv_mod_weight(mod: nn.Module) -> torch.Tensor: if ( isinstance(mod, (nn.Conv1d, nn.Conv2d, nn.Conv3d)) ): return mod.weight.detach() elif ( isinstance(mod, (nni.ConvReLU1d, nni.ConvReLU2d, nni.ConvReLU3d)) ): return mod[0].weight.detach() else: return mod._weight_bias()[0] # type: ignore[operator] def get_linear_mod_weight(mod: nn.Module) -> torch.Tensor: if isinstance(mod, nn.Linear): return mod.weight.detach() elif isinstance(mod, nni.LinearReLU): return mod[0].weight.detach() else: return mod._weight_bias()[0] # type: ignore[operator] def get_lstm_mod_weights(mod: nn.Module) -> List[torch.Tensor]: # TODO(future PR): make more generic, handle everything if isinstance(mod, nn.LSTM): res = [] for idx, param_name in enumerate(mod._flat_weights_names): if 'weight_ih_l' in param_name or 'weight_hh_l' in param_name: param_value = mod._flat_weights[idx].detach() res.append(param_value) return res else: assert isinstance(mod, nnqd.LSTM), f"type {type(res)} not handled yet" res = [] for weight_value in mod._all_weight_values: res.append(weight_value.param.__getstate__()[0][4][0].__getstate__()[0][0]) res.append(weight_value.param.__getstate__()[0][4][1].__getstate__()[0][0]) return res def get_conv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor: # traverse backwards from the weight arg, accounting for any observers weight_arg_node = node.args[1] assert isinstance(weight_arg_node, Node) weight_node = return_first_non_observer_node(weight_arg_node, gm) assert isinstance(weight_node, Node) assert weight_node.op == 'get_attr' weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type] return weight.detach() def get_qconv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor: # qconv state is arg 1 qconv_state_node = node.args[1] assert isinstance(qconv_state_node, Node) assert qconv_state_node.op == 'get_attr' qconv_state_obj = getattr_from_fqn(gm, qconv_state_node.target) # type: ignore[arg-type] return qconv_state_obj.weight() def get_linear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor: # traverse backwards from the weight arg, accounting for any observers # supported patterns: # weight -> obs -> linear # weight -> to(torch.float16) -> dequantize -> linear linear_second_arg = node.args[1] assert isinstance(linear_second_arg, Node) if linear_second_arg.op == 'call_module': # weight -> obs -> linear weight_arg_node = node.args[1] assert isinstance(weight_arg_node, Node) weight_node = weight_arg_node.args[0] assert isinstance(weight_node, Node) assert weight_node.op == 'get_attr' weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type] return weight.detach() elif linear_second_arg.op == 'call_method': # weight -> to(torch.float16) -> dequantize -> linear assert linear_second_arg.op == 'call_method' dequant_node = node.args[1] assert isinstance(dequant_node, Node) to_fp16_node = dequant_node.args[0] assert isinstance(to_fp16_node, Node) # extract the dtype, so we can cast to it before returning target_dtype = to_fp16_node.args[1] weight_node = to_fp16_node.args[0] assert isinstance(weight_node, Node) assert weight_node.op == 'get_attr' weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type] # return the weight with fp16 cast return weight.detach().to(target_dtype) else: assert linear_second_arg.op == 'get_attr' weight = getattr_from_fqn(gm, linear_second_arg.target) # type: ignore[arg-type] return weight.detach() def get_qlinear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor: # packed weight is arg 1 packed_weight_node = node.args[1] assert isinstance(packed_weight_node, Node) assert packed_weight_node.op == 'get_attr' packed_weight = getattr_from_fqn(gm, packed_weight_node.target) # type: ignore[arg-type] # TODO(future PR): why does packed_weight.unpack() not work? (weight, _bias), _name = packed_weight.__getstate__() return weight def get_op_to_type_to_weight_extraction_fn() -> Dict[str, Dict[Callable, Callable]]: op_to_type_to_weight_extraction_fn: Dict[str, Dict[Callable, Callable]] = { 'call_module': { # Conv1d nn.Conv1d: mod_weight_detach, nni.ConvReLU1d: mod_0_weight_detach, nnq.Conv1d: mod_weight_bias_0, nnqat.Conv1d: mod_weight_detach, nniqat.ConvBn1d: mod_weight_detach, nniqat.ConvBnReLU1d: mod_weight_detach, nniqat.ConvReLU1d: mod_weight_detach, nniq.ConvReLU1d: mod_weight_bias_0, # Conv2d nn.Conv2d: mod_weight_detach, nni.ConvReLU2d: mod_0_weight_detach, nnq.Conv2d: mod_weight_bias_0, nnqat.Conv2d: mod_weight_detach, nniqat.ConvBn2d: mod_weight_detach, nniqat.ConvBnReLU2d: mod_weight_detach, nniqat.ConvReLU2d: mod_weight_detach, nniq.ConvReLU2d: mod_weight_bias_0, # Conv3d nn.Conv3d: mod_weight_detach, nni.ConvReLU3d: mod_0_weight_detach, nnq.Conv3d: mod_weight_bias_0, nnqat.Conv3d: mod_weight_detach, nniqat.ConvBn3d: mod_weight_detach, nniqat.ConvBnReLU3d: mod_weight_detach, nniqat.ConvReLU3d: mod_weight_detach, nniq.ConvReLU3d: mod_weight_bias_0, # Linear nn.Linear: mod_weight_detach, nnq.Linear: mod_weight_bias_0, nni.LinearReLU: mod_0_weight_detach, nniq.LinearReLU: mod_weight_bias_0, nnqat.Linear: mod_weight_detach, nnqd.Linear: mod_weight_bias_0, nniqat.LinearReLU: mod_weight_detach, nniqat.LinearBn1d: mod_weight_detach, nn.modules.linear.NonDynamicallyQuantizableLinear: mod_weight_detach, # LSTM nn.LSTM: get_lstm_weight, nnqd.LSTM: get_qlstm_weight, }, 'call_function': { # Conv F.conv1d: get_conv_fun_weight, F.conv2d: get_conv_fun_weight, F.conv3d: get_conv_fun_weight, toq.conv1d: get_qconv_fun_weight, toq.conv2d: get_qconv_fun_weight, toq.conv3d: get_qconv_fun_weight, toq.conv1d_relu: get_qconv_fun_weight, toq.conv2d_relu: get_qconv_fun_weight, toq.conv3d_relu: get_qconv_fun_weight, # Linear F.linear: get_linear_fun_weight, toq.linear: get_qlinear_fun_weight, toq.linear_relu: get_qlinear_fun_weight, }, } return op_to_type_to_weight_extraction_fn def extract_weight_from_node( node: Node, gm: GraphModule, op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None, ) -> Optional[NSSingleResultType]: res_type = NSSingleResultValuesType.WEIGHT.value # Not all graphmodules have _node_name_to_scope, so only fill it # out if it exists. fqn = None if hasattr(gm, '_node_name_to_scope'): fqn = gm._node_name_to_scope[node.name][0] # type: ignore[index] if op_to_type_to_weight_extraction_fn is None: op_to_type_to_weight_extraction_fn = get_op_to_type_to_weight_extraction_fn() ref_node_type = get_target_type_str(node, gm) # for extracting weights, these are always the same prev_node_type = ref_node_type if node.op == 'call_function': function_mapping = op_to_type_to_weight_extraction_fn['call_function'] for target_fn_type, weight_extraction_fn in function_mapping.items(): if node.target == target_fn_type: weight = weight_extraction_fn(node, gm) return { 'type': res_type, 'values': [weight], 'prev_node_name': node.name, 'prev_node_target_type': prev_node_type, 'ref_node_name': node.name, 'ref_node_target_type': ref_node_type, 'index_within_arg': 0, 'index_of_arg': 0, 'fqn': fqn, } elif node.op == 'call_module': # for call_module, we need to look up the modules to do the type check assert isinstance(node.target, str) mod = getattr_from_fqn(gm, node.target) module_mapping = op_to_type_to_weight_extraction_fn['call_module'] for target_mod_type, weight_extraction_fn in module_mapping.items(): if type(mod) == target_mod_type: weight = weight_extraction_fn(mod) return { 'type': res_type, 'values': [weight], 'prev_node_name': node.name, 'prev_node_target_type': prev_node_type, 'ref_node_name': node.name, 'ref_node_target_type': ref_node_type, 'index_within_arg': 0, 'index_of_arg': 0, 'fqn': fqn, } return None