yolov5.cpp 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. #include <iostream>
  2. #include <opencv2/opencv.hpp>
  3. #include <opencv2/dnn/dnn.hpp>
  4. #include <string>
  5. #include <vector>
  6. using cv::Mat;
  7. using std::cout;
  8. using std::endl;
  9. using std::string;
  10. using std::vector;
  11. static const vector<string> class_name = {"cat", "chicken", "cow", "dog", "fox", "goat", "horse", "person", "racoon", "skunk"};
  12. // void print_result(const Mat &result, float conf = 0.7, int len_data = 15)
  13. // {
  14. // float *pdata = (float *)result.data;
  15. // for (int i = 0; i < result.total() / len_data; i++)
  16. // {
  17. // if (pdata[4] > conf)
  18. // {
  19. // for (int j = 0; j < len_data; j++)
  20. // {
  21. // cout << pdata[j] << " ";
  22. // }
  23. // cout << endl;
  24. // }
  25. // pdata += len_data;
  26. // }
  27. // return;
  28. // }
  29. vector<vector<float>> get_info(const Mat &result, float conf = 0.7, int len_data = 15)
  30. {
  31. float *pdata = (float *)result.data;
  32. vector<vector<float>> info;
  33. for (int i = 0; i < result.total() / len_data; i++)
  34. {
  35. if (pdata[4] > conf)
  36. {
  37. vector<float> info_line;
  38. for (int j = 0; j < len_data; j++)
  39. {
  40. // cout << pdata[j] << " ";
  41. info_line.push_back(pdata[j]);
  42. }
  43. // cout << endl;
  44. info.push_back(info_line);
  45. }
  46. pdata += len_data;
  47. }
  48. return info;
  49. }
  50. void info_simplify(vector<vector<float>> &info)
  51. {
  52. for (auto i = 0; i < info.size(); i++)
  53. {
  54. info[i][5] = std::max_element(info[i].cbegin() + 5, info[i].cend()) - (info[i].cbegin() + 5);
  55. info[i].resize(6);
  56. float x = info[i][0];
  57. float y = info[i][1];
  58. float w = info[i][2];
  59. float h = info[i][3];
  60. info[i][0] = x - w / 2.0;
  61. info[i][1] = y - h / 2.0;
  62. info[i][2] = x + w / 2.0;
  63. info[i][3] = y + h / 2.0;
  64. }
  65. }
  66. vector<vector<vector<float>>> split_info(vector<vector<float>> &info)
  67. {
  68. vector<vector<vector<float>>> info_split;
  69. vector<int> class_id;
  70. for (auto i = 0; i < info.size(); i++)
  71. {
  72. if (std::find(class_id.begin(), class_id.end(), (int)info[i][5]) == class_id.end())
  73. {
  74. class_id.push_back((int)info[i][5]);
  75. vector<vector<float>> info_;
  76. info_split.push_back(info_);
  77. }
  78. info_split[std::find(class_id.begin(), class_id.end(), (int)info[i][5]) - class_id.begin()].push_back(info[i]);
  79. }
  80. return info_split;
  81. }
  82. void nms(vector<vector<float>> &info, float iou = 0.4)
  83. {
  84. int counter = 0;
  85. vector<vector<float>> return_info;
  86. while (counter < info.size())
  87. {
  88. return_info.clear();
  89. float x1 = 0;
  90. float x2 = 0;
  91. float y1 = 0;
  92. float y2 = 0;
  93. std::sort(info.begin(), info.end(), [](vector<float> p1, vector<float> p2)
  94. { return p1[4] > p2[4]; });
  95. for (auto i = 0; i < info.size(); i++)
  96. {
  97. if (i < counter)
  98. {
  99. return_info.push_back(info[i]);
  100. continue;
  101. }
  102. if (i == counter)
  103. {
  104. x1 = info[i][0];
  105. y1 = info[i][1];
  106. x2 = info[i][2];
  107. y2 = info[i][3];
  108. return_info.push_back(info[i]);
  109. continue;
  110. }
  111. if (info[i][0] > x2 or info[i][2] < x1 or info[i][1] > y2 or info[i][3] < y1)
  112. {
  113. return_info.push_back(info[i]);
  114. }
  115. else
  116. {
  117. float over_x1 = std::max(x1, info[i][0]);
  118. float over_y1 = std::max(y1, info[i][1]);
  119. float over_x2 = std::min(x2, info[i][2]);
  120. float over_y2 = std::min(y2, info[i][3]);
  121. float s_over = (over_x2 - over_x1) * (over_y2 - over_y1);
  122. float s_total = (x2 - x1) * (y2 - y1) + (info[i][0] - info[i][2]) * (info[i][1] - info[i][3]) - s_over;
  123. if (s_over / s_total < iou)
  124. {
  125. return_info.push_back(info[i]);
  126. }
  127. }
  128. }
  129. info = return_info;
  130. counter += 1;
  131. }
  132. }
  133. // void print_info(const vector<vector<float>> &info)
  134. // {
  135. // for (auto i = 0; i < info.size(); i++)
  136. // {
  137. // for (auto j = 0; j < info[i].size(); j++)
  138. // {
  139. // cout << info[i][j] << " ";
  140. // }
  141. // cout << endl;
  142. // }
  143. // }
  144. void draw_box(Mat &img, const vector<vector<float>> &info)
  145. {
  146. for (int i = 0; i < info.size(); i++)
  147. {
  148. cv::rectangle(img, cv::Point(info[i][0], info[i][1]), cv::Point(info[i][2], info[i][3]), cv::Scalar(0, 255, 0));
  149. string label;
  150. label += class_name[info[i][5]];
  151. label += " ";
  152. std::stringstream oss;
  153. oss << info[i][4];
  154. label += oss.str();
  155. cv::putText(img, label, cv::Point(info[i][0], info[i][1]), 1, 2, cv::Scalar(0, 255, 0), 2);
  156. }
  157. }
  158. int main()
  159. {
  160. cv::dnn::Net net = cv::dnn::readNetFromONNX("best.onnx");
  161. Mat img = cv::imread("fox.jpg");
  162. cv::resize(img, img, cv::Size(640, 640));
  163. Mat blob = cv::dnn::blobFromImage(img, 1.0 / 255.0, cv::Size(640, 640), cv::Scalar(), true);
  164. net.setInput(blob);
  165. vector<Mat> netoutput;
  166. vector<string> out_name = {"output"};
  167. net.forward(netoutput, out_name);
  168. Mat result = netoutput[0];
  169. // print_result(result);
  170. vector<vector<float>> info = get_info(result);
  171. info_simplify(info);
  172. vector<vector<vector<float>>> info_split = split_info(info);
  173. // cout << " split info" << endl;
  174. // print_info(info_split[0]);
  175. // cout << info.size() << " " << info[0].size() << endl;
  176. for(auto i=0; i < info_split.size(); i++)
  177. {
  178. nms(info_split[i]);
  179. draw_box(img, info_split[i]);
  180. }
  181. // nms(info_split[0]);
  182. // cout << "nms" << endl;
  183. // print_info(info_split[0]);
  184. // draw_box(img, info_split[0]);
  185. cv::imshow("test", img);
  186. cv::waitKey(0);
  187. return 0;
  188. }