123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256 |
- import functools
- import inspect
- import warnings
- from collections import OrderedDict
- from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, Union
- from torch import nn
- from .._utils import sequence_to_str
- from ._api import WeightsEnum
- class IntermediateLayerGetter(nn.ModuleDict):
- """
- Module wrapper that returns intermediate layers from a model
- It has a strong assumption that the modules have been registered
- into the model in the same order as they are used.
- This means that one should **not** reuse the same nn.Module
- twice in the forward if you want this to work.
- Additionally, it is only able to query submodules that are directly
- assigned to the model. So if `model` is passed, `model.feature1` can
- be returned, but not `model.feature1.layer2`.
- Args:
- model (nn.Module): model on which we will extract the features
- return_layers (Dict[name, new_name]): a dict containing the names
- of the modules for which the activations will be returned as
- the key of the dict, and the value of the dict is the name
- of the returned activation (which the user can specify).
- Examples::
- >>> m = torchvision.models.resnet18(weights=ResNet18_Weights.DEFAULT)
- >>> # extract layer1 and layer3, giving as names `feat1` and feat2`
- >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m,
- >>> {'layer1': 'feat1', 'layer3': 'feat2'})
- >>> out = new_m(torch.rand(1, 3, 224, 224))
- >>> print([(k, v.shape) for k, v in out.items()])
- >>> [('feat1', torch.Size([1, 64, 56, 56])),
- >>> ('feat2', torch.Size([1, 256, 14, 14]))]
- """
- _version = 2
- __annotations__ = {
- "return_layers": Dict[str, str],
- }
- def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None:
- if not set(return_layers).issubset([name for name, _ in model.named_children()]):
- raise ValueError("return_layers are not present in model")
- orig_return_layers = return_layers
- return_layers = {str(k): str(v) for k, v in return_layers.items()}
- layers = OrderedDict()
- for name, module in model.named_children():
- layers[name] = module
- if name in return_layers:
- del return_layers[name]
- if not return_layers:
- break
- super().__init__(layers)
- self.return_layers = orig_return_layers
- def forward(self, x):
- out = OrderedDict()
- for name, module in self.items():
- x = module(x)
- if name in self.return_layers:
- out_name = self.return_layers[name]
- out[out_name] = x
- return out
- def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
- """
- This function is taken from the original tf repo.
- It ensures that all layers have a channel number that is divisible by 8
- It can be seen here:
- https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
- """
- if min_value is None:
- min_value = divisor
- new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
- # Make sure that round down does not go down by more than 10%.
- if new_v < 0.9 * v:
- new_v += divisor
- return new_v
- D = TypeVar("D")
- def kwonly_to_pos_or_kw(fn: Callable[..., D]) -> Callable[..., D]:
- """Decorates a function that uses keyword only parameters to also allow them being passed as positionals.
- For example, consider the use case of changing the signature of ``old_fn`` into the one from ``new_fn``:
- .. code::
- def old_fn(foo, bar, baz=None):
- ...
- def new_fn(foo, *, bar, baz=None):
- ...
- Calling ``old_fn("foo", "bar, "baz")`` was valid, but the same call is no longer valid with ``new_fn``. To keep BC
- and at the same time warn the user of the deprecation, this decorator can be used:
- .. code::
- @kwonly_to_pos_or_kw
- def new_fn(foo, *, bar, baz=None):
- ...
- new_fn("foo", "bar, "baz")
- """
- params = inspect.signature(fn).parameters
- try:
- keyword_only_start_idx = next(
- idx for idx, param in enumerate(params.values()) if param.kind == param.KEYWORD_ONLY
- )
- except StopIteration:
- raise TypeError(f"Found no keyword-only parameter on function '{fn.__name__}'") from None
- keyword_only_params = tuple(inspect.signature(fn).parameters)[keyword_only_start_idx:]
- @functools.wraps(fn)
- def wrapper(*args: Any, **kwargs: Any) -> D:
- args, keyword_only_args = args[:keyword_only_start_idx], args[keyword_only_start_idx:]
- if keyword_only_args:
- keyword_only_kwargs = dict(zip(keyword_only_params, keyword_only_args))
- warnings.warn(
- f"Using {sequence_to_str(tuple(keyword_only_kwargs.keys()), separate_last='and ')} as positional "
- f"parameter(s) is deprecated since 0.13 and may be removed in the future. Please use keyword parameter(s) "
- f"instead."
- )
- kwargs.update(keyword_only_kwargs)
- return fn(*args, **kwargs)
- return wrapper
- W = TypeVar("W", bound=WeightsEnum)
- M = TypeVar("M", bound=nn.Module)
- V = TypeVar("V")
- def handle_legacy_interface(**weights: Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]):
- """Decorates a model builder with the new interface to make it compatible with the old.
- In particular this handles two things:
- 1. Allows positional parameters again, but emits a deprecation warning in case they are used. See
- :func:`torchvision.prototype.utils._internal.kwonly_to_pos_or_kw` for details.
- 2. Handles the default value change from ``pretrained=False`` to ``weights=None`` and ``pretrained=True`` to
- ``weights=Weights`` and emits a deprecation warning with instructions for the new interface.
- Args:
- **weights (Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]): Deprecated parameter
- name and default value for the legacy ``pretrained=True``. The default value can be a callable in which
- case it will be called with a dictionary of the keyword arguments. The only key that is guaranteed to be in
- the dictionary is the deprecated parameter name passed as first element in the tuple. All other parameters
- should be accessed with :meth:`~dict.get`.
- """
- def outer_wrapper(builder: Callable[..., M]) -> Callable[..., M]:
- @kwonly_to_pos_or_kw
- @functools.wraps(builder)
- def inner_wrapper(*args: Any, **kwargs: Any) -> M:
- for weights_param, (pretrained_param, default) in weights.items(): # type: ignore[union-attr]
- # If neither the weights nor the pretrained parameter as passed, or the weights argument already use
- # the new style arguments, there is nothing to do. Note that we cannot use `None` as sentinel for the
- # weight argument, since it is a valid value.
- sentinel = object()
- weights_arg = kwargs.get(weights_param, sentinel)
- if (
- (weights_param not in kwargs and pretrained_param not in kwargs)
- or isinstance(weights_arg, WeightsEnum)
- or (isinstance(weights_arg, str) and weights_arg != "legacy")
- or weights_arg is None
- ):
- continue
- # If the pretrained parameter was passed as positional argument, it is now mapped to
- # `kwargs[weights_param]`. This happens because the @kwonly_to_pos_or_kw decorator uses the current
- # signature to infer the names of positionally passed arguments and thus has no knowledge that there
- # used to be a pretrained parameter.
- pretrained_positional = weights_arg is not sentinel
- if pretrained_positional:
- # We put the pretrained argument under its legacy name in the keyword argument dictionary to have
- # unified access to the value if the default value is a callable.
- kwargs[pretrained_param] = pretrained_arg = kwargs.pop(weights_param)
- else:
- pretrained_arg = kwargs[pretrained_param]
- if pretrained_arg:
- default_weights_arg = default(kwargs) if callable(default) else default
- if not isinstance(default_weights_arg, WeightsEnum):
- raise ValueError(f"No weights available for model {builder.__name__}")
- else:
- default_weights_arg = None
- if not pretrained_positional:
- warnings.warn(
- f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "
- f"please use '{weights_param}' instead."
- )
- msg = (
- f"Arguments other than a weight enum or `None` for '{weights_param}' are deprecated since 0.13 and "
- f"may be removed in the future. "
- f"The current behavior is equivalent to passing `{weights_param}={default_weights_arg}`."
- )
- if pretrained_arg:
- msg = (
- f"{msg} You can also use `{weights_param}={type(default_weights_arg).__name__}.DEFAULT` "
- f"to get the most up-to-date weights."
- )
- warnings.warn(msg)
- del kwargs[pretrained_param]
- kwargs[weights_param] = default_weights_arg
- return builder(*args, **kwargs)
- return inner_wrapper
- return outer_wrapper
- def _ovewrite_named_param(kwargs: Dict[str, Any], param: str, new_value: V) -> None:
- if param in kwargs:
- if kwargs[param] != new_value:
- raise ValueError(f"The parameter '{param}' expected value {new_value} but got {kwargs[param]} instead.")
- else:
- kwargs[param] = new_value
- def _ovewrite_value_param(param: str, actual: Optional[V], expected: V) -> V:
- if actual is not None:
- if actual != expected:
- raise ValueError(f"The parameter '{param}' expected value {expected} but got {actual} instead.")
- return expected
- class _ModelURLs(dict):
- def __getitem__(self, item):
- warnings.warn(
- "Accessing the model URLs via the internal dictionary of the module is deprecated since 0.13 and may "
- "be removed in the future. Please access them via the appropriate Weights Enum instead."
- )
- return super().__getitem__(item)
|