lr_scheduler.py 73 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728
  1. import types
  2. import math
  3. from torch import inf
  4. from functools import wraps
  5. import warnings
  6. import weakref
  7. from collections import Counter
  8. from bisect import bisect_right
  9. from .optimizer import Optimizer
  10. __all__ = ['LambdaLR', 'MultiplicativeLR', 'StepLR', 'MultiStepLR', 'ConstantLR', 'LinearLR',
  11. 'ExponentialLR', 'SequentialLR', 'CosineAnnealingLR', 'ChainedScheduler', 'ReduceLROnPlateau',
  12. 'CyclicLR', 'CosineAnnealingWarmRestarts', 'OneCycleLR', 'PolynomialLR', 'LRScheduler']
  13. EPOCH_DEPRECATION_WARNING = (
  14. "The epoch parameter in `scheduler.step()` was not necessary and is being "
  15. "deprecated where possible. Please use `scheduler.step()` to step the "
  16. "scheduler. During the deprecation, if epoch is different from None, the "
  17. "closed form is used instead of the new chainable form, where available. "
  18. "Please open an issue if you are unable to replicate your use case: "
  19. "https://github.com/pytorch/pytorch/issues/new/choose."
  20. )
  21. class LRScheduler:
  22. def __init__(self, optimizer, last_epoch=-1, verbose=False):
  23. # Attach optimizer
  24. if not isinstance(optimizer, Optimizer):
  25. raise TypeError('{} is not an Optimizer'.format(
  26. type(optimizer).__name__))
  27. self.optimizer = optimizer
  28. # Initialize epoch and base learning rates
  29. if last_epoch == -1:
  30. for group in optimizer.param_groups:
  31. group.setdefault('initial_lr', group['lr'])
  32. else:
  33. for i, group in enumerate(optimizer.param_groups):
  34. if 'initial_lr' not in group:
  35. raise KeyError("param 'initial_lr' is not specified "
  36. "in param_groups[{}] when resuming an optimizer".format(i))
  37. self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups]
  38. self.last_epoch = last_epoch
  39. # Following https://github.com/pytorch/pytorch/issues/20124
  40. # We would like to ensure that `lr_scheduler.step()` is called after
  41. # `optimizer.step()`
  42. def with_counter(method):
  43. if getattr(method, '_with_counter', False):
  44. # `optimizer.step()` has already been replaced, return.
  45. return method
  46. # Keep a weak reference to the optimizer instance to prevent
  47. # cyclic references.
  48. instance_ref = weakref.ref(method.__self__)
  49. # Get the unbound method for the same purpose.
  50. func = method.__func__
  51. cls = instance_ref().__class__
  52. del method
  53. @wraps(func)
  54. def wrapper(*args, **kwargs):
  55. instance = instance_ref()
  56. instance._step_count += 1
  57. wrapped = func.__get__(instance, cls)
  58. return wrapped(*args, **kwargs)
  59. # Note that the returned function here is no longer a bound method,
  60. # so attributes like `__func__` and `__self__` no longer exist.
  61. wrapper._with_counter = True
  62. return wrapper
  63. self.optimizer.step = with_counter(self.optimizer.step)
  64. self.verbose = verbose
  65. self._initial_step()
  66. def _initial_step(self):
  67. """Initialize step counts and performs a step"""
  68. self.optimizer._step_count = 0
  69. self._step_count = 0
  70. self.step()
  71. def state_dict(self):
  72. """Returns the state of the scheduler as a :class:`dict`.
  73. It contains an entry for every variable in self.__dict__ which
  74. is not the optimizer.
  75. """
  76. return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
  77. def load_state_dict(self, state_dict):
  78. """Loads the schedulers state.
  79. Args:
  80. state_dict (dict): scheduler state. Should be an object returned
  81. from a call to :meth:`state_dict`.
  82. """
  83. self.__dict__.update(state_dict)
  84. def get_last_lr(self):
  85. """ Return last computed learning rate by current scheduler.
  86. """
  87. return self._last_lr
  88. def get_lr(self):
  89. # Compute learning rate using chainable form of the scheduler
  90. raise NotImplementedError
  91. def print_lr(self, is_verbose, group, lr, epoch=None):
  92. """Display the current learning rate.
  93. """
  94. if is_verbose:
  95. if epoch is None:
  96. print('Adjusting learning rate'
  97. ' of group {} to {:.4e}.'.format(group, lr))
  98. else:
  99. epoch_str = ("%.2f" if isinstance(epoch, float) else
  100. "%.5d") % epoch
  101. print('Epoch {}: adjusting learning rate'
  102. ' of group {} to {:.4e}.'.format(epoch_str, group, lr))
  103. def step(self, epoch=None):
  104. # Raise a warning if old pattern is detected
  105. # https://github.com/pytorch/pytorch/issues/20124
  106. if self._step_count == 1:
  107. if not hasattr(self.optimizer.step, "_with_counter"):
  108. warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler "
  109. "initialization. Please, make sure to call `optimizer.step()` before "
  110. "`lr_scheduler.step()`. See more details at "
  111. "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
  112. # Just check if there were two first lr_scheduler.step() calls before optimizer.step()
  113. elif self.optimizer._step_count < 1:
  114. warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
  115. "In PyTorch 1.1.0 and later, you should call them in the opposite order: "
  116. "`optimizer.step()` before `lr_scheduler.step()`. Failure to do this "
  117. "will result in PyTorch skipping the first value of the learning rate schedule. "
  118. "See more details at "
  119. "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
  120. self._step_count += 1
  121. with _enable_get_lr_call(self):
  122. if epoch is None:
  123. self.last_epoch += 1
  124. values = self.get_lr()
  125. else:
  126. warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
  127. self.last_epoch = epoch
  128. if hasattr(self, "_get_closed_form_lr"):
  129. values = self._get_closed_form_lr()
  130. else:
  131. values = self.get_lr()
  132. for i, data in enumerate(zip(self.optimizer.param_groups, values)):
  133. param_group, lr = data
  134. param_group['lr'] = lr
  135. self.print_lr(self.verbose, i, lr, epoch)
  136. self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
  137. # Including _LRScheduler for backwards compatibility
  138. # Subclass instead of assign because we want __name__ of _LRScheduler to be _LRScheduler (assigning would make it LRScheduler).
  139. class _LRScheduler(LRScheduler):
  140. pass
  141. class _enable_get_lr_call:
  142. def __init__(self, o):
  143. self.o = o
  144. def __enter__(self):
  145. self.o._get_lr_called_within_step = True
  146. return self
  147. def __exit__(self, type, value, traceback):
  148. self.o._get_lr_called_within_step = False
  149. class LambdaLR(LRScheduler):
  150. """Sets the learning rate of each parameter group to the initial lr
  151. times a given function. When last_epoch=-1, sets initial lr as lr.
  152. Args:
  153. optimizer (Optimizer): Wrapped optimizer.
  154. lr_lambda (function or list): A function which computes a multiplicative
  155. factor given an integer parameter epoch, or a list of such
  156. functions, one for each group in optimizer.param_groups.
  157. last_epoch (int): The index of last epoch. Default: -1.
  158. verbose (bool): If ``True``, prints a message to stdout for
  159. each update. Default: ``False``.
  160. Example:
  161. >>> # xdoctest: +SKIP
  162. >>> # Assuming optimizer has two groups.
  163. >>> lambda1 = lambda epoch: epoch // 30
  164. >>> lambda2 = lambda epoch: 0.95 ** epoch
  165. >>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
  166. >>> for epoch in range(100):
  167. >>> train(...)
  168. >>> validate(...)
  169. >>> scheduler.step()
  170. """
  171. def __init__(self, optimizer, lr_lambda, last_epoch=-1, verbose=False):
  172. self.optimizer = optimizer
  173. if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
  174. self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
  175. else:
  176. if len(lr_lambda) != len(optimizer.param_groups):
  177. raise ValueError("Expected {} lr_lambdas, but got {}".format(
  178. len(optimizer.param_groups), len(lr_lambda)))
  179. self.lr_lambdas = list(lr_lambda)
  180. super().__init__(optimizer, last_epoch, verbose)
  181. def state_dict(self):
  182. """Returns the state of the scheduler as a :class:`dict`.
  183. It contains an entry for every variable in self.__dict__ which
  184. is not the optimizer.
  185. The learning rate lambda functions will only be saved if they are callable objects
  186. and not if they are functions or lambdas.
  187. When saving or loading the scheduler, please make sure to also save or load the state of the optimizer.
  188. """
  189. state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')}
  190. state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas)
  191. for idx, fn in enumerate(self.lr_lambdas):
  192. if not isinstance(fn, types.FunctionType):
  193. state_dict['lr_lambdas'][idx] = fn.__dict__.copy()
  194. return state_dict
  195. def load_state_dict(self, state_dict):
  196. """Loads the schedulers state.
  197. When saving or loading the scheduler, please make sure to also save or load the state of the optimizer.
  198. Args:
  199. state_dict (dict): scheduler state. Should be an object returned
  200. from a call to :meth:`state_dict`.
  201. """
  202. lr_lambdas = state_dict.pop('lr_lambdas')
  203. self.__dict__.update(state_dict)
  204. # Restore state_dict keys in order to prevent side effects
  205. # https://github.com/pytorch/pytorch/issues/32756
  206. state_dict['lr_lambdas'] = lr_lambdas
  207. for idx, fn in enumerate(lr_lambdas):
  208. if fn is not None:
  209. self.lr_lambdas[idx].__dict__.update(fn)
  210. def get_lr(self):
  211. if not self._get_lr_called_within_step:
  212. warnings.warn("To get the last learning rate computed by the scheduler, "
  213. "please use `get_last_lr()`.")
  214. return [base_lr * lmbda(self.last_epoch)
  215. for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]
  216. class MultiplicativeLR(LRScheduler):
  217. """Multiply the learning rate of each parameter group by the factor given
  218. in the specified function. When last_epoch=-1, sets initial lr as lr.
  219. Args:
  220. optimizer (Optimizer): Wrapped optimizer.
  221. lr_lambda (function or list): A function which computes a multiplicative
  222. factor given an integer parameter epoch, or a list of such
  223. functions, one for each group in optimizer.param_groups.
  224. last_epoch (int): The index of last epoch. Default: -1.
  225. verbose (bool): If ``True``, prints a message to stdout for
  226. each update. Default: ``False``.
  227. Example:
  228. >>> # xdoctest: +SKIP
  229. >>> lmbda = lambda epoch: 0.95
  230. >>> scheduler = MultiplicativeLR(optimizer, lr_lambda=lmbda)
  231. >>> for epoch in range(100):
  232. >>> train(...)
  233. >>> validate(...)
  234. >>> scheduler.step()
  235. """
  236. def __init__(self, optimizer, lr_lambda, last_epoch=-1, verbose=False):
  237. self.optimizer = optimizer
  238. if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
  239. self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
  240. else:
  241. if len(lr_lambda) != len(optimizer.param_groups):
  242. raise ValueError("Expected {} lr_lambdas, but got {}".format(
  243. len(optimizer.param_groups), len(lr_lambda)))
  244. self.lr_lambdas = list(lr_lambda)
  245. super().__init__(optimizer, last_epoch, verbose)
  246. def state_dict(self):
  247. """Returns the state of the scheduler as a :class:`dict`.
  248. It contains an entry for every variable in self.__dict__ which
  249. is not the optimizer.
  250. The learning rate lambda functions will only be saved if they are callable objects
  251. and not if they are functions or lambdas.
  252. """
  253. state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')}
  254. state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas)
  255. for idx, fn in enumerate(self.lr_lambdas):
  256. if not isinstance(fn, types.FunctionType):
  257. state_dict['lr_lambdas'][idx] = fn.__dict__.copy()
  258. return state_dict
  259. def load_state_dict(self, state_dict):
  260. """Loads the schedulers state.
  261. Args:
  262. state_dict (dict): scheduler state. Should be an object returned
  263. from a call to :meth:`state_dict`.
  264. """
  265. lr_lambdas = state_dict.pop('lr_lambdas')
  266. self.__dict__.update(state_dict)
  267. # Restore state_dict keys in order to prevent side effects
  268. # https://github.com/pytorch/pytorch/issues/32756
  269. state_dict['lr_lambdas'] = lr_lambdas
  270. for idx, fn in enumerate(lr_lambdas):
  271. if fn is not None:
  272. self.lr_lambdas[idx].__dict__.update(fn)
  273. def get_lr(self):
  274. if not self._get_lr_called_within_step:
  275. warnings.warn("To get the last learning rate computed by the scheduler, "
  276. "please use `get_last_lr()`.", UserWarning)
  277. if self.last_epoch > 0:
  278. return [group['lr'] * lmbda(self.last_epoch)
  279. for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups)]
  280. else:
  281. return [group['lr'] for group in self.optimizer.param_groups]
  282. class StepLR(LRScheduler):
  283. """Decays the learning rate of each parameter group by gamma every
  284. step_size epochs. Notice that such decay can happen simultaneously with
  285. other changes to the learning rate from outside this scheduler. When
  286. last_epoch=-1, sets initial lr as lr.
  287. Args:
  288. optimizer (Optimizer): Wrapped optimizer.
  289. step_size (int): Period of learning rate decay.
  290. gamma (float): Multiplicative factor of learning rate decay.
  291. Default: 0.1.
  292. last_epoch (int): The index of last epoch. Default: -1.
  293. verbose (bool): If ``True``, prints a message to stdout for
  294. each update. Default: ``False``.
  295. Example:
  296. >>> # xdoctest: +SKIP
  297. >>> # Assuming optimizer uses lr = 0.05 for all groups
  298. >>> # lr = 0.05 if epoch < 30
  299. >>> # lr = 0.005 if 30 <= epoch < 60
  300. >>> # lr = 0.0005 if 60 <= epoch < 90
  301. >>> # ...
  302. >>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
  303. >>> for epoch in range(100):
  304. >>> train(...)
  305. >>> validate(...)
  306. >>> scheduler.step()
  307. """
  308. def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1, verbose=False):
  309. self.step_size = step_size
  310. self.gamma = gamma
  311. super().__init__(optimizer, last_epoch, verbose)
  312. def get_lr(self):
  313. if not self._get_lr_called_within_step:
  314. warnings.warn("To get the last learning rate computed by the scheduler, "
  315. "please use `get_last_lr()`.", UserWarning)
  316. if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0):
  317. return [group['lr'] for group in self.optimizer.param_groups]
  318. return [group['lr'] * self.gamma
  319. for group in self.optimizer.param_groups]
  320. def _get_closed_form_lr(self):
  321. return [base_lr * self.gamma ** (self.last_epoch // self.step_size)
  322. for base_lr in self.base_lrs]
  323. class MultiStepLR(LRScheduler):
  324. """Decays the learning rate of each parameter group by gamma once the
  325. number of epoch reaches one of the milestones. Notice that such decay can
  326. happen simultaneously with other changes to the learning rate from outside
  327. this scheduler. When last_epoch=-1, sets initial lr as lr.
  328. Args:
  329. optimizer (Optimizer): Wrapped optimizer.
  330. milestones (list): List of epoch indices. Must be increasing.
  331. gamma (float): Multiplicative factor of learning rate decay.
  332. Default: 0.1.
  333. last_epoch (int): The index of last epoch. Default: -1.
  334. verbose (bool): If ``True``, prints a message to stdout for
  335. each update. Default: ``False``.
  336. Example:
  337. >>> # xdoctest: +SKIP
  338. >>> # Assuming optimizer uses lr = 0.05 for all groups
  339. >>> # lr = 0.05 if epoch < 30
  340. >>> # lr = 0.005 if 30 <= epoch < 80
  341. >>> # lr = 0.0005 if epoch >= 80
  342. >>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
  343. >>> for epoch in range(100):
  344. >>> train(...)
  345. >>> validate(...)
  346. >>> scheduler.step()
  347. """
  348. def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1, verbose=False):
  349. self.milestones = Counter(milestones)
  350. self.gamma = gamma
  351. super().__init__(optimizer, last_epoch, verbose)
  352. def get_lr(self):
  353. if not self._get_lr_called_within_step:
  354. warnings.warn("To get the last learning rate computed by the scheduler, "
  355. "please use `get_last_lr()`.", UserWarning)
  356. if self.last_epoch not in self.milestones:
  357. return [group['lr'] for group in self.optimizer.param_groups]
  358. return [group['lr'] * self.gamma ** self.milestones[self.last_epoch]
  359. for group in self.optimizer.param_groups]
  360. def _get_closed_form_lr(self):
  361. milestones = sorted(self.milestones.elements())
  362. return [base_lr * self.gamma ** bisect_right(milestones, self.last_epoch)
  363. for base_lr in self.base_lrs]
  364. class ConstantLR(LRScheduler):
  365. """Decays the learning rate of each parameter group by a small constant factor until the
  366. number of epoch reaches a pre-defined milestone: total_iters. Notice that such decay can
  367. happen simultaneously with other changes to the learning rate from outside this scheduler.
  368. When last_epoch=-1, sets initial lr as lr.
  369. Args:
  370. optimizer (Optimizer): Wrapped optimizer.
  371. factor (float): The number we multiply learning rate until the milestone. Default: 1./3.
  372. total_iters (int): The number of steps that the scheduler decays the learning rate.
  373. Default: 5.
  374. last_epoch (int): The index of the last epoch. Default: -1.
  375. verbose (bool): If ``True``, prints a message to stdout for
  376. each update. Default: ``False``.
  377. Example:
  378. >>> # xdoctest: +SKIP
  379. >>> # Assuming optimizer uses lr = 0.05 for all groups
  380. >>> # lr = 0.025 if epoch == 0
  381. >>> # lr = 0.025 if epoch == 1
  382. >>> # lr = 0.025 if epoch == 2
  383. >>> # lr = 0.025 if epoch == 3
  384. >>> # lr = 0.05 if epoch >= 4
  385. >>> scheduler = ConstantLR(self.opt, factor=0.5, total_iters=4)
  386. >>> for epoch in range(100):
  387. >>> train(...)
  388. >>> validate(...)
  389. >>> scheduler.step()
  390. """
  391. def __init__(self, optimizer, factor=1.0 / 3, total_iters=5, last_epoch=-1, verbose=False):
  392. if factor > 1.0 or factor < 0:
  393. raise ValueError('Constant multiplicative factor expected to be between 0 and 1.')
  394. self.factor = factor
  395. self.total_iters = total_iters
  396. super().__init__(optimizer, last_epoch, verbose)
  397. def get_lr(self):
  398. if not self._get_lr_called_within_step:
  399. warnings.warn("To get the last learning rate computed by the scheduler, "
  400. "please use `get_last_lr()`.", UserWarning)
  401. if self.last_epoch == 0:
  402. return [group['lr'] * self.factor for group in self.optimizer.param_groups]
  403. if (self.last_epoch > self.total_iters or
  404. (self.last_epoch != self.total_iters)):
  405. return [group['lr'] for group in self.optimizer.param_groups]
  406. if (self.last_epoch == self.total_iters):
  407. return [group['lr'] * (1.0 / self.factor) for group in self.optimizer.param_groups]
  408. def _get_closed_form_lr(self):
  409. return [base_lr * (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor))
  410. for base_lr in self.base_lrs]
  411. class LinearLR(LRScheduler):
  412. """Decays the learning rate of each parameter group by linearly changing small
  413. multiplicative factor until the number of epoch reaches a pre-defined milestone: total_iters.
  414. Notice that such decay can happen simultaneously with other changes to the learning rate
  415. from outside this scheduler. When last_epoch=-1, sets initial lr as lr.
  416. Args:
  417. optimizer (Optimizer): Wrapped optimizer.
  418. start_factor (float): The number we multiply learning rate in the first epoch.
  419. The multiplication factor changes towards end_factor in the following epochs.
  420. Default: 1./3.
  421. end_factor (float): The number we multiply learning rate at the end of linear changing
  422. process. Default: 1.0.
  423. total_iters (int): The number of iterations that multiplicative factor reaches to 1.
  424. Default: 5.
  425. last_epoch (int): The index of the last epoch. Default: -1.
  426. verbose (bool): If ``True``, prints a message to stdout for
  427. each update. Default: ``False``.
  428. Example:
  429. >>> # xdoctest: +SKIP
  430. >>> # Assuming optimizer uses lr = 0.05 for all groups
  431. >>> # lr = 0.025 if epoch == 0
  432. >>> # lr = 0.03125 if epoch == 1
  433. >>> # lr = 0.0375 if epoch == 2
  434. >>> # lr = 0.04375 if epoch == 3
  435. >>> # lr = 0.05 if epoch >= 4
  436. >>> scheduler = LinearLR(self.opt, start_factor=0.5, total_iters=4)
  437. >>> for epoch in range(100):
  438. >>> train(...)
  439. >>> validate(...)
  440. >>> scheduler.step()
  441. """
  442. def __init__(self, optimizer, start_factor=1.0 / 3, end_factor=1.0, total_iters=5, last_epoch=-1,
  443. verbose=False):
  444. if start_factor > 1.0 or start_factor <= 0:
  445. raise ValueError('Starting multiplicative factor expected to be greater than 0 and less or equal to 1.')
  446. if end_factor > 1.0 or end_factor < 0:
  447. raise ValueError('Ending multiplicative factor expected to be between 0 and 1.')
  448. self.start_factor = start_factor
  449. self.end_factor = end_factor
  450. self.total_iters = total_iters
  451. super().__init__(optimizer, last_epoch, verbose)
  452. def get_lr(self):
  453. if not self._get_lr_called_within_step:
  454. warnings.warn("To get the last learning rate computed by the scheduler, "
  455. "please use `get_last_lr()`.", UserWarning)
  456. if self.last_epoch == 0:
  457. return [group['lr'] * self.start_factor for group in self.optimizer.param_groups]
  458. if self.last_epoch > self.total_iters:
  459. return [group['lr'] for group in self.optimizer.param_groups]
  460. return [group['lr'] * (1. + (self.end_factor - self.start_factor) /
  461. (self.total_iters * self.start_factor + (self.last_epoch - 1) * (self.end_factor - self.start_factor)))
  462. for group in self.optimizer.param_groups]
  463. def _get_closed_form_lr(self):
  464. return [base_lr * (self.start_factor +
  465. (self.end_factor - self.start_factor) * min(self.total_iters, self.last_epoch) / self.total_iters)
  466. for base_lr in self.base_lrs]
  467. class ExponentialLR(LRScheduler):
  468. """Decays the learning rate of each parameter group by gamma every epoch.
  469. When last_epoch=-1, sets initial lr as lr.
  470. Args:
  471. optimizer (Optimizer): Wrapped optimizer.
  472. gamma (float): Multiplicative factor of learning rate decay.
  473. last_epoch (int): The index of last epoch. Default: -1.
  474. verbose (bool): If ``True``, prints a message to stdout for
  475. each update. Default: ``False``.
  476. """
  477. def __init__(self, optimizer, gamma, last_epoch=-1, verbose=False):
  478. self.gamma = gamma
  479. super().__init__(optimizer, last_epoch, verbose)
  480. def get_lr(self):
  481. if not self._get_lr_called_within_step:
  482. warnings.warn("To get the last learning rate computed by the scheduler, "
  483. "please use `get_last_lr()`.", UserWarning)
  484. if self.last_epoch == 0:
  485. return [group['lr'] for group in self.optimizer.param_groups]
  486. return [group['lr'] * self.gamma
  487. for group in self.optimizer.param_groups]
  488. def _get_closed_form_lr(self):
  489. return [base_lr * self.gamma ** self.last_epoch
  490. for base_lr in self.base_lrs]
  491. class SequentialLR(LRScheduler):
  492. """Receives the list of schedulers that is expected to be called sequentially during
  493. optimization process and milestone points that provides exact intervals to reflect
  494. which scheduler is supposed to be called at a given epoch.
  495. Args:
  496. optimizer (Optimizer): Wrapped optimizer.
  497. schedulers (list): List of chained schedulers.
  498. milestones (list): List of integers that reflects milestone points.
  499. last_epoch (int): The index of last epoch. Default: -1.
  500. verbose (bool): Does nothing.
  501. Example:
  502. >>> # xdoctest: +SKIP
  503. >>> # Assuming optimizer uses lr = 1. for all groups
  504. >>> # lr = 0.1 if epoch == 0
  505. >>> # lr = 0.1 if epoch == 1
  506. >>> # lr = 0.9 if epoch == 2
  507. >>> # lr = 0.81 if epoch == 3
  508. >>> # lr = 0.729 if epoch == 4
  509. >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2)
  510. >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9)
  511. >>> scheduler = SequentialLR(self.opt, schedulers=[scheduler1, scheduler2], milestones=[2])
  512. >>> for epoch in range(100):
  513. >>> train(...)
  514. >>> validate(...)
  515. >>> scheduler.step()
  516. """
  517. def __init__(self, optimizer, schedulers, milestones, last_epoch=-1, verbose=False):
  518. for scheduler_idx in range(len(schedulers)):
  519. if schedulers[scheduler_idx].optimizer != optimizer:
  520. raise ValueError(
  521. "Sequential Schedulers expects all schedulers to belong to the same optimizer, but "
  522. f"got schedulers at index {scheduler_idx} to be different than the optimizer passed in."
  523. )
  524. if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer):
  525. raise ValueError(
  526. "Sequential Schedulers expects all schedulers to belong to the same optimizer, but "
  527. f"got schedulers at index {0} and {scheduler_idx} to be different."
  528. )
  529. if (len(milestones) != len(schedulers) - 1):
  530. raise ValueError(
  531. "Sequential Schedulers expects number of schedulers provided to be one more "
  532. "than the number of milestone points, but got number of schedulers {} and the "
  533. "number of milestones to be equal to {}".format(len(schedulers), len(milestones))
  534. )
  535. self._schedulers = schedulers
  536. self._milestones = milestones
  537. self.last_epoch = last_epoch + 1
  538. self.optimizer = optimizer
  539. # Reset learning rates back to initial values
  540. for group in self.optimizer.param_groups:
  541. group["lr"] = group["initial_lr"]
  542. # "Undo" the step performed by other schedulers
  543. for scheduler in self._schedulers:
  544. scheduler.last_epoch -= 1
  545. # Perform the initial step for only the first scheduler
  546. self._schedulers[0]._initial_step()
  547. self._last_lr = schedulers[0].get_last_lr()
  548. def step(self):
  549. self.last_epoch += 1
  550. idx = bisect_right(self._milestones, self.last_epoch)
  551. scheduler = self._schedulers[idx]
  552. if idx > 0 and self._milestones[idx - 1] == self.last_epoch:
  553. scheduler.step(0)
  554. else:
  555. scheduler.step()
  556. self._last_lr = scheduler.get_last_lr()
  557. def state_dict(self):
  558. """Returns the state of the scheduler as a :class:`dict`.
  559. It contains an entry for every variable in self.__dict__ which
  560. is not the optimizer.
  561. The wrapped scheduler states will also be saved.
  562. """
  563. state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')}
  564. state_dict['_schedulers'] = [None] * len(self._schedulers)
  565. for idx, s in enumerate(self._schedulers):
  566. state_dict['_schedulers'][idx] = s.state_dict()
  567. return state_dict
  568. def load_state_dict(self, state_dict):
  569. """Loads the schedulers state.
  570. Args:
  571. state_dict (dict): scheduler state. Should be an object returned
  572. from a call to :meth:`state_dict`.
  573. """
  574. _schedulers = state_dict.pop('_schedulers')
  575. self.__dict__.update(state_dict)
  576. # Restore state_dict keys in order to prevent side effects
  577. # https://github.com/pytorch/pytorch/issues/32756
  578. state_dict['_schedulers'] = _schedulers
  579. for idx, s in enumerate(_schedulers):
  580. self._schedulers[idx].load_state_dict(s)
  581. class PolynomialLR(LRScheduler):
  582. """Decays the learning rate of each parameter group using a polynomial function
  583. in the given total_iters. When last_epoch=-1, sets initial lr as lr.
  584. Args:
  585. optimizer (Optimizer): Wrapped optimizer.
  586. total_iters (int): The number of steps that the scheduler decays the learning rate. Default: 5.
  587. power (int): The power of the polynomial. Default: 1.0.
  588. verbose (bool): If ``True``, prints a message to stdout for
  589. each update. Default: ``False``.
  590. Example:
  591. >>> # xdoctest: +SKIP("undefined vars")
  592. >>> # Assuming optimizer uses lr = 0.001 for all groups
  593. >>> # lr = 0.001 if epoch == 0
  594. >>> # lr = 0.00075 if epoch == 1
  595. >>> # lr = 0.00050 if epoch == 2
  596. >>> # lr = 0.00025 if epoch == 3
  597. >>> # lr = 0.0 if epoch >= 4
  598. >>> scheduler = PolynomialLR(self.opt, total_iters=4, power=1.0)
  599. >>> for epoch in range(100):
  600. >>> train(...)
  601. >>> validate(...)
  602. >>> scheduler.step()
  603. """
  604. def __init__(self, optimizer, total_iters=5, power=1.0, last_epoch=-1, verbose=False):
  605. self.total_iters = total_iters
  606. self.power = power
  607. super().__init__(optimizer, last_epoch, verbose)
  608. def get_lr(self):
  609. if not self._get_lr_called_within_step:
  610. warnings.warn("To get the last learning rate computed by the scheduler, "
  611. "please use `get_last_lr()`.", UserWarning)
  612. if self.last_epoch == 0 or self.last_epoch > self.total_iters:
  613. return [group["lr"] for group in self.optimizer.param_groups]
  614. decay_factor = ((1.0 - self.last_epoch / self.total_iters) / (1.0 - (self.last_epoch - 1) / self.total_iters)) ** self.power
  615. return [group["lr"] * decay_factor for group in self.optimizer.param_groups]
  616. def _get_closed_form_lr(self):
  617. return [
  618. (
  619. base_lr * (1.0 - min(self.total_iters, self.last_epoch) / self.total_iters) ** self.power
  620. )
  621. for base_lr in self.base_lrs
  622. ]
  623. class CosineAnnealingLR(LRScheduler):
  624. r"""Set the learning rate of each parameter group using a cosine annealing
  625. schedule, where :math:`\eta_{max}` is set to the initial lr and
  626. :math:`T_{cur}` is the number of epochs since the last restart in SGDR:
  627. .. math::
  628. \begin{aligned}
  629. \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
  630. + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
  631. & T_{cur} \neq (2k+1)T_{max}; \\
  632. \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
  633. \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
  634. & T_{cur} = (2k+1)T_{max}.
  635. \end{aligned}
  636. When last_epoch=-1, sets initial lr as lr. Notice that because the schedule
  637. is defined recursively, the learning rate can be simultaneously modified
  638. outside this scheduler by other operators. If the learning rate is set
  639. solely by this scheduler, the learning rate at each step becomes:
  640. .. math::
  641. \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
  642. \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right)
  643. It has been proposed in
  644. `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
  645. implements the cosine annealing part of SGDR, and not the restarts.
  646. Args:
  647. optimizer (Optimizer): Wrapped optimizer.
  648. T_max (int): Maximum number of iterations.
  649. eta_min (float): Minimum learning rate. Default: 0.
  650. last_epoch (int): The index of last epoch. Default: -1.
  651. verbose (bool): If ``True``, prints a message to stdout for
  652. each update. Default: ``False``.
  653. .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
  654. https://arxiv.org/abs/1608.03983
  655. """
  656. def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1, verbose=False):
  657. self.T_max = T_max
  658. self.eta_min = eta_min
  659. super().__init__(optimizer, last_epoch, verbose)
  660. def get_lr(self):
  661. if not self._get_lr_called_within_step:
  662. warnings.warn("To get the last learning rate computed by the scheduler, "
  663. "please use `get_last_lr()`.", UserWarning)
  664. if self.last_epoch == 0:
  665. return [group['lr'] for group in self.optimizer.param_groups]
  666. elif self._step_count == 1 and self.last_epoch > 0:
  667. return [self.eta_min + (base_lr - self.eta_min) *
  668. (1 + math.cos((self.last_epoch) * math.pi / self.T_max)) / 2
  669. for base_lr, group in
  670. zip(self.base_lrs, self.optimizer.param_groups)]
  671. elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
  672. return [group['lr'] + (base_lr - self.eta_min) *
  673. (1 - math.cos(math.pi / self.T_max)) / 2
  674. for base_lr, group in
  675. zip(self.base_lrs, self.optimizer.param_groups)]
  676. return [(1 + math.cos(math.pi * self.last_epoch / self.T_max)) /
  677. (1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) *
  678. (group['lr'] - self.eta_min) + self.eta_min
  679. for group in self.optimizer.param_groups]
  680. def _get_closed_form_lr(self):
  681. return [self.eta_min + (base_lr - self.eta_min) *
  682. (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2
  683. for base_lr in self.base_lrs]
  684. class ChainedScheduler(LRScheduler):
  685. """Chains list of learning rate schedulers. It takes a list of chainable learning
  686. rate schedulers and performs consecutive step() functions belonging to them by just
  687. one call.
  688. Args:
  689. schedulers (list): List of chained schedulers.
  690. Example:
  691. >>> # xdoctest: +SKIP
  692. >>> # Assuming optimizer uses lr = 1. for all groups
  693. >>> # lr = 0.09 if epoch == 0
  694. >>> # lr = 0.081 if epoch == 1
  695. >>> # lr = 0.729 if epoch == 2
  696. >>> # lr = 0.6561 if epoch == 3
  697. >>> # lr = 0.59049 if epoch >= 4
  698. >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2)
  699. >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9)
  700. >>> scheduler = ChainedScheduler([scheduler1, scheduler2])
  701. >>> for epoch in range(100):
  702. >>> train(...)
  703. >>> validate(...)
  704. >>> scheduler.step()
  705. """
  706. def __init__(self, schedulers):
  707. for scheduler_idx in range(1, len(schedulers)):
  708. if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer):
  709. raise ValueError(
  710. "ChainedScheduler expects all schedulers to belong to the same optimizer, but "
  711. "got schedulers at index {} and {} to be different".format(0, scheduler_idx)
  712. )
  713. self._schedulers = list(schedulers)
  714. self.optimizer = schedulers[0].optimizer
  715. self._last_lr = [group['lr'] for group in self._schedulers[-1].optimizer.param_groups]
  716. def step(self):
  717. for scheduler in self._schedulers:
  718. scheduler.step()
  719. self._last_lr = [group['lr'] for group in self._schedulers[-1].optimizer.param_groups]
  720. def state_dict(self):
  721. """Returns the state of the scheduler as a :class:`dict`.
  722. It contains an entry for every variable in self.__dict__ which
  723. is not the optimizer.
  724. The wrapped scheduler states will also be saved.
  725. """
  726. state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')}
  727. state_dict['_schedulers'] = [None] * len(self._schedulers)
  728. for idx, s in enumerate(self._schedulers):
  729. state_dict['_schedulers'][idx] = s.state_dict()
  730. return state_dict
  731. def load_state_dict(self, state_dict):
  732. """Loads the schedulers state.
  733. Args:
  734. state_dict (dict): scheduler state. Should be an object returned
  735. from a call to :meth:`state_dict`.
  736. """
  737. _schedulers = state_dict.pop('_schedulers')
  738. self.__dict__.update(state_dict)
  739. # Restore state_dict keys in order to prevent side effects
  740. # https://github.com/pytorch/pytorch/issues/32756
  741. state_dict['_schedulers'] = _schedulers
  742. for idx, s in enumerate(_schedulers):
  743. self._schedulers[idx].load_state_dict(s)
  744. class ReduceLROnPlateau:
  745. """Reduce learning rate when a metric has stopped improving.
  746. Models often benefit from reducing the learning rate by a factor
  747. of 2-10 once learning stagnates. This scheduler reads a metrics
  748. quantity and if no improvement is seen for a 'patience' number
  749. of epochs, the learning rate is reduced.
  750. Args:
  751. optimizer (Optimizer): Wrapped optimizer.
  752. mode (str): One of `min`, `max`. In `min` mode, lr will
  753. be reduced when the quantity monitored has stopped
  754. decreasing; in `max` mode it will be reduced when the
  755. quantity monitored has stopped increasing. Default: 'min'.
  756. factor (float): Factor by which the learning rate will be
  757. reduced. new_lr = lr * factor. Default: 0.1.
  758. patience (int): Number of epochs with no improvement after
  759. which learning rate will be reduced. For example, if
  760. `patience = 2`, then we will ignore the first 2 epochs
  761. with no improvement, and will only decrease the LR after the
  762. 3rd epoch if the loss still hasn't improved then.
  763. Default: 10.
  764. threshold (float): Threshold for measuring the new optimum,
  765. to only focus on significant changes. Default: 1e-4.
  766. threshold_mode (str): One of `rel`, `abs`. In `rel` mode,
  767. dynamic_threshold = best * ( 1 + threshold ) in 'max'
  768. mode or best * ( 1 - threshold ) in `min` mode.
  769. In `abs` mode, dynamic_threshold = best + threshold in
  770. `max` mode or best - threshold in `min` mode. Default: 'rel'.
  771. cooldown (int): Number of epochs to wait before resuming
  772. normal operation after lr has been reduced. Default: 0.
  773. min_lr (float or list): A scalar or a list of scalars. A
  774. lower bound on the learning rate of all param groups
  775. or each group respectively. Default: 0.
  776. eps (float): Minimal decay applied to lr. If the difference
  777. between new and old lr is smaller than eps, the update is
  778. ignored. Default: 1e-8.
  779. verbose (bool): If ``True``, prints a message to stdout for
  780. each update. Default: ``False``.
  781. Example:
  782. >>> # xdoctest: +SKIP
  783. >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
  784. >>> scheduler = ReduceLROnPlateau(optimizer, 'min')
  785. >>> for epoch in range(10):
  786. >>> train(...)
  787. >>> val_loss = validate(...)
  788. >>> # Note that step should be called after validate()
  789. >>> scheduler.step(val_loss)
  790. """
  791. def __init__(self, optimizer, mode='min', factor=0.1, patience=10,
  792. threshold=1e-4, threshold_mode='rel', cooldown=0,
  793. min_lr=0, eps=1e-8, verbose=False):
  794. if factor >= 1.0:
  795. raise ValueError('Factor should be < 1.0.')
  796. self.factor = factor
  797. # Attach optimizer
  798. if not isinstance(optimizer, Optimizer):
  799. raise TypeError('{} is not an Optimizer'.format(
  800. type(optimizer).__name__))
  801. self.optimizer = optimizer
  802. if isinstance(min_lr, (list, tuple)):
  803. if len(min_lr) != len(optimizer.param_groups):
  804. raise ValueError("expected {} min_lrs, got {}".format(
  805. len(optimizer.param_groups), len(min_lr)))
  806. self.min_lrs = list(min_lr)
  807. else:
  808. self.min_lrs = [min_lr] * len(optimizer.param_groups)
  809. self.patience = patience
  810. self.verbose = verbose
  811. self.cooldown = cooldown
  812. self.cooldown_counter = 0
  813. self.mode = mode
  814. self.threshold = threshold
  815. self.threshold_mode = threshold_mode
  816. self.best = None
  817. self.num_bad_epochs = None
  818. self.mode_worse = None # the worse value for the chosen mode
  819. self.eps = eps
  820. self.last_epoch = 0
  821. self._init_is_better(mode=mode, threshold=threshold,
  822. threshold_mode=threshold_mode)
  823. self._reset()
  824. def _reset(self):
  825. """Resets num_bad_epochs counter and cooldown counter."""
  826. self.best = self.mode_worse
  827. self.cooldown_counter = 0
  828. self.num_bad_epochs = 0
  829. def step(self, metrics, epoch=None):
  830. # convert `metrics` to float, in case it's a zero-dim Tensor
  831. current = float(metrics)
  832. if epoch is None:
  833. epoch = self.last_epoch + 1
  834. else:
  835. warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
  836. self.last_epoch = epoch
  837. if self.is_better(current, self.best):
  838. self.best = current
  839. self.num_bad_epochs = 0
  840. else:
  841. self.num_bad_epochs += 1
  842. if self.in_cooldown:
  843. self.cooldown_counter -= 1
  844. self.num_bad_epochs = 0 # ignore any bad epochs in cooldown
  845. if self.num_bad_epochs > self.patience:
  846. self._reduce_lr(epoch)
  847. self.cooldown_counter = self.cooldown
  848. self.num_bad_epochs = 0
  849. self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
  850. def _reduce_lr(self, epoch):
  851. for i, param_group in enumerate(self.optimizer.param_groups):
  852. old_lr = float(param_group['lr'])
  853. new_lr = max(old_lr * self.factor, self.min_lrs[i])
  854. if old_lr - new_lr > self.eps:
  855. param_group['lr'] = new_lr
  856. if self.verbose:
  857. epoch_str = ("%.2f" if isinstance(epoch, float) else
  858. "%.5d") % epoch
  859. print('Epoch {}: reducing learning rate'
  860. ' of group {} to {:.4e}.'.format(epoch_str, i, new_lr))
  861. @property
  862. def in_cooldown(self):
  863. return self.cooldown_counter > 0
  864. def is_better(self, a, best):
  865. if self.mode == 'min' and self.threshold_mode == 'rel':
  866. rel_epsilon = 1. - self.threshold
  867. return a < best * rel_epsilon
  868. elif self.mode == 'min' and self.threshold_mode == 'abs':
  869. return a < best - self.threshold
  870. elif self.mode == 'max' and self.threshold_mode == 'rel':
  871. rel_epsilon = self.threshold + 1.
  872. return a > best * rel_epsilon
  873. else: # mode == 'max' and epsilon_mode == 'abs':
  874. return a > best + self.threshold
  875. def _init_is_better(self, mode, threshold, threshold_mode):
  876. if mode not in {'min', 'max'}:
  877. raise ValueError('mode ' + mode + ' is unknown!')
  878. if threshold_mode not in {'rel', 'abs'}:
  879. raise ValueError('threshold mode ' + threshold_mode + ' is unknown!')
  880. if mode == 'min':
  881. self.mode_worse = inf
  882. else: # mode == 'max':
  883. self.mode_worse = -inf
  884. self.mode = mode
  885. self.threshold = threshold
  886. self.threshold_mode = threshold_mode
  887. def state_dict(self):
  888. return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
  889. def load_state_dict(self, state_dict):
  890. self.__dict__.update(state_dict)
  891. self._init_is_better(mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode)
  892. class CyclicLR(LRScheduler):
  893. r"""Sets the learning rate of each parameter group according to
  894. cyclical learning rate policy (CLR). The policy cycles the learning
  895. rate between two boundaries with a constant frequency, as detailed in
  896. the paper `Cyclical Learning Rates for Training Neural Networks`_.
  897. The distance between the two boundaries can be scaled on a per-iteration
  898. or per-cycle basis.
  899. Cyclical learning rate policy changes the learning rate after every batch.
  900. `step` should be called after a batch has been used for training.
  901. This class has three built-in policies, as put forth in the paper:
  902. * "triangular": A basic triangular cycle without amplitude scaling.
  903. * "triangular2": A basic triangular cycle that scales initial amplitude by half each cycle.
  904. * "exp_range": A cycle that scales initial amplitude by :math:`\text{gamma}^{\text{cycle iterations}}`
  905. at each cycle iteration.
  906. This implementation was adapted from the github repo: `bckenstler/CLR`_
  907. Args:
  908. optimizer (Optimizer): Wrapped optimizer.
  909. base_lr (float or list): Initial learning rate which is the
  910. lower boundary in the cycle for each parameter group.
  911. max_lr (float or list): Upper learning rate boundaries in the cycle
  912. for each parameter group. Functionally,
  913. it defines the cycle amplitude (max_lr - base_lr).
  914. The lr at any cycle is the sum of base_lr
  915. and some scaling of the amplitude; therefore
  916. max_lr may not actually be reached depending on
  917. scaling function.
  918. step_size_up (int): Number of training iterations in the
  919. increasing half of a cycle. Default: 2000
  920. step_size_down (int): Number of training iterations in the
  921. decreasing half of a cycle. If step_size_down is None,
  922. it is set to step_size_up. Default: None
  923. mode (str): One of {triangular, triangular2, exp_range}.
  924. Values correspond to policies detailed above.
  925. If scale_fn is not None, this argument is ignored.
  926. Default: 'triangular'
  927. gamma (float): Constant in 'exp_range' scaling function:
  928. gamma**(cycle iterations)
  929. Default: 1.0
  930. scale_fn (function): Custom scaling policy defined by a single
  931. argument lambda function, where
  932. 0 <= scale_fn(x) <= 1 for all x >= 0.
  933. If specified, then 'mode' is ignored.
  934. Default: None
  935. scale_mode (str): {'cycle', 'iterations'}.
  936. Defines whether scale_fn is evaluated on
  937. cycle number or cycle iterations (training
  938. iterations since start of cycle).
  939. Default: 'cycle'
  940. cycle_momentum (bool): If ``True``, momentum is cycled inversely
  941. to learning rate between 'base_momentum' and 'max_momentum'.
  942. Default: True
  943. base_momentum (float or list): Lower momentum boundaries in the cycle
  944. for each parameter group. Note that momentum is cycled inversely
  945. to learning rate; at the peak of a cycle, momentum is
  946. 'base_momentum' and learning rate is 'max_lr'.
  947. Default: 0.8
  948. max_momentum (float or list): Upper momentum boundaries in the cycle
  949. for each parameter group. Functionally,
  950. it defines the cycle amplitude (max_momentum - base_momentum).
  951. The momentum at any cycle is the difference of max_momentum
  952. and some scaling of the amplitude; therefore
  953. base_momentum may not actually be reached depending on
  954. scaling function. Note that momentum is cycled inversely
  955. to learning rate; at the start of a cycle, momentum is 'max_momentum'
  956. and learning rate is 'base_lr'
  957. Default: 0.9
  958. last_epoch (int): The index of the last batch. This parameter is used when
  959. resuming a training job. Since `step()` should be invoked after each
  960. batch instead of after each epoch, this number represents the total
  961. number of *batches* computed, not the total number of epochs computed.
  962. When last_epoch=-1, the schedule is started from the beginning.
  963. Default: -1
  964. verbose (bool): If ``True``, prints a message to stdout for
  965. each update. Default: ``False``.
  966. Example:
  967. >>> # xdoctest: +SKIP
  968. >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
  969. >>> scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1)
  970. >>> data_loader = torch.utils.data.DataLoader(...)
  971. >>> for epoch in range(10):
  972. >>> for batch in data_loader:
  973. >>> train_batch(...)
  974. >>> scheduler.step()
  975. .. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186
  976. .. _bckenstler/CLR: https://github.com/bckenstler/CLR
  977. """
  978. def __init__(self,
  979. optimizer,
  980. base_lr,
  981. max_lr,
  982. step_size_up=2000,
  983. step_size_down=None,
  984. mode='triangular',
  985. gamma=1.,
  986. scale_fn=None,
  987. scale_mode='cycle',
  988. cycle_momentum=True,
  989. base_momentum=0.8,
  990. max_momentum=0.9,
  991. last_epoch=-1,
  992. verbose=False):
  993. # Attach optimizer
  994. if not isinstance(optimizer, Optimizer):
  995. raise TypeError('{} is not an Optimizer'.format(
  996. type(optimizer).__name__))
  997. self.optimizer = optimizer
  998. base_lrs = self._format_param('base_lr', optimizer, base_lr)
  999. if last_epoch == -1:
  1000. for lr, group in zip(base_lrs, optimizer.param_groups):
  1001. group['lr'] = lr
  1002. self.max_lrs = self._format_param('max_lr', optimizer, max_lr)
  1003. step_size_up = float(step_size_up)
  1004. step_size_down = float(step_size_down) if step_size_down is not None else step_size_up
  1005. self.total_size = step_size_up + step_size_down
  1006. self.step_ratio = step_size_up / self.total_size
  1007. if mode not in ['triangular', 'triangular2', 'exp_range'] \
  1008. and scale_fn is None:
  1009. raise ValueError('mode is invalid and scale_fn is None')
  1010. self.mode = mode
  1011. self.gamma = gamma
  1012. self._scale_fn_ref = None
  1013. self._scale_fn_custom = scale_fn
  1014. self.scale_mode = scale_mode
  1015. self._init_scale_fn()
  1016. self.cycle_momentum = cycle_momentum
  1017. if cycle_momentum:
  1018. if 'momentum' not in optimizer.defaults:
  1019. raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled')
  1020. base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
  1021. if last_epoch == -1:
  1022. for momentum, group in zip(base_momentums, optimizer.param_groups):
  1023. group['momentum'] = momentum
  1024. self.base_momentums = [group['momentum'] for group in optimizer.param_groups]
  1025. self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
  1026. super().__init__(optimizer, last_epoch, verbose)
  1027. self.base_lrs = base_lrs
  1028. def _init_scale_fn(self):
  1029. if self._scale_fn_custom is not None:
  1030. return
  1031. if self.mode == 'triangular':
  1032. self._scale_fn_ref = weakref.WeakMethod(self._triangular_scale_fn)
  1033. self.scale_mode = 'cycle'
  1034. elif self.mode == 'triangular2':
  1035. self._scale_fn_ref = weakref.WeakMethod(self._triangular2_scale_fn)
  1036. self.scale_mode = 'cycle'
  1037. elif self.mode == 'exp_range':
  1038. self._scale_fn_ref = weakref.WeakMethod(self._exp_range_scale_fn)
  1039. self.scale_mode = 'iterations'
  1040. def _format_param(self, name, optimizer, param):
  1041. """Return correctly formatted lr/momentum for each param group."""
  1042. if isinstance(param, (list, tuple)):
  1043. if len(param) != len(optimizer.param_groups):
  1044. raise ValueError("expected {} values for {}, got {}".format(
  1045. len(optimizer.param_groups), name, len(param)))
  1046. return param
  1047. else:
  1048. return [param] * len(optimizer.param_groups)
  1049. def scale_fn(self, x):
  1050. if self._scale_fn_custom is not None:
  1051. return self._scale_fn_custom(x)
  1052. else:
  1053. return self._scale_fn_ref()(x)
  1054. def _triangular_scale_fn(self, x):
  1055. return 1.
  1056. def _triangular2_scale_fn(self, x):
  1057. return 1 / (2. ** (x - 1))
  1058. def _exp_range_scale_fn(self, x):
  1059. return self.gamma**(x)
  1060. def get_lr(self):
  1061. """Calculates the learning rate at batch index. This function treats
  1062. `self.last_epoch` as the last batch index.
  1063. If `self.cycle_momentum` is ``True``, this function has a side effect of
  1064. updating the optimizer's momentum.
  1065. """
  1066. if not self._get_lr_called_within_step:
  1067. warnings.warn("To get the last learning rate computed by the scheduler, "
  1068. "please use `get_last_lr()`.", UserWarning)
  1069. cycle = math.floor(1 + self.last_epoch / self.total_size)
  1070. x = 1. + self.last_epoch / self.total_size - cycle
  1071. if x <= self.step_ratio:
  1072. scale_factor = x / self.step_ratio
  1073. else:
  1074. scale_factor = (x - 1) / (self.step_ratio - 1)
  1075. lrs = []
  1076. for base_lr, max_lr in zip(self.base_lrs, self.max_lrs):
  1077. base_height = (max_lr - base_lr) * scale_factor
  1078. if self.scale_mode == 'cycle':
  1079. lr = base_lr + base_height * self.scale_fn(cycle)
  1080. else:
  1081. lr = base_lr + base_height * self.scale_fn(self.last_epoch)
  1082. lrs.append(lr)
  1083. if self.cycle_momentum:
  1084. momentums = []
  1085. for base_momentum, max_momentum in zip(self.base_momentums, self.max_momentums):
  1086. base_height = (max_momentum - base_momentum) * scale_factor
  1087. if self.scale_mode == 'cycle':
  1088. momentum = max_momentum - base_height * self.scale_fn(cycle)
  1089. else:
  1090. momentum = max_momentum - base_height * self.scale_fn(self.last_epoch)
  1091. momentums.append(momentum)
  1092. for param_group, momentum in zip(self.optimizer.param_groups, momentums):
  1093. param_group['momentum'] = momentum
  1094. return lrs
  1095. def state_dict(self):
  1096. state = super().state_dict()
  1097. # We are dropping the `_scale_fn_ref` attribute because it is a `weakref.WeakMethod` and can't be pickled
  1098. state.pop("_scale_fn_ref")
  1099. return state
  1100. def load_state_dict(self, state_dict):
  1101. super().load_state_dict(state_dict)
  1102. self._init_scale_fn()
  1103. class CosineAnnealingWarmRestarts(LRScheduler):
  1104. r"""Set the learning rate of each parameter group using a cosine annealing
  1105. schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
  1106. is the number of epochs since the last restart and :math:`T_{i}` is the number
  1107. of epochs between two warm restarts in SGDR:
  1108. .. math::
  1109. \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
  1110. \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)
  1111. When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
  1112. When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`.
  1113. It has been proposed in
  1114. `SGDR: Stochastic Gradient Descent with Warm Restarts`_.
  1115. Args:
  1116. optimizer (Optimizer): Wrapped optimizer.
  1117. T_0 (int): Number of iterations for the first restart.
  1118. T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1.
  1119. eta_min (float, optional): Minimum learning rate. Default: 0.
  1120. last_epoch (int, optional): The index of last epoch. Default: -1.
  1121. verbose (bool): If ``True``, prints a message to stdout for
  1122. each update. Default: ``False``.
  1123. .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
  1124. https://arxiv.org/abs/1608.03983
  1125. """
  1126. def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False):
  1127. if T_0 <= 0 or not isinstance(T_0, int):
  1128. raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
  1129. if T_mult < 1 or not isinstance(T_mult, int):
  1130. raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
  1131. self.T_0 = T_0
  1132. self.T_i = T_0
  1133. self.T_mult = T_mult
  1134. self.eta_min = eta_min
  1135. self.T_cur = last_epoch
  1136. super().__init__(optimizer, last_epoch, verbose)
  1137. def get_lr(self):
  1138. if not self._get_lr_called_within_step:
  1139. warnings.warn("To get the last learning rate computed by the scheduler, "
  1140. "please use `get_last_lr()`.", UserWarning)
  1141. return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2
  1142. for base_lr in self.base_lrs]
  1143. def step(self, epoch=None):
  1144. """Step could be called after every batch update
  1145. Example:
  1146. >>> # xdoctest: +SKIP("Undefined vars")
  1147. >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
  1148. >>> iters = len(dataloader)
  1149. >>> for epoch in range(20):
  1150. >>> for i, sample in enumerate(dataloader):
  1151. >>> inputs, labels = sample['inputs'], sample['labels']
  1152. >>> optimizer.zero_grad()
  1153. >>> outputs = net(inputs)
  1154. >>> loss = criterion(outputs, labels)
  1155. >>> loss.backward()
  1156. >>> optimizer.step()
  1157. >>> scheduler.step(epoch + i / iters)
  1158. This function can be called in an interleaved way.
  1159. Example:
  1160. >>> # xdoctest: +SKIP("Undefined vars")
  1161. >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
  1162. >>> for epoch in range(20):
  1163. >>> scheduler.step()
  1164. >>> scheduler.step(26)
  1165. >>> scheduler.step() # scheduler.step(27), instead of scheduler(20)
  1166. """
  1167. if epoch is None and self.last_epoch < 0:
  1168. epoch = 0
  1169. if epoch is None:
  1170. epoch = self.last_epoch + 1
  1171. self.T_cur = self.T_cur + 1
  1172. if self.T_cur >= self.T_i:
  1173. self.T_cur = self.T_cur - self.T_i
  1174. self.T_i = self.T_i * self.T_mult
  1175. else:
  1176. if epoch < 0:
  1177. raise ValueError("Expected non-negative epoch, but got {}".format(epoch))
  1178. if epoch >= self.T_0:
  1179. if self.T_mult == 1:
  1180. self.T_cur = epoch % self.T_0
  1181. else:
  1182. n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
  1183. self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
  1184. self.T_i = self.T_0 * self.T_mult ** (n)
  1185. else:
  1186. self.T_i = self.T_0
  1187. self.T_cur = epoch
  1188. self.last_epoch = math.floor(epoch)
  1189. class _enable_get_lr_call:
  1190. def __init__(self, o):
  1191. self.o = o
  1192. def __enter__(self):
  1193. self.o._get_lr_called_within_step = True
  1194. return self
  1195. def __exit__(self, type, value, traceback):
  1196. self.o._get_lr_called_within_step = False
  1197. return self
  1198. with _enable_get_lr_call(self):
  1199. for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())):
  1200. param_group, lr = data
  1201. param_group['lr'] = lr
  1202. self.print_lr(self.verbose, i, lr, epoch)
  1203. self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
  1204. class OneCycleLR(LRScheduler):
  1205. r"""Sets the learning rate of each parameter group according to the
  1206. 1cycle learning rate policy. The 1cycle policy anneals the learning
  1207. rate from an initial learning rate to some maximum learning rate and then
  1208. from that maximum learning rate to some minimum learning rate much lower
  1209. than the initial learning rate.
  1210. This policy was initially described in the paper `Super-Convergence:
  1211. Very Fast Training of Neural Networks Using Large Learning Rates`_.
  1212. The 1cycle learning rate policy changes the learning rate after every batch.
  1213. `step` should be called after a batch has been used for training.
  1214. This scheduler is not chainable.
  1215. Note also that the total number of steps in the cycle can be determined in one
  1216. of two ways (listed in order of precedence):
  1217. #. A value for total_steps is explicitly provided.
  1218. #. A number of epochs (epochs) and a number of steps per epoch
  1219. (steps_per_epoch) are provided.
  1220. In this case, the number of total steps is inferred by
  1221. total_steps = epochs * steps_per_epoch
  1222. You must either provide a value for total_steps or provide a value for both
  1223. epochs and steps_per_epoch.
  1224. The default behaviour of this scheduler follows the fastai implementation of 1cycle, which
  1225. claims that "unpublished work has shown even better results by using only two phases". To
  1226. mimic the behaviour of the original paper instead, set ``three_phase=True``.
  1227. Args:
  1228. optimizer (Optimizer): Wrapped optimizer.
  1229. max_lr (float or list): Upper learning rate boundaries in the cycle
  1230. for each parameter group.
  1231. total_steps (int): The total number of steps in the cycle. Note that
  1232. if a value is not provided here, then it must be inferred by providing
  1233. a value for epochs and steps_per_epoch.
  1234. Default: None
  1235. epochs (int): The number of epochs to train for. This is used along
  1236. with steps_per_epoch in order to infer the total number of steps in the cycle
  1237. if a value for total_steps is not provided.
  1238. Default: None
  1239. steps_per_epoch (int): The number of steps per epoch to train for. This is
  1240. used along with epochs in order to infer the total number of steps in the
  1241. cycle if a value for total_steps is not provided.
  1242. Default: None
  1243. pct_start (float): The percentage of the cycle (in number of steps) spent
  1244. increasing the learning rate.
  1245. Default: 0.3
  1246. anneal_strategy (str): {'cos', 'linear'}
  1247. Specifies the annealing strategy: "cos" for cosine annealing, "linear" for
  1248. linear annealing.
  1249. Default: 'cos'
  1250. cycle_momentum (bool): If ``True``, momentum is cycled inversely
  1251. to learning rate between 'base_momentum' and 'max_momentum'.
  1252. Default: True
  1253. base_momentum (float or list): Lower momentum boundaries in the cycle
  1254. for each parameter group. Note that momentum is cycled inversely
  1255. to learning rate; at the peak of a cycle, momentum is
  1256. 'base_momentum' and learning rate is 'max_lr'.
  1257. Default: 0.85
  1258. max_momentum (float or list): Upper momentum boundaries in the cycle
  1259. for each parameter group. Functionally,
  1260. it defines the cycle amplitude (max_momentum - base_momentum).
  1261. Note that momentum is cycled inversely
  1262. to learning rate; at the start of a cycle, momentum is 'max_momentum'
  1263. and learning rate is 'base_lr'
  1264. Default: 0.95
  1265. div_factor (float): Determines the initial learning rate via
  1266. initial_lr = max_lr/div_factor
  1267. Default: 25
  1268. final_div_factor (float): Determines the minimum learning rate via
  1269. min_lr = initial_lr/final_div_factor
  1270. Default: 1e4
  1271. three_phase (bool): If ``True``, use a third phase of the schedule to annihilate the
  1272. learning rate according to 'final_div_factor' instead of modifying the second
  1273. phase (the first two phases will be symmetrical about the step indicated by
  1274. 'pct_start').
  1275. last_epoch (int): The index of the last batch. This parameter is used when
  1276. resuming a training job. Since `step()` should be invoked after each
  1277. batch instead of after each epoch, this number represents the total
  1278. number of *batches* computed, not the total number of epochs computed.
  1279. When last_epoch=-1, the schedule is started from the beginning.
  1280. Default: -1
  1281. verbose (bool): If ``True``, prints a message to stdout for
  1282. each update. Default: ``False``.
  1283. Example:
  1284. >>> # xdoctest: +SKIP
  1285. >>> data_loader = torch.utils.data.DataLoader(...)
  1286. >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
  1287. >>> scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10)
  1288. >>> for epoch in range(10):
  1289. >>> for batch in data_loader:
  1290. >>> train_batch(...)
  1291. >>> scheduler.step()
  1292. .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
  1293. https://arxiv.org/abs/1708.07120
  1294. """
  1295. def __init__(self,
  1296. optimizer,
  1297. max_lr,
  1298. total_steps=None,
  1299. epochs=None,
  1300. steps_per_epoch=None,
  1301. pct_start=0.3,
  1302. anneal_strategy='cos',
  1303. cycle_momentum=True,
  1304. base_momentum=0.85,
  1305. max_momentum=0.95,
  1306. div_factor=25.,
  1307. final_div_factor=1e4,
  1308. three_phase=False,
  1309. last_epoch=-1,
  1310. verbose=False):
  1311. # Validate optimizer
  1312. if not isinstance(optimizer, Optimizer):
  1313. raise TypeError('{} is not an Optimizer'.format(
  1314. type(optimizer).__name__))
  1315. self.optimizer = optimizer
  1316. # Validate total_steps
  1317. if total_steps is None and epochs is None and steps_per_epoch is None:
  1318. raise ValueError("You must define either total_steps OR (epochs AND steps_per_epoch)")
  1319. elif total_steps is not None:
  1320. if total_steps <= 0 or not isinstance(total_steps, int):
  1321. raise ValueError("Expected positive integer total_steps, but got {}".format(total_steps))
  1322. self.total_steps = total_steps
  1323. else:
  1324. if epochs <= 0 or not isinstance(epochs, int):
  1325. raise ValueError("Expected positive integer epochs, but got {}".format(epochs))
  1326. if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int):
  1327. raise ValueError("Expected positive integer steps_per_epoch, but got {}".format(steps_per_epoch))
  1328. self.total_steps = epochs * steps_per_epoch
  1329. if three_phase:
  1330. self._schedule_phases = [
  1331. {
  1332. 'end_step': float(pct_start * self.total_steps) - 1,
  1333. 'start_lr': 'initial_lr',
  1334. 'end_lr': 'max_lr',
  1335. 'start_momentum': 'max_momentum',
  1336. 'end_momentum': 'base_momentum',
  1337. },
  1338. {
  1339. 'end_step': float(2 * pct_start * self.total_steps) - 2,
  1340. 'start_lr': 'max_lr',
  1341. 'end_lr': 'initial_lr',
  1342. 'start_momentum': 'base_momentum',
  1343. 'end_momentum': 'max_momentum',
  1344. },
  1345. {
  1346. 'end_step': self.total_steps - 1,
  1347. 'start_lr': 'initial_lr',
  1348. 'end_lr': 'min_lr',
  1349. 'start_momentum': 'max_momentum',
  1350. 'end_momentum': 'max_momentum',
  1351. },
  1352. ]
  1353. else:
  1354. self._schedule_phases = [
  1355. {
  1356. 'end_step': float(pct_start * self.total_steps) - 1,
  1357. 'start_lr': 'initial_lr',
  1358. 'end_lr': 'max_lr',
  1359. 'start_momentum': 'max_momentum',
  1360. 'end_momentum': 'base_momentum',
  1361. },
  1362. {
  1363. 'end_step': self.total_steps - 1,
  1364. 'start_lr': 'max_lr',
  1365. 'end_lr': 'min_lr',
  1366. 'start_momentum': 'base_momentum',
  1367. 'end_momentum': 'max_momentum',
  1368. },
  1369. ]
  1370. # Validate pct_start
  1371. if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
  1372. raise ValueError("Expected float between 0 and 1 pct_start, but got {}".format(pct_start))
  1373. # Validate anneal_strategy
  1374. if anneal_strategy not in ['cos', 'linear']:
  1375. raise ValueError("anneal_strategy must by one of 'cos' or 'linear', instead got {}".format(anneal_strategy))
  1376. elif anneal_strategy == 'cos':
  1377. self.anneal_func = self._annealing_cos
  1378. elif anneal_strategy == 'linear':
  1379. self.anneal_func = self._annealing_linear
  1380. # Initialize learning rate variables
  1381. max_lrs = self._format_param('max_lr', self.optimizer, max_lr)
  1382. if last_epoch == -1:
  1383. for idx, group in enumerate(self.optimizer.param_groups):
  1384. group['initial_lr'] = max_lrs[idx] / div_factor
  1385. group['max_lr'] = max_lrs[idx]
  1386. group['min_lr'] = group['initial_lr'] / final_div_factor
  1387. # Initialize momentum variables
  1388. self.cycle_momentum = cycle_momentum
  1389. if self.cycle_momentum:
  1390. if 'momentum' not in self.optimizer.defaults and 'betas' not in self.optimizer.defaults:
  1391. raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled')
  1392. self.use_beta1 = 'betas' in self.optimizer.defaults
  1393. max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
  1394. base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
  1395. if last_epoch == -1:
  1396. for m_momentum, b_momentum, group in zip(max_momentums, base_momentums, optimizer.param_groups):
  1397. if self.use_beta1:
  1398. group['betas'] = (m_momentum, *group['betas'][1:])
  1399. else:
  1400. group['momentum'] = m_momentum
  1401. group['max_momentum'] = m_momentum
  1402. group['base_momentum'] = b_momentum
  1403. super().__init__(optimizer, last_epoch, verbose)
  1404. def _format_param(self, name, optimizer, param):
  1405. """Return correctly formatted lr/momentum for each param group."""
  1406. if isinstance(param, (list, tuple)):
  1407. if len(param) != len(optimizer.param_groups):
  1408. raise ValueError("expected {} values for {}, got {}".format(
  1409. len(optimizer.param_groups), name, len(param)))
  1410. return param
  1411. else:
  1412. return [param] * len(optimizer.param_groups)
  1413. def _annealing_cos(self, start, end, pct):
  1414. "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
  1415. cos_out = math.cos(math.pi * pct) + 1
  1416. return end + (start - end) / 2.0 * cos_out
  1417. def _annealing_linear(self, start, end, pct):
  1418. "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."
  1419. return (end - start) * pct + start
  1420. def get_lr(self):
  1421. if not self._get_lr_called_within_step:
  1422. warnings.warn("To get the last learning rate computed by the scheduler, "
  1423. "please use `get_last_lr()`.", UserWarning)
  1424. lrs = []
  1425. step_num = self.last_epoch
  1426. if step_num > self.total_steps:
  1427. raise ValueError("Tried to step {} times. The specified number of total steps is {}"
  1428. .format(step_num, self.total_steps))
  1429. for group in self.optimizer.param_groups:
  1430. start_step = 0
  1431. for i, phase in enumerate(self._schedule_phases):
  1432. end_step = phase['end_step']
  1433. if step_num <= end_step or i == len(self._schedule_phases) - 1:
  1434. pct = (step_num - start_step) / (end_step - start_step)
  1435. computed_lr = self.anneal_func(group[phase['start_lr']], group[phase['end_lr']], pct)
  1436. if self.cycle_momentum:
  1437. computed_momentum = self.anneal_func(group[phase['start_momentum']], group[phase['end_momentum']], pct)
  1438. break
  1439. start_step = phase['end_step']
  1440. lrs.append(computed_lr)
  1441. if self.cycle_momentum:
  1442. if self.use_beta1:
  1443. group['betas'] = (computed_momentum, *group['betas'][1:])
  1444. else:
  1445. group['momentum'] = computed_momentum
  1446. return lrs