cubic_scheduler.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. # -*- coding: utf-8 -*-
  2. import warnings
  3. from .base_scheduler import BaseScheduler
  4. __all__ = ["CubicSL"]
  5. def _clamp(x, lo, hi):
  6. return max(lo, min(hi, x))
  7. class CubicSL(BaseScheduler):
  8. r"""Sets the sparsity level of each parameter group to the final sl
  9. plus a given exponential function.
  10. .. math::
  11. s_i = s_f + (s_0 - s_f) \cdot \left( 1 - \frac{t - t_0}{n\Delta t} \right)^3
  12. where :math:`s_i` is the sparsity at epoch :math:`t`, :math;`s_f` is the final
  13. sparsity level, :math:`f(i)` is the function to be applied to the current epoch
  14. :math:`t`, initial epoch :math:`t_0`, and final epoch :math:`t_f`.
  15. :math:`\Delta t` is used to control how often the update of the sparsity level
  16. happens. By default,
  17. Args:
  18. sparsifier (BaseSparsifier): Wrapped sparsifier.
  19. init_sl (int, list): Initial level of sparsity
  20. init_t (int, list): Initial step, when pruning starts
  21. delta_t (int, list): Pruning frequency
  22. total_t (int, list): Total number of pruning steps
  23. initially_zero (bool, list): If True, sets the level of sparsity to 0
  24. before init_t (:math:`t_0`). Otherwise, the sparsity level before
  25. init_t (:math:`t_0`) is set to init_sl(:math:`s_0`)
  26. last_epoch (int): The index of last epoch. Default: -1.
  27. verbose (bool): If ``True``, prints a message to stdout for
  28. each update. Default: ``False``.
  29. """
  30. def __init__(self,
  31. sparsifier,
  32. init_sl=0.0,
  33. init_t=0,
  34. delta_t=10,
  35. total_t=100,
  36. initially_zero=False,
  37. last_epoch=-1,
  38. verbose=False
  39. ):
  40. self.sparsifier = sparsifier
  41. self.init_sl = self._make_sure_a_list(init_sl)
  42. self.init_t = self._make_sure_a_list(init_t)
  43. self.delta_t = self._make_sure_a_list(delta_t)
  44. self.total_t = self._make_sure_a_list(total_t)
  45. self.initially_zero = self._make_sure_a_list(initially_zero)
  46. super().__init__(sparsifier, last_epoch, verbose)
  47. @staticmethod
  48. def sparsity_compute_fn(s_0, s_f, t, t_0, dt, n, initially_zero=False):
  49. r""""Computes the current level of sparsity.
  50. Based on https://arxiv.org/pdf/1710.01878.pdf
  51. Args:
  52. s_0: Initial level of sparsity, :math:`s_i`
  53. s_f: Target level of sparsity, :math:`s_f`
  54. t: Current step, :math:`t`
  55. t_0: Initial step, :math:`t_0`
  56. dt: Pruning frequency, :math:`\Delta T`
  57. n: Pruning steps, :math:`n`
  58. initially_zero: Sets the level of sparsity to 0 before t_0.
  59. If False, sets to s_0
  60. Returns:
  61. The sparsity level :math:`s_t` at the current step :math:`t`
  62. """
  63. if initially_zero and t < t_0:
  64. return 0
  65. s_t = s_f + (s_0 - s_f) * (1.0 - (t - t_0) / (dt * n)) ** 3
  66. s_t = _clamp(s_t, s_0, s_f)
  67. return s_t
  68. def get_sl(self):
  69. if not self._get_sl_called_within_step:
  70. warnings.warn(
  71. "To get the last sparsity level computed by the scheduler, "
  72. "please use `get_last_sl()`.")
  73. return [
  74. self.sparsity_compute_fn(
  75. s_0=initial_sparsity,
  76. s_f=final_sparsity,
  77. t=self.last_epoch,
  78. t_0=initial_epoch,
  79. dt=delta_epoch,
  80. n=interval_epochs,
  81. initially_zero=initially_zero
  82. ) for initial_sparsity, final_sparsity, initial_epoch, delta_epoch, interval_epochs, initially_zero in
  83. zip(
  84. self.init_sl,
  85. self.base_sl,
  86. self.init_t,
  87. self.delta_t,
  88. self.total_t,
  89. self.initially_zero
  90. )
  91. ]