train.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. from copy import copy
  3. from ultralytics.models import yolo
  4. from ultralytics.nn.tasks import PoseModel
  5. from ultralytics.utils import DEFAULT_CFG, LOGGER
  6. from ultralytics.utils.plotting import plot_images, plot_results
  7. class PoseTrainer(yolo.detect.DetectionTrainer):
  8. """
  9. A class extending the DetectionTrainer class for training based on a pose model.
  10. Example:
  11. ```python
  12. from ultralytics.models.yolo.pose import PoseTrainer
  13. args = dict(model='yolov8n-pose.pt', data='coco8-pose.yaml', epochs=3)
  14. trainer = PoseTrainer(overrides=args)
  15. trainer.train()
  16. ```
  17. """
  18. def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
  19. """Initialize a PoseTrainer object with specified configurations and overrides."""
  20. if overrides is None:
  21. overrides = {}
  22. overrides['task'] = 'pose'
  23. super().__init__(cfg, overrides, _callbacks)
  24. if isinstance(self.args.device, str) and self.args.device.lower() == 'mps':
  25. LOGGER.warning("WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
  26. 'See https://github.com/ultralytics/ultralytics/issues/4031.')
  27. def get_model(self, cfg=None, weights=None, verbose=True):
  28. """Get pose estimation model with specified configuration and weights."""
  29. model = PoseModel(cfg, ch=3, nc=self.data['nc'], data_kpt_shape=self.data['kpt_shape'], verbose=verbose)
  30. if weights:
  31. model.load(weights)
  32. return model
  33. def set_model_attributes(self):
  34. """Sets keypoints shape attribute of PoseModel."""
  35. super().set_model_attributes()
  36. self.model.kpt_shape = self.data['kpt_shape']
  37. def get_validator(self):
  38. """Returns an instance of the PoseValidator class for validation."""
  39. self.loss_names = 'box_loss', 'pose_loss', 'kobj_loss', 'cls_loss', 'dfl_loss'
  40. return yolo.pose.PoseValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
  41. def plot_training_samples(self, batch, ni):
  42. """Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints."""
  43. images = batch['img']
  44. kpts = batch['keypoints']
  45. cls = batch['cls'].squeeze(-1)
  46. bboxes = batch['bboxes']
  47. paths = batch['im_file']
  48. batch_idx = batch['batch_idx']
  49. plot_images(images,
  50. batch_idx,
  51. cls,
  52. bboxes,
  53. kpts=kpts,
  54. paths=paths,
  55. fname=self.save_dir / f'train_batch{ni}.jpg',
  56. on_plot=self.on_plot)
  57. def plot_metrics(self):
  58. """Plots training/val metrics."""
  59. plot_results(file=self.csv, pose=True, on_plot=self.on_plot) # save results.png