train.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. import argparse
  2. import warnings
  3. from math import ceil
  4. from pathlib import Path
  5. import torch
  6. import torchvision.models.optical_flow
  7. import utils
  8. from presets import OpticalFlowPresetEval, OpticalFlowPresetTrain
  9. from torchvision.datasets import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel
  10. def get_train_dataset(stage, dataset_root):
  11. if stage == "chairs":
  12. transforms = OpticalFlowPresetTrain(crop_size=(368, 496), min_scale=0.1, max_scale=1.0, do_flip=True)
  13. return FlyingChairs(root=dataset_root, split="train", transforms=transforms)
  14. elif stage == "things":
  15. transforms = OpticalFlowPresetTrain(crop_size=(400, 720), min_scale=-0.4, max_scale=0.8, do_flip=True)
  16. return FlyingThings3D(root=dataset_root, split="train", pass_name="both", transforms=transforms)
  17. elif stage == "sintel_SKH": # S + K + H as from paper
  18. crop_size = (368, 768)
  19. transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.2, max_scale=0.6, do_flip=True)
  20. things_clean = FlyingThings3D(root=dataset_root, split="train", pass_name="clean", transforms=transforms)
  21. sintel = Sintel(root=dataset_root, split="train", pass_name="both", transforms=transforms)
  22. kitti_transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.3, max_scale=0.5, do_flip=True)
  23. kitti = KittiFlow(root=dataset_root, split="train", transforms=kitti_transforms)
  24. hd1k_transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.5, max_scale=0.2, do_flip=True)
  25. hd1k = HD1K(root=dataset_root, split="train", transforms=hd1k_transforms)
  26. # As future improvement, we could probably be using a distributed sampler here
  27. # The distribution is S(.71), T(.135), K(.135), H(.02)
  28. return 100 * sintel + 200 * kitti + 5 * hd1k + things_clean
  29. elif stage == "kitti":
  30. transforms = OpticalFlowPresetTrain(
  31. # resize and crop params
  32. crop_size=(288, 960),
  33. min_scale=-0.2,
  34. max_scale=0.4,
  35. stretch_prob=0,
  36. # flip params
  37. do_flip=False,
  38. # jitter params
  39. brightness=0.3,
  40. contrast=0.3,
  41. saturation=0.3,
  42. hue=0.3 / 3.14,
  43. asymmetric_jitter_prob=0,
  44. )
  45. return KittiFlow(root=dataset_root, split="train", transforms=transforms)
  46. else:
  47. raise ValueError(f"Unknown stage {stage}")
  48. @torch.no_grad()
  49. def _evaluate(model, args, val_dataset, *, padder_mode, num_flow_updates=None, batch_size=None, header=None):
  50. """Helper function to compute various metrics (epe, etc.) for a model on a given dataset.
  51. We process as many samples as possible with ddp, and process the rest on a single worker.
  52. """
  53. batch_size = batch_size or args.batch_size
  54. device = torch.device(args.device)
  55. model.eval()
  56. if args.distributed:
  57. sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True)
  58. else:
  59. sampler = torch.utils.data.SequentialSampler(val_dataset)
  60. val_loader = torch.utils.data.DataLoader(
  61. val_dataset,
  62. sampler=sampler,
  63. batch_size=batch_size,
  64. pin_memory=True,
  65. num_workers=args.workers,
  66. )
  67. num_flow_updates = num_flow_updates or args.num_flow_updates
  68. def inner_loop(blob):
  69. if blob[0].dim() == 3:
  70. # input is not batched, so we add an extra dim for consistency
  71. blob = [x[None, :, :, :] if x is not None else None for x in blob]
  72. image1, image2, flow_gt = blob[:3]
  73. valid_flow_mask = None if len(blob) == 3 else blob[-1]
  74. image1, image2 = image1.to(device), image2.to(device)
  75. padder = utils.InputPadder(image1.shape, mode=padder_mode)
  76. image1, image2 = padder.pad(image1, image2)
  77. flow_predictions = model(image1, image2, num_flow_updates=num_flow_updates)
  78. flow_pred = flow_predictions[-1]
  79. flow_pred = padder.unpad(flow_pred).cpu()
  80. metrics, num_pixels_tot = utils.compute_metrics(flow_pred, flow_gt, valid_flow_mask)
  81. # We compute per-pixel epe (epe) and per-image epe (called f1-epe in RAFT paper).
  82. # per-pixel epe: average epe of all pixels of all images
  83. # per-image epe: average epe on each image independently, then average over images
  84. for name in ("epe", "1px", "3px", "5px", "f1"): # f1 is called f1-all in paper
  85. logger.meters[name].update(metrics[name], n=num_pixels_tot)
  86. logger.meters["per_image_epe"].update(metrics["epe"], n=batch_size)
  87. logger = utils.MetricLogger()
  88. for meter_name in ("epe", "1px", "3px", "5px", "per_image_epe", "f1"):
  89. logger.add_meter(meter_name, fmt="{global_avg:.4f}")
  90. num_processed_samples = 0
  91. for blob in logger.log_every(val_loader, header=header, print_freq=None):
  92. inner_loop(blob)
  93. num_processed_samples += blob[0].shape[0] # batch size
  94. if args.distributed:
  95. num_processed_samples = utils.reduce_across_processes(num_processed_samples)
  96. print(
  97. f"Batch-processed {num_processed_samples} / {len(val_dataset)} samples. "
  98. "Going to process the remaining samples individually, if any."
  99. )
  100. if args.rank == 0: # we only need to process the rest on a single worker
  101. for i in range(num_processed_samples, len(val_dataset)):
  102. inner_loop(val_dataset[i])
  103. logger.synchronize_between_processes()
  104. print(header, logger)
  105. def evaluate(model, args):
  106. val_datasets = args.val_dataset or []
  107. if args.weights and args.test_only:
  108. weights = torchvision.models.get_weight(args.weights)
  109. trans = weights.transforms()
  110. def preprocessing(img1, img2, flow, valid_flow_mask):
  111. img1, img2 = trans(img1, img2)
  112. if flow is not None and not isinstance(flow, torch.Tensor):
  113. flow = torch.from_numpy(flow)
  114. if valid_flow_mask is not None and not isinstance(valid_flow_mask, torch.Tensor):
  115. valid_flow_mask = torch.from_numpy(valid_flow_mask)
  116. return img1, img2, flow, valid_flow_mask
  117. else:
  118. preprocessing = OpticalFlowPresetEval()
  119. for name in val_datasets:
  120. if name == "kitti":
  121. # Kitti has different image sizes, so we need to individually pad them, we can't batch.
  122. # see comment in InputPadder
  123. if args.batch_size != 1 and (not args.distributed or args.rank == 0):
  124. warnings.warn(
  125. f"Batch-size={args.batch_size} was passed. For technical reasons, evaluating on Kitti can only be done with a batch-size of 1."
  126. )
  127. val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=preprocessing)
  128. _evaluate(
  129. model, args, val_dataset, num_flow_updates=24, padder_mode="kitti", header="Kitti val", batch_size=1
  130. )
  131. elif name == "sintel":
  132. for pass_name in ("clean", "final"):
  133. val_dataset = Sintel(
  134. root=args.dataset_root, split="train", pass_name=pass_name, transforms=preprocessing
  135. )
  136. _evaluate(
  137. model,
  138. args,
  139. val_dataset,
  140. num_flow_updates=32,
  141. padder_mode="sintel",
  142. header=f"Sintel val {pass_name}",
  143. )
  144. else:
  145. warnings.warn(f"Can't validate on {val_dataset}, skipping.")
  146. def train_one_epoch(model, optimizer, scheduler, train_loader, logger, args):
  147. device = torch.device(args.device)
  148. for data_blob in logger.log_every(train_loader):
  149. optimizer.zero_grad()
  150. image1, image2, flow_gt, valid_flow_mask = (x.to(device) for x in data_blob)
  151. flow_predictions = model(image1, image2, num_flow_updates=args.num_flow_updates)
  152. loss = utils.sequence_loss(flow_predictions, flow_gt, valid_flow_mask, args.gamma)
  153. metrics, _ = utils.compute_metrics(flow_predictions[-1], flow_gt, valid_flow_mask)
  154. metrics.pop("f1")
  155. logger.update(loss=loss, **metrics)
  156. loss.backward()
  157. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
  158. optimizer.step()
  159. scheduler.step()
  160. def main(args):
  161. utils.setup_ddp(args)
  162. args.test_only = args.train_dataset is None
  163. if args.distributed and args.device == "cpu":
  164. raise ValueError("The device must be cuda if we want to run in distributed mode using torchrun")
  165. device = torch.device(args.device)
  166. if args.use_deterministic_algorithms:
  167. torch.backends.cudnn.benchmark = False
  168. torch.use_deterministic_algorithms(True)
  169. else:
  170. torch.backends.cudnn.benchmark = True
  171. model = torchvision.models.get_model(args.model, weights=args.weights)
  172. if args.distributed:
  173. model = model.to(args.local_rank)
  174. model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
  175. model_without_ddp = model.module
  176. else:
  177. model.to(device)
  178. model_without_ddp = model
  179. if args.resume is not None:
  180. checkpoint = torch.load(args.resume, map_location="cpu")
  181. model_without_ddp.load_state_dict(checkpoint["model"])
  182. if args.test_only:
  183. # Set deterministic CUDNN algorithms, since they can affect epe a fair bit.
  184. torch.backends.cudnn.benchmark = False
  185. torch.backends.cudnn.deterministic = True
  186. evaluate(model, args)
  187. return
  188. print(f"Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
  189. train_dataset = get_train_dataset(args.train_dataset, args.dataset_root)
  190. optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.adamw_eps)
  191. scheduler = torch.optim.lr_scheduler.OneCycleLR(
  192. optimizer=optimizer,
  193. max_lr=args.lr,
  194. epochs=args.epochs,
  195. steps_per_epoch=ceil(len(train_dataset) / (args.world_size * args.batch_size)),
  196. pct_start=0.05,
  197. cycle_momentum=False,
  198. anneal_strategy="linear",
  199. )
  200. if args.resume is not None:
  201. optimizer.load_state_dict(checkpoint["optimizer"])
  202. scheduler.load_state_dict(checkpoint["scheduler"])
  203. args.start_epoch = checkpoint["epoch"] + 1
  204. else:
  205. args.start_epoch = 0
  206. torch.backends.cudnn.benchmark = True
  207. model.train()
  208. if args.freeze_batch_norm:
  209. utils.freeze_batch_norm(model.module)
  210. if args.distributed:
  211. sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True, drop_last=True)
  212. else:
  213. sampler = torch.utils.data.RandomSampler(train_dataset)
  214. train_loader = torch.utils.data.DataLoader(
  215. train_dataset,
  216. sampler=sampler,
  217. batch_size=args.batch_size,
  218. pin_memory=True,
  219. num_workers=args.workers,
  220. )
  221. logger = utils.MetricLogger()
  222. done = False
  223. for epoch in range(args.start_epoch, args.epochs):
  224. print(f"EPOCH {epoch}")
  225. if args.distributed:
  226. # needed on distributed mode, otherwise the data loading order would be the same for all epochs
  227. sampler.set_epoch(epoch)
  228. train_one_epoch(
  229. model=model,
  230. optimizer=optimizer,
  231. scheduler=scheduler,
  232. train_loader=train_loader,
  233. logger=logger,
  234. args=args,
  235. )
  236. # Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0
  237. print(f"Epoch {epoch} done. ", logger)
  238. if not args.distributed or args.rank == 0:
  239. checkpoint = {
  240. "model": model_without_ddp.state_dict(),
  241. "optimizer": optimizer.state_dict(),
  242. "scheduler": scheduler.state_dict(),
  243. "epoch": epoch,
  244. "args": args,
  245. }
  246. torch.save(checkpoint, Path(args.output_dir) / f"{args.name}_{epoch}.pth")
  247. torch.save(checkpoint, Path(args.output_dir) / f"{args.name}.pth")
  248. if epoch % args.val_freq == 0 or done:
  249. evaluate(model, args)
  250. model.train()
  251. if args.freeze_batch_norm:
  252. utils.freeze_batch_norm(model.module)
  253. def get_args_parser(add_help=True):
  254. parser = argparse.ArgumentParser(add_help=add_help, description="Train or evaluate an optical-flow model.")
  255. parser.add_argument(
  256. "--name",
  257. default="raft",
  258. type=str,
  259. help="The name of the experiment - determines the name of the files where weights are saved.",
  260. )
  261. parser.add_argument("--output-dir", default=".", type=str, help="Output dir where checkpoints will be stored.")
  262. parser.add_argument(
  263. "--resume",
  264. type=str,
  265. help="A path to previously saved weights. Used to re-start training from, or evaluate a pre-saved model.",
  266. )
  267. parser.add_argument("--workers", type=int, default=12, help="Number of workers for the data loading part.")
  268. parser.add_argument(
  269. "--train-dataset",
  270. type=str,
  271. help="The dataset to use for training. If not passed, only validation is performed (and you probably want to pass --resume).",
  272. )
  273. parser.add_argument("--val-dataset", type=str, nargs="+", help="The dataset(s) to use for validation.")
  274. parser.add_argument("--val-freq", type=int, default=2, help="Validate every X epochs")
  275. parser.add_argument("--epochs", type=int, default=20, help="The total number of epochs to train.")
  276. parser.add_argument("--batch-size", type=int, default=2)
  277. parser.add_argument("--lr", type=float, default=0.00002, help="Learning rate for AdamW optimizer")
  278. parser.add_argument("--weight-decay", type=float, default=0.00005, help="Weight decay for AdamW optimizer")
  279. parser.add_argument("--adamw-eps", type=float, default=1e-8, help="eps value for AdamW optimizer")
  280. parser.add_argument(
  281. "--freeze-batch-norm", action="store_true", help="Set BatchNorm modules of the model in eval mode."
  282. )
  283. parser.add_argument(
  284. "--model", type=str, default="raft_large", help="The name of the model to use - either raft_large or raft_small"
  285. )
  286. # TODO: resume and weights should be in an exclusive arg group
  287. parser.add_argument(
  288. "--num_flow_updates",
  289. type=int,
  290. default=12,
  291. help="number of updates (or 'iters') in the update operator of the model.",
  292. )
  293. parser.add_argument("--gamma", type=float, default=0.8, help="exponential weighting for loss. Must be < 1.")
  294. parser.add_argument("--dist-url", default="env://", help="URL used to set up distributed training")
  295. parser.add_argument(
  296. "--dataset-root",
  297. help="Root folder where the datasets are stored. Will be passed as the 'root' parameter of the datasets.",
  298. required=True,
  299. )
  300. parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load.")
  301. parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu, Default: cuda)")
  302. parser.add_argument(
  303. "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
  304. )
  305. return parser
  306. if __name__ == "__main__":
  307. args = get_args_parser().parse_args()
  308. Path(args.output_dir).mkdir(exist_ok=True)
  309. main(args)