base_scheduler.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. from torch.ao.pruning import BaseSparsifier
  2. from functools import wraps
  3. import warnings
  4. import weakref
  5. __all__ = ["BaseScheduler"]
  6. class BaseScheduler:
  7. def __init__(self, sparsifier, last_epoch=-1, verbose=False):
  8. # Attach sparsifier
  9. if not isinstance(sparsifier, BaseSparsifier):
  10. raise TypeError('{} is not an instance of torch.ao.pruning.BaseSparsifier'.format(
  11. type(sparsifier).__name__))
  12. self.sparsifier = sparsifier
  13. # Initialize epoch and base sparsity levels
  14. self.base_sl = [group['sparsity_level'] for group in sparsifier.groups]
  15. self.last_epoch = last_epoch
  16. # Following https://github.com/pytorch/pytorch/issues/20124
  17. # We would like to ensure that `scheduler.step()` is called after
  18. # `sparsifier.step()`
  19. def with_counter(method):
  20. if getattr(method, '_with_counter', False):
  21. # `sparsifier.step()` has already been replaced, return.
  22. return method
  23. # Keep a weak reference to the sparsifier instance to prevent
  24. # cyclic references.
  25. instance_ref = weakref.ref(method.__self__)
  26. # Get the unbound method for the same purpose.
  27. func = method.__func__
  28. cls = instance_ref().__class__
  29. del method
  30. @wraps(func)
  31. def wrapper(*args, **kwargs):
  32. instance = instance_ref()
  33. instance._step_count += 1 # type: ignore[union-attr]
  34. wrapped = func.__get__(instance, cls)
  35. return wrapped(*args, **kwargs)
  36. # Note that the returned function here is no longer a bound method,
  37. # so attributes like `__func__` and `__self__` no longer exist.
  38. wrapper._with_counter = True # type: ignore[attr-defined]
  39. return wrapper
  40. self.sparsifier.step = with_counter(self.sparsifier.step) # type: ignore[assignment]
  41. self.sparsifier._step_count = 0 # type: ignore[attr-defined]
  42. self._step_count: int = 0
  43. self.verbose = verbose
  44. # Housekeeping
  45. self._get_sl_called_within_step: bool = False
  46. self.step()
  47. def state_dict(self):
  48. """Returns the state of the scheduler as a :class:`dict`.
  49. It contains an entry for every variable in self.__dict__ which
  50. is not the sparsifier.
  51. """
  52. return {key: value for key, value in self.__dict__.items() if key != 'sparsifier'}
  53. def load_state_dict(self, state_dict):
  54. """Loads the schedulers state.
  55. Args:
  56. state_dict (dict): scheduler state. Should be an object returned
  57. from a call to :meth:`state_dict`.
  58. """
  59. self.__dict__.update(state_dict)
  60. def get_last_sl(self):
  61. """ Return last computed sparsity level by current scheduler.
  62. """
  63. return self._last_sl
  64. def get_sl(self):
  65. # Compute sparsity level using chainable form of the scheduler
  66. # Note: This method is not intended to be called directly, and is only
  67. # used by the ".step" method. Use .get_last_sl() instead.
  68. if not self._get_sl_called_within_step:
  69. warnings.warn(
  70. "To get the last sparsity level computed by the scheduler, "
  71. "please use `get_last_sl()`.")
  72. raise NotImplementedError
  73. def print_sl(self, is_verbose, group, sl, epoch=None):
  74. """Display the current sparsity level.
  75. """
  76. if is_verbose:
  77. if epoch is None:
  78. print('Adjusting sparsity level'
  79. ' of group {} to {:.4e}.'.format(group, sl))
  80. else:
  81. print('Epoch {:5d}: adjusting sparsity level'
  82. ' of group {} to {:.4e}.'.format(epoch, group, sl))
  83. def __repr__(self):
  84. format_string = self.__class__.__name__ + ' ('
  85. format_string += '\n'
  86. format_string += 'Sparsifier {0}\n'.format(self.sparsifier)
  87. format_string += ' {0}: {1}\n'.format('base_sl', self.base_sl)
  88. format_string += ')'
  89. return format_string
  90. def step(self, epoch=None):
  91. # Raise warning if trying to call scheduler step before the sparsifier.
  92. # https://github.com/pytorch/pytorch/issues/20124
  93. if self._step_count == 1:
  94. if not hasattr(self.sparsifier.step, "_with_counter"):
  95. warnings.warn("Seems like `sparsifier.step()` has been overridden after sparsity scheduler "
  96. "initialization. Please, make sure to call `sparsifier.step()` before "
  97. "`scheduler.step()`.", UserWarning)
  98. # Just check if there were two first scheduler.step() calls before sparsifier.step()
  99. elif self.sparsifier._step_count < 1: # type: ignore[attr-defined]
  100. warnings.warn("Detected call of `scheduler.step()` before `sparsifier.step()`. "
  101. "You have to make sure you run the sparsifier.step() BEFORE any "
  102. "calls to the scheduer.step().", UserWarning)
  103. self._step_count += 1
  104. class _enable_get_sl_call:
  105. def __init__(self, o):
  106. self.o = o
  107. def __enter__(self):
  108. self.o._get_sl_called_within_step = True
  109. return self
  110. def __exit__(self, type, value, traceback):
  111. self.o._get_sl_called_within_step = False
  112. with _enable_get_sl_call(self):
  113. self.last_epoch += 1
  114. values = self.get_sl()
  115. for i, data in enumerate(zip(self.sparsifier.groups, values)):
  116. param_group, sl = data
  117. param_group['sparsity_level'] = sl
  118. self.print_sl(self.verbose, i, sl, epoch)
  119. self._last_sl = [group['sparsity_level'] for group in self.sparsifier.groups]
  120. self.sparsifier.enable_mask_update = True
  121. def _make_sure_a_list(self, var):
  122. r"""Utility that extends it to the same length as the .groups, ensuring it is a list"""
  123. n = len(self.sparsifier.groups)
  124. if not isinstance(var, (list, tuple)):
  125. return [var] * n
  126. else:
  127. if len(var) != n:
  128. raise ValueError("Expected variable of length {n}, but got {got}".format(
  129. n=n, got=len(var)
  130. ))
  131. return list(var) # We want the result to be in a list, not tuple