123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728 |
- import types
- import math
- from torch import inf
- from functools import wraps
- import warnings
- import weakref
- from collections import Counter
- from bisect import bisect_right
- from .optimizer import Optimizer
- __all__ = ['LambdaLR', 'MultiplicativeLR', 'StepLR', 'MultiStepLR', 'ConstantLR', 'LinearLR',
- 'ExponentialLR', 'SequentialLR', 'CosineAnnealingLR', 'ChainedScheduler', 'ReduceLROnPlateau',
- 'CyclicLR', 'CosineAnnealingWarmRestarts', 'OneCycleLR', 'PolynomialLR', 'LRScheduler']
- EPOCH_DEPRECATION_WARNING = (
- "The epoch parameter in `scheduler.step()` was not necessary and is being "
- "deprecated where possible. Please use `scheduler.step()` to step the "
- "scheduler. During the deprecation, if epoch is different from None, the "
- "closed form is used instead of the new chainable form, where available. "
- "Please open an issue if you are unable to replicate your use case: "
- "https://github.com/pytorch/pytorch/issues/new/choose."
- )
- class LRScheduler:
- def __init__(self, optimizer, last_epoch=-1, verbose=False):
- # Attach optimizer
- if not isinstance(optimizer, Optimizer):
- raise TypeError('{} is not an Optimizer'.format(
- type(optimizer).__name__))
- self.optimizer = optimizer
- # Initialize epoch and base learning rates
- if last_epoch == -1:
- for group in optimizer.param_groups:
- group.setdefault('initial_lr', group['lr'])
- else:
- for i, group in enumerate(optimizer.param_groups):
- if 'initial_lr' not in group:
- raise KeyError("param 'initial_lr' is not specified "
- "in param_groups[{}] when resuming an optimizer".format(i))
- self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups]
- self.last_epoch = last_epoch
- # Following https://github.com/pytorch/pytorch/issues/20124
- # We would like to ensure that `lr_scheduler.step()` is called after
- # `optimizer.step()`
- def with_counter(method):
- if getattr(method, '_with_counter', False):
- # `optimizer.step()` has already been replaced, return.
- return method
- # Keep a weak reference to the optimizer instance to prevent
- # cyclic references.
- instance_ref = weakref.ref(method.__self__)
- # Get the unbound method for the same purpose.
- func = method.__func__
- cls = instance_ref().__class__
- del method
- @wraps(func)
- def wrapper(*args, **kwargs):
- instance = instance_ref()
- instance._step_count += 1
- wrapped = func.__get__(instance, cls)
- return wrapped(*args, **kwargs)
- # Note that the returned function here is no longer a bound method,
- # so attributes like `__func__` and `__self__` no longer exist.
- wrapper._with_counter = True
- return wrapper
- self.optimizer.step = with_counter(self.optimizer.step)
- self.verbose = verbose
- self._initial_step()
- def _initial_step(self):
- """Initialize step counts and performs a step"""
- self.optimizer._step_count = 0
- self._step_count = 0
- self.step()
- def state_dict(self):
- """Returns the state of the scheduler as a :class:`dict`.
- It contains an entry for every variable in self.__dict__ which
- is not the optimizer.
- """
- return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
- def load_state_dict(self, state_dict):
- """Loads the schedulers state.
- Args:
- state_dict (dict): scheduler state. Should be an object returned
- from a call to :meth:`state_dict`.
- """
- self.__dict__.update(state_dict)
- def get_last_lr(self):
- """ Return last computed learning rate by current scheduler.
- """
- return self._last_lr
- def get_lr(self):
- # Compute learning rate using chainable form of the scheduler
- raise NotImplementedError
- def print_lr(self, is_verbose, group, lr, epoch=None):
- """Display the current learning rate.
- """
- if is_verbose:
- if epoch is None:
- print('Adjusting learning rate'
- ' of group {} to {:.4e}.'.format(group, lr))
- else:
- epoch_str = ("%.2f" if isinstance(epoch, float) else
- "%.5d") % epoch
- print('Epoch {}: adjusting learning rate'
- ' of group {} to {:.4e}.'.format(epoch_str, group, lr))
- def step(self, epoch=None):
- # Raise a warning if old pattern is detected
- # https://github.com/pytorch/pytorch/issues/20124
- if self._step_count == 1:
- if not hasattr(self.optimizer.step, "_with_counter"):
- warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler "
- "initialization. Please, make sure to call `optimizer.step()` before "
- "`lr_scheduler.step()`. See more details at "
- "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
- # Just check if there were two first lr_scheduler.step() calls before optimizer.step()
- elif self.optimizer._step_count < 1:
- warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
- "In PyTorch 1.1.0 and later, you should call them in the opposite order: "
- "`optimizer.step()` before `lr_scheduler.step()`. Failure to do this "
- "will result in PyTorch skipping the first value of the learning rate schedule. "
- "See more details at "
- "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
- self._step_count += 1
- with _enable_get_lr_call(self):
- if epoch is None:
- self.last_epoch += 1
- values = self.get_lr()
- else:
- warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
- self.last_epoch = epoch
- if hasattr(self, "_get_closed_form_lr"):
- values = self._get_closed_form_lr()
- else:
- values = self.get_lr()
- for i, data in enumerate(zip(self.optimizer.param_groups, values)):
- param_group, lr = data
- param_group['lr'] = lr
- self.print_lr(self.verbose, i, lr, epoch)
- self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
- # Including _LRScheduler for backwards compatibility
- # Subclass instead of assign because we want __name__ of _LRScheduler to be _LRScheduler (assigning would make it LRScheduler).
- class _LRScheduler(LRScheduler):
- pass
- class _enable_get_lr_call:
- def __init__(self, o):
- self.o = o
- def __enter__(self):
- self.o._get_lr_called_within_step = True
- return self
- def __exit__(self, type, value, traceback):
- self.o._get_lr_called_within_step = False
- class LambdaLR(LRScheduler):
- """Sets the learning rate of each parameter group to the initial lr
- times a given function. When last_epoch=-1, sets initial lr as lr.
- Args:
- optimizer (Optimizer): Wrapped optimizer.
- lr_lambda (function or list): A function which computes a multiplicative
- factor given an integer parameter epoch, or a list of such
- functions, one for each group in optimizer.param_groups.
- last_epoch (int): The index of last epoch. Default: -1.
- verbose (bool): If ``True``, prints a message to stdout for
- each update. Default: ``False``.
- Example:
- >>> # xdoctest: +SKIP
- >>> # Assuming optimizer has two groups.
- >>> lambda1 = lambda epoch: epoch // 30
- >>> lambda2 = lambda epoch: 0.95 ** epoch
- >>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
- >>> for epoch in range(100):
- >>> train(...)
- >>> validate(...)
- >>> scheduler.step()
- """
- def __init__(self, optimizer, lr_lambda, last_epoch=-1, verbose=False):
- self.optimizer = optimizer
- if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
- self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
- else:
- if len(lr_lambda) != len(optimizer.param_groups):
- raise ValueError("Expected {} lr_lambdas, but got {}".format(
- len(optimizer.param_groups), len(lr_lambda)))
- self.lr_lambdas = list(lr_lambda)
- super().__init__(optimizer, last_epoch, verbose)
- def state_dict(self):
- """Returns the state of the scheduler as a :class:`dict`.
- It contains an entry for every variable in self.__dict__ which
- is not the optimizer.
- The learning rate lambda functions will only be saved if they are callable objects
- and not if they are functions or lambdas.
- When saving or loading the scheduler, please make sure to also save or load the state of the optimizer.
- """
- state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')}
- state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas)
- for idx, fn in enumerate(self.lr_lambdas):
- if not isinstance(fn, types.FunctionType):
- state_dict['lr_lambdas'][idx] = fn.__dict__.copy()
- return state_dict
- def load_state_dict(self, state_dict):
- """Loads the schedulers state.
- When saving or loading the scheduler, please make sure to also save or load the state of the optimizer.
- Args:
- state_dict (dict): scheduler state. Should be an object returned
- from a call to :meth:`state_dict`.
- """
- lr_lambdas = state_dict.pop('lr_lambdas')
- self.__dict__.update(state_dict)
- # Restore state_dict keys in order to prevent side effects
- # https://github.com/pytorch/pytorch/issues/32756
- state_dict['lr_lambdas'] = lr_lambdas
- for idx, fn in enumerate(lr_lambdas):
- if fn is not None:
- self.lr_lambdas[idx].__dict__.update(fn)
- def get_lr(self):
- if not self._get_lr_called_within_step:
- warnings.warn("To get the last learning rate computed by the scheduler, "
- "please use `get_last_lr()`.")
- return [base_lr * lmbda(self.last_epoch)
- for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]
- class MultiplicativeLR(LRScheduler):
- """Multiply the learning rate of each parameter group by the factor given
- in the specified function. When last_epoch=-1, sets initial lr as lr.
- Args:
- optimizer (Optimizer): Wrapped optimizer.
- lr_lambda (function or list): A function which computes a multiplicative
- factor given an integer parameter epoch, or a list of such
- functions, one for each group in optimizer.param_groups.
- last_epoch (int): The index of last epoch. Default: -1.
- verbose (bool): If ``True``, prints a message to stdout for
- each update. Default: ``False``.
- Example:
- >>> # xdoctest: +SKIP
- >>> lmbda = lambda epoch: 0.95
- >>> scheduler = MultiplicativeLR(optimizer, lr_lambda=lmbda)
- >>> for epoch in range(100):
- >>> train(...)
- >>> validate(...)
- >>> scheduler.step()
- """
- def __init__(self, optimizer, lr_lambda, last_epoch=-1, verbose=False):
- self.optimizer = optimizer
- if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
- self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
- else:
- if len(lr_lambda) != len(optimizer.param_groups):
- raise ValueError("Expected {} lr_lambdas, but got {}".format(
- len(optimizer.param_groups), len(lr_lambda)))
- self.lr_lambdas = list(lr_lambda)
- super().__init__(optimizer, last_epoch, verbose)
- def state_dict(self):
- """Returns the state of the scheduler as a :class:`dict`.
- It contains an entry for every variable in self.__dict__ which
- is not the optimizer.
- The learning rate lambda functions will only be saved if they are callable objects
- and not if they are functions or lambdas.
- """
- state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')}
- state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas)
- for idx, fn in enumerate(self.lr_lambdas):
- if not isinstance(fn, types.FunctionType):
- state_dict['lr_lambdas'][idx] = fn.__dict__.copy()
- return state_dict
- def load_state_dict(self, state_dict):
- """Loads the schedulers state.
- Args:
- state_dict (dict): scheduler state. Should be an object returned
- from a call to :meth:`state_dict`.
- """
- lr_lambdas = state_dict.pop('lr_lambdas')
- self.__dict__.update(state_dict)
- # Restore state_dict keys in order to prevent side effects
- # https://github.com/pytorch/pytorch/issues/32756
- state_dict['lr_lambdas'] = lr_lambdas
- for idx, fn in enumerate(lr_lambdas):
- if fn is not None:
- self.lr_lambdas[idx].__dict__.update(fn)
- def get_lr(self):
- if not self._get_lr_called_within_step:
- warnings.warn("To get the last learning rate computed by the scheduler, "
- "please use `get_last_lr()`.", UserWarning)
- if self.last_epoch > 0:
- return [group['lr'] * lmbda(self.last_epoch)
- for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups)]
- else:
- return [group['lr'] for group in self.optimizer.param_groups]
- class StepLR(LRScheduler):
- """Decays the learning rate of each parameter group by gamma every
- step_size epochs. Notice that such decay can happen simultaneously with
- other changes to the learning rate from outside this scheduler. When
- last_epoch=-1, sets initial lr as lr.
- Args:
- optimizer (Optimizer): Wrapped optimizer.
- step_size (int): Period of learning rate decay.
- gamma (float): Multiplicative factor of learning rate decay.
- Default: 0.1.
- last_epoch (int): The index of last epoch. Default: -1.
- verbose (bool): If ``True``, prints a message to stdout for
- each update. Default: ``False``.
- Example:
- >>> # xdoctest: +SKIP
- >>> # Assuming optimizer uses lr = 0.05 for all groups
- >>> # lr = 0.05 if epoch < 30
- >>> # lr = 0.005 if 30 <= epoch < 60
- >>> # lr = 0.0005 if 60 <= epoch < 90
- >>> # ...
- >>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
- >>> for epoch in range(100):
- >>> train(...)
- >>> validate(...)
- >>> scheduler.step()
- """
- def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1, verbose=False):
- self.step_size = step_size
- self.gamma = gamma
- super().__init__(optimizer, last_epoch, verbose)
- def get_lr(self):
- if not self._get_lr_called_within_step:
- warnings.warn("To get the last learning rate computed by the scheduler, "
- "please use `get_last_lr()`.", UserWarning)
- if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0):
- return [group['lr'] for group in self.optimizer.param_groups]
- return [group['lr'] * self.gamma
- for group in self.optimizer.param_groups]
- def _get_closed_form_lr(self):
- return [base_lr * self.gamma ** (self.last_epoch // self.step_size)
- for base_lr in self.base_lrs]
- class MultiStepLR(LRScheduler):
- """Decays the learning rate of each parameter group by gamma once the
- number of epoch reaches one of the milestones. Notice that such decay can
- happen simultaneously with other changes to the learning rate from outside
- this scheduler. When last_epoch=-1, sets initial lr as lr.
- Args:
- optimizer (Optimizer): Wrapped optimizer.
- milestones (list): List of epoch indices. Must be increasing.
- gamma (float): Multiplicative factor of learning rate decay.
- Default: 0.1.
- last_epoch (int): The index of last epoch. Default: -1.
- verbose (bool): If ``True``, prints a message to stdout for
- each update. Default: ``False``.
- Example:
- >>> # xdoctest: +SKIP
- >>> # Assuming optimizer uses lr = 0.05 for all groups
- >>> # lr = 0.05 if epoch < 30
- >>> # lr = 0.005 if 30 <= epoch < 80
- >>> # lr = 0.0005 if epoch >= 80
- >>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
- >>> for epoch in range(100):
- >>> train(...)
- >>> validate(...)
- >>> scheduler.step()
- """
- def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1, verbose=False):
- self.milestones = Counter(milestones)
- self.gamma = gamma
- super().__init__(optimizer, last_epoch, verbose)
- def get_lr(self):
- if not self._get_lr_called_within_step:
- warnings.warn("To get the last learning rate computed by the scheduler, "
- "please use `get_last_lr()`.", UserWarning)
- if self.last_epoch not in self.milestones:
- return [group['lr'] for group in self.optimizer.param_groups]
- return [group['lr'] * self.gamma ** self.milestones[self.last_epoch]
- for group in self.optimizer.param_groups]
- def _get_closed_form_lr(self):
- milestones = sorted(self.milestones.elements())
- return [base_lr * self.gamma ** bisect_right(milestones, self.last_epoch)
- for base_lr in self.base_lrs]
- class ConstantLR(LRScheduler):
- """Decays the learning rate of each parameter group by a small constant factor until the
- number of epoch reaches a pre-defined milestone: total_iters. Notice that such decay can
- happen simultaneously with other changes to the learning rate from outside this scheduler.
- When last_epoch=-1, sets initial lr as lr.
- Args:
- optimizer (Optimizer): Wrapped optimizer.
- factor (float): The number we multiply learning rate until the milestone. Default: 1./3.
- total_iters (int): The number of steps that the scheduler decays the learning rate.
- Default: 5.
- last_epoch (int): The index of the last epoch. Default: -1.
- verbose (bool): If ``True``, prints a message to stdout for
- each update. Default: ``False``.
- Example:
- >>> # xdoctest: +SKIP
- >>> # Assuming optimizer uses lr = 0.05 for all groups
- >>> # lr = 0.025 if epoch == 0
- >>> # lr = 0.025 if epoch == 1
- >>> # lr = 0.025 if epoch == 2
- >>> # lr = 0.025 if epoch == 3
- >>> # lr = 0.05 if epoch >= 4
- >>> scheduler = ConstantLR(self.opt, factor=0.5, total_iters=4)
- >>> for epoch in range(100):
- >>> train(...)
- >>> validate(...)
- >>> scheduler.step()
- """
- def __init__(self, optimizer, factor=1.0 / 3, total_iters=5, last_epoch=-1, verbose=False):
- if factor > 1.0 or factor < 0:
- raise ValueError('Constant multiplicative factor expected to be between 0 and 1.')
- self.factor = factor
- self.total_iters = total_iters
- super().__init__(optimizer, last_epoch, verbose)
- def get_lr(self):
- if not self._get_lr_called_within_step:
- warnings.warn("To get the last learning rate computed by the scheduler, "
- "please use `get_last_lr()`.", UserWarning)
- if self.last_epoch == 0:
- return [group['lr'] * self.factor for group in self.optimizer.param_groups]
- if (self.last_epoch > self.total_iters or
- (self.last_epoch != self.total_iters)):
- return [group['lr'] for group in self.optimizer.param_groups]
- if (self.last_epoch == self.total_iters):
- return [group['lr'] * (1.0 / self.factor) for group in self.optimizer.param_groups]
- def _get_closed_form_lr(self):
- return [base_lr * (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor))
- for base_lr in self.base_lrs]
- class LinearLR(LRScheduler):
- """Decays the learning rate of each parameter group by linearly changing small
- multiplicative factor until the number of epoch reaches a pre-defined milestone: total_iters.
- Notice that such decay can happen simultaneously with other changes to the learning rate
- from outside this scheduler. When last_epoch=-1, sets initial lr as lr.
- Args:
- optimizer (Optimizer): Wrapped optimizer.
- start_factor (float): The number we multiply learning rate in the first epoch.
- The multiplication factor changes towards end_factor in the following epochs.
- Default: 1./3.
- end_factor (float): The number we multiply learning rate at the end of linear changing
- process. Default: 1.0.
- total_iters (int): The number of iterations that multiplicative factor reaches to 1.
- Default: 5.
- last_epoch (int): The index of the last epoch. Default: -1.
- verbose (bool): If ``True``, prints a message to stdout for
- each update. Default: ``False``.
- Example:
- >>> # xdoctest: +SKIP
- >>> # Assuming optimizer uses lr = 0.05 for all groups
- >>> # lr = 0.025 if epoch == 0
- >>> # lr = 0.03125 if epoch == 1
- >>> # lr = 0.0375 if epoch == 2
- >>> # lr = 0.04375 if epoch == 3
- >>> # lr = 0.05 if epoch >= 4
- >>> scheduler = LinearLR(self.opt, start_factor=0.5, total_iters=4)
- >>> for epoch in range(100):
- >>> train(...)
- >>> validate(...)
- >>> scheduler.step()
- """
- def __init__(self, optimizer, start_factor=1.0 / 3, end_factor=1.0, total_iters=5, last_epoch=-1,
- verbose=False):
- if start_factor > 1.0 or start_factor <= 0:
- raise ValueError('Starting multiplicative factor expected to be greater than 0 and less or equal to 1.')
- if end_factor > 1.0 or end_factor < 0:
- raise ValueError('Ending multiplicative factor expected to be between 0 and 1.')
- self.start_factor = start_factor
- self.end_factor = end_factor
- self.total_iters = total_iters
- super().__init__(optimizer, last_epoch, verbose)
- def get_lr(self):
- if not self._get_lr_called_within_step:
- warnings.warn("To get the last learning rate computed by the scheduler, "
- "please use `get_last_lr()`.", UserWarning)
- if self.last_epoch == 0:
- return [group['lr'] * self.start_factor for group in self.optimizer.param_groups]
- if self.last_epoch > self.total_iters:
- return [group['lr'] for group in self.optimizer.param_groups]
- return [group['lr'] * (1. + (self.end_factor - self.start_factor) /
- (self.total_iters * self.start_factor + (self.last_epoch - 1) * (self.end_factor - self.start_factor)))
- for group in self.optimizer.param_groups]
- def _get_closed_form_lr(self):
- return [base_lr * (self.start_factor +
- (self.end_factor - self.start_factor) * min(self.total_iters, self.last_epoch) / self.total_iters)
- for base_lr in self.base_lrs]
- class ExponentialLR(LRScheduler):
- """Decays the learning rate of each parameter group by gamma every epoch.
- When last_epoch=-1, sets initial lr as lr.
- Args:
- optimizer (Optimizer): Wrapped optimizer.
- gamma (float): Multiplicative factor of learning rate decay.
- last_epoch (int): The index of last epoch. Default: -1.
- verbose (bool): If ``True``, prints a message to stdout for
- each update. Default: ``False``.
- """
- def __init__(self, optimizer, gamma, last_epoch=-1, verbose=False):
- self.gamma = gamma
- super().__init__(optimizer, last_epoch, verbose)
- def get_lr(self):
- if not self._get_lr_called_within_step:
- warnings.warn("To get the last learning rate computed by the scheduler, "
- "please use `get_last_lr()`.", UserWarning)
- if self.last_epoch == 0:
- return [group['lr'] for group in self.optimizer.param_groups]
- return [group['lr'] * self.gamma
- for group in self.optimizer.param_groups]
- def _get_closed_form_lr(self):
- return [base_lr * self.gamma ** self.last_epoch
- for base_lr in self.base_lrs]
- class SequentialLR(LRScheduler):
- """Receives the list of schedulers that is expected to be called sequentially during
- optimization process and milestone points that provides exact intervals to reflect
- which scheduler is supposed to be called at a given epoch.
- Args:
- optimizer (Optimizer): Wrapped optimizer.
- schedulers (list): List of chained schedulers.
- milestones (list): List of integers that reflects milestone points.
- last_epoch (int): The index of last epoch. Default: -1.
- verbose (bool): Does nothing.
- Example:
- >>> # xdoctest: +SKIP
- >>> # Assuming optimizer uses lr = 1. for all groups
- >>> # lr = 0.1 if epoch == 0
- >>> # lr = 0.1 if epoch == 1
- >>> # lr = 0.9 if epoch == 2
- >>> # lr = 0.81 if epoch == 3
- >>> # lr = 0.729 if epoch == 4
- >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2)
- >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9)
- >>> scheduler = SequentialLR(self.opt, schedulers=[scheduler1, scheduler2], milestones=[2])
- >>> for epoch in range(100):
- >>> train(...)
- >>> validate(...)
- >>> scheduler.step()
- """
- def __init__(self, optimizer, schedulers, milestones, last_epoch=-1, verbose=False):
- for scheduler_idx in range(len(schedulers)):
- if schedulers[scheduler_idx].optimizer != optimizer:
- raise ValueError(
- "Sequential Schedulers expects all schedulers to belong to the same optimizer, but "
- f"got schedulers at index {scheduler_idx} to be different than the optimizer passed in."
- )
- if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer):
- raise ValueError(
- "Sequential Schedulers expects all schedulers to belong to the same optimizer, but "
- f"got schedulers at index {0} and {scheduler_idx} to be different."
- )
- if (len(milestones) != len(schedulers) - 1):
- raise ValueError(
- "Sequential Schedulers expects number of schedulers provided to be one more "
- "than the number of milestone points, but got number of schedulers {} and the "
- "number of milestones to be equal to {}".format(len(schedulers), len(milestones))
- )
- self._schedulers = schedulers
- self._milestones = milestones
- self.last_epoch = last_epoch + 1
- self.optimizer = optimizer
- # Reset learning rates back to initial values
- for group in self.optimizer.param_groups:
- group["lr"] = group["initial_lr"]
- # "Undo" the step performed by other schedulers
- for scheduler in self._schedulers:
- scheduler.last_epoch -= 1
- # Perform the initial step for only the first scheduler
- self._schedulers[0]._initial_step()
- self._last_lr = schedulers[0].get_last_lr()
- def step(self):
- self.last_epoch += 1
- idx = bisect_right(self._milestones, self.last_epoch)
- scheduler = self._schedulers[idx]
- if idx > 0 and self._milestones[idx - 1] == self.last_epoch:
- scheduler.step(0)
- else:
- scheduler.step()
- self._last_lr = scheduler.get_last_lr()
- def state_dict(self):
- """Returns the state of the scheduler as a :class:`dict`.
- It contains an entry for every variable in self.__dict__ which
- is not the optimizer.
- The wrapped scheduler states will also be saved.
- """
- state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')}
- state_dict['_schedulers'] = [None] * len(self._schedulers)
- for idx, s in enumerate(self._schedulers):
- state_dict['_schedulers'][idx] = s.state_dict()
- return state_dict
- def load_state_dict(self, state_dict):
- """Loads the schedulers state.
- Args:
- state_dict (dict): scheduler state. Should be an object returned
- from a call to :meth:`state_dict`.
- """
- _schedulers = state_dict.pop('_schedulers')
- self.__dict__.update(state_dict)
- # Restore state_dict keys in order to prevent side effects
- # https://github.com/pytorch/pytorch/issues/32756
- state_dict['_schedulers'] = _schedulers
- for idx, s in enumerate(_schedulers):
- self._schedulers[idx].load_state_dict(s)
- class PolynomialLR(LRScheduler):
- """Decays the learning rate of each parameter group using a polynomial function
- in the given total_iters. When last_epoch=-1, sets initial lr as lr.
- Args:
- optimizer (Optimizer): Wrapped optimizer.
- total_iters (int): The number of steps that the scheduler decays the learning rate. Default: 5.
- power (int): The power of the polynomial. Default: 1.0.
- verbose (bool): If ``True``, prints a message to stdout for
- each update. Default: ``False``.
- Example:
- >>> # xdoctest: +SKIP("undefined vars")
- >>> # Assuming optimizer uses lr = 0.001 for all groups
- >>> # lr = 0.001 if epoch == 0
- >>> # lr = 0.00075 if epoch == 1
- >>> # lr = 0.00050 if epoch == 2
- >>> # lr = 0.00025 if epoch == 3
- >>> # lr = 0.0 if epoch >= 4
- >>> scheduler = PolynomialLR(self.opt, total_iters=4, power=1.0)
- >>> for epoch in range(100):
- >>> train(...)
- >>> validate(...)
- >>> scheduler.step()
- """
- def __init__(self, optimizer, total_iters=5, power=1.0, last_epoch=-1, verbose=False):
- self.total_iters = total_iters
- self.power = power
- super().__init__(optimizer, last_epoch, verbose)
- def get_lr(self):
- if not self._get_lr_called_within_step:
- warnings.warn("To get the last learning rate computed by the scheduler, "
- "please use `get_last_lr()`.", UserWarning)
- if self.last_epoch == 0 or self.last_epoch > self.total_iters:
- return [group["lr"] for group in self.optimizer.param_groups]
- decay_factor = ((1.0 - self.last_epoch / self.total_iters) / (1.0 - (self.last_epoch - 1) / self.total_iters)) ** self.power
- return [group["lr"] * decay_factor for group in self.optimizer.param_groups]
- def _get_closed_form_lr(self):
- return [
- (
- base_lr * (1.0 - min(self.total_iters, self.last_epoch) / self.total_iters) ** self.power
- )
- for base_lr in self.base_lrs
- ]
- class CosineAnnealingLR(LRScheduler):
- r"""Set the learning rate of each parameter group using a cosine annealing
- schedule, where :math:`\eta_{max}` is set to the initial lr and
- :math:`T_{cur}` is the number of epochs since the last restart in SGDR:
- .. math::
- \begin{aligned}
- \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
- + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
- & T_{cur} \neq (2k+1)T_{max}; \\
- \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
- \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
- & T_{cur} = (2k+1)T_{max}.
- \end{aligned}
- When last_epoch=-1, sets initial lr as lr. Notice that because the schedule
- is defined recursively, the learning rate can be simultaneously modified
- outside this scheduler by other operators. If the learning rate is set
- solely by this scheduler, the learning rate at each step becomes:
- .. math::
- \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
- \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right)
- It has been proposed in
- `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
- implements the cosine annealing part of SGDR, and not the restarts.
- Args:
- optimizer (Optimizer): Wrapped optimizer.
- T_max (int): Maximum number of iterations.
- eta_min (float): Minimum learning rate. Default: 0.
- last_epoch (int): The index of last epoch. Default: -1.
- verbose (bool): If ``True``, prints a message to stdout for
- each update. Default: ``False``.
- .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
- https://arxiv.org/abs/1608.03983
- """
- def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1, verbose=False):
- self.T_max = T_max
- self.eta_min = eta_min
- super().__init__(optimizer, last_epoch, verbose)
- def get_lr(self):
- if not self._get_lr_called_within_step:
- warnings.warn("To get the last learning rate computed by the scheduler, "
- "please use `get_last_lr()`.", UserWarning)
- if self.last_epoch == 0:
- return [group['lr'] for group in self.optimizer.param_groups]
- elif self._step_count == 1 and self.last_epoch > 0:
- return [self.eta_min + (base_lr - self.eta_min) *
- (1 + math.cos((self.last_epoch) * math.pi / self.T_max)) / 2
- for base_lr, group in
- zip(self.base_lrs, self.optimizer.param_groups)]
- elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
- return [group['lr'] + (base_lr - self.eta_min) *
- (1 - math.cos(math.pi / self.T_max)) / 2
- for base_lr, group in
- zip(self.base_lrs, self.optimizer.param_groups)]
- return [(1 + math.cos(math.pi * self.last_epoch / self.T_max)) /
- (1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) *
- (group['lr'] - self.eta_min) + self.eta_min
- for group in self.optimizer.param_groups]
- def _get_closed_form_lr(self):
- return [self.eta_min + (base_lr - self.eta_min) *
- (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2
- for base_lr in self.base_lrs]
- class ChainedScheduler(LRScheduler):
- """Chains list of learning rate schedulers. It takes a list of chainable learning
- rate schedulers and performs consecutive step() functions belonging to them by just
- one call.
- Args:
- schedulers (list): List of chained schedulers.
- Example:
- >>> # xdoctest: +SKIP
- >>> # Assuming optimizer uses lr = 1. for all groups
- >>> # lr = 0.09 if epoch == 0
- >>> # lr = 0.081 if epoch == 1
- >>> # lr = 0.729 if epoch == 2
- >>> # lr = 0.6561 if epoch == 3
- >>> # lr = 0.59049 if epoch >= 4
- >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2)
- >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9)
- >>> scheduler = ChainedScheduler([scheduler1, scheduler2])
- >>> for epoch in range(100):
- >>> train(...)
- >>> validate(...)
- >>> scheduler.step()
- """
- def __init__(self, schedulers):
- for scheduler_idx in range(1, len(schedulers)):
- if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer):
- raise ValueError(
- "ChainedScheduler expects all schedulers to belong to the same optimizer, but "
- "got schedulers at index {} and {} to be different".format(0, scheduler_idx)
- )
- self._schedulers = list(schedulers)
- self.optimizer = schedulers[0].optimizer
- self._last_lr = [group['lr'] for group in self._schedulers[-1].optimizer.param_groups]
- def step(self):
- for scheduler in self._schedulers:
- scheduler.step()
- self._last_lr = [group['lr'] for group in self._schedulers[-1].optimizer.param_groups]
- def state_dict(self):
- """Returns the state of the scheduler as a :class:`dict`.
- It contains an entry for every variable in self.__dict__ which
- is not the optimizer.
- The wrapped scheduler states will also be saved.
- """
- state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')}
- state_dict['_schedulers'] = [None] * len(self._schedulers)
- for idx, s in enumerate(self._schedulers):
- state_dict['_schedulers'][idx] = s.state_dict()
- return state_dict
- def load_state_dict(self, state_dict):
- """Loads the schedulers state.
- Args:
- state_dict (dict): scheduler state. Should be an object returned
- from a call to :meth:`state_dict`.
- """
- _schedulers = state_dict.pop('_schedulers')
- self.__dict__.update(state_dict)
- # Restore state_dict keys in order to prevent side effects
- # https://github.com/pytorch/pytorch/issues/32756
- state_dict['_schedulers'] = _schedulers
- for idx, s in enumerate(_schedulers):
- self._schedulers[idx].load_state_dict(s)
- class ReduceLROnPlateau:
- """Reduce learning rate when a metric has stopped improving.
- Models often benefit from reducing the learning rate by a factor
- of 2-10 once learning stagnates. This scheduler reads a metrics
- quantity and if no improvement is seen for a 'patience' number
- of epochs, the learning rate is reduced.
- Args:
- optimizer (Optimizer): Wrapped optimizer.
- mode (str): One of `min`, `max`. In `min` mode, lr will
- be reduced when the quantity monitored has stopped
- decreasing; in `max` mode it will be reduced when the
- quantity monitored has stopped increasing. Default: 'min'.
- factor (float): Factor by which the learning rate will be
- reduced. new_lr = lr * factor. Default: 0.1.
- patience (int): Number of epochs with no improvement after
- which learning rate will be reduced. For example, if
- `patience = 2`, then we will ignore the first 2 epochs
- with no improvement, and will only decrease the LR after the
- 3rd epoch if the loss still hasn't improved then.
- Default: 10.
- threshold (float): Threshold for measuring the new optimum,
- to only focus on significant changes. Default: 1e-4.
- threshold_mode (str): One of `rel`, `abs`. In `rel` mode,
- dynamic_threshold = best * ( 1 + threshold ) in 'max'
- mode or best * ( 1 - threshold ) in `min` mode.
- In `abs` mode, dynamic_threshold = best + threshold in
- `max` mode or best - threshold in `min` mode. Default: 'rel'.
- cooldown (int): Number of epochs to wait before resuming
- normal operation after lr has been reduced. Default: 0.
- min_lr (float or list): A scalar or a list of scalars. A
- lower bound on the learning rate of all param groups
- or each group respectively. Default: 0.
- eps (float): Minimal decay applied to lr. If the difference
- between new and old lr is smaller than eps, the update is
- ignored. Default: 1e-8.
- verbose (bool): If ``True``, prints a message to stdout for
- each update. Default: ``False``.
- Example:
- >>> # xdoctest: +SKIP
- >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
- >>> scheduler = ReduceLROnPlateau(optimizer, 'min')
- >>> for epoch in range(10):
- >>> train(...)
- >>> val_loss = validate(...)
- >>> # Note that step should be called after validate()
- >>> scheduler.step(val_loss)
- """
- def __init__(self, optimizer, mode='min', factor=0.1, patience=10,
- threshold=1e-4, threshold_mode='rel', cooldown=0,
- min_lr=0, eps=1e-8, verbose=False):
- if factor >= 1.0:
- raise ValueError('Factor should be < 1.0.')
- self.factor = factor
- # Attach optimizer
- if not isinstance(optimizer, Optimizer):
- raise TypeError('{} is not an Optimizer'.format(
- type(optimizer).__name__))
- self.optimizer = optimizer
- if isinstance(min_lr, (list, tuple)):
- if len(min_lr) != len(optimizer.param_groups):
- raise ValueError("expected {} min_lrs, got {}".format(
- len(optimizer.param_groups), len(min_lr)))
- self.min_lrs = list(min_lr)
- else:
- self.min_lrs = [min_lr] * len(optimizer.param_groups)
- self.patience = patience
- self.verbose = verbose
- self.cooldown = cooldown
- self.cooldown_counter = 0
- self.mode = mode
- self.threshold = threshold
- self.threshold_mode = threshold_mode
- self.best = None
- self.num_bad_epochs = None
- self.mode_worse = None # the worse value for the chosen mode
- self.eps = eps
- self.last_epoch = 0
- self._init_is_better(mode=mode, threshold=threshold,
- threshold_mode=threshold_mode)
- self._reset()
- def _reset(self):
- """Resets num_bad_epochs counter and cooldown counter."""
- self.best = self.mode_worse
- self.cooldown_counter = 0
- self.num_bad_epochs = 0
- def step(self, metrics, epoch=None):
- # convert `metrics` to float, in case it's a zero-dim Tensor
- current = float(metrics)
- if epoch is None:
- epoch = self.last_epoch + 1
- else:
- warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
- self.last_epoch = epoch
- if self.is_better(current, self.best):
- self.best = current
- self.num_bad_epochs = 0
- else:
- self.num_bad_epochs += 1
- if self.in_cooldown:
- self.cooldown_counter -= 1
- self.num_bad_epochs = 0 # ignore any bad epochs in cooldown
- if self.num_bad_epochs > self.patience:
- self._reduce_lr(epoch)
- self.cooldown_counter = self.cooldown
- self.num_bad_epochs = 0
- self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
- def _reduce_lr(self, epoch):
- for i, param_group in enumerate(self.optimizer.param_groups):
- old_lr = float(param_group['lr'])
- new_lr = max(old_lr * self.factor, self.min_lrs[i])
- if old_lr - new_lr > self.eps:
- param_group['lr'] = new_lr
- if self.verbose:
- epoch_str = ("%.2f" if isinstance(epoch, float) else
- "%.5d") % epoch
- print('Epoch {}: reducing learning rate'
- ' of group {} to {:.4e}.'.format(epoch_str, i, new_lr))
- @property
- def in_cooldown(self):
- return self.cooldown_counter > 0
- def is_better(self, a, best):
- if self.mode == 'min' and self.threshold_mode == 'rel':
- rel_epsilon = 1. - self.threshold
- return a < best * rel_epsilon
- elif self.mode == 'min' and self.threshold_mode == 'abs':
- return a < best - self.threshold
- elif self.mode == 'max' and self.threshold_mode == 'rel':
- rel_epsilon = self.threshold + 1.
- return a > best * rel_epsilon
- else: # mode == 'max' and epsilon_mode == 'abs':
- return a > best + self.threshold
- def _init_is_better(self, mode, threshold, threshold_mode):
- if mode not in {'min', 'max'}:
- raise ValueError('mode ' + mode + ' is unknown!')
- if threshold_mode not in {'rel', 'abs'}:
- raise ValueError('threshold mode ' + threshold_mode + ' is unknown!')
- if mode == 'min':
- self.mode_worse = inf
- else: # mode == 'max':
- self.mode_worse = -inf
- self.mode = mode
- self.threshold = threshold
- self.threshold_mode = threshold_mode
- def state_dict(self):
- return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
- def load_state_dict(self, state_dict):
- self.__dict__.update(state_dict)
- self._init_is_better(mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode)
- class CyclicLR(LRScheduler):
- r"""Sets the learning rate of each parameter group according to
- cyclical learning rate policy (CLR). The policy cycles the learning
- rate between two boundaries with a constant frequency, as detailed in
- the paper `Cyclical Learning Rates for Training Neural Networks`_.
- The distance between the two boundaries can be scaled on a per-iteration
- or per-cycle basis.
- Cyclical learning rate policy changes the learning rate after every batch.
- `step` should be called after a batch has been used for training.
- This class has three built-in policies, as put forth in the paper:
- * "triangular": A basic triangular cycle without amplitude scaling.
- * "triangular2": A basic triangular cycle that scales initial amplitude by half each cycle.
- * "exp_range": A cycle that scales initial amplitude by :math:`\text{gamma}^{\text{cycle iterations}}`
- at each cycle iteration.
- This implementation was adapted from the github repo: `bckenstler/CLR`_
- Args:
- optimizer (Optimizer): Wrapped optimizer.
- base_lr (float or list): Initial learning rate which is the
- lower boundary in the cycle for each parameter group.
- max_lr (float or list): Upper learning rate boundaries in the cycle
- for each parameter group. Functionally,
- it defines the cycle amplitude (max_lr - base_lr).
- The lr at any cycle is the sum of base_lr
- and some scaling of the amplitude; therefore
- max_lr may not actually be reached depending on
- scaling function.
- step_size_up (int): Number of training iterations in the
- increasing half of a cycle. Default: 2000
- step_size_down (int): Number of training iterations in the
- decreasing half of a cycle. If step_size_down is None,
- it is set to step_size_up. Default: None
- mode (str): One of {triangular, triangular2, exp_range}.
- Values correspond to policies detailed above.
- If scale_fn is not None, this argument is ignored.
- Default: 'triangular'
- gamma (float): Constant in 'exp_range' scaling function:
- gamma**(cycle iterations)
- Default: 1.0
- scale_fn (function): Custom scaling policy defined by a single
- argument lambda function, where
- 0 <= scale_fn(x) <= 1 for all x >= 0.
- If specified, then 'mode' is ignored.
- Default: None
- scale_mode (str): {'cycle', 'iterations'}.
- Defines whether scale_fn is evaluated on
- cycle number or cycle iterations (training
- iterations since start of cycle).
- Default: 'cycle'
- cycle_momentum (bool): If ``True``, momentum is cycled inversely
- to learning rate between 'base_momentum' and 'max_momentum'.
- Default: True
- base_momentum (float or list): Lower momentum boundaries in the cycle
- for each parameter group. Note that momentum is cycled inversely
- to learning rate; at the peak of a cycle, momentum is
- 'base_momentum' and learning rate is 'max_lr'.
- Default: 0.8
- max_momentum (float or list): Upper momentum boundaries in the cycle
- for each parameter group. Functionally,
- it defines the cycle amplitude (max_momentum - base_momentum).
- The momentum at any cycle is the difference of max_momentum
- and some scaling of the amplitude; therefore
- base_momentum may not actually be reached depending on
- scaling function. Note that momentum is cycled inversely
- to learning rate; at the start of a cycle, momentum is 'max_momentum'
- and learning rate is 'base_lr'
- Default: 0.9
- last_epoch (int): The index of the last batch. This parameter is used when
- resuming a training job. Since `step()` should be invoked after each
- batch instead of after each epoch, this number represents the total
- number of *batches* computed, not the total number of epochs computed.
- When last_epoch=-1, the schedule is started from the beginning.
- Default: -1
- verbose (bool): If ``True``, prints a message to stdout for
- each update. Default: ``False``.
- Example:
- >>> # xdoctest: +SKIP
- >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
- >>> scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1)
- >>> data_loader = torch.utils.data.DataLoader(...)
- >>> for epoch in range(10):
- >>> for batch in data_loader:
- >>> train_batch(...)
- >>> scheduler.step()
- .. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186
- .. _bckenstler/CLR: https://github.com/bckenstler/CLR
- """
- def __init__(self,
- optimizer,
- base_lr,
- max_lr,
- step_size_up=2000,
- step_size_down=None,
- mode='triangular',
- gamma=1.,
- scale_fn=None,
- scale_mode='cycle',
- cycle_momentum=True,
- base_momentum=0.8,
- max_momentum=0.9,
- last_epoch=-1,
- verbose=False):
- # Attach optimizer
- if not isinstance(optimizer, Optimizer):
- raise TypeError('{} is not an Optimizer'.format(
- type(optimizer).__name__))
- self.optimizer = optimizer
- base_lrs = self._format_param('base_lr', optimizer, base_lr)
- if last_epoch == -1:
- for lr, group in zip(base_lrs, optimizer.param_groups):
- group['lr'] = lr
- self.max_lrs = self._format_param('max_lr', optimizer, max_lr)
- step_size_up = float(step_size_up)
- step_size_down = float(step_size_down) if step_size_down is not None else step_size_up
- self.total_size = step_size_up + step_size_down
- self.step_ratio = step_size_up / self.total_size
- if mode not in ['triangular', 'triangular2', 'exp_range'] \
- and scale_fn is None:
- raise ValueError('mode is invalid and scale_fn is None')
- self.mode = mode
- self.gamma = gamma
- self._scale_fn_ref = None
- self._scale_fn_custom = scale_fn
- self.scale_mode = scale_mode
- self._init_scale_fn()
- self.cycle_momentum = cycle_momentum
- if cycle_momentum:
- if 'momentum' not in optimizer.defaults:
- raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled')
- base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
- if last_epoch == -1:
- for momentum, group in zip(base_momentums, optimizer.param_groups):
- group['momentum'] = momentum
- self.base_momentums = [group['momentum'] for group in optimizer.param_groups]
- self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
- super().__init__(optimizer, last_epoch, verbose)
- self.base_lrs = base_lrs
- def _init_scale_fn(self):
- if self._scale_fn_custom is not None:
- return
- if self.mode == 'triangular':
- self._scale_fn_ref = weakref.WeakMethod(self._triangular_scale_fn)
- self.scale_mode = 'cycle'
- elif self.mode == 'triangular2':
- self._scale_fn_ref = weakref.WeakMethod(self._triangular2_scale_fn)
- self.scale_mode = 'cycle'
- elif self.mode == 'exp_range':
- self._scale_fn_ref = weakref.WeakMethod(self._exp_range_scale_fn)
- self.scale_mode = 'iterations'
- def _format_param(self, name, optimizer, param):
- """Return correctly formatted lr/momentum for each param group."""
- if isinstance(param, (list, tuple)):
- if len(param) != len(optimizer.param_groups):
- raise ValueError("expected {} values for {}, got {}".format(
- len(optimizer.param_groups), name, len(param)))
- return param
- else:
- return [param] * len(optimizer.param_groups)
- def scale_fn(self, x):
- if self._scale_fn_custom is not None:
- return self._scale_fn_custom(x)
- else:
- return self._scale_fn_ref()(x)
- def _triangular_scale_fn(self, x):
- return 1.
- def _triangular2_scale_fn(self, x):
- return 1 / (2. ** (x - 1))
- def _exp_range_scale_fn(self, x):
- return self.gamma**(x)
- def get_lr(self):
- """Calculates the learning rate at batch index. This function treats
- `self.last_epoch` as the last batch index.
- If `self.cycle_momentum` is ``True``, this function has a side effect of
- updating the optimizer's momentum.
- """
- if not self._get_lr_called_within_step:
- warnings.warn("To get the last learning rate computed by the scheduler, "
- "please use `get_last_lr()`.", UserWarning)
- cycle = math.floor(1 + self.last_epoch / self.total_size)
- x = 1. + self.last_epoch / self.total_size - cycle
- if x <= self.step_ratio:
- scale_factor = x / self.step_ratio
- else:
- scale_factor = (x - 1) / (self.step_ratio - 1)
- lrs = []
- for base_lr, max_lr in zip(self.base_lrs, self.max_lrs):
- base_height = (max_lr - base_lr) * scale_factor
- if self.scale_mode == 'cycle':
- lr = base_lr + base_height * self.scale_fn(cycle)
- else:
- lr = base_lr + base_height * self.scale_fn(self.last_epoch)
- lrs.append(lr)
- if self.cycle_momentum:
- momentums = []
- for base_momentum, max_momentum in zip(self.base_momentums, self.max_momentums):
- base_height = (max_momentum - base_momentum) * scale_factor
- if self.scale_mode == 'cycle':
- momentum = max_momentum - base_height * self.scale_fn(cycle)
- else:
- momentum = max_momentum - base_height * self.scale_fn(self.last_epoch)
- momentums.append(momentum)
- for param_group, momentum in zip(self.optimizer.param_groups, momentums):
- param_group['momentum'] = momentum
- return lrs
- def state_dict(self):
- state = super().state_dict()
- # We are dropping the `_scale_fn_ref` attribute because it is a `weakref.WeakMethod` and can't be pickled
- state.pop("_scale_fn_ref")
- return state
- def load_state_dict(self, state_dict):
- super().load_state_dict(state_dict)
- self._init_scale_fn()
- class CosineAnnealingWarmRestarts(LRScheduler):
- r"""Set the learning rate of each parameter group using a cosine annealing
- schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
- is the number of epochs since the last restart and :math:`T_{i}` is the number
- of epochs between two warm restarts in SGDR:
- .. math::
- \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
- \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)
- When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
- When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`.
- It has been proposed in
- `SGDR: Stochastic Gradient Descent with Warm Restarts`_.
- Args:
- optimizer (Optimizer): Wrapped optimizer.
- T_0 (int): Number of iterations for the first restart.
- T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1.
- eta_min (float, optional): Minimum learning rate. Default: 0.
- last_epoch (int, optional): The index of last epoch. Default: -1.
- verbose (bool): If ``True``, prints a message to stdout for
- each update. Default: ``False``.
- .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
- https://arxiv.org/abs/1608.03983
- """
- def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False):
- if T_0 <= 0 or not isinstance(T_0, int):
- raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
- if T_mult < 1 or not isinstance(T_mult, int):
- raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
- self.T_0 = T_0
- self.T_i = T_0
- self.T_mult = T_mult
- self.eta_min = eta_min
- self.T_cur = last_epoch
- super().__init__(optimizer, last_epoch, verbose)
- def get_lr(self):
- if not self._get_lr_called_within_step:
- warnings.warn("To get the last learning rate computed by the scheduler, "
- "please use `get_last_lr()`.", UserWarning)
- return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2
- for base_lr in self.base_lrs]
- def step(self, epoch=None):
- """Step could be called after every batch update
- Example:
- >>> # xdoctest: +SKIP("Undefined vars")
- >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
- >>> iters = len(dataloader)
- >>> for epoch in range(20):
- >>> for i, sample in enumerate(dataloader):
- >>> inputs, labels = sample['inputs'], sample['labels']
- >>> optimizer.zero_grad()
- >>> outputs = net(inputs)
- >>> loss = criterion(outputs, labels)
- >>> loss.backward()
- >>> optimizer.step()
- >>> scheduler.step(epoch + i / iters)
- This function can be called in an interleaved way.
- Example:
- >>> # xdoctest: +SKIP("Undefined vars")
- >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
- >>> for epoch in range(20):
- >>> scheduler.step()
- >>> scheduler.step(26)
- >>> scheduler.step() # scheduler.step(27), instead of scheduler(20)
- """
- if epoch is None and self.last_epoch < 0:
- epoch = 0
- if epoch is None:
- epoch = self.last_epoch + 1
- self.T_cur = self.T_cur + 1
- if self.T_cur >= self.T_i:
- self.T_cur = self.T_cur - self.T_i
- self.T_i = self.T_i * self.T_mult
- else:
- if epoch < 0:
- raise ValueError("Expected non-negative epoch, but got {}".format(epoch))
- if epoch >= self.T_0:
- if self.T_mult == 1:
- self.T_cur = epoch % self.T_0
- else:
- n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
- self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
- self.T_i = self.T_0 * self.T_mult ** (n)
- else:
- self.T_i = self.T_0
- self.T_cur = epoch
- self.last_epoch = math.floor(epoch)
- class _enable_get_lr_call:
- def __init__(self, o):
- self.o = o
- def __enter__(self):
- self.o._get_lr_called_within_step = True
- return self
- def __exit__(self, type, value, traceback):
- self.o._get_lr_called_within_step = False
- return self
- with _enable_get_lr_call(self):
- for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())):
- param_group, lr = data
- param_group['lr'] = lr
- self.print_lr(self.verbose, i, lr, epoch)
- self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
- class OneCycleLR(LRScheduler):
- r"""Sets the learning rate of each parameter group according to the
- 1cycle learning rate policy. The 1cycle policy anneals the learning
- rate from an initial learning rate to some maximum learning rate and then
- from that maximum learning rate to some minimum learning rate much lower
- than the initial learning rate.
- This policy was initially described in the paper `Super-Convergence:
- Very Fast Training of Neural Networks Using Large Learning Rates`_.
- The 1cycle learning rate policy changes the learning rate after every batch.
- `step` should be called after a batch has been used for training.
- This scheduler is not chainable.
- Note also that the total number of steps in the cycle can be determined in one
- of two ways (listed in order of precedence):
- #. A value for total_steps is explicitly provided.
- #. A number of epochs (epochs) and a number of steps per epoch
- (steps_per_epoch) are provided.
- In this case, the number of total steps is inferred by
- total_steps = epochs * steps_per_epoch
- You must either provide a value for total_steps or provide a value for both
- epochs and steps_per_epoch.
- The default behaviour of this scheduler follows the fastai implementation of 1cycle, which
- claims that "unpublished work has shown even better results by using only two phases". To
- mimic the behaviour of the original paper instead, set ``three_phase=True``.
- Args:
- optimizer (Optimizer): Wrapped optimizer.
- max_lr (float or list): Upper learning rate boundaries in the cycle
- for each parameter group.
- total_steps (int): The total number of steps in the cycle. Note that
- if a value is not provided here, then it must be inferred by providing
- a value for epochs and steps_per_epoch.
- Default: None
- epochs (int): The number of epochs to train for. This is used along
- with steps_per_epoch in order to infer the total number of steps in the cycle
- if a value for total_steps is not provided.
- Default: None
- steps_per_epoch (int): The number of steps per epoch to train for. This is
- used along with epochs in order to infer the total number of steps in the
- cycle if a value for total_steps is not provided.
- Default: None
- pct_start (float): The percentage of the cycle (in number of steps) spent
- increasing the learning rate.
- Default: 0.3
- anneal_strategy (str): {'cos', 'linear'}
- Specifies the annealing strategy: "cos" for cosine annealing, "linear" for
- linear annealing.
- Default: 'cos'
- cycle_momentum (bool): If ``True``, momentum is cycled inversely
- to learning rate between 'base_momentum' and 'max_momentum'.
- Default: True
- base_momentum (float or list): Lower momentum boundaries in the cycle
- for each parameter group. Note that momentum is cycled inversely
- to learning rate; at the peak of a cycle, momentum is
- 'base_momentum' and learning rate is 'max_lr'.
- Default: 0.85
- max_momentum (float or list): Upper momentum boundaries in the cycle
- for each parameter group. Functionally,
- it defines the cycle amplitude (max_momentum - base_momentum).
- Note that momentum is cycled inversely
- to learning rate; at the start of a cycle, momentum is 'max_momentum'
- and learning rate is 'base_lr'
- Default: 0.95
- div_factor (float): Determines the initial learning rate via
- initial_lr = max_lr/div_factor
- Default: 25
- final_div_factor (float): Determines the minimum learning rate via
- min_lr = initial_lr/final_div_factor
- Default: 1e4
- three_phase (bool): If ``True``, use a third phase of the schedule to annihilate the
- learning rate according to 'final_div_factor' instead of modifying the second
- phase (the first two phases will be symmetrical about the step indicated by
- 'pct_start').
- last_epoch (int): The index of the last batch. This parameter is used when
- resuming a training job. Since `step()` should be invoked after each
- batch instead of after each epoch, this number represents the total
- number of *batches* computed, not the total number of epochs computed.
- When last_epoch=-1, the schedule is started from the beginning.
- Default: -1
- verbose (bool): If ``True``, prints a message to stdout for
- each update. Default: ``False``.
- Example:
- >>> # xdoctest: +SKIP
- >>> data_loader = torch.utils.data.DataLoader(...)
- >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
- >>> scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10)
- >>> for epoch in range(10):
- >>> for batch in data_loader:
- >>> train_batch(...)
- >>> scheduler.step()
- .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
- https://arxiv.org/abs/1708.07120
- """
- def __init__(self,
- optimizer,
- max_lr,
- total_steps=None,
- epochs=None,
- steps_per_epoch=None,
- pct_start=0.3,
- anneal_strategy='cos',
- cycle_momentum=True,
- base_momentum=0.85,
- max_momentum=0.95,
- div_factor=25.,
- final_div_factor=1e4,
- three_phase=False,
- last_epoch=-1,
- verbose=False):
- # Validate optimizer
- if not isinstance(optimizer, Optimizer):
- raise TypeError('{} is not an Optimizer'.format(
- type(optimizer).__name__))
- self.optimizer = optimizer
- # Validate total_steps
- if total_steps is None and epochs is None and steps_per_epoch is None:
- raise ValueError("You must define either total_steps OR (epochs AND steps_per_epoch)")
- elif total_steps is not None:
- if total_steps <= 0 or not isinstance(total_steps, int):
- raise ValueError("Expected positive integer total_steps, but got {}".format(total_steps))
- self.total_steps = total_steps
- else:
- if epochs <= 0 or not isinstance(epochs, int):
- raise ValueError("Expected positive integer epochs, but got {}".format(epochs))
- if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int):
- raise ValueError("Expected positive integer steps_per_epoch, but got {}".format(steps_per_epoch))
- self.total_steps = epochs * steps_per_epoch
- if three_phase:
- self._schedule_phases = [
- {
- 'end_step': float(pct_start * self.total_steps) - 1,
- 'start_lr': 'initial_lr',
- 'end_lr': 'max_lr',
- 'start_momentum': 'max_momentum',
- 'end_momentum': 'base_momentum',
- },
- {
- 'end_step': float(2 * pct_start * self.total_steps) - 2,
- 'start_lr': 'max_lr',
- 'end_lr': 'initial_lr',
- 'start_momentum': 'base_momentum',
- 'end_momentum': 'max_momentum',
- },
- {
- 'end_step': self.total_steps - 1,
- 'start_lr': 'initial_lr',
- 'end_lr': 'min_lr',
- 'start_momentum': 'max_momentum',
- 'end_momentum': 'max_momentum',
- },
- ]
- else:
- self._schedule_phases = [
- {
- 'end_step': float(pct_start * self.total_steps) - 1,
- 'start_lr': 'initial_lr',
- 'end_lr': 'max_lr',
- 'start_momentum': 'max_momentum',
- 'end_momentum': 'base_momentum',
- },
- {
- 'end_step': self.total_steps - 1,
- 'start_lr': 'max_lr',
- 'end_lr': 'min_lr',
- 'start_momentum': 'base_momentum',
- 'end_momentum': 'max_momentum',
- },
- ]
- # Validate pct_start
- if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
- raise ValueError("Expected float between 0 and 1 pct_start, but got {}".format(pct_start))
- # Validate anneal_strategy
- if anneal_strategy not in ['cos', 'linear']:
- raise ValueError("anneal_strategy must by one of 'cos' or 'linear', instead got {}".format(anneal_strategy))
- elif anneal_strategy == 'cos':
- self.anneal_func = self._annealing_cos
- elif anneal_strategy == 'linear':
- self.anneal_func = self._annealing_linear
- # Initialize learning rate variables
- max_lrs = self._format_param('max_lr', self.optimizer, max_lr)
- if last_epoch == -1:
- for idx, group in enumerate(self.optimizer.param_groups):
- group['initial_lr'] = max_lrs[idx] / div_factor
- group['max_lr'] = max_lrs[idx]
- group['min_lr'] = group['initial_lr'] / final_div_factor
- # Initialize momentum variables
- self.cycle_momentum = cycle_momentum
- if self.cycle_momentum:
- if 'momentum' not in self.optimizer.defaults and 'betas' not in self.optimizer.defaults:
- raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled')
- self.use_beta1 = 'betas' in self.optimizer.defaults
- max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
- base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
- if last_epoch == -1:
- for m_momentum, b_momentum, group in zip(max_momentums, base_momentums, optimizer.param_groups):
- if self.use_beta1:
- group['betas'] = (m_momentum, *group['betas'][1:])
- else:
- group['momentum'] = m_momentum
- group['max_momentum'] = m_momentum
- group['base_momentum'] = b_momentum
- super().__init__(optimizer, last_epoch, verbose)
- def _format_param(self, name, optimizer, param):
- """Return correctly formatted lr/momentum for each param group."""
- if isinstance(param, (list, tuple)):
- if len(param) != len(optimizer.param_groups):
- raise ValueError("expected {} values for {}, got {}".format(
- len(optimizer.param_groups), name, len(param)))
- return param
- else:
- return [param] * len(optimizer.param_groups)
- def _annealing_cos(self, start, end, pct):
- "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
- cos_out = math.cos(math.pi * pct) + 1
- return end + (start - end) / 2.0 * cos_out
- def _annealing_linear(self, start, end, pct):
- "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."
- return (end - start) * pct + start
- def get_lr(self):
- if not self._get_lr_called_within_step:
- warnings.warn("To get the last learning rate computed by the scheduler, "
- "please use `get_last_lr()`.", UserWarning)
- lrs = []
- step_num = self.last_epoch
- if step_num > self.total_steps:
- raise ValueError("Tried to step {} times. The specified number of total steps is {}"
- .format(step_num, self.total_steps))
- for group in self.optimizer.param_groups:
- start_step = 0
- for i, phase in enumerate(self._schedule_phases):
- end_step = phase['end_step']
- if step_num <= end_step or i == len(self._schedule_phases) - 1:
- pct = (step_num - start_step) / (end_step - start_step)
- computed_lr = self.anneal_func(group[phase['start_lr']], group[phase['end_lr']], pct)
- if self.cycle_momentum:
- computed_momentum = self.anneal_func(group[phase['start_momentum']], group[phase['end_momentum']], pct)
- break
- start_step = phase['end_step']
- lrs.append(computed_lr)
- if self.cycle_momentum:
- if self.use_beta1:
- group['betas'] = (computed_momentum, *group['betas'][1:])
- else:
- group['momentum'] = computed_momentum
- return lrs
|