base_sparsifier.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. import abc
  2. import copy
  3. from collections import defaultdict
  4. from typing import Any, Dict, Optional, Set, Tuple, List, Type
  5. import torch
  6. from torch import nn
  7. from torch.nn.utils import parametrize
  8. from .utils import (
  9. FakeSparsity,
  10. get_arg_info_from_tensor_fqn,
  11. module_to_fqn,
  12. )
  13. __all__ = ["BaseSparsifier"]
  14. SUPPORTED_MODULES = {
  15. nn.Linear
  16. }
  17. KEYS_NOT_IN_STATE_DICT = ["module", "module_fqn", "tensor_name"]
  18. __all__ = ["BaseSparsifier"]
  19. # TODO update desc with new config args
  20. class BaseSparsifier(abc.ABC):
  21. r"""Base class for all sparsifiers.
  22. Abstract methods that need to be implemented:
  23. - update_mask: Function to compute a new mask for all keys in the
  24. `groups`.
  25. Args:
  26. - model [nn.Module]: model to configure. The model itself is not saved
  27. but used for the state_dict saving / loading.
  28. - config [list]: configuration elements should be a dict map that includes
  29. `tensor_fqn` of tensors to sparsify
  30. - defaults [dict]: default configurations will be attached to the
  31. configuration. Only the keys that don't exist in the `config` will
  32. be updated.
  33. Example::
  34. >>> # xdoctest: +SKIP("Can't instantiate abstract class BaseSparsifier with abstract method update_mask")
  35. >>> config = [{'tensor_fqn': 'layer1.weight', 'tensor_fqn': 'linear2.weight2', 'sparsity_level': 0.5}]
  36. >>> defaults = {'sparsity_level': 0.7}
  37. >>> # model.layer1.weight will have `sparsity_level` = 0.7 (getting default)
  38. >>> sparsifier = BaseSparsifier(config, defaults)
  39. """
  40. def __init__(self, defaults: Optional[Dict[str, Any]] = None):
  41. super().__init__()
  42. self.defaults: Dict[str, Any] = defaults or {}
  43. self.state: Dict[str, Dict] = defaultdict(dict)
  44. self.groups: List[Dict[str, Any]] = []
  45. self.enable_mask_update = True
  46. def __getstate__(self) -> Dict[str, Any]:
  47. return {
  48. 'defaults': self.defaults,
  49. 'state': self.state,
  50. 'groups': self.groups,
  51. }
  52. def __setstate__(self, state: Dict[str, Dict[str, Any]]) -> None:
  53. self.__dict__.update(state)
  54. def __repr__(self):
  55. format_string = self.__class__.__name__ + ' ('
  56. for i, sparse_args in enumerate(self.groups):
  57. module = sparse_args['module']
  58. format_string += '\n'
  59. format_string += f'\tGroup {i}\n'
  60. format_string += f'\t module: {module}\n'
  61. for key in sorted(sparse_args.keys()):
  62. if key == "module":
  63. continue
  64. format_string += f"\t {key}: {sparse_args[key]}\n"
  65. format_string += ")"
  66. return format_string
  67. def state_dict(self) -> Dict[str, Any]:
  68. r"""Returns the state of the optimizer as a :class:`dict`.
  69. It contains:
  70. * state - current state of the sparsification.
  71. * groups - a list containing all sparsity configuration groups
  72. with the key 'tensor_fqn' specifying the path to the sparsified tensor within a model
  73. TODO: Need a clean way of loading the state of the "prepared" module
  74. """
  75. groups: List[Dict[str, Any]] = [
  76. dict(filter(lambda key_value: key_value[0] not in KEYS_NOT_IN_STATE_DICT , mg.items()))
  77. for mg in self.groups
  78. ]
  79. return {
  80. 'state': self.state,
  81. 'groups': groups,
  82. }
  83. def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True):
  84. groups = copy.deepcopy(state_dict['groups'])
  85. states = state_dict['state']
  86. for tensor_fqn, s in states.items():
  87. arg_info = get_arg_info_from_tensor_fqn(self.model, tensor_fqn)
  88. module = arg_info["module"]
  89. tensor_name = arg_info["tensor_name"]
  90. if strict and module is None:
  91. raise RuntimeError(f"Error loading {tensor_fqn} into the model")
  92. found = False
  93. for p in module.parametrizations[tensor_name]:
  94. if isinstance(p, FakeSparsity):
  95. found = True
  96. break
  97. if not found:
  98. p = FakeSparsity(torch.ones(getattr(module, tensor_name).shape))
  99. parametrize.register_parametrization(module, tensor_name, p)
  100. if s.get("mask", None) is not None:
  101. mask = s.pop("mask")
  102. p.mask = mask
  103. for mg in groups:
  104. if mg["tensor_fqn"] == tensor_fqn:
  105. mg.update(arg_info)
  106. self.__setstate__({"state": states, "groups": groups})
  107. def make_config_from_model(
  108. self,
  109. model: nn.Module,
  110. SUPPORTED_MODULES: Set[Type] = SUPPORTED_MODULES,
  111. ) -> None:
  112. self.config = []
  113. stack = [model]
  114. while stack:
  115. module = stack.pop()
  116. for name, child in module.named_children():
  117. if type(child) in SUPPORTED_MODULES:
  118. module_fqn = module_to_fqn(model, child)
  119. assert isinstance(module_fqn, str) # for mypy
  120. self.config.append(
  121. {"tensor_fqn": module_fqn + ".weight"}
  122. )
  123. else:
  124. stack.append(child)
  125. def prepare(self, model, config):
  126. r"""Prepares a model, by adding the parametrizations.
  127. Note::
  128. The model is modified inplace. If you need to preserve the original
  129. model, use copy.deepcopy.
  130. """
  131. self.model = model # TODO: Need to figure out how to load without this.
  132. self.config = config
  133. # If no config -- try getting all the supported layers
  134. if self.config is None:
  135. self.make_config_from_model(model)
  136. # TODO: Remove the configuration by reference ('module')
  137. for module_config in self.config:
  138. assert isinstance(module_config, dict), (
  139. "config elements should be dicts not modules i.e.:"
  140. "[{`tensor_fqn`: `foo.bar.weight`}, {`tensor_fqn`: ... }, ...]"
  141. )
  142. assert isinstance(self.defaults, Dict) # for mypy
  143. local_args = copy.deepcopy(self.defaults)
  144. local_args.update(module_config)
  145. tensor_fqn = local_args.get("tensor_fqn", None)
  146. assert tensor_fqn is not None, (
  147. "tensor_fqn is a required argument in the sparsity config which"
  148. "replaces previous `module` and [module]`fqn` arguments"
  149. )
  150. # populate all information from tensor_fqn
  151. info_from_tensor_fqn = get_arg_info_from_tensor_fqn(model, tensor_fqn)
  152. # check that whatever was put into local_args agrees with what was obtained
  153. # from tensor_fqn
  154. for key in info_from_tensor_fqn.keys():
  155. if key in local_args:
  156. assert (
  157. info_from_tensor_fqn[key] == local_args[key]
  158. or (
  159. key == "tensor_fqn"
  160. and "." + info_from_tensor_fqn[key] == local_args[key]
  161. )
  162. # info_from_tensor_fqn will chop leading '.' from tensor_fqn so ignore that
  163. ), (
  164. "Given both `{}` and `tensor_fqn` in the config, it is expected them to "
  165. "agree!".format(key)
  166. )
  167. local_args.update(info_from_tensor_fqn)
  168. self.groups.append(local_args)
  169. self._prepare()
  170. def _prepare(self, *args, **kwargs):
  171. r"""Adds mask parametrization to the layer weight
  172. """
  173. for config in self.groups:
  174. module = config['module']
  175. tensor_name = config['tensor_name']
  176. parametrization = config.get('parametrization', FakeSparsity)
  177. mask = config.get('mask', torch.ones_like(getattr(module, tensor_name)))
  178. self.state[config['tensor_fqn']]['mask'] = mask
  179. parametrize.register_parametrization(module, tensor_name, parametrization(mask))
  180. def squash_mask(self,
  181. params_to_keep: Optional[Tuple[str, ...]] = None,
  182. params_to_keep_per_layer: Optional[Dict[str, Tuple[str, ...]]] = None,
  183. *args, **kwargs):
  184. r"""Squashes the sparse masks into the appropriate tensors.
  185. If either the `params_to_keep` or `params_to_keep_per_layer` is set,
  186. the module will have a `sparse_params` dict attached to it.
  187. Args:
  188. params_to_keep: List of keys to save in the module or a dict
  189. representing the modules and keys that will have
  190. sparsity parameters saved
  191. params_to_keep_per_layer: Dict to specify the params that should be
  192. saved for specific layers. The keys in the dict
  193. should be the module fqn, while the values should
  194. be a list of strings with the names of the variables
  195. to save in the `sparse_params`
  196. Examples:
  197. >>> # xdoctest: +SKIP("locals are undefined")
  198. >>> # Don't save any sparse params
  199. >>> sparsifier.squash_mask()
  200. >>> hasattr(model.submodule1, 'sparse_params')
  201. False
  202. >>> # Keep sparse params per layer
  203. >>> sparsifier.squash_mask(
  204. ... params_to_keep_per_layer={
  205. ... 'submodule1.linear1': ('foo', 'bar'),
  206. ... 'submodule2.linear42': ('baz',)
  207. ... })
  208. >>> print(model.submodule1.linear1.sparse_params)
  209. {'foo': 42, 'bar': 24}
  210. >>> print(model.submodule2.linear42.sparse_params)
  211. {'baz': 0.1}
  212. >>> # Keep sparse params for all layers
  213. >>> sparsifier.squash_mask(params_to_keep=('foo', 'bar'))
  214. >>> print(model.submodule1.linear1.sparse_params)
  215. {'foo': 42, 'bar': 24}
  216. >>> print(model.submodule2.linear42.sparse_params)
  217. {'foo': 42, 'bar': 24}
  218. >>> # Keep some sparse params for all layers, and specific ones for
  219. >>> # some other layers
  220. >>> sparsifier.squash_mask(
  221. ... params_to_keep=('foo', 'bar'),
  222. ... params_to_keep_per_layer={
  223. ... 'submodule2.linear42': ('baz',)
  224. ... })
  225. >>> print(model.submodule1.linear1.sparse_params)
  226. {'foo': 42, 'bar': 24}
  227. >>> print(model.submodule2.linear42.sparse_params)
  228. {'foo': 42, 'bar': 24, 'baz': 0.1}
  229. """
  230. for config in self.groups:
  231. module = config['module']
  232. tensor_name = config['tensor_name']
  233. parametrize.remove_parametrizations(module, tensor_name,
  234. leave_parametrized=True)
  235. sparse_params = {}
  236. if params_to_keep is not None:
  237. global_params = {k: config[k] for k in params_to_keep}
  238. sparse_params.update(global_params)
  239. if params_to_keep_per_layer is not None:
  240. params = params_to_keep_per_layer.get(config["module_fqn"], None)
  241. if params is not None:
  242. per_layer_params = {k: config[k] for k in params}
  243. sparse_params.update(per_layer_params)
  244. if sparse_params:
  245. # TODO handle multiple tensor being quantized on a single module, where to store sparse_params?
  246. module.sparse_params = sparse_params
  247. def convert(self):
  248. # TODO: Call the torch.ao.utils.convert in here
  249. raise NotImplementedError(
  250. "`convert` is not implemented. Please, use "
  251. "`torch.ao.utils.convert` instead."
  252. )
  253. def step(self, use_path: bool = True) -> None:
  254. if not self.enable_mask_update:
  255. return
  256. with torch.no_grad():
  257. for config in self.groups:
  258. self.update_mask(**config)
  259. @abc.abstractmethod
  260. def update_mask(self, module: nn.Module, tensor_name: str, **kwargs):
  261. pass