utils.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. import torch
  2. from torch.fx import GraphModule
  3. from torch.nn.utils.fusion import fuse_conv_bn_weights
  4. # TODO[jerryzh168]: move this to a more general util function
  5. from torch.ao.quantization.fx.prepare import (
  6. _is_activation_post_process_node,
  7. )
  8. from collections import OrderedDict
  9. import operator
  10. # TODO[qihan]: longer term, this should happen in the dynamo stack as well
  11. def _get_renamed_nn_module_stack(nn_module_stack):
  12. # initialize with top level parent scope
  13. nn_module_stack_renamed = OrderedDict([("", None)])
  14. if nn_module_stack:
  15. # Rename module_key, e.g. "self_layer1_1__conv1" to "self.layer1.1._conv1", for easier downstream parsing
  16. prev_key = ""
  17. for key, value in nn_module_stack.items():
  18. if not prev_key:
  19. if key.startswith("self_"):
  20. new_key = key[5:]
  21. prev_key = new_key
  22. else:
  23. new_key = prev_key + "." + key[len(prev_key) + 6 :]
  24. nn_module_stack_renamed[new_key] = value
  25. prev_key = new_key
  26. return nn_module_stack_renamed
  27. def _get_tensor_constant_from_node(node, m):
  28. if node is None:
  29. return None
  30. assert node.op == "get_attr"
  31. return getattr(m, node.target)
  32. # fuse conv bn weights, inplace modification of the graph_module and graph
  33. def _fuse_conv_bn_(m: GraphModule) -> None:
  34. for n in m.graph.nodes:
  35. if n.op != "call_function" or n.target != torch.ops.aten.native_batch_norm.default:
  36. continue
  37. bn_op = n
  38. n = bn_op.args[0]
  39. if n.op != "call_function" or n.target != torch.ops.aten.convolution.default:
  40. continue
  41. conv_op = n
  42. # conv weight
  43. conv_w = _get_tensor_constant_from_node(conv_op.args[1], m)
  44. # conv bias
  45. conv_b = _get_tensor_constant_from_node(conv_op.args[2], m)
  46. transpose = conv_op.args[6]
  47. # bn weight
  48. bn_w = _get_tensor_constant_from_node(bn_op.args[1], m)
  49. # bn bias
  50. bn_b = _get_tensor_constant_from_node(bn_op.args[2], m)
  51. # bn running mean
  52. bn_rm = _get_tensor_constant_from_node(bn_op.args[3], m)
  53. # bn running variance
  54. bn_rv = _get_tensor_constant_from_node(bn_op.args[4], m)
  55. bn_eps = bn_op.args[7]
  56. fused_weight, fused_bias = fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=False)
  57. # update the weight and bias for conv
  58. conv_args = list(conv_op.args)
  59. # calling data since the fused_weight and fused_bias are nn.Parameter
  60. weight_attr_name = conv_args[1].target
  61. setattr(m, weight_attr_name, fused_weight)
  62. if conv_args[2] is not None:
  63. bias_attr_name = conv_args[2].target
  64. else:
  65. bias_attr_name = weight_attr_name + "_bias"
  66. with m.graph.inserting_before(conv_op):
  67. get_bias_node = m.graph.get_attr(bias_attr_name)
  68. conv_args[2] = get_bias_node
  69. setattr(m, bias_attr_name, fused_bias)
  70. conv_op.args = tuple(conv_args)
  71. # native_batch_norm has 3 outputs, we expect getitem calls on the output
  72. # and we want to replace the uses of getitem 0 with the output of conv
  73. #
  74. # Before:
  75. # conv -> bn - (first output) -> users1
  76. # \ - (second output) -> users2
  77. # \ - (third output) -> users3
  78. # After:
  79. # conv -> (first output) -> users1
  80. # bn -
  81. # \ - (second output) -> users2
  82. # \ - (third output) -> users3
  83. # if users2 and users3 are empty then bn will be removed through dead code elimination
  84. for user in bn_op.users:
  85. if user.op != "call_function" or user.target != operator.getitem or user.args[1] != 0:
  86. continue
  87. user.replace_all_uses_with(conv_op)
  88. m.graph.eliminate_dead_code()
  89. m.recompile()
  90. def _rearrange_weight_observer_for_addmm(
  91. model: GraphModule,
  92. ) -> None:
  93. """
  94. before:
  95. weight - t - observer \
  96. input - observer - addmm
  97. after:
  98. weight - observer - t \
  99. input - observer - addmm
  100. """
  101. named_modules = dict(model.named_modules(remove_duplicate=False))
  102. for node in model.graph.nodes:
  103. if node.target != torch.ops.aten.addmm.default:
  104. continue
  105. addmm = node
  106. maybe_weight_obs = addmm.args[2]
  107. if not _is_activation_post_process_node(maybe_weight_obs, named_modules):
  108. continue
  109. transpose_node = maybe_weight_obs.args[0]
  110. if transpose_node.target != torch.ops.aten.t.default:
  111. continue
  112. # swap the order of transpose and observation
  113. maybe_weight_obs.replace_input_with(transpose_node, transpose_node.args[0])
  114. # remove the transpose node
  115. with model.graph.inserting_after(maybe_weight_obs):
  116. args = list(transpose_node.args)
  117. args[0] = maybe_weight_obs
  118. new_transpose_node = model.graph.create_node(
  119. "call_function",
  120. torch.ops.aten.t.default,
  121. tuple(args),
  122. transpose_node.kwargs
  123. )
  124. addmm.replace_input_with(maybe_weight_obs, new_transpose_node)
  125. model.graph.eliminate_dead_code()
  126. model.graph.lint()