import argparse import warnings from math import ceil from pathlib import Path import torch import torchvision.models.optical_flow import utils from presets import OpticalFlowPresetEval, OpticalFlowPresetTrain from torchvision.datasets import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel def get_train_dataset(stage, dataset_root): if stage == "chairs": transforms = OpticalFlowPresetTrain(crop_size=(368, 496), min_scale=0.1, max_scale=1.0, do_flip=True) return FlyingChairs(root=dataset_root, split="train", transforms=transforms) elif stage == "things": transforms = OpticalFlowPresetTrain(crop_size=(400, 720), min_scale=-0.4, max_scale=0.8, do_flip=True) return FlyingThings3D(root=dataset_root, split="train", pass_name="both", transforms=transforms) elif stage == "sintel_SKH": # S + K + H as from paper crop_size = (368, 768) transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.2, max_scale=0.6, do_flip=True) things_clean = FlyingThings3D(root=dataset_root, split="train", pass_name="clean", transforms=transforms) sintel = Sintel(root=dataset_root, split="train", pass_name="both", transforms=transforms) kitti_transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.3, max_scale=0.5, do_flip=True) kitti = KittiFlow(root=dataset_root, split="train", transforms=kitti_transforms) hd1k_transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.5, max_scale=0.2, do_flip=True) hd1k = HD1K(root=dataset_root, split="train", transforms=hd1k_transforms) # As future improvement, we could probably be using a distributed sampler here # The distribution is S(.71), T(.135), K(.135), H(.02) return 100 * sintel + 200 * kitti + 5 * hd1k + things_clean elif stage == "kitti": transforms = OpticalFlowPresetTrain( # resize and crop params crop_size=(288, 960), min_scale=-0.2, max_scale=0.4, stretch_prob=0, # flip params do_flip=False, # jitter params brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3 / 3.14, asymmetric_jitter_prob=0, ) return KittiFlow(root=dataset_root, split="train", transforms=transforms) else: raise ValueError(f"Unknown stage {stage}") @torch.no_grad() def _evaluate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, batch_size=None, header=None): """Helper function to compute various metrics (epe, etc.) for a model on a given dataset. We process as many samples as possible with ddp, and process the rest on a single worker. """ batch_size = batch_size or args.batch_size device = torch.device(args.device) model.eval() if args.distributed: sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True) else: sampler = torch.utils.data.SequentialSampler(val_dataset) val_loader = torch.utils.data.DataLoader( val_dataset, sampler=sampler, batch_size=batch_size, pin_memory=True, num_workers=args.workers, ) num_flow_updates = num_flow_updates or args.num_flow_updates def inner_loop(blob): if blob[0].dim() == 3: # input is not batched, so we add an extra dim for consistency blob = [x[None, :, :, :] if x is not None else None for x in blob] image1, image2, flow_gt = blob[:3] valid_flow_mask = None if len(blob) == 3 else blob[-1] image1, image2 = image1.to(device), image2.to(device) padder = utils.InputPadder(image1.shape, mode=padder_mode) image1, image2 = padder.pad(image1, image2) flow_predictions = model(image1, image2, num_flow_updates=num_flow_updates) flow_pred = flow_predictions[-1] flow_pred = padder.unpad(flow_pred).cpu() metrics, num_pixels_tot = utils.compute_metrics(flow_pred, flow_gt, valid_flow_mask) # We compute per-pixel epe (epe) and per-image epe (called f1-epe in RAFT paper). # per-pixel epe: average epe of all pixels of all images # per-image epe: average epe on each image independently, then average over images for name in ("epe", "1px", "3px", "5px", "f1"): # f1 is called f1-all in paper logger.meters[name].update(metrics[name], n=num_pixels_tot) logger.meters["per_image_epe"].update(metrics["epe"], n=batch_size) logger = utils.MetricLogger() for meter_name in ("epe", "1px", "3px", "5px", "per_image_epe", "f1"): logger.add_meter(meter_name, fmt="{global_avg:.4f}") num_processed_samples = 0 for blob in logger.log_every(val_loader, header=header, print_freq=None): inner_loop(blob) num_processed_samples += blob[0].shape[0] # batch size if args.distributed: num_processed_samples = utils.reduce_across_processes(num_processed_samples) print( f"Batch-processed {num_processed_samples} / {len(val_dataset)} samples. " "Going to process the remaining samples individually, if any." ) if args.rank == 0: # we only need to process the rest on a single worker for i in range(num_processed_samples, len(val_dataset)): inner_loop(val_dataset[i]) logger.synchronize_between_processes() print(header, logger) def evaluate(model, args): val_datasets = args.val_dataset or [] if args.weights and args.test_only: weights = torchvision.models.get_weight(args.weights) trans = weights.transforms() def preprocessing(img1, img2, flow, valid_flow_mask): img1, img2 = trans(img1, img2) if flow is not None and not isinstance(flow, torch.Tensor): flow = torch.from_numpy(flow) if valid_flow_mask is not None and not isinstance(valid_flow_mask, torch.Tensor): valid_flow_mask = torch.from_numpy(valid_flow_mask) return img1, img2, flow, valid_flow_mask else: preprocessing = OpticalFlowPresetEval() for name in val_datasets: if name == "kitti": # Kitti has different image sizes, so we need to individually pad them, we can't batch. # see comment in InputPadder if args.batch_size != 1 and (not args.distributed or args.rank == 0): warnings.warn( f"Batch-size={args.batch_size} was passed. For technical reasons, evaluating on Kitti can only be done with a batch-size of 1." ) val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=preprocessing) _evaluate( model, args, val_dataset, num_flow_updates=24, padder_mode="kitti", header="Kitti val", batch_size=1 ) elif name == "sintel": for pass_name in ("clean", "final"): val_dataset = Sintel( root=args.dataset_root, split="train", pass_name=pass_name, transforms=preprocessing ) _evaluate( model, args, val_dataset, num_flow_updates=32, padder_mode="sintel", header=f"Sintel val {pass_name}", ) else: warnings.warn(f"Can't validate on {val_dataset}, skipping.") def train_one_epoch(model, optimizer, scheduler, train_loader, logger, args): device = torch.device(args.device) for data_blob in logger.log_every(train_loader): optimizer.zero_grad() image1, image2, flow_gt, valid_flow_mask = (x.to(device) for x in data_blob) flow_predictions = model(image1, image2, num_flow_updates=args.num_flow_updates) loss = utils.sequence_loss(flow_predictions, flow_gt, valid_flow_mask, args.gamma) metrics, _ = utils.compute_metrics(flow_predictions[-1], flow_gt, valid_flow_mask) metrics.pop("f1") logger.update(loss=loss, **metrics) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1) optimizer.step() scheduler.step() def main(args): utils.setup_ddp(args) args.test_only = args.train_dataset is None if args.distributed and args.device == "cpu": raise ValueError("The device must be cuda if we want to run in distributed mode using torchrun") device = torch.device(args.device) if args.use_deterministic_algorithms: torch.backends.cudnn.benchmark = False torch.use_deterministic_algorithms(True) else: torch.backends.cudnn.benchmark = True model = torchvision.models.get_model(args.model, weights=args.weights) if args.distributed: model = model.to(args.local_rank) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank]) model_without_ddp = model.module else: model.to(device) model_without_ddp = model if args.resume is not None: checkpoint = torch.load(args.resume, map_location="cpu") model_without_ddp.load_state_dict(checkpoint["model"]) if args.test_only: # Set deterministic CUDNN algorithms, since they can affect epe a fair bit. torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True evaluate(model, args) return print(f"Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") train_dataset = get_train_dataset(args.train_dataset, args.dataset_root) optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.adamw_eps) scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer=optimizer, max_lr=args.lr, epochs=args.epochs, steps_per_epoch=ceil(len(train_dataset) / (args.world_size * args.batch_size)), pct_start=0.05, cycle_momentum=False, anneal_strategy="linear", ) if args.resume is not None: optimizer.load_state_dict(checkpoint["optimizer"]) scheduler.load_state_dict(checkpoint["scheduler"]) args.start_epoch = checkpoint["epoch"] + 1 else: args.start_epoch = 0 torch.backends.cudnn.benchmark = True model.train() if args.freeze_batch_norm: utils.freeze_batch_norm(model.module) if args.distributed: sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True, drop_last=True) else: sampler = torch.utils.data.RandomSampler(train_dataset) train_loader = torch.utils.data.DataLoader( train_dataset, sampler=sampler, batch_size=args.batch_size, pin_memory=True, num_workers=args.workers, ) logger = utils.MetricLogger() done = False for epoch in range(args.start_epoch, args.epochs): print(f"EPOCH {epoch}") if args.distributed: # needed on distributed mode, otherwise the data loading order would be the same for all epochs sampler.set_epoch(epoch) train_one_epoch( model=model, optimizer=optimizer, scheduler=scheduler, train_loader=train_loader, logger=logger, args=args, ) # Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0 print(f"Epoch {epoch} done. ", logger) if not args.distributed or args.rank == 0: checkpoint = { "model": model_without_ddp.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), "epoch": epoch, "args": args, } torch.save(checkpoint, Path(args.output_dir) / f"{args.name}_{epoch}.pth") torch.save(checkpoint, Path(args.output_dir) / f"{args.name}.pth") if epoch % args.val_freq == 0 or done: evaluate(model, args) model.train() if args.freeze_batch_norm: utils.freeze_batch_norm(model.module) def get_args_parser(add_help=True): parser = argparse.ArgumentParser(add_help=add_help, description="Train or evaluate an optical-flow model.") parser.add_argument( "--name", default="raft", type=str, help="The name of the experiment - determines the name of the files where weights are saved.", ) parser.add_argument("--output-dir", default=".", type=str, help="Output dir where checkpoints will be stored.") parser.add_argument( "--resume", type=str, help="A path to previously saved weights. Used to re-start training from, or evaluate a pre-saved model.", ) parser.add_argument("--workers", type=int, default=12, help="Number of workers for the data loading part.") parser.add_argument( "--train-dataset", type=str, help="The dataset to use for training. If not passed, only validation is performed (and you probably want to pass --resume).", ) parser.add_argument("--val-dataset", type=str, nargs="+", help="The dataset(s) to use for validation.") parser.add_argument("--val-freq", type=int, default=2, help="Validate every X epochs") parser.add_argument("--epochs", type=int, default=20, help="The total number of epochs to train.") parser.add_argument("--batch-size", type=int, default=2) parser.add_argument("--lr", type=float, default=0.00002, help="Learning rate for AdamW optimizer") parser.add_argument("--weight-decay", type=float, default=0.00005, help="Weight decay for AdamW optimizer") parser.add_argument("--adamw-eps", type=float, default=1e-8, help="eps value for AdamW optimizer") parser.add_argument( "--freeze-batch-norm", action="store_true", help="Set BatchNorm modules of the model in eval mode." ) parser.add_argument( "--model", type=str, default="raft_large", help="The name of the model to use - either raft_large or raft_small" ) # TODO: resume and weights should be in an exclusive arg group parser.add_argument( "--num_flow_updates", type=int, default=12, help="number of updates (or 'iters') in the update operator of the model.", ) parser.add_argument("--gamma", type=float, default=0.8, help="exponential weighting for loss. Must be < 1.") parser.add_argument("--dist-url", default="env://", help="URL used to set up distributed training") parser.add_argument( "--dataset-root", help="Root folder where the datasets are stored. Will be passed as the 'root' parameter of the datasets.", required=True, ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load.") parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu, Default: cuda)") parser.add_argument( "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only." ) return parser if __name__ == "__main__": args = get_args_parser().parse_args() Path(args.output_dir).mkdir(exist_ok=True) main(args)