123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216 |
- """
- This file is part of the private API. Please do not use directly these classes as they will be modified on
- future versions without warning. The classes should be accessed only via the transforms argument of Weights.
- """
- from typing import Optional, Tuple, Union
- import torch
- from torch import nn, Tensor
- from . import functional as F, InterpolationMode
- __all__ = [
- "ObjectDetection",
- "ImageClassification",
- "VideoClassification",
- "SemanticSegmentation",
- "OpticalFlow",
- ]
- class ObjectDetection(nn.Module):
- def forward(self, img: Tensor) -> Tensor:
- if not isinstance(img, Tensor):
- img = F.pil_to_tensor(img)
- return F.convert_image_dtype(img, torch.float)
- def __repr__(self) -> str:
- return self.__class__.__name__ + "()"
- def describe(self) -> str:
- return (
- "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
- "The images are rescaled to ``[0.0, 1.0]``."
- )
- class ImageClassification(nn.Module):
- def __init__(
- self,
- *,
- crop_size: int,
- resize_size: int = 256,
- mean: Tuple[float, ...] = (0.485, 0.456, 0.406),
- std: Tuple[float, ...] = (0.229, 0.224, 0.225),
- interpolation: InterpolationMode = InterpolationMode.BILINEAR,
- antialias: Optional[Union[str, bool]] = "warn",
- ) -> None:
- super().__init__()
- self.crop_size = [crop_size]
- self.resize_size = [resize_size]
- self.mean = list(mean)
- self.std = list(std)
- self.interpolation = interpolation
- self.antialias = antialias
- def forward(self, img: Tensor) -> Tensor:
- img = F.resize(img, self.resize_size, interpolation=self.interpolation, antialias=self.antialias)
- img = F.center_crop(img, self.crop_size)
- if not isinstance(img, Tensor):
- img = F.pil_to_tensor(img)
- img = F.convert_image_dtype(img, torch.float)
- img = F.normalize(img, mean=self.mean, std=self.std)
- return img
- def __repr__(self) -> str:
- format_string = self.__class__.__name__ + "("
- format_string += f"\n crop_size={self.crop_size}"
- format_string += f"\n resize_size={self.resize_size}"
- format_string += f"\n mean={self.mean}"
- format_string += f"\n std={self.std}"
- format_string += f"\n interpolation={self.interpolation}"
- format_string += "\n)"
- return format_string
- def describe(self) -> str:
- return (
- "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
- f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, "
- f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to "
- f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``."
- )
- class VideoClassification(nn.Module):
- def __init__(
- self,
- *,
- crop_size: Tuple[int, int],
- resize_size: Tuple[int, int],
- mean: Tuple[float, ...] = (0.43216, 0.394666, 0.37645),
- std: Tuple[float, ...] = (0.22803, 0.22145, 0.216989),
- interpolation: InterpolationMode = InterpolationMode.BILINEAR,
- ) -> None:
- super().__init__()
- self.crop_size = list(crop_size)
- self.resize_size = list(resize_size)
- self.mean = list(mean)
- self.std = list(std)
- self.interpolation = interpolation
- def forward(self, vid: Tensor) -> Tensor:
- need_squeeze = False
- if vid.ndim < 5:
- vid = vid.unsqueeze(dim=0)
- need_squeeze = True
- N, T, C, H, W = vid.shape
- vid = vid.view(-1, C, H, W)
- # We hard-code antialias=False to preserve results after we changed
- # its default from None to True (see
- # https://github.com/pytorch/vision/pull/7160)
- # TODO: we could re-train the video models with antialias=True?
- vid = F.resize(vid, self.resize_size, interpolation=self.interpolation, antialias=False)
- vid = F.center_crop(vid, self.crop_size)
- vid = F.convert_image_dtype(vid, torch.float)
- vid = F.normalize(vid, mean=self.mean, std=self.std)
- H, W = self.crop_size
- vid = vid.view(N, T, C, H, W)
- vid = vid.permute(0, 2, 1, 3, 4) # (N, T, C, H, W) => (N, C, T, H, W)
- if need_squeeze:
- vid = vid.squeeze(dim=0)
- return vid
- def __repr__(self) -> str:
- format_string = self.__class__.__name__ + "("
- format_string += f"\n crop_size={self.crop_size}"
- format_string += f"\n resize_size={self.resize_size}"
- format_string += f"\n mean={self.mean}"
- format_string += f"\n std={self.std}"
- format_string += f"\n interpolation={self.interpolation}"
- format_string += "\n)"
- return format_string
- def describe(self) -> str:
- return (
- "Accepts batched ``(B, T, C, H, W)`` and single ``(T, C, H, W)`` video frame ``torch.Tensor`` objects. "
- f"The frames are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, "
- f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to "
- f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``. Finally the output "
- "dimensions are permuted to ``(..., C, T, H, W)`` tensors."
- )
- class SemanticSegmentation(nn.Module):
- def __init__(
- self,
- *,
- resize_size: Optional[int],
- mean: Tuple[float, ...] = (0.485, 0.456, 0.406),
- std: Tuple[float, ...] = (0.229, 0.224, 0.225),
- interpolation: InterpolationMode = InterpolationMode.BILINEAR,
- antialias: Optional[Union[str, bool]] = "warn",
- ) -> None:
- super().__init__()
- self.resize_size = [resize_size] if resize_size is not None else None
- self.mean = list(mean)
- self.std = list(std)
- self.interpolation = interpolation
- self.antialias = antialias
- def forward(self, img: Tensor) -> Tensor:
- if isinstance(self.resize_size, list):
- img = F.resize(img, self.resize_size, interpolation=self.interpolation, antialias=self.antialias)
- if not isinstance(img, Tensor):
- img = F.pil_to_tensor(img)
- img = F.convert_image_dtype(img, torch.float)
- img = F.normalize(img, mean=self.mean, std=self.std)
- return img
- def __repr__(self) -> str:
- format_string = self.__class__.__name__ + "("
- format_string += f"\n resize_size={self.resize_size}"
- format_string += f"\n mean={self.mean}"
- format_string += f"\n std={self.std}"
- format_string += f"\n interpolation={self.interpolation}"
- format_string += "\n)"
- return format_string
- def describe(self) -> str:
- return (
- "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
- f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``. "
- f"Finally the values are first rescaled to ``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and "
- f"``std={self.std}``."
- )
- class OpticalFlow(nn.Module):
- def forward(self, img1: Tensor, img2: Tensor) -> Tuple[Tensor, Tensor]:
- if not isinstance(img1, Tensor):
- img1 = F.pil_to_tensor(img1)
- if not isinstance(img2, Tensor):
- img2 = F.pil_to_tensor(img2)
- img1 = F.convert_image_dtype(img1, torch.float)
- img2 = F.convert_image_dtype(img2, torch.float)
- # map [0, 1] into [-1, 1]
- img1 = F.normalize(img1, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
- img2 = F.normalize(img2, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
- img1 = img1.contiguous()
- img2 = img2.contiguous()
- return img1, img2
- def __repr__(self) -> str:
- return self.__class__.__name__ + "()"
- def describe(self) -> str:
- return (
- "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
- "The images are rescaled to ``[-1.0, 1.0]``."
- )
|