123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788 |
- import argparse
- import os
- import warnings
- from pathlib import Path
- from typing import List, Union
- import numpy as np
- import torch
- import torch.distributed as dist
- import torchvision.models.optical_flow
- import torchvision.prototype.models.depth.stereo
- import utils
- import visualization
- from parsing import make_dataset, make_eval_transform, make_train_transform, VALID_DATASETS
- from torch import nn
- from torchvision.transforms.functional import get_dimensions, InterpolationMode, resize
- from utils.metrics import AVAILABLE_METRICS
- from utils.norm import freeze_batch_norm
- def make_stereo_flow(flow: Union[torch.Tensor, List[torch.Tensor]], model_out_channels: int) -> torch.Tensor:
- """Helper function to make stereo flow from a given model output"""
- if isinstance(flow, list):
- return [make_stereo_flow(flow_i, model_out_channels) for flow_i in flow]
- B, C, H, W = flow.shape
- # we need to add zero flow if the model outputs 2 channels
- if C == 1 and model_out_channels == 2:
- zero_flow = torch.zeros_like(flow)
- # by convention the flow is X-Y axis, so we need the Y flow last
- flow = torch.cat([flow, zero_flow], dim=1)
- return flow
- def make_lr_schedule(args: argparse.Namespace, optimizer: torch.optim.Optimizer) -> np.ndarray:
- """Helper function to return a learning rate scheduler for CRE-stereo"""
- if args.decay_after_steps < args.warmup_steps:
- raise ValueError(f"decay_after_steps: {args.function} must be greater than warmup_steps: {args.warmup_steps}")
- warmup_steps = args.warmup_steps if args.warmup_steps else 0
- flat_lr_steps = args.decay_after_steps - warmup_steps if args.decay_after_steps else 0
- decay_lr_steps = args.total_iterations - flat_lr_steps
- max_lr = args.lr
- min_lr = args.min_lr
- schedulers = []
- milestones = []
- if warmup_steps > 0:
- if args.lr_warmup_method == "linear":
- warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
- optimizer, start_factor=args.lr_warmup_factor, total_iters=warmup_steps
- )
- elif args.lr_warmup_method == "constant":
- warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
- optimizer, factor=args.lr_warmup_factor, total_iters=warmup_steps
- )
- else:
- raise ValueError(f"Unknown lr warmup method {args.lr_warmup_method}")
- schedulers.append(warmup_lr_scheduler)
- milestones.append(warmup_steps)
- if flat_lr_steps > 0:
- flat_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=max_lr, total_iters=flat_lr_steps)
- schedulers.append(flat_lr_scheduler)
- milestones.append(flat_lr_steps + warmup_steps)
- if decay_lr_steps > 0:
- if args.lr_decay_method == "cosine":
- decay_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
- optimizer, T_max=decay_lr_steps, eta_min=min_lr
- )
- elif args.lr_decay_method == "linear":
- decay_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
- optimizer, start_factor=max_lr, end_factor=min_lr, total_iters=decay_lr_steps
- )
- elif args.lr_decay_method == "exponential":
- decay_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
- optimizer, gamma=args.lr_decay_gamma, last_epoch=-1
- )
- else:
- raise ValueError(f"Unknown lr decay method {args.lr_decay_method}")
- schedulers.append(decay_lr_scheduler)
- scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers, milestones=milestones)
- return scheduler
- def shuffle_dataset(dataset):
- """Shuffle the dataset"""
- perm = torch.randperm(len(dataset))
- return torch.utils.data.Subset(dataset, perm)
- def resize_dataset_to_n_steps(
- dataset: torch.utils.data.Dataset, dataset_steps: int, samples_per_step: int, args: argparse.Namespace
- ) -> torch.utils.data.Dataset:
- original_size = len(dataset)
- if args.steps_is_epochs:
- samples_per_step = original_size
- target_size = dataset_steps * samples_per_step
- dataset_copies = []
- n_expands, remainder = divmod(target_size, original_size)
- for idx in range(n_expands):
- dataset_copies.append(dataset)
- if remainder > 0:
- dataset_copies.append(torch.utils.data.Subset(dataset, list(range(remainder))))
- if args.dataset_shuffle:
- dataset_copies = [shuffle_dataset(dataset_copy) for dataset_copy in dataset_copies]
- dataset = torch.utils.data.ConcatDataset(dataset_copies)
- return dataset
- def get_train_dataset(dataset_root: str, args: argparse.Namespace) -> torch.utils.data.Dataset:
- datasets = []
- for dataset_name in args.train_datasets:
- transform = make_train_transform(args)
- dataset = make_dataset(dataset_name, dataset_root, transform)
- datasets.append(dataset)
- if len(datasets) == 0:
- raise ValueError("No datasets specified for training")
- samples_per_step = args.world_size * args.batch_size
- for idx, (dataset, steps_per_dataset) in enumerate(zip(datasets, args.dataset_steps)):
- datasets[idx] = resize_dataset_to_n_steps(dataset, steps_per_dataset, samples_per_step, args)
- dataset = torch.utils.data.ConcatDataset(datasets)
- if args.dataset_order_shuffle:
- dataset = shuffle_dataset(dataset)
- print(f"Training dataset: {len(dataset)} samples")
- return dataset
- @torch.inference_mode()
- def _evaluate(
- model,
- args,
- val_loader,
- *,
- padder_mode,
- print_freq=10,
- writer=None,
- step=None,
- iterations=None,
- batch_size=None,
- header=None,
- ):
- """Helper function to compute various metrics (epe, etc.) for a model on a given dataset."""
- model.eval()
- header = header or "Test:"
- device = torch.device(args.device)
- metric_logger = utils.MetricLogger(delimiter=" ")
- iterations = iterations or args.recurrent_updates
- logger = utils.MetricLogger()
- for meter_name in args.metrics:
- logger.add_meter(meter_name, fmt="{global_avg:.4f}")
- if "fl-all" not in args.metrics:
- logger.add_meter("fl-all", fmt="{global_avg:.4f}")
- num_processed_samples = 0
- with torch.cuda.amp.autocast(enabled=args.mixed_precision, dtype=torch.float16):
- for blob in metric_logger.log_every(val_loader, print_freq, header):
- image_left, image_right, disp_gt, valid_disp_mask = (x.to(device) for x in blob)
- padder = utils.InputPadder(image_left.shape, mode=padder_mode)
- image_left, image_right = padder.pad(image_left, image_right)
- disp_predictions = model(image_left, image_right, flow_init=None, num_iters=iterations)
- disp_pred = disp_predictions[-1][:, :1, :, :]
- disp_pred = padder.unpad(disp_pred)
- metrics, _ = utils.compute_metrics(disp_pred, disp_gt, valid_disp_mask, metrics=logger.meters.keys())
- num_processed_samples += image_left.shape[0]
- for name in metrics:
- logger.meters[name].update(metrics[name], n=1)
- num_processed_samples = utils.reduce_across_processes(num_processed_samples)
- print("Num_processed_samples: ", num_processed_samples)
- if (
- hasattr(val_loader.dataset, "__len__")
- and len(val_loader.dataset) != num_processed_samples
- and torch.distributed.get_rank() == 0
- ):
- warnings.warn(
- f"Number of processed samples {num_processed_samples} is different"
- f"from the dataset size {len(val_loader.dataset)}. This may happen if"
- "the dataset is not divisible by the batch size. Try lowering the batch size or GPU number for more accurate results."
- )
- if writer is not None and args.rank == 0:
- for meter_name, meter_value in logger.meters.items():
- scalar_name = f"{meter_name} {header}"
- writer.add_scalar(scalar_name, meter_value.avg, step)
- logger.synchronize_between_processes()
- print(header, logger)
- def make_eval_loader(dataset_name: str, args: argparse.Namespace) -> torch.utils.data.DataLoader:
- if args.weights:
- weights = torchvision.models.get_weight(args.weights)
- trans = weights.transforms()
- def preprocessing(image_left, image_right, disp, valid_disp_mask):
- C_o, H_o, W_o = get_dimensions(image_left)
- image_left, image_right = trans(image_left, image_right)
- C_t, H_t, W_t = get_dimensions(image_left)
- scale_factor = W_t / W_o
- if disp is not None and not isinstance(disp, torch.Tensor):
- disp = torch.from_numpy(disp)
- if W_t != W_o:
- disp = resize(disp, (H_t, W_t), mode=InterpolationMode.BILINEAR) * scale_factor
- if valid_disp_mask is not None and not isinstance(valid_disp_mask, torch.Tensor):
- valid_disp_mask = torch.from_numpy(valid_disp_mask)
- if W_t != W_o:
- valid_disp_mask = resize(valid_disp_mask, (H_t, W_t), mode=InterpolationMode.NEAREST)
- return image_left, image_right, disp, valid_disp_mask
- else:
- preprocessing = make_eval_transform(args)
- val_dataset = make_dataset(dataset_name, args.dataset_root, transforms=preprocessing)
- if args.distributed:
- sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=False)
- else:
- sampler = torch.utils.data.SequentialSampler(val_dataset)
- val_loader = torch.utils.data.DataLoader(
- val_dataset,
- sampler=sampler,
- batch_size=args.batch_size,
- pin_memory=True,
- num_workers=args.workers,
- )
- return val_loader
- def evaluate(model, loaders, args, writer=None, step=None):
- for loader_name, loader in loaders.items():
- _evaluate(
- model,
- args,
- loader,
- iterations=args.recurrent_updates,
- padder_mode=args.padder_type,
- header=f"{loader_name} evaluation",
- batch_size=args.batch_size,
- writer=writer,
- step=step,
- )
- def run(model, optimizer, scheduler, train_loader, val_loaders, logger, writer, scaler, args):
- device = torch.device(args.device)
- # wrap the loader in a logger
- loader = iter(logger.log_every(train_loader))
- # output channels
- model_out_channels = model.module.output_channels if args.distributed else model.output_channels
- torch.set_num_threads(args.threads)
- sequence_criterion = utils.SequenceLoss(
- gamma=args.gamma,
- max_flow=args.max_disparity,
- exclude_large_flows=args.flow_loss_exclude_large,
- ).to(device)
- if args.consistency_weight:
- consistency_criterion = utils.FlowSequenceConsistencyLoss(
- args.gamma,
- resize_factor=0.25,
- rescale_factor=0.25,
- rescale_mode="bilinear",
- ).to(device)
- else:
- consistency_criterion = None
- if args.psnr_weight:
- psnr_criterion = utils.PSNRLoss().to(device)
- else:
- psnr_criterion = None
- if args.smoothness_weight:
- smoothness_criterion = utils.SmoothnessLoss().to(device)
- else:
- smoothness_criterion = None
- if args.photometric_weight:
- photometric_criterion = utils.FlowPhotoMetricLoss(
- ssim_weight=args.photometric_ssim_weight,
- max_displacement_ratio=args.photometric_max_displacement_ratio,
- ssim_use_padding=False,
- ).to(device)
- else:
- photometric_criterion = None
- for step in range(args.start_step + 1, args.total_iterations + 1):
- data_blob = next(loader)
- optimizer.zero_grad()
- # unpack the data blob
- image_left, image_right, disp_mask, valid_disp_mask = (x.to(device) for x in data_blob)
- with torch.cuda.amp.autocast(enabled=args.mixed_precision, dtype=torch.float16):
- disp_predictions = model(image_left, image_right, flow_init=None, num_iters=args.recurrent_updates)
- # different models have different outputs, make sure we get the right ones for this task
- disp_predictions = make_stereo_flow(disp_predictions, model_out_channels)
- # should the architecture or training loop require it, we have to adjust the disparity mask
- # target to possibly look like an optical flow mask
- disp_mask = make_stereo_flow(disp_mask, model_out_channels)
- # sequence loss on top of the model outputs
- loss = sequence_criterion(disp_predictions, disp_mask, valid_disp_mask) * args.flow_loss_weight
- if args.consistency_weight > 0:
- loss_consistency = consistency_criterion(disp_predictions)
- loss += loss_consistency * args.consistency_weight
- if args.psnr_weight > 0:
- loss_psnr = 0.0
- for pred in disp_predictions:
- # predictions might have 2 channels
- loss_psnr += psnr_criterion(
- pred * valid_disp_mask.unsqueeze(1),
- disp_mask * valid_disp_mask.unsqueeze(1),
- ).mean() # mean the psnr loss over the batch
- loss += loss_psnr / len(disp_predictions) * args.psnr_weight
- if args.photometric_weight > 0:
- loss_photometric = 0.0
- for pred in disp_predictions:
- # predictions might have 1 channel, therefore we need to inpute 0s for the second channel
- if model_out_channels == 1:
- pred = torch.cat([pred, torch.zeros_like(pred)], dim=1)
- loss_photometric += photometric_criterion(
- image_left, image_right, pred, valid_disp_mask
- ) # photometric loss already comes out meaned over the batch
- loss += loss_photometric / len(disp_predictions) * args.photometric_weight
- if args.smoothness_weight > 0:
- loss_smoothness = 0.0
- for pred in disp_predictions:
- # predictions might have 2 channels
- loss_smoothness += smoothness_criterion(
- image_left, pred[:, :1, :, :]
- ).mean() # mean the smoothness loss over the batch
- loss += loss_smoothness / len(disp_predictions) * args.smoothness_weight
- with torch.no_grad():
- metrics, _ = utils.compute_metrics(
- disp_predictions[-1][:, :1, :, :], # predictions might have 2 channels
- disp_mask[:, :1, :, :], # so does the ground truth
- valid_disp_mask,
- args.metrics,
- )
- metrics.pop("fl-all", None)
- logger.update(loss=loss, **metrics)
- if scaler is not None:
- scaler.scale(loss).backward()
- scaler.unscale_(optimizer)
- if args.clip_grad_norm:
- torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_grad_norm)
- scaler.step(optimizer)
- scaler.update()
- else:
- loss.backward()
- if args.clip_grad_norm:
- torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_grad_norm)
- optimizer.step()
- scheduler.step()
- if not dist.is_initialized() or dist.get_rank() == 0:
- if writer is not None and step % args.tensorboard_log_frequency == 0:
- # log the loss and metrics to tensorboard
- writer.add_scalar("loss", loss, step)
- for name, value in logger.meters.items():
- writer.add_scalar(name, value.avg, step)
- # log the images to tensorboard
- pred_grid = visualization.make_training_sample_grid(
- image_left, image_right, disp_mask, valid_disp_mask, disp_predictions
- )
- writer.add_image("predictions", pred_grid, step, dataformats="HWC")
- # second thing we want to see is how relevant the iterative refinement is
- pred_sequence_grid = visualization.make_disparity_sequence_grid(disp_predictions, disp_mask)
- writer.add_image("sequence", pred_sequence_grid, step, dataformats="HWC")
- if step % args.save_frequency == 0:
- if not args.distributed or args.rank == 0:
- model_without_ddp = (
- model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
- )
- checkpoint = {
- "model": model_without_ddp.state_dict(),
- "optimizer": optimizer.state_dict(),
- "scheduler": scheduler.state_dict(),
- "step": step,
- "args": args,
- }
- os.makedirs(args.checkpoint_dir, exist_ok=True)
- torch.save(checkpoint, Path(args.checkpoint_dir) / f"{args.name}_{step}.pth")
- torch.save(checkpoint, Path(args.checkpoint_dir) / f"{args.name}.pth")
- if step % args.valid_frequency == 0:
- evaluate(model, val_loaders, args, writer, step)
- model.train()
- if args.freeze_batch_norm:
- if isinstance(model, nn.parallel.DistributedDataParallel):
- freeze_batch_norm(model.module)
- else:
- freeze_batch_norm(model)
- # one final save at the end
- if not args.distributed or args.rank == 0:
- model_without_ddp = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
- checkpoint = {
- "model": model_without_ddp.state_dict(),
- "optimizer": optimizer.state_dict(),
- "scheduler": scheduler.state_dict(),
- "step": step,
- "args": args,
- }
- os.makedirs(args.checkpoint_dir, exist_ok=True)
- torch.save(checkpoint, Path(args.checkpoint_dir) / f"{args.name}_{step}.pth")
- torch.save(checkpoint, Path(args.checkpoint_dir) / f"{args.name}.pth")
- def main(args):
- args.total_iterations = sum(args.dataset_steps)
- # initialize DDP setting
- utils.setup_ddp(args)
- print(args)
- args.test_only = args.train_datasets is None
- # set the appropriate devices
- if args.distributed and args.device == "cpu":
- raise ValueError("The device must be cuda if we want to run in distributed mode using torchrun")
- device = torch.device(args.device)
- # select model architecture
- model = torchvision.prototype.models.depth.stereo.__dict__[args.model](weights=args.weights)
- # convert to DDP if need be
- if args.distributed:
- model = model.to(args.gpu)
- model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
- model_without_ddp = model.module
- else:
- model.to(device)
- model_without_ddp = model
- os.makedirs(args.checkpoint_dir, exist_ok=True)
- val_loaders = {name: make_eval_loader(name, args) for name in args.test_datasets}
- # EVAL ONLY configurations
- if args.test_only:
- evaluate(model, val_loaders, args)
- return
- # Sanity check for the parameter count
- print(f"Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
- # Compose the training dataset
- train_dataset = get_train_dataset(args.dataset_root, args)
- # initialize the optimizer
- if args.optimizer == "adam":
- optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
- elif args.optimizer == "sgd":
- optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9)
- else:
- raise ValueError(f"Unknown optimizer {args.optimizer}. Please choose between adam and sgd")
- # initialize the learning rate schedule
- scheduler = make_lr_schedule(args, optimizer)
- # load them from checkpoint if needed
- args.start_step = 0
- if args.resume_path is not None:
- checkpoint = torch.load(args.resume_path, map_location="cpu")
- if "model" in checkpoint:
- # this means the user requested to resume from a training checkpoint
- model_without_ddp.load_state_dict(checkpoint["model"])
- # this means the user wants to continue training from where it was left off
- if args.resume_schedule:
- optimizer.load_state_dict(checkpoint["optimizer"])
- scheduler.load_state_dict(checkpoint["scheduler"])
- args.start_step = checkpoint["step"] + 1
- # modify starting point of the dat
- sample_start_step = args.start_step * args.batch_size * args.world_size
- train_dataset = train_dataset[sample_start_step:]
- else:
- # this means the user wants to finetune on top of a model state dict
- # and that no other changes are required
- model_without_ddp.load_state_dict(checkpoint)
- torch.backends.cudnn.benchmark = True
- # enable training mode
- model.train()
- if args.freeze_batch_norm:
- freeze_batch_norm(model_without_ddp)
- # put dataloader on top of the dataset
- # make sure to disable shuffling since the dataset is already shuffled
- # in order to guarantee quasi randomness whilst retaining a deterministic
- # dataset consumption order
- if args.distributed:
- # the train dataset is preshuffled in order to respect the iteration order
- sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=False, drop_last=True)
- else:
- # the train dataset is already shuffled, so we can use a simple SequentialSampler
- sampler = torch.utils.data.SequentialSampler(train_dataset)
- train_loader = torch.utils.data.DataLoader(
- train_dataset,
- sampler=sampler,
- batch_size=args.batch_size,
- pin_memory=True,
- num_workers=args.workers,
- )
- # initialize the logger
- if args.tensorboard_summaries:
- from torch.utils.tensorboard import SummaryWriter
- tensorboard_path = Path(args.checkpoint_dir) / "tensorboard"
- os.makedirs(tensorboard_path, exist_ok=True)
- tensorboard_run = tensorboard_path / f"{args.name}"
- writer = SummaryWriter(tensorboard_run)
- else:
- writer = None
- logger = utils.MetricLogger(delimiter=" ")
- scaler = torch.cuda.amp.GradScaler() if args.mixed_precision else None
- # run the training loop
- # this will perform optimization, respectively logging and saving checkpoints
- # when need be
- run(
- model=model,
- optimizer=optimizer,
- scheduler=scheduler,
- train_loader=train_loader,
- val_loaders=val_loaders,
- logger=logger,
- writer=writer,
- scaler=scaler,
- args=args,
- )
- def get_args_parser(add_help=True):
- import argparse
- parser = argparse.ArgumentParser(description="PyTorch Stereo Matching Training", add_help=add_help)
- # checkpointing
- parser.add_argument("--name", default="crestereo", help="name of the experiment")
- parser.add_argument("--resume", type=str, default=None, help="from which checkpoint to resume")
- parser.add_argument("--checkpoint-dir", type=str, default="checkpoints", help="path to the checkpoint directory")
- # dataset
- parser.add_argument("--dataset-root", type=str, default="", help="path to the dataset root directory")
- parser.add_argument(
- "--train-datasets",
- type=str,
- nargs="+",
- default=["crestereo"],
- help="dataset(s) to train on",
- choices=list(VALID_DATASETS.keys()),
- )
- parser.add_argument(
- "--dataset-steps", type=int, nargs="+", default=[300_000], help="number of steps for each dataset"
- )
- parser.add_argument(
- "--steps-is-epochs", action="store_true", help="if set, dataset-steps are interpreted as epochs"
- )
- parser.add_argument(
- "--test-datasets",
- type=str,
- nargs="+",
- default=["middlebury2014-train"],
- help="dataset(s) to test on",
- choices=["middlebury2014-train"],
- )
- parser.add_argument("--dataset-shuffle", type=bool, help="shuffle the dataset", default=True)
- parser.add_argument("--dataset-order-shuffle", type=bool, help="shuffle the dataset order", default=True)
- parser.add_argument("--batch-size", type=int, default=2, help="batch size per GPU")
- parser.add_argument("--workers", type=int, default=4, help="number of workers per GPU")
- parser.add_argument(
- "--threads",
- type=int,
- default=16,
- help="number of CPU threads per GPU. This can be changed around to speed-up transforms if needed. This can lead to worker thread contention so use with care.",
- )
- # model architecture
- parser.add_argument(
- "--model",
- type=str,
- default="crestereo_base",
- help="model architecture",
- choices=["crestereo_base", "raft_stereo"],
- )
- parser.add_argument("--recurrent-updates", type=int, default=10, help="number of recurrent updates")
- parser.add_argument("--freeze-batch-norm", action="store_true", help="freeze batch norm parameters")
- # loss parameters
- parser.add_argument("--gamma", type=float, default=0.8, help="gamma parameter for the flow sequence loss")
- parser.add_argument("--flow-loss-weight", type=float, default=1.0, help="weight for the flow loss")
- parser.add_argument(
- "--flow-loss-exclude-large",
- action="store_true",
- help="exclude large flow values from the loss. A large value is defined as a value greater than the ground truth flow norm",
- default=False,
- )
- parser.add_argument("--consistency-weight", type=float, default=0.0, help="consistency loss weight")
- parser.add_argument(
- "--consistency-resize-factor",
- type=float,
- default=0.25,
- help="consistency loss resize factor to account for the fact that the flow is computed on a downsampled image",
- )
- parser.add_argument("--psnr-weight", type=float, default=0.0, help="psnr loss weight")
- parser.add_argument("--smoothness-weight", type=float, default=0.0, help="smoothness loss weight")
- parser.add_argument("--photometric-weight", type=float, default=0.0, help="photometric loss weight")
- parser.add_argument(
- "--photometric-max-displacement-ratio",
- type=float,
- default=0.15,
- help="Only pixels with a displacement smaller than this ratio of the image width will be considered for the photometric loss",
- )
- parser.add_argument("--photometric-ssim-weight", type=float, default=0.85, help="photometric ssim loss weight")
- # transforms parameters
- parser.add_argument("--gpu-transforms", action="store_true", help="use GPU transforms")
- parser.add_argument(
- "--eval-size", type=int, nargs="+", default=[384, 512], help="size of the images for evaluation"
- )
- parser.add_argument("--resize-size", type=int, nargs=2, default=None, help="resize size")
- parser.add_argument("--crop-size", type=int, nargs=2, default=[384, 512], help="crop size")
- parser.add_argument("--scale-range", type=float, nargs=2, default=[0.6, 1.0], help="random scale range")
- parser.add_argument("--rescale-prob", type=float, default=1.0, help="probability of resizing the image")
- parser.add_argument(
- "--scaling-type", type=str, default="linear", help="scaling type", choices=["exponential", "linear"]
- )
- parser.add_argument("--flip-prob", type=float, default=0.5, help="probability of flipping the image")
- parser.add_argument(
- "--norm-mean", type=float, nargs="+", default=[0.5, 0.5, 0.5], help="mean for image normalization"
- )
- parser.add_argument(
- "--norm-std", type=float, nargs="+", default=[0.5, 0.5, 0.5], help="std for image normalization"
- )
- parser.add_argument(
- "--use-grayscale", action="store_true", help="use grayscale images instead of RGB", default=False
- )
- parser.add_argument("--max-disparity", type=float, default=None, help="maximum disparity")
- parser.add_argument(
- "--interpolation-strategy",
- type=str,
- default="bilinear",
- help="interpolation strategy",
- choices=["bilinear", "bicubic", "mixed"],
- )
- parser.add_argument("--spatial-shift-prob", type=float, default=1.0, help="probability of shifting the image")
- parser.add_argument(
- "--spatial-shift-max-angle", type=float, default=0.1, help="maximum angle for the spatial shift"
- )
- parser.add_argument(
- "--spatial-shift-max-displacement", type=float, default=2.0, help="maximum displacement for the spatial shift"
- )
- parser.add_argument("--gamma-range", type=float, nargs="+", default=[0.8, 1.2], help="range for gamma correction")
- parser.add_argument(
- "--brightness-range", type=float, nargs="+", default=[0.8, 1.2], help="range for brightness correction"
- )
- parser.add_argument(
- "--contrast-range", type=float, nargs="+", default=[0.8, 1.2], help="range for contrast correction"
- )
- parser.add_argument(
- "--saturation-range", type=float, nargs="+", default=0.0, help="range for saturation correction"
- )
- parser.add_argument("--hue-range", type=float, nargs="+", default=0.0, help="range for hue correction")
- parser.add_argument(
- "--asymmetric-jitter-prob",
- type=float,
- default=1.0,
- help="probability of using asymmetric jitter instead of symmetric jitter",
- )
- parser.add_argument("--occlusion-prob", type=float, default=0.5, help="probability of occluding the rightimage")
- parser.add_argument(
- "--occlusion-px-range", type=int, nargs="+", default=[50, 100], help="range for the number of occluded pixels"
- )
- parser.add_argument("--erase-prob", type=float, default=0.0, help="probability of erasing in both images")
- parser.add_argument(
- "--erase-px-range", type=int, nargs="+", default=[50, 100], help="range for the number of erased pixels"
- )
- parser.add_argument(
- "--erase-num-repeats", type=int, default=1, help="number of times to repeat the erase operation"
- )
- # optimizer parameters
- parser.add_argument("--optimizer", type=str, default="adam", help="optimizer", choices=["adam", "sgd"])
- parser.add_argument("--lr", type=float, default=4e-4, help="learning rate")
- parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay")
- parser.add_argument("--clip-grad-norm", type=float, default=0.0, help="clip grad norm")
- # lr_scheduler parameters
- parser.add_argument("--min-lr", type=float, default=2e-5, help="minimum learning rate")
- parser.add_argument("--warmup-steps", type=int, default=6_000, help="number of warmup steps")
- parser.add_argument(
- "--decay-after-steps", type=int, default=180_000, help="number of steps after which to start decay the lr"
- )
- parser.add_argument(
- "--lr-warmup-method", type=str, default="linear", help="warmup method", choices=["linear", "cosine"]
- )
- parser.add_argument("--lr-warmup-factor", type=float, default=0.02, help="warmup factor for the learning rate")
- parser.add_argument(
- "--lr-decay-method",
- type=str,
- default="linear",
- help="decay method",
- choices=["linear", "cosine", "exponential"],
- )
- parser.add_argument("--lr-decay-gamma", type=float, default=0.8, help="decay factor for the learning rate")
- # deterministic behaviour
- parser.add_argument("--seed", type=int, default=42, help="seed for random number generators")
- # mixed precision training
- parser.add_argument("--mixed-precision", action="store_true", help="use mixed precision training")
- # logging
- parser.add_argument("--tensorboard-summaries", action="store_true", help="log to tensorboard")
- parser.add_argument("--tensorboard-log-frequency", type=int, default=100, help="log frequency")
- parser.add_argument("--save-frequency", type=int, default=1_000, help="save frequency")
- parser.add_argument("--valid-frequency", type=int, default=1_000, help="validation frequency")
- parser.add_argument(
- "--metrics",
- type=str,
- nargs="+",
- default=["mae", "rmse", "1px", "3px", "5px", "relepe"],
- help="metrics to log",
- choices=AVAILABLE_METRICS,
- )
- # distributed parameters
- parser.add_argument("--world-size", type=int, default=8, help="number of distributed processes")
- parser.add_argument("--dist-url", type=str, default="env://", help="url used to set up distributed training")
- parser.add_argument("--device", type=str, default="cuda", help="device to use for training")
- # weights API
- parser.add_argument("--weights", type=str, default=None, help="weights API url")
- parser.add_argument(
- "--resume-path", type=str, default=None, help="a path from which to resume or start fine-tuning"
- )
- parser.add_argument("--resume-schedule", action="store_true", help="resume optimizer state")
- # padder parameters
- parser.add_argument("--padder-type", type=str, default="kitti", help="padder type", choices=["kitti", "sintel"])
- return parser
- if __name__ == "__main__":
- args = get_args_parser().parse_args()
- main(args)
|