123456789101112131415161718192021222324252627282930313233343536373839404142434445464748 |
- import warnings
- from .base_scheduler import BaseScheduler
- __all__ = ["LambdaSL"]
- class LambdaSL(BaseScheduler):
- """Sets the sparsity level of each parameter group to the final sl
- times a given function. When last_epoch=-1, sets initial sl as zero.
- Args:
- sparsifier (BaseSparsifier): Wrapped sparsifier.
- sl_lambda (function or list): A function which computes a multiplicative
- factor given an integer parameter epoch, or a list of such
- functions, one for each group in sparsifier.param_groups.
- last_epoch (int): The index of last epoch. Default: -1.
- verbose (bool): If ``True``, prints a message to stdout for
- each update. Default: ``False``.
- Example:
- >>> # Assuming sparsifier has two groups.
- >>> lambda1 = lambda epoch: epoch // 30
- >>> lambda2 = lambda epoch: 0.95 ** epoch
- >>> # xdoctest: +SKIP
- >>> scheduler = LambdaSL(sparsifier, sl_lambda=[lambda1, lambda2])
- >>> for epoch in range(100):
- >>> train(...)
- >>> validate(...)
- >>> scheduler.step()
- """
- def __init__(self, sparsifier, sl_lambda, last_epoch=-1, verbose=False):
- self.sparsifier = sparsifier
- if not isinstance(sl_lambda, list) and not isinstance(sl_lambda, tuple):
- self.sl_lambdas = [sl_lambda] * len(sparsifier.groups)
- else:
- if len(sl_lambda) != len(sparsifier.groups):
- raise ValueError("Expected {} lr_lambdas, but got {}".format(
- len(sparsifier.groups), len(sl_lambda)))
- self.sl_lambdas = list(sl_lambda)
- super().__init__(sparsifier, last_epoch, verbose)
- def get_sl(self):
- if not self._get_sl_called_within_step:
- warnings.warn(
- "To get the last sparsity level computed by the scheduler, "
- "please use `get_last_sl()`.")
- return [base_sl * lmbda(self.last_epoch)
- for lmbda, base_sl in zip(self.sl_lambdas, self.base_sl)]
|