123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299 |
- import os
- import warnings
- import torch
- import torchvision
- import torchvision.prototype.models.depth.stereo
- import utils
- from torch.nn import functional as F
- from train import make_eval_loader
- from utils.metrics import AVAILABLE_METRICS
- from visualization import make_prediction_image_side_to_side
- def get_args_parser(add_help=True):
- import argparse
- parser = argparse.ArgumentParser(description="PyTorch Stereo Matching Evaluation", add_help=add_help)
- parser.add_argument("--dataset", type=str, default="middlebury2014-train", help="dataset to use")
- parser.add_argument("--dataset-root", type=str, default="", help="root of the dataset")
- parser.add_argument("--checkpoint", type=str, default="", help="path to weights")
- parser.add_argument("--weights", type=str, default=None, help="torchvision API weight")
- parser.add_argument(
- "--model",
- type=str,
- default="crestereo_base",
- help="which model to use if not speciffying a training checkpoint",
- )
- parser.add_argument("--img-folder", type=str, default="images")
- parser.add_argument("--batch-size", type=int, default=1, help="batch size")
- parser.add_argument("--workers", type=int, default=0, help="number of workers")
- parser.add_argument("--eval-size", type=int, nargs="+", default=[384, 512], help="resize size")
- parser.add_argument(
- "--norm-mean", type=float, nargs="+", default=[0.5, 0.5, 0.5], help="mean for image normalization"
- )
- parser.add_argument(
- "--norm-std", type=float, nargs="+", default=[0.5, 0.5, 0.5], help="std for image normalization"
- )
- parser.add_argument(
- "--use-grayscale", action="store_true", help="use grayscale images instead of RGB", default=False
- )
- parser.add_argument("--max-disparity", type=float, default=None, help="maximum disparity")
- parser.add_argument(
- "--interpolation-strategy",
- type=str,
- default="bilinear",
- help="interpolation strategy",
- choices=["bilinear", "bicubic", "mixed"],
- )
- parser.add_argument("--n_iterations", nargs="+", type=int, default=[10], help="number of recurent iterations")
- parser.add_argument("--n_cascades", nargs="+", type=int, default=[1], help="number of cascades")
- parser.add_argument(
- "--metrics",
- type=str,
- nargs="+",
- default=["mae", "rmse", "1px", "3px", "5px", "relepe"],
- help="metrics to log",
- choices=AVAILABLE_METRICS,
- )
- parser.add_argument("--mixed-precision", action="store_true", help="use mixed precision training")
- parser.add_argument("--world-size", type=int, default=1, help="number of distributed processes")
- parser.add_argument("--dist-url", type=str, default="env://", help="url used to set up distributed training")
- parser.add_argument("--device", type=str, default="cuda", help="device to use for training")
- parser.add_argument("--save-images", action="store_true", help="save images of the predictions")
- parser.add_argument("--padder-type", type=str, default="kitti", help="padder type", choices=["kitti", "sintel"])
- return parser
- def cascade_inference(model, image_left, image_right, iterations, cascades):
- # check that image size is divisible by 16 * (2 ** (cascades - 1))
- for image in [image_left, image_right]:
- if image.shape[-2] % ((2 ** (cascades - 1))) != 0:
- raise ValueError(
- f"image height is not divisible by {16 * (2 ** (cascades - 1))}. Image shape: {image.shape[-2]}"
- )
- if image.shape[-1] % ((2 ** (cascades - 1))) != 0:
- raise ValueError(
- f"image width is not divisible by {16 * (2 ** (cascades - 1))}. Image shape: {image.shape[-2]}"
- )
- left_image_pyramid = [image_left]
- right_image_pyramid = [image_right]
- for idx in range(0, cascades - 1):
- ds_factor = int(2 ** (idx + 1))
- ds_shape = (image_left.shape[-2] // ds_factor, image_left.shape[-1] // ds_factor)
- left_image_pyramid += F.interpolate(image_left, size=ds_shape, mode="bilinear", align_corners=True).unsqueeze(0)
- right_image_pyramid += F.interpolate(image_right, size=ds_shape, mode="bilinear", align_corners=True).unsqueeze(
- 0
- )
- flow_init = None
- for left_image, right_image in zip(reversed(left_image_pyramid), reversed(right_image_pyramid)):
- flow_pred = model(left_image, right_image, flow_init, num_iters=iterations)
- # flow pred is a list
- flow_init = flow_pred[-1]
- return flow_init
- @torch.inference_mode()
- def _evaluate(
- model,
- args,
- val_loader,
- *,
- padder_mode,
- print_freq=10,
- writer=None,
- step=None,
- iterations=10,
- cascades=1,
- batch_size=None,
- header=None,
- save_images=False,
- save_path="",
- ):
- """Helper function to compute various metrics (epe, etc.) for a model on a given dataset.
- We process as many samples as possible with ddp.
- """
- model.eval()
- header = header or "Test:"
- device = torch.device(args.device)
- metric_logger = utils.MetricLogger(delimiter=" ")
- iterations = iterations or args.recurrent_updates
- logger = utils.MetricLogger()
- for meter_name in args.metrics:
- logger.add_meter(meter_name, fmt="{global_avg:.4f}")
- if "fl-all" not in args.metrics:
- logger.add_meter("fl-all", fmt="{global_avg:.4f}")
- num_processed_samples = 0
- with torch.cuda.amp.autocast(enabled=args.mixed_precision, dtype=torch.float16):
- batch_idx = 0
- for blob in metric_logger.log_every(val_loader, print_freq, header):
- image_left, image_right, disp_gt, valid_disp_mask = (x.to(device) for x in blob)
- padder = utils.InputPadder(image_left.shape, mode=padder_mode)
- image_left, image_right = padder.pad(image_left, image_right)
- disp_pred = cascade_inference(model, image_left, image_right, iterations, cascades)
- disp_pred = disp_pred[:, :1, :, :]
- disp_pred = padder.unpad(disp_pred)
- if save_images:
- if args.distributed:
- rank_prefix = args.rank
- else:
- rank_prefix = 0
- make_prediction_image_side_to_side(
- disp_pred, disp_gt, valid_disp_mask, save_path, prefix=f"batch_{rank_prefix}_{batch_idx}"
- )
- metrics, _ = utils.compute_metrics(disp_pred, disp_gt, valid_disp_mask, metrics=logger.meters.keys())
- num_processed_samples += image_left.shape[0]
- for name in metrics:
- logger.meters[name].update(metrics[name], n=1)
- batch_idx += 1
- num_processed_samples = utils.reduce_across_processes(num_processed_samples) / args.world_size
- print("Num_processed_samples: ", num_processed_samples)
- if (
- hasattr(val_loader.dataset, "__len__")
- and len(val_loader.dataset) != num_processed_samples
- and torch.distributed.get_rank() == 0
- ):
- warnings.warn(
- f"Number of processed samples {num_processed_samples} is different"
- f"from the dataset size {len(val_loader.dataset)}. This may happen if"
- "the dataset is not divisible by the batch size. Try lowering the batch size for more accurate results."
- )
- if writer is not None and args.rank == 0:
- for meter_name, meter_value in logger.meters.items():
- scalar_name = f"{meter_name} {header}"
- writer.add_scalar(scalar_name, meter_value.avg, step)
- logger.synchronize_between_processes()
- print(header, logger)
- logger_metrics = {k: v.global_avg for k, v in logger.meters.items()}
- return logger_metrics
- def evaluate(model, loader, args, writer=None, step=None):
- os.makedirs(args.img_folder, exist_ok=True)
- checkpoint_name = os.path.basename(args.checkpoint) or args.weights
- image_checkpoint_folder = os.path.join(args.img_folder, checkpoint_name)
- metrics = {}
- base_image_folder = os.path.join(image_checkpoint_folder, args.dataset)
- os.makedirs(base_image_folder, exist_ok=True)
- for n_cascades in args.n_cascades:
- for n_iters in args.n_iterations:
- config = f"{n_cascades}c_{n_iters}i"
- config_image_folder = os.path.join(base_image_folder, config)
- os.makedirs(config_image_folder, exist_ok=True)
- metrics[config] = _evaluate(
- model,
- args,
- loader,
- padder_mode=args.padder_type,
- header=f"{args.dataset} evaluation@ size:{args.eval_size} n_cascades:{n_cascades} n_iters:{n_iters}",
- batch_size=args.batch_size,
- writer=writer,
- step=step,
- iterations=n_iters,
- cascades=n_cascades,
- save_path=config_image_folder,
- save_images=args.save_images,
- )
- metric_log = []
- metric_log_dict = {}
- # print the final results
- for config in metrics:
- config_tokens = config.split("_")
- config_iters = config_tokens[1][:-1]
- config_cascades = config_tokens[0][:-1]
- metric_log_dict[config_cascades] = metric_log_dict.get(config_cascades, {})
- metric_log_dict[config_cascades][config_iters] = metrics[config]
- evaluation_str = f"{args.dataset} evaluation@ size:{args.eval_size} n_cascades:{config_cascades} recurrent_updates:{config_iters}"
- metrics_str = f"Metrics: {metrics[config]}"
- metric_log.extend([evaluation_str, metrics_str])
- print(evaluation_str)
- print(metrics_str)
- eval_log_name = f"{checkpoint_name.replace('.pth', '')}_eval.log"
- print("Saving eval log to: ", eval_log_name)
- with open(eval_log_name, "w") as f:
- f.write(f"Dataset: {args.dataset} @size: {args.eval_size}:\n")
- # write the dict line by line for each key, and each value in the keys
- for config_cascades in metric_log_dict:
- f.write("{\n")
- f.write(f"\t{config_cascades}: {{\n")
- for config_iters in metric_log_dict[config_cascades]:
- # convert every metric to 4 decimal places
- metrics = metric_log_dict[config_cascades][config_iters]
- metrics = {k: float(f"{v:.3f}") for k, v in metrics.items()}
- f.write(f"\t\t{config_iters}: {metrics}\n")
- f.write("\t},\n")
- f.write("}\n")
- def load_checkpoint(args):
- utils.setup_ddp(args)
- if not args.weights:
- checkpoint = torch.load(args.checkpoint, map_location=torch.device("cpu"))
- if "model" in checkpoint:
- experiment_args = checkpoint["args"]
- model = torchvision.prototype.models.depth.stereo.__dict__[experiment_args.model](weights=None)
- model.load_state_dict(checkpoint["model"])
- else:
- model = torchvision.prototype.models.depth.stereo.__dict__[args.model](weights=None)
- model.load_state_dict(checkpoint)
- # set the appropriate devices
- if args.distributed and args.device == "cpu":
- raise ValueError("The device must be cuda if we want to run in distributed mode using torchrun")
- device = torch.device(args.device)
- else:
- model = torchvision.prototype.models.depth.stereo.__dict__[args.model](weights=args.weights)
- # convert to DDP if need be
- if args.distributed:
- model = model.to(args.device)
- model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
- else:
- model.to(device)
- return model
- def main(args):
- model = load_checkpoint(args)
- loader = make_eval_loader(args.dataset, args)
- evaluate(model, loader, args)
- if __name__ == "__main__":
- args = get_args_parser().parse_args()
- main(args)
|