12345678910111213141516171819202122232425262728293031323334353637 |
- from collections import OrderedDict
- from typing import Dict, Optional
- from torch import nn, Tensor
- from torch.nn import functional as F
- from ...utils import _log_api_usage_once
- class _SimpleSegmentationModel(nn.Module):
- __constants__ = ["aux_classifier"]
- def __init__(self, backbone: nn.Module, classifier: nn.Module, aux_classifier: Optional[nn.Module] = None) -> None:
- super().__init__()
- _log_api_usage_once(self)
- self.backbone = backbone
- self.classifier = classifier
- self.aux_classifier = aux_classifier
- def forward(self, x: Tensor) -> Dict[str, Tensor]:
- input_shape = x.shape[-2:]
- # contract: features is a dict of tensors
- features = self.backbone(x)
- result = OrderedDict()
- x = features["out"]
- x = self.classifier(x)
- x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
- result["out"] = x
- if self.aux_classifier is not None:
- x = features["aux"]
- x = self.aux_classifier(x)
- x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
- result["aux"] = x
- return result
|