123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389 |
- import argparse
- import warnings
- from math import ceil
- from pathlib import Path
- import torch
- import torchvision.models.optical_flow
- import utils
- from presets import OpticalFlowPresetEval, OpticalFlowPresetTrain
- from torchvision.datasets import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel
- def get_train_dataset(stage, dataset_root):
- if stage == "chairs":
- transforms = OpticalFlowPresetTrain(crop_size=(368, 496), min_scale=0.1, max_scale=1.0, do_flip=True)
- return FlyingChairs(root=dataset_root, split="train", transforms=transforms)
- elif stage == "things":
- transforms = OpticalFlowPresetTrain(crop_size=(400, 720), min_scale=-0.4, max_scale=0.8, do_flip=True)
- return FlyingThings3D(root=dataset_root, split="train", pass_name="both", transforms=transforms)
- elif stage == "sintel_SKH": # S + K + H as from paper
- crop_size = (368, 768)
- transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.2, max_scale=0.6, do_flip=True)
- things_clean = FlyingThings3D(root=dataset_root, split="train", pass_name="clean", transforms=transforms)
- sintel = Sintel(root=dataset_root, split="train", pass_name="both", transforms=transforms)
- kitti_transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.3, max_scale=0.5, do_flip=True)
- kitti = KittiFlow(root=dataset_root, split="train", transforms=kitti_transforms)
- hd1k_transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.5, max_scale=0.2, do_flip=True)
- hd1k = HD1K(root=dataset_root, split="train", transforms=hd1k_transforms)
- # As future improvement, we could probably be using a distributed sampler here
- # The distribution is S(.71), T(.135), K(.135), H(.02)
- return 100 * sintel + 200 * kitti + 5 * hd1k + things_clean
- elif stage == "kitti":
- transforms = OpticalFlowPresetTrain(
- # resize and crop params
- crop_size=(288, 960),
- min_scale=-0.2,
- max_scale=0.4,
- stretch_prob=0,
- # flip params
- do_flip=False,
- # jitter params
- brightness=0.3,
- contrast=0.3,
- saturation=0.3,
- hue=0.3 / 3.14,
- asymmetric_jitter_prob=0,
- )
- return KittiFlow(root=dataset_root, split="train", transforms=transforms)
- else:
- raise ValueError(f"Unknown stage {stage}")
- @torch.no_grad()
- def _evaluate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, batch_size=None, header=None):
- """Helper function to compute various metrics (epe, etc.) for a model on a given dataset.
- We process as many samples as possible with ddp, and process the rest on a single worker.
- """
- batch_size = batch_size or args.batch_size
- device = torch.device(args.device)
- model.eval()
- if args.distributed:
- sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True)
- else:
- sampler = torch.utils.data.SequentialSampler(val_dataset)
- val_loader = torch.utils.data.DataLoader(
- val_dataset,
- sampler=sampler,
- batch_size=batch_size,
- pin_memory=True,
- num_workers=args.workers,
- )
- num_flow_updates = num_flow_updates or args.num_flow_updates
- def inner_loop(blob):
- if blob[0].dim() == 3:
- # input is not batched, so we add an extra dim for consistency
- blob = [x[None, :, :, :] if x is not None else None for x in blob]
- image1, image2, flow_gt = blob[:3]
- valid_flow_mask = None if len(blob) == 3 else blob[-1]
- image1, image2 = image1.to(device), image2.to(device)
- padder = utils.InputPadder(image1.shape, mode=padder_mode)
- image1, image2 = padder.pad(image1, image2)
- flow_predictions = model(image1, image2, num_flow_updates=num_flow_updates)
- flow_pred = flow_predictions[-1]
- flow_pred = padder.unpad(flow_pred).cpu()
- metrics, num_pixels_tot = utils.compute_metrics(flow_pred, flow_gt, valid_flow_mask)
- # We compute per-pixel epe (epe) and per-image epe (called f1-epe in RAFT paper).
- # per-pixel epe: average epe of all pixels of all images
- # per-image epe: average epe on each image independently, then average over images
- for name in ("epe", "1px", "3px", "5px", "f1"): # f1 is called f1-all in paper
- logger.meters[name].update(metrics[name], n=num_pixels_tot)
- logger.meters["per_image_epe"].update(metrics["epe"], n=batch_size)
- logger = utils.MetricLogger()
- for meter_name in ("epe", "1px", "3px", "5px", "per_image_epe", "f1"):
- logger.add_meter(meter_name, fmt="{global_avg:.4f}")
- num_processed_samples = 0
- for blob in logger.log_every(val_loader, header=header, print_freq=None):
- inner_loop(blob)
- num_processed_samples += blob[0].shape[0] # batch size
- if args.distributed:
- num_processed_samples = utils.reduce_across_processes(num_processed_samples)
- print(
- f"Batch-processed {num_processed_samples} / {len(val_dataset)} samples. "
- "Going to process the remaining samples individually, if any."
- )
- if args.rank == 0: # we only need to process the rest on a single worker
- for i in range(num_processed_samples, len(val_dataset)):
- inner_loop(val_dataset[i])
- logger.synchronize_between_processes()
- print(header, logger)
- def evaluate(model, args):
- val_datasets = args.val_dataset or []
- if args.weights and args.test_only:
- weights = torchvision.models.get_weight(args.weights)
- trans = weights.transforms()
- def preprocessing(img1, img2, flow, valid_flow_mask):
- img1, img2 = trans(img1, img2)
- if flow is not None and not isinstance(flow, torch.Tensor):
- flow = torch.from_numpy(flow)
- if valid_flow_mask is not None and not isinstance(valid_flow_mask, torch.Tensor):
- valid_flow_mask = torch.from_numpy(valid_flow_mask)
- return img1, img2, flow, valid_flow_mask
- else:
- preprocessing = OpticalFlowPresetEval()
- for name in val_datasets:
- if name == "kitti":
- # Kitti has different image sizes, so we need to individually pad them, we can't batch.
- # see comment in InputPadder
- if args.batch_size != 1 and (not args.distributed or args.rank == 0):
- warnings.warn(
- f"Batch-size={args.batch_size} was passed. For technical reasons, evaluating on Kitti can only be done with a batch-size of 1."
- )
- val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=preprocessing)
- _evaluate(
- model, args, val_dataset, num_flow_updates=24, padder_mode="kitti", header="Kitti val", batch_size=1
- )
- elif name == "sintel":
- for pass_name in ("clean", "final"):
- val_dataset = Sintel(
- root=args.dataset_root, split="train", pass_name=pass_name, transforms=preprocessing
- )
- _evaluate(
- model,
- args,
- val_dataset,
- num_flow_updates=32,
- padder_mode="sintel",
- header=f"Sintel val {pass_name}",
- )
- else:
- warnings.warn(f"Can't validate on {val_dataset}, skipping.")
- def train_one_epoch(model, optimizer, scheduler, train_loader, logger, args):
- device = torch.device(args.device)
- for data_blob in logger.log_every(train_loader):
- optimizer.zero_grad()
- image1, image2, flow_gt, valid_flow_mask = (x.to(device) for x in data_blob)
- flow_predictions = model(image1, image2, num_flow_updates=args.num_flow_updates)
- loss = utils.sequence_loss(flow_predictions, flow_gt, valid_flow_mask, args.gamma)
- metrics, _ = utils.compute_metrics(flow_predictions[-1], flow_gt, valid_flow_mask)
- metrics.pop("f1")
- logger.update(loss=loss, **metrics)
- loss.backward()
- torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
- optimizer.step()
- scheduler.step()
- def main(args):
- utils.setup_ddp(args)
- args.test_only = args.train_dataset is None
- if args.distributed and args.device == "cpu":
- raise ValueError("The device must be cuda if we want to run in distributed mode using torchrun")
- device = torch.device(args.device)
- if args.use_deterministic_algorithms:
- torch.backends.cudnn.benchmark = False
- torch.use_deterministic_algorithms(True)
- else:
- torch.backends.cudnn.benchmark = True
- model = torchvision.models.get_model(args.model, weights=args.weights)
- if args.distributed:
- model = model.to(args.local_rank)
- model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
- model_without_ddp = model.module
- else:
- model.to(device)
- model_without_ddp = model
- if args.resume is not None:
- checkpoint = torch.load(args.resume, map_location="cpu")
- model_without_ddp.load_state_dict(checkpoint["model"])
- if args.test_only:
- # Set deterministic CUDNN algorithms, since they can affect epe a fair bit.
- torch.backends.cudnn.benchmark = False
- torch.backends.cudnn.deterministic = True
- evaluate(model, args)
- return
- print(f"Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
- train_dataset = get_train_dataset(args.train_dataset, args.dataset_root)
- optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.adamw_eps)
- scheduler = torch.optim.lr_scheduler.OneCycleLR(
- optimizer=optimizer,
- max_lr=args.lr,
- epochs=args.epochs,
- steps_per_epoch=ceil(len(train_dataset) / (args.world_size * args.batch_size)),
- pct_start=0.05,
- cycle_momentum=False,
- anneal_strategy="linear",
- )
- if args.resume is not None:
- optimizer.load_state_dict(checkpoint["optimizer"])
- scheduler.load_state_dict(checkpoint["scheduler"])
- args.start_epoch = checkpoint["epoch"] + 1
- else:
- args.start_epoch = 0
- torch.backends.cudnn.benchmark = True
- model.train()
- if args.freeze_batch_norm:
- utils.freeze_batch_norm(model.module)
- if args.distributed:
- sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True, drop_last=True)
- else:
- sampler = torch.utils.data.RandomSampler(train_dataset)
- train_loader = torch.utils.data.DataLoader(
- train_dataset,
- sampler=sampler,
- batch_size=args.batch_size,
- pin_memory=True,
- num_workers=args.workers,
- )
- logger = utils.MetricLogger()
- done = False
- for epoch in range(args.start_epoch, args.epochs):
- print(f"EPOCH {epoch}")
- if args.distributed:
- # needed on distributed mode, otherwise the data loading order would be the same for all epochs
- sampler.set_epoch(epoch)
- train_one_epoch(
- model=model,
- optimizer=optimizer,
- scheduler=scheduler,
- train_loader=train_loader,
- logger=logger,
- args=args,
- )
- # Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0
- print(f"Epoch {epoch} done. ", logger)
- if not args.distributed or args.rank == 0:
- checkpoint = {
- "model": model_without_ddp.state_dict(),
- "optimizer": optimizer.state_dict(),
- "scheduler": scheduler.state_dict(),
- "epoch": epoch,
- "args": args,
- }
- torch.save(checkpoint, Path(args.output_dir) / f"{args.name}_{epoch}.pth")
- torch.save(checkpoint, Path(args.output_dir) / f"{args.name}.pth")
- if epoch % args.val_freq == 0 or done:
- evaluate(model, args)
- model.train()
- if args.freeze_batch_norm:
- utils.freeze_batch_norm(model.module)
- def get_args_parser(add_help=True):
- parser = argparse.ArgumentParser(add_help=add_help, description="Train or evaluate an optical-flow model.")
- parser.add_argument(
- "--name",
- default="raft",
- type=str,
- help="The name of the experiment - determines the name of the files where weights are saved.",
- )
- parser.add_argument("--output-dir", default=".", type=str, help="Output dir where checkpoints will be stored.")
- parser.add_argument(
- "--resume",
- type=str,
- help="A path to previously saved weights. Used to re-start training from, or evaluate a pre-saved model.",
- )
- parser.add_argument("--workers", type=int, default=12, help="Number of workers for the data loading part.")
- parser.add_argument(
- "--train-dataset",
- type=str,
- help="The dataset to use for training. If not passed, only validation is performed (and you probably want to pass --resume).",
- )
- parser.add_argument("--val-dataset", type=str, nargs="+", help="The dataset(s) to use for validation.")
- parser.add_argument("--val-freq", type=int, default=2, help="Validate every X epochs")
- parser.add_argument("--epochs", type=int, default=20, help="The total number of epochs to train.")
- parser.add_argument("--batch-size", type=int, default=2)
- parser.add_argument("--lr", type=float, default=0.00002, help="Learning rate for AdamW optimizer")
- parser.add_argument("--weight-decay", type=float, default=0.00005, help="Weight decay for AdamW optimizer")
- parser.add_argument("--adamw-eps", type=float, default=1e-8, help="eps value for AdamW optimizer")
- parser.add_argument(
- "--freeze-batch-norm", action="store_true", help="Set BatchNorm modules of the model in eval mode."
- )
- parser.add_argument(
- "--model", type=str, default="raft_large", help="The name of the model to use - either raft_large or raft_small"
- )
- # TODO: resume and weights should be in an exclusive arg group
- parser.add_argument(
- "--num_flow_updates",
- type=int,
- default=12,
- help="number of updates (or 'iters') in the update operator of the model.",
- )
- parser.add_argument("--gamma", type=float, default=0.8, help="exponential weighting for loss. Must be < 1.")
- parser.add_argument("--dist-url", default="env://", help="URL used to set up distributed training")
- parser.add_argument(
- "--dataset-root",
- help="Root folder where the datasets are stored. Will be passed as the 'root' parameter of the datasets.",
- required=True,
- )
- parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load.")
- parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu, Default: cuda)")
- parser.add_argument(
- "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
- )
- return parser
- if __name__ == "__main__":
- args = get_args_parser().parse_args()
- Path(args.output_dir).mkdir(exist_ok=True)
- main(args)
|