trt_inference.h 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. /*
  2. * Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved.
  3. *
  4. * Redistribution and use in source and binary forms, with or without
  5. * modification, are permitted provided that the following conditions
  6. * are met:
  7. * * Redistributions of source code must retain the above copyright
  8. * notice, this list of conditions and the following disclaimer.
  9. * * Redistributions in binary form must reproduce the above copyright
  10. * notice, this list of conditions and the following disclaimer in the
  11. * documentation and/or other materials provided with the distribution.
  12. * * Neither the name of NVIDIA CORPORATION nor the names of its
  13. * contributors may be used to endorse or promote products derived
  14. * from this software without specific prior written permission.
  15. *
  16. * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
  17. * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  18. * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
  19. * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
  20. * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
  21. * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
  22. * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
  23. * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
  24. * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  25. * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  26. * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  27. */
  28. #ifndef TRT_INFERENCE_H_
  29. #define TRT_INFERENCE_H_
  30. #include <fstream>
  31. #include <queue>
  32. #include "NvInfer.h"
  33. #include "NvCaffeParser.h"
  34. #include "NvOnnxParser.h"
  35. #include <opencv2/objdetect/objdetect.hpp>
  36. using namespace nvinfer1;
  37. using namespace nvcaffeparser1;
  38. using namespace nvonnxparser;
  39. using namespace std;
  40. // Model Index
  41. #define GOOGLENET_SINGLE_CLASS 0
  42. #define GOOGLENET_THREE_CLASS 1
  43. #define RESNET_THREE_CLASS 2
  44. class Logger;
  45. class Profiler;
  46. class TRT_Context
  47. {
  48. public:
  49. //net related parameter
  50. int getNetWidth() const;
  51. int getNetHeight() const;
  52. uint32_t getBatchSize() const;
  53. int getChannel() const;
  54. int getModelClassCnt() const;
  55. void* getScales() const;
  56. void* getOffsets() const;
  57. // Buffer is allocated in TRT_Conxtex,
  58. // Expose this interface for inputing data
  59. void*& getBuffer(const int& index);
  60. float*& getInputBuf();
  61. uint32_t getNumTrtInstances() const;
  62. //0 fp16 1 fp32 2 int8
  63. void setMode(const int& mode);
  64. void setBatchSize(const uint32_t& batchsize);
  65. void setDumpResult(const bool& dump_result);
  66. void setTrtProfilerEnabled(const bool& enable_trt_profiler);
  67. int getFilterNum() const;
  68. void setFilterNum(const unsigned int& filter_num);
  69. TRT_Context();
  70. void setModelIndex(int modelIndex);
  71. void buildTrtContext(const string& deployfile,
  72. const string& modelfile, bool bUseCPUBuf = false, bool isOnnxModel = false);
  73. void doInference(
  74. queue< vector<cv::Rect> >* rectList_queue,
  75. float *input = NULL);
  76. void destroyTrtContext(bool bUseCPUBuf = false);
  77. ~TRT_Context();
  78. private:
  79. int net_width;
  80. int net_height;
  81. int filter_num;
  82. void **buffers;
  83. float *input_buf;
  84. float *output_cov_buf;
  85. float *output_bbox_buf;
  86. void* offset_gpu;
  87. void* scales_gpu;
  88. float helnet_scale[4];
  89. IRuntime *runtime;
  90. ICudaEngine *engine;
  91. IExecutionContext *context;
  92. uint32_t *pResultArray;
  93. int channel; //input file's channel
  94. int num_bindings;
  95. int trtinstance_num; //inference channel num
  96. int batch_size;
  97. int mode;
  98. bool dump_result;
  99. ofstream fstream;
  100. bool enable_trt_profiler;
  101. bool is_onnx_model;
  102. IHostMemory *trtModelStream{nullptr};
  103. vector<string> outputs;
  104. string result_file;
  105. Logger *pLogger;
  106. Profiler *pProfiler;
  107. int frame_num;
  108. uint64_t elapsed_frame_num;
  109. uint64_t elapsed_time;
  110. int inputIndex;
  111. int outputIndex;
  112. int outputIndexBBOX;
  113. Dims3 inputDims;
  114. Dims3 outputDims;
  115. Dims3 outputDimsBBOX;
  116. size_t inputSize;
  117. size_t outputSize;
  118. size_t outputSizeBBOX;
  119. struct {
  120. const int classCnt;
  121. float THRESHOLD[3];
  122. const char *INPUT_BLOB_NAME;
  123. const char *OUTPUT_BLOB_NAME;
  124. const char *OUTPUT_BBOX_NAME;
  125. const int STRIDE;
  126. const int WORKSPACE_SIZE;
  127. int offsets[3];
  128. float input_scale[3];
  129. float bbox_output_scales[4];
  130. const int ParseFunc_ID;
  131. } *g_pModelNetAttr, gModelNetAttr[4] = {
  132. {
  133. // GOOGLENET_SINGLE_CLASS
  134. 1,
  135. {0.8, 0, 0},
  136. "data",
  137. "coverage",
  138. "bboxes",
  139. 4,
  140. 450 * 1024 * 1024,
  141. {0, 0, 0},
  142. {1.0f, 1.0f, 1.0f},
  143. {1, 1, 1, 1},
  144. 0
  145. },
  146. {
  147. // GOOGLENET_THREE_CLASS
  148. 3,
  149. {0.6, 0.6, 1.0}, //People, Motorbike, Car
  150. "data",
  151. "Layer16_cov",
  152. "Layer16_bbox",
  153. 16,
  154. 110 * 1024 * 1024,
  155. {124, 117, 104},
  156. {1.0f, 1.0f, 1.0f},
  157. {-640, -368, 640, 368},
  158. 0
  159. },
  160. {
  161. // RESNET_THREE_CLASS
  162. 4,
  163. {0.1, 0.1, 0.1}, //People, Motorbike, Car
  164. "data",
  165. "Layer7_cov",
  166. "Layer7_bbox",
  167. 16,
  168. 110 * 1024 * 1024,
  169. {0, 0, 0},
  170. {0.0039215697906911373, 0.0039215697906911373, 0.0039215697906911373},
  171. {-640, -368, 640, 368},
  172. 1
  173. },
  174. };
  175. enum Mode_type{
  176. MODE_FP16 = 0,
  177. MODE_FP32 = 1,
  178. MODE_INT8 = 2
  179. };
  180. int parseNet(const string& deployfile);
  181. void parseBbox(vector<cv::Rect>* rectList, int batch_th);
  182. void ParseResnet10Bbox(vector<cv::Rect>* rectList, int batch_th);
  183. void allocateMemory(bool bUseCPUBuf);
  184. void releaseMemory(bool bUseCPUBuf);
  185. void caffeToTRTModel(const string& deployfile, const string& modelfile);
  186. void onnxToTRTModel(const string& modelfile);
  187. };
  188. #endif