visualization.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import os
  2. from typing import List
  3. import numpy as np
  4. import torch
  5. from torch import Tensor
  6. from torchvision.utils import make_grid
  7. @torch.no_grad()
  8. def make_disparity_image(disparity: Tensor):
  9. # normalize image to [0, 1]
  10. disparity = disparity.detach().cpu()
  11. disparity = (disparity - disparity.min()) / (disparity.max() - disparity.min())
  12. return disparity
  13. @torch.no_grad()
  14. def make_disparity_image_pairs(disparity: Tensor, image: Tensor):
  15. disparity = make_disparity_image(disparity)
  16. # image is in [-1, 1], bring it to [0, 1]
  17. image = image.detach().cpu()
  18. image = image * 0.5 + 0.5
  19. return disparity, image
  20. @torch.no_grad()
  21. def make_disparity_sequence(disparities: List[Tensor]):
  22. # convert each disparity to [0, 1]
  23. for idx, disparity_batch in enumerate(disparities):
  24. disparities[idx] = torch.stack(list(map(make_disparity_image, disparity_batch)))
  25. # make the list into a batch
  26. disparity_sequences = torch.stack(disparities)
  27. return disparity_sequences
  28. @torch.no_grad()
  29. def make_pair_grid(*inputs, orientation="horizontal"):
  30. # make a grid of images with the outputs and references side by side
  31. if orientation == "horizontal":
  32. # interleave the outputs and references
  33. canvas = torch.zeros_like(inputs[0])
  34. canvas = torch.cat([canvas] * len(inputs), dim=0)
  35. size = len(inputs)
  36. for idx, inp in enumerate(inputs):
  37. canvas[idx::size, ...] = inp
  38. grid = make_grid(canvas, nrow=len(inputs), padding=16, normalize=True, scale_each=True)
  39. elif orientation == "vertical":
  40. # interleave the outputs and references
  41. canvas = torch.cat(inputs, dim=0)
  42. size = len(inputs)
  43. for idx, inp in enumerate(inputs):
  44. canvas[idx::size, ...] = inp
  45. grid = make_grid(canvas, nrow=len(inputs[0]), padding=16, normalize=True, scale_each=True)
  46. else:
  47. raise ValueError("Unknown orientation: {}".format(orientation))
  48. return grid
  49. @torch.no_grad()
  50. def make_training_sample_grid(
  51. left_images: Tensor,
  52. right_images: Tensor,
  53. disparities: Tensor,
  54. masks: Tensor,
  55. predictions: List[Tensor],
  56. ) -> np.ndarray:
  57. # detach images and renormalize to [0, 1]
  58. images_left = left_images.detach().cpu() * 0.5 + 0.5
  59. images_right = right_images.detach().cpu() * 0.5 + 0.5
  60. # detach the disparties and predictions
  61. disparities = disparities.detach().cpu()
  62. predictions = predictions[-1].detach().cpu()
  63. # keep only the first channel of pixels, and repeat it 3 times
  64. disparities = disparities[:, :1, ...].repeat(1, 3, 1, 1)
  65. predictions = predictions[:, :1, ...].repeat(1, 3, 1, 1)
  66. # unsqueeze and repeat the masks
  67. masks = masks.detach().cpu().unsqueeze(1).repeat(1, 3, 1, 1)
  68. # make a grid that will self normalize across the batch
  69. pred_grid = make_pair_grid(images_left, images_right, masks, disparities, predictions, orientation="horizontal")
  70. pred_grid = pred_grid.permute(1, 2, 0).numpy()
  71. pred_grid = (pred_grid * 255).astype(np.uint8)
  72. return pred_grid
  73. @torch.no_grad()
  74. def make_disparity_sequence_grid(predictions: List[Tensor], disparities: Tensor) -> np.ndarray:
  75. # right most we will be adding the ground truth
  76. seq_len = len(predictions) + 1
  77. predictions = list(map(lambda x: x[:, :1, :, :].detach().cpu(), predictions + [disparities]))
  78. sequence = make_disparity_sequence(predictions)
  79. # swap axes to have the in the correct order for each batch sample
  80. sequence = torch.swapaxes(sequence, 0, 1).contiguous().reshape(-1, 1, disparities.shape[-2], disparities.shape[-1])
  81. sequence = make_grid(sequence, nrow=seq_len, padding=16, normalize=True, scale_each=True)
  82. sequence = sequence.permute(1, 2, 0).numpy()
  83. sequence = (sequence * 255).astype(np.uint8)
  84. return sequence
  85. @torch.no_grad()
  86. def make_prediction_image_side_to_side(
  87. predictions: Tensor, disparities: Tensor, valid_mask: Tensor, save_path: str, prefix: str
  88. ) -> None:
  89. import matplotlib.pyplot as plt
  90. # normalize the predictions and disparities in [0, 1]
  91. predictions = (predictions - predictions.min()) / (predictions.max() - predictions.min())
  92. disparities = (disparities - disparities.min()) / (disparities.max() - disparities.min())
  93. predictions = predictions * valid_mask
  94. disparities = disparities * valid_mask
  95. predictions = predictions.detach().cpu()
  96. disparities = disparities.detach().cpu()
  97. for idx, (pred, gt) in enumerate(zip(predictions, disparities)):
  98. pred = pred.permute(1, 2, 0).numpy()
  99. gt = gt.permute(1, 2, 0).numpy()
  100. # plot pred and gt side by side
  101. fig, ax = plt.subplots(1, 2, figsize=(10, 5))
  102. ax[0].imshow(pred)
  103. ax[0].set_title("Prediction")
  104. ax[1].imshow(gt)
  105. ax[1].set_title("Ground Truth")
  106. save_name = os.path.join(save_path, "{}_{}.png".format(prefix, idx))
  107. plt.savefig(save_name)
  108. plt.close()