testonnxvideo.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. import os
  2. import sys
  3. import onnx
  4. import onnxruntime as ort
  5. import cv2
  6. import numpy as np
  7. import time # 导入时间模块
  8. # CLASSES = ['person']
  9. CLASSES = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
  10. 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
  11. 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
  12. 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
  13. 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
  14. 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
  15. 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
  16. 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
  17. 'hair drier', 'toothbrush'] # coco80类别
  18. class Yolov5ONNX(object):
  19. def __init__(self, onnx_path):
  20. onnx_model = onnx.load(onnx_path)
  21. try:
  22. onnx.checker.check_model(onnx_model)
  23. except Exception:
  24. print("Model incorrect")
  25. else:
  26. print("Model correct")
  27. # 设置 ONNX Runtime 会尝试优先使用 CUDAExecutionProvider,如果不可用则使用 CPUExecutionProvider
  28. self.onnx_session = ort.InferenceSession(onnx_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
  29. # 检查是否使用了 CUDA
  30. providers = self.onnx_session.get_providers()
  31. if 'CUDAExecutionProvider' in providers:
  32. print("Using CUDA for inference.")
  33. else:
  34. print("CUDA is not available, using CPU for inference.")
  35. self.input_name = self.get_input_name()
  36. self.output_name = self.get_output_name()
  37. self.input_size = (640, 640)
  38. def get_input_name(self):
  39. input_name = []
  40. for node in self.onnx_session.get_inputs():
  41. input_name.append(node.name)
  42. return input_name
  43. def get_output_name(self):
  44. output_name = []
  45. for node in self.onnx_session.get_outputs():
  46. output_name.append(node.name)
  47. return output_name
  48. def get_input_feed(self, image_numpy):
  49. input_feed = {}
  50. for name in self.input_name:
  51. input_feed[name] = image_numpy
  52. return input_feed
  53. def inference(self, img):
  54. h, w, _ = img.shape
  55. new_w, new_h = self.input_size
  56. scale = min(new_w / w, new_h / h)
  57. new_w = int(w * scale)
  58. new_h = int(h * scale)
  59. img_resized = cv2.resize(img, (new_w, new_h))
  60. padded_img = np.zeros((self.input_size[1], self.input_size[0], 3), dtype=np.uint8)
  61. pad_x = (self.input_size[0] - new_w) // 2
  62. pad_y = (self.input_size[1] - new_h) // 2
  63. padded_img[pad_y:pad_y + new_h, pad_x:pad_x + new_w] = img_resized
  64. img_rgb = cv2.cvtColor(padded_img, cv2.COLOR_BGR2RGB).transpose(2, 0, 1)
  65. img_rgb = img_rgb.astype(np.float32)
  66. img_rgb /= 255.0
  67. img_rgb = np.expand_dims(img_rgb, axis=0)
  68. input_feed = self.get_input_feed(img_rgb)
  69. start_time = time.time() # 记录开始时间
  70. pred = self.onnx_session.run(None, input_feed)[0]
  71. end_time = time.time() # 记录结束时间
  72. inference_time = end_time - start_time # 计算推理时间
  73. print(f"Inference time: {inference_time:.4f} seconds") # 输出推理时间
  74. return pred, padded_img
  75. def nms(dets, thresh):
  76. x1 = dets[:, 0]
  77. y1 = dets[:, 1]
  78. x2 = dets[:, 2]
  79. y2 = dets[:, 3]
  80. areas = (y2 - y1 + 1) * (x2 - x1 + 1)
  81. scores = dets[:, 4]
  82. keep = []
  83. index = scores.argsort()[::-1]
  84. while index.size > 0:
  85. i = index[0]
  86. keep.append(i)
  87. x11 = np.maximum(x1[i], x1[index[1:]])
  88. y11 = np.maximum(y1[i], y1[index[1:]])
  89. x22 = np.minimum(x2[i], x2[index[1:]])
  90. y22 = np.minimum(y2[i], y2[index[1:]])
  91. w = np.maximum(0, x22 - x11 + 1)
  92. h = np.maximum(0, y22 - y11 + 1)
  93. overlaps = w * h
  94. ious = overlaps / (areas[i] + areas[index[1:]] - overlaps)
  95. idx = np.where(ious <= thresh)[0]
  96. index = index[idx + 1]
  97. return keep
  98. def xywh2xyxy(x):
  99. y = np.copy(x)
  100. y[:, 0] = x[:, 0] - x[:, 2] / 2
  101. y[:, 1] = x[:, 1] - x[:, 3] / 2
  102. y[:, 2] = x[:, 0] + x[:, 2] / 2
  103. y[:, 3] = x[:, 1] + x[:, 3] / 2
  104. return y
  105. def filter_box(org_box, conf_thres, iou_thres):
  106. org_box = np.squeeze(org_box)
  107. conf = org_box[..., 4] > conf_thres
  108. box = org_box[conf == True]
  109. if box.size == 0:
  110. return np.array([]) # 确保返回一个空的numpy数组
  111. cls_cinf = box[..., 5:]
  112. cls = [int(np.argmax(cls_cinf[i])) for i in range(len(cls_cinf))]
  113. # 只保留类别为“person”的框
  114. person_boxes = [box[i] for i in range(len(cls)) if cls[i] == 0] # 类别“person”对应的索引是0
  115. if len(person_boxes) == 0:
  116. return np.array([]) # 如果没有“person”框,返回一个空的numpy数组
  117. person_boxes = np.array(person_boxes)
  118. person_boxes = xywh2xyxy(person_boxes)
  119. person_out_box = nms(person_boxes, iou_thres)
  120. output = [person_boxes[k] for k in person_out_box]
  121. return np.array(output)
  122. def draw(image, box_data):
  123. if box_data.size == 0:
  124. return image
  125. boxes = box_data[..., :4].astype(np.int32)
  126. scores = box_data[..., 4]
  127. classes = box_data[..., 5].astype(np.int32)
  128. for box, score, cl in zip(boxes, scores, classes):
  129. top, left, right, bottom = box
  130. cv2.rectangle(image, (top, left), (right, bottom), (255, 0, 0), 2)
  131. cv2.putText(image, '{0} {1:.2f}'.format(CLASSES[cl], score),
  132. (top, left),
  133. cv2.FONT_HERSHEY_SIMPLEX,
  134. 0.6, (0, 0, 255), 2)
  135. return image
  136. # def main():
  137. # onnx_path = 'D:/Thework/yolov5-5.0/yolov5-5.0/runs/train/exp/weights/yolov5s.onnx'
  138. # model = Yolov5ONNX(onnx_path)
  139. #
  140. # cap = cv2.VideoCapture(0) # 打开摄像头
  141. # if not cap.isOpened():
  142. # print("无法打开摄像头")
  143. # sys.exit(0)
  144. #
  145. # while True:
  146. # ret, frame = cap.read()
  147. # if not ret:
  148. # print("无法读取摄像头图像")
  149. # break
  150. #
  151. # start_time = time.time() # 记录开始时间
  152. # output, org_img = model.inference(frame)
  153. # outbox = filter_box(output, 0.5, 0.5)
  154. # org_img = draw(org_img, outbox)
  155. # end_time = time.time() # 记录结束时间
  156. #
  157. # inference_time = end_time - start_time # 计算推理时间
  158. # print(f"Total processing time: {inference_time:.4f} seconds") # 输出总处理时间
  159. #
  160. # cv2.imshow('result', org_img)
  161. #
  162. # if cv2.waitKey(1) & 0xFF == ord('q'):
  163. # break
  164. #
  165. # cap.release()
  166. # cv2.destroyAllWindows()
  167. #
  168. #
  169. # if __name__ == "__main__":
  170. # main()
  171. def main():
  172. onnx_path = 'yolov5s.onnx'
  173. model = Yolov5ONNX(onnx_path)
  174. # cap = cv2.VideoCapture(0) # 打开摄像头
  175. cap = cv2.VideoCapture(7, cv2.CAP_V4L2)
  176. cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280)
  177. cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720)
  178. if not cap.isOpened():
  179. print("无法打开摄像头")
  180. sys.exit(0)
  181. last_save_time = time.time() # 记录最后一次保存的时间
  182. frame_count = 0 # 用于命名保存的图像
  183. # 创建保存图片的文件夹
  184. output_dir = "saved_images"
  185. if not os.path.exists(output_dir):
  186. os.makedirs(output_dir)
  187. while True:
  188. ret, frame = cap.read()
  189. if not ret:
  190. print("无法读取摄像头图像")
  191. break
  192. output, org_img = model.inference(frame)
  193. outbox = filter_box(output, 0.5, 0.5)
  194. org_img = draw(org_img, outbox)
  195. # 如果距离上次保存超过2秒,则保存图像
  196. if time.time() - last_save_time >= 2:
  197. frame_count += 1
  198. image_path = os.path.join(output_dir, f'result_{frame_count}.jpg')
  199. cv2.imwrite(image_path, org_img) # 保存图像
  200. print(f"Image saved: {image_path}") # 输出保存路径
  201. last_save_time = time.time() # 更新最后一次保存的时间
  202. # 如果保存了5张图片,退出循环
  203. if frame_count >= 5:
  204. print("保存了5张图片,程序退出")
  205. break
  206. cap.release()
  207. cv2.destroyAllWindows()
  208. if __name__ == "__main__":
  209. main()