123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239 |
- #include <iostream>
- #include <opencv2/opencv.hpp>
- #include <onnxruntime_cxx_api.h>
- #include <vector>
- #include <string>
- #include <algorithm>
- #include <numeric>
- #include <sstream>
- using cv::Mat;
- using std::cout;
- using std::endl;
- using std::string;
- using std::vector;
- static const vector<string> class_name = { "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
- "fire hydrant","stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
- "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
- "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
- "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
- "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
- "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone",
- "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
- "hair drier", "toothbrush" };
- vector<vector<float>> get_info(const Mat& result, float conf = 0.5, int len_data = 85)
- {
- float* pdata = (float*)result.data;
- vector<vector<float>> info;
- for (int i = 0; i < result.total() / len_data; i++)
- {
- if (pdata[4] >= conf)
- {
- vector<float> line_data(pdata, pdata + len_data);
- info.push_back(line_data);
- }
- pdata += len_data;
- }
- return info;
- }
- void info_simplify(vector<vector<float>>& info)
- {
- for (auto& entry : info)
- {
- entry[5] = std::max_element(entry.begin() + 5, entry.end()) - entry.begin() - 5;
- entry.resize(6);
- float x = entry[0];
- float y = entry[1];
- float w = entry[2];
- float h = entry[3];
- entry[0] = x - w / 2.0;
- entry[1] = y - h / 2.0;
- entry[2] = x + w / 2.0;
- entry[3] = y + h / 2.0;
- }
- }
- vector<vector<vector<float>>> split_info(vector<vector<float>>& info)
- {
- vector<vector<vector<float>>> info_split;
- vector<int> class_id;
- for (const auto& entry : info)
- {
- int classIndex = static_cast<int>(entry[5]);
- if (std::find(class_id.begin(), class_id.end(), classIndex) == class_id.end())
- {
- class_id.push_back(classIndex);
- info_split.emplace_back();
- }
- info_split[std::find(class_id.begin(), class_id.end(), classIndex) - class_id.begin()].push_back(entry);
- }
- return info_split;
- }
- void nms(vector<vector<float>>& info, float iou = 0.5)
- {
- std::sort(info.begin(), info.end(), [](vector<float> p1, vector<float> p2) {
- return p1[4] > p2[4];
- });
- vector<vector<float>> return_info;
- vector<bool> suppressed(info.size(), false);
- for (size_t i = 0; i < info.size(); i++)
- {
- if (suppressed[i]) continue;
- return_info.push_back(info[i]);
- float x1 = info[i][0], y1 = info[i][1], x2 = info[i][2], y2 = info[i][3];
- for (size_t j = i + 1; j < info.size(); j++)
- {
- if (suppressed[j]) continue;
- float interX1 = std::max(x1, info[j][0]);
- float interY1 = std::max(y1, info[j][1]);
- float interX2 = std::min(x2, info[j][2]);
- float interY2 = std::min(y2, info[j][3]);
- float interArea = std::max(0.0f, interX2 - interX1) * std::max(0.0f, interY2 - interY1);
- float totalArea = (x2 - x1) * (y2 - y1) + (info[j][2] - info[j][0]) * (info[j][3] - info[j][1]) - interArea;
- if (interArea / totalArea > iou)
- {
- suppressed[j] = true;
- }
- }
- }
- info = std::move(return_info);
- }
- void draw_box(Mat& img, const vector<vector<float>>& info)
- {
- for (int i = 0; i < info.size(); i++)
- {
- cv::Point topLeft(info[i][0], info[i][1]);
- cv::Point bottomRight(info[i][2], info[i][3]);
- int thickness = 2;
- cv::Scalar color(0, 255, 0);
- int lineType = cv::LINE_8;
- const int cornerRadius = 5;
- cv::rectangle(img, topLeft, bottomRight, color, thickness, lineType);
- string label = class_name[static_cast<int>(info[i][5])] + " " + std::to_string(info[i][4]);
- cv::Size textSize = cv::getTextSize(label, cv::FONT_HERSHEY_SIMPLEX, 0.6, 1, nullptr);
- cv::Rect textBgRect(topLeft.x, topLeft.y - textSize.height - 5, textSize.width + 10, textSize.height + 5);
- cv::rectangle(img, textBgRect, color, cv::FILLED);
- cv::putText(img, label, cv::Point(topLeft.x + 5, topLeft.y - 5), cv::FONT_HERSHEY_SIMPLEX, 0.6, CV_RGB(255, 255, 255), 2);
- }
- }
- int main() {
- try {
- Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "ONNXRuntime");
- Ort::SessionOptions session_options;
- OrtCUDAProviderOptions cuda_options;
- session_options.AppendExecutionProvider_CUDA(cuda_options);
- std::wstring model_path = L"D:/VisualStudio/test/Project1/Project1/yolov5s.onnx";
- Ort::Session session(env, model_path.c_str(), session_options);
- std::vector<std::string> input_names = { "input" };
- std::vector<std::string> output_names = { "output" };
- std::vector<const char*> input_names_cstr;
- std::vector<const char*> output_names_cstr;
- for (const auto& name : input_names) {
- input_names_cstr.push_back(name.c_str());
- }
- for (const auto& name : output_names) {
- output_names_cstr.push_back(name.c_str());
- }
- cv::VideoCapture cap(0);
- if (!cap.isOpened()) {
- cout << "Error: Cannot open the camera" << endl;
- return -1;
- }
- Ort::MemoryInfo memory_info("Cpu", OrtArenaAllocator, 0, OrtMemTypeDefault);
- while (true) {
- Mat img;
- cap >> img;
- if (img.empty()) break;
- cv::resize(img, img, cv::Size(640, 640));
- // Convert image to float and normalize
- Mat blob;
- img.convertTo(blob, CV_32F, 1.0 / 255.0);
- std::vector<float> input_tensor_values(blob.total());
- std::memcpy(input_tensor_values.data(), blob.data, blob.total() * sizeof(float));
- // Create ONNX Runtime tensor from input data
- std::vector<int64_t> input_node_dims = { 1, 3, 640, 640 };
- Ort::Value input_tensor = Ort::Value::CreateTensor<float>(
- memory_info, input_tensor_values.data(), input_tensor_values.size(),
- input_node_dims.data(), input_node_dims.size());
- // Run inference
- auto output_tensors = session.Run(Ort::RunOptions{ nullptr },
- input_names_cstr.data(), &input_tensor, 1,
- output_names_cstr.data(), 1);
- // Get the result
- Ort::Value& output_tensor = output_tensors.front();
- auto output_shape = output_tensor.GetTensorTypeAndShapeInfo().GetShape();
- std::vector<float> output_data(output_tensor.GetTensorTypeAndShapeInfo().GetElementCount());
- // Print output shape for debugging
- cout << "Output Shape: ";
- for (const auto& dim : output_shape) {
- cout << dim << " ";
- }
- cout << endl;
- size_t expected_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies<size_t>());
- if (output_data.size() != expected_size) {
- std::cerr << "Error: Output data size mismatch." << endl;
- return -1;
- }
- std::memcpy(output_data.data(), output_tensor.GetTensorData<float>(),
- output_data.size() * sizeof(float));
- // Print a portion of the output data for debugging
- cout << "Output Data (first 10 values): ";
- for (size_t i = 0; i < std::min(output_data.size(), static_cast<size_t>(10)); ++i) {
- cout << output_data[i] << " ";
- }
- cout << endl;
- Mat result(output_shape[2], output_shape[3], CV_32FC1, output_data.data());
- vector<vector<float>> info = get_info(result);
- info_simplify(info);
- vector<vector<vector<float>>> info_split = split_info(info);
- for (auto& split_info : info_split) {
- nms(split_info);
- draw_box(img, split_info);
- }
- cv::imshow("test", img);
- if (cv::waitKey(1) == 'q') break;
- }
- cap.release();
- cv::destroyAllWindows();
- }
- catch (const Ort::Exception& e) {
- cout << "ONNX Runtime error: " << e.what() << endl;
- return -1;
- }
- catch (const std::exception& e) {
- cout << "Standard exception: " << e.what() << endl;
- return -1;
- }
- return 0;
- }
|