|
- import fnmatch
- import importlib
- import inspect
- import sys
- from dataclasses import dataclass
- from enum import Enum
- from functools import partial
- from inspect import signature
- from types import ModuleType
- from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Set, Type, TypeVar, Union
- from torch import nn
- from .._internally_replaced_utils import load_state_dict_from_url
- __all__ = ["WeightsEnum", "Weights", "get_model", "get_model_builder", "get_model_weights", "get_weight", "list_models"]
- @dataclass
- class Weights:
- """
- This class is used to group important attributes associated with the pre-trained weights.
- Args:
- url (str): The location where we find the weights.
- transforms (Callable): A callable that constructs the preprocessing method (or validation preset transforms)
- needed to use the model. The reason we attach a constructor method rather than an already constructed
- object is because the specific object might have memory and thus we want to delay initialization until
- needed.
- meta (Dict[str, Any]): Stores meta-data related to the weights of the model and its configuration. These can be
- informative attributes (for example the number of parameters/flops, recipe link/methods used in training
- etc), configuration parameters (for example the `num_classes`) needed to construct the model or important
- meta-data (for example the `classes` of a classification model) needed to use the model.
- """
- url: str
- transforms: Callable
- meta: Dict[str, Any]
- def __eq__(self, other: Any) -> bool:
- # We need this custom implementation for correct deep-copy and deserialization behavior.
- # TL;DR: After the definition of an enum, creating a new instance, i.e. by deep-copying or deserializing it,
- # involves an equality check against the defined members. Unfortunately, the `transforms` attribute is often
- # defined with `functools.partial` and `fn = partial(...); assert deepcopy(fn) != fn`. Without custom handling
- # for it, the check against the defined members would fail and effectively prevent the weights from being
- # deep-copied or deserialized.
- # See https://github.com/pytorch/vision/pull/7107 for details.
- if not isinstance(other, Weights):
- return NotImplemented
- if self.url != other.url:
- return False
- if self.meta != other.meta:
- return False
- if isinstance(self.transforms, partial) and isinstance(other.transforms, partial):
- return (
- self.transforms.func == other.transforms.func
- and self.transforms.args == other.transforms.args
- and self.transforms.keywords == other.transforms.keywords
- )
- else:
- return self.transforms == other.transforms
- class WeightsEnum(Enum):
- """
- This class is the parent class of all model weights. Each model building method receives an optional `weights`
- parameter with its associated pre-trained weights. It inherits from `Enum` and its values should be of type
- `Weights`.
- Args:
- value (Weights): The data class entry with the weight information.
- """
- @classmethod
- def verify(cls, obj: Any) -> Any:
- if obj is not None:
- if type(obj) is str:
- obj = cls[obj.replace(cls.__name__ + ".", "")]
- elif not isinstance(obj, cls):
- raise TypeError(
- f"Invalid Weight class provided; expected {cls.__name__} but received {obj.__class__.__name__}."
- )
- return obj
- def get_state_dict(self, *args: Any, **kwargs: Any) -> Mapping[str, Any]:
- return load_state_dict_from_url(self.url, *args, **kwargs)
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}.{self._name_}"
- @property
- def url(self):
- return self.value.url
- @property
- def transforms(self):
- return self.value.transforms
- @property
- def meta(self):
- return self.value.meta
- def get_weight(name: str) -> WeightsEnum:
- """
- Gets the weights enum value by its full name. Example: "ResNet50_Weights.IMAGENET1K_V1"
- Args:
- name (str): The name of the weight enum entry.
- Returns:
- WeightsEnum: The requested weight enum.
- """
- try:
- enum_name, value_name = name.split(".")
- except ValueError:
- raise ValueError(f"Invalid weight name provided: '{name}'.")
- base_module_name = ".".join(sys.modules[__name__].__name__.split(".")[:-1])
- base_module = importlib.import_module(base_module_name)
- model_modules = [base_module] + [
- x[1]
- for x in inspect.getmembers(base_module, inspect.ismodule)
- if x[1].__file__.endswith("__init__.py") # type: ignore[union-attr]
- ]
- weights_enum = None
- for m in model_modules:
- potential_class = m.__dict__.get(enum_name, None)
- if potential_class is not None and issubclass(potential_class, WeightsEnum):
- weights_enum = potential_class
- break
- if weights_enum is None:
- raise ValueError(f"The weight enum '{enum_name}' for the specific method couldn't be retrieved.")
- return weights_enum[value_name]
- def get_model_weights(name: Union[Callable, str]) -> Type[WeightsEnum]:
- """
- Returns the weights enum class associated to the given model.
- Args:
- name (callable or str): The model builder function or the name under which it is registered.
- Returns:
- weights_enum (WeightsEnum): The weights enum class associated with the model.
- """
- model = get_model_builder(name) if isinstance(name, str) else name
- return _get_enum_from_fn(model)
- def _get_enum_from_fn(fn: Callable) -> Type[WeightsEnum]:
- """
- Internal method that gets the weight enum of a specific model builder method.
- Args:
- fn (Callable): The builder method used to create the model.
- Returns:
- WeightsEnum: The requested weight enum.
- """
- sig = signature(fn)
- if "weights" not in sig.parameters:
- raise ValueError("The method is missing the 'weights' argument.")
- ann = signature(fn).parameters["weights"].annotation
- weights_enum = None
- if isinstance(ann, type) and issubclass(ann, WeightsEnum):
- weights_enum = ann
- else:
- # handle cases like Union[Optional, T]
- # TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8
- for t in ann.__args__: # type: ignore[union-attr]
- if isinstance(t, type) and issubclass(t, WeightsEnum):
- weights_enum = t
- break
- if weights_enum is None:
- raise ValueError(
- "The WeightsEnum class for the specific method couldn't be retrieved. Make sure the typing info is correct."
- )
- return weights_enum
- M = TypeVar("M", bound=nn.Module)
- BUILTIN_MODELS = {}
- def register_model(name: Optional[str] = None) -> Callable[[Callable[..., M]], Callable[..., M]]:
- def wrapper(fn: Callable[..., M]) -> Callable[..., M]:
- key = name if name is not None else fn.__name__
- if key in BUILTIN_MODELS:
- raise ValueError(f"An entry is already registered under the name '{key}'.")
- BUILTIN_MODELS[key] = fn
- return fn
- return wrapper
- def list_models(
- module: Optional[ModuleType] = None,
- include: Union[Iterable[str], str, None] = None,
- exclude: Union[Iterable[str], str, None] = None,
- ) -> List[str]:
- """
- Returns a list with the names of registered models.
- Args:
- module (ModuleType, optional): The module from which we want to extract the available models.
- include (str or Iterable[str], optional): Filter(s) for including the models from the set of all models.
- Filters are passed to `fnmatch <https://docs.python.org/3/library/fnmatch.html>`__ to match Unix shell-style
- wildcards. In case of many filters, the results is the union of individual filters.
- exclude (str or Iterable[str], optional): Filter(s) applied after include_filters to remove models.
- Filter are passed to `fnmatch <https://docs.python.org/3/library/fnmatch.html>`__ to match Unix shell-style
- wildcards. In case of many filters, the results is removal of all the models that match any individual filter.
- Returns:
- models (list): A list with the names of available models.
- """
- all_models = {
- k for k, v in BUILTIN_MODELS.items() if module is None or v.__module__.rsplit(".", 1)[0] == module.__name__
- }
- if include:
- models: Set[str] = set()
- if isinstance(include, str):
- include = [include]
- for include_filter in include:
- models = models | set(fnmatch.filter(all_models, include_filter))
- else:
- models = all_models
- if exclude:
- if isinstance(exclude, str):
- exclude = [exclude]
- for exclude_filter in exclude:
- models = models - set(fnmatch.filter(all_models, exclude_filter))
- return sorted(models)
- def get_model_builder(name: str) -> Callable[..., nn.Module]:
- """
- Gets the model name and returns the model builder method.
- Args:
- name (str): The name under which the model is registered.
- Returns:
- fn (Callable): The model builder method.
- """
- name = name.lower()
- try:
- fn = BUILTIN_MODELS[name]
- except KeyError:
- raise ValueError(f"Unknown model {name}")
- return fn
- def get_model(name: str, **config: Any) -> nn.Module:
- """
- Gets the model name and configuration and returns an instantiated model.
- Args:
- name (str): The name under which the model is registered.
- **config (Any): parameters passed to the model builder method.
- Returns:
- model (nn.Module): The initialized model.
- """
- fn = get_model_builder(name)
- return fn(**config)
|