123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import torch.ao.nn.quantized.dynamic as nnqd
- import torch.ao.nn.quantized as nnq
- import torch.ao.nn.intrinsic.qat as nniqat
- import torch.ao.nn.qat as nnqat
- import torch.ao.nn.intrinsic as nni
- import torch.ao.nn.intrinsic.quantized as nniq
- toq = torch.ops.quantized
- from torch.fx import GraphModule
- from torch.fx.graph import Node
- from .utils import (
- get_target_type_str,
- getattr_from_fqn,
- return_first_non_observer_node,
- )
- from .ns_types import (
- NSSingleResultValuesType,
- NSSingleResultType,
- )
- from typing import List, Optional, Dict, Callable
- def mod_weight_detach(mod: nn.Module) -> torch.Tensor:
- return mod.weight.detach() # type: ignore[operator]
- def mod_0_weight_detach(mod: nn.Module) -> torch.Tensor:
- return mod[0].weight.detach() # type: ignore[index]
- def mod_weight_bias_0(mod: nn.Module) -> torch.Tensor:
- return mod._weight_bias()[0] # type: ignore[operator]
- def get_lstm_weight(mod: nn.Module) -> List[torch.Tensor]:
- res = []
- for idx, param_name in enumerate(mod._flat_weights_names): # type: ignore[arg-type]
- if 'weight_ih_l' in param_name or 'weight_hh_l' in param_name:
- param_value = mod._flat_weights[idx].detach() # type: ignore[index]
- res.append(param_value)
- return res
- def get_qlstm_weight(mod: nn.Module) -> List[torch.Tensor]:
- res = []
- for weight_value in mod._all_weight_values: # type: ignore[union-attr]
- res.append(weight_value.param.__getstate__()[0][4][0].__getstate__()[0][0])
- res.append(weight_value.param.__getstate__()[0][4][1].__getstate__()[0][0])
- return res
- def get_conv_mod_weight(mod: nn.Module) -> torch.Tensor:
- if (
- isinstance(mod, (nn.Conv1d, nn.Conv2d, nn.Conv3d))
- ):
- return mod.weight.detach()
- elif (
- isinstance(mod, (nni.ConvReLU1d, nni.ConvReLU2d, nni.ConvReLU3d))
- ):
- return mod[0].weight.detach()
- else:
- return mod._weight_bias()[0] # type: ignore[operator]
- def get_linear_mod_weight(mod: nn.Module) -> torch.Tensor:
- if isinstance(mod, nn.Linear):
- return mod.weight.detach()
- elif isinstance(mod, nni.LinearReLU):
- return mod[0].weight.detach()
- else:
- return mod._weight_bias()[0] # type: ignore[operator]
- def get_lstm_mod_weights(mod: nn.Module) -> List[torch.Tensor]:
- # TODO(future PR): make more generic, handle everything
- if isinstance(mod, nn.LSTM):
- res = []
- for idx, param_name in enumerate(mod._flat_weights_names):
- if 'weight_ih_l' in param_name or 'weight_hh_l' in param_name:
- param_value = mod._flat_weights[idx].detach()
- res.append(param_value)
- return res
- else:
- assert isinstance(mod, nnqd.LSTM), f"type {type(res)} not handled yet"
- res = []
- for weight_value in mod._all_weight_values:
- res.append(weight_value.param.__getstate__()[0][4][0].__getstate__()[0][0])
- res.append(weight_value.param.__getstate__()[0][4][1].__getstate__()[0][0])
- return res
- def get_conv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
- # traverse backwards from the weight arg, accounting for any observers
- weight_arg_node = node.args[1]
- assert isinstance(weight_arg_node, Node)
- weight_node = return_first_non_observer_node(weight_arg_node, gm)
- assert isinstance(weight_node, Node)
- assert weight_node.op == 'get_attr'
- weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type]
- return weight.detach()
- def get_qconv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
- # qconv state is arg 1
- qconv_state_node = node.args[1]
- assert isinstance(qconv_state_node, Node)
- assert qconv_state_node.op == 'get_attr'
- qconv_state_obj = getattr_from_fqn(gm, qconv_state_node.target) # type: ignore[arg-type]
- return qconv_state_obj.weight()
- def get_linear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
- # traverse backwards from the weight arg, accounting for any observers
- # supported patterns:
- # weight -> obs -> linear
- # weight -> to(torch.float16) -> dequantize -> linear
- linear_second_arg = node.args[1]
- assert isinstance(linear_second_arg, Node)
- if linear_second_arg.op == 'call_module':
- # weight -> obs -> linear
- weight_arg_node = node.args[1]
- assert isinstance(weight_arg_node, Node)
- weight_node = weight_arg_node.args[0]
- assert isinstance(weight_node, Node)
- assert weight_node.op == 'get_attr'
- weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type]
- return weight.detach()
- elif linear_second_arg.op == 'call_method':
- # weight -> to(torch.float16) -> dequantize -> linear
- assert linear_second_arg.op == 'call_method'
- dequant_node = node.args[1]
- assert isinstance(dequant_node, Node)
- to_fp16_node = dequant_node.args[0]
- assert isinstance(to_fp16_node, Node)
- # extract the dtype, so we can cast to it before returning
- target_dtype = to_fp16_node.args[1]
- weight_node = to_fp16_node.args[0]
- assert isinstance(weight_node, Node)
- assert weight_node.op == 'get_attr'
- weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type]
- # return the weight with fp16 cast
- return weight.detach().to(target_dtype)
- else:
- assert linear_second_arg.op == 'get_attr'
- weight = getattr_from_fqn(gm, linear_second_arg.target) # type: ignore[arg-type]
- return weight.detach()
- def get_qlinear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
- # packed weight is arg 1
- packed_weight_node = node.args[1]
- assert isinstance(packed_weight_node, Node)
- assert packed_weight_node.op == 'get_attr'
- packed_weight = getattr_from_fqn(gm, packed_weight_node.target) # type: ignore[arg-type]
- # TODO(future PR): why does packed_weight.unpack() not work?
- (weight, _bias), _name = packed_weight.__getstate__()
- return weight
- def get_op_to_type_to_weight_extraction_fn() -> Dict[str, Dict[Callable, Callable]]:
- op_to_type_to_weight_extraction_fn: Dict[str, Dict[Callable, Callable]] = {
- 'call_module': {
- # Conv1d
- nn.Conv1d: mod_weight_detach,
- nni.ConvReLU1d: mod_0_weight_detach,
- nnq.Conv1d: mod_weight_bias_0,
- nnqat.Conv1d: mod_weight_detach,
- nniqat.ConvBn1d: mod_weight_detach,
- nniqat.ConvBnReLU1d: mod_weight_detach,
- nniqat.ConvReLU1d: mod_weight_detach,
- nniq.ConvReLU1d: mod_weight_bias_0,
- # Conv2d
- nn.Conv2d: mod_weight_detach,
- nni.ConvReLU2d: mod_0_weight_detach,
- nnq.Conv2d: mod_weight_bias_0,
- nnqat.Conv2d: mod_weight_detach,
- nniqat.ConvBn2d: mod_weight_detach,
- nniqat.ConvBnReLU2d: mod_weight_detach,
- nniqat.ConvReLU2d: mod_weight_detach,
- nniq.ConvReLU2d: mod_weight_bias_0,
- # Conv3d
- nn.Conv3d: mod_weight_detach,
- nni.ConvReLU3d: mod_0_weight_detach,
- nnq.Conv3d: mod_weight_bias_0,
- nnqat.Conv3d: mod_weight_detach,
- nniqat.ConvBn3d: mod_weight_detach,
- nniqat.ConvBnReLU3d: mod_weight_detach,
- nniqat.ConvReLU3d: mod_weight_detach,
- nniq.ConvReLU3d: mod_weight_bias_0,
- # Linear
- nn.Linear: mod_weight_detach,
- nnq.Linear: mod_weight_bias_0,
- nni.LinearReLU: mod_0_weight_detach,
- nniq.LinearReLU: mod_weight_bias_0,
- nnqat.Linear: mod_weight_detach,
- nnqd.Linear: mod_weight_bias_0,
- nniqat.LinearReLU: mod_weight_detach,
- nniqat.LinearBn1d: mod_weight_detach,
- nn.modules.linear.NonDynamicallyQuantizableLinear: mod_weight_detach,
- # LSTM
- nn.LSTM: get_lstm_weight,
- nnqd.LSTM: get_qlstm_weight,
- },
- 'call_function': {
- # Conv
- F.conv1d: get_conv_fun_weight,
- F.conv2d: get_conv_fun_weight,
- F.conv3d: get_conv_fun_weight,
- toq.conv1d: get_qconv_fun_weight,
- toq.conv2d: get_qconv_fun_weight,
- toq.conv3d: get_qconv_fun_weight,
- toq.conv1d_relu: get_qconv_fun_weight,
- toq.conv2d_relu: get_qconv_fun_weight,
- toq.conv3d_relu: get_qconv_fun_weight,
- # Linear
- F.linear: get_linear_fun_weight,
- toq.linear: get_qlinear_fun_weight,
- toq.linear_relu: get_qlinear_fun_weight,
- },
- }
- return op_to_type_to_weight_extraction_fn
- def extract_weight_from_node(
- node: Node,
- gm: GraphModule,
- op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
- ) -> Optional[NSSingleResultType]:
- res_type = NSSingleResultValuesType.WEIGHT.value
- # Not all graphmodules have _node_name_to_scope, so only fill it
- # out if it exists.
- fqn = None
- if hasattr(gm, '_node_name_to_scope'):
- fqn = gm._node_name_to_scope[node.name][0] # type: ignore[index]
- if op_to_type_to_weight_extraction_fn is None:
- op_to_type_to_weight_extraction_fn = get_op_to_type_to_weight_extraction_fn()
- ref_node_type = get_target_type_str(node, gm)
- # for extracting weights, these are always the same
- prev_node_type = ref_node_type
- if node.op == 'call_function':
- function_mapping = op_to_type_to_weight_extraction_fn['call_function']
- for target_fn_type, weight_extraction_fn in function_mapping.items():
- if node.target == target_fn_type:
- weight = weight_extraction_fn(node, gm)
- return {
- 'type': res_type,
- 'values': [weight],
- 'prev_node_name': node.name,
- 'prev_node_target_type': prev_node_type,
- 'ref_node_name': node.name,
- 'ref_node_target_type': ref_node_type,
- 'index_within_arg': 0,
- 'index_of_arg': 0,
- 'fqn': fqn,
- }
- elif node.op == 'call_module':
- # for call_module, we need to look up the modules to do the type check
- assert isinstance(node.target, str)
- mod = getattr_from_fqn(gm, node.target)
- module_mapping = op_to_type_to_weight_extraction_fn['call_module']
- for target_mod_type, weight_extraction_fn in module_mapping.items():
- if type(mod) == target_mod_type:
- weight = weight_extraction_fn(mod)
- return {
- 'type': res_type,
- 'values': [weight],
- 'prev_node_name': node.name,
- 'prev_node_target_type': prev_node_type,
- 'ref_node_name': node.name,
- 'ref_node_target_type': ref_node_type,
- 'index_within_arg': 0,
- 'index_of_arg': 0,
- 'fqn': fqn,
- }
- return None
|