from torch.ao.pruning import BaseSparsifier from functools import wraps import warnings import weakref __all__ = ["BaseScheduler"] class BaseScheduler: def __init__(self, sparsifier, last_epoch=-1, verbose=False): # Attach sparsifier if not isinstance(sparsifier, BaseSparsifier): raise TypeError('{} is not an instance of torch.ao.pruning.BaseSparsifier'.format( type(sparsifier).__name__)) self.sparsifier = sparsifier # Initialize epoch and base sparsity levels self.base_sl = [group['sparsity_level'] for group in sparsifier.groups] self.last_epoch = last_epoch # Following https://github.com/pytorch/pytorch/issues/20124 # We would like to ensure that `scheduler.step()` is called after # `sparsifier.step()` def with_counter(method): if getattr(method, '_with_counter', False): # `sparsifier.step()` has already been replaced, return. return method # Keep a weak reference to the sparsifier instance to prevent # cyclic references. instance_ref = weakref.ref(method.__self__) # Get the unbound method for the same purpose. func = method.__func__ cls = instance_ref().__class__ del method @wraps(func) def wrapper(*args, **kwargs): instance = instance_ref() instance._step_count += 1 # type: ignore[union-attr] wrapped = func.__get__(instance, cls) return wrapped(*args, **kwargs) # Note that the returned function here is no longer a bound method, # so attributes like `__func__` and `__self__` no longer exist. wrapper._with_counter = True # type: ignore[attr-defined] return wrapper self.sparsifier.step = with_counter(self.sparsifier.step) # type: ignore[assignment] self.sparsifier._step_count = 0 # type: ignore[attr-defined] self._step_count: int = 0 self.verbose = verbose # Housekeeping self._get_sl_called_within_step: bool = False self.step() def state_dict(self): """Returns the state of the scheduler as a :class:`dict`. It contains an entry for every variable in self.__dict__ which is not the sparsifier. """ return {key: value for key, value in self.__dict__.items() if key != 'sparsifier'} def load_state_dict(self, state_dict): """Loads the schedulers state. Args: state_dict (dict): scheduler state. Should be an object returned from a call to :meth:`state_dict`. """ self.__dict__.update(state_dict) def get_last_sl(self): """ Return last computed sparsity level by current scheduler. """ return self._last_sl def get_sl(self): # Compute sparsity level using chainable form of the scheduler # Note: This method is not intended to be called directly, and is only # used by the ".step" method. Use .get_last_sl() instead. if not self._get_sl_called_within_step: warnings.warn( "To get the last sparsity level computed by the scheduler, " "please use `get_last_sl()`.") raise NotImplementedError def print_sl(self, is_verbose, group, sl, epoch=None): """Display the current sparsity level. """ if is_verbose: if epoch is None: print('Adjusting sparsity level' ' of group {} to {:.4e}.'.format(group, sl)) else: print('Epoch {:5d}: adjusting sparsity level' ' of group {} to {:.4e}.'.format(epoch, group, sl)) def __repr__(self): format_string = self.__class__.__name__ + ' (' format_string += '\n' format_string += 'Sparsifier {0}\n'.format(self.sparsifier) format_string += ' {0}: {1}\n'.format('base_sl', self.base_sl) format_string += ')' return format_string def step(self, epoch=None): # Raise warning if trying to call scheduler step before the sparsifier. # https://github.com/pytorch/pytorch/issues/20124 if self._step_count == 1: if not hasattr(self.sparsifier.step, "_with_counter"): warnings.warn("Seems like `sparsifier.step()` has been overridden after sparsity scheduler " "initialization. Please, make sure to call `sparsifier.step()` before " "`scheduler.step()`.", UserWarning) # Just check if there were two first scheduler.step() calls before sparsifier.step() elif self.sparsifier._step_count < 1: # type: ignore[attr-defined] warnings.warn("Detected call of `scheduler.step()` before `sparsifier.step()`. " "You have to make sure you run the sparsifier.step() BEFORE any " "calls to the scheduer.step().", UserWarning) self._step_count += 1 class _enable_get_sl_call: def __init__(self, o): self.o = o def __enter__(self): self.o._get_sl_called_within_step = True return self def __exit__(self, type, value, traceback): self.o._get_sl_called_within_step = False with _enable_get_sl_call(self): self.last_epoch += 1 values = self.get_sl() for i, data in enumerate(zip(self.sparsifier.groups, values)): param_group, sl = data param_group['sparsity_level'] = sl self.print_sl(self.verbose, i, sl, epoch) self._last_sl = [group['sparsity_level'] for group in self.sparsifier.groups] self.sparsifier.enable_mask_update = True def _make_sure_a_list(self, var): r"""Utility that extends it to the same length as the .groups, ensuring it is a list""" n = len(self.sparsifier.groups) if not isinstance(var, (list, tuple)): return [var] * n else: if len(var) != n: raise ValueError("Expected variable of length {n}, but got {got}".format( n=n, got=len(var) )) return list(var) # We want the result to be in a list, not tuple