123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503 |
- import copy
- import os
- import pickle
- import pytest
- import test_models as TM
- import torch
- from common_extended_utils import get_file_size_mb, get_ops
- from torchvision import models
- from torchvision.models import get_model_weights, Weights, WeightsEnum
- from torchvision.models._utils import handle_legacy_interface
- from torchvision.models.detection.backbone_utils import mobilenet_backbone, resnet_fpn_backbone
- run_if_test_with_extended = pytest.mark.skipif(
- os.getenv("PYTORCH_TEST_WITH_EXTENDED", "0") != "1",
- reason="Extended tests are disabled by default. Set PYTORCH_TEST_WITH_EXTENDED=1 to run them.",
- )
- @pytest.mark.parametrize(
- "name, model_class",
- [
- ("resnet50", models.ResNet),
- ("retinanet_resnet50_fpn_v2", models.detection.RetinaNet),
- ("raft_large", models.optical_flow.RAFT),
- ("quantized_resnet50", models.quantization.QuantizableResNet),
- ("lraspp_mobilenet_v3_large", models.segmentation.LRASPP),
- ("mvit_v1_b", models.video.MViT),
- ],
- )
- def test_get_model(name, model_class):
- assert isinstance(models.get_model(name), model_class)
- @pytest.mark.parametrize(
- "name, model_fn",
- [
- ("resnet50", models.resnet50),
- ("retinanet_resnet50_fpn_v2", models.detection.retinanet_resnet50_fpn_v2),
- ("raft_large", models.optical_flow.raft_large),
- ("quantized_resnet50", models.quantization.resnet50),
- ("lraspp_mobilenet_v3_large", models.segmentation.lraspp_mobilenet_v3_large),
- ("mvit_v1_b", models.video.mvit_v1_b),
- ],
- )
- def test_get_model_builder(name, model_fn):
- assert models.get_model_builder(name) == model_fn
- @pytest.mark.parametrize(
- "name, weight",
- [
- ("resnet50", models.ResNet50_Weights),
- ("retinanet_resnet50_fpn_v2", models.detection.RetinaNet_ResNet50_FPN_V2_Weights),
- ("raft_large", models.optical_flow.Raft_Large_Weights),
- ("quantized_resnet50", models.quantization.ResNet50_QuantizedWeights),
- ("lraspp_mobilenet_v3_large", models.segmentation.LRASPP_MobileNet_V3_Large_Weights),
- ("mvit_v1_b", models.video.MViT_V1_B_Weights),
- ],
- )
- def test_get_model_weights(name, weight):
- assert models.get_model_weights(name) == weight
- @pytest.mark.parametrize("copy_fn", [copy.copy, copy.deepcopy])
- @pytest.mark.parametrize(
- "name",
- [
- "resnet50",
- "retinanet_resnet50_fpn_v2",
- "raft_large",
- "quantized_resnet50",
- "lraspp_mobilenet_v3_large",
- "mvit_v1_b",
- ],
- )
- def test_weights_copyable(copy_fn, name):
- for weights in list(models.get_model_weights(name)):
- # It is somewhat surprising that (deep-)copying is an identity operation here, but this is the default behavior
- # of enums: https://docs.python.org/3/howto/enum.html#enum-members-aka-instances
- # Checking for equality, i.e. `==`, is sufficient (and even preferable) for our use case, should we need to drop
- # support for the identity operation in the future.
- assert copy_fn(weights) is weights
- @pytest.mark.parametrize(
- "name",
- [
- "resnet50",
- "retinanet_resnet50_fpn_v2",
- "raft_large",
- "quantized_resnet50",
- "lraspp_mobilenet_v3_large",
- "mvit_v1_b",
- ],
- )
- def test_weights_deserializable(name):
- for weights in list(models.get_model_weights(name)):
- # It is somewhat surprising that deserialization is an identity operation here, but this is the default behavior
- # of enums: https://docs.python.org/3/howto/enum.html#enum-members-aka-instances
- # Checking for equality, i.e. `==`, is sufficient (and even preferable) for our use case, should we need to drop
- # support for the identity operation in the future.
- assert pickle.loads(pickle.dumps(weights)) is weights
- def get_models_from_module(module):
- return [
- v.__name__
- for k, v in module.__dict__.items()
- if callable(v) and k[0].islower() and k[0] != "_" and k not in models._api.__all__
- ]
- @pytest.mark.parametrize(
- "module", [models, models.detection, models.quantization, models.segmentation, models.video, models.optical_flow]
- )
- def test_list_models(module):
- a = set(get_models_from_module(module))
- b = set(x.replace("quantized_", "") for x in models.list_models(module))
- assert len(b) > 0
- assert a == b
- @pytest.mark.parametrize(
- "include_filters",
- [
- None,
- [],
- (),
- "",
- "*resnet*",
- ["*alexnet*"],
- "*not-existing-model-for-test?",
- ["*resnet*", "*alexnet*"],
- ["*resnet*", "*alexnet*", "*not-existing-model-for-test?"],
- ("*resnet*", "*alexnet*"),
- set(["*resnet*", "*alexnet*"]),
- ],
- )
- @pytest.mark.parametrize(
- "exclude_filters",
- [
- None,
- [],
- (),
- "",
- "*resnet*",
- ["*alexnet*"],
- ["*not-existing-model-for-test?"],
- ["resnet34", "*not-existing-model-for-test?"],
- ["resnet34", "*resnet1*"],
- ("resnet34", "*resnet1*"),
- set(["resnet34", "*resnet1*"]),
- ],
- )
- def test_list_models_filters(include_filters, exclude_filters):
- actual = set(models.list_models(models, include=include_filters, exclude=exclude_filters))
- classification_models = set(get_models_from_module(models))
- if isinstance(include_filters, str):
- include_filters = [include_filters]
- if isinstance(exclude_filters, str):
- exclude_filters = [exclude_filters]
- if include_filters:
- expected = set()
- for include_f in include_filters:
- include_f = include_f.strip("*?")
- expected = expected | set(x for x in classification_models if include_f in x)
- else:
- expected = classification_models
- if exclude_filters:
- for exclude_f in exclude_filters:
- exclude_f = exclude_f.strip("*?")
- if exclude_f != "":
- a_exclude = set(x for x in classification_models if exclude_f in x)
- expected = expected - a_exclude
- assert expected == actual
- @pytest.mark.parametrize(
- "name, weight",
- [
- ("ResNet50_Weights.IMAGENET1K_V1", models.ResNet50_Weights.IMAGENET1K_V1),
- ("ResNet50_Weights.DEFAULT", models.ResNet50_Weights.IMAGENET1K_V2),
- (
- "ResNet50_QuantizedWeights.DEFAULT",
- models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V2,
- ),
- (
- "ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1",
- models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1,
- ),
- ],
- )
- def test_get_weight(name, weight):
- assert models.get_weight(name) == weight
- @pytest.mark.parametrize(
- "model_fn",
- TM.list_model_fns(models)
- + TM.list_model_fns(models.detection)
- + TM.list_model_fns(models.quantization)
- + TM.list_model_fns(models.segmentation)
- + TM.list_model_fns(models.video)
- + TM.list_model_fns(models.optical_flow),
- )
- def test_naming_conventions(model_fn):
- weights_enum = get_model_weights(model_fn)
- assert weights_enum is not None
- assert len(weights_enum) == 0 or hasattr(weights_enum, "DEFAULT")
- detection_models_input_dims = {
- "fasterrcnn_mobilenet_v3_large_320_fpn": (320, 320),
- "fasterrcnn_mobilenet_v3_large_fpn": (800, 800),
- "fasterrcnn_resnet50_fpn": (800, 800),
- "fasterrcnn_resnet50_fpn_v2": (800, 800),
- "fcos_resnet50_fpn": (800, 800),
- "keypointrcnn_resnet50_fpn": (1333, 1333),
- "maskrcnn_resnet50_fpn": (800, 800),
- "maskrcnn_resnet50_fpn_v2": (800, 800),
- "retinanet_resnet50_fpn": (800, 800),
- "retinanet_resnet50_fpn_v2": (800, 800),
- "ssd300_vgg16": (300, 300),
- "ssdlite320_mobilenet_v3_large": (320, 320),
- }
- @pytest.mark.parametrize(
- "model_fn",
- TM.list_model_fns(models)
- + TM.list_model_fns(models.detection)
- + TM.list_model_fns(models.quantization)
- + TM.list_model_fns(models.segmentation)
- + TM.list_model_fns(models.video)
- + TM.list_model_fns(models.optical_flow),
- )
- @run_if_test_with_extended
- def test_schema_meta_validation(model_fn):
- if model_fn.__name__ == "maskrcnn_resnet50_fpn_v2":
- pytest.skip(reason="FIXME https://github.com/pytorch/vision/issues/7349")
- # list of all possible supported high-level fields for weights meta-data
- permitted_fields = {
- "backend",
- "categories",
- "keypoint_names",
- "license",
- "_metrics",
- "min_size",
- "min_temporal_size",
- "num_params",
- "recipe",
- "unquantized",
- "_docs",
- "_ops",
- "_file_size",
- }
- # mandatory fields for each computer vision task
- classification_fields = {"categories", ("_metrics", "ImageNet-1K", "acc@1"), ("_metrics", "ImageNet-1K", "acc@5")}
- defaults = {
- "all": {"_metrics", "min_size", "num_params", "recipe", "_docs", "_file_size", "_ops"},
- "models": classification_fields,
- "detection": {"categories", ("_metrics", "COCO-val2017", "box_map")},
- "quantization": classification_fields | {"backend", "unquantized"},
- "segmentation": {
- "categories",
- ("_metrics", "COCO-val2017-VOC-labels", "miou"),
- ("_metrics", "COCO-val2017-VOC-labels", "pixel_acc"),
- },
- "video": {"categories", ("_metrics", "Kinetics-400", "acc@1"), ("_metrics", "Kinetics-400", "acc@5")},
- "optical_flow": set(),
- }
- model_name = model_fn.__name__
- module_name = model_fn.__module__.split(".")[-2]
- expected_fields = defaults["all"] | defaults[module_name]
- weights_enum = get_model_weights(model_fn)
- if len(weights_enum) == 0:
- pytest.skip(f"Model '{model_name}' doesn't have any pre-trained weights.")
- problematic_weights = {}
- incorrect_meta = []
- bad_names = []
- for w in weights_enum:
- actual_fields = set(w.meta.keys())
- actual_fields |= set(
- ("_metrics", dataset, metric_key)
- for dataset in w.meta.get("_metrics", {}).keys()
- for metric_key in w.meta.get("_metrics", {}).get(dataset, {}).keys()
- )
- missing_fields = expected_fields - actual_fields
- unsupported_fields = set(w.meta.keys()) - permitted_fields
- if missing_fields or unsupported_fields:
- problematic_weights[w] = {"missing": missing_fields, "unsupported": unsupported_fields}
- if w == weights_enum.DEFAULT or any(w.meta[k] != weights_enum.DEFAULT.meta[k] for k in ["num_params", "_ops"]):
- if module_name == "quantization":
- # parameters() count doesn't work well with quantization, so we check against the non-quantized
- unquantized_w = w.meta.get("unquantized")
- if unquantized_w is not None:
- if w.meta.get("num_params") != unquantized_w.meta.get("num_params"):
- incorrect_meta.append((w, "num_params"))
- # the methodology for quantized ops count doesn't work as well, so we take unquantized FLOPs
- # instead
- if w.meta["_ops"] != unquantized_w.meta.get("_ops"):
- incorrect_meta.append((w, "_ops"))
- else:
- # loading the model and using it for parameter and ops verification
- model = model_fn(weights=w)
- if w.meta.get("num_params") != sum(p.numel() for p in model.parameters()):
- incorrect_meta.append((w, "num_params"))
- kwargs = {}
- if model_name in detection_models_input_dims:
- # detection models have non default height and width
- height, width = detection_models_input_dims[model_name]
- kwargs = {"height": height, "width": width}
- if not model_fn.__name__.startswith("vit"):
- # FIXME: https://github.com/pytorch/vision/issues/7871
- calculated_ops = get_ops(model=model, weight=w, **kwargs)
- if calculated_ops != w.meta["_ops"]:
- incorrect_meta.append((w, "_ops"))
- if not w.name.isupper():
- bad_names.append(w)
- if get_file_size_mb(w) != w.meta.get("_file_size"):
- incorrect_meta.append((w, "_file_size"))
- assert not problematic_weights
- assert not incorrect_meta
- assert not bad_names
- @pytest.mark.parametrize(
- "model_fn",
- TM.list_model_fns(models)
- + TM.list_model_fns(models.detection)
- + TM.list_model_fns(models.quantization)
- + TM.list_model_fns(models.segmentation)
- + TM.list_model_fns(models.video)
- + TM.list_model_fns(models.optical_flow),
- )
- @run_if_test_with_extended
- def test_transforms_jit(model_fn):
- model_name = model_fn.__name__
- weights_enum = get_model_weights(model_fn)
- if len(weights_enum) == 0:
- pytest.skip(f"Model '{model_name}' doesn't have any pre-trained weights.")
- defaults = {
- "models": {
- "input_shape": (1, 3, 224, 224),
- },
- "detection": {
- "input_shape": (3, 300, 300),
- },
- "quantization": {
- "input_shape": (1, 3, 224, 224),
- },
- "segmentation": {
- "input_shape": (1, 3, 520, 520),
- },
- "video": {
- "input_shape": (1, 3, 4, 112, 112),
- },
- "optical_flow": {
- "input_shape": (1, 3, 128, 128),
- },
- }
- module_name = model_fn.__module__.split(".")[-2]
- kwargs = {**defaults[module_name], **TM._model_params.get(model_name, {})}
- input_shape = kwargs.pop("input_shape")
- x = torch.rand(input_shape)
- if module_name == "optical_flow":
- args = (x, x)
- else:
- if module_name == "video":
- x = x.permute(0, 2, 1, 3, 4)
- args = (x,)
- problematic_weights = []
- for w in weights_enum:
- transforms = w.transforms()
- try:
- TM._check_jit_scriptable(transforms, args)
- except Exception:
- problematic_weights.append(w)
- assert not problematic_weights
- # With this filter, every unexpected warning will be turned into an error
- @pytest.mark.filterwarnings("error")
- class TestHandleLegacyInterface:
- class ModelWeights(WeightsEnum):
- Sentinel = Weights(url="https://pytorch.org", transforms=lambda x: x, meta=dict())
- @pytest.mark.parametrize(
- "kwargs",
- [
- pytest.param(dict(), id="empty"),
- pytest.param(dict(weights=None), id="None"),
- pytest.param(dict(weights=ModelWeights.Sentinel), id="Weights"),
- ],
- )
- def test_no_warn(self, kwargs):
- @handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
- def builder(*, weights=None):
- pass
- builder(**kwargs)
- @pytest.mark.parametrize("pretrained", (True, False))
- def test_pretrained_pos(self, pretrained):
- @handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
- def builder(*, weights=None):
- pass
- with pytest.warns(UserWarning, match="positional"):
- builder(pretrained)
- @pytest.mark.parametrize("pretrained", (True, False))
- def test_pretrained_kw(self, pretrained):
- @handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
- def builder(*, weights=None):
- pass
- with pytest.warns(UserWarning, match="deprecated"):
- builder(pretrained)
- @pytest.mark.parametrize("pretrained", (True, False))
- @pytest.mark.parametrize("positional", (True, False))
- def test_equivalent_behavior_weights(self, pretrained, positional):
- @handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
- def builder(*, weights=None):
- pass
- args, kwargs = ((pretrained,), dict()) if positional else ((), dict(pretrained=pretrained))
- with pytest.warns(UserWarning, match=f"weights={self.ModelWeights.Sentinel if pretrained else None}"):
- builder(*args, **kwargs)
- def test_multi_params(self):
- weights_params = ("weights", "weights_other")
- pretrained_params = [param.replace("weights", "pretrained") for param in weights_params]
- @handle_legacy_interface(
- **{
- weights_param: (pretrained_param, self.ModelWeights.Sentinel)
- for weights_param, pretrained_param in zip(weights_params, pretrained_params)
- }
- )
- def builder(*, weights=None, weights_other=None):
- pass
- for pretrained_param in pretrained_params:
- with pytest.warns(UserWarning, match="deprecated"):
- builder(**{pretrained_param: True})
- def test_default_callable(self):
- @handle_legacy_interface(
- weights=(
- "pretrained",
- lambda kwargs: self.ModelWeights.Sentinel if kwargs["flag"] else None,
- )
- )
- def builder(*, weights=None, flag):
- pass
- with pytest.warns(UserWarning, match="deprecated"):
- builder(pretrained=True, flag=True)
- with pytest.raises(ValueError, match="weights"):
- builder(pretrained=True, flag=False)
- @pytest.mark.parametrize(
- "model_fn",
- [fn for fn in TM.list_model_fns(models) if fn.__name__ not in {"vit_h_14", "regnet_y_128gf"}]
- + TM.list_model_fns(models.detection)
- + TM.list_model_fns(models.quantization)
- + TM.list_model_fns(models.segmentation)
- + TM.list_model_fns(models.video)
- + TM.list_model_fns(models.optical_flow)
- + [
- lambda pretrained: resnet_fpn_backbone(backbone_name="resnet50", pretrained=pretrained),
- lambda pretrained: mobilenet_backbone(backbone_name="mobilenet_v2", fpn=False, pretrained=pretrained),
- ],
- )
- @run_if_test_with_extended
- def test_pretrained_deprecation(self, model_fn):
- with pytest.warns(UserWarning, match="deprecated"):
- model_fn(pretrained=True)
|