AIDetector_pytorch.py 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. import cv2
  2. import torch
  3. import numpy as np
  4. import onnxruntime as ort
  5. from utils.general import non_max_suppression, scale_coords
  6. from utils.BaseDetector import baseDet
  7. from utils.datasets import letterbox
  8. import logging
  9. logging.basicConfig(filename='detection_log.txt', level=logging.INFO,
  10. format='%(asctime)s - %(message)s')
  11. class Detector(baseDet):
  12. def __init__(self):
  13. super(Detector, self).__init__()
  14. self.device = None
  15. self.weights = None
  16. self.session = None
  17. self.names = None
  18. self.img_size = 640
  19. self.init_model()
  20. self.build_config()
  21. def init_model(self):
  22. self.weights = 'weights/yolov5s.onnx'
  23. self.device = '0' if torch.cuda.is_available() else 'cpu'
  24. #self.session = ort.InferenceSession(self.weights, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
  25. self.session = ort.InferenceSession(self.weights, providers=['CUDAExecutionProvider'])
  26. self.names = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
  27. 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
  28. 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
  29. 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
  30. 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
  31. 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
  32. 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
  33. 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
  34. 'hair drier', 'toothbrush']
  35. def preprocess(self, img):
  36. img0 = img.copy()
  37. img = cv2.resize(img, (640, 640))
  38. img = letterbox(img, new_shape=self.img_size)[0]
  39. img = img[:, :, ::-1].transpose(2, 0, 1)
  40. img = np.ascontiguousarray(img)
  41. img = img.astype(np.float32)
  42. img /= 255.0
  43. if img.ndim == 3:
  44. img = np.expand_dims(img, axis=0)
  45. return img0, img
  46. def detect(self, im):
  47. im0, img = self.preprocess(im)
  48. # Prepare input for ONNX model
  49. input_name = self.session.get_inputs()[0].name
  50. pred = self.session.run(None, {input_name: img})[0] # Run inference
  51. pred = pred.astype(np.float32)
  52. pred = non_max_suppression(torch.from_numpy(pred), self.threshold, 0.4)
  53. pred_boxes = []
  54. for det in pred:
  55. if det is not None and len(det):
  56. det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
  57. for *x, conf, cls_id in det:
  58. lbl = self.names[int(cls_id)]
  59. if lbl not in ['person', 'car', 'truck']: # Filter unwanted labels
  60. continue
  61. x1, y1 = int(x[0]), int(x[1])
  62. x2, y2 = int(x[2]), int(x[3])
  63. pred_boxes.append((x1, y1, x2, y2, lbl, conf))
  64. return im0, pred_boxes