val.py 907 B

123456789101112131415161718192021222324
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. import torch
  3. from ultralytics.models.yolo.detect import DetectionValidator
  4. from ultralytics.utils import ops
  5. __all__ = ['NASValidator']
  6. class NASValidator(DetectionValidator):
  7. def postprocess(self, preds_in):
  8. """Apply Non-maximum suppression to prediction outputs."""
  9. boxes = ops.xyxy2xywh(preds_in[0][0])
  10. preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
  11. return ops.non_max_suppression(preds,
  12. self.args.conf,
  13. self.args.iou,
  14. labels=self.lb,
  15. multi_label=False,
  16. agnostic=self.args.single_cls,
  17. max_det=self.args.max_det,
  18. max_time_img=0.5)