test_onnx.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582
  1. import io
  2. from collections import OrderedDict
  3. from typing import List, Optional, Tuple
  4. import pytest
  5. import torch
  6. from common_utils import assert_equal, set_rng_seed
  7. from torchvision import models, ops
  8. from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead
  9. from torchvision.models.detection.image_list import ImageList
  10. from torchvision.models.detection.roi_heads import RoIHeads
  11. from torchvision.models.detection.rpn import AnchorGenerator, RegionProposalNetwork, RPNHead
  12. from torchvision.models.detection.transform import GeneralizedRCNNTransform
  13. from torchvision.ops import _register_onnx_ops
  14. # In environments without onnxruntime we prefer to
  15. # invoke all tests in the repo and have this one skipped rather than fail.
  16. onnxruntime = pytest.importorskip("onnxruntime")
  17. class TestONNXExporter:
  18. @classmethod
  19. def setup_class(cls):
  20. torch.manual_seed(123)
  21. def run_model(
  22. self,
  23. model,
  24. inputs_list,
  25. do_constant_folding=True,
  26. dynamic_axes=None,
  27. output_names=None,
  28. input_names=None,
  29. opset_version: Optional[int] = None,
  30. ):
  31. if opset_version is None:
  32. opset_version = _register_onnx_ops.BASE_ONNX_OPSET_VERSION
  33. model.eval()
  34. onnx_io = io.BytesIO()
  35. if isinstance(inputs_list[0][-1], dict):
  36. torch_onnx_input = inputs_list[0] + ({},)
  37. else:
  38. torch_onnx_input = inputs_list[0]
  39. # export to onnx with the first input
  40. torch.onnx.export(
  41. model,
  42. torch_onnx_input,
  43. onnx_io,
  44. do_constant_folding=do_constant_folding,
  45. opset_version=opset_version,
  46. dynamic_axes=dynamic_axes,
  47. input_names=input_names,
  48. output_names=output_names,
  49. verbose=True,
  50. )
  51. # validate the exported model with onnx runtime
  52. for test_inputs in inputs_list:
  53. with torch.no_grad():
  54. if isinstance(test_inputs, torch.Tensor) or isinstance(test_inputs, list):
  55. test_inputs = (test_inputs,)
  56. test_ouputs = model(*test_inputs)
  57. if isinstance(test_ouputs, torch.Tensor):
  58. test_ouputs = (test_ouputs,)
  59. self.ort_validate(onnx_io, test_inputs, test_ouputs)
  60. def ort_validate(self, onnx_io, inputs, outputs):
  61. inputs, _ = torch.jit._flatten(inputs)
  62. outputs, _ = torch.jit._flatten(outputs)
  63. def to_numpy(tensor):
  64. if tensor.requires_grad:
  65. return tensor.detach().cpu().numpy()
  66. else:
  67. return tensor.cpu().numpy()
  68. inputs = list(map(to_numpy, inputs))
  69. outputs = list(map(to_numpy, outputs))
  70. ort_session = onnxruntime.InferenceSession(onnx_io.getvalue())
  71. # compute onnxruntime output prediction
  72. ort_inputs = {ort_session.get_inputs()[i].name: inpt for i, inpt in enumerate(inputs)}
  73. ort_outs = ort_session.run(None, ort_inputs)
  74. for i in range(0, len(outputs)):
  75. torch.testing.assert_close(outputs[i], ort_outs[i], rtol=1e-03, atol=1e-05)
  76. def test_nms(self):
  77. num_boxes = 100
  78. boxes = torch.rand(num_boxes, 4)
  79. boxes[:, 2:] += boxes[:, :2]
  80. scores = torch.randn(num_boxes)
  81. class Module(torch.nn.Module):
  82. def forward(self, boxes, scores):
  83. return ops.nms(boxes, scores, 0.5)
  84. self.run_model(Module(), [(boxes, scores)])
  85. def test_batched_nms(self):
  86. num_boxes = 100
  87. boxes = torch.rand(num_boxes, 4)
  88. boxes[:, 2:] += boxes[:, :2]
  89. scores = torch.randn(num_boxes)
  90. idxs = torch.randint(0, 5, size=(num_boxes,))
  91. class Module(torch.nn.Module):
  92. def forward(self, boxes, scores, idxs):
  93. return ops.batched_nms(boxes, scores, idxs, 0.5)
  94. self.run_model(Module(), [(boxes, scores, idxs)])
  95. def test_clip_boxes_to_image(self):
  96. boxes = torch.randn(5, 4) * 500
  97. boxes[:, 2:] += boxes[:, :2]
  98. size = torch.randn(200, 300)
  99. size_2 = torch.randn(300, 400)
  100. class Module(torch.nn.Module):
  101. def forward(self, boxes, size):
  102. return ops.boxes.clip_boxes_to_image(boxes, size.shape)
  103. self.run_model(
  104. Module(), [(boxes, size), (boxes, size_2)], input_names=["boxes", "size"], dynamic_axes={"size": [0, 1]}
  105. )
  106. def test_roi_align(self):
  107. x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
  108. single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
  109. model = ops.RoIAlign((5, 5), 1, 2)
  110. self.run_model(model, [(x, single_roi)])
  111. x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
  112. single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
  113. model = ops.RoIAlign((5, 5), 1, -1)
  114. self.run_model(model, [(x, single_roi)])
  115. def test_roi_align_aligned(self):
  116. supported_onnx_version = _register_onnx_ops._ONNX_OPSET_VERSION_16
  117. x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
  118. single_roi = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=torch.float32)
  119. model = ops.RoIAlign((5, 5), 1, 2, aligned=True)
  120. self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)
  121. x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
  122. single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
  123. model = ops.RoIAlign((5, 5), 0.5, 3, aligned=True)
  124. self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)
  125. x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
  126. single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
  127. model = ops.RoIAlign((5, 5), 1.8, 2, aligned=True)
  128. self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)
  129. x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
  130. single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
  131. model = ops.RoIAlign((2, 2), 2.5, 0, aligned=True)
  132. self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)
  133. x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
  134. single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
  135. model = ops.RoIAlign((2, 2), 2.5, -1, aligned=True)
  136. self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)
  137. def test_roi_align_malformed_boxes(self):
  138. supported_onnx_version = _register_onnx_ops._ONNX_OPSET_VERSION_16
  139. x = torch.randn(1, 1, 10, 10, dtype=torch.float32)
  140. single_roi = torch.tensor([[0, 2, 0.3, 1.5, 1.5]], dtype=torch.float32)
  141. model = ops.RoIAlign((5, 5), 1, 1, aligned=True)
  142. self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)
  143. def test_roi_pool(self):
  144. x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
  145. rois = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
  146. pool_h = 5
  147. pool_w = 5
  148. model = ops.RoIPool((pool_h, pool_w), 2)
  149. self.run_model(model, [(x, rois)])
  150. def test_resize_images(self):
  151. class TransformModule(torch.nn.Module):
  152. def __init__(self_module):
  153. super().__init__()
  154. self_module.transform = self._init_test_generalized_rcnn_transform()
  155. def forward(self_module, images):
  156. return self_module.transform.resize(images, None)[0]
  157. input = torch.rand(3, 10, 20)
  158. input_test = torch.rand(3, 100, 150)
  159. self.run_model(
  160. TransformModule(), [(input,), (input_test,)], input_names=["input1"], dynamic_axes={"input1": [0, 1, 2]}
  161. )
  162. def test_transform_images(self):
  163. class TransformModule(torch.nn.Module):
  164. def __init__(self_module):
  165. super().__init__()
  166. self_module.transform = self._init_test_generalized_rcnn_transform()
  167. def forward(self_module, images):
  168. return self_module.transform(images)[0].tensors
  169. input = torch.rand(3, 100, 200), torch.rand(3, 200, 200)
  170. input_test = torch.rand(3, 100, 200), torch.rand(3, 200, 200)
  171. self.run_model(TransformModule(), [(input,), (input_test,)])
  172. def _init_test_generalized_rcnn_transform(self):
  173. min_size = 100
  174. max_size = 200
  175. image_mean = [0.485, 0.456, 0.406]
  176. image_std = [0.229, 0.224, 0.225]
  177. transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
  178. return transform
  179. def _init_test_rpn(self):
  180. anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
  181. aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
  182. rpn_anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
  183. out_channels = 256
  184. rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])
  185. rpn_fg_iou_thresh = 0.7
  186. rpn_bg_iou_thresh = 0.3
  187. rpn_batch_size_per_image = 256
  188. rpn_positive_fraction = 0.5
  189. rpn_pre_nms_top_n = dict(training=2000, testing=1000)
  190. rpn_post_nms_top_n = dict(training=2000, testing=1000)
  191. rpn_nms_thresh = 0.7
  192. rpn_score_thresh = 0.0
  193. rpn = RegionProposalNetwork(
  194. rpn_anchor_generator,
  195. rpn_head,
  196. rpn_fg_iou_thresh,
  197. rpn_bg_iou_thresh,
  198. rpn_batch_size_per_image,
  199. rpn_positive_fraction,
  200. rpn_pre_nms_top_n,
  201. rpn_post_nms_top_n,
  202. rpn_nms_thresh,
  203. score_thresh=rpn_score_thresh,
  204. )
  205. return rpn
  206. def _init_test_roi_heads_faster_rcnn(self):
  207. out_channels = 256
  208. num_classes = 91
  209. box_fg_iou_thresh = 0.5
  210. box_bg_iou_thresh = 0.5
  211. box_batch_size_per_image = 512
  212. box_positive_fraction = 0.25
  213. bbox_reg_weights = None
  214. box_score_thresh = 0.05
  215. box_nms_thresh = 0.5
  216. box_detections_per_img = 100
  217. box_roi_pool = ops.MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2)
  218. resolution = box_roi_pool.output_size[0]
  219. representation_size = 1024
  220. box_head = TwoMLPHead(out_channels * resolution**2, representation_size)
  221. representation_size = 1024
  222. box_predictor = FastRCNNPredictor(representation_size, num_classes)
  223. roi_heads = RoIHeads(
  224. box_roi_pool,
  225. box_head,
  226. box_predictor,
  227. box_fg_iou_thresh,
  228. box_bg_iou_thresh,
  229. box_batch_size_per_image,
  230. box_positive_fraction,
  231. bbox_reg_weights,
  232. box_score_thresh,
  233. box_nms_thresh,
  234. box_detections_per_img,
  235. )
  236. return roi_heads
  237. def get_features(self, images):
  238. s0, s1 = images.shape[-2:]
  239. features = [
  240. ("0", torch.rand(2, 256, s0 // 4, s1 // 4)),
  241. ("1", torch.rand(2, 256, s0 // 8, s1 // 8)),
  242. ("2", torch.rand(2, 256, s0 // 16, s1 // 16)),
  243. ("3", torch.rand(2, 256, s0 // 32, s1 // 32)),
  244. ("4", torch.rand(2, 256, s0 // 64, s1 // 64)),
  245. ]
  246. features = OrderedDict(features)
  247. return features
  248. def test_rpn(self):
  249. set_rng_seed(0)
  250. class RPNModule(torch.nn.Module):
  251. def __init__(self_module):
  252. super().__init__()
  253. self_module.rpn = self._init_test_rpn()
  254. def forward(self_module, images, features):
  255. images = ImageList(images, [i.shape[-2:] for i in images])
  256. return self_module.rpn(images, features)
  257. images = torch.rand(2, 3, 150, 150)
  258. features = self.get_features(images)
  259. images2 = torch.rand(2, 3, 80, 80)
  260. test_features = self.get_features(images2)
  261. model = RPNModule()
  262. model.eval()
  263. model(images, features)
  264. self.run_model(
  265. model,
  266. [(images, features), (images2, test_features)],
  267. input_names=["input1", "input2", "input3", "input4", "input5", "input6"],
  268. dynamic_axes={
  269. "input1": [0, 1, 2, 3],
  270. "input2": [0, 1, 2, 3],
  271. "input3": [0, 1, 2, 3],
  272. "input4": [0, 1, 2, 3],
  273. "input5": [0, 1, 2, 3],
  274. "input6": [0, 1, 2, 3],
  275. },
  276. )
  277. def test_multi_scale_roi_align(self):
  278. class TransformModule(torch.nn.Module):
  279. def __init__(self):
  280. super().__init__()
  281. self.model = ops.MultiScaleRoIAlign(["feat1", "feat2"], 3, 2)
  282. self.image_sizes = [(512, 512)]
  283. def forward(self, input, boxes):
  284. return self.model(input, boxes, self.image_sizes)
  285. i = OrderedDict()
  286. i["feat1"] = torch.rand(1, 5, 64, 64)
  287. i["feat2"] = torch.rand(1, 5, 16, 16)
  288. boxes = torch.rand(6, 4) * 256
  289. boxes[:, 2:] += boxes[:, :2]
  290. i1 = OrderedDict()
  291. i1["feat1"] = torch.rand(1, 5, 64, 64)
  292. i1["feat2"] = torch.rand(1, 5, 16, 16)
  293. boxes1 = torch.rand(6, 4) * 256
  294. boxes1[:, 2:] += boxes1[:, :2]
  295. self.run_model(
  296. TransformModule(),
  297. [
  298. (
  299. i,
  300. [boxes],
  301. ),
  302. (
  303. i1,
  304. [boxes1],
  305. ),
  306. ],
  307. )
  308. def test_roi_heads(self):
  309. class RoiHeadsModule(torch.nn.Module):
  310. def __init__(self_module):
  311. super().__init__()
  312. self_module.transform = self._init_test_generalized_rcnn_transform()
  313. self_module.rpn = self._init_test_rpn()
  314. self_module.roi_heads = self._init_test_roi_heads_faster_rcnn()
  315. def forward(self_module, images, features):
  316. original_image_sizes = [img.shape[-2:] for img in images]
  317. images = ImageList(images, [i.shape[-2:] for i in images])
  318. proposals, _ = self_module.rpn(images, features)
  319. detections, _ = self_module.roi_heads(features, proposals, images.image_sizes)
  320. detections = self_module.transform.postprocess(detections, images.image_sizes, original_image_sizes)
  321. return detections
  322. images = torch.rand(2, 3, 100, 100)
  323. features = self.get_features(images)
  324. images2 = torch.rand(2, 3, 150, 150)
  325. test_features = self.get_features(images2)
  326. model = RoiHeadsModule()
  327. model.eval()
  328. model(images, features)
  329. self.run_model(
  330. model,
  331. [(images, features), (images2, test_features)],
  332. input_names=["input1", "input2", "input3", "input4", "input5", "input6"],
  333. dynamic_axes={
  334. "input1": [0, 1, 2, 3],
  335. "input2": [0, 1, 2, 3],
  336. "input3": [0, 1, 2, 3],
  337. "input4": [0, 1, 2, 3],
  338. "input5": [0, 1, 2, 3],
  339. "input6": [0, 1, 2, 3],
  340. },
  341. )
  342. def get_image(self, rel_path: str, size: Tuple[int, int]) -> torch.Tensor:
  343. import os
  344. from PIL import Image
  345. from torchvision.transforms import functional as F
  346. data_dir = os.path.join(os.path.dirname(__file__), "assets")
  347. path = os.path.join(data_dir, *rel_path.split("/"))
  348. image = Image.open(path).convert("RGB").resize(size, Image.BILINEAR)
  349. return F.convert_image_dtype(F.pil_to_tensor(image))
  350. def get_test_images(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
  351. return (
  352. [self.get_image("encode_jpeg/grace_hopper_517x606.jpg", (100, 320))],
  353. [self.get_image("fakedata/logos/rgb_pytorch.png", (250, 380))],
  354. )
  355. def test_faster_rcnn(self):
  356. images, test_images = self.get_test_images()
  357. dummy_image = [torch.ones(3, 100, 100) * 0.3]
  358. model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(
  359. weights=models.detection.faster_rcnn.FasterRCNN_ResNet50_FPN_Weights.DEFAULT, min_size=200, max_size=300
  360. )
  361. model.eval()
  362. model(images)
  363. # Test exported model on images of different size, or dummy input
  364. self.run_model(
  365. model,
  366. [(images,), (test_images,), (dummy_image,)],
  367. input_names=["images_tensors"],
  368. output_names=["outputs"],
  369. dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]},
  370. )
  371. # Test exported model for an image with no detections on other images
  372. self.run_model(
  373. model,
  374. [(dummy_image,), (images,)],
  375. input_names=["images_tensors"],
  376. output_names=["outputs"],
  377. dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]},
  378. )
  379. # Verify that paste_mask_in_image beahves the same in tracing.
  380. # This test also compares both paste_masks_in_image and _onnx_paste_masks_in_image
  381. # (since jit_trace witll call _onnx_paste_masks_in_image).
  382. def test_paste_mask_in_image(self):
  383. masks = torch.rand(10, 1, 26, 26)
  384. boxes = torch.rand(10, 4)
  385. boxes[:, 2:] += torch.rand(10, 2)
  386. boxes *= 50
  387. o_im_s = (100, 100)
  388. from torchvision.models.detection.roi_heads import paste_masks_in_image
  389. out = paste_masks_in_image(masks, boxes, o_im_s)
  390. jit_trace = torch.jit.trace(
  391. paste_masks_in_image, (masks, boxes, [torch.tensor(o_im_s[0]), torch.tensor(o_im_s[1])])
  392. )
  393. out_trace = jit_trace(masks, boxes, [torch.tensor(o_im_s[0]), torch.tensor(o_im_s[1])])
  394. assert torch.all(out.eq(out_trace))
  395. masks2 = torch.rand(20, 1, 26, 26)
  396. boxes2 = torch.rand(20, 4)
  397. boxes2[:, 2:] += torch.rand(20, 2)
  398. boxes2 *= 100
  399. o_im_s2 = (200, 200)
  400. from torchvision.models.detection.roi_heads import paste_masks_in_image
  401. out2 = paste_masks_in_image(masks2, boxes2, o_im_s2)
  402. out_trace2 = jit_trace(masks2, boxes2, [torch.tensor(o_im_s2[0]), torch.tensor(o_im_s2[1])])
  403. assert torch.all(out2.eq(out_trace2))
  404. def test_mask_rcnn(self):
  405. images, test_images = self.get_test_images()
  406. dummy_image = [torch.ones(3, 100, 100) * 0.3]
  407. model = models.detection.mask_rcnn.maskrcnn_resnet50_fpn(
  408. weights=models.detection.mask_rcnn.MaskRCNN_ResNet50_FPN_Weights.DEFAULT, min_size=200, max_size=300
  409. )
  410. model.eval()
  411. model(images)
  412. # Test exported model on images of different size, or dummy input
  413. self.run_model(
  414. model,
  415. [(images,), (test_images,), (dummy_image,)],
  416. input_names=["images_tensors"],
  417. output_names=["boxes", "labels", "scores", "masks"],
  418. dynamic_axes={
  419. "images_tensors": [0, 1, 2],
  420. "boxes": [0, 1],
  421. "labels": [0],
  422. "scores": [0],
  423. "masks": [0, 1, 2],
  424. },
  425. )
  426. # Test exported model for an image with no detections on other images
  427. self.run_model(
  428. model,
  429. [(dummy_image,), (images,)],
  430. input_names=["images_tensors"],
  431. output_names=["boxes", "labels", "scores", "masks"],
  432. dynamic_axes={
  433. "images_tensors": [0, 1, 2],
  434. "boxes": [0, 1],
  435. "labels": [0],
  436. "scores": [0],
  437. "masks": [0, 1, 2],
  438. },
  439. )
  440. # Verify that heatmaps_to_keypoints behaves the same in tracing.
  441. # This test also compares both heatmaps_to_keypoints and _onnx_heatmaps_to_keypoints
  442. # (since jit_trace witll call _heatmaps_to_keypoints).
  443. def test_heatmaps_to_keypoints(self):
  444. maps = torch.rand(10, 1, 26, 26)
  445. rois = torch.rand(10, 4)
  446. from torchvision.models.detection.roi_heads import heatmaps_to_keypoints
  447. out = heatmaps_to_keypoints(maps, rois)
  448. jit_trace = torch.jit.trace(heatmaps_to_keypoints, (maps, rois))
  449. out_trace = jit_trace(maps, rois)
  450. assert_equal(out[0], out_trace[0])
  451. assert_equal(out[1], out_trace[1])
  452. maps2 = torch.rand(20, 2, 21, 21)
  453. rois2 = torch.rand(20, 4)
  454. from torchvision.models.detection.roi_heads import heatmaps_to_keypoints
  455. out2 = heatmaps_to_keypoints(maps2, rois2)
  456. out_trace2 = jit_trace(maps2, rois2)
  457. assert_equal(out2[0], out_trace2[0])
  458. assert_equal(out2[1], out_trace2[1])
  459. def test_keypoint_rcnn(self):
  460. images, test_images = self.get_test_images()
  461. dummy_images = [torch.ones(3, 100, 100) * 0.3]
  462. model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(
  463. weights=models.detection.keypoint_rcnn.KeypointRCNN_ResNet50_FPN_Weights.DEFAULT, min_size=200, max_size=300
  464. )
  465. model.eval()
  466. model(images)
  467. self.run_model(
  468. model,
  469. [(images,), (test_images,), (dummy_images,)],
  470. input_names=["images_tensors"],
  471. output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
  472. dynamic_axes={"images_tensors": [0, 1, 2]},
  473. )
  474. self.run_model(
  475. model,
  476. [(dummy_images,), (test_images,)],
  477. input_names=["images_tensors"],
  478. output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
  479. dynamic_axes={"images_tensors": [0, 1, 2]},
  480. )
  481. def test_shufflenet_v2_dynamic_axes(self):
  482. model = models.shufflenet_v2_x0_5(weights=models.ShuffleNet_V2_X0_5_Weights.DEFAULT)
  483. dummy_input = torch.randn(1, 3, 224, 224, requires_grad=True)
  484. test_inputs = torch.cat([dummy_input, dummy_input, dummy_input], 0)
  485. self.run_model(
  486. model,
  487. [(dummy_input,), (test_inputs,)],
  488. input_names=["input_images"],
  489. output_names=["output"],
  490. dynamic_axes={"input_images": {0: "batch_size"}, "output": {0: "batch_size"}},
  491. )
  492. if __name__ == "__main__":
  493. pytest.main([__file__])