nearly_diagonal_sparsifier.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. import torch
  2. from . import base_sparsifier
  3. class NearlyDiagonalSparsifier(base_sparsifier.BaseSparsifier):
  4. r"""Nearly Diagonal Sparsifier
  5. This sparsifier creates a nearly diagonal mask to be applied to the weight matrix.
  6. Nearly Diagonal Matrix is a matrix that contains non-zero elements near the diagonal and the rest are zero.
  7. An example of a nearly diagonal matrix with degree (or nearliness) 3 and 5 are follows respectively.
  8. 1 1 0 0 1 1 1 0
  9. 1 1 1 0 1 1 1 1
  10. 0 1 1 1 1 1 1 1
  11. 0 0 1 1 0 1 1 1
  12. Note that a nearly diagonal matrix with degree 1 is just a matrix with main diagonal populated
  13. This sparsifier is controlled by one variable:
  14. 1. `nearliness` defines the number of non-zero diagonal lines that are closest to the main diagonal.
  15. Currently - supports only odd number
  16. Note:
  17. This can be accelerated (vectorized) once the Spdiagonal feature (PR: #78439) is landed or the banded matrix
  18. feature is landed: https://stackoverflow.com/questions/52463972/generating-banded-matrices-using-numpy
  19. Args:
  20. nearliness: The degree of nearliness (default = 1)
  21. """
  22. def __init__(self, nearliness: int = 1):
  23. defaults = {'nearliness': nearliness}
  24. super().__init__(defaults=defaults)
  25. def update_mask(self, module, tensor_name, nearliness,
  26. **kwargs):
  27. mask = getattr(module.parametrizations, tensor_name)[0].mask
  28. mask.data = torch.zeros_like(mask)
  29. if nearliness <= 0:
  30. return
  31. tensor = getattr(module, tensor_name)
  32. height, width = tensor.shape
  33. if nearliness % 2 == 0:
  34. raise ValueError("nearliness can only be an odd number")
  35. dist_to_diagonal = nearliness // 2
  36. # check
  37. if dist_to_diagonal >= min(height, width):
  38. raise ValueError("nearliness cannot be larger than the dimensions of tensor.")
  39. for row in range(0, height):
  40. # Bounds of entries that needs to be set to 1
  41. low = max(0, row - dist_to_diagonal)
  42. high = min(width, row + dist_to_diagonal + 1)
  43. mask[row, low:high].fill_(1)