test_backbone_utils.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. import random
  2. from itertools import chain
  3. from typing import Mapping, Sequence
  4. import pytest
  5. import torch
  6. from common_utils import set_rng_seed
  7. from torchvision import models
  8. from torchvision.models._utils import IntermediateLayerGetter
  9. from torchvision.models.detection.backbone_utils import BackboneWithFPN, mobilenet_backbone, resnet_fpn_backbone
  10. from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
  11. @pytest.mark.parametrize("backbone_name", ("resnet18", "resnet50"))
  12. def test_resnet_fpn_backbone(backbone_name):
  13. x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device="cpu")
  14. model = resnet_fpn_backbone(backbone_name=backbone_name, weights=None)
  15. assert isinstance(model, BackboneWithFPN)
  16. y = model(x)
  17. assert list(y.keys()) == ["0", "1", "2", "3", "pool"]
  18. with pytest.raises(ValueError, match=r"Trainable layers should be in the range"):
  19. resnet_fpn_backbone(backbone_name=backbone_name, weights=None, trainable_layers=6)
  20. with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
  21. resnet_fpn_backbone(backbone_name=backbone_name, weights=None, returned_layers=[0, 1, 2, 3])
  22. with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
  23. resnet_fpn_backbone(backbone_name=backbone_name, weights=None, returned_layers=[2, 3, 4, 5])
  24. @pytest.mark.parametrize("backbone_name", ("mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small"))
  25. def test_mobilenet_backbone(backbone_name):
  26. with pytest.raises(ValueError, match=r"Trainable layers should be in the range"):
  27. mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=False, trainable_layers=-1)
  28. with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
  29. mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=True, returned_layers=[-1, 0, 1, 2])
  30. with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
  31. mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=True, returned_layers=[3, 4, 5, 6])
  32. model_fpn = mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=True)
  33. assert isinstance(model_fpn, BackboneWithFPN)
  34. model = mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=False)
  35. assert isinstance(model, torch.nn.Sequential)
  36. # Needed by TestFxFeatureExtraction.test_leaf_module_and_function
  37. def leaf_function(x):
  38. return int(x)
  39. # Needed by TestFXFeatureExtraction. Checking that node naming conventions
  40. # are respected. Particularly the index postfix of repeated node names
  41. class TestSubModule(torch.nn.Module):
  42. def __init__(self):
  43. super().__init__()
  44. self.relu = torch.nn.ReLU()
  45. def forward(self, x):
  46. x = x + 1
  47. x = x + 1
  48. x = self.relu(x)
  49. x = self.relu(x)
  50. return x
  51. class TestModule(torch.nn.Module):
  52. def __init__(self):
  53. super().__init__()
  54. self.submodule = TestSubModule()
  55. self.relu = torch.nn.ReLU()
  56. def forward(self, x):
  57. x = self.submodule(x)
  58. x = x + 1
  59. x = x + 1
  60. x = self.relu(x)
  61. x = self.relu(x)
  62. return x
  63. test_module_nodes = [
  64. "x",
  65. "submodule.add",
  66. "submodule.add_1",
  67. "submodule.relu",
  68. "submodule.relu_1",
  69. "add",
  70. "add_1",
  71. "relu",
  72. "relu_1",
  73. ]
  74. class TestFxFeatureExtraction:
  75. inp = torch.rand(1, 3, 224, 224, dtype=torch.float32, device="cpu")
  76. model_defaults = {"num_classes": 1}
  77. leaf_modules = []
  78. def _create_feature_extractor(self, *args, **kwargs):
  79. """
  80. Apply leaf modules
  81. """
  82. tracer_kwargs = {}
  83. if "tracer_kwargs" not in kwargs:
  84. tracer_kwargs = {"leaf_modules": self.leaf_modules}
  85. else:
  86. tracer_kwargs = kwargs.pop("tracer_kwargs")
  87. return create_feature_extractor(*args, **kwargs, tracer_kwargs=tracer_kwargs, suppress_diff_warning=True)
  88. def _get_return_nodes(self, model):
  89. set_rng_seed(0)
  90. exclude_nodes_filter = [
  91. "getitem",
  92. "floordiv",
  93. "size",
  94. "chunk",
  95. "_assert",
  96. "eq",
  97. "dim",
  98. "getattr",
  99. ]
  100. train_nodes, eval_nodes = get_graph_node_names(
  101. model, tracer_kwargs={"leaf_modules": self.leaf_modules}, suppress_diff_warning=True
  102. )
  103. # Get rid of any nodes that don't return tensors as they cause issues
  104. # when testing backward pass.
  105. train_nodes = [n for n in train_nodes if not any(x in n for x in exclude_nodes_filter)]
  106. eval_nodes = [n for n in eval_nodes if not any(x in n for x in exclude_nodes_filter)]
  107. return random.sample(train_nodes, 10), random.sample(eval_nodes, 10)
  108. @pytest.mark.parametrize("model_name", models.list_models(models))
  109. def test_build_fx_feature_extractor(self, model_name):
  110. set_rng_seed(0)
  111. model = models.get_model(model_name, **self.model_defaults).eval()
  112. train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
  113. # Check that it works with both a list and dict for return nodes
  114. self._create_feature_extractor(
  115. model, train_return_nodes={v: v for v in train_return_nodes}, eval_return_nodes=eval_return_nodes
  116. )
  117. self._create_feature_extractor(
  118. model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
  119. )
  120. # Check must specify return nodes
  121. with pytest.raises(ValueError):
  122. self._create_feature_extractor(model)
  123. # Check return_nodes and train_return_nodes / eval_return nodes
  124. # mutual exclusivity
  125. with pytest.raises(ValueError):
  126. self._create_feature_extractor(
  127. model, return_nodes=train_return_nodes, train_return_nodes=train_return_nodes
  128. )
  129. # Check train_return_nodes / eval_return nodes must both be specified
  130. with pytest.raises(ValueError):
  131. self._create_feature_extractor(model, train_return_nodes=train_return_nodes)
  132. # Check invalid node name raises ValueError
  133. with pytest.raises(ValueError):
  134. # First just double check that this node really doesn't exist
  135. if not any(n.startswith("l") or n.startswith("l.") for n in chain(train_return_nodes, eval_return_nodes)):
  136. self._create_feature_extractor(model, train_return_nodes=["l"], eval_return_nodes=["l"])
  137. else: # otherwise skip this check
  138. raise ValueError
  139. def test_node_name_conventions(self):
  140. model = TestModule()
  141. train_nodes, _ = get_graph_node_names(model)
  142. assert all(a == b for a, b in zip(train_nodes, test_module_nodes))
  143. @pytest.mark.parametrize("model_name", models.list_models(models))
  144. def test_forward_backward(self, model_name):
  145. model = models.get_model(model_name, **self.model_defaults).train()
  146. train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
  147. model = self._create_feature_extractor(
  148. model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
  149. )
  150. out = model(self.inp)
  151. out_agg = 0
  152. for node_out in out.values():
  153. if isinstance(node_out, Sequence):
  154. out_agg += sum(o.float().mean() for o in node_out if o is not None)
  155. elif isinstance(node_out, Mapping):
  156. out_agg += sum(o.float().mean() for o in node_out.values() if o is not None)
  157. else:
  158. # Assume that the only other alternative at this point is a Tensor
  159. out_agg += node_out.float().mean()
  160. out_agg.backward()
  161. def test_feature_extraction_methods_equivalence(self):
  162. model = models.resnet18(**self.model_defaults).eval()
  163. return_layers = {"layer1": "layer1", "layer2": "layer2", "layer3": "layer3", "layer4": "layer4"}
  164. ilg_model = IntermediateLayerGetter(model, return_layers).eval()
  165. fx_model = self._create_feature_extractor(model, return_layers)
  166. # Check that we have same parameters
  167. for (n1, p1), (n2, p2) in zip(ilg_model.named_parameters(), fx_model.named_parameters()):
  168. assert n1 == n2
  169. assert p1.equal(p2)
  170. # And that outputs match
  171. with torch.no_grad():
  172. ilg_out = ilg_model(self.inp)
  173. fgn_out = fx_model(self.inp)
  174. assert all(k1 == k2 for k1, k2 in zip(ilg_out.keys(), fgn_out.keys()))
  175. for k in ilg_out.keys():
  176. assert ilg_out[k].equal(fgn_out[k])
  177. @pytest.mark.parametrize("model_name", models.list_models(models))
  178. def test_jit_forward_backward(self, model_name):
  179. set_rng_seed(0)
  180. model = models.get_model(model_name, **self.model_defaults).train()
  181. train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
  182. model = self._create_feature_extractor(
  183. model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
  184. )
  185. model = torch.jit.script(model)
  186. fgn_out = model(self.inp)
  187. out_agg = 0
  188. for node_out in fgn_out.values():
  189. if isinstance(node_out, Sequence):
  190. out_agg += sum(o.float().mean() for o in node_out if o is not None)
  191. elif isinstance(node_out, Mapping):
  192. out_agg += sum(o.float().mean() for o in node_out.values() if o is not None)
  193. else:
  194. # Assume that the only other alternative at this point is a Tensor
  195. out_agg += node_out.float().mean()
  196. out_agg.backward()
  197. def test_train_eval(self):
  198. class TestModel(torch.nn.Module):
  199. def __init__(self):
  200. super().__init__()
  201. self.dropout = torch.nn.Dropout(p=1.0)
  202. def forward(self, x):
  203. x = x.float().mean()
  204. x = self.dropout(x) # dropout
  205. if self.training:
  206. x += 100 # add
  207. else:
  208. x *= 0 # mul
  209. x -= 0 # sub
  210. return x
  211. model = TestModel()
  212. train_return_nodes = ["dropout", "add", "sub"]
  213. eval_return_nodes = ["dropout", "mul", "sub"]
  214. def checks(model, mode):
  215. with torch.no_grad():
  216. out = model(torch.ones(10, 10))
  217. if mode == "train":
  218. # Check that dropout is respected
  219. assert out["dropout"].item() == 0
  220. # Check that control flow dependent on training_mode is respected
  221. assert out["sub"].item() == 100
  222. assert "add" in out
  223. assert "mul" not in out
  224. elif mode == "eval":
  225. # Check that dropout is respected
  226. assert out["dropout"].item() == 1
  227. # Check that control flow dependent on training_mode is respected
  228. assert out["sub"].item() == 0
  229. assert "mul" in out
  230. assert "add" not in out
  231. # Starting from train mode
  232. model.train()
  233. fx_model = self._create_feature_extractor(
  234. model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
  235. )
  236. # Check that the models stay in their original training state
  237. assert model.training
  238. assert fx_model.training
  239. # Check outputs
  240. checks(fx_model, "train")
  241. # Check outputs after switching to eval mode
  242. fx_model.eval()
  243. checks(fx_model, "eval")
  244. # Starting from eval mode
  245. model.eval()
  246. fx_model = self._create_feature_extractor(
  247. model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
  248. )
  249. # Check that the models stay in their original training state
  250. assert not model.training
  251. assert not fx_model.training
  252. # Check outputs
  253. checks(fx_model, "eval")
  254. # Check outputs after switching to train mode
  255. fx_model.train()
  256. checks(fx_model, "train")
  257. def test_leaf_module_and_function(self):
  258. class LeafModule(torch.nn.Module):
  259. def forward(self, x):
  260. # This would raise a TypeError if it were not in a leaf module
  261. int(x.shape[0])
  262. return torch.nn.functional.relu(x + 4)
  263. class TestModule(torch.nn.Module):
  264. def __init__(self):
  265. super().__init__()
  266. self.conv = torch.nn.Conv2d(3, 1, 3)
  267. self.leaf_module = LeafModule()
  268. def forward(self, x):
  269. leaf_function(x.shape[0])
  270. x = self.conv(x)
  271. return self.leaf_module(x)
  272. model = self._create_feature_extractor(
  273. TestModule(),
  274. return_nodes=["leaf_module"],
  275. tracer_kwargs={"leaf_modules": [LeafModule], "autowrap_functions": [leaf_function]},
  276. ).train()
  277. # Check that LeafModule is not in the list of nodes
  278. assert "relu" not in [str(n) for n in model.graph.nodes]
  279. assert "leaf_module" in [str(n) for n in model.graph.nodes]
  280. # Check forward
  281. out = model(self.inp)
  282. # And backward
  283. out["leaf_module"].float().mean().backward()