123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263 |
- import itertools
- import warnings
- from typing import Protocol
- import torch
- from ..parameter import is_lazy
- __all__ = ['LazyModuleMixin']
- class _LazyProtocol(Protocol):
- """This is to avoid errors with mypy checks for
- The attributes in a mixin:
- https://mypy.readthedocs.io/en/latest/more_types.html#mixin-classes
- """
- def _register_load_state_dict_pre_hook(self, hook):
- ...
- def register_forward_pre_hook(self, hook):
- ...
- def _lazy_load_hook(
- self, state_dict, prefix, local_metadata, strict,
- missing_keys, unexpected_keys, error_msgs):
- ...
- def _get_name(self):
- ...
- def _infer_parameters(self, module, input):
- ...
- @property
- def _parameters(self):
- ...
- @property
- def _buffers(self):
- ...
- @property
- def _non_persistent_buffers_set(self):
- ...
- @property
- def _load_hook(self):
- ...
- @property
- def _initialize_hook(self):
- ...
- class LazyModuleMixin:
- r"""A mixin for modules that lazily initialize parameters, also known as "lazy modules."
- .. warning:
- Lazy modules are an experimental new feature under active development,
- and their API is likely to change.
- Modules that lazily initialize parameters, or "lazy modules",
- derive the shapes of their parameters from the first input(s)
- to their forward method. Until that first forward they contain
- :class:`torch.nn.UninitializedParameter` s that should not be accessed
- or used, and afterward they contain regular :class:`torch.nn.Parameter` s.
- Lazy modules are convenient since they don't require computing some
- module arguments, like the :attr:`in_features` argument of a
- typical :class:`torch.nn.Linear`.
- After construction, networks with lazy modules should first
- be converted to the desired dtype and placed on the expected device.
- This is because lazy modules only perform shape inference so the usual dtype
- and device placement behavior applies.
- The lazy modules should then perform "dry runs" to initialize all the components in the module.
- These "dry runs" send inputs of the correct size, dtype, and device through
- the network and to each one of its lazy modules. After this the network can be used as usual.
- >>> # xdoctest: +SKIP
- >>> class LazyMLP(torch.nn.Module):
- ... def __init__(self):
- ... super().__init__()
- ... self.fc1 = torch.nn.LazyLinear(10)
- ... self.relu1 = torch.nn.ReLU()
- ... self.fc2 = torch.nn.LazyLinear(1)
- ... self.relu2 = torch.nn.ReLU()
- ...
- ... def forward(self, input):
- ... x = self.relu1(self.fc1(input))
- ... y = self.relu2(self.fc2(x))
- ... return y
- >>> # constructs a network with lazy modules
- >>> lazy_mlp = LazyMLP()
- >>> # transforms the network's device and dtype
- >>> # NOTE: these transforms can and should be applied after construction and before any 'dry runs'
- >>> lazy_mlp = lazy_mlp.cuda().double()
- >>> lazy_mlp
- LazyMLP( (fc1): LazyLinear(in_features=0, out_features=10, bias=True)
- (relu1): ReLU()
- (fc2): LazyLinear(in_features=0, out_features=1, bias=True)
- (relu2): ReLU()
- )
- >>> # performs a dry run to initialize the network's lazy modules
- >>> lazy_mlp(torch.ones(10,10).cuda())
- >>> # after initialization, LazyLinear modules become regular Linear modules
- >>> lazy_mlp
- LazyMLP(
- (fc1): Linear(in_features=10, out_features=10, bias=True)
- (relu1): ReLU()
- (fc2): Linear(in_features=10, out_features=1, bias=True)
- (relu2): ReLU()
- )
- >>> # attaches an optimizer, since parameters can now be used as usual
- >>> optim = torch.optim.SGD(mlp.parameters(), lr=0.01)
- A final caveat when using lazy modules is that the order of initialization of a network's
- parameters may change, since the lazy modules are always initialized after other modules.
- For example, if the LazyMLP class defined above had a :class:`torch.nn.LazyLinear` module
- first and then a regular :class:`torch.nn.Linear` second, the second module would be
- initialized on construction and the first module would be initialized during the first dry run.
- This can cause the parameters of a network using lazy modules to be initialized differently
- than the parameters of a network without lazy modules as the order of parameter initializations,
- which often depends on a stateful random number generator, is different.
- Check :doc:`/notes/randomness` for more details.
- Lazy modules can be serialized with a state dict like other modules. For example:
- >>> lazy_mlp = LazyMLP()
- >>> # The state dict shows the uninitialized parameters
- >>> lazy_mlp.state_dict()
- OrderedDict([('fc1.weight', Uninitialized parameter),
- ('fc1.bias',
- tensor([-1.8832e+25, 4.5636e-41, -1.8832e+25, 4.5636e-41, -6.1598e-30,
- 4.5637e-41, -1.8788e+22, 4.5636e-41, -2.0042e-31, 4.5637e-41])),
- ('fc2.weight', Uninitialized parameter),
- ('fc2.bias', tensor([0.0019]))])
- Lazy modules can load regular :class:`torch.nn.Parameter` s (i.e. you can serialize/deserialize
- initialized LazyModules and they will remain initialized)
- >>> full_mlp = LazyMLP()
- >>> # Dry run to initialize another module
- >>> full_mlp.forward(torch.ones(10, 1))
- >>> # Load an initialized state into a lazy module
- >>> lazy_mlp.load_state_dict(full_mlp.state_dict())
- >>> # The state dict now holds valid values
- >>> lazy_mlp.state_dict()
- OrderedDict([('fc1.weight',
- tensor([[-0.3837],
- [ 0.0907],
- [ 0.6708],
- [-0.5223],
- [-0.9028],
- [ 0.2851],
- [-0.4537],
- [ 0.6813],
- [ 0.5766],
- [-0.8678]])),
- ('fc1.bias',
- tensor([-1.8832e+25, 4.5636e-41, -1.8832e+25, 4.5636e-41, -6.1598e-30,
- 4.5637e-41, -1.8788e+22, 4.5636e-41, -2.0042e-31, 4.5637e-41])),
- ('fc2.weight',
- tensor([[ 0.1320, 0.2938, 0.0679, 0.2793, 0.1088, -0.1795, -0.2301, 0.2807,
- 0.2479, 0.1091]])),
- ('fc2.bias', tensor([0.0019]))])
- Note, however, that the loaded parameters will not be replaced when doing a "dry run" if they are initialized
- when the state is loaded. This prevents using initialized modules in different contexts.
- """
- # modules inheriting from this will change their __class__ to the specified
- # one after they are fully initialized
- cls_to_become = None
- def __init__(self: _LazyProtocol, *args, **kwargs):
- # Mypy doesnt like this super call in a mixin
- super().__init__(*args, **kwargs) # type: ignore[misc]
- self._load_hook = self._register_load_state_dict_pre_hook(self._lazy_load_hook)
- self._initialize_hook = self.register_forward_pre_hook(self._infer_parameters)
- warnings.warn('Lazy modules are a new feature under heavy development '
- 'so changes to the API or functionality can happen at any moment.')
- def _save_to_state_dict(self: _LazyProtocol, destination, prefix, keep_vars):
- # This should be ideally implemented as a hook,
- # but we should override `detach` in the UninitializedParameter to return itself
- # which is not clean
- for name, param in self._parameters.items():
- if param is not None:
- if not (is_lazy(param) or keep_vars):
- param = param.detach()
- destination[prefix + name] = param
- for name, buf in self._buffers.items():
- if buf is not None and name not in self._non_persistent_buffers_set:
- if not (is_lazy(buf) or keep_vars):
- buf = buf.detach()
- destination[prefix + name] = buf
- def _lazy_load_hook(
- self: _LazyProtocol, state_dict, prefix, local_metadata, strict,
- missing_keys, unexpected_keys, error_msgs):
- """load_state_dict pre-hook function for lazy buffers and parameters.
- The purpose of this hook is to adjust the current state and/or
- ``state_dict`` being loaded so that a module instance serialized in
- both un/initialized state can be deserialized onto both un/initialized
- module instance.
- See comment in ``torch.nn.Module._register_load_state_dict_pre_hook``
- for the details of the hook specification.
- """
- for name, param in itertools.chain(self._parameters.items(), self._buffers.items()):
- key = prefix + name
- if key in state_dict and param is not None:
- input_param = state_dict[key]
- if is_lazy(param):
- # The current parameter is not initialized but the one being loaded one is
- # create a new parameter based on the uninitialized one
- if not is_lazy(input_param):
- with torch.no_grad():
- param.materialize(input_param.shape)
- def initialize_parameters(self: _LazyProtocol, *args, **kwargs):
- r"""Initialize parameters according to the input batch properties.
- This adds an interface to isolate parameter initialization from the
- forward pass when doing parameter shape inference.
- """
- raise NotImplementedError('initialize_parameters is not implemented for {}'.format(self.__class__.__name__))
- def has_uninitialized_params(self: _LazyProtocol):
- r"""Check if a module has parameters that are not initialized
- """
- # This is to avoid the JIT to track this parameter and force
- # custom modules __setstate__ to add it
- params = self._parameters.values()
- buffers = self._buffers.values()
- for param in itertools.chain(params, buffers):
- if is_lazy(param):
- return True
- return False
- def _infer_parameters(self: _LazyProtocol, module, input):
- r"""Infers the size and initializes the parameters according to the
- provided input batch.
- Given a module that contains parameters that were declared inferrable
- using :class:`torch.nn.parameter.ParameterMode.Infer`, runs a forward pass
- in the complete module using the provided input to initialize all the parameters
- as needed.
- The module is set into evaluation mode before running the forward pass in order
- to avoid saving statistics or calculating gradients
- """
- module.initialize_parameters(*input)
- if module.has_uninitialized_params():
- raise RuntimeError('module {} has not been fully initialized'.format(self._get_name()))
- module._initialize_hook.remove()
- module._load_hook.remove()
- delattr(module, '_initialize_hook')
- delattr(module, '_load_hook')
- if module.cls_to_become is not None:
- module.__class__ = module.cls_to_become
- def _replicate_for_data_parallel(self: _LazyProtocol):
- raise RuntimeError('Modules with uninitialized parameters can\'t be used with `DataParallel`. '
- 'Run a dummy forward pass to correctly initialize the modules')
|