onnxruntime_cxx_inline.h 76 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125
  1. // Copyright (c) Microsoft Corporation. All rights reserved.
  2. // Licensed under the MIT License.
  3. // Do not include this file directly. Please include "onnxruntime_cxx_api.h" instead.
  4. // If interested in trying out features of the new experimental C++ API, include "experimental_onnxruntime_cxx_api.h" instead.
  5. //
  6. // These are the inline implementations of the C++ header APIs. They're in this separate file as to not clutter
  7. // the main C++ file with implementation details.
  8. #include <algorithm>
  9. #include <functional>
  10. #include <iterator>
  11. #include <type_traits>
  12. // Convert OrtStatus to Ort::Status and return
  13. // instead of throwing
  14. #define ORT_CXX_RETURN_ON_API_FAIL(expression) \
  15. { \
  16. auto ort_status = (expression); \
  17. if (ort_status) { \
  18. return Ort::Status(ort_status); \
  19. } \
  20. }
  21. #ifdef __cpp_if_constexpr
  22. #define ORT_CXX_IF_CONSTEXPR if constexpr
  23. #else
  24. #define ORT_CXX_IF_CONSTEXPR if
  25. #endif
  26. namespace Ort {
  27. namespace detail {
  28. inline void ThrowStatus(const Status& st) {
  29. std::string error_message = st.GetErrorMessage();
  30. OrtErrorCode error_code = st.GetErrorCode();
  31. ORT_CXX_API_THROW(std::move(error_message), error_code);
  32. }
  33. } // namespace detail
  34. inline void ThrowOnError(OrtStatus* ort_status) {
  35. if (ort_status) {
  36. Ort::Status st(ort_status);
  37. detail::ThrowStatus(st);
  38. }
  39. }
  40. inline void ThrowOnError(const Status& st) {
  41. if (st) {
  42. detail::ThrowStatus(st);
  43. }
  44. }
  45. inline Status::Status(OrtStatus* status) noexcept : Base<OrtStatus>{status} {
  46. }
  47. inline Status::Status(const std::exception& e) noexcept {
  48. p_ = GetApi().CreateStatus(ORT_FAIL, e.what());
  49. }
  50. inline Status::Status(const Exception& e) noexcept {
  51. p_ = GetApi().CreateStatus(e.GetOrtErrorCode(), e.what());
  52. }
  53. inline Status::Status(const char* message, OrtErrorCode code) noexcept {
  54. p_ = GetApi().CreateStatus(code, message);
  55. }
  56. inline std::string Status::GetErrorMessage() const {
  57. std::string message(GetApi().GetErrorMessage(p_));
  58. return message;
  59. }
  60. inline OrtErrorCode Status::GetErrorCode() const {
  61. return GetApi().GetErrorCode(p_);
  62. }
  63. inline bool Status::IsOK() const noexcept {
  64. return (p_ == nullptr);
  65. }
  66. // This template converts a C++ type into it's ONNXTensorElementDataType
  67. template <typename T>
  68. struct TypeToTensorType;
  69. template <>
  70. struct TypeToTensorType<float> {
  71. static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
  72. };
  73. template <>
  74. struct TypeToTensorType<Float16_t> {
  75. static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
  76. };
  77. template <>
  78. struct TypeToTensorType<BFloat16_t> {
  79. static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16;
  80. };
  81. template <>
  82. struct TypeToTensorType<double> {
  83. static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
  84. };
  85. template <>
  86. struct TypeToTensorType<int8_t> {
  87. static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
  88. };
  89. template <>
  90. struct TypeToTensorType<int16_t> {
  91. static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
  92. };
  93. template <>
  94. struct TypeToTensorType<int32_t> {
  95. static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
  96. };
  97. template <>
  98. struct TypeToTensorType<int64_t> {
  99. static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
  100. };
  101. template <>
  102. struct TypeToTensorType<uint8_t> {
  103. static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
  104. };
  105. template <>
  106. struct TypeToTensorType<uint16_t> {
  107. static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16;
  108. };
  109. template <>
  110. struct TypeToTensorType<uint32_t> {
  111. static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32;
  112. };
  113. template <>
  114. struct TypeToTensorType<uint64_t> {
  115. static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64;
  116. };
  117. template <>
  118. struct TypeToTensorType<bool> {
  119. static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
  120. };
  121. template <>
  122. struct TypeToTensorType<Float8E4M3FN_t> {
  123. static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN;
  124. };
  125. template <>
  126. struct TypeToTensorType<Float8E4M3FNUZ_t> {
  127. static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ;
  128. };
  129. template <>
  130. struct TypeToTensorType<Float8E5M2_t> {
  131. static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2;
  132. };
  133. template <>
  134. struct TypeToTensorType<Float8E5M2FNUZ_t> {
  135. static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ;
  136. };
  137. inline bool BFloat16_t::operator==(const BFloat16_t& rhs) const noexcept {
  138. if (IsNaN() || rhs.IsNaN()) {
  139. // IEEE defines that NaN is not equal to anything, including itself.
  140. return false;
  141. }
  142. return val == rhs.val;
  143. }
  144. inline bool BFloat16_t::operator<(const BFloat16_t& rhs) const noexcept {
  145. if (IsNaN() || rhs.IsNaN()) {
  146. // IEEE defines that NaN is unordered with respect to everything, including itself.
  147. return false;
  148. }
  149. const bool left_is_negative = IsNegative();
  150. if (left_is_negative != rhs.IsNegative()) {
  151. // When the signs of left and right differ, we know that left is less than right if it is
  152. // the negative value. The exception to this is if both values are zero, in which case IEEE
  153. // says they should be equal, even if the signs differ.
  154. return left_is_negative && !AreZero(*this, rhs);
  155. }
  156. return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative);
  157. }
  158. inline MemoryAllocation::MemoryAllocation(OrtAllocator* allocator, void* p, size_t size)
  159. : allocator_(allocator), p_(p), size_(size) {
  160. }
  161. inline MemoryAllocation::~MemoryAllocation() {
  162. if (p_ != nullptr) {
  163. // We do not throw out of destructor
  164. auto ret = GetApi().AllocatorFree(allocator_, p_);
  165. static_cast<void>(ret);
  166. }
  167. }
  168. inline MemoryAllocation::MemoryAllocation(MemoryAllocation&& o) noexcept : allocator_(nullptr), p_(nullptr), size_(0) {
  169. *this = std::move(o);
  170. }
  171. inline MemoryAllocation& MemoryAllocation::operator=(MemoryAllocation&& o) noexcept {
  172. OrtAllocator* alloc = nullptr;
  173. void* p = nullptr;
  174. size_t sz = 0;
  175. // Swap out this
  176. std::swap(alloc, allocator_);
  177. std::swap(p, p_);
  178. std::swap(sz, size_);
  179. // Swap with incoming
  180. std::swap(allocator_, o.allocator_);
  181. std::swap(p_, o.p_);
  182. std::swap(size_, o.size_);
  183. // Destroy this instance if needed
  184. MemoryAllocation this_alloc(alloc, p, sz);
  185. return *this;
  186. }
  187. namespace detail {
  188. template <typename T>
  189. inline void* AllocatorImpl<T>::Alloc(size_t size) {
  190. void* out;
  191. ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out));
  192. return out;
  193. }
  194. template <typename T>
  195. inline MemoryAllocation AllocatorImpl<T>::GetAllocation(size_t size) {
  196. void* out;
  197. ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out));
  198. MemoryAllocation result(this->p_, out, size);
  199. return result;
  200. }
  201. template <typename T>
  202. inline void AllocatorImpl<T>::Free(void* p) {
  203. ThrowOnError(GetApi().AllocatorFree(this->p_, p));
  204. }
  205. template <typename T>
  206. inline ConstMemoryInfo AllocatorImpl<T>::GetInfo() const {
  207. const OrtMemoryInfo* out;
  208. ThrowOnError(GetApi().AllocatorGetInfo(this->p_, &out));
  209. return ConstMemoryInfo{out};
  210. }
  211. } // namespace detail
  212. inline AllocatorWithDefaultOptions::AllocatorWithDefaultOptions() {
  213. ThrowOnError(GetApi().GetAllocatorWithDefaultOptions(&this->p_));
  214. }
  215. inline Allocator::Allocator(const Session& sess, const OrtMemoryInfo* mem_info) {
  216. ThrowOnError(GetApi().CreateAllocator(sess, mem_info, &this->p_));
  217. }
  218. namespace detail {
  219. template <typename T>
  220. inline std::string MemoryInfoImpl<T>::GetAllocatorName() const {
  221. const char* name = nullptr;
  222. ThrowOnError(GetApi().MemoryInfoGetName(this->p_, &name));
  223. return std::string(name);
  224. }
  225. template <typename T>
  226. inline OrtAllocatorType MemoryInfoImpl<T>::GetAllocatorType() const {
  227. OrtAllocatorType type;
  228. ThrowOnError(GetApi().MemoryInfoGetType(this->p_, &type));
  229. return type;
  230. }
  231. template <typename T>
  232. inline int MemoryInfoImpl<T>::GetDeviceId() const {
  233. int id = 0;
  234. ThrowOnError(GetApi().MemoryInfoGetId(this->p_, &id));
  235. return id;
  236. }
  237. template <typename T>
  238. inline OrtMemoryInfoDeviceType MemoryInfoImpl<T>::GetDeviceType() const {
  239. OrtMemoryInfoDeviceType type;
  240. GetApi().MemoryInfoGetDeviceType(this->p_, &type);
  241. return type;
  242. }
  243. template <typename T>
  244. inline OrtMemType MemoryInfoImpl<T>::GetMemoryType() const {
  245. OrtMemType type;
  246. ThrowOnError(GetApi().MemoryInfoGetMemType(this->p_, &type));
  247. return type;
  248. }
  249. template <typename T>
  250. template <typename U>
  251. inline bool MemoryInfoImpl<T>::operator==(const MemoryInfoImpl<U>& o) const {
  252. int comp_result = 0;
  253. ThrowOnError(Ort::GetApi().CompareMemoryInfo(this->p_, o, &comp_result));
  254. return comp_result == 0;
  255. }
  256. } // namespace detail
  257. inline MemoryInfo MemoryInfo::CreateCpu(OrtAllocatorType type, OrtMemType mem_type) {
  258. OrtMemoryInfo* p;
  259. ThrowOnError(GetApi().CreateCpuMemoryInfo(type, mem_type, &p));
  260. return MemoryInfo(p);
  261. }
  262. inline MemoryInfo::MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) {
  263. ThrowOnError(GetApi().CreateMemoryInfo(name, type, id, mem_type, &this->p_));
  264. }
  265. namespace detail {
  266. template <typename T>
  267. inline std::vector<std::string> ConstIoBindingImpl<T>::GetOutputNames() const {
  268. AllocatorWithDefaultOptions allocator;
  269. return binding_utils::GetOutputNamesHelper(this->p_, allocator);
  270. }
  271. template <typename T>
  272. inline std::vector<std::string> ConstIoBindingImpl<T>::GetOutputNames(OrtAllocator* allocator) const {
  273. return binding_utils::GetOutputNamesHelper(this->p_, allocator);
  274. }
  275. template <typename T>
  276. inline std::vector<Value> ConstIoBindingImpl<T>::GetOutputValues() const {
  277. AllocatorWithDefaultOptions allocator;
  278. return binding_utils::GetOutputValuesHelper(this->p_, allocator);
  279. }
  280. template <typename T>
  281. inline std::vector<Value> ConstIoBindingImpl<T>::GetOutputValues(OrtAllocator* allocator) const {
  282. return binding_utils::GetOutputValuesHelper(this->p_, allocator);
  283. }
  284. template <typename T>
  285. inline void IoBindingImpl<T>::BindInput(const char* name, const Value& value) {
  286. ThrowOnError(GetApi().BindInput(this->p_, name, value));
  287. }
  288. template <typename T>
  289. inline void IoBindingImpl<T>::BindOutput(const char* name, const Value& value) {
  290. ThrowOnError(GetApi().BindOutput(this->p_, name, value));
  291. }
  292. template <typename T>
  293. inline void IoBindingImpl<T>::BindOutput(const char* name, const OrtMemoryInfo* mem_info) {
  294. ThrowOnError(GetApi().BindOutputToDevice(this->p_, name, mem_info));
  295. }
  296. template <typename T>
  297. inline void IoBindingImpl<T>::ClearBoundInputs() {
  298. GetApi().ClearBoundInputs(this->p_);
  299. }
  300. template <typename T>
  301. inline void IoBindingImpl<T>::ClearBoundOutputs() {
  302. GetApi().ClearBoundOutputs(this->p_);
  303. }
  304. template <typename T>
  305. inline void IoBindingImpl<T>::SynchronizeInputs() {
  306. ThrowOnError(GetApi().SynchronizeBoundInputs(this->p_));
  307. }
  308. template <typename T>
  309. inline void IoBindingImpl<T>::SynchronizeOutputs() {
  310. ThrowOnError(GetApi().SynchronizeBoundOutputs(this->p_));
  311. }
  312. namespace binding_utils {
  313. inline std::vector<std::string> GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) {
  314. std::vector<std::string> result;
  315. auto free_fn = detail::AllocatedFree(allocator);
  316. using Ptr = std::unique_ptr<void, decltype(free_fn)>;
  317. char* buffer = nullptr;
  318. size_t* lengths = nullptr;
  319. size_t count = 0;
  320. ThrowOnError(GetApi().GetBoundOutputNames(binding, allocator, &buffer, &lengths, &count));
  321. if (count == 0) {
  322. return result;
  323. }
  324. Ptr buffer_g(buffer, free_fn);
  325. Ptr lengths_g(lengths, free_fn);
  326. result.reserve(count);
  327. for (size_t i = 0; i < count; ++i) {
  328. auto sz = *lengths;
  329. result.emplace_back(buffer, sz);
  330. buffer += sz;
  331. ++lengths;
  332. }
  333. return result;
  334. }
  335. inline std::vector<Value> GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) {
  336. std::vector<Value> result;
  337. size_t owned = 0;
  338. size_t output_count = 0;
  339. // Lambda to release the buffer when no longer needed and
  340. // make sure that we destroy all instances on exception
  341. auto free_fn = [&owned, &output_count, allocator](OrtValue** buffer) {
  342. if (buffer) {
  343. while (owned < output_count) {
  344. auto* p = buffer + owned++;
  345. GetApi().ReleaseValue(*p);
  346. }
  347. allocator->Free(allocator, buffer);
  348. }
  349. };
  350. using Ptr = std::unique_ptr<OrtValue*, decltype(free_fn)>;
  351. OrtValue** output_buffer = nullptr;
  352. ThrowOnError(GetApi().GetBoundOutputValues(binding, allocator, &output_buffer, &output_count));
  353. if (output_count == 0) {
  354. return result;
  355. }
  356. Ptr buffer_g(output_buffer, free_fn);
  357. result.reserve(output_count);
  358. for (size_t i = 0; i < output_count; ++i) {
  359. result.emplace_back(output_buffer[i]);
  360. ++owned;
  361. }
  362. return result;
  363. }
  364. } // namespace binding_utils
  365. } // namespace detail
  366. inline IoBinding::IoBinding(Session& session) {
  367. ThrowOnError(GetApi().CreateIoBinding(session, &this->p_));
  368. }
  369. inline ArenaCfg::ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk) {
  370. ThrowOnError(GetApi().CreateArenaCfg(max_mem, arena_extend_strategy, initial_chunk_size_bytes, max_dead_bytes_per_chunk, &p_));
  371. }
  372. inline ThreadingOptions::ThreadingOptions() {
  373. ThrowOnError(GetApi().CreateThreadingOptions(&p_));
  374. }
  375. inline ThreadingOptions& ThreadingOptions::SetGlobalIntraOpNumThreads(int intra_op_num_threads) {
  376. ThrowOnError(GetApi().SetGlobalIntraOpNumThreads(p_, intra_op_num_threads));
  377. return *this;
  378. }
  379. inline ThreadingOptions& ThreadingOptions::SetGlobalInterOpNumThreads(int inter_op_num_threads) {
  380. ThrowOnError(GetApi().SetGlobalInterOpNumThreads(p_, inter_op_num_threads));
  381. return *this;
  382. }
  383. inline ThreadingOptions& ThreadingOptions::SetGlobalSpinControl(int allow_spinning) {
  384. ThrowOnError(GetApi().SetGlobalSpinControl(p_, allow_spinning));
  385. return *this;
  386. }
  387. inline ThreadingOptions& ThreadingOptions::SetGlobalDenormalAsZero() {
  388. ThrowOnError(GetApi().SetGlobalDenormalAsZero(p_));
  389. return *this;
  390. }
  391. inline ThreadingOptions& ThreadingOptions::SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
  392. ThrowOnError(GetApi().SetGlobalCustomCreateThreadFn(p_, ort_custom_create_thread_fn));
  393. return *this;
  394. }
  395. inline ThreadingOptions& ThreadingOptions::SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options) {
  396. ThrowOnError(GetApi().SetGlobalCustomThreadCreationOptions(p_, ort_custom_thread_creation_options));
  397. return *this;
  398. }
  399. inline ThreadingOptions& ThreadingOptions::SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) {
  400. ThrowOnError(GetApi().SetGlobalCustomJoinThreadFn(p_, ort_custom_join_thread_fn));
  401. return *this;
  402. }
  403. inline Env::Env(OrtLoggingLevel logging_level, _In_ const char* logid) {
  404. ThrowOnError(GetApi().CreateEnv(logging_level, logid, &p_));
  405. if (strcmp(logid, "onnxruntime-node") == 0) {
  406. ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
  407. } else {
  408. ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
  409. }
  410. }
  411. inline Env::Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param) {
  412. ThrowOnError(GetApi().CreateEnvWithCustomLogger(logging_function, logger_param, logging_level, logid, &p_));
  413. if (strcmp(logid, "onnxruntime-node") == 0) {
  414. ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
  415. } else {
  416. ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
  417. }
  418. }
  419. inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level, _In_ const char* logid) {
  420. ThrowOnError(GetApi().CreateEnvWithGlobalThreadPools(logging_level, logid, tp_options, &p_));
  421. if (strcmp(logid, "onnxruntime-node") == 0) {
  422. ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
  423. } else {
  424. ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
  425. }
  426. }
  427. inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param,
  428. OrtLoggingLevel logging_level, _In_ const char* logid) {
  429. ThrowOnError(GetApi().CreateEnvWithCustomLoggerAndGlobalThreadPools(logging_function, logger_param, logging_level, logid, tp_options, &p_));
  430. if (strcmp(logid, "onnxruntime-node") == 0) {
  431. ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
  432. } else {
  433. ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
  434. }
  435. }
  436. inline Env& Env::EnableTelemetryEvents() {
  437. ThrowOnError(GetApi().EnableTelemetryEvents(p_));
  438. return *this;
  439. }
  440. inline Env& Env::DisableTelemetryEvents() {
  441. ThrowOnError(GetApi().DisableTelemetryEvents(p_));
  442. return *this;
  443. }
  444. inline Env& Env::UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level) {
  445. ThrowOnError(GetApi().UpdateEnvWithCustomLogLevel(p_, log_severity_level));
  446. return *this;
  447. }
  448. inline Env& Env::CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg) {
  449. ThrowOnError(GetApi().CreateAndRegisterAllocator(p_, mem_info, arena_cfg));
  450. return *this;
  451. }
  452. inline Env& Env::CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo* mem_info, const std::unordered_map<std::string, std::string>& options, const OrtArenaCfg* arena_cfg) {
  453. std::vector<const char*> keys, values;
  454. auto num_entries = options.size();
  455. if (num_entries > 0) {
  456. keys.reserve(num_entries);
  457. values.reserve(num_entries);
  458. for (const auto& entry : options) {
  459. keys.push_back(entry.first.c_str());
  460. values.push_back(entry.second.c_str());
  461. }
  462. }
  463. ThrowOnError(GetApi().CreateAndRegisterAllocatorV2(p_, provider_type.c_str(), mem_info, arena_cfg, keys.data(), values.data(), num_entries));
  464. return *this;
  465. }
  466. inline CustomOpDomain::CustomOpDomain(const char* domain) {
  467. ThrowOnError(GetApi().CreateCustomOpDomain(domain, &p_));
  468. }
  469. inline void CustomOpDomain::Add(const OrtCustomOp* op) {
  470. ThrowOnError(GetApi().CustomOpDomain_Add(p_, op));
  471. }
  472. inline RunOptions::RunOptions() {
  473. ThrowOnError(GetApi().CreateRunOptions(&p_));
  474. }
  475. inline RunOptions& RunOptions::SetRunLogVerbosityLevel(int level) {
  476. ThrowOnError(GetApi().RunOptionsSetRunLogVerbosityLevel(p_, level));
  477. return *this;
  478. }
  479. inline RunOptions& RunOptions::SetRunLogSeverityLevel(int level) {
  480. ThrowOnError(GetApi().RunOptionsSetRunLogSeverityLevel(p_, level));
  481. return *this;
  482. }
  483. inline int RunOptions::GetRunLogVerbosityLevel() const {
  484. int out;
  485. ThrowOnError(GetApi().RunOptionsGetRunLogVerbosityLevel(p_, &out));
  486. return out;
  487. }
  488. inline int RunOptions::GetRunLogSeverityLevel() const {
  489. int out;
  490. ThrowOnError(GetApi().RunOptionsGetRunLogSeverityLevel(p_, &out));
  491. return out;
  492. }
  493. inline RunOptions& RunOptions::SetRunTag(const char* run_tag) {
  494. ThrowOnError(GetApi().RunOptionsSetRunTag(p_, run_tag));
  495. return *this;
  496. }
  497. inline const char* RunOptions::GetRunTag() const {
  498. const char* out;
  499. ThrowOnError(GetApi().RunOptionsGetRunTag(p_, &out));
  500. return out;
  501. }
  502. inline RunOptions& RunOptions::AddConfigEntry(const char* config_key, const char* config_value) {
  503. ThrowOnError(GetApi().AddRunConfigEntry(p_, config_key, config_value));
  504. return *this;
  505. }
  506. inline RunOptions& RunOptions::SetTerminate() {
  507. ThrowOnError(GetApi().RunOptionsSetTerminate(p_));
  508. return *this;
  509. }
  510. inline RunOptions& RunOptions::UnsetTerminate() {
  511. ThrowOnError(GetApi().RunOptionsUnsetTerminate(p_));
  512. return *this;
  513. }
  514. namespace detail {
  515. template <typename T>
  516. inline Ort::SessionOptions ConstSessionOptionsImpl<T>::Clone() const {
  517. OrtSessionOptions* out;
  518. ThrowOnError(GetApi().CloneSessionOptions(this->p_, &out));
  519. return SessionOptions{out};
  520. }
  521. template <typename T>
  522. inline std::string ConstSessionOptionsImpl<T>::GetConfigEntry(const char* config_key) const {
  523. size_t size = 0;
  524. // Feed nullptr for the data buffer to query the true size of the string value
  525. Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, nullptr, &size));
  526. std::string out;
  527. out.resize(size);
  528. Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, &out[0], &size));
  529. out.resize(size - 1); // remove the terminating character '\0'
  530. return out;
  531. }
  532. template <typename T>
  533. inline bool ConstSessionOptionsImpl<T>::HasConfigEntry(const char* config_key) const {
  534. int out = 0;
  535. Ort::ThrowOnError(GetApi().HasSessionConfigEntry(this->p_, config_key, &out));
  536. return static_cast<bool>(out);
  537. }
  538. template <typename T>
  539. inline std::string ConstSessionOptionsImpl<T>::GetConfigEntryOrDefault(const char* config_key, const std::string& def) {
  540. if (!this->HasConfigEntry(config_key)) {
  541. return def;
  542. }
  543. return this->GetConfigEntry(config_key);
  544. }
  545. template <typename T>
  546. inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetIntraOpNumThreads(int intra_op_num_threads) {
  547. ThrowOnError(GetApi().SetIntraOpNumThreads(this->p_, intra_op_num_threads));
  548. return *this;
  549. }
  550. template <typename T>
  551. inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetInterOpNumThreads(int inter_op_num_threads) {
  552. ThrowOnError(GetApi().SetInterOpNumThreads(this->p_, inter_op_num_threads));
  553. return *this;
  554. }
  555. template <typename T>
  556. inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) {
  557. ThrowOnError(GetApi().SetSessionGraphOptimizationLevel(this->p_, graph_optimization_level));
  558. return *this;
  559. }
  560. template <typename T>
  561. inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetDeterministicCompute(bool value) {
  562. ThrowOnError(GetApi().SetDeterministicCompute(this->p_, value));
  563. return *this;
  564. }
  565. template <typename T>
  566. inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_filepath) {
  567. ThrowOnError(GetApi().SetOptimizedModelFilePath(this->p_, optimized_model_filepath));
  568. return *this;
  569. }
  570. template <typename T>
  571. inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableProfiling(const ORTCHAR_T* profile_file_prefix) {
  572. ThrowOnError(GetApi().EnableProfiling(this->p_, profile_file_prefix));
  573. return *this;
  574. }
  575. template <typename T>
  576. inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisableProfiling() {
  577. ThrowOnError(GetApi().DisableProfiling(this->p_));
  578. return *this;
  579. }
  580. template <typename T>
  581. inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableOrtCustomOps() {
  582. ThrowOnError(GetApi().EnableOrtCustomOps(this->p_));
  583. return *this;
  584. }
  585. template <typename T>
  586. inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableMemPattern() {
  587. ThrowOnError(GetApi().EnableMemPattern(this->p_));
  588. return *this;
  589. }
  590. template <typename T>
  591. inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisableMemPattern() {
  592. ThrowOnError(GetApi().DisableMemPattern(this->p_));
  593. return *this;
  594. }
  595. template <typename T>
  596. inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableCpuMemArena() {
  597. ThrowOnError(GetApi().EnableCpuMemArena(this->p_));
  598. return *this;
  599. }
  600. template <typename T>
  601. inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisableCpuMemArena() {
  602. ThrowOnError(GetApi().DisableCpuMemArena(this->p_));
  603. return *this;
  604. }
  605. template <typename T>
  606. inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetExecutionMode(ExecutionMode execution_mode) {
  607. ThrowOnError(GetApi().SetSessionExecutionMode(this->p_, execution_mode));
  608. return *this;
  609. }
  610. template <typename T>
  611. inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetLogId(const char* logid) {
  612. ThrowOnError(GetApi().SetSessionLogId(this->p_, logid));
  613. return *this;
  614. }
  615. template <typename T>
  616. inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetLogSeverityLevel(int level) {
  617. ThrowOnError(GetApi().SetSessionLogSeverityLevel(this->p_, level));
  618. return *this;
  619. }
  620. template <typename T>
  621. inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::Add(OrtCustomOpDomain* custom_op_domain) {
  622. ThrowOnError(GetApi().AddCustomOpDomain(this->p_, custom_op_domain));
  623. return *this;
  624. }
  625. template <typename T>
  626. inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddConfigEntry(const char* config_key, const char* config_value) {
  627. ThrowOnError(GetApi().AddSessionConfigEntry(this->p_, config_key, config_value));
  628. return *this;
  629. }
  630. template <typename T>
  631. inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddInitializer(const char* name, const OrtValue* ort_val) {
  632. ThrowOnError(GetApi().AddInitializer(this->p_, name, ort_val));
  633. return *this;
  634. }
  635. template <typename T>
  636. inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisablePerSessionThreads() {
  637. ThrowOnError(GetApi().DisablePerSessionThreads(this->p_));
  638. return *this;
  639. }
  640. template <typename T>
  641. inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddExternalInitializers(const std::vector<std::string>& names,
  642. const std::vector<Value>& ort_values) {
  643. const size_t inputs_num = names.size();
  644. if (inputs_num != ort_values.size()) {
  645. ORT_CXX_API_THROW("Expecting names and ort_values to have the same length", ORT_INVALID_ARGUMENT);
  646. }
  647. std::vector<const char*> names_ptr;
  648. std::vector<const OrtValue*> ort_values_ptrs;
  649. names_ptr.reserve(inputs_num);
  650. ort_values_ptrs.reserve(inputs_num);
  651. for (size_t i = 0; i < inputs_num; ++i) {
  652. names_ptr.push_back(names[i].c_str());
  653. ort_values_ptrs.push_back(ort_values[i]);
  654. }
  655. ThrowOnError(GetApi().AddExternalInitializers(this->p_, names_ptr.data(), ort_values_ptrs.data(), inputs_num));
  656. return *this;
  657. }
  658. template <typename T>
  659. inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddExternalInitializersFromFilesInMemory(const std::vector<std::basic_string<ORTCHAR_T>>& file_names,
  660. const std::vector<char*>& buffer_array,
  661. const std::vector<size_t>& file_lengths) {
  662. const size_t inputs_num = file_names.size();
  663. if (inputs_num != buffer_array.size()) {
  664. ORT_CXX_API_THROW("Expecting names and buffer_array to have the same length", ORT_INVALID_ARGUMENT);
  665. }
  666. if (inputs_num != file_lengths.size()) {
  667. ORT_CXX_API_THROW("Expecting names and file_lengths to have the same length", ORT_INVALID_ARGUMENT);
  668. }
  669. std::vector<const ORTCHAR_T*> names_ptr;
  670. names_ptr.reserve(inputs_num);
  671. for (size_t i = 0; i < inputs_num; ++i) {
  672. names_ptr.push_back(file_names[i].c_str());
  673. }
  674. ThrowOnError(GetApi().AddExternalInitializersFromFilesInMemory(this->p_, names_ptr.data(), buffer_array.data(),
  675. file_lengths.data(), inputs_num));
  676. return *this;
  677. }
  678. template <typename T>
  679. inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options) {
  680. ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA(this->p_, &provider_options));
  681. return *this;
  682. }
  683. template <typename T>
  684. inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options) {
  685. ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA_V2(this->p_, &provider_options));
  686. return *this;
  687. }
  688. template <typename T>
  689. inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options) {
  690. ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_ROCM(this->p_, &provider_options));
  691. return *this;
  692. }
  693. template <typename T>
  694. inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options) {
  695. ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT(this->p_, &provider_options));
  696. return *this;
  697. }
  698. template <typename T>
  699. inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options) {
  700. ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT_V2(this->p_, &provider_options));
  701. return *this;
  702. }
  703. template <typename T>
  704. inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options) {
  705. ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_MIGraphX(this->p_, &provider_options));
  706. return *this;
  707. }
  708. template <typename T>
  709. inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options) {
  710. ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CANN(this->p_, &provider_options));
  711. return *this;
  712. }
  713. template <typename T>
  714. inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options) {
  715. ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_Dnnl(this->p_, &provider_options));
  716. return *this;
  717. }
  718. template <typename T>
  719. inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider(
  720. const std::string& provider_name,
  721. const std::unordered_map<std::string, std::string>& provider_options) {
  722. auto num_entries = provider_options.size();
  723. std::vector<const char*> keys, values;
  724. if (num_entries > 0) {
  725. keys.reserve(num_entries);
  726. values.reserve(num_entries);
  727. for (const auto& entry : provider_options) {
  728. keys.push_back(entry.first.c_str());
  729. values.push_back(entry.second.c_str());
  730. }
  731. }
  732. ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider(this->p_, provider_name.c_str(),
  733. keys.data(), values.data(), num_entries));
  734. return *this;
  735. }
  736. template <typename T>
  737. inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
  738. ThrowOnError(GetApi().SessionOptionsSetCustomCreateThreadFn(this->p_, ort_custom_create_thread_fn));
  739. return *this;
  740. }
  741. template <typename T>
  742. inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options) {
  743. ThrowOnError(GetApi().SessionOptionsSetCustomThreadCreationOptions(this->p_, ort_custom_thread_creation_options));
  744. return *this;
  745. }
  746. template <typename T>
  747. inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) {
  748. ThrowOnError(GetApi().SessionOptionsSetCustomJoinThreadFn(this->p_, ort_custom_join_thread_fn));
  749. return *this;
  750. }
  751. template <typename T>
  752. inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options) {
  753. ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_OpenVINO(this->p_, &provider_options));
  754. return *this;
  755. }
  756. template <typename T>
  757. inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_OpenVINO_V2(const std::unordered_map<std::string, std::string>& provider_options) {
  758. auto num_entries = provider_options.size();
  759. std::vector<const char*> keys, values;
  760. if (num_entries > 0) {
  761. keys.reserve(num_entries);
  762. values.reserve(num_entries);
  763. for (const auto& entry : provider_options) {
  764. keys.push_back(entry.first.c_str());
  765. values.push_back(entry.second.c_str());
  766. }
  767. }
  768. ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_OpenVINO_V2(this->p_,
  769. keys.data(), values.data(), num_entries));
  770. return *this;
  771. }
  772. template <typename T>
  773. inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_VitisAI(const std::unordered_map<std::string, std::string>& provider_options) {
  774. auto num_entries = provider_options.size();
  775. std::vector<const char*> keys, values;
  776. if (num_entries > 0) {
  777. keys.reserve(num_entries);
  778. values.reserve(num_entries);
  779. for (const auto& entry : provider_options) {
  780. keys.push_back(entry.first.c_str());
  781. values.push_back(entry.second.c_str());
  782. }
  783. }
  784. ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_VitisAI(this->p_, keys.data(), values.data(), num_entries));
  785. return *this;
  786. }
  787. template <typename T>
  788. inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::RegisterCustomOpsLibrary(const ORTCHAR_T* library_name,
  789. const CustomOpConfigs& custom_op_configs) {
  790. // Add custom op config entries before registering the custom op library. Otherwise, the config entries _may_ be ignored by
  791. // the custom op library.
  792. for (const auto& config_iter : custom_op_configs.GetFlattenedConfigs()) {
  793. AddConfigEntry(config_iter.first.c_str(), config_iter.second.c_str());
  794. }
  795. ThrowOnError(GetApi().RegisterCustomOpsLibrary_V2(this->p_, library_name));
  796. return *this;
  797. }
  798. template <typename T>
  799. inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::RegisterCustomOpsUsingFunction(const char* registration_function_name) {
  800. ThrowOnError(GetApi().RegisterCustomOpsUsingFunction(this->p_, registration_function_name));
  801. return *this;
  802. }
  803. /// Session
  804. template <typename T>
  805. inline size_t ConstSessionImpl<T>::GetInputCount() const {
  806. size_t out;
  807. ThrowOnError(GetApi().SessionGetInputCount(this->p_, &out));
  808. return out;
  809. }
  810. template <typename T>
  811. inline size_t ConstSessionImpl<T>::GetOutputCount() const {
  812. size_t out;
  813. ThrowOnError(GetApi().SessionGetOutputCount(this->p_, &out));
  814. return out;
  815. }
  816. template <typename T>
  817. inline size_t ConstSessionImpl<T>::GetOverridableInitializerCount() const {
  818. size_t out;
  819. ThrowOnError(GetApi().SessionGetOverridableInitializerCount(this->p_, &out));
  820. return out;
  821. }
  822. template <typename T>
  823. inline AllocatedStringPtr ConstSessionImpl<T>::GetInputNameAllocated(size_t index, OrtAllocator* allocator) const {
  824. char* out;
  825. ThrowOnError(GetApi().SessionGetInputName(this->p_, index, allocator, &out));
  826. return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
  827. }
  828. template <typename T>
  829. inline AllocatedStringPtr ConstSessionImpl<T>::GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const {
  830. char* out;
  831. ThrowOnError(GetApi().SessionGetOutputName(this->p_, index, allocator, &out));
  832. return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
  833. }
  834. template <typename T>
  835. inline AllocatedStringPtr ConstSessionImpl<T>::GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const {
  836. char* out;
  837. ThrowOnError(GetApi().SessionGetOverridableInitializerName(this->p_, index, allocator, &out));
  838. return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
  839. }
  840. template <typename T>
  841. inline uint64_t ConstSessionImpl<T>::GetProfilingStartTimeNs() const {
  842. uint64_t out;
  843. ThrowOnError(GetApi().SessionGetProfilingStartTimeNs(this->p_, &out));
  844. return out;
  845. }
  846. template <typename T>
  847. inline ModelMetadata ConstSessionImpl<T>::GetModelMetadata() const {
  848. OrtModelMetadata* out;
  849. ThrowOnError(GetApi().SessionGetModelMetadata(this->p_, &out));
  850. return ModelMetadata{out};
  851. }
  852. template <typename T>
  853. inline TypeInfo ConstSessionImpl<T>::GetInputTypeInfo(size_t index) const {
  854. OrtTypeInfo* out;
  855. ThrowOnError(GetApi().SessionGetInputTypeInfo(this->p_, index, &out));
  856. return TypeInfo{out};
  857. }
  858. template <typename T>
  859. inline TypeInfo ConstSessionImpl<T>::GetOutputTypeInfo(size_t index) const {
  860. OrtTypeInfo* out;
  861. ThrowOnError(GetApi().SessionGetOutputTypeInfo(this->p_, index, &out));
  862. return TypeInfo{out};
  863. }
  864. template <typename T>
  865. inline TypeInfo ConstSessionImpl<T>::GetOverridableInitializerTypeInfo(size_t index) const {
  866. OrtTypeInfo* out;
  867. ThrowOnError(GetApi().SessionGetOverridableInitializerTypeInfo(this->p_, index, &out));
  868. return TypeInfo{out};
  869. }
  870. template <typename T>
  871. inline std::vector<Value> SessionImpl<T>::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
  872. const char* const* output_names, size_t output_count) {
  873. std::vector<Value> output_values;
  874. output_values.reserve(output_count);
  875. for (size_t i = 0; i < output_count; i++)
  876. output_values.emplace_back(nullptr);
  877. Run(run_options, input_names, input_values, input_count, output_names, output_values.data(), output_count);
  878. return output_values;
  879. }
  880. template <typename T>
  881. inline void SessionImpl<T>::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
  882. const char* const* output_names, Value* output_values, size_t output_count) {
  883. static_assert(sizeof(Value) == sizeof(OrtValue*), "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely");
  884. auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
  885. auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
  886. ThrowOnError(GetApi().Run(this->p_, run_options, input_names, ort_input_values, input_count, output_names, output_count, ort_output_values));
  887. }
  888. template <typename T>
  889. inline void SessionImpl<T>::Run(const RunOptions& run_options, const IoBinding& io_binding) {
  890. ThrowOnError(GetApi().RunWithBinding(this->p_, run_options, io_binding));
  891. }
  892. template <typename T>
  893. inline void SessionImpl<T>::RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
  894. const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data) {
  895. auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
  896. auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
  897. ThrowOnError(GetApi().RunAsync(this->p_, run_options, input_names,
  898. ort_input_values, input_count, output_names, output_count,
  899. ort_output_values, callback, user_data));
  900. }
  901. template <typename T>
  902. inline AllocatedStringPtr SessionImpl<T>::EndProfilingAllocated(OrtAllocator* allocator) {
  903. char* out = nullptr;
  904. ThrowOnError(GetApi().SessionEndProfiling(this->p_, allocator, &out));
  905. return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
  906. }
  907. } // namespace detail
  908. inline SessionOptions::SessionOptions() {
  909. ThrowOnError(GetApi().CreateSessionOptions(&this->p_));
  910. }
  911. /// CustomOpConfigs
  912. inline std::string detail::MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config) {
  913. std::string config_key = "custom_op.";
  914. config_key += custom_op_name;
  915. config_key += ".";
  916. config_key += config;
  917. return config_key;
  918. }
  919. inline CustomOpConfigs& CustomOpConfigs::AddConfig(const char* custom_op_name, const char* config_key, const char* config_value) {
  920. const std::string full_flat_key = detail::MakeCustomOpConfigEntryKey(custom_op_name, config_key);
  921. flat_configs_[full_flat_key] = config_value;
  922. return *this;
  923. }
  924. inline const std::unordered_map<std::string, std::string>& CustomOpConfigs::GetFlattenedConfigs() const {
  925. return flat_configs_;
  926. }
  927. inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) {
  928. ThrowOnError(GetApi().CreateSession(env, model_path, options, &this->p_));
  929. }
  930. inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options,
  931. OrtPrepackedWeightsContainer* prepacked_weights_container) {
  932. ThrowOnError(GetApi().CreateSessionWithPrepackedWeightsContainer(env, model_path, options, prepacked_weights_container, &this->p_));
  933. }
  934. inline Session::Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options) {
  935. ThrowOnError(GetApi().CreateSessionFromArray(env, model_data, model_data_length, options, &this->p_));
  936. }
  937. inline Session::Session(const Env& env, const void* model_data, size_t model_data_length,
  938. const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container) {
  939. ThrowOnError(GetApi().CreateSessionFromArrayWithPrepackedWeightsContainer(env, model_data, model_data_length, options,
  940. prepacked_weights_container, &this->p_));
  941. }
  942. inline AllocatedStringPtr ModelMetadata::GetProducerNameAllocated(OrtAllocator* allocator) const {
  943. char* out;
  944. ThrowOnError(GetApi().ModelMetadataGetProducerName(p_, allocator, &out));
  945. return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
  946. }
  947. inline AllocatedStringPtr ModelMetadata::GetGraphNameAllocated(OrtAllocator* allocator) const {
  948. char* out;
  949. ThrowOnError(GetApi().ModelMetadataGetGraphName(p_, allocator, &out));
  950. return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
  951. }
  952. inline AllocatedStringPtr ModelMetadata::GetDomainAllocated(OrtAllocator* allocator) const {
  953. char* out;
  954. ThrowOnError(GetApi().ModelMetadataGetDomain(p_, allocator, &out));
  955. return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
  956. }
  957. inline AllocatedStringPtr Ort::ModelMetadata::GetDescriptionAllocated(OrtAllocator* allocator) const {
  958. char* out;
  959. ThrowOnError(GetApi().ModelMetadataGetDescription(p_, allocator, &out));
  960. return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
  961. }
  962. inline AllocatedStringPtr ModelMetadata::GetGraphDescriptionAllocated(OrtAllocator* allocator) const {
  963. char* out;
  964. ThrowOnError(GetApi().ModelMetadataGetGraphDescription(p_, allocator, &out));
  965. return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
  966. }
  967. inline AllocatedStringPtr ModelMetadata::LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const {
  968. char* out;
  969. ThrowOnError(GetApi().ModelMetadataLookupCustomMetadataMap(p_, allocator, key, &out));
  970. return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
  971. }
  972. inline std::vector<AllocatedStringPtr> ModelMetadata::GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const {
  973. auto deletor = detail::AllocatedFree(allocator);
  974. std::vector<AllocatedStringPtr> result;
  975. char** out = nullptr;
  976. int64_t num_keys = 0;
  977. ThrowOnError(GetApi().ModelMetadataGetCustomMetadataMapKeys(p_, allocator, &out, &num_keys));
  978. if (num_keys <= 0) {
  979. return result;
  980. }
  981. // array of pointers will be freed
  982. std::unique_ptr<void, decltype(deletor)> array_guard(out, deletor);
  983. // reserve may throw
  984. auto strings_deletor = [&deletor, num_keys](char** out) { for(int64_t i = 0; i < num_keys; ++i) deletor(out[i]); };
  985. std::unique_ptr<char*, decltype(strings_deletor)> strings_guard(out, strings_deletor);
  986. result.reserve(static_cast<size_t>(num_keys));
  987. strings_guard.release();
  988. for (int64_t i = 0; i < num_keys; ++i) {
  989. result.push_back(AllocatedStringPtr(out[i], deletor));
  990. }
  991. return result;
  992. }
  993. inline int64_t ModelMetadata::GetVersion() const {
  994. int64_t out;
  995. ThrowOnError(GetApi().ModelMetadataGetVersion(p_, &out));
  996. return out;
  997. }
  998. namespace detail {
  999. template <typename T>
  1000. inline ONNXTensorElementDataType TensorTypeAndShapeInfoImpl<T>::GetElementType() const {
  1001. ONNXTensorElementDataType out;
  1002. ThrowOnError(GetApi().GetTensorElementType(this->p_, &out));
  1003. return out;
  1004. }
  1005. template <typename T>
  1006. inline size_t TensorTypeAndShapeInfoImpl<T>::GetElementCount() const {
  1007. size_t out;
  1008. ThrowOnError(GetApi().GetTensorShapeElementCount(this->p_, &out));
  1009. return static_cast<size_t>(out);
  1010. }
  1011. template <typename T>
  1012. inline size_t TensorTypeAndShapeInfoImpl<T>::GetDimensionsCount() const {
  1013. size_t out;
  1014. ThrowOnError(GetApi().GetDimensionsCount(this->p_, &out));
  1015. return out;
  1016. }
  1017. template <typename T>
  1018. inline void TensorTypeAndShapeInfoImpl<T>::GetDimensions(int64_t* values, size_t values_count) const {
  1019. ThrowOnError(GetApi().GetDimensions(this->p_, values, values_count));
  1020. }
  1021. template <typename T>
  1022. inline void TensorTypeAndShapeInfoImpl<T>::GetSymbolicDimensions(const char** values, size_t values_count) const {
  1023. ThrowOnError(GetApi().GetSymbolicDimensions(this->p_, values, values_count));
  1024. }
  1025. template <typename T>
  1026. inline std::vector<int64_t> TensorTypeAndShapeInfoImpl<T>::GetShape() const {
  1027. std::vector<int64_t> out(GetDimensionsCount(), 0);
  1028. ThrowOnError(GetApi().GetDimensions(this->p_, out.data(), out.size()));
  1029. return out;
  1030. }
  1031. template <typename T>
  1032. inline ConstTensorTypeAndShapeInfo TypeInfoImpl<T>::GetTensorTypeAndShapeInfo() const {
  1033. const OrtTensorTypeAndShapeInfo* out;
  1034. ThrowOnError(GetApi().CastTypeInfoToTensorInfo(this->p_, &out));
  1035. return ConstTensorTypeAndShapeInfo{out};
  1036. }
  1037. template <typename T>
  1038. inline ConstSequenceTypeInfo TypeInfoImpl<T>::GetSequenceTypeInfo() const {
  1039. const OrtSequenceTypeInfo* out;
  1040. ThrowOnError(GetApi().CastTypeInfoToSequenceTypeInfo(this->p_, &out));
  1041. return ConstSequenceTypeInfo{out};
  1042. }
  1043. template <typename T>
  1044. inline ConstMapTypeInfo TypeInfoImpl<T>::GetMapTypeInfo() const {
  1045. const OrtMapTypeInfo* out;
  1046. ThrowOnError(GetApi().CastTypeInfoToMapTypeInfo(this->p_, &out));
  1047. return ConstMapTypeInfo{out};
  1048. }
  1049. template <typename T>
  1050. inline ONNXType TypeInfoImpl<T>::GetONNXType() const {
  1051. ONNXType out;
  1052. ThrowOnError(GetApi().GetOnnxTypeFromTypeInfo(this->p_, &out));
  1053. return out;
  1054. }
  1055. template <typename T>
  1056. inline TypeInfo SequenceTypeInfoImpl<T>::GetSequenceElementType() const {
  1057. OrtTypeInfo* output;
  1058. ThrowOnError(GetApi().GetSequenceElementType(this->p_, &output));
  1059. return TypeInfo{output};
  1060. }
  1061. template <typename T>
  1062. inline TypeInfo OptionalTypeInfoImpl<T>::GetOptionalElementType() const {
  1063. OrtTypeInfo* info;
  1064. ThrowOnError(GetApi().GetOptionalContainedTypeInfo(this->p_, &info));
  1065. return TypeInfo{info};
  1066. }
  1067. template <typename T>
  1068. inline ONNXTensorElementDataType MapTypeInfoImpl<T>::GetMapKeyType() const {
  1069. ONNXTensorElementDataType out;
  1070. ThrowOnError(GetApi().GetMapKeyType(this->p_, &out));
  1071. return out;
  1072. }
  1073. template <typename T>
  1074. inline TypeInfo MapTypeInfoImpl<T>::GetMapValueType() const {
  1075. OrtTypeInfo* output;
  1076. ThrowOnError(GetApi().GetMapValueType(this->p_, &output));
  1077. return TypeInfo{output};
  1078. }
  1079. template <typename T>
  1080. inline ConstOptionalTypeInfo TypeInfoImpl<T>::GetOptionalTypeInfo() const {
  1081. const OrtOptionalTypeInfo* info;
  1082. ThrowOnError(GetApi().CastTypeInfoToOptionalTypeInfo(this->p_, &info));
  1083. return ConstOptionalTypeInfo{info};
  1084. }
  1085. } // namespace detail
  1086. namespace detail {
  1087. template <typename T>
  1088. template <typename R>
  1089. inline void ConstValueImpl<T>::GetOpaqueData(const char* domain, const char* type_name, R& out) const {
  1090. ThrowOnError(GetApi().GetOpaqueValue(domain, type_name, this->p_, &out, sizeof(R)));
  1091. }
  1092. template <typename T>
  1093. inline bool ConstValueImpl<T>::IsTensor() const {
  1094. int out;
  1095. ThrowOnError(GetApi().IsTensor(this->p_, &out));
  1096. return out != 0;
  1097. }
  1098. template <typename T>
  1099. inline bool ConstValueImpl<T>::HasValue() const {
  1100. int out;
  1101. ThrowOnError(GetApi().HasValue(this->p_, &out));
  1102. return out != 0;
  1103. }
  1104. template <typename T>
  1105. inline size_t ConstValueImpl<T>::GetCount() const {
  1106. size_t out;
  1107. ThrowOnError(GetApi().GetValueCount(this->p_, &out));
  1108. return out;
  1109. }
  1110. template <typename T>
  1111. inline Value ConstValueImpl<T>::GetValue(int index, OrtAllocator* allocator) const {
  1112. OrtValue* out;
  1113. ThrowOnError(GetApi().GetValue(this->p_, index, allocator, &out));
  1114. return Value{out};
  1115. }
  1116. template <typename T>
  1117. inline size_t ConstValueImpl<T>::GetStringTensorDataLength() const {
  1118. size_t out;
  1119. ThrowOnError(GetApi().GetStringTensorDataLength(this->p_, &out));
  1120. return out;
  1121. }
  1122. template <typename T>
  1123. inline size_t ConstValueImpl<T>::GetStringTensorElementLength(size_t element_index) const {
  1124. size_t out;
  1125. ThrowOnError(GetApi().GetStringTensorElementLength(this->p_, element_index, &out));
  1126. return out;
  1127. }
  1128. template <typename T>
  1129. template <typename R>
  1130. inline const R* ConstValueImpl<T>::GetTensorData() const {
  1131. R* out;
  1132. ThrowOnError(GetApi().GetTensorMutableData(const_cast<OrtValue*>(this->p_), (void**)&out));
  1133. return out;
  1134. }
  1135. template <typename T>
  1136. inline const void* ConstValueImpl<T>::GetTensorRawData() const {
  1137. void* out;
  1138. ThrowOnError(GetApi().GetTensorMutableData(const_cast<OrtValue*>(this->p_), &out));
  1139. return out;
  1140. }
  1141. template <typename T>
  1142. inline TypeInfo ConstValueImpl<T>::GetTypeInfo() const {
  1143. OrtTypeInfo* output;
  1144. ThrowOnError(GetApi().GetTypeInfo(this->p_, &output));
  1145. return TypeInfo{output};
  1146. }
  1147. template <typename T>
  1148. inline TensorTypeAndShapeInfo ConstValueImpl<T>::GetTensorTypeAndShapeInfo() const {
  1149. OrtTensorTypeAndShapeInfo* output;
  1150. ThrowOnError(GetApi().GetTensorTypeAndShape(this->p_, &output));
  1151. return TensorTypeAndShapeInfo{output};
  1152. }
  1153. template <typename T>
  1154. inline ConstMemoryInfo ConstValueImpl<T>::GetTensorMemoryInfo() const {
  1155. const OrtMemoryInfo* mem_info;
  1156. ThrowOnError(GetApi().GetTensorMemoryInfo(this->p_, &mem_info));
  1157. return ConstMemoryInfo(mem_info);
  1158. }
  1159. template <typename T>
  1160. inline void ConstValueImpl<T>::GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const {
  1161. ThrowOnError(GetApi().GetStringTensorElement(this->p_, buffer_length, element_index, buffer));
  1162. }
  1163. template <typename T>
  1164. inline std::string ConstValueImpl<T>::GetStringTensorElement(size_t element_index) const {
  1165. size_t buffer_length;
  1166. ThrowOnError(GetApi().GetStringTensorElementLength(this->p_, element_index, &buffer_length));
  1167. std::string s;
  1168. s.resize(buffer_length);
  1169. ThrowOnError(GetApi().GetStringTensorElement(this->p_, buffer_length, element_index, &s[0]));
  1170. return s;
  1171. }
  1172. template <typename T>
  1173. inline void ConstValueImpl<T>::GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const {
  1174. ThrowOnError(GetApi().GetStringTensorContent(this->p_, buffer, buffer_length, offsets, offsets_count));
  1175. }
  1176. #if !defined(DISABLE_SPARSE_TENSORS)
  1177. template <typename T>
  1178. inline OrtSparseFormat ConstValueImpl<T>::GetSparseFormat() const {
  1179. OrtSparseFormat format;
  1180. ThrowOnError(GetApi().GetSparseTensorFormat(this->p_, &format));
  1181. return format;
  1182. }
  1183. template <typename T>
  1184. inline TensorTypeAndShapeInfo ConstValueImpl<T>::GetSparseTensorValuesTypeAndShapeInfo() const {
  1185. OrtTensorTypeAndShapeInfo* output;
  1186. ThrowOnError(GetApi().GetSparseTensorValuesTypeAndShape(this->p_, &output));
  1187. return TensorTypeAndShapeInfo{output};
  1188. }
  1189. template <typename T>
  1190. inline TensorTypeAndShapeInfo ConstValueImpl<T>::GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat indices_format) const {
  1191. OrtTensorTypeAndShapeInfo* output;
  1192. ThrowOnError(GetApi().GetSparseTensorIndicesTypeShape(this->p_, indices_format, &output));
  1193. return TensorTypeAndShapeInfo{output};
  1194. }
  1195. template <typename T>
  1196. template <typename R>
  1197. inline const R* ConstValueImpl<T>::GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const {
  1198. const void* out;
  1199. ThrowOnError(GetApi().GetSparseTensorIndices(this->p_, indices_format, &num_indices, &out));
  1200. return reinterpret_cast<const R*>(out);
  1201. }
  1202. template <typename T>
  1203. inline bool ConstValueImpl<T>::IsSparseTensor() const {
  1204. int out;
  1205. ThrowOnError(GetApi().IsSparseTensor(this->p_, &out));
  1206. return out != 0;
  1207. }
  1208. template <typename T>
  1209. template <typename R>
  1210. inline const R* ConstValueImpl<T>::GetSparseTensorValues() const {
  1211. const void* out;
  1212. ThrowOnError(GetApi().GetSparseTensorValues(this->p_, &out));
  1213. return reinterpret_cast<const R*>(out);
  1214. }
  1215. #endif
  1216. template <typename T>
  1217. void ValueImpl<T>::FillStringTensor(const char* const* s, size_t s_len) {
  1218. ThrowOnError(GetApi().FillStringTensor(this->p_, s, s_len));
  1219. }
  1220. template <typename T>
  1221. void ValueImpl<T>::FillStringTensorElement(const char* s, size_t index) {
  1222. ThrowOnError(GetApi().FillStringTensorElement(this->p_, s, index));
  1223. }
  1224. template <typename T>
  1225. inline char* ValueImpl<T>::GetResizedStringTensorElementBuffer(size_t index, size_t buffer_length) {
  1226. char* result;
  1227. ThrowOnError(GetApi().GetResizedStringTensorElementBuffer(this->p_, index, buffer_length, &result));
  1228. return result;
  1229. }
  1230. template <typename T>
  1231. void* ValueImpl<T>::GetTensorMutableRawData() {
  1232. void* out;
  1233. ThrowOnError(GetApi().GetTensorMutableData(this->p_, &out));
  1234. return out;
  1235. }
  1236. template <typename T>
  1237. template <typename R>
  1238. R* ValueImpl<T>::GetTensorMutableData() {
  1239. R* out;
  1240. ThrowOnError(GetApi().GetTensorMutableData(this->p_, (void**)&out));
  1241. return out;
  1242. }
  1243. template <typename T>
  1244. template <typename R>
  1245. R& ValueImpl<T>::At(const std::vector<int64_t>& location) {
  1246. static_assert(!std::is_same<T, std::string>::value, "this api does not support std::string");
  1247. R* out;
  1248. ThrowOnError(GetApi().TensorAt(this->p_, location.data(), location.size(), (void**)&out));
  1249. return *out;
  1250. }
  1251. #if !defined(DISABLE_SPARSE_TENSORS)
  1252. template <typename T>
  1253. void ValueImpl<T>::UseCooIndices(int64_t* indices_data, size_t indices_num) {
  1254. ThrowOnError(GetApi().UseCooIndices(this->p_, indices_data, indices_num));
  1255. }
  1256. template <typename T>
  1257. void ValueImpl<T>::UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num) {
  1258. ThrowOnError(GetApi().UseCsrIndices(this->p_, inner_data, inner_num, outer_data, outer_num));
  1259. }
  1260. template <typename T>
  1261. void ValueImpl<T>::UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data) {
  1262. ThrowOnError(GetApi().UseBlockSparseIndices(this->p_, indices_shape.shape, indices_shape.shape_len, indices_data));
  1263. }
  1264. template <typename T>
  1265. void ValueImpl<T>::FillSparseTensorCoo(const OrtMemoryInfo* mem_info, const OrtSparseValuesParam& values_param,
  1266. const int64_t* indices_data, size_t indices_num) {
  1267. ThrowOnError(GetApi().FillSparseTensorCoo(this->p_, mem_info, values_param.values_shape,
  1268. values_param.values_shape_len, values_param.data.p_data,
  1269. indices_data, indices_num));
  1270. }
  1271. template <typename T>
  1272. void ValueImpl<T>::FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info,
  1273. const OrtSparseValuesParam& values,
  1274. const int64_t* inner_indices_data, size_t inner_indices_num,
  1275. const int64_t* outer_indices_data, size_t outer_indices_num) {
  1276. ThrowOnError(GetApi().FillSparseTensorCsr(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data,
  1277. inner_indices_data, inner_indices_num,
  1278. outer_indices_data, outer_indices_num));
  1279. }
  1280. template <typename T>
  1281. void ValueImpl<T>::FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info,
  1282. const OrtSparseValuesParam& values,
  1283. const Shape& indices_shape,
  1284. const int32_t* indices_data) {
  1285. ThrowOnError(GetApi().FillSparseTensorBlockSparse(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data,
  1286. indices_shape.shape, indices_shape.shape_len,
  1287. indices_data));
  1288. }
  1289. #endif // !defined(DISABLE_SPARSE_TENSORS)
  1290. } // namespace detail
  1291. template <typename T>
  1292. inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len) {
  1293. return CreateTensor(info, p_data, p_data_element_count * sizeof(T), shape, shape_len, TypeToTensorType<T>::type);
  1294. }
  1295. inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
  1296. ONNXTensorElementDataType type) {
  1297. OrtValue* out;
  1298. ThrowOnError(GetApi().CreateTensorWithDataAsOrtValue(info, p_data, p_data_byte_count, shape, shape_len, type, &out));
  1299. return Value{out};
  1300. }
  1301. template <typename T>
  1302. inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len) {
  1303. return CreateTensor(allocator, shape, shape_len, TypeToTensorType<T>::type);
  1304. }
  1305. inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type) {
  1306. OrtValue* out;
  1307. ThrowOnError(GetApi().CreateTensorAsOrtValue(allocator, shape, shape_len, type, &out));
  1308. return Value{out};
  1309. }
  1310. #if !defined(DISABLE_SPARSE_TENSORS)
  1311. template <typename T>
  1312. inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape,
  1313. const Shape& values_shape) {
  1314. return CreateSparseTensor(info, p_data, dense_shape, values_shape, TypeToTensorType<T>::type);
  1315. }
  1316. inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape,
  1317. const Shape& values_shape, ONNXTensorElementDataType type) {
  1318. OrtValue* out;
  1319. ThrowOnError(GetApi().CreateSparseTensorWithValuesAsOrtValue(info, p_data, dense_shape.shape, dense_shape.shape_len,
  1320. values_shape.shape, values_shape.shape_len, type, &out));
  1321. return Value{out};
  1322. }
  1323. template <typename T>
  1324. inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape) {
  1325. return CreateSparseTensor(allocator, dense_shape, TypeToTensorType<T>::type);
  1326. }
  1327. inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape,
  1328. ONNXTensorElementDataType type) {
  1329. OrtValue* out;
  1330. ThrowOnError(GetApi().CreateSparseTensorAsOrtValue(allocator, dense_shape.shape, dense_shape.shape_len, type, &out));
  1331. return Value{out};
  1332. }
  1333. #endif // !defined(DISABLE_SPARSE_TENSORS)
  1334. inline Value Value::CreateMap(const Value& keys, const Value& values) {
  1335. OrtValue* out;
  1336. const OrtValue* inputs[2] = {keys, values};
  1337. ThrowOnError(GetApi().CreateValue(inputs, 2, ONNX_TYPE_MAP, &out));
  1338. return Value{out};
  1339. }
  1340. inline Value Value::CreateSequence(const std::vector<Value>& values) {
  1341. OrtValue* out;
  1342. std::vector<const OrtValue*> values_ort{values.data(), values.data() + values.size()};
  1343. ThrowOnError(GetApi().CreateValue(values_ort.data(), values_ort.size(), ONNX_TYPE_SEQUENCE, &out));
  1344. return Value{out};
  1345. }
  1346. template <typename T>
  1347. inline Value Value::CreateOpaque(const char* domain, const char* type_name, const T& data_container) {
  1348. OrtValue* out;
  1349. ThrowOnError(GetApi().CreateOpaqueValue(domain, type_name, &data_container, sizeof(T), &out));
  1350. return Value{out};
  1351. }
  1352. //
  1353. // Custom OP Inlines
  1354. //
  1355. inline Logger::Logger(const OrtLogger* logger) : logger_(logger) {
  1356. Ort::ThrowOnError(GetApi().Logger_GetLoggingSeverityLevel(this->logger_, &this->cached_severity_level_));
  1357. }
  1358. inline OrtLoggingLevel Logger::GetLoggingSeverityLevel() const noexcept {
  1359. return cached_severity_level_;
  1360. }
  1361. inline Status Logger::LogMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number,
  1362. const char* func_name, const char* message) const noexcept {
  1363. OrtStatus* status = GetApi().Logger_LogMessage(logger_, log_severity_level, message, file_path, line_number,
  1364. func_name);
  1365. return Status{status};
  1366. }
  1367. // Disable warnings about the format string not being a literal (-Wformat-nonliteral and -Wformat-security)
  1368. // for gcc and clang. The alternative is to use actual C-style variadic parameters and apply
  1369. // __attribute__(format(printf...)), which does not work with variadic templates.
  1370. #if defined(__GNUC__)
  1371. #pragma GCC diagnostic push
  1372. #pragma GCC diagnostic ignored "-Wformat-nonliteral"
  1373. #pragma GCC diagnostic ignored "-Wformat-security"
  1374. #elif defined(__clang__)
  1375. #pragma clang diagnostic push
  1376. #pragma clang diagnostic ignored "-Wformat-nonliteral"
  1377. #pragma clang diagnostic ignored "-Wformat-security"
  1378. #endif
  1379. template <typename... Args>
  1380. inline Status Logger::LogFormattedMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path,
  1381. int line_number, const char* func_name, const char* format,
  1382. Args&&... args) const noexcept {
  1383. int msg_len = std::snprintf(nullptr, 0U, format, std::forward<Args>(args)...);
  1384. if (msg_len < 0) { // Formatting error
  1385. return Status("Failed to log message due to formatting error", OrtErrorCode::ORT_FAIL);
  1386. }
  1387. OrtStatus* status = nullptr;
  1388. const size_t buffer_size = static_cast<size_t>(msg_len) + 1U;
  1389. constexpr size_t kStackBufferSize = 1024;
  1390. if (buffer_size < kStackBufferSize) {
  1391. char buffer[kStackBufferSize];
  1392. snprintf(buffer, kStackBufferSize, format, std::forward<Args>(args)...);
  1393. status = GetApi().Logger_LogMessage(logger_, log_severity_level, buffer, file_path, line_number, func_name);
  1394. } else {
  1395. // std::make_unique is only supported starting at C++14.
  1396. #if (__cplusplus >= 201402L) || (_MSC_VER >= 1900)
  1397. auto buffer = std::make_unique<char[]>(buffer_size);
  1398. #else
  1399. std::unique_ptr<char[]> buffer(new char[buffer_size]);
  1400. #endif
  1401. std::snprintf(buffer.get(), buffer_size, format, std::forward<Args>(args)...);
  1402. status = GetApi().Logger_LogMessage(logger_, log_severity_level, buffer.get(), file_path, line_number, func_name);
  1403. }
  1404. return Status{status};
  1405. }
  1406. // Re-enable -Wformat-nonliteral and -Wformat-security
  1407. #if defined(__GNUC__)
  1408. #pragma GCC diagnostic pop
  1409. #elif defined(__clang__)
  1410. #pragma clang diagnostic pop
  1411. #endif
  1412. inline KernelContext::KernelContext(OrtKernelContext* context) : ctx_(context) {
  1413. }
  1414. inline size_t KernelContext::GetInputCount() const {
  1415. size_t out = 0;
  1416. Ort::ThrowOnError(GetApi().KernelContext_GetInputCount(ctx_, &out));
  1417. return out;
  1418. }
  1419. inline size_t KernelContext::GetOutputCount() const {
  1420. size_t out = 0;
  1421. Ort::ThrowOnError(GetApi().KernelContext_GetOutputCount(ctx_, &out));
  1422. return out;
  1423. }
  1424. inline ConstValue KernelContext::GetInput(size_t index) const {
  1425. const OrtValue* out = nullptr;
  1426. Ort::ThrowOnError(GetApi().KernelContext_GetInput(ctx_, index, &out));
  1427. return ConstValue{out};
  1428. }
  1429. inline UnownedValue KernelContext::GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const {
  1430. OrtValue* out = nullptr;
  1431. Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dim_values, dim_count, &out));
  1432. return UnownedValue(out);
  1433. }
  1434. inline UnownedValue KernelContext::GetOutput(size_t index, const std::vector<int64_t>& dims) const {
  1435. OrtValue* out = nullptr;
  1436. Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dims.data(), dims.size(), &out));
  1437. return UnownedValue(out);
  1438. }
  1439. inline void* KernelContext::GetGPUComputeStream() const {
  1440. void* out = nullptr;
  1441. Ort::ThrowOnError(GetApi().KernelContext_GetGPUComputeStream(ctx_, &out));
  1442. return out;
  1443. }
  1444. inline OrtAllocator* KernelContext::GetAllocator(const OrtMemoryInfo& memory_info) const {
  1445. OrtAllocator* out = nullptr;
  1446. Ort::ThrowOnError(GetApi().KernelContext_GetAllocator(ctx_, &memory_info, &out));
  1447. return out;
  1448. }
  1449. inline Logger KernelContext::GetLogger() const {
  1450. const OrtLogger* out = nullptr;
  1451. ThrowOnError(GetApi().KernelContext_GetLogger(this->ctx_, &out));
  1452. return Logger{out};
  1453. }
  1454. inline void KernelContext::ParallelFor(void (*fn)(void*, size_t), size_t total, size_t num_batch, void* usr_data) const {
  1455. ThrowOnError(GetApi().KernelContext_ParallelFor(ctx_, fn, total, num_batch, usr_data));
  1456. }
  1457. inline OpAttr::OpAttr(const char* name, const void* data, int len, OrtOpAttrType type) {
  1458. Ort::ThrowOnError(GetApi().CreateOpAttr(name, data, len, type, &p_));
  1459. }
  1460. namespace detail {
  1461. template <typename T>
  1462. inline KernelInfo KernelInfoImpl<T>::Copy() const {
  1463. OrtKernelInfo* info_copy = nullptr;
  1464. Ort::ThrowOnError(GetApi().CopyKernelInfo(this->p_, &info_copy));
  1465. return KernelInfo{info_copy};
  1466. }
  1467. template <typename T>
  1468. inline size_t KernelInfoImpl<T>::GetInputCount() const {
  1469. size_t out = 0;
  1470. ThrowOnError(GetApi().KernelInfo_GetInputCount(this->p_, &out));
  1471. return out;
  1472. }
  1473. template <typename T>
  1474. inline size_t KernelInfoImpl<T>::GetOutputCount() const {
  1475. size_t out = 0;
  1476. ThrowOnError(GetApi().KernelInfo_GetOutputCount(this->p_, &out));
  1477. return out;
  1478. }
  1479. template <typename T>
  1480. inline std::string KernelInfoImpl<T>::GetInputName(size_t index) const {
  1481. size_t size = 0;
  1482. // Feed nullptr for the data buffer to query the true size of the string value
  1483. Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, nullptr, &size));
  1484. std::string out;
  1485. out.resize(size);
  1486. Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, &out[0], &size));
  1487. out.resize(size - 1); // remove the terminating character '\0'
  1488. return out;
  1489. }
  1490. template <typename T>
  1491. inline std::string KernelInfoImpl<T>::GetOutputName(size_t index) const {
  1492. size_t size = 0;
  1493. // Feed nullptr for the data buffer to query the true size of the string value
  1494. Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, nullptr, &size));
  1495. std::string out;
  1496. out.resize(size);
  1497. Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, &out[0], &size));
  1498. out.resize(size - 1); // remove the terminating character '\0'
  1499. return out;
  1500. }
  1501. template <typename T>
  1502. inline TypeInfo KernelInfoImpl<T>::GetInputTypeInfo(size_t index) const {
  1503. OrtTypeInfo* out = nullptr;
  1504. ThrowOnError(GetApi().KernelInfo_GetInputTypeInfo(this->p_, index, &out));
  1505. return TypeInfo{out};
  1506. }
  1507. template <typename T>
  1508. inline TypeInfo KernelInfoImpl<T>::GetOutputTypeInfo(size_t index) const {
  1509. OrtTypeInfo* out = nullptr;
  1510. ThrowOnError(GetApi().KernelInfo_GetOutputTypeInfo(this->p_, index, &out));
  1511. return TypeInfo{out};
  1512. }
  1513. template <typename T>
  1514. inline Value KernelInfoImpl<T>::GetTensorAttribute(const char* name, OrtAllocator* allocator) const {
  1515. OrtValue* out = nullptr;
  1516. ThrowOnError(GetApi().KernelInfoGetAttribute_tensor(this->p_, name, allocator, &out));
  1517. return Value{out};
  1518. }
  1519. template <typename T>
  1520. inline ConstValue KernelInfoImpl<T>::GetTensorConstantInput(size_t index, int* is_constant) const {
  1521. const OrtValue* out = nullptr;
  1522. ThrowOnError(GetApi().KernelInfoGetConstantInput_tensor(this->p_, index, is_constant, &out));
  1523. return ConstValue{out};
  1524. }
  1525. template <typename T>
  1526. inline std::string KernelInfoImpl<T>::GetNodeName() const {
  1527. size_t size = 0;
  1528. // Feed nullptr for the data buffer to query the true size of the string value
  1529. Ort::ThrowOnError(GetApi().KernelInfo_GetNodeName(this->p_, nullptr, &size));
  1530. std::string out;
  1531. out.resize(size);
  1532. Ort::ThrowOnError(GetApi().KernelInfo_GetNodeName(this->p_, &out[0], &size));
  1533. out.resize(size - 1); // remove the terminating character '\0'
  1534. return out;
  1535. }
  1536. template <typename T>
  1537. inline Logger KernelInfoImpl<T>::GetLogger() const {
  1538. const OrtLogger* out = nullptr;
  1539. ThrowOnError(GetApi().KernelInfo_GetLogger(this->p_, &out));
  1540. return Logger{out};
  1541. }
  1542. inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, float& out) {
  1543. Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_float(p, name, &out));
  1544. }
  1545. inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, int64_t& out) {
  1546. Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_int64(p, name, &out));
  1547. }
  1548. inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, std::string& result) {
  1549. size_t size = 0;
  1550. // Feed nullptr for the data buffer to query the true size of the string attribute
  1551. Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, nullptr, &size));
  1552. std::string out;
  1553. out.resize(size);
  1554. Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, &out[0], &size));
  1555. out.resize(size - 1); // remove the terminating character '\0'
  1556. out.swap(result);
  1557. }
  1558. inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<float>& result) {
  1559. size_t size = 0;
  1560. // Feed nullptr for the data buffer to query the true size of the attribute
  1561. Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, nullptr, &size));
  1562. std::vector<float> out;
  1563. out.resize(size);
  1564. Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, out.data(), &size));
  1565. out.swap(result);
  1566. }
  1567. inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<int64_t>& result) {
  1568. size_t size = 0;
  1569. // Feed nullptr for the data buffer to query the true size of the attribute
  1570. Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, nullptr, &size));
  1571. std::vector<int64_t> out;
  1572. out.resize(size);
  1573. Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, out.data(), &size));
  1574. out.swap(result);
  1575. }
  1576. } // namespace detail
  1577. inline KernelInfo::KernelInfo(OrtKernelInfo* info) : detail::KernelInfoImpl<OrtKernelInfo>{info} {}
  1578. inline Op::Op(OrtOp* p) : Base<OrtOp>(p) {}
  1579. inline Op Op::Create(const OrtKernelInfo* info, const char* op_name, const char* domain, int version,
  1580. const char** type_constraint_names,
  1581. const ONNXTensorElementDataType* type_constraint_values,
  1582. size_t type_constraint_count,
  1583. const OpAttr* attr_values, size_t attr_count,
  1584. size_t input_count, size_t output_count) {
  1585. static_assert(sizeof(OpAttr) == sizeof(OrtOpAttr*),
  1586. "OpAttr's is expected to be just an array of OrtOpAttr in memory so we can reinterpret safely");
  1587. auto attr_input_values = reinterpret_cast<const OrtOpAttr* const*>(attr_values);
  1588. OrtOp* op;
  1589. Ort::ThrowOnError(GetApi().CreateOp(info, op_name, domain, version, type_constraint_names, type_constraint_values,
  1590. static_cast<int>(type_constraint_count),
  1591. attr_input_values,
  1592. static_cast<int>(attr_count),
  1593. static_cast<int>(input_count),
  1594. static_cast<int>(output_count), &op));
  1595. return Op{op};
  1596. }
  1597. inline void Op::Invoke(const OrtKernelContext* context,
  1598. const Value* input_values,
  1599. size_t input_count,
  1600. Value* output_values,
  1601. size_t output_count) {
  1602. static_assert(sizeof(Value) == sizeof(OrtValue*),
  1603. "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely");
  1604. auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
  1605. auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
  1606. Ort::ThrowOnError(GetApi().InvokeOp(context, p_, ort_input_values, static_cast<int>(input_count),
  1607. ort_output_values, static_cast<int>(output_count)));
  1608. }
  1609. inline void Op::Invoke(const OrtKernelContext* context,
  1610. const OrtValue* const* input_values,
  1611. size_t input_count,
  1612. OrtValue* const* output_values,
  1613. size_t output_count) {
  1614. Ort::ThrowOnError(GetApi().InvokeOp(context, p_, input_values, static_cast<int>(input_count),
  1615. output_values, static_cast<int>(output_count)));
  1616. }
  1617. inline std::string GetVersionString() {
  1618. return OrtGetApiBase()->GetVersionString();
  1619. }
  1620. inline std::string GetBuildInfoString() {
  1621. return GetApi().GetBuildInfoString();
  1622. }
  1623. inline std::vector<std::string> GetAvailableProviders() {
  1624. char** providers;
  1625. int len;
  1626. auto release_fn = [&len](char** providers) {
  1627. // This should always return nullptr.
  1628. ThrowOnError(GetApi().ReleaseAvailableProviders(providers, len));
  1629. };
  1630. ThrowOnError(GetApi().GetAvailableProviders(&providers, &len));
  1631. std::unique_ptr<char*, decltype(release_fn)> guard(providers, release_fn);
  1632. std::vector<std::string> available_providers;
  1633. available_providers.reserve(static_cast<size_t>(len));
  1634. for (int i = 0; i < len; ++i) {
  1635. available_providers.emplace_back(providers[i]);
  1636. }
  1637. return available_providers;
  1638. }
  1639. template <typename TOp, typename TKernel, bool WithStatus>
  1640. void CustomOpBase<TOp, TKernel, WithStatus>::GetSessionConfigs(std::unordered_map<std::string, std::string>& out,
  1641. ConstSessionOptions options) const {
  1642. const TOp* derived = static_cast<const TOp*>(this);
  1643. std::vector<std::string> keys = derived->GetSessionConfigKeys();
  1644. out.reserve(keys.size());
  1645. std::string config_entry_key = detail::MakeCustomOpConfigEntryKey(derived->GetName(), "");
  1646. const size_t prefix_size = config_entry_key.length();
  1647. for (const auto& key : keys) {
  1648. config_entry_key.resize(prefix_size);
  1649. config_entry_key.append(key);
  1650. out[key] = options.GetConfigEntryOrDefault(config_entry_key.c_str(), "");
  1651. }
  1652. }
  1653. inline ShapeInferContext::ShapeInferContext(const OrtApi* ort_api,
  1654. OrtShapeInferContext* ctx) : ort_api_(ort_api), ctx_(ctx) {
  1655. size_t input_count = 0;
  1656. Ort::ThrowOnError(ort_api_->ShapeInferContext_GetInputCount(ctx_, &input_count));
  1657. for (size_t ith_input = 0; ith_input < input_count; ++ith_input) {
  1658. OrtTensorTypeAndShapeInfo* info{};
  1659. Ort::ThrowOnError(ort_api_->ShapeInferContext_GetInputTypeShape(ctx, ith_input, &info));
  1660. TensorTypeAndShapeInfo type_shape_info(info);
  1661. auto integer_shape = type_shape_info.GetShape();
  1662. std::vector<const char*> symbolic_shape(integer_shape.size(), {});
  1663. type_shape_info.GetSymbolicDimensions(&symbolic_shape[0], integer_shape.size());
  1664. Shape shape;
  1665. for (size_t ith = 0; ith < integer_shape.size(); ++ith) {
  1666. if (symbolic_shape[ith] && std::string{symbolic_shape[ith]}.size() > 0) {
  1667. shape.emplace_back(symbolic_shape[ith]);
  1668. } else {
  1669. shape.emplace_back(integer_shape[ith]);
  1670. }
  1671. }
  1672. input_shapes_.push_back(std::move(shape));
  1673. type_shape_info.release();
  1674. }
  1675. }
  1676. inline Status ShapeInferContext::SetOutputShape(size_t indice, const Shape& shape) {
  1677. OrtTensorTypeAndShapeInfo* info = {};
  1678. ORT_CXX_RETURN_ON_API_FAIL(ort_api_->CreateTensorTypeAndShapeInfo(&info));
  1679. using InfoPtr = std::unique_ptr<OrtTensorTypeAndShapeInfo, std::function<void(OrtTensorTypeAndShapeInfo*)>>;
  1680. InfoPtr info_ptr(info, [this](OrtTensorTypeAndShapeInfo* obj) {
  1681. ort_api_->ReleaseTensorTypeAndShapeInfo(obj);
  1682. });
  1683. std::vector<int64_t> integer_dims;
  1684. std::vector<const char*> symbolic_dims;
  1685. for (const auto dim : shape) {
  1686. if (dim.IsInt()) {
  1687. integer_dims.push_back(dim.IsInt());
  1688. symbolic_dims.push_back("");
  1689. } else {
  1690. if (!dim.AsSym() || std::string{dim.AsSym()}.empty()) {
  1691. ORT_CXX_API_THROW("Symbolic dim must not be an empty string", ORT_INVALID_ARGUMENT);
  1692. }
  1693. integer_dims.push_back(SymbolicInteger::INVALID_INT_DIM);
  1694. symbolic_dims.push_back(dim.AsSym());
  1695. }
  1696. }
  1697. ORT_CXX_RETURN_ON_API_FAIL(ort_api_->SetDimensions(info, integer_dims.data(), integer_dims.size()));
  1698. ORT_CXX_RETURN_ON_API_FAIL(ort_api_->SetSymbolicDimensions(info, symbolic_dims.data(), symbolic_dims.size()));
  1699. ORT_CXX_RETURN_ON_API_FAIL(ort_api_->ShapeInferContext_SetOutputTypeShape(ctx_, indice, info));
  1700. return Status{nullptr};
  1701. }
  1702. inline int64_t ShapeInferContext::GetAttrInt(const char* attr_name) {
  1703. const auto* attr = GetAttrHdl(attr_name);
  1704. int64_t i = {};
  1705. size_t out = {};
  1706. Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INT, &i, sizeof(i), &out));
  1707. return i;
  1708. }
  1709. inline ShapeInferContext::Ints ShapeInferContext::GetAttrInts(const char* attr_name) {
  1710. const auto* attr = GetAttrHdl(attr_name);
  1711. int64_t i = {};
  1712. size_t out = {};
  1713. // first call to get the bytes needed
  1714. auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INTS, &i, sizeof(i), &out);
  1715. if (status) {
  1716. size_t num_i = out / sizeof(int64_t);
  1717. ShapeInferContext::Ints ints(num_i, 0);
  1718. Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INTS, ints.data(), out, &out));
  1719. return ints;
  1720. } else {
  1721. return {i};
  1722. }
  1723. }
  1724. inline float ShapeInferContext::GetAttrFloat(const char* attr_name) {
  1725. const auto* attr = GetAttrHdl(attr_name);
  1726. float f = {};
  1727. size_t out = {};
  1728. Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOAT, &f, sizeof(f), &out));
  1729. return f;
  1730. }
  1731. inline ShapeInferContext::Floats ShapeInferContext::GetAttrFloats(const char* attr_name) {
  1732. const auto* attr = GetAttrHdl(attr_name);
  1733. float f = {};
  1734. size_t out = {};
  1735. // first call to get the bytes needed
  1736. auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOATS, &f, sizeof(f), &out);
  1737. if (status) {
  1738. size_t num_f = out / sizeof(float);
  1739. ShapeInferContext::Floats floats(num_f, 0);
  1740. Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOATS, floats.data(), out, &out));
  1741. return floats;
  1742. } else {
  1743. return {f};
  1744. }
  1745. }
  1746. inline std::string ShapeInferContext::GetAttrString(const char* attr_name) {
  1747. const auto* attr = GetAttrHdl(attr_name);
  1748. char c = {};
  1749. size_t out = {};
  1750. // first call to get the bytes needed
  1751. auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRING, &c, sizeof(char), &out);
  1752. if (status) {
  1753. std::vector<char> chars(out, '\0');
  1754. Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRING, chars.data(), out, &out));
  1755. return {chars.data()};
  1756. } else {
  1757. return {c};
  1758. }
  1759. }
  1760. inline ShapeInferContext::Strings ShapeInferContext::GetAttrStrings(const char* attr_name) {
  1761. const auto* attr = GetAttrHdl(attr_name);
  1762. char c = {};
  1763. size_t out = {};
  1764. // first call to get the bytes needed
  1765. auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRINGS, &c, sizeof(char), &out);
  1766. if (status) {
  1767. std::vector<char> chars(out, '\0');
  1768. Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRINGS, chars.data(), out, &out));
  1769. ShapeInferContext::Strings strings;
  1770. char* char_st = chars.data();
  1771. char* char_ed = char_st + out;
  1772. while (char_st < char_ed) {
  1773. strings.emplace_back(char_st);
  1774. while (*char_st != '\0') {
  1775. char_st++;
  1776. }
  1777. char_st++;
  1778. }
  1779. return strings;
  1780. } else {
  1781. return {std::string{c}};
  1782. }
  1783. }
  1784. inline const OrtOpAttr* ShapeInferContext::GetAttrHdl(const char* attr_name) const {
  1785. const OrtOpAttr* attr_hdl = {};
  1786. Ort::ThrowOnError(ort_api_->ShapeInferContext_GetAttribute(ctx_, attr_name, &attr_hdl));
  1787. return attr_hdl;
  1788. }
  1789. } // namespace Ort