parametrization.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import torch
  2. from torch import nn
  3. from torch.nn.utils.parametrize import is_parametrized
  4. def module_contains_param(module, parametrization):
  5. if is_parametrized(module):
  6. # see if any of the module tensors have a parametriztion attached that matches the one passed in
  7. return any(
  8. [
  9. any(isinstance(param, parametrization) for param in param_list)
  10. for key, param_list in module.parametrizations.items()
  11. ]
  12. )
  13. return False
  14. # Structured Pruning Parameterizations
  15. class FakeStructuredSparsity(nn.Module):
  16. r"""
  17. Parametrization for Structured Pruning. Like FakeSparsity, this should be attached to
  18. the 'weight' or any other parameter that requires a mask.
  19. Instead of an element-wise bool mask, this parameterization uses a row-wise bool mask.
  20. """
  21. def __init__(self, mask):
  22. super().__init__()
  23. self.register_buffer("mask", mask)
  24. def forward(self, x):
  25. assert isinstance(self.mask, torch.Tensor)
  26. assert self.mask.shape[0] == x.shape[0]
  27. shape = [1] * len(x.shape)
  28. shape[0] = -1
  29. return self.mask.reshape(shape) * x
  30. def state_dict(self, *args, **kwargs):
  31. # avoid double saving masks
  32. return {}
  33. class BiasHook:
  34. def __init__(self, parametrization, prune_bias):
  35. self.param = parametrization
  36. self.prune_bias = prune_bias
  37. def __call__(self, module, input, output):
  38. if getattr(module, "_bias", None) is not None:
  39. bias = module._bias.data
  40. if self.prune_bias:
  41. bias[~self.param.mask] = 0
  42. # reshape bias to broadcast over output dimensions
  43. idx = [1] * len(output.shape)
  44. idx[1] = -1
  45. bias = bias.reshape(idx)
  46. output += bias
  47. return output