raft.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947
  1. from typing import List, Optional
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from torch import Tensor
  6. from torch.nn.modules.batchnorm import BatchNorm2d
  7. from torch.nn.modules.instancenorm import InstanceNorm2d
  8. from torchvision.ops import Conv2dNormActivation
  9. from ...transforms._presets import OpticalFlow
  10. from ...utils import _log_api_usage_once
  11. from .._api import register_model, Weights, WeightsEnum
  12. from .._utils import handle_legacy_interface
  13. from ._utils import grid_sample, make_coords_grid, upsample_flow
  14. __all__ = (
  15. "RAFT",
  16. "raft_large",
  17. "raft_small",
  18. "Raft_Large_Weights",
  19. "Raft_Small_Weights",
  20. )
  21. class ResidualBlock(nn.Module):
  22. """Slightly modified Residual block with extra relu and biases."""
  23. def __init__(self, in_channels, out_channels, *, norm_layer, stride=1, always_project: bool = False):
  24. super().__init__()
  25. # Note regarding bias=True:
  26. # Usually we can pass bias=False in conv layers followed by a norm layer.
  27. # But in the RAFT training reference, the BatchNorm2d layers are only activated for the first dataset,
  28. # and frozen for the rest of the training process (i.e. set as eval()). The bias term is thus still useful
  29. # for the rest of the datasets. Technically, we could remove the bias for other norm layers like Instance norm
  30. # because these aren't frozen, but we don't bother (also, we wouldn't be able to load the original weights).
  31. self.convnormrelu1 = Conv2dNormActivation(
  32. in_channels, out_channels, norm_layer=norm_layer, kernel_size=3, stride=stride, bias=True
  33. )
  34. self.convnormrelu2 = Conv2dNormActivation(
  35. out_channels, out_channels, norm_layer=norm_layer, kernel_size=3, bias=True
  36. )
  37. # make mypy happy
  38. self.downsample: nn.Module
  39. if stride == 1 and not always_project:
  40. self.downsample = nn.Identity()
  41. else:
  42. self.downsample = Conv2dNormActivation(
  43. in_channels,
  44. out_channels,
  45. norm_layer=norm_layer,
  46. kernel_size=1,
  47. stride=stride,
  48. bias=True,
  49. activation_layer=None,
  50. )
  51. self.relu = nn.ReLU(inplace=True)
  52. def forward(self, x):
  53. y = x
  54. y = self.convnormrelu1(y)
  55. y = self.convnormrelu2(y)
  56. x = self.downsample(x)
  57. return self.relu(x + y)
  58. class BottleneckBlock(nn.Module):
  59. """Slightly modified BottleNeck block (extra relu and biases)"""
  60. def __init__(self, in_channels, out_channels, *, norm_layer, stride=1):
  61. super().__init__()
  62. # See note in ResidualBlock for the reason behind bias=True
  63. self.convnormrelu1 = Conv2dNormActivation(
  64. in_channels, out_channels // 4, norm_layer=norm_layer, kernel_size=1, bias=True
  65. )
  66. self.convnormrelu2 = Conv2dNormActivation(
  67. out_channels // 4, out_channels // 4, norm_layer=norm_layer, kernel_size=3, stride=stride, bias=True
  68. )
  69. self.convnormrelu3 = Conv2dNormActivation(
  70. out_channels // 4, out_channels, norm_layer=norm_layer, kernel_size=1, bias=True
  71. )
  72. self.relu = nn.ReLU(inplace=True)
  73. if stride == 1:
  74. self.downsample = nn.Identity()
  75. else:
  76. self.downsample = Conv2dNormActivation(
  77. in_channels,
  78. out_channels,
  79. norm_layer=norm_layer,
  80. kernel_size=1,
  81. stride=stride,
  82. bias=True,
  83. activation_layer=None,
  84. )
  85. def forward(self, x):
  86. y = x
  87. y = self.convnormrelu1(y)
  88. y = self.convnormrelu2(y)
  89. y = self.convnormrelu3(y)
  90. x = self.downsample(x)
  91. return self.relu(x + y)
  92. class FeatureEncoder(nn.Module):
  93. """The feature encoder, used both as the actual feature encoder, and as the context encoder.
  94. It must downsample its input by 8.
  95. """
  96. def __init__(
  97. self, *, block=ResidualBlock, layers=(64, 64, 96, 128, 256), strides=(2, 1, 2, 2), norm_layer=nn.BatchNorm2d
  98. ):
  99. super().__init__()
  100. if len(layers) != 5:
  101. raise ValueError(f"The expected number of layers is 5, instead got {len(layers)}")
  102. # See note in ResidualBlock for the reason behind bias=True
  103. self.convnormrelu = Conv2dNormActivation(
  104. 3, layers[0], norm_layer=norm_layer, kernel_size=7, stride=strides[0], bias=True
  105. )
  106. self.layer1 = self._make_2_blocks(block, layers[0], layers[1], norm_layer=norm_layer, first_stride=strides[1])
  107. self.layer2 = self._make_2_blocks(block, layers[1], layers[2], norm_layer=norm_layer, first_stride=strides[2])
  108. self.layer3 = self._make_2_blocks(block, layers[2], layers[3], norm_layer=norm_layer, first_stride=strides[3])
  109. self.conv = nn.Conv2d(layers[3], layers[4], kernel_size=1)
  110. for m in self.modules():
  111. if isinstance(m, nn.Conv2d):
  112. nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
  113. elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d)):
  114. if m.weight is not None:
  115. nn.init.constant_(m.weight, 1)
  116. if m.bias is not None:
  117. nn.init.constant_(m.bias, 0)
  118. num_downsamples = len(list(filter(lambda s: s == 2, strides)))
  119. self.output_dim = layers[-1]
  120. self.downsample_factor = 2**num_downsamples
  121. def _make_2_blocks(self, block, in_channels, out_channels, norm_layer, first_stride):
  122. block1 = block(in_channels, out_channels, norm_layer=norm_layer, stride=first_stride)
  123. block2 = block(out_channels, out_channels, norm_layer=norm_layer, stride=1)
  124. return nn.Sequential(block1, block2)
  125. def forward(self, x):
  126. x = self.convnormrelu(x)
  127. x = self.layer1(x)
  128. x = self.layer2(x)
  129. x = self.layer3(x)
  130. x = self.conv(x)
  131. return x
  132. class MotionEncoder(nn.Module):
  133. """The motion encoder, part of the update block.
  134. Takes the current predicted flow and the correlation features as input and returns an encoded version of these.
  135. """
  136. def __init__(self, *, in_channels_corr, corr_layers=(256, 192), flow_layers=(128, 64), out_channels=128):
  137. super().__init__()
  138. if len(flow_layers) != 2:
  139. raise ValueError(f"The expected number of flow_layers is 2, instead got {len(flow_layers)}")
  140. if len(corr_layers) not in (1, 2):
  141. raise ValueError(f"The number of corr_layers should be 1 or 2, instead got {len(corr_layers)}")
  142. self.convcorr1 = Conv2dNormActivation(in_channels_corr, corr_layers[0], norm_layer=None, kernel_size=1)
  143. if len(corr_layers) == 2:
  144. self.convcorr2 = Conv2dNormActivation(corr_layers[0], corr_layers[1], norm_layer=None, kernel_size=3)
  145. else:
  146. self.convcorr2 = nn.Identity()
  147. self.convflow1 = Conv2dNormActivation(2, flow_layers[0], norm_layer=None, kernel_size=7)
  148. self.convflow2 = Conv2dNormActivation(flow_layers[0], flow_layers[1], norm_layer=None, kernel_size=3)
  149. # out_channels - 2 because we cat the flow (2 channels) at the end
  150. self.conv = Conv2dNormActivation(
  151. corr_layers[-1] + flow_layers[-1], out_channels - 2, norm_layer=None, kernel_size=3
  152. )
  153. self.out_channels = out_channels
  154. def forward(self, flow, corr_features):
  155. corr = self.convcorr1(corr_features)
  156. corr = self.convcorr2(corr)
  157. flow_orig = flow
  158. flow = self.convflow1(flow)
  159. flow = self.convflow2(flow)
  160. corr_flow = torch.cat([corr, flow], dim=1)
  161. corr_flow = self.conv(corr_flow)
  162. return torch.cat([corr_flow, flow_orig], dim=1)
  163. class ConvGRU(nn.Module):
  164. """Convolutional Gru unit."""
  165. def __init__(self, *, input_size, hidden_size, kernel_size, padding):
  166. super().__init__()
  167. self.convz = nn.Conv2d(hidden_size + input_size, hidden_size, kernel_size=kernel_size, padding=padding)
  168. self.convr = nn.Conv2d(hidden_size + input_size, hidden_size, kernel_size=kernel_size, padding=padding)
  169. self.convq = nn.Conv2d(hidden_size + input_size, hidden_size, kernel_size=kernel_size, padding=padding)
  170. def forward(self, h, x):
  171. hx = torch.cat([h, x], dim=1)
  172. z = torch.sigmoid(self.convz(hx))
  173. r = torch.sigmoid(self.convr(hx))
  174. q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))
  175. h = (1 - z) * h + z * q
  176. return h
  177. def _pass_through_h(h, _):
  178. # Declared here for torchscript
  179. return h
  180. class RecurrentBlock(nn.Module):
  181. """Recurrent block, part of the update block.
  182. Takes the current hidden state and the concatenation of (motion encoder output, context) as input.
  183. Returns an updated hidden state.
  184. """
  185. def __init__(self, *, input_size, hidden_size, kernel_size=((1, 5), (5, 1)), padding=((0, 2), (2, 0))):
  186. super().__init__()
  187. if len(kernel_size) != len(padding):
  188. raise ValueError(
  189. f"kernel_size should have the same length as padding, instead got len(kernel_size) = {len(kernel_size)} and len(padding) = {len(padding)}"
  190. )
  191. if len(kernel_size) not in (1, 2):
  192. raise ValueError(f"kernel_size should either 1 or 2, instead got {len(kernel_size)}")
  193. self.convgru1 = ConvGRU(
  194. input_size=input_size, hidden_size=hidden_size, kernel_size=kernel_size[0], padding=padding[0]
  195. )
  196. if len(kernel_size) == 2:
  197. self.convgru2 = ConvGRU(
  198. input_size=input_size, hidden_size=hidden_size, kernel_size=kernel_size[1], padding=padding[1]
  199. )
  200. else:
  201. self.convgru2 = _pass_through_h
  202. self.hidden_size = hidden_size
  203. def forward(self, h, x):
  204. h = self.convgru1(h, x)
  205. h = self.convgru2(h, x)
  206. return h
  207. class FlowHead(nn.Module):
  208. """Flow head, part of the update block.
  209. Takes the hidden state of the recurrent unit as input, and outputs the predicted "delta flow".
  210. """
  211. def __init__(self, *, in_channels, hidden_size):
  212. super().__init__()
  213. self.conv1 = nn.Conv2d(in_channels, hidden_size, 3, padding=1)
  214. self.conv2 = nn.Conv2d(hidden_size, 2, 3, padding=1)
  215. self.relu = nn.ReLU(inplace=True)
  216. def forward(self, x):
  217. return self.conv2(self.relu(self.conv1(x)))
  218. class UpdateBlock(nn.Module):
  219. """The update block which contains the motion encoder, the recurrent block, and the flow head.
  220. It must expose a ``hidden_state_size`` attribute which is the hidden state size of its recurrent block.
  221. """
  222. def __init__(self, *, motion_encoder, recurrent_block, flow_head):
  223. super().__init__()
  224. self.motion_encoder = motion_encoder
  225. self.recurrent_block = recurrent_block
  226. self.flow_head = flow_head
  227. self.hidden_state_size = recurrent_block.hidden_size
  228. def forward(self, hidden_state, context, corr_features, flow):
  229. motion_features = self.motion_encoder(flow, corr_features)
  230. x = torch.cat([context, motion_features], dim=1)
  231. hidden_state = self.recurrent_block(hidden_state, x)
  232. delta_flow = self.flow_head(hidden_state)
  233. return hidden_state, delta_flow
  234. class MaskPredictor(nn.Module):
  235. """Mask predictor to be used when upsampling the predicted flow.
  236. It takes the hidden state of the recurrent unit as input and outputs the mask.
  237. This is not used in the raft-small model.
  238. """
  239. def __init__(self, *, in_channels, hidden_size, multiplier=0.25):
  240. super().__init__()
  241. self.convrelu = Conv2dNormActivation(in_channels, hidden_size, norm_layer=None, kernel_size=3)
  242. # 8 * 8 * 9 because the predicted flow is downsampled by 8, from the downsampling of the initial FeatureEncoder,
  243. # and we interpolate with all 9 surrounding neighbors. See paper and appendix B.
  244. self.conv = nn.Conv2d(hidden_size, 8 * 8 * 9, 1, padding=0)
  245. # In the original code, they use a factor of 0.25 to "downweight the gradients" of that branch.
  246. # See e.g. https://github.com/princeton-vl/RAFT/issues/119#issuecomment-953950419
  247. # or https://github.com/princeton-vl/RAFT/issues/24.
  248. # It doesn't seem to affect epe significantly and can likely be set to 1.
  249. self.multiplier = multiplier
  250. def forward(self, x):
  251. x = self.convrelu(x)
  252. x = self.conv(x)
  253. return self.multiplier * x
  254. class CorrBlock(nn.Module):
  255. """The correlation block.
  256. Creates a correlation pyramid with ``num_levels`` levels from the outputs of the feature encoder,
  257. and then indexes from this pyramid to create correlation features.
  258. The "indexing" of a given centroid pixel x' is done by concatenating its surrounding neighbors that
  259. are within a ``radius``, according to the infinity norm (see paper section 3.2).
  260. Note: typo in the paper, it should be infinity norm, not 1-norm.
  261. """
  262. def __init__(self, *, num_levels: int = 4, radius: int = 4):
  263. super().__init__()
  264. self.num_levels = num_levels
  265. self.radius = radius
  266. self.corr_pyramid: List[Tensor] = [torch.tensor(0)] # useless, but torchscript is otherwise confused :')
  267. # The neighborhood of a centroid pixel x' is {x' + delta, ||delta||_inf <= radius}
  268. # so it's a square surrounding x', and its sides have a length of 2 * radius + 1
  269. # The paper claims that it's ||.||_1 instead of ||.||_inf but it's a typo:
  270. # https://github.com/princeton-vl/RAFT/issues/122
  271. self.out_channels = num_levels * (2 * radius + 1) ** 2
  272. def build_pyramid(self, fmap1, fmap2):
  273. """Build the correlation pyramid from two feature maps.
  274. The correlation volume is first computed as the dot product of each pair (pixel_in_fmap1, pixel_in_fmap2)
  275. The last 2 dimensions of the correlation volume are then pooled num_levels times at different resolutions
  276. to build the correlation pyramid.
  277. """
  278. if fmap1.shape != fmap2.shape:
  279. raise ValueError(
  280. f"Input feature maps should have the same shape, instead got {fmap1.shape} (fmap1.shape) != {fmap2.shape} (fmap2.shape)"
  281. )
  282. # Explaining min_fmap_size below: the fmaps are down-sampled (num_levels - 1) times by a factor of 2.
  283. # The last corr_volume most have at least 2 values (hence the 2* factor), otherwise grid_sample() would
  284. # produce nans in its output.
  285. min_fmap_size = 2 * (2 ** (self.num_levels - 1))
  286. if any(fmap_size < min_fmap_size for fmap_size in fmap1.shape[-2:]):
  287. raise ValueError(
  288. "Feature maps are too small to be down-sampled by the correlation pyramid. "
  289. f"H and W of feature maps should be at least {min_fmap_size}; got: {fmap1.shape[-2:]}. "
  290. "Remember that input images to the model are downsampled by 8, so that means their "
  291. f"dimensions should be at least 8 * {min_fmap_size} = {8 * min_fmap_size}."
  292. )
  293. corr_volume = self._compute_corr_volume(fmap1, fmap2)
  294. batch_size, h, w, num_channels, _, _ = corr_volume.shape # _, _ = h, w
  295. corr_volume = corr_volume.reshape(batch_size * h * w, num_channels, h, w)
  296. self.corr_pyramid = [corr_volume]
  297. for _ in range(self.num_levels - 1):
  298. corr_volume = F.avg_pool2d(corr_volume, kernel_size=2, stride=2)
  299. self.corr_pyramid.append(corr_volume)
  300. def index_pyramid(self, centroids_coords):
  301. """Return correlation features by indexing from the pyramid."""
  302. neighborhood_side_len = 2 * self.radius + 1 # see note in __init__ about out_channels
  303. di = torch.linspace(-self.radius, self.radius, neighborhood_side_len)
  304. dj = torch.linspace(-self.radius, self.radius, neighborhood_side_len)
  305. delta = torch.stack(torch.meshgrid(di, dj, indexing="ij"), dim=-1).to(centroids_coords.device)
  306. delta = delta.view(1, neighborhood_side_len, neighborhood_side_len, 2)
  307. batch_size, _, h, w = centroids_coords.shape # _ = 2
  308. centroids_coords = centroids_coords.permute(0, 2, 3, 1).reshape(batch_size * h * w, 1, 1, 2)
  309. indexed_pyramid = []
  310. for corr_volume in self.corr_pyramid:
  311. sampling_coords = centroids_coords + delta # end shape is (batch_size * h * w, side_len, side_len, 2)
  312. indexed_corr_volume = grid_sample(corr_volume, sampling_coords, align_corners=True, mode="bilinear").view(
  313. batch_size, h, w, -1
  314. )
  315. indexed_pyramid.append(indexed_corr_volume)
  316. centroids_coords = centroids_coords / 2
  317. corr_features = torch.cat(indexed_pyramid, dim=-1).permute(0, 3, 1, 2).contiguous()
  318. expected_output_shape = (batch_size, self.out_channels, h, w)
  319. if corr_features.shape != expected_output_shape:
  320. raise ValueError(
  321. f"Output shape of index pyramid is incorrect. Should be {expected_output_shape}, got {corr_features.shape}"
  322. )
  323. return corr_features
  324. def _compute_corr_volume(self, fmap1, fmap2):
  325. batch_size, num_channels, h, w = fmap1.shape
  326. fmap1 = fmap1.view(batch_size, num_channels, h * w)
  327. fmap2 = fmap2.view(batch_size, num_channels, h * w)
  328. corr = torch.matmul(fmap1.transpose(1, 2), fmap2)
  329. corr = corr.view(batch_size, h, w, 1, h, w)
  330. return corr / torch.sqrt(torch.tensor(num_channels))
  331. class RAFT(nn.Module):
  332. def __init__(self, *, feature_encoder, context_encoder, corr_block, update_block, mask_predictor=None):
  333. """RAFT model from
  334. `RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
  335. args:
  336. feature_encoder (nn.Module): The feature encoder. It must downsample the input by 8.
  337. Its input is the concatenation of ``image1`` and ``image2``.
  338. context_encoder (nn.Module): The context encoder. It must downsample the input by 8.
  339. Its input is ``image1``. As in the original implementation, its output will be split into 2 parts:
  340. - one part will be used as the actual "context", passed to the recurrent unit of the ``update_block``
  341. - one part will be used to initialize the hidden state of the recurrent unit of
  342. the ``update_block``
  343. These 2 parts are split according to the ``hidden_state_size`` of the ``update_block``, so the output
  344. of the ``context_encoder`` must be strictly greater than ``hidden_state_size``.
  345. corr_block (nn.Module): The correlation block, which creates a correlation pyramid from the output of the
  346. ``feature_encoder``, and then indexes from this pyramid to create correlation features. It must expose
  347. 2 methods:
  348. - a ``build_pyramid`` method that takes ``feature_map_1`` and ``feature_map_2`` as input (these are the
  349. output of the ``feature_encoder``).
  350. - a ``index_pyramid`` method that takes the coordinates of the centroid pixels as input, and returns
  351. the correlation features. See paper section 3.2.
  352. It must expose an ``out_channels`` attribute.
  353. update_block (nn.Module): The update block, which contains the motion encoder, the recurrent unit, and the
  354. flow head. It takes as input the hidden state of its recurrent unit, the context, the correlation
  355. features, and the current predicted flow. It outputs an updated hidden state, and the ``delta_flow``
  356. prediction (see paper appendix A). It must expose a ``hidden_state_size`` attribute.
  357. mask_predictor (nn.Module, optional): Predicts the mask that will be used to upsample the predicted flow.
  358. The output channel must be 8 * 8 * 9 - see paper section 3.3, and Appendix B.
  359. If ``None`` (default), the flow is upsampled using interpolation.
  360. """
  361. super().__init__()
  362. _log_api_usage_once(self)
  363. self.feature_encoder = feature_encoder
  364. self.context_encoder = context_encoder
  365. self.corr_block = corr_block
  366. self.update_block = update_block
  367. self.mask_predictor = mask_predictor
  368. if not hasattr(self.update_block, "hidden_state_size"):
  369. raise ValueError("The update_block parameter should expose a 'hidden_state_size' attribute.")
  370. def forward(self, image1, image2, num_flow_updates: int = 12):
  371. batch_size, _, h, w = image1.shape
  372. if (h, w) != image2.shape[-2:]:
  373. raise ValueError(f"input images should have the same shape, instead got ({h}, {w}) != {image2.shape[-2:]}")
  374. if not (h % 8 == 0) and (w % 8 == 0):
  375. raise ValueError(f"input image H and W should be divisible by 8, instead got {h} (h) and {w} (w)")
  376. fmaps = self.feature_encoder(torch.cat([image1, image2], dim=0))
  377. fmap1, fmap2 = torch.chunk(fmaps, chunks=2, dim=0)
  378. if fmap1.shape[-2:] != (h // 8, w // 8):
  379. raise ValueError("The feature encoder should downsample H and W by 8")
  380. self.corr_block.build_pyramid(fmap1, fmap2)
  381. context_out = self.context_encoder(image1)
  382. if context_out.shape[-2:] != (h // 8, w // 8):
  383. raise ValueError("The context encoder should downsample H and W by 8")
  384. # As in the original paper, the actual output of the context encoder is split in 2 parts:
  385. # - one part is used to initialize the hidden state of the recurent units of the update block
  386. # - the rest is the "actual" context.
  387. hidden_state_size = self.update_block.hidden_state_size
  388. out_channels_context = context_out.shape[1] - hidden_state_size
  389. if out_channels_context <= 0:
  390. raise ValueError(
  391. f"The context encoder outputs {context_out.shape[1]} channels, but it should have at strictly more than hidden_state={hidden_state_size} channels"
  392. )
  393. hidden_state, context = torch.split(context_out, [hidden_state_size, out_channels_context], dim=1)
  394. hidden_state = torch.tanh(hidden_state)
  395. context = F.relu(context)
  396. coords0 = make_coords_grid(batch_size, h // 8, w // 8).to(fmap1.device)
  397. coords1 = make_coords_grid(batch_size, h // 8, w // 8).to(fmap1.device)
  398. flow_predictions = []
  399. for _ in range(num_flow_updates):
  400. coords1 = coords1.detach() # Don't backpropagate gradients through this branch, see paper
  401. corr_features = self.corr_block.index_pyramid(centroids_coords=coords1)
  402. flow = coords1 - coords0
  403. hidden_state, delta_flow = self.update_block(hidden_state, context, corr_features, flow)
  404. coords1 = coords1 + delta_flow
  405. up_mask = None if self.mask_predictor is None else self.mask_predictor(hidden_state)
  406. upsampled_flow = upsample_flow(flow=(coords1 - coords0), up_mask=up_mask)
  407. flow_predictions.append(upsampled_flow)
  408. return flow_predictions
  409. _COMMON_META = {
  410. "min_size": (128, 128),
  411. }
  412. class Raft_Large_Weights(WeightsEnum):
  413. """The metrics reported here are as follows.
  414. ``epe`` is the "end-point-error" and indicates how far (in pixels) the
  415. predicted flow is from its true value. This is averaged over all pixels
  416. of all images. ``per_image_epe`` is similar, but the average is different:
  417. the epe is first computed on each image independently, and then averaged
  418. over all images. This corresponds to "Fl-epe" (sometimes written "F1-epe")
  419. in the original paper, and it's only used on Kitti. ``fl-all`` is also a
  420. Kitti-specific metric, defined by the author of the dataset and used for the
  421. Kitti leaderboard. It corresponds to the average of pixels whose epe is
  422. either <3px, or <5% of flow's 2-norm.
  423. """
  424. C_T_V1 = Weights(
  425. # Weights ported from https://github.com/princeton-vl/RAFT
  426. url="https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth",
  427. transforms=OpticalFlow,
  428. meta={
  429. **_COMMON_META,
  430. "num_params": 5257536,
  431. "recipe": "https://github.com/princeton-vl/RAFT",
  432. "_metrics": {
  433. "Sintel-Train-Cleanpass": {"epe": 1.4411},
  434. "Sintel-Train-Finalpass": {"epe": 2.7894},
  435. "Kitti-Train": {"per_image_epe": 5.0172, "fl_all": 17.4506},
  436. },
  437. "_ops": 211.007,
  438. "_file_size": 20.129,
  439. "_docs": """These weights were ported from the original paper. They
  440. are trained on :class:`~torchvision.datasets.FlyingChairs` +
  441. :class:`~torchvision.datasets.FlyingThings3D`.""",
  442. },
  443. )
  444. C_T_V2 = Weights(
  445. url="https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth",
  446. transforms=OpticalFlow,
  447. meta={
  448. **_COMMON_META,
  449. "num_params": 5257536,
  450. "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow",
  451. "_metrics": {
  452. "Sintel-Train-Cleanpass": {"epe": 1.3822},
  453. "Sintel-Train-Finalpass": {"epe": 2.7161},
  454. "Kitti-Train": {"per_image_epe": 4.5118, "fl_all": 16.0679},
  455. },
  456. "_ops": 211.007,
  457. "_file_size": 20.129,
  458. "_docs": """These weights were trained from scratch on
  459. :class:`~torchvision.datasets.FlyingChairs` +
  460. :class:`~torchvision.datasets.FlyingThings3D`.""",
  461. },
  462. )
  463. C_T_SKHT_V1 = Weights(
  464. # Weights ported from https://github.com/princeton-vl/RAFT
  465. url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V1-0b8c9e55.pth",
  466. transforms=OpticalFlow,
  467. meta={
  468. **_COMMON_META,
  469. "num_params": 5257536,
  470. "recipe": "https://github.com/princeton-vl/RAFT",
  471. "_metrics": {
  472. "Sintel-Test-Cleanpass": {"epe": 1.94},
  473. "Sintel-Test-Finalpass": {"epe": 3.18},
  474. },
  475. "_ops": 211.007,
  476. "_file_size": 20.129,
  477. "_docs": """
  478. These weights were ported from the original paper. They are
  479. trained on :class:`~torchvision.datasets.FlyingChairs` +
  480. :class:`~torchvision.datasets.FlyingThings3D` and fine-tuned on
  481. Sintel. The Sintel fine-tuning step is a combination of
  482. :class:`~torchvision.datasets.Sintel`,
  483. :class:`~torchvision.datasets.KittiFlow`,
  484. :class:`~torchvision.datasets.HD1K`, and
  485. :class:`~torchvision.datasets.FlyingThings3D` (clean pass).
  486. """,
  487. },
  488. )
  489. C_T_SKHT_V2 = Weights(
  490. url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pth",
  491. transforms=OpticalFlow,
  492. meta={
  493. **_COMMON_META,
  494. "num_params": 5257536,
  495. "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow",
  496. "_metrics": {
  497. "Sintel-Test-Cleanpass": {"epe": 1.819},
  498. "Sintel-Test-Finalpass": {"epe": 3.067},
  499. },
  500. "_ops": 211.007,
  501. "_file_size": 20.129,
  502. "_docs": """
  503. These weights were trained from scratch. They are
  504. pre-trained on :class:`~torchvision.datasets.FlyingChairs` +
  505. :class:`~torchvision.datasets.FlyingThings3D` and then
  506. fine-tuned on Sintel. The Sintel fine-tuning step is a
  507. combination of :class:`~torchvision.datasets.Sintel`,
  508. :class:`~torchvision.datasets.KittiFlow`,
  509. :class:`~torchvision.datasets.HD1K`, and
  510. :class:`~torchvision.datasets.FlyingThings3D` (clean pass).
  511. """,
  512. },
  513. )
  514. C_T_SKHT_K_V1 = Weights(
  515. # Weights ported from https://github.com/princeton-vl/RAFT
  516. url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V1-4a6a5039.pth",
  517. transforms=OpticalFlow,
  518. meta={
  519. **_COMMON_META,
  520. "num_params": 5257536,
  521. "recipe": "https://github.com/princeton-vl/RAFT",
  522. "_metrics": {
  523. "Kitti-Test": {"fl_all": 5.10},
  524. },
  525. "_ops": 211.007,
  526. "_file_size": 20.129,
  527. "_docs": """
  528. These weights were ported from the original paper. They are
  529. pre-trained on :class:`~torchvision.datasets.FlyingChairs` +
  530. :class:`~torchvision.datasets.FlyingThings3D`,
  531. fine-tuned on Sintel, and then fine-tuned on
  532. :class:`~torchvision.datasets.KittiFlow`. The Sintel fine-tuning
  533. step was described above.
  534. """,
  535. },
  536. )
  537. C_T_SKHT_K_V2 = Weights(
  538. url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V2-b5c70766.pth",
  539. transforms=OpticalFlow,
  540. meta={
  541. **_COMMON_META,
  542. "num_params": 5257536,
  543. "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow",
  544. "_metrics": {
  545. "Kitti-Test": {"fl_all": 5.19},
  546. },
  547. "_ops": 211.007,
  548. "_file_size": 20.129,
  549. "_docs": """
  550. These weights were trained from scratch. They are
  551. pre-trained on :class:`~torchvision.datasets.FlyingChairs` +
  552. :class:`~torchvision.datasets.FlyingThings3D`,
  553. fine-tuned on Sintel, and then fine-tuned on
  554. :class:`~torchvision.datasets.KittiFlow`. The Sintel fine-tuning
  555. step was described above.
  556. """,
  557. },
  558. )
  559. DEFAULT = C_T_SKHT_V2
  560. class Raft_Small_Weights(WeightsEnum):
  561. """The metrics reported here are as follows.
  562. ``epe`` is the "end-point-error" and indicates how far (in pixels) the
  563. predicted flow is from its true value. This is averaged over all pixels
  564. of all images. ``per_image_epe`` is similar, but the average is different:
  565. the epe is first computed on each image independently, and then averaged
  566. over all images. This corresponds to "Fl-epe" (sometimes written "F1-epe")
  567. in the original paper, and it's only used on Kitti. ``fl-all`` is also a
  568. Kitti-specific metric, defined by the author of the dataset and used for the
  569. Kitti leaderboard. It corresponds to the average of pixels whose epe is
  570. either <3px, or <5% of flow's 2-norm.
  571. """
  572. C_T_V1 = Weights(
  573. # Weights ported from https://github.com/princeton-vl/RAFT
  574. url="https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth",
  575. transforms=OpticalFlow,
  576. meta={
  577. **_COMMON_META,
  578. "num_params": 990162,
  579. "recipe": "https://github.com/princeton-vl/RAFT",
  580. "_metrics": {
  581. "Sintel-Train-Cleanpass": {"epe": 2.1231},
  582. "Sintel-Train-Finalpass": {"epe": 3.2790},
  583. "Kitti-Train": {"per_image_epe": 7.6557, "fl_all": 25.2801},
  584. },
  585. "_ops": 47.655,
  586. "_file_size": 3.821,
  587. "_docs": """These weights were ported from the original paper. They
  588. are trained on :class:`~torchvision.datasets.FlyingChairs` +
  589. :class:`~torchvision.datasets.FlyingThings3D`.""",
  590. },
  591. )
  592. C_T_V2 = Weights(
  593. url="https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth",
  594. transforms=OpticalFlow,
  595. meta={
  596. **_COMMON_META,
  597. "num_params": 990162,
  598. "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow",
  599. "_metrics": {
  600. "Sintel-Train-Cleanpass": {"epe": 1.9901},
  601. "Sintel-Train-Finalpass": {"epe": 3.2831},
  602. "Kitti-Train": {"per_image_epe": 7.5978, "fl_all": 25.2369},
  603. },
  604. "_ops": 47.655,
  605. "_file_size": 3.821,
  606. "_docs": """These weights were trained from scratch on
  607. :class:`~torchvision.datasets.FlyingChairs` +
  608. :class:`~torchvision.datasets.FlyingThings3D`.""",
  609. },
  610. )
  611. DEFAULT = C_T_V2
  612. def _raft(
  613. *,
  614. weights=None,
  615. progress=False,
  616. # Feature encoder
  617. feature_encoder_layers,
  618. feature_encoder_block,
  619. feature_encoder_norm_layer,
  620. # Context encoder
  621. context_encoder_layers,
  622. context_encoder_block,
  623. context_encoder_norm_layer,
  624. # Correlation block
  625. corr_block_num_levels,
  626. corr_block_radius,
  627. # Motion encoder
  628. motion_encoder_corr_layers,
  629. motion_encoder_flow_layers,
  630. motion_encoder_out_channels,
  631. # Recurrent block
  632. recurrent_block_hidden_state_size,
  633. recurrent_block_kernel_size,
  634. recurrent_block_padding,
  635. # Flow Head
  636. flow_head_hidden_size,
  637. # Mask predictor
  638. use_mask_predictor,
  639. **kwargs,
  640. ):
  641. feature_encoder = kwargs.pop("feature_encoder", None) or FeatureEncoder(
  642. block=feature_encoder_block, layers=feature_encoder_layers, norm_layer=feature_encoder_norm_layer
  643. )
  644. context_encoder = kwargs.pop("context_encoder", None) or FeatureEncoder(
  645. block=context_encoder_block, layers=context_encoder_layers, norm_layer=context_encoder_norm_layer
  646. )
  647. corr_block = kwargs.pop("corr_block", None) or CorrBlock(num_levels=corr_block_num_levels, radius=corr_block_radius)
  648. update_block = kwargs.pop("update_block", None)
  649. if update_block is None:
  650. motion_encoder = MotionEncoder(
  651. in_channels_corr=corr_block.out_channels,
  652. corr_layers=motion_encoder_corr_layers,
  653. flow_layers=motion_encoder_flow_layers,
  654. out_channels=motion_encoder_out_channels,
  655. )
  656. # See comments in forward pass of RAFT class about why we split the output of the context encoder
  657. out_channels_context = context_encoder_layers[-1] - recurrent_block_hidden_state_size
  658. recurrent_block = RecurrentBlock(
  659. input_size=motion_encoder.out_channels + out_channels_context,
  660. hidden_size=recurrent_block_hidden_state_size,
  661. kernel_size=recurrent_block_kernel_size,
  662. padding=recurrent_block_padding,
  663. )
  664. flow_head = FlowHead(in_channels=recurrent_block_hidden_state_size, hidden_size=flow_head_hidden_size)
  665. update_block = UpdateBlock(motion_encoder=motion_encoder, recurrent_block=recurrent_block, flow_head=flow_head)
  666. mask_predictor = kwargs.pop("mask_predictor", None)
  667. if mask_predictor is None and use_mask_predictor:
  668. mask_predictor = MaskPredictor(
  669. in_channels=recurrent_block_hidden_state_size,
  670. hidden_size=256,
  671. multiplier=0.25, # See comment in MaskPredictor about this
  672. )
  673. model = RAFT(
  674. feature_encoder=feature_encoder,
  675. context_encoder=context_encoder,
  676. corr_block=corr_block,
  677. update_block=update_block,
  678. mask_predictor=mask_predictor,
  679. **kwargs, # not really needed, all params should be consumed by now
  680. )
  681. if weights is not None:
  682. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  683. return model
  684. @register_model()
  685. @handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_SKHT_V2))
  686. def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs) -> RAFT:
  687. """RAFT model from
  688. `RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
  689. Please see the example below for a tutorial on how to use this model.
  690. Args:
  691. weights(:class:`~torchvision.models.optical_flow.Raft_Large_Weights`, optional): The
  692. pretrained weights to use. See
  693. :class:`~torchvision.models.optical_flow.Raft_Large_Weights`
  694. below for more details, and possible values. By default, no
  695. pre-trained weights are used.
  696. progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
  697. **kwargs: parameters passed to the ``torchvision.models.optical_flow.RAFT``
  698. base class. Please refer to the `source code
  699. <https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/raft.py>`_
  700. for more details about this class.
  701. .. autoclass:: torchvision.models.optical_flow.Raft_Large_Weights
  702. :members:
  703. """
  704. weights = Raft_Large_Weights.verify(weights)
  705. return _raft(
  706. weights=weights,
  707. progress=progress,
  708. # Feature encoder
  709. feature_encoder_layers=(64, 64, 96, 128, 256),
  710. feature_encoder_block=ResidualBlock,
  711. feature_encoder_norm_layer=InstanceNorm2d,
  712. # Context encoder
  713. context_encoder_layers=(64, 64, 96, 128, 256),
  714. context_encoder_block=ResidualBlock,
  715. context_encoder_norm_layer=BatchNorm2d,
  716. # Correlation block
  717. corr_block_num_levels=4,
  718. corr_block_radius=4,
  719. # Motion encoder
  720. motion_encoder_corr_layers=(256, 192),
  721. motion_encoder_flow_layers=(128, 64),
  722. motion_encoder_out_channels=128,
  723. # Recurrent block
  724. recurrent_block_hidden_state_size=128,
  725. recurrent_block_kernel_size=((1, 5), (5, 1)),
  726. recurrent_block_padding=((0, 2), (2, 0)),
  727. # Flow head
  728. flow_head_hidden_size=256,
  729. # Mask predictor
  730. use_mask_predictor=True,
  731. **kwargs,
  732. )
  733. @register_model()
  734. @handle_legacy_interface(weights=("pretrained", Raft_Small_Weights.C_T_V2))
  735. def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs) -> RAFT:
  736. """RAFT "small" model from
  737. `RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`__.
  738. Please see the example below for a tutorial on how to use this model.
  739. Args:
  740. weights(:class:`~torchvision.models.optical_flow.Raft_Small_Weights`, optional): The
  741. pretrained weights to use. See
  742. :class:`~torchvision.models.optical_flow.Raft_Small_Weights`
  743. below for more details, and possible values. By default, no
  744. pre-trained weights are used.
  745. progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
  746. **kwargs: parameters passed to the ``torchvision.models.optical_flow.RAFT``
  747. base class. Please refer to the `source code
  748. <https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/raft.py>`_
  749. for more details about this class.
  750. .. autoclass:: torchvision.models.optical_flow.Raft_Small_Weights
  751. :members:
  752. """
  753. weights = Raft_Small_Weights.verify(weights)
  754. return _raft(
  755. weights=weights,
  756. progress=progress,
  757. # Feature encoder
  758. feature_encoder_layers=(32, 32, 64, 96, 128),
  759. feature_encoder_block=BottleneckBlock,
  760. feature_encoder_norm_layer=InstanceNorm2d,
  761. # Context encoder
  762. context_encoder_layers=(32, 32, 64, 96, 160),
  763. context_encoder_block=BottleneckBlock,
  764. context_encoder_norm_layer=None,
  765. # Correlation block
  766. corr_block_num_levels=4,
  767. corr_block_radius=3,
  768. # Motion encoder
  769. motion_encoder_corr_layers=(96,),
  770. motion_encoder_flow_layers=(64, 32),
  771. motion_encoder_out_channels=82,
  772. # Recurrent block
  773. recurrent_block_hidden_state_size=96,
  774. recurrent_block_kernel_size=(3,),
  775. recurrent_block_padding=(1,),
  776. # Flow head
  777. flow_head_hidden_size=128,
  778. # Mask predictor
  779. use_mask_predictor=False,
  780. **kwargs,
  781. )