engine.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import math
  2. import sys
  3. import time
  4. import torch
  5. import torchvision.models.detection.mask_rcnn
  6. import utils
  7. from coco_eval import CocoEvaluator
  8. from coco_utils import get_coco_api_from_dataset
  9. def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
  10. model.train()
  11. metric_logger = utils.MetricLogger(delimiter=" ")
  12. metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
  13. header = f"Epoch: [{epoch}]"
  14. lr_scheduler = None
  15. if epoch == 0:
  16. warmup_factor = 1.0 / 1000
  17. warmup_iters = min(1000, len(data_loader) - 1)
  18. lr_scheduler = torch.optim.lr_scheduler.LinearLR(
  19. optimizer, start_factor=warmup_factor, total_iters=warmup_iters
  20. )
  21. for images, targets in metric_logger.log_every(data_loader, print_freq, header):
  22. images = list(image.to(device) for image in images)
  23. targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
  24. with torch.cuda.amp.autocast(enabled=scaler is not None):
  25. loss_dict = model(images, targets)
  26. losses = sum(loss for loss in loss_dict.values())
  27. # reduce losses over all GPUs for logging purposes
  28. loss_dict_reduced = utils.reduce_dict(loss_dict)
  29. losses_reduced = sum(loss for loss in loss_dict_reduced.values())
  30. loss_value = losses_reduced.item()
  31. if not math.isfinite(loss_value):
  32. print(f"Loss is {loss_value}, stopping training")
  33. print(loss_dict_reduced)
  34. sys.exit(1)
  35. optimizer.zero_grad()
  36. if scaler is not None:
  37. scaler.scale(losses).backward()
  38. scaler.step(optimizer)
  39. scaler.update()
  40. else:
  41. losses.backward()
  42. optimizer.step()
  43. if lr_scheduler is not None:
  44. lr_scheduler.step()
  45. metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
  46. metric_logger.update(lr=optimizer.param_groups[0]["lr"])
  47. return metric_logger
  48. def _get_iou_types(model):
  49. model_without_ddp = model
  50. if isinstance(model, torch.nn.parallel.DistributedDataParallel):
  51. model_without_ddp = model.module
  52. iou_types = ["bbox"]
  53. if isinstance(model_without_ddp, torchvision.models.detection.MaskRCNN):
  54. iou_types.append("segm")
  55. if isinstance(model_without_ddp, torchvision.models.detection.KeypointRCNN):
  56. iou_types.append("keypoints")
  57. return iou_types
  58. @torch.inference_mode()
  59. def evaluate(model, data_loader, device):
  60. n_threads = torch.get_num_threads()
  61. # FIXME remove this and make paste_masks_in_image run on the GPU
  62. torch.set_num_threads(1)
  63. cpu_device = torch.device("cpu")
  64. model.eval()
  65. metric_logger = utils.MetricLogger(delimiter=" ")
  66. header = "Test:"
  67. coco = get_coco_api_from_dataset(data_loader.dataset)
  68. iou_types = _get_iou_types(model)
  69. coco_evaluator = CocoEvaluator(coco, iou_types)
  70. for images, targets in metric_logger.log_every(data_loader, 100, header):
  71. images = list(img.to(device) for img in images)
  72. if torch.cuda.is_available():
  73. torch.cuda.synchronize()
  74. model_time = time.time()
  75. outputs = model(images)
  76. outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
  77. model_time = time.time() - model_time
  78. res = {target["image_id"]: output for target, output in zip(targets, outputs)}
  79. evaluator_time = time.time()
  80. coco_evaluator.update(res)
  81. evaluator_time = time.time() - evaluator_time
  82. metric_logger.update(model_time=model_time, evaluator_time=evaluator_time)
  83. # gather the stats from all processes
  84. metric_logger.synchronize_between_processes()
  85. print("Averaged stats:", metric_logger)
  86. coco_evaluator.synchronize_between_processes()
  87. # accumulate predictions from all images
  88. coco_evaluator.accumulate()
  89. coco_evaluator.summarize()
  90. torch.set_num_threads(n_threads)
  91. return coco_evaluator