yolov5onnxC++.cpp 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. #include <array>
  2. #include <algorithm>
  3. #include <iostream>
  4. #include <opencv2/opencv.hpp>
  5. #include <onnxruntime_cxx_api.h>
  6. #include <vector>
  7. #include <string>
  8. #include <chrono>
  9. #include <filesystem>
  10. using cv::Mat;
  11. using std::cout;
  12. using std::endl;
  13. using std::string;
  14. using std::vector;
  15. namespace fs = std::filesystem;
  16. static const vector<string> class_name = { "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
  17. "fire hydrant","stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
  18. "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
  19. "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
  20. "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
  21. "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
  22. "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone",
  23. "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
  24. "hair drier", "toothbrush" };
  25. // 筛选置信度低的
  26. vector<vector<float>> get_info(const float* pdata, int total, float conf = 0.5, int len_data = 85)
  27. {
  28. vector<vector<float>> info;
  29. for (int i = 0; i < total / len_data; i++)
  30. {
  31. if (pdata[4] >= conf)
  32. {
  33. vector<float> line_data(pdata, pdata + len_data);
  34. info.push_back(line_data);
  35. }
  36. pdata += len_data;
  37. }
  38. return info;
  39. }
  40. void info_simplify(vector<vector<float>>& info)
  41. {
  42. for (int i = 0; i < info.size(); i++)
  43. {
  44. // 找出类别的索引
  45. auto max_pos = std::max_element(info[i].cbegin() + 5, info[i].cend());
  46. int class_id = std::distance(info[i].cbegin() + 5, max_pos);
  47. // 仅保留类别为 "person" 的检测
  48. if (class_id == 0)
  49. {
  50. info[i][5] = class_id;
  51. info[i].resize(6);
  52. float x = info[i][0];
  53. float y = info[i][1];
  54. float w = info[i][2];
  55. float h = info[i][3];
  56. info[i][0] = x - w / 2.0;
  57. info[i][1] = y - h / 2.0;
  58. info[i][2] = x + w / 2.0;
  59. info[i][3] = y + h / 2.0;
  60. }
  61. else
  62. {
  63. info.erase(info.begin() + i);
  64. i--;
  65. }
  66. }
  67. }
  68. void nms(vector<vector<float>>& info, float iou = 0.4)
  69. {
  70. int counter = 0;
  71. vector<vector<float>> return_info;
  72. while (counter < info.size())
  73. {
  74. return_info.clear();
  75. float x1 = 0;
  76. float x2 = 0;
  77. float y1 = 0;
  78. float y2 = 0;
  79. // 按照置信度排序
  80. std::sort(info.begin(), info.end(), [](vector<float> p1, vector<float> p2) {
  81. return p1[4] > p2[4];
  82. });
  83. for (auto i = 0; i < info.size(); i++)
  84. {
  85. if (i < counter)
  86. {
  87. return_info.push_back(info[i]);
  88. continue;
  89. }
  90. if (i == counter)
  91. {
  92. x1 = info[i][0];
  93. y1 = info[i][1];
  94. x2 = info[i][2];
  95. y2 = info[i][3];
  96. return_info.push_back(info[i]);
  97. continue;
  98. }
  99. if (info[i][0] > x2 || info[i][2] < x1 || info[i][1] > y2 || info[i][3] < y1)
  100. {
  101. return_info.push_back(info[i]);
  102. }
  103. else
  104. {
  105. float over_x1 = std::max(x1, info[i][0]);
  106. float over_y1 = std::max(y1, info[i][1]);
  107. float over_x2 = std::min(x2, info[i][2]);
  108. float over_y2 = std::min(y2, info[i][3]);
  109. float s_over = (over_x2 - over_x1) * (over_y2 - over_y1);
  110. float s_total = (x2 - x1) * (y2 - y1) + (info[i][0] - info[i][2]) * (info[i][1] - info[i][3]) - s_over;
  111. if (s_over / s_total < iou)
  112. {
  113. return_info.push_back(info[i]);
  114. }
  115. }
  116. }
  117. info = return_info;
  118. counter += 1;
  119. }
  120. }
  121. void draw_box(Mat& img, const vector<vector<float>>& info)
  122. {
  123. for (int i = 0; i < info.size(); i++)
  124. {
  125. if (static_cast<int>(info[i][5]) == 0)
  126. {
  127. cv::Point topLeft(info[i][0], info[i][1]);
  128. cv::Point bottomRight(info[i][2], info[i][3]);
  129. int thickness = 2;
  130. cv::Scalar color(0, 255, 0);
  131. int lineType = cv::LINE_8;
  132. const int cornerRadius = 5;
  133. cv::rectangle(img, topLeft, bottomRight, color, thickness, lineType);
  134. string label = class_name[0] + " " + std::to_string(info[i][4]); // 仅显示 "person" 标签
  135. cv::Size textSize = cv::getTextSize(label, cv::FONT_HERSHEY_SIMPLEX, 0.6, 1, nullptr);
  136. cv::Rect textBgRect(topLeft.x, topLeft.y - textSize.height - 5, textSize.width + 10, textSize.height + 5);
  137. cv::rectangle(img, textBgRect, color, cv::FILLED);
  138. cv::putText(img, label, cv::Point(topLeft.x + 5, topLeft.y - 5), cv::FONT_HERSHEY_SIMPLEX, 0.6, CV_RGB(255, 255, 255), 2);
  139. }
  140. }
  141. }
  142. int main()
  143. {
  144. // 定义ONNX模型路径
  145. const wchar_t* model_path = L"D://Thework//testrelease//yolov5s.onnx";
  146. // 初始化 ONNX 运行环境和内存信息
  147. Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "ONNXRuntime");
  148. auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
  149. // 配置会话选项
  150. Ort::SessionOptions session_options;
  151. session_options.SetIntraOpNumThreads(5);
  152. session_options.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
  153. // 定义输入输出的名称
  154. const char* input_names[] = { "images" }; // 根据你的模型输入名调整
  155. const char* output_names[] = { "output" }; // 根据你的模型输出名调整
  156. try {
  157. // 创建 ONNX 会话
  158. Ort::Session session(env, model_path, session_options);
  159. // 打开摄像头
  160. //cv::VideoCapture cap(0);
  161. //if (!cap.isOpened())
  162. //{
  163. // cout << "Error: Cannot open the camera" << endl;
  164. // return -1;
  165. //}
  166. cv::VideoCapture cap(7, cv::CAP_V4L2);
  167. if (!cap.isOpened())
  168. {
  169. cout << "Error: Cannot open the camera" << endl;
  170. return -1;
  171. }
  172. cap.set(cv::CAP_PROP_FRAME_WIDTH, 1280);
  173. cap.set(cv::CAP_PROP_FRAME_HEIGHT, 720);
  174. // 创建保存图像的文件夹
  175. std::filesystem::path save_dir = "saved_images";
  176. std::filesystem::create_directory(save_dir);
  177. int image_count = 0; // 用于记录已保存的图像数量
  178. auto start_time = std::chrono::steady_clock::now();
  179. while (image_count < 5)
  180. {
  181. cv::Mat img;
  182. cap >> img;
  183. if (img.empty()) break;
  184. double start = cv::getTickCount();
  185. // 获取原始图像尺寸
  186. cv::Size originalSize = img.size();
  187. // 将图像调整为模型要求的尺寸
  188. cv::Mat resized_img;
  189. cv::resize(img, resized_img, cv::Size(640, 640));
  190. cv::Mat blob = cv::dnn::blobFromImage(resized_img, 1.0 / 255.0, cv::Size(640, 640), cv::Scalar(), true);
  191. // 定义输入Tensor
  192. std::array<int64_t, 4> input_shape = { 1, 3, 640, 640 };
  193. Ort::Value input_tensor = Ort::Value::CreateTensor<float>(memory_info, blob.ptr<float>(), blob.total(), input_shape.data(), input_shape.size());
  194. // 定义输出Tensor的形状
  195. std::array<int64_t, 3> output_shape{ 1, 25200, 85 }; // 根据实际模型输出调整
  196. std::vector<float> output_tensor_values(1 * 25200 * 85); // 根据实际模型输出调整
  197. Ort::Value output_tensor = Ort::Value::CreateTensor<float>(memory_info, output_tensor_values.data(), output_tensor_values.size(), output_shape.data(), output_shape.size());
  198. // 推理
  199. session.Run(Ort::RunOptions{ nullptr }, input_names, &input_tensor, 1, output_names, &output_tensor, 1);
  200. // 处理输出
  201. float* output_data = output_tensor.GetTensorMutableData<float>();
  202. int output_size = 25200 * 85; // 根据实际模型输出调整
  203. vector<vector<float>> info = get_info(output_data, output_size, 0.5);
  204. info_simplify(info);
  205. nms(info);
  206. // 将检测框重新映射回原始图像尺寸
  207. for (auto& box : info)
  208. {
  209. box[0] *= (float)originalSize.width / 640.0;
  210. box[1] *= (float)originalSize.height / 640.0;
  211. box[2] *= (float)originalSize.width / 640.0;
  212. box[3] *= (float)originalSize.height / 640.0;
  213. }
  214. // 在原始图像上绘制检测框
  215. draw_box(img, info);
  216. double end = cv::getTickCount();
  217. double timeSec = (end - start) / cv::getTickFrequency();
  218. cout << "Frame time: " << timeSec << " seconds" << endl;
  219. // 每隔3秒保存一张图像
  220. auto now = std::chrono::steady_clock::now();
  221. if (std::chrono::duration_cast<std::chrono::seconds>(now - start_time).count() >= 3)
  222. {
  223. string filename = (save_dir / ("image_" + std::to_string(image_count) + ".jpg")).string();
  224. cv::imwrite(filename, img);
  225. cout << "Saved " << filename << endl;
  226. start_time = now; // 重置计时器
  227. image_count++; // 增加计数器
  228. }
  229. // 等待1毫秒
  230. std::this_thread::sleep_for(std::chrono::milliseconds(1));
  231. }
  232. // 释放摄像头
  233. cap.release();
  234. cv::destroyAllWindows();
  235. }
  236. catch (const Ort::Exception& e) {
  237. std::cerr << "ONNX Runtime 异常: " << e.what() << std::endl;
  238. }
  239. catch (const std::exception& e) {
  240. std::cerr << "标准异常: " << e.what() << std::endl;
  241. }
  242. catch (...) {
  243. std::cerr << "未知异常." << std::endl;
  244. }
  245. return 0;
  246. }