_utils.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. from collections import OrderedDict
  2. from typing import Dict, Optional
  3. from torch import nn, Tensor
  4. from torch.nn import functional as F
  5. from ...utils import _log_api_usage_once
  6. class _SimpleSegmentationModel(nn.Module):
  7. __constants__ = ["aux_classifier"]
  8. def __init__(self, backbone: nn.Module, classifier: nn.Module, aux_classifier: Optional[nn.Module] = None) -> None:
  9. super().__init__()
  10. _log_api_usage_once(self)
  11. self.backbone = backbone
  12. self.classifier = classifier
  13. self.aux_classifier = aux_classifier
  14. def forward(self, x: Tensor) -> Dict[str, Tensor]:
  15. input_shape = x.shape[-2:]
  16. # contract: features is a dict of tensors
  17. features = self.backbone(x)
  18. result = OrderedDict()
  19. x = features["out"]
  20. x = self.classifier(x)
  21. x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
  22. result["out"] = x
  23. if self.aux_classifier is not None:
  24. x = features["aux"]
  25. x = self.aux_classifier(x)
  26. x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
  27. result["aux"] = x
  28. return result