1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677 |
- import collections
- from itertools import repeat
- from typing import List, Dict, Any
- __all__ = ['consume_prefix_in_state_dict_if_present']
- def _ntuple(n, name="parse"):
- def parse(x):
- if isinstance(x, collections.abc.Iterable):
- return tuple(x)
- return tuple(repeat(x, n))
- parse.__name__ = name
- return parse
- _single = _ntuple(1, "_single")
- _pair = _ntuple(2, "_pair")
- _triple = _ntuple(3, "_triple")
- _quadruple = _ntuple(4, "_quadruple")
- def _reverse_repeat_tuple(t, n):
- r"""Reverse the order of `t` and repeat each element for `n` times.
- This can be used to translate padding arg used by Conv and Pooling modules
- to the ones used by `F.pad`.
- """
- return tuple(x for x in reversed(t) for _ in range(n))
- def _list_with_default(out_size: List[int], defaults: List[int]) -> List[int]:
- if isinstance(out_size, int):
- return out_size
- if len(defaults) <= len(out_size):
- raise ValueError(
- "Input dimension should be at least {}".format(len(out_size) + 1)
- )
- return [
- v if v is not None else d for v, d in zip(out_size, defaults[-len(out_size) :])
- ]
- def consume_prefix_in_state_dict_if_present(
- state_dict: Dict[str, Any], prefix: str
- ) -> None:
- r"""Strip the prefix in state_dict in place, if any.
- ..note::
- Given a `state_dict` from a DP/DDP model, a local model can load it by applying
- `consume_prefix_in_state_dict_if_present(state_dict, "module.")` before calling
- :meth:`torch.nn.Module.load_state_dict`.
- Args:
- state_dict (OrderedDict): a state-dict to be loaded to the model.
- prefix (str): prefix.
- """
- keys = sorted(state_dict.keys())
- for key in keys:
- if key.startswith(prefix):
- newkey = key[len(prefix) :]
- state_dict[newkey] = state_dict.pop(key)
- # also strip the prefix in metadata if any.
- if "_metadata" in state_dict:
- metadata = state_dict["_metadata"]
- for key in list(metadata.keys()):
- # for the metadata dict, the key can be:
- # '': for the DDP module, which we want to remove.
- # 'module': for the actual model.
- # 'module.xx.xx': for the rest.
- if len(key) == 0:
- continue
- newkey = key[len(prefix) :]
- metadata[newkey] = metadata.pop(key)
|