123456789101112131415161718192021222324252627282930313233343536373839404142434445464748 |
- from typing import Optional
- import torch
- import torch.nn.functional as F
- from torch import Tensor
- def grid_sample(img: Tensor, absolute_grid: Tensor, mode: str = "bilinear", align_corners: Optional[bool] = None):
- """Same as torch's grid_sample, with absolute pixel coordinates instead of normalized coordinates."""
- h, w = img.shape[-2:]
- xgrid, ygrid = absolute_grid.split([1, 1], dim=-1)
- xgrid = 2 * xgrid / (w - 1) - 1
- # Adding condition if h > 1 to enable this function be reused in raft-stereo
- if h > 1:
- ygrid = 2 * ygrid / (h - 1) - 1
- normalized_grid = torch.cat([xgrid, ygrid], dim=-1)
- return F.grid_sample(img, normalized_grid, mode=mode, align_corners=align_corners)
- def make_coords_grid(batch_size: int, h: int, w: int, device: str = "cpu"):
- device = torch.device(device)
- coords = torch.meshgrid(torch.arange(h, device=device), torch.arange(w, device=device), indexing="ij")
- coords = torch.stack(coords[::-1], dim=0).float()
- return coords[None].repeat(batch_size, 1, 1, 1)
- def upsample_flow(flow, up_mask: Optional[Tensor] = None, factor: int = 8):
- """Upsample flow by the input factor (default 8).
- If up_mask is None we just interpolate.
- If up_mask is specified, we upsample using a convex combination of its weights. See paper page 8 and appendix B.
- Note that in appendix B the picture assumes a downsample factor of 4 instead of 8.
- """
- batch_size, num_channels, h, w = flow.shape
- new_h, new_w = h * factor, w * factor
- if up_mask is None:
- return factor * F.interpolate(flow, size=(new_h, new_w), mode="bilinear", align_corners=True)
- up_mask = up_mask.view(batch_size, 1, 9, factor, factor, h, w)
- up_mask = torch.softmax(up_mask, dim=2) # "convex" == weights sum to 1
- upsampled_flow = F.unfold(factor * flow, kernel_size=3, padding=1).view(batch_size, num_channels, 9, 1, 1, h, w)
- upsampled_flow = torch.sum(up_mask * upsampled_flow, dim=2)
- return upsampled_flow.permute(0, 1, 4, 2, 5, 3).reshape(batch_size, num_channels, new_h, new_w)
|