import warnings from .base_scheduler import BaseScheduler __all__ = ["LambdaSL"] class LambdaSL(BaseScheduler): """Sets the sparsity level of each parameter group to the final sl times a given function. When last_epoch=-1, sets initial sl as zero. Args: sparsifier (BaseSparsifier): Wrapped sparsifier. sl_lambda (function or list): A function which computes a multiplicative factor given an integer parameter epoch, or a list of such functions, one for each group in sparsifier.param_groups. last_epoch (int): The index of last epoch. Default: -1. verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False``. Example: >>> # Assuming sparsifier has two groups. >>> lambda1 = lambda epoch: epoch // 30 >>> lambda2 = lambda epoch: 0.95 ** epoch >>> # xdoctest: +SKIP >>> scheduler = LambdaSL(sparsifier, sl_lambda=[lambda1, lambda2]) >>> for epoch in range(100): >>> train(...) >>> validate(...) >>> scheduler.step() """ def __init__(self, sparsifier, sl_lambda, last_epoch=-1, verbose=False): self.sparsifier = sparsifier if not isinstance(sl_lambda, list) and not isinstance(sl_lambda, tuple): self.sl_lambdas = [sl_lambda] * len(sparsifier.groups) else: if len(sl_lambda) != len(sparsifier.groups): raise ValueError("Expected {} lr_lambdas, but got {}".format( len(sparsifier.groups), len(sl_lambda))) self.sl_lambdas = list(sl_lambda) super().__init__(sparsifier, last_epoch, verbose) def get_sl(self): if not self._get_sl_called_within_step: warnings.warn( "To get the last sparsity level computed by the scheduler, " "please use `get_last_sl()`.") return [base_sl * lmbda(self.last_epoch) for lmbda, base_sl in zip(self.sl_lambdas, self.base_sl)]