#include #include #include #include #include using cv::Mat; using std::cout; using std::endl; using std::string; using std::vector; static const vector class_name = { "cat", "chicken", "cow", "dog", "fox", "goat", "horse", "person", "racoon", "skunk" }; vector> get_info(const Mat& result, float conf = 0.7, int len_data = 15) { float* pdata = (float*)result.data; vector> info; for (int i = 0; i < result.total() / len_data; i++) { if (pdata[4] > conf) { vector info_line; for (int j = 0; j < len_data; j++) { // cout << pdata[j] << " "; info_line.push_back(pdata[j]); } // cout << endl; info.push_back(info_line); } pdata += len_data; } return info; } void info_simplify(vector>& info) { for (auto i = 0; i < info.size(); i++) { info[i][5] = std::max_element(info[i].cbegin() + 5, info[i].cend()) - (info[i].cbegin() + 5); info[i].resize(6); float x = info[i][0]; float y = info[i][1]; float w = info[i][2]; float h = info[i][3]; info[i][0] = x - w / 2.0; info[i][1] = y - h / 2.0; info[i][2] = x + w / 2.0; info[i][3] = y + h / 2.0; } } vector>> split_info(vector>& info) { vector>> info_split; vector class_id; for (auto i = 0; i < info.size(); i++) { if (std::find(class_id.begin(), class_id.end(), (int)info[i][5]) == class_id.end()) { class_id.push_back((int)info[i][5]); vector> info_; info_split.push_back(info_); } info_split[std::find(class_id.begin(), class_id.end(), (int)info[i][5]) - class_id.begin()].push_back(info[i]); } return info_split; } void nms(vector>& info, float iou = 0.4) { int counter = 0; vector> return_info; while (counter < info.size()) { return_info.clear(); float x1 = 0; float x2 = 0; float y1 = 0; float y2 = 0; std::sort(info.begin(), info.end(), [](vector p1, vector p2) { return p1[4] > p2[4]; }); for (auto i = 0; i < info.size(); i++) { if (i < counter) { return_info.push_back(info[i]); continue; } if (i == counter) { x1 = info[i][0]; y1 = info[i][1]; x2 = info[i][2]; y2 = info[i][3]; return_info.push_back(info[i]); continue; } if (info[i][0] > x2 or info[i][2] < x1 or info[i][1] > y2 or info[i][3] < y1) { return_info.push_back(info[i]); } else { float over_x1 = std::max(x1, info[i][0]); float over_y1 = std::max(y1, info[i][1]); float over_x2 = std::min(x2, info[i][2]); float over_y2 = std::min(y2, info[i][3]); float s_over = (over_x2 - over_x1) * (over_y2 - over_y1); float s_total = (x2 - x1) * (y2 - y1) + (info[i][0] - info[i][2]) * (info[i][1] - info[i][3]) - s_over; if (s_over / s_total < iou) { return_info.push_back(info[i]); } } } info = return_info; counter += 1; } } void draw_box(Mat& img, const vector>& info) { for (int i = 0; i < info.size(); i++) { cv::Point topLeft(info[i][0], info[i][1]); cv::Point bottomRight(info[i][2], info[i][3]); int thickness = 2; cv::Scalar color(0, 255, 0); int lineType = cv::LINE_8; const int cornerRadius = 5; cv::rectangle(img, topLeft, bottomRight, color, thickness, lineType); string label = class_name[static_cast(info[i][5])] + " " + std::to_string(info[i][4]); cv::Size textSize = cv::getTextSize(label, cv::FONT_HERSHEY_SIMPLEX, 0.6, 1, nullptr); cv::Rect textBgRect(topLeft.x, topLeft.y - textSize.height - 5, textSize.width + 10, textSize.height + 5); cv::rectangle(img, textBgRect, color, cv::FILLED); cv::putText(img, label, cv::Point(topLeft.x + 5, topLeft.y - 5), cv::FONT_HERSHEY_SIMPLEX, 0.6, CV_RGB(255, 255, 255), 2); } } int main() { // Load the network and set the backend to CUDA cv::dnn::Net net = cv::dnn::readNetFromONNX("best.onnx"); // Set the DNN backend to CUDA and target to CUDA net.setPreferableBackend(cv::dnn::DNN_BACKEND_CUDA); net.setPreferableTarget(cv::dnn::DNN_TARGET_CUDA); Mat img = cv::imread("fox.jpg"); cv::resize(img, img, cv::Size(640, 640)); Mat blob = cv::dnn::blobFromImage(img, 1.0 / 255.0, cv::Size(640, 640), cv::Scalar(), true); net.setInput(blob); vector netoutput; vector out_name = { "output" }; net.forward(netoutput, out_name); Mat result = netoutput[0]; vector> info = get_info(result); info_simplify(info); vector>> info_split = split_info(info); for (auto i = 0; i < info_split.size(); i++) { nms(info_split[i]); draw_box(img, info_split[i]); } cv::imshow("test", img); cv::waitKey(0); return 0; }