123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324 |
- import random
- from itertools import chain
- from typing import Mapping, Sequence
- import pytest
- import torch
- from common_utils import set_rng_seed
- from torchvision import models
- from torchvision.models._utils import IntermediateLayerGetter
- from torchvision.models.detection.backbone_utils import BackboneWithFPN, mobilenet_backbone, resnet_fpn_backbone
- from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
- @pytest.mark.parametrize("backbone_name", ("resnet18", "resnet50"))
- def test_resnet_fpn_backbone(backbone_name):
- x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device="cpu")
- model = resnet_fpn_backbone(backbone_name=backbone_name, weights=None)
- assert isinstance(model, BackboneWithFPN)
- y = model(x)
- assert list(y.keys()) == ["0", "1", "2", "3", "pool"]
- with pytest.raises(ValueError, match=r"Trainable layers should be in the range"):
- resnet_fpn_backbone(backbone_name=backbone_name, weights=None, trainable_layers=6)
- with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
- resnet_fpn_backbone(backbone_name=backbone_name, weights=None, returned_layers=[0, 1, 2, 3])
- with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
- resnet_fpn_backbone(backbone_name=backbone_name, weights=None, returned_layers=[2, 3, 4, 5])
- @pytest.mark.parametrize("backbone_name", ("mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small"))
- def test_mobilenet_backbone(backbone_name):
- with pytest.raises(ValueError, match=r"Trainable layers should be in the range"):
- mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=False, trainable_layers=-1)
- with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
- mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=True, returned_layers=[-1, 0, 1, 2])
- with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
- mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=True, returned_layers=[3, 4, 5, 6])
- model_fpn = mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=True)
- assert isinstance(model_fpn, BackboneWithFPN)
- model = mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=False)
- assert isinstance(model, torch.nn.Sequential)
- # Needed by TestFxFeatureExtraction.test_leaf_module_and_function
- def leaf_function(x):
- return int(x)
- # Needed by TestFXFeatureExtraction. Checking that node naming conventions
- # are respected. Particularly the index postfix of repeated node names
- class TestSubModule(torch.nn.Module):
- def __init__(self):
- super().__init__()
- self.relu = torch.nn.ReLU()
- def forward(self, x):
- x = x + 1
- x = x + 1
- x = self.relu(x)
- x = self.relu(x)
- return x
- class TestModule(torch.nn.Module):
- def __init__(self):
- super().__init__()
- self.submodule = TestSubModule()
- self.relu = torch.nn.ReLU()
- def forward(self, x):
- x = self.submodule(x)
- x = x + 1
- x = x + 1
- x = self.relu(x)
- x = self.relu(x)
- return x
- test_module_nodes = [
- "x",
- "submodule.add",
- "submodule.add_1",
- "submodule.relu",
- "submodule.relu_1",
- "add",
- "add_1",
- "relu",
- "relu_1",
- ]
- class TestFxFeatureExtraction:
- inp = torch.rand(1, 3, 224, 224, dtype=torch.float32, device="cpu")
- model_defaults = {"num_classes": 1}
- leaf_modules = []
- def _create_feature_extractor(self, *args, **kwargs):
- """
- Apply leaf modules
- """
- tracer_kwargs = {}
- if "tracer_kwargs" not in kwargs:
- tracer_kwargs = {"leaf_modules": self.leaf_modules}
- else:
- tracer_kwargs = kwargs.pop("tracer_kwargs")
- return create_feature_extractor(*args, **kwargs, tracer_kwargs=tracer_kwargs, suppress_diff_warning=True)
- def _get_return_nodes(self, model):
- set_rng_seed(0)
- exclude_nodes_filter = [
- "getitem",
- "floordiv",
- "size",
- "chunk",
- "_assert",
- "eq",
- "dim",
- "getattr",
- ]
- train_nodes, eval_nodes = get_graph_node_names(
- model, tracer_kwargs={"leaf_modules": self.leaf_modules}, suppress_diff_warning=True
- )
- # Get rid of any nodes that don't return tensors as they cause issues
- # when testing backward pass.
- train_nodes = [n for n in train_nodes if not any(x in n for x in exclude_nodes_filter)]
- eval_nodes = [n for n in eval_nodes if not any(x in n for x in exclude_nodes_filter)]
- return random.sample(train_nodes, 10), random.sample(eval_nodes, 10)
- @pytest.mark.parametrize("model_name", models.list_models(models))
- def test_build_fx_feature_extractor(self, model_name):
- set_rng_seed(0)
- model = models.get_model(model_name, **self.model_defaults).eval()
- train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
- # Check that it works with both a list and dict for return nodes
- self._create_feature_extractor(
- model, train_return_nodes={v: v for v in train_return_nodes}, eval_return_nodes=eval_return_nodes
- )
- self._create_feature_extractor(
- model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
- )
- # Check must specify return nodes
- with pytest.raises(ValueError):
- self._create_feature_extractor(model)
- # Check return_nodes and train_return_nodes / eval_return nodes
- # mutual exclusivity
- with pytest.raises(ValueError):
- self._create_feature_extractor(
- model, return_nodes=train_return_nodes, train_return_nodes=train_return_nodes
- )
- # Check train_return_nodes / eval_return nodes must both be specified
- with pytest.raises(ValueError):
- self._create_feature_extractor(model, train_return_nodes=train_return_nodes)
- # Check invalid node name raises ValueError
- with pytest.raises(ValueError):
- # First just double check that this node really doesn't exist
- if not any(n.startswith("l") or n.startswith("l.") for n in chain(train_return_nodes, eval_return_nodes)):
- self._create_feature_extractor(model, train_return_nodes=["l"], eval_return_nodes=["l"])
- else: # otherwise skip this check
- raise ValueError
- def test_node_name_conventions(self):
- model = TestModule()
- train_nodes, _ = get_graph_node_names(model)
- assert all(a == b for a, b in zip(train_nodes, test_module_nodes))
- @pytest.mark.parametrize("model_name", models.list_models(models))
- def test_forward_backward(self, model_name):
- model = models.get_model(model_name, **self.model_defaults).train()
- train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
- model = self._create_feature_extractor(
- model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
- )
- out = model(self.inp)
- out_agg = 0
- for node_out in out.values():
- if isinstance(node_out, Sequence):
- out_agg += sum(o.float().mean() for o in node_out if o is not None)
- elif isinstance(node_out, Mapping):
- out_agg += sum(o.float().mean() for o in node_out.values() if o is not None)
- else:
- # Assume that the only other alternative at this point is a Tensor
- out_agg += node_out.float().mean()
- out_agg.backward()
- def test_feature_extraction_methods_equivalence(self):
- model = models.resnet18(**self.model_defaults).eval()
- return_layers = {"layer1": "layer1", "layer2": "layer2", "layer3": "layer3", "layer4": "layer4"}
- ilg_model = IntermediateLayerGetter(model, return_layers).eval()
- fx_model = self._create_feature_extractor(model, return_layers)
- # Check that we have same parameters
- for (n1, p1), (n2, p2) in zip(ilg_model.named_parameters(), fx_model.named_parameters()):
- assert n1 == n2
- assert p1.equal(p2)
- # And that outputs match
- with torch.no_grad():
- ilg_out = ilg_model(self.inp)
- fgn_out = fx_model(self.inp)
- assert all(k1 == k2 for k1, k2 in zip(ilg_out.keys(), fgn_out.keys()))
- for k in ilg_out.keys():
- assert ilg_out[k].equal(fgn_out[k])
- @pytest.mark.parametrize("model_name", models.list_models(models))
- def test_jit_forward_backward(self, model_name):
- set_rng_seed(0)
- model = models.get_model(model_name, **self.model_defaults).train()
- train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
- model = self._create_feature_extractor(
- model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
- )
- model = torch.jit.script(model)
- fgn_out = model(self.inp)
- out_agg = 0
- for node_out in fgn_out.values():
- if isinstance(node_out, Sequence):
- out_agg += sum(o.float().mean() for o in node_out if o is not None)
- elif isinstance(node_out, Mapping):
- out_agg += sum(o.float().mean() for o in node_out.values() if o is not None)
- else:
- # Assume that the only other alternative at this point is a Tensor
- out_agg += node_out.float().mean()
- out_agg.backward()
- def test_train_eval(self):
- class TestModel(torch.nn.Module):
- def __init__(self):
- super().__init__()
- self.dropout = torch.nn.Dropout(p=1.0)
- def forward(self, x):
- x = x.float().mean()
- x = self.dropout(x) # dropout
- if self.training:
- x += 100 # add
- else:
- x *= 0 # mul
- x -= 0 # sub
- return x
- model = TestModel()
- train_return_nodes = ["dropout", "add", "sub"]
- eval_return_nodes = ["dropout", "mul", "sub"]
- def checks(model, mode):
- with torch.no_grad():
- out = model(torch.ones(10, 10))
- if mode == "train":
- # Check that dropout is respected
- assert out["dropout"].item() == 0
- # Check that control flow dependent on training_mode is respected
- assert out["sub"].item() == 100
- assert "add" in out
- assert "mul" not in out
- elif mode == "eval":
- # Check that dropout is respected
- assert out["dropout"].item() == 1
- # Check that control flow dependent on training_mode is respected
- assert out["sub"].item() == 0
- assert "mul" in out
- assert "add" not in out
- # Starting from train mode
- model.train()
- fx_model = self._create_feature_extractor(
- model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
- )
- # Check that the models stay in their original training state
- assert model.training
- assert fx_model.training
- # Check outputs
- checks(fx_model, "train")
- # Check outputs after switching to eval mode
- fx_model.eval()
- checks(fx_model, "eval")
- # Starting from eval mode
- model.eval()
- fx_model = self._create_feature_extractor(
- model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
- )
- # Check that the models stay in their original training state
- assert not model.training
- assert not fx_model.training
- # Check outputs
- checks(fx_model, "eval")
- # Check outputs after switching to train mode
- fx_model.train()
- checks(fx_model, "train")
- def test_leaf_module_and_function(self):
- class LeafModule(torch.nn.Module):
- def forward(self, x):
- # This would raise a TypeError if it were not in a leaf module
- int(x.shape[0])
- return torch.nn.functional.relu(x + 4)
- class TestModule(torch.nn.Module):
- def __init__(self):
- super().__init__()
- self.conv = torch.nn.Conv2d(3, 1, 3)
- self.leaf_module = LeafModule()
- def forward(self, x):
- leaf_function(x.shape[0])
- x = self.conv(x)
- return self.leaf_module(x)
- model = self._create_feature_extractor(
- TestModule(),
- return_nodes=["leaf_module"],
- tracer_kwargs={"leaf_modules": [LeafModule], "autowrap_functions": [leaf_function]},
- ).train()
- # Check that LeafModule is not in the list of nodes
- assert "relu" not in [str(n) for n in model.graph.nodes]
- assert "leaf_module" in [str(n) for n in model.graph.nodes]
- # Check forward
- out = model(self.inp)
- # And backward
- out["leaf_module"].float().mean().backward()
|