deform_conv.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. import math
  2. from typing import Optional, Tuple
  3. import torch
  4. from torch import nn, Tensor
  5. from torch.nn import init
  6. from torch.nn.modules.utils import _pair
  7. from torch.nn.parameter import Parameter
  8. from torchvision.extension import _assert_has_ops
  9. from ..utils import _log_api_usage_once
  10. def deform_conv2d(
  11. input: Tensor,
  12. offset: Tensor,
  13. weight: Tensor,
  14. bias: Optional[Tensor] = None,
  15. stride: Tuple[int, int] = (1, 1),
  16. padding: Tuple[int, int] = (0, 0),
  17. dilation: Tuple[int, int] = (1, 1),
  18. mask: Optional[Tensor] = None,
  19. ) -> Tensor:
  20. r"""
  21. Performs Deformable Convolution v2, described in
  22. `Deformable ConvNets v2: More Deformable, Better Results
  23. <https://arxiv.org/abs/1811.11168>`__ if :attr:`mask` is not ``None`` and
  24. Performs Deformable Convolution, described in
  25. `Deformable Convolutional Networks
  26. <https://arxiv.org/abs/1703.06211>`__ if :attr:`mask` is ``None``.
  27. Args:
  28. input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor
  29. offset (Tensor[batch_size, 2 * offset_groups * kernel_height * kernel_width, out_height, out_width]):
  30. offsets to be applied for each position in the convolution kernel.
  31. weight (Tensor[out_channels, in_channels // groups, kernel_height, kernel_width]): convolution weights,
  32. split into groups of size (in_channels // groups)
  33. bias (Tensor[out_channels]): optional bias of shape (out_channels,). Default: None
  34. stride (int or Tuple[int, int]): distance between convolution centers. Default: 1
  35. padding (int or Tuple[int, int]): height/width of padding of zeroes around
  36. each image. Default: 0
  37. dilation (int or Tuple[int, int]): the spacing between kernel elements. Default: 1
  38. mask (Tensor[batch_size, offset_groups * kernel_height * kernel_width, out_height, out_width]):
  39. masks to be applied for each position in the convolution kernel. Default: None
  40. Returns:
  41. Tensor[batch_sz, out_channels, out_h, out_w]: result of convolution
  42. Examples::
  43. >>> input = torch.rand(4, 3, 10, 10)
  44. >>> kh, kw = 3, 3
  45. >>> weight = torch.rand(5, 3, kh, kw)
  46. >>> # offset and mask should have the same spatial size as the output
  47. >>> # of the convolution. In this case, for an input of 10, stride of 1
  48. >>> # and kernel size of 3, without padding, the output size is 8
  49. >>> offset = torch.rand(4, 2 * kh * kw, 8, 8)
  50. >>> mask = torch.rand(4, kh * kw, 8, 8)
  51. >>> out = deform_conv2d(input, offset, weight, mask=mask)
  52. >>> print(out.shape)
  53. >>> # returns
  54. >>> torch.Size([4, 5, 8, 8])
  55. """
  56. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  57. _log_api_usage_once(deform_conv2d)
  58. _assert_has_ops()
  59. out_channels = weight.shape[0]
  60. use_mask = mask is not None
  61. if mask is None:
  62. mask = torch.zeros((input.shape[0], 1), device=input.device, dtype=input.dtype)
  63. if bias is None:
  64. bias = torch.zeros(out_channels, device=input.device, dtype=input.dtype)
  65. stride_h, stride_w = _pair(stride)
  66. pad_h, pad_w = _pair(padding)
  67. dil_h, dil_w = _pair(dilation)
  68. weights_h, weights_w = weight.shape[-2:]
  69. _, n_in_channels, _, _ = input.shape
  70. n_offset_grps = offset.shape[1] // (2 * weights_h * weights_w)
  71. n_weight_grps = n_in_channels // weight.shape[1]
  72. if n_offset_grps == 0:
  73. raise RuntimeError(
  74. "the shape of the offset tensor at dimension 1 is not valid. It should "
  75. "be a multiple of 2 * weight.size[2] * weight.size[3].\n"
  76. f"Got offset.shape[1]={offset.shape[1]}, while 2 * weight.size[2] * weight.size[3]={2 * weights_h * weights_w}"
  77. )
  78. return torch.ops.torchvision.deform_conv2d(
  79. input,
  80. weight,
  81. offset,
  82. mask,
  83. bias,
  84. stride_h,
  85. stride_w,
  86. pad_h,
  87. pad_w,
  88. dil_h,
  89. dil_w,
  90. n_weight_grps,
  91. n_offset_grps,
  92. use_mask,
  93. )
  94. class DeformConv2d(nn.Module):
  95. """
  96. See :func:`deform_conv2d`.
  97. """
  98. def __init__(
  99. self,
  100. in_channels: int,
  101. out_channels: int,
  102. kernel_size: int,
  103. stride: int = 1,
  104. padding: int = 0,
  105. dilation: int = 1,
  106. groups: int = 1,
  107. bias: bool = True,
  108. ):
  109. super().__init__()
  110. _log_api_usage_once(self)
  111. if in_channels % groups != 0:
  112. raise ValueError("in_channels must be divisible by groups")
  113. if out_channels % groups != 0:
  114. raise ValueError("out_channels must be divisible by groups")
  115. self.in_channels = in_channels
  116. self.out_channels = out_channels
  117. self.kernel_size = _pair(kernel_size)
  118. self.stride = _pair(stride)
  119. self.padding = _pair(padding)
  120. self.dilation = _pair(dilation)
  121. self.groups = groups
  122. self.weight = Parameter(
  123. torch.empty(out_channels, in_channels // groups, self.kernel_size[0], self.kernel_size[1])
  124. )
  125. if bias:
  126. self.bias = Parameter(torch.empty(out_channels))
  127. else:
  128. self.register_parameter("bias", None)
  129. self.reset_parameters()
  130. def reset_parameters(self) -> None:
  131. init.kaiming_uniform_(self.weight, a=math.sqrt(5))
  132. if self.bias is not None:
  133. fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
  134. bound = 1 / math.sqrt(fan_in)
  135. init.uniform_(self.bias, -bound, bound)
  136. def forward(self, input: Tensor, offset: Tensor, mask: Optional[Tensor] = None) -> Tensor:
  137. """
  138. Args:
  139. input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor
  140. offset (Tensor[batch_size, 2 * offset_groups * kernel_height * kernel_width, out_height, out_width]):
  141. offsets to be applied for each position in the convolution kernel.
  142. mask (Tensor[batch_size, offset_groups * kernel_height * kernel_width, out_height, out_width]):
  143. masks to be applied for each position in the convolution kernel.
  144. """
  145. return deform_conv2d(
  146. input,
  147. offset,
  148. self.weight,
  149. self.bias,
  150. stride=self.stride,
  151. padding=self.padding,
  152. dilation=self.dilation,
  153. mask=mask,
  154. )
  155. def __repr__(self) -> str:
  156. s = (
  157. f"{self.__class__.__name__}("
  158. f"{self.in_channels}"
  159. f", {self.out_channels}"
  160. f", kernel_size={self.kernel_size}"
  161. f", stride={self.stride}"
  162. )
  163. s += f", padding={self.padding}" if self.padding != (0, 0) else ""
  164. s += f", dilation={self.dilation}" if self.dilation != (1, 1) else ""
  165. s += f", groups={self.groups}" if self.groups != 1 else ""
  166. s += ", bias=False" if self.bias is None else ""
  167. s += ")"
  168. return s