from typing import Optional import torch import torch.nn.functional as F from torch import Tensor def grid_sample(img: Tensor, absolute_grid: Tensor, mode: str = "bilinear", align_corners: Optional[bool] = None): """Same as torch's grid_sample, with absolute pixel coordinates instead of normalized coordinates.""" h, w = img.shape[-2:] xgrid, ygrid = absolute_grid.split([1, 1], dim=-1) xgrid = 2 * xgrid / (w - 1) - 1 # Adding condition if h > 1 to enable this function be reused in raft-stereo if h > 1: ygrid = 2 * ygrid / (h - 1) - 1 normalized_grid = torch.cat([xgrid, ygrid], dim=-1) return F.grid_sample(img, normalized_grid, mode=mode, align_corners=align_corners) def make_coords_grid(batch_size: int, h: int, w: int, device: str = "cpu"): device = torch.device(device) coords = torch.meshgrid(torch.arange(h, device=device), torch.arange(w, device=device), indexing="ij") coords = torch.stack(coords[::-1], dim=0).float() return coords[None].repeat(batch_size, 1, 1, 1) def upsample_flow(flow, up_mask: Optional[Tensor] = None, factor: int = 8): """Upsample flow by the input factor (default 8). If up_mask is None we just interpolate. If up_mask is specified, we upsample using a convex combination of its weights. See paper page 8 and appendix B. Note that in appendix B the picture assumes a downsample factor of 4 instead of 8. """ batch_size, num_channels, h, w = flow.shape new_h, new_w = h * factor, w * factor if up_mask is None: return factor * F.interpolate(flow, size=(new_h, new_w), mode="bilinear", align_corners=True) up_mask = up_mask.view(batch_size, 1, 9, factor, factor, h, w) up_mask = torch.softmax(up_mask, dim=2) # "convex" == weights sum to 1 upsampled_flow = F.unfold(factor * flow, kernel_size=3, padding=1).view(batch_size, num_channels, 9, 1, 1, h, w) upsampled_flow = torch.sum(up_mask * upsampled_flow, dim=2) return upsampled_flow.permute(0, 1, 4, 2, 5, 3).reshape(batch_size, num_channels, new_h, new_w)