inference.h 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. #pragma once
  2. #define RET_OK nullptr
  3. #ifdef _WIN32
  4. #include <Windows.h>
  5. #include <direct.h>
  6. #include <io.h>
  7. #endif
  8. #include <string>
  9. #include <vector>
  10. #include <cstdio>
  11. #include <opencv2/opencv.hpp>
  12. #include "onnxruntime_cxx_api.h"
  13. #ifdef USE_CUDA
  14. #include <cuda_fp16.h>
  15. #endif
  16. enum MODEL_TYPE {
  17. //FLOAT32 MODEL
  18. YOLO_ORIGIN_V5 = 0,
  19. YOLO_ORIGIN_V8 = 1,//only support v8 detector currently
  20. YOLO_POSE_V8 = 2,
  21. YOLO_CLS_V8 = 3,
  22. YOLO_ORIGIN_V8_HALF = 4,
  23. YOLO_POSE_V8_HALF = 5,
  24. YOLO_CLS_V8_HALF = 6
  25. };
  26. typedef struct _DCSP_INIT_PARAM {
  27. std::string ModelPath;
  28. MODEL_TYPE ModelType = YOLO_ORIGIN_V8;
  29. std::vector<int> imgSize = {640, 640};
  30. float RectConfidenceThreshold = 0.6;
  31. float iouThreshold = 0.5;
  32. bool CudaEnable = false;
  33. int LogSeverityLevel = 3;
  34. int IntraOpNumThreads = 1;
  35. } DCSP_INIT_PARAM;
  36. typedef struct _DCSP_RESULT {
  37. int classId;
  38. float confidence;
  39. cv::Rect box;
  40. } DCSP_RESULT;
  41. class DCSP_CORE {
  42. public:
  43. DCSP_CORE();
  44. ~DCSP_CORE();
  45. public:
  46. char *CreateSession(DCSP_INIT_PARAM &iParams);
  47. char *RunSession(cv::Mat &iImg, std::vector<DCSP_RESULT> &oResult);
  48. char *WarmUpSession();
  49. template<typename N>
  50. char *TensorProcess(clock_t &starttime_1, cv::Mat &iImg, N &blob, std::vector<int64_t> &inputNodeDims,
  51. std::vector<DCSP_RESULT> &oResult);
  52. std::vector<std::string> classes{};
  53. private:
  54. Ort::Env env;
  55. Ort::Session *session;
  56. bool cudaEnable;
  57. Ort::RunOptions options;
  58. std::vector<const char *> inputNodeNames;
  59. std::vector<const char *> outputNodeNames;
  60. MODEL_TYPE modelType;
  61. std::vector<int> imgSize;
  62. float rectConfidenceThreshold;
  63. float iouThreshold;
  64. };