train.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788
  1. import argparse
  2. import os
  3. import warnings
  4. from pathlib import Path
  5. from typing import List, Union
  6. import numpy as np
  7. import torch
  8. import torch.distributed as dist
  9. import torchvision.models.optical_flow
  10. import torchvision.prototype.models.depth.stereo
  11. import utils
  12. import visualization
  13. from parsing import make_dataset, make_eval_transform, make_train_transform, VALID_DATASETS
  14. from torch import nn
  15. from torchvision.transforms.functional import get_dimensions, InterpolationMode, resize
  16. from utils.metrics import AVAILABLE_METRICS
  17. from utils.norm import freeze_batch_norm
  18. def make_stereo_flow(flow: Union[torch.Tensor, List[torch.Tensor]], model_out_channels: int) -> torch.Tensor:
  19. """Helper function to make stereo flow from a given model output"""
  20. if isinstance(flow, list):
  21. return [make_stereo_flow(flow_i, model_out_channels) for flow_i in flow]
  22. B, C, H, W = flow.shape
  23. # we need to add zero flow if the model outputs 2 channels
  24. if C == 1 and model_out_channels == 2:
  25. zero_flow = torch.zeros_like(flow)
  26. # by convention the flow is X-Y axis, so we need the Y flow last
  27. flow = torch.cat([flow, zero_flow], dim=1)
  28. return flow
  29. def make_lr_schedule(args: argparse.Namespace, optimizer: torch.optim.Optimizer) -> np.ndarray:
  30. """Helper function to return a learning rate scheduler for CRE-stereo"""
  31. if args.decay_after_steps < args.warmup_steps:
  32. raise ValueError(f"decay_after_steps: {args.function} must be greater than warmup_steps: {args.warmup_steps}")
  33. warmup_steps = args.warmup_steps if args.warmup_steps else 0
  34. flat_lr_steps = args.decay_after_steps - warmup_steps if args.decay_after_steps else 0
  35. decay_lr_steps = args.total_iterations - flat_lr_steps
  36. max_lr = args.lr
  37. min_lr = args.min_lr
  38. schedulers = []
  39. milestones = []
  40. if warmup_steps > 0:
  41. if args.lr_warmup_method == "linear":
  42. warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
  43. optimizer, start_factor=args.lr_warmup_factor, total_iters=warmup_steps
  44. )
  45. elif args.lr_warmup_method == "constant":
  46. warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
  47. optimizer, factor=args.lr_warmup_factor, total_iters=warmup_steps
  48. )
  49. else:
  50. raise ValueError(f"Unknown lr warmup method {args.lr_warmup_method}")
  51. schedulers.append(warmup_lr_scheduler)
  52. milestones.append(warmup_steps)
  53. if flat_lr_steps > 0:
  54. flat_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=max_lr, total_iters=flat_lr_steps)
  55. schedulers.append(flat_lr_scheduler)
  56. milestones.append(flat_lr_steps + warmup_steps)
  57. if decay_lr_steps > 0:
  58. if args.lr_decay_method == "cosine":
  59. decay_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
  60. optimizer, T_max=decay_lr_steps, eta_min=min_lr
  61. )
  62. elif args.lr_decay_method == "linear":
  63. decay_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
  64. optimizer, start_factor=max_lr, end_factor=min_lr, total_iters=decay_lr_steps
  65. )
  66. elif args.lr_decay_method == "exponential":
  67. decay_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
  68. optimizer, gamma=args.lr_decay_gamma, last_epoch=-1
  69. )
  70. else:
  71. raise ValueError(f"Unknown lr decay method {args.lr_decay_method}")
  72. schedulers.append(decay_lr_scheduler)
  73. scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers, milestones=milestones)
  74. return scheduler
  75. def shuffle_dataset(dataset):
  76. """Shuffle the dataset"""
  77. perm = torch.randperm(len(dataset))
  78. return torch.utils.data.Subset(dataset, perm)
  79. def resize_dataset_to_n_steps(
  80. dataset: torch.utils.data.Dataset, dataset_steps: int, samples_per_step: int, args: argparse.Namespace
  81. ) -> torch.utils.data.Dataset:
  82. original_size = len(dataset)
  83. if args.steps_is_epochs:
  84. samples_per_step = original_size
  85. target_size = dataset_steps * samples_per_step
  86. dataset_copies = []
  87. n_expands, remainder = divmod(target_size, original_size)
  88. for idx in range(n_expands):
  89. dataset_copies.append(dataset)
  90. if remainder > 0:
  91. dataset_copies.append(torch.utils.data.Subset(dataset, list(range(remainder))))
  92. if args.dataset_shuffle:
  93. dataset_copies = [shuffle_dataset(dataset_copy) for dataset_copy in dataset_copies]
  94. dataset = torch.utils.data.ConcatDataset(dataset_copies)
  95. return dataset
  96. def get_train_dataset(dataset_root: str, args: argparse.Namespace) -> torch.utils.data.Dataset:
  97. datasets = []
  98. for dataset_name in args.train_datasets:
  99. transform = make_train_transform(args)
  100. dataset = make_dataset(dataset_name, dataset_root, transform)
  101. datasets.append(dataset)
  102. if len(datasets) == 0:
  103. raise ValueError("No datasets specified for training")
  104. samples_per_step = args.world_size * args.batch_size
  105. for idx, (dataset, steps_per_dataset) in enumerate(zip(datasets, args.dataset_steps)):
  106. datasets[idx] = resize_dataset_to_n_steps(dataset, steps_per_dataset, samples_per_step, args)
  107. dataset = torch.utils.data.ConcatDataset(datasets)
  108. if args.dataset_order_shuffle:
  109. dataset = shuffle_dataset(dataset)
  110. print(f"Training dataset: {len(dataset)} samples")
  111. return dataset
  112. @torch.inference_mode()
  113. def _evaluate(
  114. model,
  115. args,
  116. val_loader,
  117. *,
  118. padder_mode,
  119. print_freq=10,
  120. writer=None,
  121. step=None,
  122. iterations=None,
  123. batch_size=None,
  124. header=None,
  125. ):
  126. """Helper function to compute various metrics (epe, etc.) for a model on a given dataset."""
  127. model.eval()
  128. header = header or "Test:"
  129. device = torch.device(args.device)
  130. metric_logger = utils.MetricLogger(delimiter=" ")
  131. iterations = iterations or args.recurrent_updates
  132. logger = utils.MetricLogger()
  133. for meter_name in args.metrics:
  134. logger.add_meter(meter_name, fmt="{global_avg:.4f}")
  135. if "fl-all" not in args.metrics:
  136. logger.add_meter("fl-all", fmt="{global_avg:.4f}")
  137. num_processed_samples = 0
  138. with torch.cuda.amp.autocast(enabled=args.mixed_precision, dtype=torch.float16):
  139. for blob in metric_logger.log_every(val_loader, print_freq, header):
  140. image_left, image_right, disp_gt, valid_disp_mask = (x.to(device) for x in blob)
  141. padder = utils.InputPadder(image_left.shape, mode=padder_mode)
  142. image_left, image_right = padder.pad(image_left, image_right)
  143. disp_predictions = model(image_left, image_right, flow_init=None, num_iters=iterations)
  144. disp_pred = disp_predictions[-1][:, :1, :, :]
  145. disp_pred = padder.unpad(disp_pred)
  146. metrics, _ = utils.compute_metrics(disp_pred, disp_gt, valid_disp_mask, metrics=logger.meters.keys())
  147. num_processed_samples += image_left.shape[0]
  148. for name in metrics:
  149. logger.meters[name].update(metrics[name], n=1)
  150. num_processed_samples = utils.reduce_across_processes(num_processed_samples)
  151. print("Num_processed_samples: ", num_processed_samples)
  152. if (
  153. hasattr(val_loader.dataset, "__len__")
  154. and len(val_loader.dataset) != num_processed_samples
  155. and torch.distributed.get_rank() == 0
  156. ):
  157. warnings.warn(
  158. f"Number of processed samples {num_processed_samples} is different"
  159. f"from the dataset size {len(val_loader.dataset)}. This may happen if"
  160. "the dataset is not divisible by the batch size. Try lowering the batch size or GPU number for more accurate results."
  161. )
  162. if writer is not None and args.rank == 0:
  163. for meter_name, meter_value in logger.meters.items():
  164. scalar_name = f"{meter_name} {header}"
  165. writer.add_scalar(scalar_name, meter_value.avg, step)
  166. logger.synchronize_between_processes()
  167. print(header, logger)
  168. def make_eval_loader(dataset_name: str, args: argparse.Namespace) -> torch.utils.data.DataLoader:
  169. if args.weights:
  170. weights = torchvision.models.get_weight(args.weights)
  171. trans = weights.transforms()
  172. def preprocessing(image_left, image_right, disp, valid_disp_mask):
  173. C_o, H_o, W_o = get_dimensions(image_left)
  174. image_left, image_right = trans(image_left, image_right)
  175. C_t, H_t, W_t = get_dimensions(image_left)
  176. scale_factor = W_t / W_o
  177. if disp is not None and not isinstance(disp, torch.Tensor):
  178. disp = torch.from_numpy(disp)
  179. if W_t != W_o:
  180. disp = resize(disp, (H_t, W_t), mode=InterpolationMode.BILINEAR) * scale_factor
  181. if valid_disp_mask is not None and not isinstance(valid_disp_mask, torch.Tensor):
  182. valid_disp_mask = torch.from_numpy(valid_disp_mask)
  183. if W_t != W_o:
  184. valid_disp_mask = resize(valid_disp_mask, (H_t, W_t), mode=InterpolationMode.NEAREST)
  185. return image_left, image_right, disp, valid_disp_mask
  186. else:
  187. preprocessing = make_eval_transform(args)
  188. val_dataset = make_dataset(dataset_name, args.dataset_root, transforms=preprocessing)
  189. if args.distributed:
  190. sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=False)
  191. else:
  192. sampler = torch.utils.data.SequentialSampler(val_dataset)
  193. val_loader = torch.utils.data.DataLoader(
  194. val_dataset,
  195. sampler=sampler,
  196. batch_size=args.batch_size,
  197. pin_memory=True,
  198. num_workers=args.workers,
  199. )
  200. return val_loader
  201. def evaluate(model, loaders, args, writer=None, step=None):
  202. for loader_name, loader in loaders.items():
  203. _evaluate(
  204. model,
  205. args,
  206. loader,
  207. iterations=args.recurrent_updates,
  208. padder_mode=args.padder_type,
  209. header=f"{loader_name} evaluation",
  210. batch_size=args.batch_size,
  211. writer=writer,
  212. step=step,
  213. )
  214. def run(model, optimizer, scheduler, train_loader, val_loaders, logger, writer, scaler, args):
  215. device = torch.device(args.device)
  216. # wrap the loader in a logger
  217. loader = iter(logger.log_every(train_loader))
  218. # output channels
  219. model_out_channels = model.module.output_channels if args.distributed else model.output_channels
  220. torch.set_num_threads(args.threads)
  221. sequence_criterion = utils.SequenceLoss(
  222. gamma=args.gamma,
  223. max_flow=args.max_disparity,
  224. exclude_large_flows=args.flow_loss_exclude_large,
  225. ).to(device)
  226. if args.consistency_weight:
  227. consistency_criterion = utils.FlowSequenceConsistencyLoss(
  228. args.gamma,
  229. resize_factor=0.25,
  230. rescale_factor=0.25,
  231. rescale_mode="bilinear",
  232. ).to(device)
  233. else:
  234. consistency_criterion = None
  235. if args.psnr_weight:
  236. psnr_criterion = utils.PSNRLoss().to(device)
  237. else:
  238. psnr_criterion = None
  239. if args.smoothness_weight:
  240. smoothness_criterion = utils.SmoothnessLoss().to(device)
  241. else:
  242. smoothness_criterion = None
  243. if args.photometric_weight:
  244. photometric_criterion = utils.FlowPhotoMetricLoss(
  245. ssim_weight=args.photometric_ssim_weight,
  246. max_displacement_ratio=args.photometric_max_displacement_ratio,
  247. ssim_use_padding=False,
  248. ).to(device)
  249. else:
  250. photometric_criterion = None
  251. for step in range(args.start_step + 1, args.total_iterations + 1):
  252. data_blob = next(loader)
  253. optimizer.zero_grad()
  254. # unpack the data blob
  255. image_left, image_right, disp_mask, valid_disp_mask = (x.to(device) for x in data_blob)
  256. with torch.cuda.amp.autocast(enabled=args.mixed_precision, dtype=torch.float16):
  257. disp_predictions = model(image_left, image_right, flow_init=None, num_iters=args.recurrent_updates)
  258. # different models have different outputs, make sure we get the right ones for this task
  259. disp_predictions = make_stereo_flow(disp_predictions, model_out_channels)
  260. # should the architecture or training loop require it, we have to adjust the disparity mask
  261. # target to possibly look like an optical flow mask
  262. disp_mask = make_stereo_flow(disp_mask, model_out_channels)
  263. # sequence loss on top of the model outputs
  264. loss = sequence_criterion(disp_predictions, disp_mask, valid_disp_mask) * args.flow_loss_weight
  265. if args.consistency_weight > 0:
  266. loss_consistency = consistency_criterion(disp_predictions)
  267. loss += loss_consistency * args.consistency_weight
  268. if args.psnr_weight > 0:
  269. loss_psnr = 0.0
  270. for pred in disp_predictions:
  271. # predictions might have 2 channels
  272. loss_psnr += psnr_criterion(
  273. pred * valid_disp_mask.unsqueeze(1),
  274. disp_mask * valid_disp_mask.unsqueeze(1),
  275. ).mean() # mean the psnr loss over the batch
  276. loss += loss_psnr / len(disp_predictions) * args.psnr_weight
  277. if args.photometric_weight > 0:
  278. loss_photometric = 0.0
  279. for pred in disp_predictions:
  280. # predictions might have 1 channel, therefore we need to inpute 0s for the second channel
  281. if model_out_channels == 1:
  282. pred = torch.cat([pred, torch.zeros_like(pred)], dim=1)
  283. loss_photometric += photometric_criterion(
  284. image_left, image_right, pred, valid_disp_mask
  285. ) # photometric loss already comes out meaned over the batch
  286. loss += loss_photometric / len(disp_predictions) * args.photometric_weight
  287. if args.smoothness_weight > 0:
  288. loss_smoothness = 0.0
  289. for pred in disp_predictions:
  290. # predictions might have 2 channels
  291. loss_smoothness += smoothness_criterion(
  292. image_left, pred[:, :1, :, :]
  293. ).mean() # mean the smoothness loss over the batch
  294. loss += loss_smoothness / len(disp_predictions) * args.smoothness_weight
  295. with torch.no_grad():
  296. metrics, _ = utils.compute_metrics(
  297. disp_predictions[-1][:, :1, :, :], # predictions might have 2 channels
  298. disp_mask[:, :1, :, :], # so does the ground truth
  299. valid_disp_mask,
  300. args.metrics,
  301. )
  302. metrics.pop("fl-all", None)
  303. logger.update(loss=loss, **metrics)
  304. if scaler is not None:
  305. scaler.scale(loss).backward()
  306. scaler.unscale_(optimizer)
  307. if args.clip_grad_norm:
  308. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_grad_norm)
  309. scaler.step(optimizer)
  310. scaler.update()
  311. else:
  312. loss.backward()
  313. if args.clip_grad_norm:
  314. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_grad_norm)
  315. optimizer.step()
  316. scheduler.step()
  317. if not dist.is_initialized() or dist.get_rank() == 0:
  318. if writer is not None and step % args.tensorboard_log_frequency == 0:
  319. # log the loss and metrics to tensorboard
  320. writer.add_scalar("loss", loss, step)
  321. for name, value in logger.meters.items():
  322. writer.add_scalar(name, value.avg, step)
  323. # log the images to tensorboard
  324. pred_grid = visualization.make_training_sample_grid(
  325. image_left, image_right, disp_mask, valid_disp_mask, disp_predictions
  326. )
  327. writer.add_image("predictions", pred_grid, step, dataformats="HWC")
  328. # second thing we want to see is how relevant the iterative refinement is
  329. pred_sequence_grid = visualization.make_disparity_sequence_grid(disp_predictions, disp_mask)
  330. writer.add_image("sequence", pred_sequence_grid, step, dataformats="HWC")
  331. if step % args.save_frequency == 0:
  332. if not args.distributed or args.rank == 0:
  333. model_without_ddp = (
  334. model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
  335. )
  336. checkpoint = {
  337. "model": model_without_ddp.state_dict(),
  338. "optimizer": optimizer.state_dict(),
  339. "scheduler": scheduler.state_dict(),
  340. "step": step,
  341. "args": args,
  342. }
  343. os.makedirs(args.checkpoint_dir, exist_ok=True)
  344. torch.save(checkpoint, Path(args.checkpoint_dir) / f"{args.name}_{step}.pth")
  345. torch.save(checkpoint, Path(args.checkpoint_dir) / f"{args.name}.pth")
  346. if step % args.valid_frequency == 0:
  347. evaluate(model, val_loaders, args, writer, step)
  348. model.train()
  349. if args.freeze_batch_norm:
  350. if isinstance(model, nn.parallel.DistributedDataParallel):
  351. freeze_batch_norm(model.module)
  352. else:
  353. freeze_batch_norm(model)
  354. # one final save at the end
  355. if not args.distributed or args.rank == 0:
  356. model_without_ddp = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
  357. checkpoint = {
  358. "model": model_without_ddp.state_dict(),
  359. "optimizer": optimizer.state_dict(),
  360. "scheduler": scheduler.state_dict(),
  361. "step": step,
  362. "args": args,
  363. }
  364. os.makedirs(args.checkpoint_dir, exist_ok=True)
  365. torch.save(checkpoint, Path(args.checkpoint_dir) / f"{args.name}_{step}.pth")
  366. torch.save(checkpoint, Path(args.checkpoint_dir) / f"{args.name}.pth")
  367. def main(args):
  368. args.total_iterations = sum(args.dataset_steps)
  369. # initialize DDP setting
  370. utils.setup_ddp(args)
  371. print(args)
  372. args.test_only = args.train_datasets is None
  373. # set the appropriate devices
  374. if args.distributed and args.device == "cpu":
  375. raise ValueError("The device must be cuda if we want to run in distributed mode using torchrun")
  376. device = torch.device(args.device)
  377. # select model architecture
  378. model = torchvision.prototype.models.depth.stereo.__dict__[args.model](weights=args.weights)
  379. # convert to DDP if need be
  380. if args.distributed:
  381. model = model.to(args.gpu)
  382. model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
  383. model_without_ddp = model.module
  384. else:
  385. model.to(device)
  386. model_without_ddp = model
  387. os.makedirs(args.checkpoint_dir, exist_ok=True)
  388. val_loaders = {name: make_eval_loader(name, args) for name in args.test_datasets}
  389. # EVAL ONLY configurations
  390. if args.test_only:
  391. evaluate(model, val_loaders, args)
  392. return
  393. # Sanity check for the parameter count
  394. print(f"Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
  395. # Compose the training dataset
  396. train_dataset = get_train_dataset(args.dataset_root, args)
  397. # initialize the optimizer
  398. if args.optimizer == "adam":
  399. optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
  400. elif args.optimizer == "sgd":
  401. optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9)
  402. else:
  403. raise ValueError(f"Unknown optimizer {args.optimizer}. Please choose between adam and sgd")
  404. # initialize the learning rate schedule
  405. scheduler = make_lr_schedule(args, optimizer)
  406. # load them from checkpoint if needed
  407. args.start_step = 0
  408. if args.resume_path is not None:
  409. checkpoint = torch.load(args.resume_path, map_location="cpu")
  410. if "model" in checkpoint:
  411. # this means the user requested to resume from a training checkpoint
  412. model_without_ddp.load_state_dict(checkpoint["model"])
  413. # this means the user wants to continue training from where it was left off
  414. if args.resume_schedule:
  415. optimizer.load_state_dict(checkpoint["optimizer"])
  416. scheduler.load_state_dict(checkpoint["scheduler"])
  417. args.start_step = checkpoint["step"] + 1
  418. # modify starting point of the dat
  419. sample_start_step = args.start_step * args.batch_size * args.world_size
  420. train_dataset = train_dataset[sample_start_step:]
  421. else:
  422. # this means the user wants to finetune on top of a model state dict
  423. # and that no other changes are required
  424. model_without_ddp.load_state_dict(checkpoint)
  425. torch.backends.cudnn.benchmark = True
  426. # enable training mode
  427. model.train()
  428. if args.freeze_batch_norm:
  429. freeze_batch_norm(model_without_ddp)
  430. # put dataloader on top of the dataset
  431. # make sure to disable shuffling since the dataset is already shuffled
  432. # in order to guarantee quasi randomness whilst retaining a deterministic
  433. # dataset consumption order
  434. if args.distributed:
  435. # the train dataset is preshuffled in order to respect the iteration order
  436. sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=False, drop_last=True)
  437. else:
  438. # the train dataset is already shuffled, so we can use a simple SequentialSampler
  439. sampler = torch.utils.data.SequentialSampler(train_dataset)
  440. train_loader = torch.utils.data.DataLoader(
  441. train_dataset,
  442. sampler=sampler,
  443. batch_size=args.batch_size,
  444. pin_memory=True,
  445. num_workers=args.workers,
  446. )
  447. # initialize the logger
  448. if args.tensorboard_summaries:
  449. from torch.utils.tensorboard import SummaryWriter
  450. tensorboard_path = Path(args.checkpoint_dir) / "tensorboard"
  451. os.makedirs(tensorboard_path, exist_ok=True)
  452. tensorboard_run = tensorboard_path / f"{args.name}"
  453. writer = SummaryWriter(tensorboard_run)
  454. else:
  455. writer = None
  456. logger = utils.MetricLogger(delimiter=" ")
  457. scaler = torch.cuda.amp.GradScaler() if args.mixed_precision else None
  458. # run the training loop
  459. # this will perform optimization, respectively logging and saving checkpoints
  460. # when need be
  461. run(
  462. model=model,
  463. optimizer=optimizer,
  464. scheduler=scheduler,
  465. train_loader=train_loader,
  466. val_loaders=val_loaders,
  467. logger=logger,
  468. writer=writer,
  469. scaler=scaler,
  470. args=args,
  471. )
  472. def get_args_parser(add_help=True):
  473. import argparse
  474. parser = argparse.ArgumentParser(description="PyTorch Stereo Matching Training", add_help=add_help)
  475. # checkpointing
  476. parser.add_argument("--name", default="crestereo", help="name of the experiment")
  477. parser.add_argument("--resume", type=str, default=None, help="from which checkpoint to resume")
  478. parser.add_argument("--checkpoint-dir", type=str, default="checkpoints", help="path to the checkpoint directory")
  479. # dataset
  480. parser.add_argument("--dataset-root", type=str, default="", help="path to the dataset root directory")
  481. parser.add_argument(
  482. "--train-datasets",
  483. type=str,
  484. nargs="+",
  485. default=["crestereo"],
  486. help="dataset(s) to train on",
  487. choices=list(VALID_DATASETS.keys()),
  488. )
  489. parser.add_argument(
  490. "--dataset-steps", type=int, nargs="+", default=[300_000], help="number of steps for each dataset"
  491. )
  492. parser.add_argument(
  493. "--steps-is-epochs", action="store_true", help="if set, dataset-steps are interpreted as epochs"
  494. )
  495. parser.add_argument(
  496. "--test-datasets",
  497. type=str,
  498. nargs="+",
  499. default=["middlebury2014-train"],
  500. help="dataset(s) to test on",
  501. choices=["middlebury2014-train"],
  502. )
  503. parser.add_argument("--dataset-shuffle", type=bool, help="shuffle the dataset", default=True)
  504. parser.add_argument("--dataset-order-shuffle", type=bool, help="shuffle the dataset order", default=True)
  505. parser.add_argument("--batch-size", type=int, default=2, help="batch size per GPU")
  506. parser.add_argument("--workers", type=int, default=4, help="number of workers per GPU")
  507. parser.add_argument(
  508. "--threads",
  509. type=int,
  510. default=16,
  511. help="number of CPU threads per GPU. This can be changed around to speed-up transforms if needed. This can lead to worker thread contention so use with care.",
  512. )
  513. # model architecture
  514. parser.add_argument(
  515. "--model",
  516. type=str,
  517. default="crestereo_base",
  518. help="model architecture",
  519. choices=["crestereo_base", "raft_stereo"],
  520. )
  521. parser.add_argument("--recurrent-updates", type=int, default=10, help="number of recurrent updates")
  522. parser.add_argument("--freeze-batch-norm", action="store_true", help="freeze batch norm parameters")
  523. # loss parameters
  524. parser.add_argument("--gamma", type=float, default=0.8, help="gamma parameter for the flow sequence loss")
  525. parser.add_argument("--flow-loss-weight", type=float, default=1.0, help="weight for the flow loss")
  526. parser.add_argument(
  527. "--flow-loss-exclude-large",
  528. action="store_true",
  529. help="exclude large flow values from the loss. A large value is defined as a value greater than the ground truth flow norm",
  530. default=False,
  531. )
  532. parser.add_argument("--consistency-weight", type=float, default=0.0, help="consistency loss weight")
  533. parser.add_argument(
  534. "--consistency-resize-factor",
  535. type=float,
  536. default=0.25,
  537. help="consistency loss resize factor to account for the fact that the flow is computed on a downsampled image",
  538. )
  539. parser.add_argument("--psnr-weight", type=float, default=0.0, help="psnr loss weight")
  540. parser.add_argument("--smoothness-weight", type=float, default=0.0, help="smoothness loss weight")
  541. parser.add_argument("--photometric-weight", type=float, default=0.0, help="photometric loss weight")
  542. parser.add_argument(
  543. "--photometric-max-displacement-ratio",
  544. type=float,
  545. default=0.15,
  546. help="Only pixels with a displacement smaller than this ratio of the image width will be considered for the photometric loss",
  547. )
  548. parser.add_argument("--photometric-ssim-weight", type=float, default=0.85, help="photometric ssim loss weight")
  549. # transforms parameters
  550. parser.add_argument("--gpu-transforms", action="store_true", help="use GPU transforms")
  551. parser.add_argument(
  552. "--eval-size", type=int, nargs="+", default=[384, 512], help="size of the images for evaluation"
  553. )
  554. parser.add_argument("--resize-size", type=int, nargs=2, default=None, help="resize size")
  555. parser.add_argument("--crop-size", type=int, nargs=2, default=[384, 512], help="crop size")
  556. parser.add_argument("--scale-range", type=float, nargs=2, default=[0.6, 1.0], help="random scale range")
  557. parser.add_argument("--rescale-prob", type=float, default=1.0, help="probability of resizing the image")
  558. parser.add_argument(
  559. "--scaling-type", type=str, default="linear", help="scaling type", choices=["exponential", "linear"]
  560. )
  561. parser.add_argument("--flip-prob", type=float, default=0.5, help="probability of flipping the image")
  562. parser.add_argument(
  563. "--norm-mean", type=float, nargs="+", default=[0.5, 0.5, 0.5], help="mean for image normalization"
  564. )
  565. parser.add_argument(
  566. "--norm-std", type=float, nargs="+", default=[0.5, 0.5, 0.5], help="std for image normalization"
  567. )
  568. parser.add_argument(
  569. "--use-grayscale", action="store_true", help="use grayscale images instead of RGB", default=False
  570. )
  571. parser.add_argument("--max-disparity", type=float, default=None, help="maximum disparity")
  572. parser.add_argument(
  573. "--interpolation-strategy",
  574. type=str,
  575. default="bilinear",
  576. help="interpolation strategy",
  577. choices=["bilinear", "bicubic", "mixed"],
  578. )
  579. parser.add_argument("--spatial-shift-prob", type=float, default=1.0, help="probability of shifting the image")
  580. parser.add_argument(
  581. "--spatial-shift-max-angle", type=float, default=0.1, help="maximum angle for the spatial shift"
  582. )
  583. parser.add_argument(
  584. "--spatial-shift-max-displacement", type=float, default=2.0, help="maximum displacement for the spatial shift"
  585. )
  586. parser.add_argument("--gamma-range", type=float, nargs="+", default=[0.8, 1.2], help="range for gamma correction")
  587. parser.add_argument(
  588. "--brightness-range", type=float, nargs="+", default=[0.8, 1.2], help="range for brightness correction"
  589. )
  590. parser.add_argument(
  591. "--contrast-range", type=float, nargs="+", default=[0.8, 1.2], help="range for contrast correction"
  592. )
  593. parser.add_argument(
  594. "--saturation-range", type=float, nargs="+", default=0.0, help="range for saturation correction"
  595. )
  596. parser.add_argument("--hue-range", type=float, nargs="+", default=0.0, help="range for hue correction")
  597. parser.add_argument(
  598. "--asymmetric-jitter-prob",
  599. type=float,
  600. default=1.0,
  601. help="probability of using asymmetric jitter instead of symmetric jitter",
  602. )
  603. parser.add_argument("--occlusion-prob", type=float, default=0.5, help="probability of occluding the rightimage")
  604. parser.add_argument(
  605. "--occlusion-px-range", type=int, nargs="+", default=[50, 100], help="range for the number of occluded pixels"
  606. )
  607. parser.add_argument("--erase-prob", type=float, default=0.0, help="probability of erasing in both images")
  608. parser.add_argument(
  609. "--erase-px-range", type=int, nargs="+", default=[50, 100], help="range for the number of erased pixels"
  610. )
  611. parser.add_argument(
  612. "--erase-num-repeats", type=int, default=1, help="number of times to repeat the erase operation"
  613. )
  614. # optimizer parameters
  615. parser.add_argument("--optimizer", type=str, default="adam", help="optimizer", choices=["adam", "sgd"])
  616. parser.add_argument("--lr", type=float, default=4e-4, help="learning rate")
  617. parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay")
  618. parser.add_argument("--clip-grad-norm", type=float, default=0.0, help="clip grad norm")
  619. # lr_scheduler parameters
  620. parser.add_argument("--min-lr", type=float, default=2e-5, help="minimum learning rate")
  621. parser.add_argument("--warmup-steps", type=int, default=6_000, help="number of warmup steps")
  622. parser.add_argument(
  623. "--decay-after-steps", type=int, default=180_000, help="number of steps after which to start decay the lr"
  624. )
  625. parser.add_argument(
  626. "--lr-warmup-method", type=str, default="linear", help="warmup method", choices=["linear", "cosine"]
  627. )
  628. parser.add_argument("--lr-warmup-factor", type=float, default=0.02, help="warmup factor for the learning rate")
  629. parser.add_argument(
  630. "--lr-decay-method",
  631. type=str,
  632. default="linear",
  633. help="decay method",
  634. choices=["linear", "cosine", "exponential"],
  635. )
  636. parser.add_argument("--lr-decay-gamma", type=float, default=0.8, help="decay factor for the learning rate")
  637. # deterministic behaviour
  638. parser.add_argument("--seed", type=int, default=42, help="seed for random number generators")
  639. # mixed precision training
  640. parser.add_argument("--mixed-precision", action="store_true", help="use mixed precision training")
  641. # logging
  642. parser.add_argument("--tensorboard-summaries", action="store_true", help="log to tensorboard")
  643. parser.add_argument("--tensorboard-log-frequency", type=int, default=100, help="log frequency")
  644. parser.add_argument("--save-frequency", type=int, default=1_000, help="save frequency")
  645. parser.add_argument("--valid-frequency", type=int, default=1_000, help="validation frequency")
  646. parser.add_argument(
  647. "--metrics",
  648. type=str,
  649. nargs="+",
  650. default=["mae", "rmse", "1px", "3px", "5px", "relepe"],
  651. help="metrics to log",
  652. choices=AVAILABLE_METRICS,
  653. )
  654. # distributed parameters
  655. parser.add_argument("--world-size", type=int, default=8, help="number of distributed processes")
  656. parser.add_argument("--dist-url", type=str, default="env://", help="url used to set up distributed training")
  657. parser.add_argument("--device", type=str, default="cuda", help="device to use for training")
  658. # weights API
  659. parser.add_argument("--weights", type=str, default=None, help="weights API url")
  660. parser.add_argument(
  661. "--resume-path", type=str, default=None, help="a path from which to resume or start fine-tuning"
  662. )
  663. parser.add_argument("--resume-schedule", action="store_true", help="resume optimizer state")
  664. # padder parameters
  665. parser.add_argument("--padder-type", type=str, default="kitti", help="padder type", choices=["kitti", "sintel"])
  666. return parser
  667. if __name__ == "__main__":
  668. args = get_args_parser().parse_args()
  669. main(args)