lambda_scheduler.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. import warnings
  2. from .base_scheduler import BaseScheduler
  3. __all__ = ["LambdaSL"]
  4. class LambdaSL(BaseScheduler):
  5. """Sets the sparsity level of each parameter group to the final sl
  6. times a given function. When last_epoch=-1, sets initial sl as zero.
  7. Args:
  8. sparsifier (BaseSparsifier): Wrapped sparsifier.
  9. sl_lambda (function or list): A function which computes a multiplicative
  10. factor given an integer parameter epoch, or a list of such
  11. functions, one for each group in sparsifier.param_groups.
  12. last_epoch (int): The index of last epoch. Default: -1.
  13. verbose (bool): If ``True``, prints a message to stdout for
  14. each update. Default: ``False``.
  15. Example:
  16. >>> # Assuming sparsifier has two groups.
  17. >>> lambda1 = lambda epoch: epoch // 30
  18. >>> lambda2 = lambda epoch: 0.95 ** epoch
  19. >>> # xdoctest: +SKIP
  20. >>> scheduler = LambdaSL(sparsifier, sl_lambda=[lambda1, lambda2])
  21. >>> for epoch in range(100):
  22. >>> train(...)
  23. >>> validate(...)
  24. >>> scheduler.step()
  25. """
  26. def __init__(self, sparsifier, sl_lambda, last_epoch=-1, verbose=False):
  27. self.sparsifier = sparsifier
  28. if not isinstance(sl_lambda, list) and not isinstance(sl_lambda, tuple):
  29. self.sl_lambdas = [sl_lambda] * len(sparsifier.groups)
  30. else:
  31. if len(sl_lambda) != len(sparsifier.groups):
  32. raise ValueError("Expected {} lr_lambdas, but got {}".format(
  33. len(sparsifier.groups), len(sl_lambda)))
  34. self.sl_lambdas = list(sl_lambda)
  35. super().__init__(sparsifier, last_epoch, verbose)
  36. def get_sl(self):
  37. if not self._get_sl_called_within_step:
  38. warnings.warn(
  39. "To get the last sparsity level computed by the scheduler, "
  40. "please use `get_last_sl()`.")
  41. return [base_sl * lmbda(self.last_epoch)
  42. for lmbda, base_sl in zip(self.sl_lambdas, self.base_sl)]