train.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. import torch
  3. import torchvision
  4. from ultralytics.data import ClassificationDataset, build_dataloader
  5. from ultralytics.engine.trainer import BaseTrainer
  6. from ultralytics.models import yolo
  7. from ultralytics.nn.tasks import ClassificationModel, attempt_load_one_weight
  8. from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK, colorstr
  9. from ultralytics.utils.plotting import plot_images, plot_results
  10. from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first
  11. class ClassificationTrainer(BaseTrainer):
  12. """
  13. A class extending the BaseTrainer class for training based on a classification model.
  14. Notes:
  15. - Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
  16. Example:
  17. ```python
  18. from ultralytics.models.yolo.classify import ClassificationTrainer
  19. args = dict(model='yolov8n-cls.pt', data='imagenet10', epochs=3)
  20. trainer = ClassificationTrainer(overrides=args)
  21. trainer.train()
  22. ```
  23. """
  24. def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
  25. """Initialize a ClassificationTrainer object with optional configuration overrides and callbacks."""
  26. if overrides is None:
  27. overrides = {}
  28. overrides['task'] = 'classify'
  29. if overrides.get('imgsz') is None:
  30. overrides['imgsz'] = 224
  31. super().__init__(cfg, overrides, _callbacks)
  32. def set_model_attributes(self):
  33. """Set the YOLO model's class names from the loaded dataset."""
  34. self.model.names = self.data['names']
  35. def get_model(self, cfg=None, weights=None, verbose=True):
  36. """Returns a modified PyTorch model configured for training YOLO."""
  37. model = ClassificationModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
  38. if weights:
  39. model.load(weights)
  40. for m in model.modules():
  41. if not self.args.pretrained and hasattr(m, 'reset_parameters'):
  42. m.reset_parameters()
  43. if isinstance(m, torch.nn.Dropout) and self.args.dropout:
  44. m.p = self.args.dropout # set dropout
  45. for p in model.parameters():
  46. p.requires_grad = True # for training
  47. return model
  48. def setup_model(self):
  49. """load/create/download model for any task"""
  50. if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
  51. return
  52. model, ckpt = str(self.model), None
  53. # Load a YOLO model locally, from torchvision, or from Ultralytics assets
  54. if model.endswith('.pt'):
  55. self.model, ckpt = attempt_load_one_weight(model, device='cpu')
  56. for p in self.model.parameters():
  57. p.requires_grad = True # for training
  58. elif model.split('.')[-1] in ('yaml', 'yml'):
  59. self.model = self.get_model(cfg=model)
  60. elif model in torchvision.models.__dict__:
  61. self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if self.args.pretrained else None)
  62. else:
  63. FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
  64. ClassificationModel.reshape_outputs(self.model, self.data['nc'])
  65. return ckpt
  66. def build_dataset(self, img_path, mode='train', batch=None):
  67. return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train', prefix=mode)
  68. def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
  69. """Returns PyTorch DataLoader with transforms to preprocess images for inference."""
  70. with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
  71. dataset = self.build_dataset(dataset_path, mode)
  72. loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank)
  73. # Attach inference transforms
  74. if mode != 'train':
  75. if is_parallel(self.model):
  76. self.model.module.transforms = loader.dataset.torch_transforms
  77. else:
  78. self.model.transforms = loader.dataset.torch_transforms
  79. return loader
  80. def preprocess_batch(self, batch):
  81. """Preprocesses a batch of images and classes."""
  82. batch['img'] = batch['img'].to(self.device)
  83. batch['cls'] = batch['cls'].to(self.device)
  84. return batch
  85. def progress_string(self):
  86. """Returns a formatted string showing training progress."""
  87. return ('\n' + '%11s' * (4 + len(self.loss_names))) % \
  88. ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
  89. def get_validator(self):
  90. """Returns an instance of ClassificationValidator for validation."""
  91. self.loss_names = ['loss']
  92. return yolo.classify.ClassificationValidator(self.test_loader, self.save_dir)
  93. def label_loss_items(self, loss_items=None, prefix='train'):
  94. """
  95. Returns a loss dict with labelled training loss items tensor. Not needed for classification but necessary for
  96. segmentation & detection
  97. """
  98. keys = [f'{prefix}/{x}' for x in self.loss_names]
  99. if loss_items is None:
  100. return keys
  101. loss_items = [round(float(loss_items), 5)]
  102. return dict(zip(keys, loss_items))
  103. def plot_metrics(self):
  104. """Plots metrics from a CSV file."""
  105. plot_results(file=self.csv, classify=True, on_plot=self.on_plot) # save results.png
  106. def final_eval(self):
  107. """Evaluate trained model and save validation results."""
  108. for f in self.last, self.best:
  109. if f.exists():
  110. strip_optimizer(f) # strip optimizers
  111. # TODO: validate best.pt after training completes
  112. # if f is self.best:
  113. # LOGGER.info(f'\nValidating {f}...')
  114. # self.validator.args.save_json = True
  115. # self.metrics = self.validator(model=f)
  116. # self.metrics.pop('fitness', None)
  117. # self.run_callbacks('on_fit_epoch_end')
  118. LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
  119. def plot_training_samples(self, batch, ni):
  120. """Plots training samples with their annotations."""
  121. plot_images(
  122. images=batch['img'],
  123. batch_idx=torch.arange(len(batch['img'])),
  124. cls=batch['cls'].view(-1), # warning: use .view(), not .squeeze() for Classify models
  125. fname=self.save_dir / f'train_batch{ni}.jpg',
  126. on_plot=self.on_plot)