123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759 |
- import torch
- from torch.nn.modules.container import ModuleList, ModuleDict, Module
- from torch.nn.parameter import Parameter
- from torch import Tensor
- import collections
- import copyreg
- from copy import deepcopy
- from contextlib import contextmanager
- from typing import Union, Optional, Dict, Tuple, Sequence
- __all__ = ['cached', 'ParametrizationList', 'register_parametrization', 'is_parametrized', 'remove_parametrizations',
- 'type_before_parametrizations', 'transfer_parametrizations_and_params']
- _cache_enabled = 0
- _cache: Dict[Tuple[int, str], Optional[Tensor]] = {}
- @contextmanager
- def cached():
- r"""Context manager that enables the caching system within parametrizations
- registered with :func:`register_parametrization`.
- The value of the parametrized objects is computed and cached the first time
- they are required when this context manager is active. The cached values are
- discarded when leaving the context manager.
- This is useful when using a parametrized parameter more than once in the forward pass.
- An example of this is when parametrizing the recurrent kernel of an RNN or when
- sharing weights.
- The simplest way to activate the cache is by wrapping the forward pass of the neural network
- .. code-block:: python
- import torch.nn.utils.parametrize as P
- ...
- with P.cached():
- output = model(inputs)
- in training and evaluation. One may also wrap the parts of the modules that use
- several times the parametrized tensors. For example, the loop of an RNN with a
- parametrized recurrent kernel:
- .. code-block:: python
- with P.cached():
- for x in xs:
- out_rnn = self.rnn_cell(x, out_rnn)
- """
- global _cache
- global _cache_enabled
- _cache_enabled += 1
- try:
- yield
- finally:
- _cache_enabled -= 1
- if not _cache_enabled:
- _cache = {}
- def _register_parameter_or_buffer(module, name, X):
- if isinstance(X, Parameter):
- module.register_parameter(name, X)
- else:
- module.register_buffer(name, X)
- class ParametrizationList(ModuleList):
- r"""A sequential container that holds and manages the ``original`` or ``original0``, ``original1``, ...
- parameters or buffers of a parametrized :class:`torch.nn.Module`.
- It is the type of ``module.parametrizations[tensor_name]`` when ``module[tensor_name]``
- has been parametrized with :func:`register_parametrization`.
- If the first registered parametrization has a ``right_inverse`` that returns one tensor or
- does not have a ``right_inverse`` (in which case we assume that ``right_inverse`` is the identity),
- it will hold the tensor under the name ``original``.
- If it has a ``right_inverse`` that returns more than one tensor, these will be registered as
- ``original0``, ``original1``, ...
- .. warning::
- This class is used internally by :func:`register_parametrization`. It is documented
- here for completeness. It shall not be instantiated by the user.
- Args:
- modules (sequence): sequence of modules representing the parametrizations
- original (Parameter or Tensor): parameter or buffer that is parametrized
- unsafe (bool): a boolean flag that denotes whether the parametrization
- may change the dtype and shape of the tensor. Default: `False`
- Warning: the parametrization is not checked for consistency upon registration.
- Enable this flag at your own risk.
- """
- original: Tensor
- unsafe: bool
- def __init__(
- self, modules: Sequence[Module], original: Union[Tensor, Parameter], unsafe: bool = False
- ) -> None:
- # We require this because we need to treat differently the first parametrization
- # This should never throw, unless this class is used from the outside
- if len(modules) == 0:
- raise ValueError("ParametrizationList requires one or more modules.")
- super().__init__(modules)
- self.unsafe = unsafe
- # In plain words:
- # module.weight must keep its dtype and shape.
- # Furthermore, if there is no right_inverse or the right_inverse returns a tensor,
- # this should be of the same dtype as the original tensor
- #
- # We check that the following invariants hold:
- # X = module.weight
- # Y = param.right_inverse(X)
- # assert isinstance(Y, Tensor) or
- # (isinstance(Y, collections.abc.Sequence) and all(isinstance(t, Tensor) for t in Y))
- # Z = param(Y) if isisntance(Y, Tensor) else param(*Y)
- # # Consistency checks
- # assert X.dtype == Z.dtype and X.shape == Z.shape
- # # If it has one input, this allows to be able to use set_ to be able to
- # # move data to/from the original tensor without changing its id (which is what the
- # # optimizer uses to track parameters)
- # if isinstance(Y, Tensor)
- # assert X.dtype == Y.dtype
- # Below we use original = X, new = Y
- original_shape = original.shape
- original_dtype = original.dtype
- # Compute new
- with torch.no_grad():
- new = original
- for module in reversed(self): # type: ignore[call-overload]
- if hasattr(module, "right_inverse"):
- try:
- new = module.right_inverse(new)
- except NotImplementedError:
- pass
- # else, or if it throws, we assume that right_inverse is the identity
- if not isinstance(new, Tensor) and not isinstance(new, collections.abc.Sequence):
- raise ValueError("'right_inverse' must return a Tensor or a Sequence of tensors (list, tuple...). "
- f"Got {type(new).__name__}")
- # Set the number of original tensors
- self.is_tensor = isinstance(new, Tensor)
- self.ntensors = 1 if self.is_tensor else len(new)
- # Register the tensor(s)
- if self.is_tensor:
- if original.dtype != new.dtype:
- raise ValueError(
- "When `right_inverse` outputs one tensor, it may not change the dtype.\n"
- f"original.dtype: {original.dtype}\n"
- f"right_inverse(original).dtype: {new.dtype}"
- )
- # Set the original to original so that the user does not need to re-register the parameter
- # manually in the optimiser
- with torch.no_grad():
- original.set_(new) # type: ignore[call-overload]
- _register_parameter_or_buffer(self, "original", original)
- else:
- for i, originali in enumerate(new):
- if not isinstance(originali, Tensor):
- raise ValueError("'right_inverse' must return a Tensor or a Sequence of tensors "
- "(list, tuple...). "
- f"Got element {i} of the sequence with type {type(originali).__name__}.")
- # If the original tensor was a Parameter that required grad, we expect the user to
- # add the new parameters to the optimizer after registering the parametrization
- # (this is documented)
- if isinstance(original, Parameter):
- originali = Parameter(originali)
- originali.requires_grad_(original.requires_grad)
- _register_parameter_or_buffer(self, f"original{i}", originali)
- if not self.unsafe:
- # Consistency checks:
- # Since f : A -> B, right_inverse : B -> A, Z and original should live in B
- # Z = forward(right_inverse(original))
- Z = self()
- if not isinstance(Z, Tensor):
- raise ValueError(
- f"A parametrization must return a tensor. Got {type(Z).__name__}."
- )
- if Z.dtype != original_dtype:
- raise ValueError(
- "Registering a parametrization may not change the dtype of the tensor, unless `unsafe` flag is enabled.\n"
- f"unparametrized dtype: {original_dtype}\n"
- f"parametrized dtype: {Z.dtype}"
- )
- if Z.shape != original_shape:
- raise ValueError(
- "Registering a parametrization may not change the shape of the tensor, unless `unsafe` flag is enabled.\n"
- f"unparametrized shape: {original_shape}\n"
- f"parametrized shape: {Z.shape}"
- )
- def right_inverse(self, value: Tensor) -> None:
- r"""Calls the methods ``right_inverse`` (see :func:`register_parametrization`)
- of the parametrizations in the inverse order they were registered in.
- Then, it stores the result in ``self.original`` if ``right_inverse`` outputs one tensor
- or in ``self.original0``, ``self.original1``, ... if it outputs several.
- Args:
- value (Tensor): Value to which initialize the module
- """
- # All the exceptions in this function should almost never throw.
- # They could throw if, for example, right_inverse function returns a different
- # dtype when given a different input, which should most likely be caused by a
- # bug in the user's code
- with torch.no_grad():
- # See https://github.com/pytorch/pytorch/issues/53103
- for module in reversed(self): # type: ignore[call-overload]
- if hasattr(module, "right_inverse"):
- value = module.right_inverse(value)
- else:
- raise RuntimeError(f"parametrization {type(module).__name__} does not implement "
- "right_inverse.")
- if self.is_tensor:
- # These exceptions should only throw when a right_inverse function does not
- # return the same dtype for every input, which should most likely be caused by a bug
- if not isinstance(value, Tensor):
- raise ValueError(
- f"`right_inverse` should return a tensor. Got {type(value).__name__}"
- )
- if value.dtype != self.original.dtype:
- raise ValueError(
- f"The tensor returned by `right_inverse` has dtype {value.dtype} "
- f"while `original` has dtype {self.original.dtype}"
- )
- # We know that the result is going to have the same dtype
- self.original.set_(value) # type: ignore[call-overload]
- else:
- if not isinstance(value, collections.abc.Sequence):
- raise ValueError(
- "'right_inverse' must return a sequence of tensors. "
- f"Got {type(value).__name__}."
- )
- if len(value) != self.ntensors:
- raise ValueError(
- "'right_inverse' must return a sequence of tensors of length "
- f"{self.ntensors}. Got a sequence of length {len(value)}."
- )
- for i, tensor in enumerate(value):
- original_i = getattr(self, f"original{i}")
- if not isinstance(tensor, Tensor):
- raise ValueError(
- f"`right_inverse` must return a sequence of tensors. "
- f"Got element {i} of type {type(tensor).__name__}"
- )
- if original_i.dtype != tensor.dtype:
- raise ValueError(
- f"Tensor {i} returned by `right_inverse` has dtype {tensor.dtype} "
- f"while `original{i}` has dtype {original_i.dtype}"
- )
- original_i.set_(tensor)
- def forward(self) -> Tensor:
- if torch.jit.is_scripting():
- raise RuntimeError('Parametrization is not working with scripting.')
- # Unpack the originals for the first parametrization
- if self.is_tensor:
- x = self[0](self.original)
- else:
- originals = (getattr(self, f"original{i}") for i in range(self.ntensors))
- x = self[0](*originals)
- # It's not possible to call self[1:] here, so we have to be a bit more cryptic
- # Also we want to skip all non-integer keys
- curr_idx = 1
- while hasattr(self, str(curr_idx)):
- x = self[curr_idx](x)
- curr_idx += 1
- return x
- def _inject_new_class(module: Module) -> None:
- r"""Sets up a module to be parametrized.
- This works by substituting the class of the module by a class
- that extends it to be able to inject a property
- Args:
- module (nn.Module): module into which to inject the property
- """
- cls = module.__class__
- def default_deepcopy(self, memo):
- # Just emulate a standard deepcopy procedure when __deepcopy__ doesn't exist in the current class.
- obj = memo.get(id(self), None)
- if obj is not None:
- return obj
- replica = self.__new__(self.__class__)
- memo[id(self)] = replica
- replica.__dict__ = deepcopy(self.__dict__, memo)
- # Also save all slots if they exist.
- slots_to_save = copyreg._slotnames(self.__class__) # type: ignore[attr-defined]
- for slot in slots_to_save:
- if hasattr(self, slot):
- setattr(replica, slot, deepcopy(getattr(self, slot), memo))
- return replica
- def getstate(self):
- raise RuntimeError(
- "Serialization of parametrized modules is only "
- "supported through state_dict(). See:\n"
- "https://pytorch.org/tutorials/beginner/saving_loading_models.html"
- "#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training"
- )
- dct = {"__getstate__": getstate}
- # We don't allow serialization of parametrized modules but should still allow deepcopying.
- # Default 'deepcopy' function invokes __deepcopy__ method instead of __getstate__ when it exists.
- if not hasattr(cls, "__deepcopy__"):
- dct["__deepcopy__"] = default_deepcopy # type: ignore[assignment]
- param_cls = type(
- f"Parametrized{cls.__name__}",
- (cls,),
- dct,
- )
- module.__class__ = param_cls
- def _inject_property(module: Module, tensor_name: str) -> None:
- r"""Injects a property into module[tensor_name].
- It assumes that the class in the module has already been modified from its
- original one using _inject_new_class and that the tensor under :attr:`tensor_name`
- has already been moved out
- Args:
- module (nn.Module): module into which to inject the property
- tensor_name (str): name of the name of the property to create
- """
- # We check the precondition.
- # This should never fire if register_parametrization is correctly implemented
- assert not hasattr(module, tensor_name)
- @torch.jit.unused
- def get_cached_parametrization(parametrization) -> Tensor:
- global _cache
- key = (id(module), tensor_name)
- tensor = _cache.get(key)
- if tensor is None:
- tensor = parametrization()
- _cache[key] = tensor
- return tensor
- def get_parametrized(self) -> Tensor:
- if torch.jit.is_scripting():
- raise RuntimeError('Parametrization is not working with scripting.')
- parametrization = self.parametrizations[tensor_name]
- if _cache_enabled:
- if torch.jit.is_scripting():
- # Scripting
- raise RuntimeError('Caching is not implemented for scripting. '
- 'Either disable caching or avoid scripting.')
- elif torch._C._get_tracing_state() is not None:
- # Tracing
- raise RuntimeError('Cannot trace a model while caching parametrizations.')
- else:
- return get_cached_parametrization(parametrization)
- else:
- # If caching is not active, this function just evaluates the parametrization
- return parametrization()
- def set_original(self, value: Tensor) -> None:
- if torch.jit.is_scripting():
- raise RuntimeError('Parametrization is not working with scripting.')
- self.parametrizations[tensor_name].right_inverse(value)
- setattr(module.__class__, tensor_name, property(get_parametrized, set_original))
- def register_parametrization(
- module: Module, tensor_name: str, parametrization: Module, *, unsafe: bool = False,
- ) -> Module:
- r"""Adds a parametrization to a tensor in a module.
- Assume that ``tensor_name="weight"`` for simplicity. When accessing ``module.weight``,
- the module will return the parametrized version ``parametrization(module.weight)``.
- If the original tensor requires a gradient, the backward pass will differentiate
- through :attr:`parametrization`, and the optimizer will update the tensor accordingly.
- The first time that a module registers a parametrization, this function will add an attribute
- ``parametrizations`` to the module of type :class:`~ParametrizationList`.
- The list of parametrizations on the tensor ``weight`` will be accessible under
- ``module.parametrizations.weight``.
- The original tensor will be accessible under
- ``module.parametrizations.weight.original``.
- Parametrizations may be concatenated by registering several parametrizations
- on the same attribute.
- The training mode of a registered parametrization is updated on registration
- to match the training mode of the host module
- Parametrized parameters and buffers have an inbuilt caching system that can be activated
- using the context manager :func:`cached`.
- A :attr:`parametrization` may optionally implement a method with signature
- .. code-block:: python
- def right_inverse(self, X: Tensor) -> Union[Tensor, Sequence[Tensor]]
- This method is called on the unparametrized tensor when the first parametrization
- is registered to compute the initial value of the original tensor.
- If this method is not implemented, the original tensor will be just the unparametrized tensor.
- If all the parametrizations registered on a tensor implement `right_inverse` it is possible
- to initialize a parametrized tensor by assigning to it, as shown in the example below.
- It is possible for the first parametrization to depend on several inputs.
- This may be implemented returning a tuple of tensors from ``right_inverse``
- (see the example implementation of a ``RankOne`` parametrization below).
- In this case, the unconstrained tensors are also located under ``module.parametrizations.weight``
- with names ``original0``, ``original1``,...
- .. note::
- If unsafe=False (default) both the forward and right_inverse methods will be called
- once to perform a number of consistency checks.
- If unsafe=True, then right_inverse will be called if the tensor is not parametrized,
- and nothing will be called otherwise.
- .. note::
- In most situations, ``right_inverse`` will be a function such that
- ``forward(right_inverse(X)) == X`` (see
- `right inverse <https://en.wikipedia.org/wiki/Inverse_function#Right_inverses>`_).
- Sometimes, when the parametrization is not surjective, it may be reasonable
- to relax this.
- .. warning::
- If a parametrization depends on several inputs, :func:`~register_parametrization`
- will register a number of new parameters. If such parametrization is registered
- after the optimizer is created, these new parameters will need to be added manually
- to the optimizer. See :meth:`torch.Optimizer.add_param_group`.
- Args:
- module (nn.Module): module on which to register the parametrization
- tensor_name (str): name of the parameter or buffer on which to register
- the parametrization
- parametrization (nn.Module): the parametrization to register
- Keyword args:
- unsafe (bool): a boolean flag that denotes whether the parametrization
- may change the dtype and shape of the tensor. Default: `False`
- Warning: the parametrization is not checked for consistency upon registration.
- Enable this flag at your own risk.
- Raises:
- ValueError: if the module does not have a parameter or a buffer named :attr:`tensor_name`
- Examples:
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
- >>> import torch
- >>> import torch.nn as nn
- >>> import torch.nn.utils.parametrize as P
- >>>
- >>> class Symmetric(nn.Module):
- >>> def forward(self, X):
- >>> return X.triu() + X.triu(1).T # Return a symmetric matrix
- >>>
- >>> def right_inverse(self, A):
- >>> return A.triu()
- >>>
- >>> m = nn.Linear(5, 5)
- >>> P.register_parametrization(m, "weight", Symmetric())
- >>> print(torch.allclose(m.weight, m.weight.T)) # m.weight is now symmetric
- True
- >>> A = torch.rand(5, 5)
- >>> A = A + A.T # A is now symmetric
- >>> m.weight = A # Initialize the weight to be the symmetric matrix A
- >>> print(torch.allclose(m.weight, A))
- True
- >>> class RankOne(nn.Module):
- >>> def forward(self, x, y):
- >>> # Form a rank 1 matrix multiplying two vectors
- >>> return x.unsqueeze(-1) @ y.unsqueeze(-2)
- >>>
- >>> def right_inverse(self, Z):
- >>> # Project Z onto the rank 1 matrices
- >>> U, S, Vh = torch.linalg.svd(Z, full_matrices=False)
- >>> # Return rescaled singular vectors
- >>> s0_sqrt = S[0].sqrt().unsqueeze(-1)
- >>> return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt
- >>>
- >>> linear_rank_one = P.register_parametrization(nn.Linear(4, 4), "weight", RankOne())
- >>> print(torch.linalg.matrix_rank(linear_rank_one.weight).item())
- 1
- """
- parametrization.train(module.training)
- if is_parametrized(module, tensor_name):
- # Correctness checks.
- # If A is the space of tensors with shape and dtype equal to module.weight
- # we check that parametrization.forward and parametrization.right_inverse are
- # functions from A to A
- if not unsafe:
- Y = getattr(module, tensor_name)
- X = parametrization(Y)
- if not isinstance(X, Tensor):
- raise ValueError(
- f"A parametrization must return a tensor. Got {type(X).__name__}."
- )
- if X.dtype != Y.dtype:
- raise ValueError(
- "Registering a parametrization may not change the dtype of the tensor, unless the `unsafe` flag is enabled.\n"
- f"module.{tensor_name}.dtype: {Y.dtype}\n"
- f"parametrization(module.{tensor_name}).dtype: {X.dtype}"
- )
- if X.shape != Y.shape:
- raise ValueError(
- "Registering a parametrization may not change the shape of the tensor, unless the `unsafe` flag is enabled.\n"
- f"module.{tensor_name}.shape: {Y.shape}\n"
- f"parametrization(module.{tensor_name}).shape: {X.shape}"
- )
- if hasattr(parametrization, "right_inverse"):
- try:
- Z = parametrization.right_inverse(X) # type: ignore[operator]
- except NotImplementedError:
- pass
- else:
- if not isinstance(Z, Tensor):
- raise ValueError(
- f"parametrization.right_inverse must return a tensor. Got: {type(Z).__name__}"
- )
- if Z.dtype != Y.dtype:
- raise ValueError(
- "The tensor returned by parametrization.right_inverse must have the same dtype "
- f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n"
- f"module.{tensor_name}.dtype: {Y.dtype}\n"
- f"returned dtype: {Z.dtype}"
- )
- if Z.shape != Y.shape:
- raise ValueError(
- "The tensor returned by parametrization.right_inverse must have the same shape "
- f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n"
- f"module.{tensor_name}.shape: {Y.shape}\n"
- f"returned shape: {Z.shape}"
- )
- # else right_inverse is assumed to be the identity
- # add the new parametrization to the parametrization list
- assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy
- module.parametrizations[tensor_name].append(parametrization)
- # If unsafe was True in previous parametrization, keep it enabled
- module.parametrizations[tensor_name].unsafe |= unsafe # type: ignore[index, union-attr]
- elif tensor_name in module._buffers or tensor_name in module._parameters:
- # Set the parametrization mechanism
- # Fetch the original buffer or parameter
- original = getattr(module, tensor_name)
- # We create this early to check for possible errors
- parametrizations = ParametrizationList([parametrization], original, unsafe=unsafe)
- # Delete the previous parameter or buffer
- delattr(module, tensor_name)
- # If this is the first parametrization registered on the module,
- # we prepare the module to inject the property
- if not is_parametrized(module):
- # Change the class
- _inject_new_class(module)
- # Inject a ``ModuleDict`` into the instance under module.parametrizations
- module.parametrizations = ModuleDict()
- # Add a property into the class
- _inject_property(module, tensor_name)
- # Add a ParametrizationList
- assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy
- module.parametrizations[tensor_name] = parametrizations
- else:
- raise ValueError(
- f"Module '{module}' does not have a parameter, a buffer, or a "
- f"parametrized element with name '{tensor_name}'"
- )
- return module
- def is_parametrized(module: Module, tensor_name: Optional[str] = None) -> bool:
- r"""Returns ``True`` if module has an active parametrization.
- If the argument :attr:`tensor_name` is specified, returns ``True`` if
- ``module[tensor_name]`` is parametrized.
- Args:
- module (nn.Module): module to query
- tensor_name (str, optional): attribute in the module to query
- Default: ``None``
- """
- parametrizations = getattr(module, "parametrizations", None)
- if parametrizations is None or not isinstance(parametrizations, ModuleDict):
- return False
- if tensor_name is None:
- # Check that there is at least one parametrized buffer or Parameter
- return len(parametrizations) > 0
- else:
- return tensor_name in parametrizations
- def remove_parametrizations(
- module: Module, tensor_name: str, leave_parametrized: bool = True
- ) -> Module:
- r"""Removes the parametrizations on a tensor in a module.
- - If ``leave_parametrized=True``, ``module[tensor_name]`` will be set to
- its current output. In this case, the parametrization shall not change the ``dtype``
- of the tensor.
- - If ``leave_parametrized=False``, ``module[tensor_name]`` will be set to
- the unparametrised tensor in ``module.parametrizations[tensor_name].original``.
- This is only possible when the parametrization depends on just one tensor.
- Args:
- module (nn.Module): module from which remove the parametrization
- tensor_name (str): name of the parametrization to be removed
- leave_parametrized (bool, optional): leave the attribute :attr:`tensor_name` parametrized.
- Default: ``True``
- Returns:
- Module: module
- Raises:
- ValueError: if ``module[tensor_name]`` is not parametrized
- ValueError: if ``leave_parametrized=False`` and the parametrization depends on several tensors
- """
- if not is_parametrized(module, tensor_name):
- raise ValueError(f"Module {module} does not have a parametrization on {tensor_name}")
- # Fetch the original tensor
- assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy
- parametrizations = module.parametrizations[tensor_name]
- if parametrizations.is_tensor:
- original = parametrizations.original
- if leave_parametrized:
- with torch.no_grad():
- t = getattr(module, tensor_name)
- # We know they have the same dtype because we have checked this when registering the
- # parametrizations. As such, we can use set_
- # We do this so that the parameter does not to change the id()
- # This way the user does not need to update the optimizer
- with torch.no_grad():
- if type(original) is torch.Tensor:
- original.set_(t)
- else:
- try:
- original.set_(t)
- except RuntimeError as e:
- # TODO: Fix this for tensor subclasses that are parameters:
- # RuntimeError: set_storage is not allowed on a Tensor created from .data or .detach().
- raise RuntimeError("Calling remove_parametrizations() with leave_parametrized=True "
- "for a parameter that is an instance of a tensor subclass requires "
- "set_() to be implemented correctly for the tensor subclass. Either "
- "set leave_parametrized=False or provide a working implementation for "
- "set_() in the tensor subclass.") from e
- else:
- if leave_parametrized:
- # We cannot use no_grad because we need to know whether one or more
- # original tensors required grad
- t = getattr(module, tensor_name)
- # We'll have to trust the user to add it to the optimizer
- original = Parameter(t) if t.requires_grad else t
- else:
- raise ValueError("Cannot leave unparametrized (`leave_parametrized=False`) a tensor "
- "that is parametrized in terms of a sequence of tensors.")
- # Delete the property that manages the parametrization
- delattr(module.__class__, tensor_name)
- # Delete the ParametrizationList
- del module.parametrizations[tensor_name]
- # Restore the parameter / buffer into the main class
- _register_parameter_or_buffer(module, tensor_name, original)
- # Roll back the parametrized class if no other buffer or parameter
- # is currently parametrized in this class
- if not is_parametrized(module):
- delattr(module, "parametrizations")
- # Restore class
- orig_cls = module.__class__.__bases__[0]
- module.__class__ = orig_cls
- return module
- def type_before_parametrizations(module: Module) -> type:
- r"""Returns the module type before parametrizations were applied and if not,
- then it returns the module type.
- Args:
- module (nn.Module): module to get type of
- """
- if is_parametrized(module):
- return module.__class__.__bases__[0]
- else:
- return type(module)
- def transfer_parametrizations_and_params(
- from_module: Module, to_module: Module, tensor_name: Optional[str] = None
- ) -> Module:
- r"""Transfers parametrizations and the parameters they parametrize from from_module
- to to_module. If tensor_name is specified, only transfers the specified parameter, otherwise
- transfers all parametrized parameters. If those parameters do not exist in to_module, it will create them.
- Does nothing if from_module is not parametrized.
- Args:
- from_module (nn.Module): module to transfer from
- to_module (nn.Module): module to transfer to
- tensor_name (str, optional): parameter to transfer
- Returns:
- Module: to_module
- """
- if is_parametrized(from_module):
- assert isinstance(from_module.parametrizations, ModuleDict) # for mypy
- # get list of all params or the single param to transfer
- parameters_to_transfer: Union[list, ModuleDict] = (
- from_module.parametrizations if tensor_name is None else [tensor_name]
- )
- assert hasattr(parameters_to_transfer, "__iter__") # for mypy
- for parameter_name in parameters_to_transfer:
- # initialize the to-be-transfered param in to_module if it doesn't exist already
- if not hasattr(to_module, parameter_name):
- setattr(
- to_module,
- parameter_name,
- Parameter(getattr(from_module, parameter_name)),
- )
- # apply the params's parametrizations to to_module
- for param_func in from_module.parametrizations[parameter_name]:
- register_parametrization(to_module, parameter_name, param_func)
- assert isinstance(to_module.parametrizations, ModuleDict) # for mypy
- # make values match, original values can be stored in either original or
- # original0, original1..., need to check both cases
- if hasattr(from_module.parametrizations[parameter_name], "original"):
- to_module.parametrizations[parameter_name].original = \
- from_module.parametrizations[parameter_name].original
- else:
- num = 0
- orig_num = "original" + str(num)
- # loop through each original# until all values have been set
- while hasattr(from_module.parametrizations[parameter_name], orig_num):
- setattr(
- to_module.parametrizations[parameter_name],
- orig_num,
- getattr(from_module.parametrizations[parameter_name], orig_num),
- )
- num = num + 1
- orig_num = "original" + str(num)
- return to_module
|