123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478 |
- import warnings
- from collections import namedtuple
- from functools import partial
- from typing import Any, Callable, List, Optional, Tuple
- import torch
- import torch.nn.functional as F
- from torch import nn, Tensor
- from ..transforms._presets import ImageClassification
- from ..utils import _log_api_usage_once
- from ._api import register_model, Weights, WeightsEnum
- from ._meta import _IMAGENET_CATEGORIES
- from ._utils import _ovewrite_named_param, handle_legacy_interface
- __all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "Inception_V3_Weights", "inception_v3"]
- InceptionOutputs = namedtuple("InceptionOutputs", ["logits", "aux_logits"])
- InceptionOutputs.__annotations__ = {"logits": Tensor, "aux_logits": Optional[Tensor]}
- # Script annotations failed with _GoogleNetOutputs = namedtuple ...
- # _InceptionOutputs set here for backwards compat
- _InceptionOutputs = InceptionOutputs
- class Inception3(nn.Module):
- def __init__(
- self,
- num_classes: int = 1000,
- aux_logits: bool = True,
- transform_input: bool = False,
- inception_blocks: Optional[List[Callable[..., nn.Module]]] = None,
- init_weights: Optional[bool] = None,
- dropout: float = 0.5,
- ) -> None:
- super().__init__()
- _log_api_usage_once(self)
- if inception_blocks is None:
- inception_blocks = [BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE, InceptionAux]
- if init_weights is None:
- warnings.warn(
- "The default weight initialization of inception_v3 will be changed in future releases of "
- "torchvision. If you wish to keep the old behavior (which leads to long initialization times"
- " due to scipy/scipy#11299), please set init_weights=True.",
- FutureWarning,
- )
- init_weights = True
- if len(inception_blocks) != 7:
- raise ValueError(f"length of inception_blocks should be 7 instead of {len(inception_blocks)}")
- conv_block = inception_blocks[0]
- inception_a = inception_blocks[1]
- inception_b = inception_blocks[2]
- inception_c = inception_blocks[3]
- inception_d = inception_blocks[4]
- inception_e = inception_blocks[5]
- inception_aux = inception_blocks[6]
- self.aux_logits = aux_logits
- self.transform_input = transform_input
- self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2)
- self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
- self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
- self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2)
- self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
- self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
- self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2)
- self.Mixed_5b = inception_a(192, pool_features=32)
- self.Mixed_5c = inception_a(256, pool_features=64)
- self.Mixed_5d = inception_a(288, pool_features=64)
- self.Mixed_6a = inception_b(288)
- self.Mixed_6b = inception_c(768, channels_7x7=128)
- self.Mixed_6c = inception_c(768, channels_7x7=160)
- self.Mixed_6d = inception_c(768, channels_7x7=160)
- self.Mixed_6e = inception_c(768, channels_7x7=192)
- self.AuxLogits: Optional[nn.Module] = None
- if aux_logits:
- self.AuxLogits = inception_aux(768, num_classes)
- self.Mixed_7a = inception_d(768)
- self.Mixed_7b = inception_e(1280)
- self.Mixed_7c = inception_e(2048)
- self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
- self.dropout = nn.Dropout(p=dropout)
- self.fc = nn.Linear(2048, num_classes)
- if init_weights:
- for m in self.modules():
- if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
- stddev = float(m.stddev) if hasattr(m, "stddev") else 0.1 # type: ignore
- torch.nn.init.trunc_normal_(m.weight, mean=0.0, std=stddev, a=-2, b=2)
- elif isinstance(m, nn.BatchNorm2d):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
- def _transform_input(self, x: Tensor) -> Tensor:
- if self.transform_input:
- x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
- x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
- x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
- x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
- return x
- def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
- # N x 3 x 299 x 299
- x = self.Conv2d_1a_3x3(x)
- # N x 32 x 149 x 149
- x = self.Conv2d_2a_3x3(x)
- # N x 32 x 147 x 147
- x = self.Conv2d_2b_3x3(x)
- # N x 64 x 147 x 147
- x = self.maxpool1(x)
- # N x 64 x 73 x 73
- x = self.Conv2d_3b_1x1(x)
- # N x 80 x 73 x 73
- x = self.Conv2d_4a_3x3(x)
- # N x 192 x 71 x 71
- x = self.maxpool2(x)
- # N x 192 x 35 x 35
- x = self.Mixed_5b(x)
- # N x 256 x 35 x 35
- x = self.Mixed_5c(x)
- # N x 288 x 35 x 35
- x = self.Mixed_5d(x)
- # N x 288 x 35 x 35
- x = self.Mixed_6a(x)
- # N x 768 x 17 x 17
- x = self.Mixed_6b(x)
- # N x 768 x 17 x 17
- x = self.Mixed_6c(x)
- # N x 768 x 17 x 17
- x = self.Mixed_6d(x)
- # N x 768 x 17 x 17
- x = self.Mixed_6e(x)
- # N x 768 x 17 x 17
- aux: Optional[Tensor] = None
- if self.AuxLogits is not None:
- if self.training:
- aux = self.AuxLogits(x)
- # N x 768 x 17 x 17
- x = self.Mixed_7a(x)
- # N x 1280 x 8 x 8
- x = self.Mixed_7b(x)
- # N x 2048 x 8 x 8
- x = self.Mixed_7c(x)
- # N x 2048 x 8 x 8
- # Adaptive average pooling
- x = self.avgpool(x)
- # N x 2048 x 1 x 1
- x = self.dropout(x)
- # N x 2048 x 1 x 1
- x = torch.flatten(x, 1)
- # N x 2048
- x = self.fc(x)
- # N x 1000 (num_classes)
- return x, aux
- @torch.jit.unused
- def eager_outputs(self, x: Tensor, aux: Optional[Tensor]) -> InceptionOutputs:
- if self.training and self.aux_logits:
- return InceptionOutputs(x, aux)
- else:
- return x # type: ignore[return-value]
- def forward(self, x: Tensor) -> InceptionOutputs:
- x = self._transform_input(x)
- x, aux = self._forward(x)
- aux_defined = self.training and self.aux_logits
- if torch.jit.is_scripting():
- if not aux_defined:
- warnings.warn("Scripted Inception3 always returns Inception3 Tuple")
- return InceptionOutputs(x, aux)
- else:
- return self.eager_outputs(x, aux)
- class InceptionA(nn.Module):
- def __init__(
- self, in_channels: int, pool_features: int, conv_block: Optional[Callable[..., nn.Module]] = None
- ) -> None:
- super().__init__()
- if conv_block is None:
- conv_block = BasicConv2d
- self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)
- self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
- self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)
- self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
- self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
- self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)
- self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)
- def _forward(self, x: Tensor) -> List[Tensor]:
- branch1x1 = self.branch1x1(x)
- branch5x5 = self.branch5x5_1(x)
- branch5x5 = self.branch5x5_2(branch5x5)
- branch3x3dbl = self.branch3x3dbl_1(x)
- branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
- branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
- branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
- branch_pool = self.branch_pool(branch_pool)
- outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
- return outputs
- def forward(self, x: Tensor) -> Tensor:
- outputs = self._forward(x)
- return torch.cat(outputs, 1)
- class InceptionB(nn.Module):
- def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None:
- super().__init__()
- if conv_block is None:
- conv_block = BasicConv2d
- self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)
- self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
- self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
- self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)
- def _forward(self, x: Tensor) -> List[Tensor]:
- branch3x3 = self.branch3x3(x)
- branch3x3dbl = self.branch3x3dbl_1(x)
- branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
- branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
- branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
- outputs = [branch3x3, branch3x3dbl, branch_pool]
- return outputs
- def forward(self, x: Tensor) -> Tensor:
- outputs = self._forward(x)
- return torch.cat(outputs, 1)
- class InceptionC(nn.Module):
- def __init__(
- self, in_channels: int, channels_7x7: int, conv_block: Optional[Callable[..., nn.Module]] = None
- ) -> None:
- super().__init__()
- if conv_block is None:
- conv_block = BasicConv2d
- self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)
- c7 = channels_7x7
- self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)
- self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
- self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))
- self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)
- self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
- self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
- self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
- self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))
- self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
- def _forward(self, x: Tensor) -> List[Tensor]:
- branch1x1 = self.branch1x1(x)
- branch7x7 = self.branch7x7_1(x)
- branch7x7 = self.branch7x7_2(branch7x7)
- branch7x7 = self.branch7x7_3(branch7x7)
- branch7x7dbl = self.branch7x7dbl_1(x)
- branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
- branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
- branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
- branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
- branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
- branch_pool = self.branch_pool(branch_pool)
- outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
- return outputs
- def forward(self, x: Tensor) -> Tensor:
- outputs = self._forward(x)
- return torch.cat(outputs, 1)
- class InceptionD(nn.Module):
- def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None:
- super().__init__()
- if conv_block is None:
- conv_block = BasicConv2d
- self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
- self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)
- self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
- self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
- self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
- self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)
- def _forward(self, x: Tensor) -> List[Tensor]:
- branch3x3 = self.branch3x3_1(x)
- branch3x3 = self.branch3x3_2(branch3x3)
- branch7x7x3 = self.branch7x7x3_1(x)
- branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
- branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
- branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
- branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
- outputs = [branch3x3, branch7x7x3, branch_pool]
- return outputs
- def forward(self, x: Tensor) -> Tensor:
- outputs = self._forward(x)
- return torch.cat(outputs, 1)
- class InceptionE(nn.Module):
- def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None:
- super().__init__()
- if conv_block is None:
- conv_block = BasicConv2d
- self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)
- self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
- self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
- self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
- self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)
- self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)
- self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
- self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
- self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
- def _forward(self, x: Tensor) -> List[Tensor]:
- branch1x1 = self.branch1x1(x)
- branch3x3 = self.branch3x3_1(x)
- branch3x3 = [
- self.branch3x3_2a(branch3x3),
- self.branch3x3_2b(branch3x3),
- ]
- branch3x3 = torch.cat(branch3x3, 1)
- branch3x3dbl = self.branch3x3dbl_1(x)
- branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
- branch3x3dbl = [
- self.branch3x3dbl_3a(branch3x3dbl),
- self.branch3x3dbl_3b(branch3x3dbl),
- ]
- branch3x3dbl = torch.cat(branch3x3dbl, 1)
- branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
- branch_pool = self.branch_pool(branch_pool)
- outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
- return outputs
- def forward(self, x: Tensor) -> Tensor:
- outputs = self._forward(x)
- return torch.cat(outputs, 1)
- class InceptionAux(nn.Module):
- def __init__(
- self, in_channels: int, num_classes: int, conv_block: Optional[Callable[..., nn.Module]] = None
- ) -> None:
- super().__init__()
- if conv_block is None:
- conv_block = BasicConv2d
- self.conv0 = conv_block(in_channels, 128, kernel_size=1)
- self.conv1 = conv_block(128, 768, kernel_size=5)
- self.conv1.stddev = 0.01 # type: ignore[assignment]
- self.fc = nn.Linear(768, num_classes)
- self.fc.stddev = 0.001 # type: ignore[assignment]
- def forward(self, x: Tensor) -> Tensor:
- # N x 768 x 17 x 17
- x = F.avg_pool2d(x, kernel_size=5, stride=3)
- # N x 768 x 5 x 5
- x = self.conv0(x)
- # N x 128 x 5 x 5
- x = self.conv1(x)
- # N x 768 x 1 x 1
- # Adaptive average pooling
- x = F.adaptive_avg_pool2d(x, (1, 1))
- # N x 768 x 1 x 1
- x = torch.flatten(x, 1)
- # N x 768
- x = self.fc(x)
- # N x 1000
- return x
- class BasicConv2d(nn.Module):
- def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None:
- super().__init__()
- self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
- self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
- def forward(self, x: Tensor) -> Tensor:
- x = self.conv(x)
- x = self.bn(x)
- return F.relu(x, inplace=True)
- class Inception_V3_Weights(WeightsEnum):
- IMAGENET1K_V1 = Weights(
- url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
- transforms=partial(ImageClassification, crop_size=299, resize_size=342),
- meta={
- "num_params": 27161264,
- "min_size": (75, 75),
- "categories": _IMAGENET_CATEGORIES,
- "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#inception-v3",
- "_metrics": {
- "ImageNet-1K": {
- "acc@1": 77.294,
- "acc@5": 93.450,
- }
- },
- "_ops": 5.713,
- "_file_size": 103.903,
- "_docs": """These weights are ported from the original paper.""",
- },
- )
- DEFAULT = IMAGENET1K_V1
- @register_model()
- @handle_legacy_interface(weights=("pretrained", Inception_V3_Weights.IMAGENET1K_V1))
- def inception_v3(*, weights: Optional[Inception_V3_Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3:
- """
- Inception v3 model architecture from
- `Rethinking the Inception Architecture for Computer Vision <http://arxiv.org/abs/1512.00567>`_.
- .. note::
- **Important**: In contrast to the other models the inception_v3 expects tensors with a size of
- N x 3 x 299 x 299, so ensure your images are sized accordingly.
- Args:
- weights (:class:`~torchvision.models.Inception_V3_Weights`, optional): The
- pretrained weights for the model. See
- :class:`~torchvision.models.Inception_V3_Weights` below for
- more details, and possible values. By default, no pre-trained
- weights are used.
- progress (bool, optional): If True, displays a progress bar of the
- download to stderr. Default is True.
- **kwargs: parameters passed to the ``torchvision.models.Inception3``
- base class. Please refer to the `source code
- <https://github.com/pytorch/vision/blob/main/torchvision/models/inception.py>`_
- for more details about this class.
- .. autoclass:: torchvision.models.Inception_V3_Weights
- :members:
- """
- weights = Inception_V3_Weights.verify(weights)
- original_aux_logits = kwargs.get("aux_logits", True)
- if weights is not None:
- if "transform_input" not in kwargs:
- _ovewrite_named_param(kwargs, "transform_input", True)
- _ovewrite_named_param(kwargs, "aux_logits", True)
- _ovewrite_named_param(kwargs, "init_weights", False)
- _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
- model = Inception3(**kwargs)
- if weights is not None:
- model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
- if not original_aux_logits:
- model.aux_logits = False
- model.AuxLogits = None
- return model
|