YOLOv5Detector.cpp 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. #include "YOLOv5Detector.h"
  2. static const vector<string> class_name = { "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
  3. "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
  4. "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
  5. "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
  6. "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
  7. "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
  8. "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone",
  9. "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
  10. "hair drier", "toothbrush" };
  11. YOLOv5Detector::YOLOv5Detector(const wchar_t* model_path)
  12. : env(ORT_LOGGING_LEVEL_WARNING, "ONNXRuntime"),
  13. memory_info(Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU)),
  14. session_options(),
  15. session(env, model_path, session_options)
  16. {
  17. // 配置会话选项
  18. session_options.SetIntraOpNumThreads(5);
  19. session_options.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
  20. // 尝试添加 CUDA 提供程序
  21. try {
  22. OrtCUDAProviderOptions cuda_options;
  23. cuda_options.device_id = 0; // 设置为 GPU 0
  24. session_options.AppendExecutionProvider_CUDA(cuda_options);
  25. use_cuda = true;
  26. }
  27. catch (const Ort::Exception& e) {
  28. cout << "CUDA 不可用, 使用 CPU." << endl;
  29. use_cuda = false;
  30. }
  31. // 输出是否使用 CUDA
  32. cout << (use_cuda ? "正在使用 CUDA 进行推理" : "正在使用 CPU 进行推理") << endl;
  33. }
  34. vector<vector<float>> YOLOv5Detector::detect(const cv::Mat& img)
  35. {
  36. cv::Mat resized_img;
  37. cv::resize(img, resized_img, cv::Size(640, 640));
  38. cv::Mat blob = cv::dnn::blobFromImage(resized_img, 1.0 / 255.0, cv::Size(640, 640), cv::Scalar(), true);
  39. // 定义输入Tensor
  40. std::array<int64_t, 4> input_shape = { 1, 3, 640, 640 };
  41. Ort::Value input_tensor = Ort::Value::CreateTensor<float>(memory_info, blob.ptr<float>(), blob.total(), input_shape.data(), input_shape.size());
  42. // 定义输出Tensor的形状
  43. std::array<int64_t, 3> output_shape{ 1, 25200, 85 };
  44. std::vector<float> output_tensor_values(1 * 25200 * 85);
  45. Ort::Value output_tensor = Ort::Value::CreateTensor<float>(memory_info, output_tensor_values.data(), output_tensor_values.size(), output_shape.data(), output_shape.size());
  46. // 推理
  47. session.Run(Ort::RunOptions{ nullptr }, input_names, &input_tensor, 1, output_names, &output_tensor, 1);
  48. // 处理输出
  49. float* output_data = output_tensor.GetTensorMutableData<float>();
  50. int output_size = 25200 * 85;
  51. vector<vector<float>> info = get_info(output_data, output_size);
  52. info_simplify(info);
  53. nms(info);
  54. // 将检测框重新映射回原始图像尺寸
  55. for (auto& box : info)
  56. {
  57. box[0] *= (float)img.cols / 640.0;
  58. box[1] *= (float)img.rows / 640.0;
  59. box[2] *= (float)img.cols / 640.0;
  60. box[3] *= (float)img.rows / 640.0;
  61. }
  62. return info;
  63. }
  64. void YOLOv5Detector::draw_boxes(Mat& img, const vector<vector<float>>& info)
  65. {
  66. for (const auto& box : info)
  67. {
  68. if (static_cast<int>(box[5]) == 0)
  69. {
  70. cv::Point topLeft(box[0], box[1]);
  71. cv::Point bottomRight(box[2], box[3]);
  72. cv::rectangle(img, topLeft, bottomRight, cv::Scalar(0, 255, 0), 2);
  73. string label = class_name[0] + " " + std::to_string(box[4]);
  74. cv::putText(img, label, topLeft, cv::FONT_HERSHEY_SIMPLEX, 0.6, CV_RGB(255, 255, 255), 2);
  75. }
  76. }
  77. }
  78. cv::Mat YOLOv5Detector::QImageToMat(const QImage& image)
  79. {
  80. // 处理 RGB 格式的 QImage
  81. switch (image.format())
  82. {
  83. case QImage::Format_RGB32:
  84. {
  85. cv::Mat mat(image.height(), image.width(), CV_8UC4, (void*)image.bits(), image.bytesPerLine());
  86. cv::Mat bgr;
  87. cv::cvtColor(mat, bgr, cv::COLOR_RGBA2BGR); // 转换为 BGR
  88. return bgr.clone(); // 复制以避免修改原始数据
  89. //return mat.clone(); // 复制以避免修改原始数据
  90. }
  91. case QImage::Format_RGB888:
  92. {
  93. QImage swapped = image.rgbSwapped(); // OpenCV 使用 BGR 排列
  94. return cv::Mat(swapped.height(), swapped.width(), CV_8UC3, (void*)swapped.bits(), swapped.bytesPerLine()).clone();
  95. }
  96. case QImage::Format_Grayscale8:
  97. {
  98. return cv::Mat(image.height(), image.width(), CV_8UC1, (void*)image.bits(), image.bytesPerLine()).clone();
  99. }
  100. default:
  101. break;
  102. }
  103. // 如果无法识别格式,返回空的 cv::Mat
  104. return cv::Mat();
  105. }
  106. vector<vector<float>> YOLOv5Detector::get_info(const float* pdata, int total)
  107. {
  108. float conf = 0.5;
  109. int len_data = 85;
  110. vector<vector<float>> info;
  111. for (int i = 0; i < total / len_data; i++)
  112. {
  113. if (pdata[4] >= conf)
  114. {
  115. vector<float> line_data(pdata, pdata + len_data);
  116. info.push_back(line_data);
  117. }
  118. pdata += len_data;
  119. }
  120. return info;
  121. }
  122. void YOLOv5Detector::info_simplify(vector<vector<float>>& info)
  123. {
  124. for (int i = 0; i < info.size(); i++)
  125. {
  126. auto max_pos = std::max_element(info[i].cbegin() + 5, info[i].cend());
  127. int class_id = std::distance(info[i].cbegin() + 5, max_pos);
  128. info[i][5] = class_id;
  129. info[i].resize(6);
  130. float x = info[i][0];
  131. float y = info[i][1];
  132. float w = info[i][2];
  133. float h = info[i][3];
  134. info[i][0] = x - w / 2.0;
  135. info[i][1] = y - h / 2.0;
  136. info[i][2] = x + w / 2.0;
  137. info[i][3] = y + h / 2.0;
  138. }
  139. }
  140. void YOLOv5Detector::nms(vector<vector<float>>& info)
  141. {
  142. float iou = 0.4;
  143. int counter = 0;
  144. vector<vector<float>> return_info;
  145. while (counter < info.size())
  146. {
  147. return_info.clear();
  148. float x1 = 0;
  149. float x2 = 0;
  150. float y1 = 0;
  151. float y2 = 0;
  152. // 按照置信度排序
  153. std::sort(info.begin(), info.end(), [](vector<float> p1, vector<float> p2)
  154. {
  155. return p1[4] > p2[4];
  156. });
  157. for (auto i = 0; i < info.size(); i++)
  158. {
  159. if (i < counter)
  160. {
  161. return_info.push_back(info[i]);
  162. continue;
  163. }
  164. if (i == counter)
  165. {
  166. x1 = info[i][0];
  167. y1 = info[i][1];
  168. x2 = info[i][2];
  169. y2 = info[i][3];
  170. return_info.push_back(info[i]);
  171. continue;
  172. }
  173. if (info[i][0] > x2 || info[i][2] < x1 || info[i][1] > y2 || info[i][3] < y1)
  174. {
  175. return_info.push_back(info[i]);
  176. }
  177. else
  178. {
  179. float over_x1 = std::max(x1, info[i][0]);
  180. float over_y1 = std::max(y1, info[i][1]);
  181. float over_x2 = std::min(x2, info[i][2]);
  182. float over_y2 = std::min(y2, info[i][3]);
  183. float s_over = (over_x2 - over_x1) * (over_y2 - over_y1);
  184. float s_total = (x2 - x1) * (y2 - y1) + (info[i][2] - info[i][0]) * (info[i][3] - info[i][1]);
  185. if (s_over / s_total < iou)
  186. {
  187. return_info.push_back(info[i]);
  188. }
  189. }
  190. }
  191. info = return_info;
  192. counter++;
  193. }
  194. }