_utils.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. import functools
  2. import inspect
  3. import warnings
  4. from collections import OrderedDict
  5. from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, Union
  6. from torch import nn
  7. from .._utils import sequence_to_str
  8. from ._api import WeightsEnum
  9. class IntermediateLayerGetter(nn.ModuleDict):
  10. """
  11. Module wrapper that returns intermediate layers from a model
  12. It has a strong assumption that the modules have been registered
  13. into the model in the same order as they are used.
  14. This means that one should **not** reuse the same nn.Module
  15. twice in the forward if you want this to work.
  16. Additionally, it is only able to query submodules that are directly
  17. assigned to the model. So if `model` is passed, `model.feature1` can
  18. be returned, but not `model.feature1.layer2`.
  19. Args:
  20. model (nn.Module): model on which we will extract the features
  21. return_layers (Dict[name, new_name]): a dict containing the names
  22. of the modules for which the activations will be returned as
  23. the key of the dict, and the value of the dict is the name
  24. of the returned activation (which the user can specify).
  25. Examples::
  26. >>> m = torchvision.models.resnet18(weights=ResNet18_Weights.DEFAULT)
  27. >>> # extract layer1 and layer3, giving as names `feat1` and feat2`
  28. >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m,
  29. >>> {'layer1': 'feat1', 'layer3': 'feat2'})
  30. >>> out = new_m(torch.rand(1, 3, 224, 224))
  31. >>> print([(k, v.shape) for k, v in out.items()])
  32. >>> [('feat1', torch.Size([1, 64, 56, 56])),
  33. >>> ('feat2', torch.Size([1, 256, 14, 14]))]
  34. """
  35. _version = 2
  36. __annotations__ = {
  37. "return_layers": Dict[str, str],
  38. }
  39. def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None:
  40. if not set(return_layers).issubset([name for name, _ in model.named_children()]):
  41. raise ValueError("return_layers are not present in model")
  42. orig_return_layers = return_layers
  43. return_layers = {str(k): str(v) for k, v in return_layers.items()}
  44. layers = OrderedDict()
  45. for name, module in model.named_children():
  46. layers[name] = module
  47. if name in return_layers:
  48. del return_layers[name]
  49. if not return_layers:
  50. break
  51. super().__init__(layers)
  52. self.return_layers = orig_return_layers
  53. def forward(self, x):
  54. out = OrderedDict()
  55. for name, module in self.items():
  56. x = module(x)
  57. if name in self.return_layers:
  58. out_name = self.return_layers[name]
  59. out[out_name] = x
  60. return out
  61. def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
  62. """
  63. This function is taken from the original tf repo.
  64. It ensures that all layers have a channel number that is divisible by 8
  65. It can be seen here:
  66. https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
  67. """
  68. if min_value is None:
  69. min_value = divisor
  70. new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
  71. # Make sure that round down does not go down by more than 10%.
  72. if new_v < 0.9 * v:
  73. new_v += divisor
  74. return new_v
  75. D = TypeVar("D")
  76. def kwonly_to_pos_or_kw(fn: Callable[..., D]) -> Callable[..., D]:
  77. """Decorates a function that uses keyword only parameters to also allow them being passed as positionals.
  78. For example, consider the use case of changing the signature of ``old_fn`` into the one from ``new_fn``:
  79. .. code::
  80. def old_fn(foo, bar, baz=None):
  81. ...
  82. def new_fn(foo, *, bar, baz=None):
  83. ...
  84. Calling ``old_fn("foo", "bar, "baz")`` was valid, but the same call is no longer valid with ``new_fn``. To keep BC
  85. and at the same time warn the user of the deprecation, this decorator can be used:
  86. .. code::
  87. @kwonly_to_pos_or_kw
  88. def new_fn(foo, *, bar, baz=None):
  89. ...
  90. new_fn("foo", "bar, "baz")
  91. """
  92. params = inspect.signature(fn).parameters
  93. try:
  94. keyword_only_start_idx = next(
  95. idx for idx, param in enumerate(params.values()) if param.kind == param.KEYWORD_ONLY
  96. )
  97. except StopIteration:
  98. raise TypeError(f"Found no keyword-only parameter on function '{fn.__name__}'") from None
  99. keyword_only_params = tuple(inspect.signature(fn).parameters)[keyword_only_start_idx:]
  100. @functools.wraps(fn)
  101. def wrapper(*args: Any, **kwargs: Any) -> D:
  102. args, keyword_only_args = args[:keyword_only_start_idx], args[keyword_only_start_idx:]
  103. if keyword_only_args:
  104. keyword_only_kwargs = dict(zip(keyword_only_params, keyword_only_args))
  105. warnings.warn(
  106. f"Using {sequence_to_str(tuple(keyword_only_kwargs.keys()), separate_last='and ')} as positional "
  107. f"parameter(s) is deprecated since 0.13 and may be removed in the future. Please use keyword parameter(s) "
  108. f"instead."
  109. )
  110. kwargs.update(keyword_only_kwargs)
  111. return fn(*args, **kwargs)
  112. return wrapper
  113. W = TypeVar("W", bound=WeightsEnum)
  114. M = TypeVar("M", bound=nn.Module)
  115. V = TypeVar("V")
  116. def handle_legacy_interface(**weights: Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]):
  117. """Decorates a model builder with the new interface to make it compatible with the old.
  118. In particular this handles two things:
  119. 1. Allows positional parameters again, but emits a deprecation warning in case they are used. See
  120. :func:`torchvision.prototype.utils._internal.kwonly_to_pos_or_kw` for details.
  121. 2. Handles the default value change from ``pretrained=False`` to ``weights=None`` and ``pretrained=True`` to
  122. ``weights=Weights`` and emits a deprecation warning with instructions for the new interface.
  123. Args:
  124. **weights (Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]): Deprecated parameter
  125. name and default value for the legacy ``pretrained=True``. The default value can be a callable in which
  126. case it will be called with a dictionary of the keyword arguments. The only key that is guaranteed to be in
  127. the dictionary is the deprecated parameter name passed as first element in the tuple. All other parameters
  128. should be accessed with :meth:`~dict.get`.
  129. """
  130. def outer_wrapper(builder: Callable[..., M]) -> Callable[..., M]:
  131. @kwonly_to_pos_or_kw
  132. @functools.wraps(builder)
  133. def inner_wrapper(*args: Any, **kwargs: Any) -> M:
  134. for weights_param, (pretrained_param, default) in weights.items(): # type: ignore[union-attr]
  135. # If neither the weights nor the pretrained parameter as passed, or the weights argument already use
  136. # the new style arguments, there is nothing to do. Note that we cannot use `None` as sentinel for the
  137. # weight argument, since it is a valid value.
  138. sentinel = object()
  139. weights_arg = kwargs.get(weights_param, sentinel)
  140. if (
  141. (weights_param not in kwargs and pretrained_param not in kwargs)
  142. or isinstance(weights_arg, WeightsEnum)
  143. or (isinstance(weights_arg, str) and weights_arg != "legacy")
  144. or weights_arg is None
  145. ):
  146. continue
  147. # If the pretrained parameter was passed as positional argument, it is now mapped to
  148. # `kwargs[weights_param]`. This happens because the @kwonly_to_pos_or_kw decorator uses the current
  149. # signature to infer the names of positionally passed arguments and thus has no knowledge that there
  150. # used to be a pretrained parameter.
  151. pretrained_positional = weights_arg is not sentinel
  152. if pretrained_positional:
  153. # We put the pretrained argument under its legacy name in the keyword argument dictionary to have
  154. # unified access to the value if the default value is a callable.
  155. kwargs[pretrained_param] = pretrained_arg = kwargs.pop(weights_param)
  156. else:
  157. pretrained_arg = kwargs[pretrained_param]
  158. if pretrained_arg:
  159. default_weights_arg = default(kwargs) if callable(default) else default
  160. if not isinstance(default_weights_arg, WeightsEnum):
  161. raise ValueError(f"No weights available for model {builder.__name__}")
  162. else:
  163. default_weights_arg = None
  164. if not pretrained_positional:
  165. warnings.warn(
  166. f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "
  167. f"please use '{weights_param}' instead."
  168. )
  169. msg = (
  170. f"Arguments other than a weight enum or `None` for '{weights_param}' are deprecated since 0.13 and "
  171. f"may be removed in the future. "
  172. f"The current behavior is equivalent to passing `{weights_param}={default_weights_arg}`."
  173. )
  174. if pretrained_arg:
  175. msg = (
  176. f"{msg} You can also use `{weights_param}={type(default_weights_arg).__name__}.DEFAULT` "
  177. f"to get the most up-to-date weights."
  178. )
  179. warnings.warn(msg)
  180. del kwargs[pretrained_param]
  181. kwargs[weights_param] = default_weights_arg
  182. return builder(*args, **kwargs)
  183. return inner_wrapper
  184. return outer_wrapper
  185. def _ovewrite_named_param(kwargs: Dict[str, Any], param: str, new_value: V) -> None:
  186. if param in kwargs:
  187. if kwargs[param] != new_value:
  188. raise ValueError(f"The parameter '{param}' expected value {new_value} but got {kwargs[param]} instead.")
  189. else:
  190. kwargs[param] = new_value
  191. def _ovewrite_value_param(param: str, actual: Optional[V], expected: V) -> V:
  192. if actual is not None:
  193. if actual != expected:
  194. raise ValueError(f"The parameter '{param}' expected value {expected} but got {actual} instead.")
  195. return expected
  196. class _ModelURLs(dict):
  197. def __getitem__(self, item):
  198. warnings.warn(
  199. "Accessing the model URLs via the internal dictionary of the module is deprecated since 0.13 and may "
  200. "be removed in the future. Please access them via the appropriate Weights Enum instead."
  201. )
  202. return super().__getitem__(item)