train.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. import datetime
  2. import os
  3. import time
  4. import warnings
  5. import datasets
  6. import presets
  7. import torch
  8. import torch.utils.data
  9. import torchvision
  10. import torchvision.datasets.video_utils
  11. import utils
  12. from torch import nn
  13. from torch.utils.data.dataloader import default_collate
  14. from torchvision.datasets.samplers import DistributedSampler, RandomClipSampler, UniformClipSampler
  15. def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, scaler=None):
  16. model.train()
  17. metric_logger = utils.MetricLogger(delimiter=" ")
  18. metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
  19. metric_logger.add_meter("clips/s", utils.SmoothedValue(window_size=10, fmt="{value:.3f}"))
  20. header = f"Epoch: [{epoch}]"
  21. for video, target, _ in metric_logger.log_every(data_loader, print_freq, header):
  22. start_time = time.time()
  23. video, target = video.to(device), target.to(device)
  24. with torch.cuda.amp.autocast(enabled=scaler is not None):
  25. output = model(video)
  26. loss = criterion(output, target)
  27. optimizer.zero_grad()
  28. if scaler is not None:
  29. scaler.scale(loss).backward()
  30. scaler.step(optimizer)
  31. scaler.update()
  32. else:
  33. loss.backward()
  34. optimizer.step()
  35. acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
  36. batch_size = video.shape[0]
  37. metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
  38. metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
  39. metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
  40. metric_logger.meters["clips/s"].update(batch_size / (time.time() - start_time))
  41. lr_scheduler.step()
  42. def evaluate(model, criterion, data_loader, device):
  43. model.eval()
  44. metric_logger = utils.MetricLogger(delimiter=" ")
  45. header = "Test:"
  46. num_processed_samples = 0
  47. # Group and aggregate output of a video
  48. num_videos = len(data_loader.dataset.samples)
  49. num_classes = len(data_loader.dataset.classes)
  50. agg_preds = torch.zeros((num_videos, num_classes), dtype=torch.float32, device=device)
  51. agg_targets = torch.zeros((num_videos), dtype=torch.int32, device=device)
  52. with torch.inference_mode():
  53. for video, target, video_idx in metric_logger.log_every(data_loader, 100, header):
  54. video = video.to(device, non_blocking=True)
  55. target = target.to(device, non_blocking=True)
  56. output = model(video)
  57. loss = criterion(output, target)
  58. # Use softmax to convert output into prediction probability
  59. preds = torch.softmax(output, dim=1)
  60. for b in range(video.size(0)):
  61. idx = video_idx[b].item()
  62. agg_preds[idx] += preds[b].detach()
  63. agg_targets[idx] = target[b].detach().item()
  64. acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
  65. # FIXME need to take into account that the datasets
  66. # could have been padded in distributed setup
  67. batch_size = video.shape[0]
  68. metric_logger.update(loss=loss.item())
  69. metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
  70. metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
  71. num_processed_samples += batch_size
  72. # gather the stats from all processes
  73. num_processed_samples = utils.reduce_across_processes(num_processed_samples)
  74. if isinstance(data_loader.sampler, DistributedSampler):
  75. # Get the len of UniformClipSampler inside DistributedSampler
  76. num_data_from_sampler = len(data_loader.sampler.dataset)
  77. else:
  78. num_data_from_sampler = len(data_loader.sampler)
  79. if (
  80. hasattr(data_loader.dataset, "__len__")
  81. and num_data_from_sampler != num_processed_samples
  82. and torch.distributed.get_rank() == 0
  83. ):
  84. # See FIXME above
  85. warnings.warn(
  86. f"It looks like the sampler has {num_data_from_sampler} samples, but {num_processed_samples} "
  87. "samples were used for the validation, which might bias the results. "
  88. "Try adjusting the batch size and / or the world size. "
  89. "Setting the world size to 1 is always a safe bet."
  90. )
  91. metric_logger.synchronize_between_processes()
  92. print(
  93. " * Clip Acc@1 {top1.global_avg:.3f} Clip Acc@5 {top5.global_avg:.3f}".format(
  94. top1=metric_logger.acc1, top5=metric_logger.acc5
  95. )
  96. )
  97. # Reduce the agg_preds and agg_targets from all gpu and show result
  98. agg_preds = utils.reduce_across_processes(agg_preds)
  99. agg_targets = utils.reduce_across_processes(agg_targets, op=torch.distributed.ReduceOp.MAX)
  100. agg_acc1, agg_acc5 = utils.accuracy(agg_preds, agg_targets, topk=(1, 5))
  101. print(" * Video Acc@1 {acc1:.3f} Video Acc@5 {acc5:.3f}".format(acc1=agg_acc1, acc5=agg_acc5))
  102. return metric_logger.acc1.global_avg
  103. def _get_cache_path(filepath, args):
  104. import hashlib
  105. value = f"{filepath}-{args.clip_len}-{args.kinetics_version}-{args.frame_rate}"
  106. h = hashlib.sha1(value.encode()).hexdigest()
  107. cache_path = os.path.join("~", ".torch", "vision", "datasets", "kinetics", h[:10] + ".pt")
  108. cache_path = os.path.expanduser(cache_path)
  109. return cache_path
  110. def collate_fn(batch):
  111. # remove audio from the batch
  112. batch = [(d[0], d[2], d[3]) for d in batch]
  113. return default_collate(batch)
  114. def main(args):
  115. if args.output_dir:
  116. utils.mkdir(args.output_dir)
  117. utils.init_distributed_mode(args)
  118. print(args)
  119. device = torch.device(args.device)
  120. if args.use_deterministic_algorithms:
  121. torch.backends.cudnn.benchmark = False
  122. torch.use_deterministic_algorithms(True)
  123. else:
  124. torch.backends.cudnn.benchmark = True
  125. # Data loading code
  126. print("Loading data")
  127. val_resize_size = tuple(args.val_resize_size)
  128. val_crop_size = tuple(args.val_crop_size)
  129. train_resize_size = tuple(args.train_resize_size)
  130. train_crop_size = tuple(args.train_crop_size)
  131. traindir = os.path.join(args.data_path, "train")
  132. valdir = os.path.join(args.data_path, "val")
  133. print("Loading training data")
  134. st = time.time()
  135. cache_path = _get_cache_path(traindir, args)
  136. transform_train = presets.VideoClassificationPresetTrain(crop_size=train_crop_size, resize_size=train_resize_size)
  137. if args.cache_dataset and os.path.exists(cache_path):
  138. print(f"Loading dataset_train from {cache_path}")
  139. dataset, _ = torch.load(cache_path)
  140. dataset.transform = transform_train
  141. else:
  142. if args.distributed:
  143. print("It is recommended to pre-compute the dataset cache on a single-gpu first, as it will be faster")
  144. dataset = datasets.KineticsWithVideoId(
  145. args.data_path,
  146. frames_per_clip=args.clip_len,
  147. num_classes=args.kinetics_version,
  148. split="train",
  149. step_between_clips=1,
  150. transform=transform_train,
  151. frame_rate=args.frame_rate,
  152. extensions=(
  153. "avi",
  154. "mp4",
  155. ),
  156. output_format="TCHW",
  157. )
  158. if args.cache_dataset:
  159. print(f"Saving dataset_train to {cache_path}")
  160. utils.mkdir(os.path.dirname(cache_path))
  161. utils.save_on_master((dataset, traindir), cache_path)
  162. print("Took", time.time() - st)
  163. print("Loading validation data")
  164. cache_path = _get_cache_path(valdir, args)
  165. if args.weights and args.test_only:
  166. weights = torchvision.models.get_weight(args.weights)
  167. transform_test = weights.transforms()
  168. else:
  169. transform_test = presets.VideoClassificationPresetEval(crop_size=val_crop_size, resize_size=val_resize_size)
  170. if args.cache_dataset and os.path.exists(cache_path):
  171. print(f"Loading dataset_test from {cache_path}")
  172. dataset_test, _ = torch.load(cache_path)
  173. dataset_test.transform = transform_test
  174. else:
  175. if args.distributed:
  176. print("It is recommended to pre-compute the dataset cache on a single-gpu first, as it will be faster")
  177. dataset_test = datasets.KineticsWithVideoId(
  178. args.data_path,
  179. frames_per_clip=args.clip_len,
  180. num_classes=args.kinetics_version,
  181. split="val",
  182. step_between_clips=1,
  183. transform=transform_test,
  184. frame_rate=args.frame_rate,
  185. extensions=(
  186. "avi",
  187. "mp4",
  188. ),
  189. output_format="TCHW",
  190. )
  191. if args.cache_dataset:
  192. print(f"Saving dataset_test to {cache_path}")
  193. utils.mkdir(os.path.dirname(cache_path))
  194. utils.save_on_master((dataset_test, valdir), cache_path)
  195. print("Creating data loaders")
  196. train_sampler = RandomClipSampler(dataset.video_clips, args.clips_per_video)
  197. test_sampler = UniformClipSampler(dataset_test.video_clips, args.clips_per_video)
  198. if args.distributed:
  199. train_sampler = DistributedSampler(train_sampler)
  200. test_sampler = DistributedSampler(test_sampler, shuffle=False)
  201. data_loader = torch.utils.data.DataLoader(
  202. dataset,
  203. batch_size=args.batch_size,
  204. sampler=train_sampler,
  205. num_workers=args.workers,
  206. pin_memory=True,
  207. collate_fn=collate_fn,
  208. )
  209. data_loader_test = torch.utils.data.DataLoader(
  210. dataset_test,
  211. batch_size=args.batch_size,
  212. sampler=test_sampler,
  213. num_workers=args.workers,
  214. pin_memory=True,
  215. collate_fn=collate_fn,
  216. )
  217. print("Creating model")
  218. model = torchvision.models.get_model(args.model, weights=args.weights)
  219. model.to(device)
  220. if args.distributed and args.sync_bn:
  221. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
  222. criterion = nn.CrossEntropyLoss()
  223. optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
  224. scaler = torch.cuda.amp.GradScaler() if args.amp else None
  225. # convert scheduler to be per iteration, not per epoch, for warmup that lasts
  226. # between different epochs
  227. iters_per_epoch = len(data_loader)
  228. lr_milestones = [iters_per_epoch * (m - args.lr_warmup_epochs) for m in args.lr_milestones]
  229. main_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=lr_milestones, gamma=args.lr_gamma)
  230. if args.lr_warmup_epochs > 0:
  231. warmup_iters = iters_per_epoch * args.lr_warmup_epochs
  232. args.lr_warmup_method = args.lr_warmup_method.lower()
  233. if args.lr_warmup_method == "linear":
  234. warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
  235. optimizer, start_factor=args.lr_warmup_decay, total_iters=warmup_iters
  236. )
  237. elif args.lr_warmup_method == "constant":
  238. warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
  239. optimizer, factor=args.lr_warmup_decay, total_iters=warmup_iters
  240. )
  241. else:
  242. raise RuntimeError(
  243. f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
  244. )
  245. lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
  246. optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[warmup_iters]
  247. )
  248. else:
  249. lr_scheduler = main_lr_scheduler
  250. model_without_ddp = model
  251. if args.distributed:
  252. model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
  253. model_without_ddp = model.module
  254. if args.resume:
  255. checkpoint = torch.load(args.resume, map_location="cpu")
  256. model_without_ddp.load_state_dict(checkpoint["model"])
  257. optimizer.load_state_dict(checkpoint["optimizer"])
  258. lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
  259. args.start_epoch = checkpoint["epoch"] + 1
  260. if args.amp:
  261. scaler.load_state_dict(checkpoint["scaler"])
  262. if args.test_only:
  263. # We disable the cudnn benchmarking because it can noticeably affect the accuracy
  264. torch.backends.cudnn.benchmark = False
  265. torch.backends.cudnn.deterministic = True
  266. evaluate(model, criterion, data_loader_test, device=device)
  267. return
  268. print("Start training")
  269. start_time = time.time()
  270. for epoch in range(args.start_epoch, args.epochs):
  271. if args.distributed:
  272. train_sampler.set_epoch(epoch)
  273. train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, args.print_freq, scaler)
  274. evaluate(model, criterion, data_loader_test, device=device)
  275. if args.output_dir:
  276. checkpoint = {
  277. "model": model_without_ddp.state_dict(),
  278. "optimizer": optimizer.state_dict(),
  279. "lr_scheduler": lr_scheduler.state_dict(),
  280. "epoch": epoch,
  281. "args": args,
  282. }
  283. if args.amp:
  284. checkpoint["scaler"] = scaler.state_dict()
  285. utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
  286. utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
  287. total_time = time.time() - start_time
  288. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  289. print(f"Training time {total_time_str}")
  290. def get_args_parser(add_help=True):
  291. import argparse
  292. parser = argparse.ArgumentParser(description="PyTorch Video Classification Training", add_help=add_help)
  293. parser.add_argument("--data-path", default="/datasets01_101/kinetics/070618/", type=str, help="dataset path")
  294. parser.add_argument(
  295. "--kinetics-version", default="400", type=str, choices=["400", "600"], help="Select kinetics version"
  296. )
  297. parser.add_argument("--model", default="r2plus1d_18", type=str, help="model name")
  298. parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
  299. parser.add_argument("--clip-len", default=16, type=int, metavar="N", help="number of frames per clip")
  300. parser.add_argument("--frame-rate", default=15, type=int, metavar="N", help="the frame rate")
  301. parser.add_argument(
  302. "--clips-per-video", default=5, type=int, metavar="N", help="maximum number of clips per video to consider"
  303. )
  304. parser.add_argument(
  305. "-b", "--batch-size", default=24, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
  306. )
  307. parser.add_argument("--epochs", default=45, type=int, metavar="N", help="number of total epochs to run")
  308. parser.add_argument(
  309. "-j", "--workers", default=10, type=int, metavar="N", help="number of data loading workers (default: 10)"
  310. )
  311. parser.add_argument("--lr", default=0.64, type=float, help="initial learning rate")
  312. parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
  313. parser.add_argument(
  314. "--wd",
  315. "--weight-decay",
  316. default=1e-4,
  317. type=float,
  318. metavar="W",
  319. help="weight decay (default: 1e-4)",
  320. dest="weight_decay",
  321. )
  322. parser.add_argument("--lr-milestones", nargs="+", default=[20, 30, 40], type=int, help="decrease lr on milestones")
  323. parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
  324. parser.add_argument("--lr-warmup-epochs", default=10, type=int, help="the number of epochs to warmup (default: 10)")
  325. parser.add_argument("--lr-warmup-method", default="linear", type=str, help="the warmup method (default: linear)")
  326. parser.add_argument("--lr-warmup-decay", default=0.001, type=float, help="the decay for lr")
  327. parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
  328. parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
  329. parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
  330. parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
  331. parser.add_argument(
  332. "--cache-dataset",
  333. dest="cache_dataset",
  334. help="Cache the datasets for quicker initialization. It also serializes the transforms",
  335. action="store_true",
  336. )
  337. parser.add_argument(
  338. "--sync-bn",
  339. dest="sync_bn",
  340. help="Use sync batch norm",
  341. action="store_true",
  342. )
  343. parser.add_argument(
  344. "--test-only",
  345. dest="test_only",
  346. help="Only test the model",
  347. action="store_true",
  348. )
  349. parser.add_argument(
  350. "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
  351. )
  352. # distributed training parameters
  353. parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
  354. parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
  355. parser.add_argument(
  356. "--val-resize-size",
  357. default=(128, 171),
  358. nargs="+",
  359. type=int,
  360. help="the resize size used for validation (default: (128, 171))",
  361. )
  362. parser.add_argument(
  363. "--val-crop-size",
  364. default=(112, 112),
  365. nargs="+",
  366. type=int,
  367. help="the central crop size used for validation (default: (112, 112))",
  368. )
  369. parser.add_argument(
  370. "--train-resize-size",
  371. default=(128, 171),
  372. nargs="+",
  373. type=int,
  374. help="the resize size used for training (default: (128, 171))",
  375. )
  376. parser.add_argument(
  377. "--train-crop-size",
  378. default=(112, 112),
  379. nargs="+",
  380. type=int,
  381. help="the random crop size used for training (default: (112, 112))",
  382. )
  383. parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
  384. # Mixed precision training parameters
  385. parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
  386. return parser
  387. if __name__ == "__main__":
  388. args = get_args_parser().parse_args()
  389. main(args)