test_extended_models.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503
  1. import copy
  2. import os
  3. import pickle
  4. import pytest
  5. import test_models as TM
  6. import torch
  7. from common_extended_utils import get_file_size_mb, get_ops
  8. from torchvision import models
  9. from torchvision.models import get_model_weights, Weights, WeightsEnum
  10. from torchvision.models._utils import handle_legacy_interface
  11. from torchvision.models.detection.backbone_utils import mobilenet_backbone, resnet_fpn_backbone
  12. run_if_test_with_extended = pytest.mark.skipif(
  13. os.getenv("PYTORCH_TEST_WITH_EXTENDED", "0") != "1",
  14. reason="Extended tests are disabled by default. Set PYTORCH_TEST_WITH_EXTENDED=1 to run them.",
  15. )
  16. @pytest.mark.parametrize(
  17. "name, model_class",
  18. [
  19. ("resnet50", models.ResNet),
  20. ("retinanet_resnet50_fpn_v2", models.detection.RetinaNet),
  21. ("raft_large", models.optical_flow.RAFT),
  22. ("quantized_resnet50", models.quantization.QuantizableResNet),
  23. ("lraspp_mobilenet_v3_large", models.segmentation.LRASPP),
  24. ("mvit_v1_b", models.video.MViT),
  25. ],
  26. )
  27. def test_get_model(name, model_class):
  28. assert isinstance(models.get_model(name), model_class)
  29. @pytest.mark.parametrize(
  30. "name, model_fn",
  31. [
  32. ("resnet50", models.resnet50),
  33. ("retinanet_resnet50_fpn_v2", models.detection.retinanet_resnet50_fpn_v2),
  34. ("raft_large", models.optical_flow.raft_large),
  35. ("quantized_resnet50", models.quantization.resnet50),
  36. ("lraspp_mobilenet_v3_large", models.segmentation.lraspp_mobilenet_v3_large),
  37. ("mvit_v1_b", models.video.mvit_v1_b),
  38. ],
  39. )
  40. def test_get_model_builder(name, model_fn):
  41. assert models.get_model_builder(name) == model_fn
  42. @pytest.mark.parametrize(
  43. "name, weight",
  44. [
  45. ("resnet50", models.ResNet50_Weights),
  46. ("retinanet_resnet50_fpn_v2", models.detection.RetinaNet_ResNet50_FPN_V2_Weights),
  47. ("raft_large", models.optical_flow.Raft_Large_Weights),
  48. ("quantized_resnet50", models.quantization.ResNet50_QuantizedWeights),
  49. ("lraspp_mobilenet_v3_large", models.segmentation.LRASPP_MobileNet_V3_Large_Weights),
  50. ("mvit_v1_b", models.video.MViT_V1_B_Weights),
  51. ],
  52. )
  53. def test_get_model_weights(name, weight):
  54. assert models.get_model_weights(name) == weight
  55. @pytest.mark.parametrize("copy_fn", [copy.copy, copy.deepcopy])
  56. @pytest.mark.parametrize(
  57. "name",
  58. [
  59. "resnet50",
  60. "retinanet_resnet50_fpn_v2",
  61. "raft_large",
  62. "quantized_resnet50",
  63. "lraspp_mobilenet_v3_large",
  64. "mvit_v1_b",
  65. ],
  66. )
  67. def test_weights_copyable(copy_fn, name):
  68. for weights in list(models.get_model_weights(name)):
  69. # It is somewhat surprising that (deep-)copying is an identity operation here, but this is the default behavior
  70. # of enums: https://docs.python.org/3/howto/enum.html#enum-members-aka-instances
  71. # Checking for equality, i.e. `==`, is sufficient (and even preferable) for our use case, should we need to drop
  72. # support for the identity operation in the future.
  73. assert copy_fn(weights) is weights
  74. @pytest.mark.parametrize(
  75. "name",
  76. [
  77. "resnet50",
  78. "retinanet_resnet50_fpn_v2",
  79. "raft_large",
  80. "quantized_resnet50",
  81. "lraspp_mobilenet_v3_large",
  82. "mvit_v1_b",
  83. ],
  84. )
  85. def test_weights_deserializable(name):
  86. for weights in list(models.get_model_weights(name)):
  87. # It is somewhat surprising that deserialization is an identity operation here, but this is the default behavior
  88. # of enums: https://docs.python.org/3/howto/enum.html#enum-members-aka-instances
  89. # Checking for equality, i.e. `==`, is sufficient (and even preferable) for our use case, should we need to drop
  90. # support for the identity operation in the future.
  91. assert pickle.loads(pickle.dumps(weights)) is weights
  92. def get_models_from_module(module):
  93. return [
  94. v.__name__
  95. for k, v in module.__dict__.items()
  96. if callable(v) and k[0].islower() and k[0] != "_" and k not in models._api.__all__
  97. ]
  98. @pytest.mark.parametrize(
  99. "module", [models, models.detection, models.quantization, models.segmentation, models.video, models.optical_flow]
  100. )
  101. def test_list_models(module):
  102. a = set(get_models_from_module(module))
  103. b = set(x.replace("quantized_", "") for x in models.list_models(module))
  104. assert len(b) > 0
  105. assert a == b
  106. @pytest.mark.parametrize(
  107. "include_filters",
  108. [
  109. None,
  110. [],
  111. (),
  112. "",
  113. "*resnet*",
  114. ["*alexnet*"],
  115. "*not-existing-model-for-test?",
  116. ["*resnet*", "*alexnet*"],
  117. ["*resnet*", "*alexnet*", "*not-existing-model-for-test?"],
  118. ("*resnet*", "*alexnet*"),
  119. set(["*resnet*", "*alexnet*"]),
  120. ],
  121. )
  122. @pytest.mark.parametrize(
  123. "exclude_filters",
  124. [
  125. None,
  126. [],
  127. (),
  128. "",
  129. "*resnet*",
  130. ["*alexnet*"],
  131. ["*not-existing-model-for-test?"],
  132. ["resnet34", "*not-existing-model-for-test?"],
  133. ["resnet34", "*resnet1*"],
  134. ("resnet34", "*resnet1*"),
  135. set(["resnet34", "*resnet1*"]),
  136. ],
  137. )
  138. def test_list_models_filters(include_filters, exclude_filters):
  139. actual = set(models.list_models(models, include=include_filters, exclude=exclude_filters))
  140. classification_models = set(get_models_from_module(models))
  141. if isinstance(include_filters, str):
  142. include_filters = [include_filters]
  143. if isinstance(exclude_filters, str):
  144. exclude_filters = [exclude_filters]
  145. if include_filters:
  146. expected = set()
  147. for include_f in include_filters:
  148. include_f = include_f.strip("*?")
  149. expected = expected | set(x for x in classification_models if include_f in x)
  150. else:
  151. expected = classification_models
  152. if exclude_filters:
  153. for exclude_f in exclude_filters:
  154. exclude_f = exclude_f.strip("*?")
  155. if exclude_f != "":
  156. a_exclude = set(x for x in classification_models if exclude_f in x)
  157. expected = expected - a_exclude
  158. assert expected == actual
  159. @pytest.mark.parametrize(
  160. "name, weight",
  161. [
  162. ("ResNet50_Weights.IMAGENET1K_V1", models.ResNet50_Weights.IMAGENET1K_V1),
  163. ("ResNet50_Weights.DEFAULT", models.ResNet50_Weights.IMAGENET1K_V2),
  164. (
  165. "ResNet50_QuantizedWeights.DEFAULT",
  166. models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V2,
  167. ),
  168. (
  169. "ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1",
  170. models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1,
  171. ),
  172. ],
  173. )
  174. def test_get_weight(name, weight):
  175. assert models.get_weight(name) == weight
  176. @pytest.mark.parametrize(
  177. "model_fn",
  178. TM.list_model_fns(models)
  179. + TM.list_model_fns(models.detection)
  180. + TM.list_model_fns(models.quantization)
  181. + TM.list_model_fns(models.segmentation)
  182. + TM.list_model_fns(models.video)
  183. + TM.list_model_fns(models.optical_flow),
  184. )
  185. def test_naming_conventions(model_fn):
  186. weights_enum = get_model_weights(model_fn)
  187. assert weights_enum is not None
  188. assert len(weights_enum) == 0 or hasattr(weights_enum, "DEFAULT")
  189. detection_models_input_dims = {
  190. "fasterrcnn_mobilenet_v3_large_320_fpn": (320, 320),
  191. "fasterrcnn_mobilenet_v3_large_fpn": (800, 800),
  192. "fasterrcnn_resnet50_fpn": (800, 800),
  193. "fasterrcnn_resnet50_fpn_v2": (800, 800),
  194. "fcos_resnet50_fpn": (800, 800),
  195. "keypointrcnn_resnet50_fpn": (1333, 1333),
  196. "maskrcnn_resnet50_fpn": (800, 800),
  197. "maskrcnn_resnet50_fpn_v2": (800, 800),
  198. "retinanet_resnet50_fpn": (800, 800),
  199. "retinanet_resnet50_fpn_v2": (800, 800),
  200. "ssd300_vgg16": (300, 300),
  201. "ssdlite320_mobilenet_v3_large": (320, 320),
  202. }
  203. @pytest.mark.parametrize(
  204. "model_fn",
  205. TM.list_model_fns(models)
  206. + TM.list_model_fns(models.detection)
  207. + TM.list_model_fns(models.quantization)
  208. + TM.list_model_fns(models.segmentation)
  209. + TM.list_model_fns(models.video)
  210. + TM.list_model_fns(models.optical_flow),
  211. )
  212. @run_if_test_with_extended
  213. def test_schema_meta_validation(model_fn):
  214. if model_fn.__name__ == "maskrcnn_resnet50_fpn_v2":
  215. pytest.skip(reason="FIXME https://github.com/pytorch/vision/issues/7349")
  216. # list of all possible supported high-level fields for weights meta-data
  217. permitted_fields = {
  218. "backend",
  219. "categories",
  220. "keypoint_names",
  221. "license",
  222. "_metrics",
  223. "min_size",
  224. "min_temporal_size",
  225. "num_params",
  226. "recipe",
  227. "unquantized",
  228. "_docs",
  229. "_ops",
  230. "_file_size",
  231. }
  232. # mandatory fields for each computer vision task
  233. classification_fields = {"categories", ("_metrics", "ImageNet-1K", "acc@1"), ("_metrics", "ImageNet-1K", "acc@5")}
  234. defaults = {
  235. "all": {"_metrics", "min_size", "num_params", "recipe", "_docs", "_file_size", "_ops"},
  236. "models": classification_fields,
  237. "detection": {"categories", ("_metrics", "COCO-val2017", "box_map")},
  238. "quantization": classification_fields | {"backend", "unquantized"},
  239. "segmentation": {
  240. "categories",
  241. ("_metrics", "COCO-val2017-VOC-labels", "miou"),
  242. ("_metrics", "COCO-val2017-VOC-labels", "pixel_acc"),
  243. },
  244. "video": {"categories", ("_metrics", "Kinetics-400", "acc@1"), ("_metrics", "Kinetics-400", "acc@5")},
  245. "optical_flow": set(),
  246. }
  247. model_name = model_fn.__name__
  248. module_name = model_fn.__module__.split(".")[-2]
  249. expected_fields = defaults["all"] | defaults[module_name]
  250. weights_enum = get_model_weights(model_fn)
  251. if len(weights_enum) == 0:
  252. pytest.skip(f"Model '{model_name}' doesn't have any pre-trained weights.")
  253. problematic_weights = {}
  254. incorrect_meta = []
  255. bad_names = []
  256. for w in weights_enum:
  257. actual_fields = set(w.meta.keys())
  258. actual_fields |= set(
  259. ("_metrics", dataset, metric_key)
  260. for dataset in w.meta.get("_metrics", {}).keys()
  261. for metric_key in w.meta.get("_metrics", {}).get(dataset, {}).keys()
  262. )
  263. missing_fields = expected_fields - actual_fields
  264. unsupported_fields = set(w.meta.keys()) - permitted_fields
  265. if missing_fields or unsupported_fields:
  266. problematic_weights[w] = {"missing": missing_fields, "unsupported": unsupported_fields}
  267. if w == weights_enum.DEFAULT or any(w.meta[k] != weights_enum.DEFAULT.meta[k] for k in ["num_params", "_ops"]):
  268. if module_name == "quantization":
  269. # parameters() count doesn't work well with quantization, so we check against the non-quantized
  270. unquantized_w = w.meta.get("unquantized")
  271. if unquantized_w is not None:
  272. if w.meta.get("num_params") != unquantized_w.meta.get("num_params"):
  273. incorrect_meta.append((w, "num_params"))
  274. # the methodology for quantized ops count doesn't work as well, so we take unquantized FLOPs
  275. # instead
  276. if w.meta["_ops"] != unquantized_w.meta.get("_ops"):
  277. incorrect_meta.append((w, "_ops"))
  278. else:
  279. # loading the model and using it for parameter and ops verification
  280. model = model_fn(weights=w)
  281. if w.meta.get("num_params") != sum(p.numel() for p in model.parameters()):
  282. incorrect_meta.append((w, "num_params"))
  283. kwargs = {}
  284. if model_name in detection_models_input_dims:
  285. # detection models have non default height and width
  286. height, width = detection_models_input_dims[model_name]
  287. kwargs = {"height": height, "width": width}
  288. if not model_fn.__name__.startswith("vit"):
  289. # FIXME: https://github.com/pytorch/vision/issues/7871
  290. calculated_ops = get_ops(model=model, weight=w, **kwargs)
  291. if calculated_ops != w.meta["_ops"]:
  292. incorrect_meta.append((w, "_ops"))
  293. if not w.name.isupper():
  294. bad_names.append(w)
  295. if get_file_size_mb(w) != w.meta.get("_file_size"):
  296. incorrect_meta.append((w, "_file_size"))
  297. assert not problematic_weights
  298. assert not incorrect_meta
  299. assert not bad_names
  300. @pytest.mark.parametrize(
  301. "model_fn",
  302. TM.list_model_fns(models)
  303. + TM.list_model_fns(models.detection)
  304. + TM.list_model_fns(models.quantization)
  305. + TM.list_model_fns(models.segmentation)
  306. + TM.list_model_fns(models.video)
  307. + TM.list_model_fns(models.optical_flow),
  308. )
  309. @run_if_test_with_extended
  310. def test_transforms_jit(model_fn):
  311. model_name = model_fn.__name__
  312. weights_enum = get_model_weights(model_fn)
  313. if len(weights_enum) == 0:
  314. pytest.skip(f"Model '{model_name}' doesn't have any pre-trained weights.")
  315. defaults = {
  316. "models": {
  317. "input_shape": (1, 3, 224, 224),
  318. },
  319. "detection": {
  320. "input_shape": (3, 300, 300),
  321. },
  322. "quantization": {
  323. "input_shape": (1, 3, 224, 224),
  324. },
  325. "segmentation": {
  326. "input_shape": (1, 3, 520, 520),
  327. },
  328. "video": {
  329. "input_shape": (1, 3, 4, 112, 112),
  330. },
  331. "optical_flow": {
  332. "input_shape": (1, 3, 128, 128),
  333. },
  334. }
  335. module_name = model_fn.__module__.split(".")[-2]
  336. kwargs = {**defaults[module_name], **TM._model_params.get(model_name, {})}
  337. input_shape = kwargs.pop("input_shape")
  338. x = torch.rand(input_shape)
  339. if module_name == "optical_flow":
  340. args = (x, x)
  341. else:
  342. if module_name == "video":
  343. x = x.permute(0, 2, 1, 3, 4)
  344. args = (x,)
  345. problematic_weights = []
  346. for w in weights_enum:
  347. transforms = w.transforms()
  348. try:
  349. TM._check_jit_scriptable(transforms, args)
  350. except Exception:
  351. problematic_weights.append(w)
  352. assert not problematic_weights
  353. # With this filter, every unexpected warning will be turned into an error
  354. @pytest.mark.filterwarnings("error")
  355. class TestHandleLegacyInterface:
  356. class ModelWeights(WeightsEnum):
  357. Sentinel = Weights(url="https://pytorch.org", transforms=lambda x: x, meta=dict())
  358. @pytest.mark.parametrize(
  359. "kwargs",
  360. [
  361. pytest.param(dict(), id="empty"),
  362. pytest.param(dict(weights=None), id="None"),
  363. pytest.param(dict(weights=ModelWeights.Sentinel), id="Weights"),
  364. ],
  365. )
  366. def test_no_warn(self, kwargs):
  367. @handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
  368. def builder(*, weights=None):
  369. pass
  370. builder(**kwargs)
  371. @pytest.mark.parametrize("pretrained", (True, False))
  372. def test_pretrained_pos(self, pretrained):
  373. @handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
  374. def builder(*, weights=None):
  375. pass
  376. with pytest.warns(UserWarning, match="positional"):
  377. builder(pretrained)
  378. @pytest.mark.parametrize("pretrained", (True, False))
  379. def test_pretrained_kw(self, pretrained):
  380. @handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
  381. def builder(*, weights=None):
  382. pass
  383. with pytest.warns(UserWarning, match="deprecated"):
  384. builder(pretrained)
  385. @pytest.mark.parametrize("pretrained", (True, False))
  386. @pytest.mark.parametrize("positional", (True, False))
  387. def test_equivalent_behavior_weights(self, pretrained, positional):
  388. @handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
  389. def builder(*, weights=None):
  390. pass
  391. args, kwargs = ((pretrained,), dict()) if positional else ((), dict(pretrained=pretrained))
  392. with pytest.warns(UserWarning, match=f"weights={self.ModelWeights.Sentinel if pretrained else None}"):
  393. builder(*args, **kwargs)
  394. def test_multi_params(self):
  395. weights_params = ("weights", "weights_other")
  396. pretrained_params = [param.replace("weights", "pretrained") for param in weights_params]
  397. @handle_legacy_interface(
  398. **{
  399. weights_param: (pretrained_param, self.ModelWeights.Sentinel)
  400. for weights_param, pretrained_param in zip(weights_params, pretrained_params)
  401. }
  402. )
  403. def builder(*, weights=None, weights_other=None):
  404. pass
  405. for pretrained_param in pretrained_params:
  406. with pytest.warns(UserWarning, match="deprecated"):
  407. builder(**{pretrained_param: True})
  408. def test_default_callable(self):
  409. @handle_legacy_interface(
  410. weights=(
  411. "pretrained",
  412. lambda kwargs: self.ModelWeights.Sentinel if kwargs["flag"] else None,
  413. )
  414. )
  415. def builder(*, weights=None, flag):
  416. pass
  417. with pytest.warns(UserWarning, match="deprecated"):
  418. builder(pretrained=True, flag=True)
  419. with pytest.raises(ValueError, match="weights"):
  420. builder(pretrained=True, flag=False)
  421. @pytest.mark.parametrize(
  422. "model_fn",
  423. [fn for fn in TM.list_model_fns(models) if fn.__name__ not in {"vit_h_14", "regnet_y_128gf"}]
  424. + TM.list_model_fns(models.detection)
  425. + TM.list_model_fns(models.quantization)
  426. + TM.list_model_fns(models.segmentation)
  427. + TM.list_model_fns(models.video)
  428. + TM.list_model_fns(models.optical_flow)
  429. + [
  430. lambda pretrained: resnet_fpn_backbone(backbone_name="resnet50", pretrained=pretrained),
  431. lambda pretrained: mobilenet_backbone(backbone_name="mobilenet_v2", fpn=False, pretrained=pretrained),
  432. ],
  433. )
  434. @run_if_test_with_extended
  435. def test_pretrained_deprecation(self, model_fn):
  436. with pytest.warns(UserWarning, match="deprecated"):
  437. model_fn(pretrained=True)