# -*- coding: utf-8 -*- import warnings from .base_scheduler import BaseScheduler __all__ = ["CubicSL"] def _clamp(x, lo, hi): return max(lo, min(hi, x)) class CubicSL(BaseScheduler): r"""Sets the sparsity level of each parameter group to the final sl plus a given exponential function. .. math:: s_i = s_f + (s_0 - s_f) \cdot \left( 1 - \frac{t - t_0}{n\Delta t} \right)^3 where :math:`s_i` is the sparsity at epoch :math:`t`, :math;`s_f` is the final sparsity level, :math:`f(i)` is the function to be applied to the current epoch :math:`t`, initial epoch :math:`t_0`, and final epoch :math:`t_f`. :math:`\Delta t` is used to control how often the update of the sparsity level happens. By default, Args: sparsifier (BaseSparsifier): Wrapped sparsifier. init_sl (int, list): Initial level of sparsity init_t (int, list): Initial step, when pruning starts delta_t (int, list): Pruning frequency total_t (int, list): Total number of pruning steps initially_zero (bool, list): If True, sets the level of sparsity to 0 before init_t (:math:`t_0`). Otherwise, the sparsity level before init_t (:math:`t_0`) is set to init_sl(:math:`s_0`) last_epoch (int): The index of last epoch. Default: -1. verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False``. """ def __init__(self, sparsifier, init_sl=0.0, init_t=0, delta_t=10, total_t=100, initially_zero=False, last_epoch=-1, verbose=False ): self.sparsifier = sparsifier self.init_sl = self._make_sure_a_list(init_sl) self.init_t = self._make_sure_a_list(init_t) self.delta_t = self._make_sure_a_list(delta_t) self.total_t = self._make_sure_a_list(total_t) self.initially_zero = self._make_sure_a_list(initially_zero) super().__init__(sparsifier, last_epoch, verbose) @staticmethod def sparsity_compute_fn(s_0, s_f, t, t_0, dt, n, initially_zero=False): r""""Computes the current level of sparsity. Based on https://arxiv.org/pdf/1710.01878.pdf Args: s_0: Initial level of sparsity, :math:`s_i` s_f: Target level of sparsity, :math:`s_f` t: Current step, :math:`t` t_0: Initial step, :math:`t_0` dt: Pruning frequency, :math:`\Delta T` n: Pruning steps, :math:`n` initially_zero: Sets the level of sparsity to 0 before t_0. If False, sets to s_0 Returns: The sparsity level :math:`s_t` at the current step :math:`t` """ if initially_zero and t < t_0: return 0 s_t = s_f + (s_0 - s_f) * (1.0 - (t - t_0) / (dt * n)) ** 3 s_t = _clamp(s_t, s_0, s_f) return s_t def get_sl(self): if not self._get_sl_called_within_step: warnings.warn( "To get the last sparsity level computed by the scheduler, " "please use `get_last_sl()`.") return [ self.sparsity_compute_fn( s_0=initial_sparsity, s_f=final_sparsity, t=self.last_epoch, t_0=initial_epoch, dt=delta_epoch, n=interval_epochs, initially_zero=initially_zero ) for initial_sparsity, final_sparsity, initial_epoch, delta_epoch, interval_epochs, initially_zero in zip( self.init_sl, self.base_sl, self.init_t, self.delta_t, self.total_t, self.initially_zero ) ]