train.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. from copy import copy
  3. from ultralytics.models import yolo
  4. from ultralytics.nn.tasks import SegmentationModel
  5. from ultralytics.utils import DEFAULT_CFG, RANK
  6. from ultralytics.utils.plotting import plot_images, plot_results
  7. class SegmentationTrainer(yolo.detect.DetectionTrainer):
  8. """
  9. A class extending the DetectionTrainer class for training based on a segmentation model.
  10. Example:
  11. ```python
  12. from ultralytics.models.yolo.segment import SegmentationTrainer
  13. args = dict(model='yolov8n-seg.pt', data='coco8-seg.yaml', epochs=3)
  14. trainer = SegmentationTrainer(overrides=args)
  15. trainer.train()
  16. ```
  17. """
  18. def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
  19. """Initialize a SegmentationTrainer object with given arguments."""
  20. if overrides is None:
  21. overrides = {}
  22. overrides['task'] = 'segment'
  23. super().__init__(cfg, overrides, _callbacks)
  24. def get_model(self, cfg=None, weights=None, verbose=True):
  25. """Return SegmentationModel initialized with specified config and weights."""
  26. model = SegmentationModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1)
  27. if weights:
  28. model.load(weights)
  29. return model
  30. def get_validator(self):
  31. """Return an instance of SegmentationValidator for validation of YOLO model."""
  32. self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss'
  33. return yolo.segment.SegmentationValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
  34. def plot_training_samples(self, batch, ni):
  35. """Creates a plot of training sample images with labels and box coordinates."""
  36. plot_images(batch['img'],
  37. batch['batch_idx'],
  38. batch['cls'].squeeze(-1),
  39. batch['bboxes'],
  40. batch['masks'],
  41. paths=batch['im_file'],
  42. fname=self.save_dir / f'train_batch{ni}.jpg',
  43. on_plot=self.on_plot)
  44. def plot_metrics(self):
  45. """Plots training/val metrics."""
  46. plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png