_utils.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. from typing import Optional
  2. import torch
  3. import torch.nn.functional as F
  4. from torch import Tensor
  5. def grid_sample(img: Tensor, absolute_grid: Tensor, mode: str = "bilinear", align_corners: Optional[bool] = None):
  6. """Same as torch's grid_sample, with absolute pixel coordinates instead of normalized coordinates."""
  7. h, w = img.shape[-2:]
  8. xgrid, ygrid = absolute_grid.split([1, 1], dim=-1)
  9. xgrid = 2 * xgrid / (w - 1) - 1
  10. # Adding condition if h > 1 to enable this function be reused in raft-stereo
  11. if h > 1:
  12. ygrid = 2 * ygrid / (h - 1) - 1
  13. normalized_grid = torch.cat([xgrid, ygrid], dim=-1)
  14. return F.grid_sample(img, normalized_grid, mode=mode, align_corners=align_corners)
  15. def make_coords_grid(batch_size: int, h: int, w: int, device: str = "cpu"):
  16. device = torch.device(device)
  17. coords = torch.meshgrid(torch.arange(h, device=device), torch.arange(w, device=device), indexing="ij")
  18. coords = torch.stack(coords[::-1], dim=0).float()
  19. return coords[None].repeat(batch_size, 1, 1, 1)
  20. def upsample_flow(flow, up_mask: Optional[Tensor] = None, factor: int = 8):
  21. """Upsample flow by the input factor (default 8).
  22. If up_mask is None we just interpolate.
  23. If up_mask is specified, we upsample using a convex combination of its weights. See paper page 8 and appendix B.
  24. Note that in appendix B the picture assumes a downsample factor of 4 instead of 8.
  25. """
  26. batch_size, num_channels, h, w = flow.shape
  27. new_h, new_w = h * factor, w * factor
  28. if up_mask is None:
  29. return factor * F.interpolate(flow, size=(new_h, new_w), mode="bilinear", align_corners=True)
  30. up_mask = up_mask.view(batch_size, 1, 9, factor, factor, h, w)
  31. up_mask = torch.softmax(up_mask, dim=2) # "convex" == weights sum to 1
  32. upsampled_flow = F.unfold(factor * flow, kernel_size=3, padding=1).view(batch_size, num_channels, 9, 1, 1, h, w)
  33. upsampled_flow = torch.sum(up_mask * upsampled_flow, dim=2)
  34. return upsampled_flow.permute(0, 1, 4, 2, 5, 3).reshape(batch_size, num_channels, new_h, new_w)