cascade_evaluation.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. import os
  2. import warnings
  3. import torch
  4. import torchvision
  5. import torchvision.prototype.models.depth.stereo
  6. import utils
  7. from torch.nn import functional as F
  8. from train import make_eval_loader
  9. from utils.metrics import AVAILABLE_METRICS
  10. from visualization import make_prediction_image_side_to_side
  11. def get_args_parser(add_help=True):
  12. import argparse
  13. parser = argparse.ArgumentParser(description="PyTorch Stereo Matching Evaluation", add_help=add_help)
  14. parser.add_argument("--dataset", type=str, default="middlebury2014-train", help="dataset to use")
  15. parser.add_argument("--dataset-root", type=str, default="", help="root of the dataset")
  16. parser.add_argument("--checkpoint", type=str, default="", help="path to weights")
  17. parser.add_argument("--weights", type=str, default=None, help="torchvision API weight")
  18. parser.add_argument(
  19. "--model",
  20. type=str,
  21. default="crestereo_base",
  22. help="which model to use if not speciffying a training checkpoint",
  23. )
  24. parser.add_argument("--img-folder", type=str, default="images")
  25. parser.add_argument("--batch-size", type=int, default=1, help="batch size")
  26. parser.add_argument("--workers", type=int, default=0, help="number of workers")
  27. parser.add_argument("--eval-size", type=int, nargs="+", default=[384, 512], help="resize size")
  28. parser.add_argument(
  29. "--norm-mean", type=float, nargs="+", default=[0.5, 0.5, 0.5], help="mean for image normalization"
  30. )
  31. parser.add_argument(
  32. "--norm-std", type=float, nargs="+", default=[0.5, 0.5, 0.5], help="std for image normalization"
  33. )
  34. parser.add_argument(
  35. "--use-grayscale", action="store_true", help="use grayscale images instead of RGB", default=False
  36. )
  37. parser.add_argument("--max-disparity", type=float, default=None, help="maximum disparity")
  38. parser.add_argument(
  39. "--interpolation-strategy",
  40. type=str,
  41. default="bilinear",
  42. help="interpolation strategy",
  43. choices=["bilinear", "bicubic", "mixed"],
  44. )
  45. parser.add_argument("--n_iterations", nargs="+", type=int, default=[10], help="number of recurent iterations")
  46. parser.add_argument("--n_cascades", nargs="+", type=int, default=[1], help="number of cascades")
  47. parser.add_argument(
  48. "--metrics",
  49. type=str,
  50. nargs="+",
  51. default=["mae", "rmse", "1px", "3px", "5px", "relepe"],
  52. help="metrics to log",
  53. choices=AVAILABLE_METRICS,
  54. )
  55. parser.add_argument("--mixed-precision", action="store_true", help="use mixed precision training")
  56. parser.add_argument("--world-size", type=int, default=1, help="number of distributed processes")
  57. parser.add_argument("--dist-url", type=str, default="env://", help="url used to set up distributed training")
  58. parser.add_argument("--device", type=str, default="cuda", help="device to use for training")
  59. parser.add_argument("--save-images", action="store_true", help="save images of the predictions")
  60. parser.add_argument("--padder-type", type=str, default="kitti", help="padder type", choices=["kitti", "sintel"])
  61. return parser
  62. def cascade_inference(model, image_left, image_right, iterations, cascades):
  63. # check that image size is divisible by 16 * (2 ** (cascades - 1))
  64. for image in [image_left, image_right]:
  65. if image.shape[-2] % ((2 ** (cascades - 1))) != 0:
  66. raise ValueError(
  67. f"image height is not divisible by {16 * (2 ** (cascades - 1))}. Image shape: {image.shape[-2]}"
  68. )
  69. if image.shape[-1] % ((2 ** (cascades - 1))) != 0:
  70. raise ValueError(
  71. f"image width is not divisible by {16 * (2 ** (cascades - 1))}. Image shape: {image.shape[-2]}"
  72. )
  73. left_image_pyramid = [image_left]
  74. right_image_pyramid = [image_right]
  75. for idx in range(0, cascades - 1):
  76. ds_factor = int(2 ** (idx + 1))
  77. ds_shape = (image_left.shape[-2] // ds_factor, image_left.shape[-1] // ds_factor)
  78. left_image_pyramid += F.interpolate(image_left, size=ds_shape, mode="bilinear", align_corners=True).unsqueeze(0)
  79. right_image_pyramid += F.interpolate(image_right, size=ds_shape, mode="bilinear", align_corners=True).unsqueeze(
  80. 0
  81. )
  82. flow_init = None
  83. for left_image, right_image in zip(reversed(left_image_pyramid), reversed(right_image_pyramid)):
  84. flow_pred = model(left_image, right_image, flow_init, num_iters=iterations)
  85. # flow pred is a list
  86. flow_init = flow_pred[-1]
  87. return flow_init
  88. @torch.inference_mode()
  89. def _evaluate(
  90. model,
  91. args,
  92. val_loader,
  93. *,
  94. padder_mode,
  95. print_freq=10,
  96. writer=None,
  97. step=None,
  98. iterations=10,
  99. cascades=1,
  100. batch_size=None,
  101. header=None,
  102. save_images=False,
  103. save_path="",
  104. ):
  105. """Helper function to compute various metrics (epe, etc.) for a model on a given dataset.
  106. We process as many samples as possible with ddp.
  107. """
  108. model.eval()
  109. header = header or "Test:"
  110. device = torch.device(args.device)
  111. metric_logger = utils.MetricLogger(delimiter=" ")
  112. iterations = iterations or args.recurrent_updates
  113. logger = utils.MetricLogger()
  114. for meter_name in args.metrics:
  115. logger.add_meter(meter_name, fmt="{global_avg:.4f}")
  116. if "fl-all" not in args.metrics:
  117. logger.add_meter("fl-all", fmt="{global_avg:.4f}")
  118. num_processed_samples = 0
  119. with torch.cuda.amp.autocast(enabled=args.mixed_precision, dtype=torch.float16):
  120. batch_idx = 0
  121. for blob in metric_logger.log_every(val_loader, print_freq, header):
  122. image_left, image_right, disp_gt, valid_disp_mask = (x.to(device) for x in blob)
  123. padder = utils.InputPadder(image_left.shape, mode=padder_mode)
  124. image_left, image_right = padder.pad(image_left, image_right)
  125. disp_pred = cascade_inference(model, image_left, image_right, iterations, cascades)
  126. disp_pred = disp_pred[:, :1, :, :]
  127. disp_pred = padder.unpad(disp_pred)
  128. if save_images:
  129. if args.distributed:
  130. rank_prefix = args.rank
  131. else:
  132. rank_prefix = 0
  133. make_prediction_image_side_to_side(
  134. disp_pred, disp_gt, valid_disp_mask, save_path, prefix=f"batch_{rank_prefix}_{batch_idx}"
  135. )
  136. metrics, _ = utils.compute_metrics(disp_pred, disp_gt, valid_disp_mask, metrics=logger.meters.keys())
  137. num_processed_samples += image_left.shape[0]
  138. for name in metrics:
  139. logger.meters[name].update(metrics[name], n=1)
  140. batch_idx += 1
  141. num_processed_samples = utils.reduce_across_processes(num_processed_samples) / args.world_size
  142. print("Num_processed_samples: ", num_processed_samples)
  143. if (
  144. hasattr(val_loader.dataset, "__len__")
  145. and len(val_loader.dataset) != num_processed_samples
  146. and torch.distributed.get_rank() == 0
  147. ):
  148. warnings.warn(
  149. f"Number of processed samples {num_processed_samples} is different"
  150. f"from the dataset size {len(val_loader.dataset)}. This may happen if"
  151. "the dataset is not divisible by the batch size. Try lowering the batch size for more accurate results."
  152. )
  153. if writer is not None and args.rank == 0:
  154. for meter_name, meter_value in logger.meters.items():
  155. scalar_name = f"{meter_name} {header}"
  156. writer.add_scalar(scalar_name, meter_value.avg, step)
  157. logger.synchronize_between_processes()
  158. print(header, logger)
  159. logger_metrics = {k: v.global_avg for k, v in logger.meters.items()}
  160. return logger_metrics
  161. def evaluate(model, loader, args, writer=None, step=None):
  162. os.makedirs(args.img_folder, exist_ok=True)
  163. checkpoint_name = os.path.basename(args.checkpoint) or args.weights
  164. image_checkpoint_folder = os.path.join(args.img_folder, checkpoint_name)
  165. metrics = {}
  166. base_image_folder = os.path.join(image_checkpoint_folder, args.dataset)
  167. os.makedirs(base_image_folder, exist_ok=True)
  168. for n_cascades in args.n_cascades:
  169. for n_iters in args.n_iterations:
  170. config = f"{n_cascades}c_{n_iters}i"
  171. config_image_folder = os.path.join(base_image_folder, config)
  172. os.makedirs(config_image_folder, exist_ok=True)
  173. metrics[config] = _evaluate(
  174. model,
  175. args,
  176. loader,
  177. padder_mode=args.padder_type,
  178. header=f"{args.dataset} evaluation@ size:{args.eval_size} n_cascades:{n_cascades} n_iters:{n_iters}",
  179. batch_size=args.batch_size,
  180. writer=writer,
  181. step=step,
  182. iterations=n_iters,
  183. cascades=n_cascades,
  184. save_path=config_image_folder,
  185. save_images=args.save_images,
  186. )
  187. metric_log = []
  188. metric_log_dict = {}
  189. # print the final results
  190. for config in metrics:
  191. config_tokens = config.split("_")
  192. config_iters = config_tokens[1][:-1]
  193. config_cascades = config_tokens[0][:-1]
  194. metric_log_dict[config_cascades] = metric_log_dict.get(config_cascades, {})
  195. metric_log_dict[config_cascades][config_iters] = metrics[config]
  196. evaluation_str = f"{args.dataset} evaluation@ size:{args.eval_size} n_cascades:{config_cascades} recurrent_updates:{config_iters}"
  197. metrics_str = f"Metrics: {metrics[config]}"
  198. metric_log.extend([evaluation_str, metrics_str])
  199. print(evaluation_str)
  200. print(metrics_str)
  201. eval_log_name = f"{checkpoint_name.replace('.pth', '')}_eval.log"
  202. print("Saving eval log to: ", eval_log_name)
  203. with open(eval_log_name, "w") as f:
  204. f.write(f"Dataset: {args.dataset} @size: {args.eval_size}:\n")
  205. # write the dict line by line for each key, and each value in the keys
  206. for config_cascades in metric_log_dict:
  207. f.write("{\n")
  208. f.write(f"\t{config_cascades}: {{\n")
  209. for config_iters in metric_log_dict[config_cascades]:
  210. # convert every metric to 4 decimal places
  211. metrics = metric_log_dict[config_cascades][config_iters]
  212. metrics = {k: float(f"{v:.3f}") for k, v in metrics.items()}
  213. f.write(f"\t\t{config_iters}: {metrics}\n")
  214. f.write("\t},\n")
  215. f.write("}\n")
  216. def load_checkpoint(args):
  217. utils.setup_ddp(args)
  218. if not args.weights:
  219. checkpoint = torch.load(args.checkpoint, map_location=torch.device("cpu"))
  220. if "model" in checkpoint:
  221. experiment_args = checkpoint["args"]
  222. model = torchvision.prototype.models.depth.stereo.__dict__[experiment_args.model](weights=None)
  223. model.load_state_dict(checkpoint["model"])
  224. else:
  225. model = torchvision.prototype.models.depth.stereo.__dict__[args.model](weights=None)
  226. model.load_state_dict(checkpoint)
  227. # set the appropriate devices
  228. if args.distributed and args.device == "cpu":
  229. raise ValueError("The device must be cuda if we want to run in distributed mode using torchrun")
  230. device = torch.device(args.device)
  231. else:
  232. model = torchvision.prototype.models.depth.stereo.__dict__[args.model](weights=args.weights)
  233. # convert to DDP if need be
  234. if args.distributed:
  235. model = model.to(args.device)
  236. model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
  237. else:
  238. model.to(device)
  239. return model
  240. def main(args):
  241. model = load_checkpoint(args)
  242. loader = make_eval_loader(args.dataset, args)
  243. evaluate(model, loader, args)
  244. if __name__ == "__main__":
  245. args = get_args_parser().parse_args()
  246. main(args)