test_engine.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. from pathlib import Path
  3. from ultralytics import YOLO
  4. from ultralytics.cfg import get_cfg
  5. from ultralytics.engine.exporter import Exporter
  6. from ultralytics.models.yolo import classify, detect, segment
  7. from ultralytics.utils import ASSETS, DEFAULT_CFG, SETTINGS
  8. CFG_DET = 'yolov8n.yaml'
  9. CFG_SEG = 'yolov8n-seg.yaml'
  10. CFG_CLS = 'yolov8n-cls.yaml' # or 'squeezenet1_0'
  11. CFG = get_cfg(DEFAULT_CFG)
  12. MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n'
  13. def test_func(*args): # noqa
  14. print('callback test passed')
  15. def test_export():
  16. exporter = Exporter()
  17. exporter.add_callback('on_export_start', test_func)
  18. assert test_func in exporter.callbacks['on_export_start'], 'callback test failed'
  19. f = exporter(model=YOLO(CFG_DET).model)
  20. YOLO(f)(ASSETS) # exported model inference
  21. def test_detect():
  22. overrides = {'data': 'coco8.yaml', 'model': CFG_DET, 'imgsz': 32, 'epochs': 1, 'save': False}
  23. CFG.data = 'coco8.yaml'
  24. CFG.imgsz = 32
  25. # Trainer
  26. trainer = detect.DetectionTrainer(overrides=overrides)
  27. trainer.add_callback('on_train_start', test_func)
  28. assert test_func in trainer.callbacks['on_train_start'], 'callback test failed'
  29. trainer.train()
  30. # Validator
  31. val = detect.DetectionValidator(args=CFG)
  32. val.add_callback('on_val_start', test_func)
  33. assert test_func in val.callbacks['on_val_start'], 'callback test failed'
  34. val(model=trainer.best) # validate best.pt
  35. # Predictor
  36. pred = detect.DetectionPredictor(overrides={'imgsz': [64, 64]})
  37. pred.add_callback('on_predict_start', test_func)
  38. assert test_func in pred.callbacks['on_predict_start'], 'callback test failed'
  39. result = pred(source=ASSETS, model=f'{MODEL}.pt')
  40. assert len(result), 'predictor test failed'
  41. overrides['resume'] = trainer.last
  42. trainer = detect.DetectionTrainer(overrides=overrides)
  43. try:
  44. trainer.train()
  45. except Exception as e:
  46. print(f'Expected exception caught: {e}')
  47. return
  48. Exception('Resume test failed!')
  49. def test_segment():
  50. overrides = {'data': 'coco8-seg.yaml', 'model': CFG_SEG, 'imgsz': 32, 'epochs': 1, 'save': False}
  51. CFG.data = 'coco8-seg.yaml'
  52. CFG.imgsz = 32
  53. # YOLO(CFG_SEG).train(**overrides) # works
  54. # trainer
  55. trainer = segment.SegmentationTrainer(overrides=overrides)
  56. trainer.add_callback('on_train_start', test_func)
  57. assert test_func in trainer.callbacks['on_train_start'], 'callback test failed'
  58. trainer.train()
  59. # Validator
  60. val = segment.SegmentationValidator(args=CFG)
  61. val.add_callback('on_val_start', test_func)
  62. assert test_func in val.callbacks['on_val_start'], 'callback test failed'
  63. val(model=trainer.best) # validate best.pt
  64. # Predictor
  65. pred = segment.SegmentationPredictor(overrides={'imgsz': [64, 64]})
  66. pred.add_callback('on_predict_start', test_func)
  67. assert test_func in pred.callbacks['on_predict_start'], 'callback test failed'
  68. result = pred(source=ASSETS, model=f'{MODEL}-seg.pt')
  69. assert len(result), 'predictor test failed'
  70. # Test resume
  71. overrides['resume'] = trainer.last
  72. trainer = segment.SegmentationTrainer(overrides=overrides)
  73. try:
  74. trainer.train()
  75. except Exception as e:
  76. print(f'Expected exception caught: {e}')
  77. return
  78. Exception('Resume test failed!')
  79. def test_classify():
  80. overrides = {'data': 'imagenet10', 'model': CFG_CLS, 'imgsz': 32, 'epochs': 1, 'save': False}
  81. CFG.data = 'imagenet10'
  82. CFG.imgsz = 32
  83. # YOLO(CFG_SEG).train(**overrides) # works
  84. # Trainer
  85. trainer = classify.ClassificationTrainer(overrides=overrides)
  86. trainer.add_callback('on_train_start', test_func)
  87. assert test_func in trainer.callbacks['on_train_start'], 'callback test failed'
  88. trainer.train()
  89. # Validator
  90. val = classify.ClassificationValidator(args=CFG)
  91. val.add_callback('on_val_start', test_func)
  92. assert test_func in val.callbacks['on_val_start'], 'callback test failed'
  93. val(model=trainer.best)
  94. # Predictor
  95. pred = classify.ClassificationPredictor(overrides={'imgsz': [64, 64]})
  96. pred.add_callback('on_predict_start', test_func)
  97. assert test_func in pred.callbacks['on_predict_start'], 'callback test failed'
  98. result = pred(source=ASSETS, model=trainer.best)
  99. assert len(result), 'predictor test failed'