_equalize.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. import torch
  2. import copy
  3. from typing import Dict, Any
  4. __all__ = [
  5. "set_module_weight",
  6. "set_module_bias",
  7. "get_module_weight",
  8. "get_module_bias",
  9. "max_over_ndim",
  10. "min_over_ndim",
  11. "channel_range",
  12. "cross_layer_equalization",
  13. "equalize",
  14. "converged",
  15. ]
  16. _supported_types = {torch.nn.Conv2d, torch.nn.Linear}
  17. _supported_intrinsic_types = {torch.ao.nn.intrinsic.ConvReLU2d, torch.ao.nn.intrinsic.LinearReLU}
  18. _all_supported_types = _supported_types.union(_supported_intrinsic_types)
  19. def set_module_weight(module, weight) -> None:
  20. if type(module) in _supported_types:
  21. module.weight = torch.nn.Parameter(weight)
  22. else:
  23. module[0].weight = torch.nn.Parameter(weight)
  24. def set_module_bias(module, bias) -> None:
  25. if type(module) in _supported_types:
  26. module.bias = torch.nn.Parameter(bias)
  27. else:
  28. module[0].bias = torch.nn.Parameter(bias)
  29. def get_module_weight(module):
  30. if type(module) in _supported_types:
  31. return module.weight
  32. else:
  33. return module[0].weight
  34. def get_module_bias(module):
  35. if type(module) in _supported_types:
  36. return module.bias
  37. else:
  38. return module[0].bias
  39. def max_over_ndim(input, axis_list, keepdim=False):
  40. ''' Applies 'torch.max' over the given axises
  41. '''
  42. axis_list.sort(reverse=True)
  43. for axis in axis_list:
  44. input, _ = input.max(axis, keepdim)
  45. return input
  46. def min_over_ndim(input, axis_list, keepdim=False):
  47. ''' Applies 'torch.min' over the given axises
  48. '''
  49. axis_list.sort(reverse=True)
  50. for axis in axis_list:
  51. input, _ = input.min(axis, keepdim)
  52. return input
  53. def channel_range(input, axis=0):
  54. ''' finds the range of weights associated with a specific channel
  55. '''
  56. size_of_tensor_dim = input.ndim
  57. axis_list = list(range(size_of_tensor_dim))
  58. axis_list.remove(axis)
  59. mins = min_over_ndim(input, axis_list)
  60. maxs = max_over_ndim(input, axis_list)
  61. assert mins.size(0) == input.size(axis), "Dimensions of resultant channel range does not match size of requested axis"
  62. return maxs - mins
  63. def cross_layer_equalization(module1, module2, output_axis=0, input_axis=1):
  64. ''' Given two adjacent tensors', the weights are scaled such that
  65. the ranges of the first tensors' output channel are equal to the
  66. ranges of the second tensors' input channel
  67. '''
  68. if type(module1) not in _all_supported_types or type(module2) not in _all_supported_types:
  69. raise ValueError("module type not supported:", type(module1), " ", type(module2))
  70. weight1 = get_module_weight(module1)
  71. weight2 = get_module_weight(module2)
  72. if weight1.size(output_axis) != weight2.size(input_axis):
  73. raise TypeError("Number of output channels of first arg do not match \
  74. number input channels of second arg")
  75. bias = get_module_bias(module1)
  76. weight1_range = channel_range(weight1, output_axis)
  77. weight2_range = channel_range(weight2, input_axis)
  78. # producing scaling factors to applied
  79. weight2_range += 1e-9
  80. scaling_factors = torch.sqrt(weight1_range / weight2_range)
  81. inverse_scaling_factors = torch.reciprocal(scaling_factors)
  82. bias = bias * inverse_scaling_factors
  83. # formatting the scaling (1D) tensors to be applied on the given argument tensors
  84. # pads axis to (1D) tensors to then be broadcasted
  85. size1 = [1] * weight1.ndim
  86. size1[output_axis] = weight1.size(output_axis)
  87. size2 = [1] * weight2.ndim
  88. size2[input_axis] = weight2.size(input_axis)
  89. scaling_factors = torch.reshape(scaling_factors, size2)
  90. inverse_scaling_factors = torch.reshape(inverse_scaling_factors, size1)
  91. weight1 = weight1 * inverse_scaling_factors
  92. weight2 = weight2 * scaling_factors
  93. set_module_weight(module1, weight1)
  94. set_module_bias(module1, bias)
  95. set_module_weight(module2, weight2)
  96. def equalize(model, paired_modules_list, threshold=1e-4, inplace=True):
  97. ''' Given a list of adjacent modules within a model, equalization will
  98. be applied between each pair, this will repeated until convergence is achieved
  99. Keeps a copy of the changing modules from the previous iteration, if the copies
  100. are not that different than the current modules (determined by converged_test),
  101. then the modules have converged enough that further equalizing is not necessary
  102. Implementation of this referced section 4.1 of this paper https://arxiv.org/pdf/1906.04721.pdf
  103. Args:
  104. model: a model (nn.module) that equalization is to be applied on
  105. paired_modules_list: a list of lists where each sublist is a pair of two
  106. submodules found in the model, for each pair the two submodules generally
  107. have to be adjacent in the model to get expected/reasonable results
  108. threshold: a number used by the converged function to determine what degree
  109. similarity between models is necessary for them to be called equivalent
  110. inplace: determines if function is inplace or not
  111. '''
  112. if not inplace:
  113. model = copy.deepcopy(model)
  114. name_to_module : Dict[str, torch.nn.Module] = {}
  115. previous_name_to_module: Dict[str, Any] = {}
  116. name_set = {name for pair in paired_modules_list for name in pair}
  117. for name, module in model.named_modules():
  118. if name in name_set:
  119. name_to_module[name] = module
  120. previous_name_to_module[name] = None
  121. while not converged(name_to_module, previous_name_to_module, threshold):
  122. for pair in paired_modules_list:
  123. previous_name_to_module[pair[0]] = copy.deepcopy(name_to_module[pair[0]])
  124. previous_name_to_module[pair[1]] = copy.deepcopy(name_to_module[pair[1]])
  125. cross_layer_equalization(name_to_module[pair[0]], name_to_module[pair[1]])
  126. return model
  127. def converged(curr_modules, prev_modules, threshold=1e-4):
  128. ''' Tests for the summed norm of the differences between each set of modules
  129. being less than the given threshold
  130. Takes two dictionaries mapping names to modules, the set of names for each dictionary
  131. should be the same, looping over the set of names, for each name take the differnce
  132. between the associated modules in each dictionary
  133. '''
  134. if curr_modules.keys() != prev_modules.keys():
  135. raise ValueError("The keys to the given mappings must have the same set of names of modules")
  136. summed_norms = torch.tensor(0.)
  137. if None in prev_modules.values():
  138. return False
  139. for name in curr_modules.keys():
  140. curr_weight = get_module_weight(curr_modules[name])
  141. prev_weight = get_module_weight(prev_modules[name])
  142. difference = curr_weight.sub(prev_weight)
  143. summed_norms += torch.norm(difference)
  144. return bool(summed_norms < threshold)