train.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. import datetime
  2. import os
  3. import time
  4. import warnings
  5. import presets
  6. import torch
  7. import torch.utils.data
  8. import torchvision
  9. import utils
  10. from coco_utils import get_coco
  11. from torch import nn
  12. from torch.optim.lr_scheduler import PolynomialLR
  13. from torchvision.transforms import functional as F, InterpolationMode
  14. def get_dataset(args, is_train):
  15. def sbd(*args, **kwargs):
  16. kwargs.pop("use_v2")
  17. return torchvision.datasets.SBDataset(*args, mode="segmentation", **kwargs)
  18. def voc(*args, **kwargs):
  19. kwargs.pop("use_v2")
  20. return torchvision.datasets.VOCSegmentation(*args, **kwargs)
  21. paths = {
  22. "voc": (args.data_path, voc, 21),
  23. "voc_aug": (args.data_path, sbd, 21),
  24. "coco": (args.data_path, get_coco, 21),
  25. }
  26. p, ds_fn, num_classes = paths[args.dataset]
  27. image_set = "train" if is_train else "val"
  28. ds = ds_fn(p, image_set=image_set, transforms=get_transform(is_train, args), use_v2=args.use_v2)
  29. return ds, num_classes
  30. def get_transform(is_train, args):
  31. if is_train:
  32. return presets.SegmentationPresetTrain(base_size=520, crop_size=480, backend=args.backend, use_v2=args.use_v2)
  33. elif args.weights and args.test_only:
  34. weights = torchvision.models.get_weight(args.weights)
  35. trans = weights.transforms()
  36. def preprocessing(img, target):
  37. img = trans(img)
  38. size = F.get_dimensions(img)[1:]
  39. target = F.resize(target, size, interpolation=InterpolationMode.NEAREST)
  40. return img, F.pil_to_tensor(target)
  41. return preprocessing
  42. else:
  43. return presets.SegmentationPresetEval(base_size=520, backend=args.backend, use_v2=args.use_v2)
  44. def criterion(inputs, target):
  45. losses = {}
  46. for name, x in inputs.items():
  47. losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255)
  48. if len(losses) == 1:
  49. return losses["out"]
  50. return losses["out"] + 0.5 * losses["aux"]
  51. def evaluate(model, data_loader, device, num_classes):
  52. model.eval()
  53. confmat = utils.ConfusionMatrix(num_classes)
  54. metric_logger = utils.MetricLogger(delimiter=" ")
  55. header = "Test:"
  56. num_processed_samples = 0
  57. with torch.inference_mode():
  58. for image, target in metric_logger.log_every(data_loader, 100, header):
  59. image, target = image.to(device), target.to(device)
  60. output = model(image)
  61. output = output["out"]
  62. confmat.update(target.flatten(), output.argmax(1).flatten())
  63. # FIXME need to take into account that the datasets
  64. # could have been padded in distributed setup
  65. num_processed_samples += image.shape[0]
  66. confmat.reduce_from_all_processes()
  67. num_processed_samples = utils.reduce_across_processes(num_processed_samples)
  68. if (
  69. hasattr(data_loader.dataset, "__len__")
  70. and len(data_loader.dataset) != num_processed_samples
  71. and torch.distributed.get_rank() == 0
  72. ):
  73. # See FIXME above
  74. warnings.warn(
  75. f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
  76. "samples were used for the validation, which might bias the results. "
  77. "Try adjusting the batch size and / or the world size. "
  78. "Setting the world size to 1 is always a safe bet."
  79. )
  80. return confmat
  81. def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, print_freq, scaler=None):
  82. model.train()
  83. metric_logger = utils.MetricLogger(delimiter=" ")
  84. metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
  85. header = f"Epoch: [{epoch}]"
  86. for image, target in metric_logger.log_every(data_loader, print_freq, header):
  87. image, target = image.to(device), target.to(device)
  88. with torch.cuda.amp.autocast(enabled=scaler is not None):
  89. output = model(image)
  90. loss = criterion(output, target)
  91. optimizer.zero_grad()
  92. if scaler is not None:
  93. scaler.scale(loss).backward()
  94. scaler.step(optimizer)
  95. scaler.update()
  96. else:
  97. loss.backward()
  98. optimizer.step()
  99. lr_scheduler.step()
  100. metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
  101. def main(args):
  102. if args.backend.lower() != "pil" and not args.use_v2:
  103. # TODO: Support tensor backend in V1?
  104. raise ValueError("Use --use-v2 if you want to use the tv_tensor or tensor backend.")
  105. if args.use_v2 and args.dataset != "coco":
  106. raise ValueError("v2 is only support supported for coco dataset for now.")
  107. if args.output_dir:
  108. utils.mkdir(args.output_dir)
  109. utils.init_distributed_mode(args)
  110. print(args)
  111. device = torch.device(args.device)
  112. if args.use_deterministic_algorithms:
  113. torch.backends.cudnn.benchmark = False
  114. torch.use_deterministic_algorithms(True)
  115. else:
  116. torch.backends.cudnn.benchmark = True
  117. dataset, num_classes = get_dataset(args, is_train=True)
  118. dataset_test, _ = get_dataset(args, is_train=False)
  119. if args.distributed:
  120. train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
  121. test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
  122. else:
  123. train_sampler = torch.utils.data.RandomSampler(dataset)
  124. test_sampler = torch.utils.data.SequentialSampler(dataset_test)
  125. data_loader = torch.utils.data.DataLoader(
  126. dataset,
  127. batch_size=args.batch_size,
  128. sampler=train_sampler,
  129. num_workers=args.workers,
  130. collate_fn=utils.collate_fn,
  131. drop_last=True,
  132. )
  133. data_loader_test = torch.utils.data.DataLoader(
  134. dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
  135. )
  136. model = torchvision.models.get_model(
  137. args.model,
  138. weights=args.weights,
  139. weights_backbone=args.weights_backbone,
  140. num_classes=num_classes,
  141. aux_loss=args.aux_loss,
  142. )
  143. model.to(device)
  144. if args.distributed:
  145. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
  146. model_without_ddp = model
  147. if args.distributed:
  148. model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
  149. model_without_ddp = model.module
  150. params_to_optimize = [
  151. {"params": [p for p in model_without_ddp.backbone.parameters() if p.requires_grad]},
  152. {"params": [p for p in model_without_ddp.classifier.parameters() if p.requires_grad]},
  153. ]
  154. if args.aux_loss:
  155. params = [p for p in model_without_ddp.aux_classifier.parameters() if p.requires_grad]
  156. params_to_optimize.append({"params": params, "lr": args.lr * 10})
  157. optimizer = torch.optim.SGD(params_to_optimize, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
  158. scaler = torch.cuda.amp.GradScaler() if args.amp else None
  159. iters_per_epoch = len(data_loader)
  160. main_lr_scheduler = PolynomialLR(
  161. optimizer, total_iters=iters_per_epoch * (args.epochs - args.lr_warmup_epochs), power=0.9
  162. )
  163. if args.lr_warmup_epochs > 0:
  164. warmup_iters = iters_per_epoch * args.lr_warmup_epochs
  165. args.lr_warmup_method = args.lr_warmup_method.lower()
  166. if args.lr_warmup_method == "linear":
  167. warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
  168. optimizer, start_factor=args.lr_warmup_decay, total_iters=warmup_iters
  169. )
  170. elif args.lr_warmup_method == "constant":
  171. warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
  172. optimizer, factor=args.lr_warmup_decay, total_iters=warmup_iters
  173. )
  174. else:
  175. raise RuntimeError(
  176. f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
  177. )
  178. lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
  179. optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[warmup_iters]
  180. )
  181. else:
  182. lr_scheduler = main_lr_scheduler
  183. if args.resume:
  184. checkpoint = torch.load(args.resume, map_location="cpu")
  185. model_without_ddp.load_state_dict(checkpoint["model"], strict=not args.test_only)
  186. if not args.test_only:
  187. optimizer.load_state_dict(checkpoint["optimizer"])
  188. lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
  189. args.start_epoch = checkpoint["epoch"] + 1
  190. if args.amp:
  191. scaler.load_state_dict(checkpoint["scaler"])
  192. if args.test_only:
  193. # We disable the cudnn benchmarking because it can noticeably affect the accuracy
  194. torch.backends.cudnn.benchmark = False
  195. torch.backends.cudnn.deterministic = True
  196. confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
  197. print(confmat)
  198. return
  199. start_time = time.time()
  200. for epoch in range(args.start_epoch, args.epochs):
  201. if args.distributed:
  202. train_sampler.set_epoch(epoch)
  203. train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq, scaler)
  204. confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
  205. print(confmat)
  206. checkpoint = {
  207. "model": model_without_ddp.state_dict(),
  208. "optimizer": optimizer.state_dict(),
  209. "lr_scheduler": lr_scheduler.state_dict(),
  210. "epoch": epoch,
  211. "args": args,
  212. }
  213. if args.amp:
  214. checkpoint["scaler"] = scaler.state_dict()
  215. utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
  216. utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
  217. total_time = time.time() - start_time
  218. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  219. print(f"Training time {total_time_str}")
  220. def get_args_parser(add_help=True):
  221. import argparse
  222. parser = argparse.ArgumentParser(description="PyTorch Segmentation Training", add_help=add_help)
  223. parser.add_argument("--data-path", default="/datasets01/COCO/022719/", type=str, help="dataset path")
  224. parser.add_argument("--dataset", default="coco", type=str, help="dataset name")
  225. parser.add_argument("--model", default="fcn_resnet101", type=str, help="model name")
  226. parser.add_argument("--aux-loss", action="store_true", help="auxiliary loss")
  227. parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
  228. parser.add_argument(
  229. "-b", "--batch-size", default=8, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
  230. )
  231. parser.add_argument("--epochs", default=30, type=int, metavar="N", help="number of total epochs to run")
  232. parser.add_argument(
  233. "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
  234. )
  235. parser.add_argument("--lr", default=0.01, type=float, help="initial learning rate")
  236. parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
  237. parser.add_argument(
  238. "--wd",
  239. "--weight-decay",
  240. default=1e-4,
  241. type=float,
  242. metavar="W",
  243. help="weight decay (default: 1e-4)",
  244. dest="weight_decay",
  245. )
  246. parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
  247. parser.add_argument("--lr-warmup-method", default="linear", type=str, help="the warmup method (default: linear)")
  248. parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
  249. parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
  250. parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
  251. parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
  252. parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
  253. parser.add_argument(
  254. "--test-only",
  255. dest="test_only",
  256. help="Only test the model",
  257. action="store_true",
  258. )
  259. parser.add_argument(
  260. "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
  261. )
  262. # distributed training parameters
  263. parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
  264. parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
  265. parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
  266. parser.add_argument("--weights-backbone", default=None, type=str, help="the backbone weights enum name to load")
  267. # Mixed precision training parameters
  268. parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
  269. parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
  270. parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms")
  271. return parser
  272. if __name__ == "__main__":
  273. args = get_args_parser().parse_args()
  274. main(args)