123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320 |
- import warnings
- from typing import Callable, List, Optional, Sequence, Tuple, Union
- import torch
- from torch import Tensor
- from ..utils import _log_api_usage_once, _make_ntuple
- interpolate = torch.nn.functional.interpolate
- class FrozenBatchNorm2d(torch.nn.Module):
- """
- BatchNorm2d where the batch statistics and the affine parameters are fixed
- Args:
- num_features (int): Number of features ``C`` from an expected input of size ``(N, C, H, W)``
- eps (float): a value added to the denominator for numerical stability. Default: 1e-5
- """
- def __init__(
- self,
- num_features: int,
- eps: float = 1e-5,
- ):
- super().__init__()
- _log_api_usage_once(self)
- self.eps = eps
- self.register_buffer("weight", torch.ones(num_features))
- self.register_buffer("bias", torch.zeros(num_features))
- self.register_buffer("running_mean", torch.zeros(num_features))
- self.register_buffer("running_var", torch.ones(num_features))
- def _load_from_state_dict(
- self,
- state_dict: dict,
- prefix: str,
- local_metadata: dict,
- strict: bool,
- missing_keys: List[str],
- unexpected_keys: List[str],
- error_msgs: List[str],
- ):
- num_batches_tracked_key = prefix + "num_batches_tracked"
- if num_batches_tracked_key in state_dict:
- del state_dict[num_batches_tracked_key]
- super()._load_from_state_dict(
- state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
- )
- def forward(self, x: Tensor) -> Tensor:
- # move reshapes to the beginning
- # to make it fuser-friendly
- w = self.weight.reshape(1, -1, 1, 1)
- b = self.bias.reshape(1, -1, 1, 1)
- rv = self.running_var.reshape(1, -1, 1, 1)
- rm = self.running_mean.reshape(1, -1, 1, 1)
- scale = w * (rv + self.eps).rsqrt()
- bias = b - rm * scale
- return x * scale + bias
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps})"
- class ConvNormActivation(torch.nn.Sequential):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: Union[int, Tuple[int, ...]] = 3,
- stride: Union[int, Tuple[int, ...]] = 1,
- padding: Optional[Union[int, Tuple[int, ...], str]] = None,
- groups: int = 1,
- norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
- activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
- dilation: Union[int, Tuple[int, ...]] = 1,
- inplace: Optional[bool] = True,
- bias: Optional[bool] = None,
- conv_layer: Callable[..., torch.nn.Module] = torch.nn.Conv2d,
- ) -> None:
- if padding is None:
- if isinstance(kernel_size, int) and isinstance(dilation, int):
- padding = (kernel_size - 1) // 2 * dilation
- else:
- _conv_dim = len(kernel_size) if isinstance(kernel_size, Sequence) else len(dilation)
- kernel_size = _make_ntuple(kernel_size, _conv_dim)
- dilation = _make_ntuple(dilation, _conv_dim)
- padding = tuple((kernel_size[i] - 1) // 2 * dilation[i] for i in range(_conv_dim))
- if bias is None:
- bias = norm_layer is None
- layers = [
- conv_layer(
- in_channels,
- out_channels,
- kernel_size,
- stride,
- padding,
- dilation=dilation,
- groups=groups,
- bias=bias,
- )
- ]
- if norm_layer is not None:
- layers.append(norm_layer(out_channels))
- if activation_layer is not None:
- params = {} if inplace is None else {"inplace": inplace}
- layers.append(activation_layer(**params))
- super().__init__(*layers)
- _log_api_usage_once(self)
- self.out_channels = out_channels
- if self.__class__ == ConvNormActivation:
- warnings.warn(
- "Don't use ConvNormActivation directly, please use Conv2dNormActivation and Conv3dNormActivation instead."
- )
- class Conv2dNormActivation(ConvNormActivation):
- """
- Configurable block used for Convolution2d-Normalization-Activation blocks.
- Args:
- in_channels (int): Number of channels in the input image
- out_channels (int): Number of channels produced by the Convolution-Normalization-Activation block
- kernel_size: (int, optional): Size of the convolving kernel. Default: 3
- stride (int, optional): Stride of the convolution. Default: 1
- padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will be calculated as ``padding = (kernel_size - 1) // 2 * dilation``
- groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
- norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer won't be used. Default: ``torch.nn.BatchNorm2d``
- activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU``
- dilation (int): Spacing between kernel elements. Default: 1
- inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
- bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
- """
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: Union[int, Tuple[int, int]] = 3,
- stride: Union[int, Tuple[int, int]] = 1,
- padding: Optional[Union[int, Tuple[int, int], str]] = None,
- groups: int = 1,
- norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
- activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
- dilation: Union[int, Tuple[int, int]] = 1,
- inplace: Optional[bool] = True,
- bias: Optional[bool] = None,
- ) -> None:
- super().__init__(
- in_channels,
- out_channels,
- kernel_size,
- stride,
- padding,
- groups,
- norm_layer,
- activation_layer,
- dilation,
- inplace,
- bias,
- torch.nn.Conv2d,
- )
- class Conv3dNormActivation(ConvNormActivation):
- """
- Configurable block used for Convolution3d-Normalization-Activation blocks.
- Args:
- in_channels (int): Number of channels in the input video.
- out_channels (int): Number of channels produced by the Convolution-Normalization-Activation block
- kernel_size: (int, optional): Size of the convolving kernel. Default: 3
- stride (int, optional): Stride of the convolution. Default: 1
- padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will be calculated as ``padding = (kernel_size - 1) // 2 * dilation``
- groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
- norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer won't be used. Default: ``torch.nn.BatchNorm3d``
- activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU``
- dilation (int): Spacing between kernel elements. Default: 1
- inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
- bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
- """
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: Union[int, Tuple[int, int, int]] = 3,
- stride: Union[int, Tuple[int, int, int]] = 1,
- padding: Optional[Union[int, Tuple[int, int, int], str]] = None,
- groups: int = 1,
- norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm3d,
- activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
- dilation: Union[int, Tuple[int, int, int]] = 1,
- inplace: Optional[bool] = True,
- bias: Optional[bool] = None,
- ) -> None:
- super().__init__(
- in_channels,
- out_channels,
- kernel_size,
- stride,
- padding,
- groups,
- norm_layer,
- activation_layer,
- dilation,
- inplace,
- bias,
- torch.nn.Conv3d,
- )
- class SqueezeExcitation(torch.nn.Module):
- """
- This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507 (see Fig. 1).
- Parameters ``activation``, and ``scale_activation`` correspond to ``delta`` and ``sigma`` in eq. 3.
- Args:
- input_channels (int): Number of channels in the input image
- squeeze_channels (int): Number of squeeze channels
- activation (Callable[..., torch.nn.Module], optional): ``delta`` activation. Default: ``torch.nn.ReLU``
- scale_activation (Callable[..., torch.nn.Module]): ``sigma`` activation. Default: ``torch.nn.Sigmoid``
- """
- def __init__(
- self,
- input_channels: int,
- squeeze_channels: int,
- activation: Callable[..., torch.nn.Module] = torch.nn.ReLU,
- scale_activation: Callable[..., torch.nn.Module] = torch.nn.Sigmoid,
- ) -> None:
- super().__init__()
- _log_api_usage_once(self)
- self.avgpool = torch.nn.AdaptiveAvgPool2d(1)
- self.fc1 = torch.nn.Conv2d(input_channels, squeeze_channels, 1)
- self.fc2 = torch.nn.Conv2d(squeeze_channels, input_channels, 1)
- self.activation = activation()
- self.scale_activation = scale_activation()
- def _scale(self, input: Tensor) -> Tensor:
- scale = self.avgpool(input)
- scale = self.fc1(scale)
- scale = self.activation(scale)
- scale = self.fc2(scale)
- return self.scale_activation(scale)
- def forward(self, input: Tensor) -> Tensor:
- scale = self._scale(input)
- return scale * input
- class MLP(torch.nn.Sequential):
- """This block implements the multi-layer perceptron (MLP) module.
- Args:
- in_channels (int): Number of channels of the input
- hidden_channels (List[int]): List of the hidden channel dimensions
- norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the linear layer. If ``None`` this layer won't be used. Default: ``None``
- activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the linear layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU``
- inplace (bool, optional): Parameter for the activation layer, which can optionally do the operation in-place.
- Default is ``None``, which uses the respective default values of the ``activation_layer`` and Dropout layer.
- bias (bool): Whether to use bias in the linear layer. Default ``True``
- dropout (float): The probability for the dropout layer. Default: 0.0
- """
- def __init__(
- self,
- in_channels: int,
- hidden_channels: List[int],
- norm_layer: Optional[Callable[..., torch.nn.Module]] = None,
- activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
- inplace: Optional[bool] = None,
- bias: bool = True,
- dropout: float = 0.0,
- ):
- # The addition of `norm_layer` is inspired from the implementation of TorchMultimodal:
- # https://github.com/facebookresearch/multimodal/blob/5dec8a/torchmultimodal/modules/layers/mlp.py
- params = {} if inplace is None else {"inplace": inplace}
- layers = []
- in_dim = in_channels
- for hidden_dim in hidden_channels[:-1]:
- layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias))
- if norm_layer is not None:
- layers.append(norm_layer(hidden_dim))
- layers.append(activation_layer(**params))
- layers.append(torch.nn.Dropout(dropout, **params))
- in_dim = hidden_dim
- layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias))
- layers.append(torch.nn.Dropout(dropout, **params))
- super().__init__(*layers)
- _log_api_usage_once(self)
- class Permute(torch.nn.Module):
- """This module returns a view of the tensor input with its dimensions permuted.
- Args:
- dims (List[int]): The desired ordering of dimensions
- """
- def __init__(self, dims: List[int]):
- super().__init__()
- self.dims = dims
- def forward(self, x: Tensor) -> Tensor:
- return torch.permute(x, self.dims)
|