testonnxvideo-1.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. import os
  2. import sys
  3. import onnx
  4. import onnxruntime as ort
  5. import cv2
  6. import numpy as np
  7. import time
  8. CLASSES = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
  9. 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
  10. 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
  11. 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
  12. 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
  13. 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
  14. 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
  15. 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
  16. 'hair drier', 'toothbrush']
  17. class Yolov5ONNX(object):
  18. def __init__(self, onnx_path):
  19. onnx_model = onnx.load(onnx_path)
  20. try:
  21. onnx.checker.check_model(onnx_model)
  22. except Exception:
  23. print("Model incorrect")
  24. else:
  25. print("Model correct")
  26. self.onnx_session = ort.InferenceSession(onnx_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
  27. providers = self.onnx_session.get_providers()
  28. if 'CUDAExecutionProvider' in providers:
  29. print("Using CUDA for inference.")
  30. else:
  31. print("CUDA is not available, using CPU for inference.")
  32. self.input_name = self.get_input_name()
  33. self.output_name = self.get_output_name()
  34. self.input_size = (640, 640)
  35. def get_input_name(self):
  36. input_name = []
  37. for node in self.onnx_session.get_inputs():
  38. input_name.append(node.name)
  39. return input_name
  40. def get_output_name(self):
  41. output_name = []
  42. for node in self.onnx_session.get_outputs():
  43. output_name.append(node.name)
  44. return output_name
  45. def get_input_feed(self, image_numpy):
  46. input_feed = {}
  47. for name in self.input_name:
  48. input_feed[name] = image_numpy
  49. return input_feed
  50. def inference(self, img):
  51. h, w, _ = img.shape
  52. new_w, new_h = self.input_size
  53. scale = min(new_w / w, new_h / h)
  54. new_w = int(w * scale)
  55. new_h = int(h * scale)
  56. img_resized = cv2.resize(img, (new_w, new_h))
  57. padded_img = np.zeros((self.input_size[1], self.input_size[0], 3), dtype=np.uint8)
  58. pad_x = (self.input_size[0] - new_w) // 2
  59. pad_y = (self.input_size[1] - new_h) // 2
  60. padded_img[pad_y:pad_y + new_h, pad_x:pad_x + new_w] = img_resized
  61. img_rgb = cv2.cvtColor(padded_img, cv2.COLOR_BGR2RGB).transpose(2, 0, 1)
  62. img_rgb = img_rgb.astype(np.float32)
  63. img_rgb /= 255.0
  64. img_rgb = np.expand_dims(img_rgb, axis=0)
  65. input_feed = self.get_input_feed(img_rgb)
  66. start_time = time.time()
  67. pred = self.onnx_session.run(None, input_feed)[0]
  68. end_time = time.time()
  69. inference_time = end_time - start_time
  70. print(f"Inference time: {inference_time:.4f} seconds")
  71. return pred, padded_img
  72. def nms(dets, thresh):
  73. x1 = dets[:, 0]
  74. y1 = dets[:, 1]
  75. x2 = dets[:, 2]
  76. y2 = dets[:, 3]
  77. areas = (y2 - y1 + 1) * (x2 - x1 + 1)
  78. scores = dets[:, 4]
  79. keep = []
  80. index = scores.argsort()[::-1]
  81. while index.size > 0:
  82. i = index[0]
  83. keep.append(i)
  84. x11 = np.maximum(x1[i], x1[index[1:]])
  85. y11 = np.maximum(y1[i], y1[index[1:]])
  86. x22 = np.minimum(x2[i], x2[index[1:]])
  87. y22 = np.minimum(y2[i], y2[index[1:]])
  88. w = np.maximum(0, x22 - x11 + 1)
  89. h = np.maximum(0, y22 - y11 + 1)
  90. overlaps = w * h
  91. ious = overlaps / (areas[i] + areas[index[1:]] - overlaps)
  92. idx = np.where(ious <= thresh)[0]
  93. index = index[idx + 1]
  94. return keep
  95. def xywh2xyxy(x):
  96. y = np.copy(x)
  97. y[:, 0] = x[:, 0] - x[:, 2] / 2
  98. y[:, 1] = x[:, 1] - x[:, 3] / 2
  99. y[:, 2] = x[:, 0] + x[:, 2] / 2
  100. y[:, 3] = x[:, 1] + x[:, 3] / 2
  101. return y
  102. def filter_box(org_box, conf_thres, iou_thres):
  103. org_box = np.squeeze(org_box)
  104. conf = org_box[..., 4] > conf_thres
  105. box = org_box[conf == True]
  106. if box.size == 0:
  107. return np.array([])
  108. cls_cinf = box[..., 5:]
  109. cls = [int(np.argmax(cls_cinf[i])) for i in range(len(cls_cinf))]
  110. person_boxes = [box[i] for i in range(len(cls)) if cls[i] == 0]
  111. if len(person_boxes) == 0:
  112. return np.array([])
  113. person_boxes = np.array(person_boxes)
  114. person_boxes = xywh2xyxy(person_boxes)
  115. person_out_box = nms(person_boxes, iou_thres)
  116. output = [person_boxes[k] for k in person_out_box]
  117. return np.array(output)
  118. def draw(image, box_data):
  119. if box_data.size == 0:
  120. return image
  121. boxes = box_data[..., :4].astype(np.int32)
  122. scores = box_data[..., 4]
  123. classes = box_data[..., 5].astype(np.int32)
  124. for box, score, cl in zip(boxes, scores, classes):
  125. top, left, right, bottom = box
  126. cv2.rectangle(image, (top, left), (right, bottom), (255, 0, 0), 2)
  127. cv2.putText(image, '{0} {1:.2f}'.format(CLASSES[cl], score),
  128. (top, left),
  129. cv2.FONT_HERSHEY_SIMPLEX,
  130. 0.6, (0, 0, 255), 2)
  131. return image
  132. def main():
  133. onnx_path = 'yolov5s.onnx'
  134. model = Yolov5ONNX(onnx_path)
  135. cap = cv2.VideoCapture(7, cv2.CAP_V4L2)
  136. cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280)
  137. cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720)
  138. if not cap.isOpened():
  139. print("无法打开摄像头")
  140. sys.exit(0)
  141. last_save_time = time.time()
  142. frame_count = 0
  143. output_dir = "saved_images"
  144. if not os.path.exists(output_dir):
  145. os.makedirs(output_dir)
  146. # 控制帧率
  147. frame_rate = 1 # 每秒最多处理10帧
  148. prev_time = time.time()
  149. while True:
  150. ret, frame = cap.read()
  151. if not ret:
  152. print("无法读取摄像头图像")
  153. break
  154. current_time = time.time()
  155. if current_time - prev_time >= 1.0 / frame_rate:
  156. prev_time = current_time
  157. output, org_img = model.inference(frame)
  158. outbox = filter_box(output, 0.5, 0.5)
  159. org_img = draw(org_img, outbox)
  160. if time.time() - last_save_time >= 2:
  161. frame_count += 1
  162. image_path = os.path.join(output_dir, f'result_{frame_count}.jpg')
  163. cv2.imwrite(image_path, org_img)
  164. print(f"Image saved: {image_path}")
  165. last_save_time = time.time()
  166. if frame_count >= 5:
  167. print("保存了5张图片,程序退出")
  168. break
  169. cap.release()
  170. cv2.destroyAllWindows()
  171. if __name__ == "__main__":
  172. main()