onnxruntime_cxx_api.h 106 KB


  1. // Copyright (c) Microsoft Corporation. All rights reserved.
  2. // Licensed under the MIT License.
  3. // Summary: The Ort C++ API is a header only wrapper around the Ort C API.
  4. //
  5. // The C++ API simplifies usage by returning values directly instead of error codes, throwing exceptions on errors
  6. // and automatically releasing resources in the destructors. The primary purpose of C++ API is exception safety so
  7. // all the resources follow RAII and do not leak memory.
  8. //
  9. // Each of the C++ wrapper classes holds only a pointer to the C internal object. Treat them like smart pointers.
  10. // To create an empty object, pass 'nullptr' to the constructor (for example, Env e{nullptr};). However, you can't use them
  11. // until you assign an instance that actually holds an underlying object.
  12. //
  13. // For Ort objects only move assignment between objects is allowed, there are no copy constructors.
  14. // Some objects have explicit 'Clone' methods for this purpose.
  15. //
  16. // ConstXXXX types are copyable since they do not own the underlying C object, so you can pass them to functions as arguments
  17. // by value or by reference. ConstXXXX types are restricted to const only interfaces.
  18. //
  19. // UnownedXXXX are similar to ConstXXXX but also allow non-const interfaces.
  20. //
  21. // The lifetime of the corresponding owning object must eclipse the lifetimes of the ConstXXXX/UnownedXXXX types. They exists so you do not
  22. // have to fallback to C types and the API with the usual pitfalls. In general, do not use C API from your C++ code.
  23. #pragma once
  24. #include "onnxruntime_c_api.h"
  25. #include "onnxruntime_float16.h"
  26. #include <cstddef>
  27. #include <cstdio>
  28. #include <array>
  29. #include <memory>
  30. #include <stdexcept>
  31. #include <string>
  32. #include <vector>
  33. #include <unordered_map>
  34. #include <utility>
  35. #include <type_traits>
  36. #ifdef ORT_NO_EXCEPTIONS
  37. #include <iostream>
  38. #endif
  39. /** \brief All C++ Onnxruntime APIs are defined inside this namespace
  40. *
  41. */
  42. namespace Ort {
  43. /** \brief All C++ methods that can fail will throw an exception of this type
  44. *
  45. * If <tt>ORT_NO_EXCEPTIONS</tt> is defined, then any error will result in a call to abort()
  46. */
  47. struct Exception : std::exception {
  48. Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {}
  49. OrtErrorCode GetOrtErrorCode() const { return code_; }
  50. const char* what() const noexcept override { return message_.c_str(); }
  51. private:
  52. std::string message_;
  53. OrtErrorCode code_;
  54. };
  55. #ifdef ORT_NO_EXCEPTIONS
  56. // The #ifndef is for the very special case where the user of this library wants to define their own way of handling errors.
  57. // NOTE: This header expects control flow to not continue after calling ORT_CXX_API_THROW
  58. #ifndef ORT_CXX_API_THROW
  59. #define ORT_CXX_API_THROW(string, code) \
  60. do { \
  61. std::cerr << Ort::Exception(string, code) \
  62. .what() \
  63. << std::endl; \
  64. abort(); \
  65. } while (false)
  66. #endif
  67. #else
  68. #define ORT_CXX_API_THROW(string, code) \
  69. throw Ort::Exception(string, code)
  70. #endif
  71. // This is used internally by the C++ API. This class holds the global variable that points to the OrtApi,
  72. // it's in a template so that we can define a global variable in a header and make
  73. // it transparent to the users of the API.
  74. template <typename T>
  75. struct Global {
  76. static const OrtApi* api_;
  77. };
  78. // If macro ORT_API_MANUAL_INIT is defined, no static initialization will be performed. Instead, user must call InitApi() before using it.
  79. template <typename T>
  80. #ifdef ORT_API_MANUAL_INIT
  81. const OrtApi* Global<T>::api_{};
  82. inline void InitApi() noexcept { Global<void>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); }
  83. // Used by custom operator libraries that are not linked to onnxruntime. Sets the global API object, which is
  84. // required by C++ APIs.
  85. //
  86. // Example mycustomop.cc:
  87. //
  88. // #define ORT_API_MANUAL_INIT
  89. // #include <onnxruntime_cxx_api.h>
  90. // #undef ORT_API_MANUAL_INIT
  91. //
  92. // OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api_base) {
  93. // Ort::InitApi(api_base->GetApi(ORT_API_VERSION));
  94. // // ...
  95. // }
  96. //
  97. inline void InitApi(const OrtApi* api) noexcept { Global<void>::api_ = api; }
  98. #else
  99. #if defined(_MSC_VER) && !defined(__clang__)
  100. #pragma warning(push)
  101. // "Global initializer calls a non-constexpr function." Therefore you can't use ORT APIs in the other global initializers.
  102. // Please define ORT_API_MANUAL_INIT if it conerns you.
  103. #pragma warning(disable : 26426)
  104. #endif
  105. const OrtApi* Global<T>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION);
  106. #if defined(_MSC_VER) && !defined(__clang__)
  107. #pragma warning(pop)
  108. #endif
  109. #endif
  110. /// This returns a reference to the OrtApi interface in use
  111. inline const OrtApi& GetApi() noexcept { return *Global<void>::api_; }
  112. /// <summary>
  113. /// This function returns the onnxruntime version string
  114. /// </summary>
  115. /// <returns>version string major.minor.rev</returns>
  116. std::string GetVersionString();
  117. /// <summary>
  118. /// This function returns the onnxruntime build information: including git branch,
  119. /// git commit id, build type(Debug/Release/RelWithDebInfo) and cmake cpp flags.
  120. /// </summary>
  121. /// <returns>string</returns>
  122. std::string GetBuildInfoString();
  123. /// <summary>
  124. /// This is a C++ wrapper for OrtApi::GetAvailableProviders() and
  125. /// returns a vector of strings representing the available execution providers.
  126. /// </summary>
  127. /// <returns>vector of strings</returns>
  128. std::vector<std::string> GetAvailableProviders();
  129. /** \brief IEEE 754 half-precision floating point data type
  130. *
  131. * \details This struct is used for converting float to float16 and back
  132. * so the user could feed inputs and fetch outputs using these type.
  133. *
  134. * The size of the structure should align with uint16_t and one can freely cast
  135. * uint16_t buffers to/from Ort::Float16_t to feed and retrieve data.
  136. *
  137. * \code{.unparsed}
  138. * // This example demonstrates converion from float to float16
  139. * constexpr float values[] = {1.f, 2.f, 3.f, 4.f, 5.f};
  140. * std::vector<Ort::Float16_t> fp16_values;
  141. * fp16_values.reserve(std::size(values));
  142. * std::transform(std::begin(values), std::end(values), std::back_inserter(fp16_values),
  143. * [](float value) { return Ort::Float16_t(value); });
  144. *
  145. * \endcode
  146. */
  147. struct Float16_t : onnxruntime_float16::Float16Impl<Float16_t> {
  148. private:
  149. /// <summary>
  150. /// Constructor from a 16-bit representation of a float16 value
  151. /// No conversion is done here.
  152. /// </summary>
  153. /// <param name="v">16-bit representation</param>
  154. constexpr explicit Float16_t(uint16_t v) noexcept { val = v; }
  155. public:
  156. using Base = onnxruntime_float16::Float16Impl<Float16_t>;
  157. /// <summary>
  158. /// Default constructor
  159. /// </summary>
  160. Float16_t() = default;
  161. /// <summary>
  162. /// Explicit conversion to uint16_t representation of float16.
  163. /// </summary>
  164. /// <param name="v">uint16_t bit representation of float16</param>
  165. /// <returns>new instance of Float16_t</returns>
  166. constexpr static Float16_t FromBits(uint16_t v) noexcept { return Float16_t(v); }
  167. /// <summary>
  168. /// __ctor from float. Float is converted into float16 16-bit representation.
  169. /// </summary>
  170. /// <param name="v">float value</param>
  171. explicit Float16_t(float v) noexcept { val = Base::ToUint16Impl(v); }
  172. /// <summary>
  173. /// Converts float16 to float
  174. /// </summary>
  175. /// <returns>float representation of float16 value</returns>
  176. float ToFloat() const noexcept { return Base::ToFloatImpl(); }
  177. /// <summary>
  178. /// Checks if the value is negative
  179. /// </summary>
  180. /// <returns>true if negative</returns>
  181. using Base::IsNegative;
  182. /// <summary>
  183. /// Tests if the value is NaN
  184. /// </summary>
  185. /// <returns>true if NaN</returns>
  186. using Base::IsNaN;
  187. /// <summary>
  188. /// Tests if the value is finite
  189. /// </summary>
  190. /// <returns>true if finite</returns>
  191. using Base::IsFinite;
  192. /// <summary>
  193. /// Tests if the value represents positive infinity.
  194. /// </summary>
  195. /// <returns>true if positive infinity</returns>
  196. using Base::IsPositiveInfinity;
  197. /// <summary>
  198. /// Tests if the value represents negative infinity
  199. /// </summary>
  200. /// <returns>true if negative infinity</returns>
  201. using Base::IsNegativeInfinity;
  202. /// <summary>
  203. /// Tests if the value is either positive or negative infinity.
  204. /// </summary>
  205. /// <returns>True if absolute value is infinity</returns>
  206. using Base::IsInfinity;
  207. /// <summary>
  208. /// Tests if the value is NaN or zero. Useful for comparisons.
  209. /// </summary>
  210. /// <returns>True if NaN or zero.</returns>
  211. using Base::IsNaNOrZero;
  212. /// <summary>
  213. /// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
  214. /// </summary>
  215. /// <returns>True if so</returns>
  216. using Base::IsNormal;
  217. /// <summary>
  218. /// Tests if the value is subnormal (denormal).
  219. /// </summary>
  220. /// <returns>True if so</returns>
  221. using Base::IsSubnormal;
  222. /// <summary>
  223. /// Creates an instance that represents absolute value.
  224. /// </summary>
  225. /// <returns>Absolute value</returns>
  226. using Base::Abs;
  227. /// <summary>
  228. /// Creates a new instance with the sign flipped.
  229. /// </summary>
  230. /// <returns>Flipped sign instance</returns>
  231. using Base::Negate;
  232. /// <summary>
  233. /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
  234. /// for two values by or'ing the private bits together and stripping the sign. They are both zero,
  235. /// and therefore equivalent, if the resulting value is still zero.
  236. /// </summary>
  237. /// <param name="lhs">first value</param>
  238. /// <param name="rhs">second value</param>
  239. /// <returns>True if both arguments represent zero</returns>
  240. using Base::AreZero;
  241. /// <summary>
  242. /// User defined conversion operator. Converts Float16_t to float.
  243. /// </summary>
  244. explicit operator float() const noexcept { return ToFloat(); }
  245. using Base::operator==;
  246. using Base::operator!=;
  247. using Base::operator<;
  248. };
  249. static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match");
  250. /** \brief bfloat16 (Brain Floating Point) data type
  251. *
  252. * \details This struct is used for converting float to bfloat16 and back
  253. * so the user could feed inputs and fetch outputs using these type.
  254. *
  255. * The size of the structure should align with uint16_t and one can freely cast
  256. * uint16_t buffers to/from Ort::BFloat16_t to feed and retrieve data.
  257. *
  258. * \code{.unparsed}
  259. * // This example demonstrates converion from float to float16
  260. * constexpr float values[] = {1.f, 2.f, 3.f, 4.f, 5.f};
  261. * std::vector<Ort::BFloat16_t> bfp16_values;
  262. * bfp16_values.reserve(std::size(values));
  263. * std::transform(std::begin(values), std::end(values), std::back_inserter(bfp16_values),
  264. * [](float value) { return Ort::BFloat16_t(value); });
  265. *
  266. * \endcode
  267. */
  268. struct BFloat16_t : onnxruntime_float16::BFloat16Impl<BFloat16_t> {
  269. private:
  270. /// <summary>
  271. /// Constructor from a uint16_t representation of bfloat16
  272. /// used in FromBits() to escape overload resolution issue with
  273. /// constructor from float.
  274. /// No conversion is done.
  275. /// </summary>
  276. /// <param name="v">16-bit bfloat16 value</param>
  277. constexpr explicit BFloat16_t(uint16_t v) noexcept { val = v; }
  278. public:
  279. using Base = onnxruntime_float16::BFloat16Impl<BFloat16_t>;
  280. BFloat16_t() = default;
  281. /// <summary>
  282. /// Explicit conversion to uint16_t representation of bfloat16.
  283. /// </summary>
  284. /// <param name="v">uint16_t bit representation of bfloat16</param>
  285. /// <returns>new instance of BFloat16_t</returns>
  286. static constexpr BFloat16_t FromBits(uint16_t v) noexcept { return BFloat16_t(v); }
  287. /// <summary>
  288. /// __ctor from float. Float is converted into bfloat16 16-bit representation.
  289. /// </summary>
  290. /// <param name="v">float value</param>
  291. explicit BFloat16_t(float v) noexcept { val = Base::ToUint16Impl(v); }
  292. /// <summary>
  293. /// Converts bfloat16 to float
  294. /// </summary>
  295. /// <returns>float representation of bfloat16 value</returns>
  296. float ToFloat() const noexcept { return Base::ToFloatImpl(); }
  297. /// <summary>
  298. /// Checks if the value is negative
  299. /// </summary>
  300. /// <returns>true if negative</returns>
  301. using Base::IsNegative;
  302. /// <summary>
  303. /// Tests if the value is NaN
  304. /// </summary>
  305. /// <returns>true if NaN</returns>
  306. using Base::IsNaN;
  307. /// <summary>
  308. /// Tests if the value is finite
  309. /// </summary>
  310. /// <returns>true if finite</returns>
  311. using Base::IsFinite;
  312. /// <summary>
  313. /// Tests if the value represents positive infinity.
  314. /// </summary>
  315. /// <returns>true if positive infinity</returns>
  316. using Base::IsPositiveInfinity;
  317. /// <summary>
  318. /// Tests if the value represents negative infinity
  319. /// </summary>
  320. /// <returns>true if negative infinity</returns>
  321. using Base::IsNegativeInfinity;
  322. /// <summary>
  323. /// Tests if the value is either positive or negative infinity.
  324. /// </summary>
  325. /// <returns>True if absolute value is infinity</returns>
  326. using Base::IsInfinity;
  327. /// <summary>
  328. /// Tests if the value is NaN or zero. Useful for comparisons.
  329. /// </summary>
  330. /// <returns>True if NaN or zero.</returns>
  331. using Base::IsNaNOrZero;
  332. /// <summary>
  333. /// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
  334. /// </summary>
  335. /// <returns>True if so</returns>
  336. using Base::IsNormal;
  337. /// <summary>
  338. /// Tests if the value is subnormal (denormal).
  339. /// </summary>
  340. /// <returns>True if so</returns>
  341. using Base::IsSubnormal;
  342. /// <summary>
  343. /// Creates an instance that represents absolute value.
  344. /// </summary>
  345. /// <returns>Absolute value</returns>
  346. using Base::Abs;
  347. /// <summary>
  348. /// Creates a new instance with the sign flipped.
  349. /// </summary>
  350. /// <returns>Flipped sign instance</returns>
  351. using Base::Negate;
  352. /// <summary>
  353. /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
  354. /// for two values by or'ing the private bits together and stripping the sign. They are both zero,
  355. /// and therefore equivalent, if the resulting value is still zero.
  356. /// </summary>
  357. /// <param name="lhs">first value</param>
  358. /// <param name="rhs">second value</param>
  359. /// <returns>True if both arguments represent zero</returns>
  360. using Base::AreZero;
  361. /// <summary>
  362. /// User defined conversion operator. Converts BFloat16_t to float.
  363. /// </summary>
  364. explicit operator float() const noexcept { return ToFloat(); }
  365. // We do not have an inherited impl for the below operators
  366. // as the internal class implements them a little differently
  367. bool operator==(const BFloat16_t& rhs) const noexcept;
  368. bool operator!=(const BFloat16_t& rhs) const noexcept { return !(*this == rhs); }
  369. bool operator<(const BFloat16_t& rhs) const noexcept;
  370. };
  371. static_assert(sizeof(BFloat16_t) == sizeof(uint16_t), "Sizes must match");
  372. /** \brief float8e4m3fn (Float8 Floating Point) data type
  373. * \details It is necessary for type dispatching to make use of C++ API
  374. * The type is implicitly convertible to/from uint8_t.
  375. * See https://onnx.ai/onnx/technical/float8.html for further details.
  376. */
  377. struct Float8E4M3FN_t {
  378. uint8_t value;
  379. constexpr Float8E4M3FN_t() noexcept : value(0) {}
  380. constexpr Float8E4M3FN_t(uint8_t v) noexcept : value(v) {}
  381. constexpr operator uint8_t() const noexcept { return value; }
  382. // nan values are treated like any other value for operator ==, !=
  383. constexpr bool operator==(const Float8E4M3FN_t& rhs) const noexcept { return value == rhs.value; };
  384. constexpr bool operator!=(const Float8E4M3FN_t& rhs) const noexcept { return value != rhs.value; };
  385. };
  386. static_assert(sizeof(Float8E4M3FN_t) == sizeof(uint8_t), "Sizes must match");
  387. /** \brief float8e4m3fnuz (Float8 Floating Point) data type
  388. * \details It is necessary for type dispatching to make use of C++ API
  389. * The type is implicitly convertible to/from uint8_t.
  390. * See https://onnx.ai/onnx/technical/float8.html for further details.
  391. */
  392. struct Float8E4M3FNUZ_t {
  393. uint8_t value;
  394. constexpr Float8E4M3FNUZ_t() noexcept : value(0) {}
  395. constexpr Float8E4M3FNUZ_t(uint8_t v) noexcept : value(v) {}
  396. constexpr operator uint8_t() const noexcept { return value; }
  397. // nan values are treated like any other value for operator ==, !=
  398. constexpr bool operator==(const Float8E4M3FNUZ_t& rhs) const noexcept { return value == rhs.value; };
  399. constexpr bool operator!=(const Float8E4M3FNUZ_t& rhs) const noexcept { return value != rhs.value; };
  400. };
  401. static_assert(sizeof(Float8E4M3FNUZ_t) == sizeof(uint8_t), "Sizes must match");
  402. /** \brief float8e5m2 (Float8 Floating Point) data type
  403. * \details It is necessary for type dispatching to make use of C++ API
  404. * The type is implicitly convertible to/from uint8_t.
  405. * See https://onnx.ai/onnx/technical/float8.html for further details.
  406. */
  407. struct Float8E5M2_t {
  408. uint8_t value;
  409. constexpr Float8E5M2_t() noexcept : value(0) {}
  410. constexpr Float8E5M2_t(uint8_t v) noexcept : value(v) {}
  411. constexpr operator uint8_t() const noexcept { return value; }
  412. // nan values are treated like any other value for operator ==, !=
  413. constexpr bool operator==(const Float8E5M2_t& rhs) const noexcept { return value == rhs.value; };
  414. constexpr bool operator!=(const Float8E5M2_t& rhs) const noexcept { return value != rhs.value; };
  415. };
  416. static_assert(sizeof(Float8E5M2_t) == sizeof(uint8_t), "Sizes must match");
  417. /** \brief float8e5m2fnuz (Float8 Floating Point) data type
  418. * \details It is necessary for type dispatching to make use of C++ API
  419. * The type is implicitly convertible to/from uint8_t.
  420. * See https://onnx.ai/onnx/technical/float8.html for further details.
  421. */
  422. struct Float8E5M2FNUZ_t {
  423. uint8_t value;
  424. constexpr Float8E5M2FNUZ_t() noexcept : value(0) {}
  425. constexpr Float8E5M2FNUZ_t(uint8_t v) noexcept : value(v) {}
  426. constexpr operator uint8_t() const noexcept { return value; }
  427. // nan values are treated like any other value for operator ==, !=
  428. constexpr bool operator==(const Float8E5M2FNUZ_t& rhs) const noexcept { return value == rhs.value; };
  429. constexpr bool operator!=(const Float8E5M2FNUZ_t& rhs) const noexcept { return value != rhs.value; };
  430. };
  431. static_assert(sizeof(Float8E5M2FNUZ_t) == sizeof(uint8_t), "Sizes must match");
  432. namespace detail {
  433. // This is used internally by the C++ API. This macro is to make it easy to generate overloaded methods for all of the various OrtRelease* functions for every Ort* type
  434. // This can't be done in the C API since C doesn't have function overloading.
  435. #define ORT_DEFINE_RELEASE(NAME) \
  436. inline void OrtRelease(Ort##NAME* ptr) { GetApi().Release##NAME(ptr); }
  437. ORT_DEFINE_RELEASE(Allocator);
  438. ORT_DEFINE_RELEASE(MemoryInfo);
  439. ORT_DEFINE_RELEASE(CustomOpDomain);
  440. ORT_DEFINE_RELEASE(ThreadingOptions);
  441. ORT_DEFINE_RELEASE(Env);
  442. ORT_DEFINE_RELEASE(RunOptions);
  443. ORT_DEFINE_RELEASE(Session);
  444. ORT_DEFINE_RELEASE(SessionOptions);
  445. ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo);
  446. ORT_DEFINE_RELEASE(SequenceTypeInfo);
  447. ORT_DEFINE_RELEASE(MapTypeInfo);
  448. ORT_DEFINE_RELEASE(TypeInfo);
  449. ORT_DEFINE_RELEASE(Value);
  450. ORT_DEFINE_RELEASE(ModelMetadata);
  451. ORT_DEFINE_RELEASE(IoBinding);
  452. ORT_DEFINE_RELEASE(ArenaCfg);
  453. ORT_DEFINE_RELEASE(Status);
  454. ORT_DEFINE_RELEASE(OpAttr);
  455. ORT_DEFINE_RELEASE(Op);
  456. ORT_DEFINE_RELEASE(KernelInfo);
  457. #undef ORT_DEFINE_RELEASE
  458. /** \brief This is a tagging template type. Use it with Base<T> to indicate that the C++ interface object
  459. * has no ownership of the underlying C object.
  460. */
  461. template <typename T>
  462. struct Unowned {
  463. using Type = T;
  464. };
  465. /** \brief Used internally by the C++ API. C++ wrapper types inherit from this.
  466. * This is a zero cost abstraction to wrap the C API objects and delete them on destruction.
  467. *
  468. * All of the C++ classes
  469. * a) serve as containers for pointers to objects that are created by the underlying C API.
  470. * Their size is just a pointer size, no need to dynamically allocate them. Use them by value.
  471. * b) Each of struct XXXX, XXX instances function as smart pointers to the underlying C API objects.
  472. * they would release objects owned automatically when going out of scope, they are move-only.
  473. * c) ConstXXXX and UnownedXXX structs function as non-owning, copyable containers for the above pointers.
  474. * ConstXXXX allow calling const interfaces only. They give access to objects that are owned by somebody else
  475. * such as Onnxruntime or instances of XXXX classes.
  476. * d) serve convenient interfaces that return C++ objects and further enhance exception and type safety so they can be used
  477. * in C++ code.
  478. *
  479. */
  480. /// <summary>
  481. /// This is a non-const pointer holder that is move-only. Disposes of the pointer on destruction.
  482. /// </summary>
  483. template <typename T>
  484. struct Base {
  485. using contained_type = T;
  486. constexpr Base() = default;
  487. constexpr explicit Base(contained_type* p) noexcept : p_{p} {}
  488. ~Base() { OrtRelease(p_); }
  489. Base(const Base&) = delete;
  490. Base& operator=(const Base&) = delete;
  491. Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
  492. Base& operator=(Base&& v) noexcept {
  493. OrtRelease(p_);
  494. p_ = v.release();
  495. return *this;
  496. }
  497. constexpr operator contained_type*() const noexcept { return p_; }
  498. /// \brief Relinquishes ownership of the contained C object pointer
  499. /// The underlying object is not destroyed
  500. contained_type* release() {
  501. T* p = p_;
  502. p_ = nullptr;
  503. return p;
  504. }
  505. protected:
  506. contained_type* p_{};
  507. };
  508. // Undefined. For const types use Base<Unowned<const T>>
  509. template <typename T>
  510. struct Base<const T>;
  511. /// <summary>
  512. /// Covers unowned pointers owned by either the ORT
  513. /// or some other instance of CPP wrappers.
  514. /// Used for ConstXXX and UnownedXXXX types that are copyable.
  515. /// Also convenient to wrap raw OrtXX pointers .
  516. /// </summary>
  517. /// <typeparam name="T"></typeparam>
  518. template <typename T>
  519. struct Base<Unowned<T>> {
  520. using contained_type = typename Unowned<T>::Type;
  521. constexpr Base() = default;
  522. constexpr explicit Base(contained_type* p) noexcept : p_{p} {}
  523. ~Base() = default;
  524. Base(const Base&) = default;
  525. Base& operator=(const Base&) = default;
  526. Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
  527. Base& operator=(Base&& v) noexcept {
  528. p_ = nullptr;
  529. std::swap(p_, v.p_);
  530. return *this;
  531. }
  532. constexpr operator contained_type*() const noexcept { return p_; }
  533. protected:
  534. contained_type* p_{};
  535. };
  536. // Light functor to release memory with OrtAllocator
  537. struct AllocatedFree {
  538. OrtAllocator* allocator_;
  539. explicit AllocatedFree(OrtAllocator* allocator)
  540. : allocator_(allocator) {}
  541. void operator()(void* ptr) const {
  542. if (ptr) allocator_->Free(allocator_, ptr);
  543. }
  544. };
  545. } // namespace detail
  546. struct AllocatorWithDefaultOptions;
  547. struct Env;
  548. struct TypeInfo;
  549. struct Value;
  550. struct ModelMetadata;
  551. /** \brief unique_ptr typedef used to own strings allocated by OrtAllocators
  552. * and release them at the end of the scope. The lifespan of the given allocator
  553. * must eclipse the lifespan of AllocatedStringPtr instance
  554. */
  555. using AllocatedStringPtr = std::unique_ptr<char, detail::AllocatedFree>;
  556. /** \brief The Status that holds ownership of OrtStatus received from C API
  557. * Use it to safely destroy OrtStatus* returned from the C API. Use appropriate
  558. * constructors to construct an instance of a Status object from exceptions.
  559. */
  560. struct Status : detail::Base<OrtStatus> {
  561. explicit Status(std::nullptr_t) noexcept {} ///< Create an empty object, must be assigned a valid one to be used
  562. explicit Status(OrtStatus* status) noexcept; ///< Takes ownership of OrtStatus instance returned from the C API.
  563. explicit Status(const Exception&) noexcept; ///< Creates status instance out of exception
  564. explicit Status(const std::exception&) noexcept; ///< Creates status instance out of exception
  565. Status(const char* message, OrtErrorCode code) noexcept; ///< Creates status instance out of null-terminated string message.
  566. std::string GetErrorMessage() const;
  567. OrtErrorCode GetErrorCode() const;
  568. bool IsOK() const noexcept; ///< Returns true if instance represents an OK (non-error) status.
  569. };
  570. /** \brief The ThreadingOptions
  571. *
  572. * The ThreadingOptions used for set global threadpools' options of The Env.
  573. */
  574. struct ThreadingOptions : detail::Base<OrtThreadingOptions> {
  575. /// \brief Wraps OrtApi::CreateThreadingOptions
  576. ThreadingOptions();
  577. /// \brief Wraps OrtApi::SetGlobalIntraOpNumThreads
  578. ThreadingOptions& SetGlobalIntraOpNumThreads(int intra_op_num_threads);
  579. /// \brief Wraps OrtApi::SetGlobalInterOpNumThreads
  580. ThreadingOptions& SetGlobalInterOpNumThreads(int inter_op_num_threads);
  581. /// \brief Wraps OrtApi::SetGlobalSpinControl
  582. ThreadingOptions& SetGlobalSpinControl(int allow_spinning);
  583. /// \brief Wraps OrtApi::SetGlobalDenormalAsZero
  584. ThreadingOptions& SetGlobalDenormalAsZero();
  585. /// \brief Wraps OrtApi::SetGlobalCustomCreateThreadFn
  586. ThreadingOptions& SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn);
  587. /// \brief Wraps OrtApi::SetGlobalCustomThreadCreationOptions
  588. ThreadingOptions& SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options);
  589. /// \brief Wraps OrtApi::SetGlobalCustomJoinThreadFn
  590. ThreadingOptions& SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn);
  591. };
  592. /** \brief The Env (Environment)
  593. *
  594. * The Env holds the logging state used by all other objects.
  595. * <b>Note:</b> One Env must be created before using any other Onnxruntime functionality
  596. */
  597. struct Env : detail::Base<OrtEnv> {
  598. explicit Env(std::nullptr_t) {} ///< Create an empty Env object, must be assigned a valid one to be used
  599. /// \brief Wraps OrtApi::CreateEnv
  600. Env(OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
  601. /// \brief Wraps OrtApi::CreateEnvWithCustomLogger
  602. Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param);
  603. /// \brief Wraps OrtApi::CreateEnvWithGlobalThreadPools
  604. Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
  605. /// \brief Wraps OrtApi::CreateEnvWithCustomLoggerAndGlobalThreadPools
  606. Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param,
  607. OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
  608. /// \brief C Interop Helper
  609. explicit Env(OrtEnv* p) : Base<OrtEnv>{p} {}
  610. Env& EnableTelemetryEvents(); ///< Wraps OrtApi::EnableTelemetryEvents
  611. Env& DisableTelemetryEvents(); ///< Wraps OrtApi::DisableTelemetryEvents
  612. Env& UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level); ///< Wraps OrtApi::UpdateEnvWithCustomLogLevel
  613. Env& CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocator
  614. Env& CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo* mem_info, const std::unordered_map<std::string, std::string>& options, const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocatorV2
  615. };
  616. /** \brief Custom Op Domain
  617. *
  618. */
  619. struct CustomOpDomain : detail::Base<OrtCustomOpDomain> {
  620. explicit CustomOpDomain(std::nullptr_t) {} ///< Create an empty CustomOpDomain object, must be assigned a valid one to be used
  621. /// \brief Wraps OrtApi::CreateCustomOpDomain
  622. explicit CustomOpDomain(const char* domain);
  623. // This does not take ownership of the op, simply registers it.
  624. void Add(const OrtCustomOp* op); ///< Wraps CustomOpDomain_Add
  625. };
  626. /** \brief RunOptions
  627. *
  628. */
  629. struct RunOptions : detail::Base<OrtRunOptions> {
  630. explicit RunOptions(std::nullptr_t) {} ///< Create an empty RunOptions object, must be assigned a valid one to be used
  631. RunOptions(); ///< Wraps OrtApi::CreateRunOptions
  632. RunOptions& SetRunLogVerbosityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogVerbosityLevel
  633. int GetRunLogVerbosityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogVerbosityLevel
  634. RunOptions& SetRunLogSeverityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogSeverityLevel
  635. int GetRunLogSeverityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogSeverityLevel
  636. RunOptions& SetRunTag(const char* run_tag); ///< wraps OrtApi::RunOptionsSetRunTag
  637. const char* GetRunTag() const; ///< Wraps OrtApi::RunOptionsGetRunTag
  638. RunOptions& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddRunConfigEntry
  639. /** \brief Terminates all currently executing Session::Run calls that were made using this RunOptions instance
  640. *
  641. * If a currently executing session needs to be force terminated, this can be called from another thread to force it to fail with an error
  642. * Wraps OrtApi::RunOptionsSetTerminate
  643. */
  644. RunOptions& SetTerminate();
  645. /** \brief Clears the terminate flag so this RunOptions instance can be used in a new Session::Run call without it instantly terminating
  646. *
  647. * Wraps OrtApi::RunOptionsUnsetTerminate
  648. */
  649. RunOptions& UnsetTerminate();
  650. };
  651. namespace detail {
  652. // Utility function that returns a SessionOption config entry key for a specific custom operator.
  653. // Ex: custom_op.[custom_op_name].[config]
  654. std::string MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config);
  655. } // namespace detail
  656. /// <summary>
  657. /// Class that represents session configuration entries for one or more custom operators.
  658. ///
  659. /// Example:
  660. /// Ort::CustomOpConfigs op_configs;
  661. /// op_configs.AddConfig("my_custom_op", "device_type", "CPU");
  662. ///
  663. /// Passed to Ort::SessionOptions::RegisterCustomOpsLibrary.
  664. /// </summary>
  665. struct CustomOpConfigs {
  666. CustomOpConfigs() = default;
  667. ~CustomOpConfigs() = default;
  668. CustomOpConfigs(const CustomOpConfigs&) = default;
  669. CustomOpConfigs& operator=(const CustomOpConfigs&) = default;
  670. CustomOpConfigs(CustomOpConfigs&& o) = default;
  671. CustomOpConfigs& operator=(CustomOpConfigs&& o) = default;
  672. /** \brief Adds a session configuration entry/value for a specific custom operator.
  673. *
  674. * \param custom_op_name The name of the custom operator for which to add a configuration entry.
  675. * Must match the name returned by the CustomOp's GetName() method.
  676. * \param config_key The name of the configuration entry.
  677. * \param config_value The value of the configuration entry.
  678. * \return A reference to this object to enable call chaining.
  679. */
  680. CustomOpConfigs& AddConfig(const char* custom_op_name, const char* config_key, const char* config_value);
  681. /** \brief Returns a flattened map of custom operator configuration entries and their values.
  682. *
  683. * The keys has been flattened to include both the custom operator name and the configuration entry key name.
  684. * For example, a prior call to AddConfig("my_op", "key", "value") corresponds to the flattened key/value pair
  685. * {"my_op.key", "value"}.
  686. *
  687. * \return An unordered map of flattened configurations.
  688. */
  689. const std::unordered_map<std::string, std::string>& GetFlattenedConfigs() const;
  690. private:
  691. std::unordered_map<std::string, std::string> flat_configs_;
  692. };
  693. /** \brief Options object used when creating a new Session object
  694. *
  695. * Wraps ::OrtSessionOptions object and methods
  696. */
  697. struct SessionOptions;
  698. namespace detail {
  699. // we separate const-only methods because passing const ptr to non-const methods
  700. // is only discovered when inline methods are compiled which is counter-intuitive
  701. template <typename T>
  702. struct ConstSessionOptionsImpl : Base<T> {
  703. using B = Base<T>;
  704. using B::B;
  705. SessionOptions Clone() const; ///< Creates and returns a copy of this SessionOptions object. Wraps OrtApi::CloneSessionOptions
  706. std::string GetConfigEntry(const char* config_key) const; ///< Wraps OrtApi::GetSessionConfigEntry
  707. bool HasConfigEntry(const char* config_key) const; ///< Wraps OrtApi::HasSessionConfigEntry
  708. std::string GetConfigEntryOrDefault(const char* config_key, const std::string& def);
  709. };
  710. template <typename T>
  711. struct SessionOptionsImpl : ConstSessionOptionsImpl<T> {
  712. using B = ConstSessionOptionsImpl<T>;
  713. using B::B;
  714. SessionOptionsImpl& SetIntraOpNumThreads(int intra_op_num_threads); ///< Wraps OrtApi::SetIntraOpNumThreads
  715. SessionOptionsImpl& SetInterOpNumThreads(int inter_op_num_threads); ///< Wraps OrtApi::SetInterOpNumThreads
  716. SessionOptionsImpl& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); ///< Wraps OrtApi::SetSessionGraphOptimizationLevel
  717. SessionOptionsImpl& SetDeterministicCompute(bool value); ///< Wraps OrtApi::SetDeterministicCompute
  718. SessionOptionsImpl& EnableCpuMemArena(); ///< Wraps OrtApi::EnableCpuMemArena
  719. SessionOptionsImpl& DisableCpuMemArena(); ///< Wraps OrtApi::DisableCpuMemArena
  720. SessionOptionsImpl& SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_file); ///< Wraps OrtApi::SetOptimizedModelFilePath
  721. SessionOptionsImpl& EnableProfiling(const ORTCHAR_T* profile_file_prefix); ///< Wraps OrtApi::EnableProfiling
  722. SessionOptionsImpl& DisableProfiling(); ///< Wraps OrtApi::DisableProfiling
  723. SessionOptionsImpl& EnableOrtCustomOps(); ///< Wraps OrtApi::EnableOrtCustomOps
  724. SessionOptionsImpl& EnableMemPattern(); ///< Wraps OrtApi::EnableMemPattern
  725. SessionOptionsImpl& DisableMemPattern(); ///< Wraps OrtApi::DisableMemPattern
  726. SessionOptionsImpl& SetExecutionMode(ExecutionMode execution_mode); ///< Wraps OrtApi::SetSessionExecutionMode
  727. SessionOptionsImpl& SetLogId(const char* logid); ///< Wraps OrtApi::SetSessionLogId
  728. SessionOptionsImpl& SetLogSeverityLevel(int level); ///< Wraps OrtApi::SetSessionLogSeverityLevel
  729. SessionOptionsImpl& Add(OrtCustomOpDomain* custom_op_domain); ///< Wraps OrtApi::AddCustomOpDomain
  730. SessionOptionsImpl& DisablePerSessionThreads(); ///< Wraps OrtApi::DisablePerSessionThreads
  731. SessionOptionsImpl& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddSessionConfigEntry
  732. SessionOptionsImpl& AddInitializer(const char* name, const OrtValue* ort_val); ///< Wraps OrtApi::AddInitializer
  733. SessionOptionsImpl& AddExternalInitializers(const std::vector<std::string>& names, const std::vector<Value>& ort_values); ///< Wraps OrtApi::AddExternalInitializers
  734. SessionOptionsImpl& AddExternalInitializersFromFilesInMemory(const std::vector<std::basic_string<ORTCHAR_T>>& external_initializer_file_names,
  735. const std::vector<char*>& external_initializer_file_buffer_array,
  736. const std::vector<size_t>& external_initializer_file_lengths); ///< Wraps OrtApi::AddExternalInitializersFromFilesInMemory
  737. SessionOptionsImpl& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA
  738. SessionOptionsImpl& AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA_V2
  739. SessionOptionsImpl& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM
  740. SessionOptionsImpl& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO
  741. ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO_V2
  742. SessionOptionsImpl& AppendExecutionProvider_OpenVINO_V2(const std::unordered_map<std::string, std::string>& provider_options = {});
  743. SessionOptionsImpl& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
  744. SessionOptionsImpl& AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
  745. SessionOptionsImpl& AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX
  746. ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CANN
  747. SessionOptionsImpl& AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options);
  748. ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_Dnnl
  749. SessionOptionsImpl& AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options);
  750. /// Wraps OrtApi::SessionOptionsAppendExecutionProvider. Currently supports QNN, SNPE and XNNPACK.
  751. SessionOptionsImpl& AppendExecutionProvider(const std::string& provider_name,
  752. const std::unordered_map<std::string, std::string>& provider_options = {});
  753. SessionOptionsImpl& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn
  754. SessionOptionsImpl& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options); ///< Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions
  755. SessionOptionsImpl& SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomJoinThreadFn
  756. ///< Registers the custom operator from the specified shared library via OrtApi::RegisterCustomOpsLibrary_V2.
  757. ///< The custom operator configurations are optional. If provided, custom operator configs are set via
  758. ///< OrtApi::AddSessionConfigEntry.
  759. SessionOptionsImpl& RegisterCustomOpsLibrary(const ORTCHAR_T* library_name, const CustomOpConfigs& custom_op_configs = {});
  760. SessionOptionsImpl& RegisterCustomOpsUsingFunction(const char* function_name); ///< Wraps OrtApi::RegisterCustomOpsUsingFunction
  761. ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_VitisAI
  762. SessionOptionsImpl& AppendExecutionProvider_VitisAI(const std::unordered_map<std::string, std::string>& provider_options = {});
  763. };
  764. } // namespace detail
  765. using UnownedSessionOptions = detail::SessionOptionsImpl<detail::Unowned<OrtSessionOptions>>;
  766. using ConstSessionOptions = detail::ConstSessionOptionsImpl<detail::Unowned<const OrtSessionOptions>>;
  767. /** \brief Wrapper around ::OrtSessionOptions
  768. *
  769. */
  770. struct SessionOptions : detail::SessionOptionsImpl<OrtSessionOptions> {
  771. explicit SessionOptions(std::nullptr_t) {} ///< Create an empty SessionOptions object, must be assigned a valid one to be used
  772. SessionOptions(); ///< Wraps OrtApi::CreateSessionOptions
  773. explicit SessionOptions(OrtSessionOptions* p) : SessionOptionsImpl<OrtSessionOptions>{p} {} ///< Used for interop with the C API
  774. UnownedSessionOptions GetUnowned() const { return UnownedSessionOptions{this->p_}; }
  775. ConstSessionOptions GetConst() const { return ConstSessionOptions{this->p_}; }
  776. };
  777. /** \brief Wrapper around ::OrtModelMetadata
  778. *
  779. */
  780. struct ModelMetadata : detail::Base<OrtModelMetadata> {
  781. explicit ModelMetadata(std::nullptr_t) {} ///< Create an empty ModelMetadata object, must be assigned a valid one to be used
  782. explicit ModelMetadata(OrtModelMetadata* p) : Base<OrtModelMetadata>{p} {} ///< Used for interop with the C API
  783. /** \brief Returns a copy of the producer name.
  784. *
  785. * \param allocator to allocate memory for the copy of the name returned
  786. * \return a instance of smart pointer that would deallocate the buffer when out of scope.
  787. * The OrtAllocator instances must be valid at the point of memory release.
  788. */
  789. AllocatedStringPtr GetProducerNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetProducerName
  790. /** \brief Returns a copy of the graph name.
  791. *
  792. * \param allocator to allocate memory for the copy of the name returned
  793. * \return a instance of smart pointer that would deallocate the buffer when out of scope.
  794. * The OrtAllocator instances must be valid at the point of memory release.
  795. */
  796. AllocatedStringPtr GetGraphNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphName
  797. /** \brief Returns a copy of the domain name.
  798. *
  799. * \param allocator to allocate memory for the copy of the name returned
  800. * \return a instance of smart pointer that would deallocate the buffer when out of scope.
  801. * The OrtAllocator instances must be valid at the point of memory release.
  802. */
  803. AllocatedStringPtr GetDomainAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDomain
  804. /** \brief Returns a copy of the description.
  805. *
  806. * \param allocator to allocate memory for the copy of the string returned
  807. * \return a instance of smart pointer that would deallocate the buffer when out of scope.
  808. * The OrtAllocator instances must be valid at the point of memory release.
  809. */
  810. AllocatedStringPtr GetDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDescription
  811. /** \brief Returns a copy of the graph description.
  812. *
  813. * \param allocator to allocate memory for the copy of the string returned
  814. * \return a instance of smart pointer that would deallocate the buffer when out of scope.
  815. * The OrtAllocator instances must be valid at the point of memory release.
  816. */
  817. AllocatedStringPtr GetGraphDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphDescription
  818. /** \brief Returns a vector of copies of the custom metadata keys.
  819. *
  820. * \param allocator to allocate memory for the copy of the string returned
  821. * \return a instance std::vector of smart pointers that would deallocate the buffers when out of scope.
  822. * The OrtAllocator instance must be valid at the point of memory release.
  823. */
  824. std::vector<AllocatedStringPtr> GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetCustomMetadataMapKeys
  825. /** \brief Looks up a value by a key in the Custom Metadata map
  826. *
  827. * \param key zero terminated string key to lookup
  828. * \param allocator to allocate memory for the copy of the string returned
  829. * \return a instance of smart pointer that would deallocate the buffer when out of scope.
  830. * maybe nullptr if key is not found.
  831. *
  832. * The OrtAllocator instances must be valid at the point of memory release.
  833. */
  834. AllocatedStringPtr LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataLookupCustomMetadataMap
  835. int64_t GetVersion() const; ///< Wraps OrtApi::ModelMetadataGetVersion
  836. };
  837. struct IoBinding;
  838. namespace detail {
  839. // we separate const-only methods because passing const ptr to non-const methods
  840. // is only discovered when inline methods are compiled which is counter-intuitive
  841. template <typename T>
  842. struct ConstSessionImpl : Base<T> {
  843. using B = Base<T>;
  844. using B::B;
  845. size_t GetInputCount() const; ///< Returns the number of model inputs
  846. size_t GetOutputCount() const; ///< Returns the number of model outputs
  847. size_t GetOverridableInitializerCount() const; ///< Returns the number of inputs that have defaults that can be overridden
  848. /** \brief Returns a copy of input name at the specified index.
  849. *
  850. * \param index must less than the value returned by GetInputCount()
  851. * \param allocator to allocate memory for the copy of the name returned
  852. * \return a instance of smart pointer that would deallocate the buffer when out of scope.
  853. * The OrtAllocator instances must be valid at the point of memory release.
  854. */
  855. AllocatedStringPtr GetInputNameAllocated(size_t index, OrtAllocator* allocator) const;
  856. /** \brief Returns a copy of output name at then specified index.
  857. *
  858. * \param index must less than the value returned by GetOutputCount()
  859. * \param allocator to allocate memory for the copy of the name returned
  860. * \return a instance of smart pointer that would deallocate the buffer when out of scope.
  861. * The OrtAllocator instances must be valid at the point of memory release.
  862. */
  863. AllocatedStringPtr GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const;
  864. /** \brief Returns a copy of the overridable initializer name at then specified index.
  865. *
  866. * \param index must less than the value returned by GetOverridableInitializerCount()
  867. * \param allocator to allocate memory for the copy of the name returned
  868. * \return a instance of smart pointer that would deallocate the buffer when out of scope.
  869. * The OrtAllocator instances must be valid at the point of memory release.
  870. */
  871. AllocatedStringPtr GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetOverridableInitializerName
  872. uint64_t GetProfilingStartTimeNs() const; ///< Wraps OrtApi::SessionGetProfilingStartTimeNs
  873. ModelMetadata GetModelMetadata() const; ///< Wraps OrtApi::SessionGetModelMetadata
  874. TypeInfo GetInputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetInputTypeInfo
  875. TypeInfo GetOutputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOutputTypeInfo
  876. TypeInfo GetOverridableInitializerTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOverridableInitializerTypeInfo
  877. };
  878. template <typename T>
  879. struct SessionImpl : ConstSessionImpl<T> {
  880. using B = ConstSessionImpl<T>;
  881. using B::B;
  882. /** \brief Run the model returning results in an Ort allocated vector.
  883. *
  884. * Wraps OrtApi::Run
  885. *
  886. * The caller provides a list of inputs and a list of the desired outputs to return.
  887. *
  888. * See the output logs for more information on warnings/errors that occur while processing the model.
  889. * Common errors are.. (TODO)
  890. *
  891. * \param[in] run_options
  892. * \param[in] input_names Array of null terminated strings of length input_count that is the list of input names
  893. * \param[in] input_values Array of Value objects of length input_count that is the list of input values
  894. * \param[in] input_count Number of inputs (the size of the input_names & input_values arrays)
  895. * \param[in] output_names Array of C style strings of length output_count that is the list of output names
  896. * \param[in] output_count Number of outputs (the size of the output_names array)
  897. * \return A std::vector of Value objects that directly maps to the output_names array (eg. output_name[0] is the first entry of the returned vector)
  898. */
  899. std::vector<Value> Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
  900. const char* const* output_names, size_t output_count);
  901. /** \brief Run the model returning results in user provided outputs
  902. * Same as Run(const RunOptions&, const char* const*, const Value*, size_t,const char* const*, size_t)
  903. */
  904. void Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
  905. const char* const* output_names, Value* output_values, size_t output_count);
  906. void Run(const RunOptions& run_options, const IoBinding&); ///< Wraps OrtApi::RunWithBinding
  907. /** \brief Run the model asynchronously in a thread owned by intra op thread pool
  908. *
  909. * Wraps OrtApi::RunAsync
  910. *
  911. * \param[in] run_options
  912. * \param[in] input_names Array of null terminated UTF8 encoded strings of the input names
  913. * \param[in] input_values Array of Value objects of length input_count
  914. * \param[in] input_count Number of elements in the input_names and inputs arrays
  915. * \param[in] output_names Array of null terminated UTF8 encoded strings of the output names
  916. * \param[out] output_values Array of provided Values to be filled with outputs.
  917. * On calling RunAsync, output_values[i] could either be initialized by a null pointer or a preallocated OrtValue*.
  918. * Later, on invoking the callback, each output_values[i] of null will be filled with an OrtValue* allocated by onnxruntime.
  919. * Then, an OrtValue** pointer will be casted from output_values, and pass to the callback.
  920. * NOTE: it is customer's duty to finally release output_values and each of its member,
  921. * regardless of whether the member (Ort::Value) is allocated by onnxruntime or preallocated by the customer.
  922. * \param[in] output_count Number of elements in the output_names and outputs array
  923. * \param[in] callback Callback function on model run completion
  924. * \param[in] user_data User data that pass back to the callback
  925. */
  926. void RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
  927. const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data);
  928. /** \brief End profiling and return a copy of the profiling file name.
  929. *
  930. * \param allocator to allocate memory for the copy of the string returned
  931. * \return a instance of smart pointer that would deallocate the buffer when out of scope.
  932. * The OrtAllocator instances must be valid at the point of memory release.
  933. */
  934. AllocatedStringPtr EndProfilingAllocated(OrtAllocator* allocator); ///< Wraps OrtApi::SessionEndProfiling
  935. };
  936. } // namespace detail
  937. using ConstSession = detail::ConstSessionImpl<detail::Unowned<const OrtSession>>;
  938. using UnownedSession = detail::SessionImpl<detail::Unowned<OrtSession>>;
  939. /** \brief Wrapper around ::OrtSession
  940. *
  941. */
  942. struct Session : detail::SessionImpl<OrtSession> {
  943. explicit Session(std::nullptr_t) {} ///< Create an empty Session object, must be assigned a valid one to be used
  944. Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); ///< Wraps OrtApi::CreateSession
  945. Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options,
  946. OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer
  947. Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); ///< Wraps OrtApi::CreateSessionFromArray
  948. Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options,
  949. OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionFromArrayWithPrepackedWeightsContainer
  950. ConstSession GetConst() const { return ConstSession{this->p_}; }
  951. UnownedSession GetUnowned() const { return UnownedSession{this->p_}; }
  952. };
  953. namespace detail {
  954. template <typename T>
  955. struct MemoryInfoImpl : Base<T> {
  956. using B = Base<T>;
  957. using B::B;
  958. std::string GetAllocatorName() const;
  959. OrtAllocatorType GetAllocatorType() const;
  960. int GetDeviceId() const;
  961. OrtMemoryInfoDeviceType GetDeviceType() const;
  962. OrtMemType GetMemoryType() const;
  963. template <typename U>
  964. bool operator==(const MemoryInfoImpl<U>& o) const;
  965. };
  966. } // namespace detail
  967. // Const object holder that does not own the underlying object
  968. using ConstMemoryInfo = detail::MemoryInfoImpl<detail::Unowned<const OrtMemoryInfo>>;
  969. /** \brief Wrapper around ::OrtMemoryInfo
  970. *
  971. */
  972. struct MemoryInfo : detail::MemoryInfoImpl<OrtMemoryInfo> {
  973. static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1);
  974. explicit MemoryInfo(std::nullptr_t) {} ///< No instance is created
  975. explicit MemoryInfo(OrtMemoryInfo* p) : MemoryInfoImpl<OrtMemoryInfo>{p} {} ///< Take ownership of a pointer created by C Api
  976. MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type);
  977. ConstMemoryInfo GetConst() const { return ConstMemoryInfo{this->p_}; }
  978. };
  979. namespace detail {
  980. template <typename T>
  981. struct TensorTypeAndShapeInfoImpl : Base<T> {
  982. using B = Base<T>;
  983. using B::B;
  984. ONNXTensorElementDataType GetElementType() const; ///< Wraps OrtApi::GetTensorElementType
  985. size_t GetElementCount() const; ///< Wraps OrtApi::GetTensorShapeElementCount
  986. size_t GetDimensionsCount() const; ///< Wraps OrtApi::GetDimensionsCount
  987. /** \deprecated use GetShape() returning std::vector
  988. * [[deprecated]]
  989. * This interface is unsafe to use
  990. */
  991. [[deprecated("use GetShape()")]] void GetDimensions(int64_t* values, size_t values_count) const; ///< Wraps OrtApi::GetDimensions
  992. void GetSymbolicDimensions(const char** values, size_t values_count) const; ///< Wraps OrtApi::GetSymbolicDimensions
  993. std::vector<int64_t> GetShape() const; ///< Uses GetDimensionsCount & GetDimensions to return a std::vector of the shape
  994. };
  995. } // namespace detail
  996. using ConstTensorTypeAndShapeInfo = detail::TensorTypeAndShapeInfoImpl<detail::Unowned<const OrtTensorTypeAndShapeInfo>>;
  997. /** \brief Wrapper around ::OrtTensorTypeAndShapeInfo
  998. *
  999. */
  1000. struct TensorTypeAndShapeInfo : detail::TensorTypeAndShapeInfoImpl<OrtTensorTypeAndShapeInfo> {
  1001. explicit TensorTypeAndShapeInfo(std::nullptr_t) {} ///< Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used
  1002. explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : TensorTypeAndShapeInfoImpl{p} {} ///< Used for interop with the C API
  1003. ConstTensorTypeAndShapeInfo GetConst() const { return ConstTensorTypeAndShapeInfo{this->p_}; }
  1004. };
  1005. namespace detail {
  1006. template <typename T>
  1007. struct SequenceTypeInfoImpl : Base<T> {
  1008. using B = Base<T>;
  1009. using B::B;
  1010. TypeInfo GetSequenceElementType() const; ///< Wraps OrtApi::GetSequenceElementType
  1011. };
  1012. } // namespace detail
  1013. using ConstSequenceTypeInfo = detail::SequenceTypeInfoImpl<detail::Unowned<const OrtSequenceTypeInfo>>;
  1014. /** \brief Wrapper around ::OrtSequenceTypeInfo
  1015. *
  1016. */
  1017. struct SequenceTypeInfo : detail::SequenceTypeInfoImpl<OrtSequenceTypeInfo> {
  1018. explicit SequenceTypeInfo(std::nullptr_t) {} ///< Create an empty SequenceTypeInfo object, must be assigned a valid one to be used
  1019. explicit SequenceTypeInfo(OrtSequenceTypeInfo* p) : SequenceTypeInfoImpl<OrtSequenceTypeInfo>{p} {} ///< Used for interop with the C API
  1020. ConstSequenceTypeInfo GetConst() const { return ConstSequenceTypeInfo{this->p_}; }
  1021. };
  1022. namespace detail {
  1023. template <typename T>
  1024. struct OptionalTypeInfoImpl : Base<T> {
  1025. using B = Base<T>;
  1026. using B::B;
  1027. TypeInfo GetOptionalElementType() const; ///< Wraps OrtApi::CastOptionalTypeToContainedTypeInfo
  1028. };
  1029. } // namespace detail
  1030. // This is always owned by the TypeInfo and can only be obtained from it.
  1031. using ConstOptionalTypeInfo = detail::OptionalTypeInfoImpl<detail::Unowned<const OrtOptionalTypeInfo>>;
  1032. namespace detail {
  1033. template <typename T>
  1034. struct MapTypeInfoImpl : detail::Base<T> {
  1035. using B = Base<T>;
  1036. using B::B;
  1037. ONNXTensorElementDataType GetMapKeyType() const; ///< Wraps OrtApi::GetMapKeyType
  1038. TypeInfo GetMapValueType() const; ///< Wraps OrtApi::GetMapValueType
  1039. };
  1040. } // namespace detail
  1041. using ConstMapTypeInfo = detail::MapTypeInfoImpl<detail::Unowned<const OrtMapTypeInfo>>;
  1042. /** \brief Wrapper around ::OrtMapTypeInfo
  1043. *
  1044. */
  1045. struct MapTypeInfo : detail::MapTypeInfoImpl<OrtMapTypeInfo> {
  1046. explicit MapTypeInfo(std::nullptr_t) {} ///< Create an empty MapTypeInfo object, must be assigned a valid one to be used
  1047. explicit MapTypeInfo(OrtMapTypeInfo* p) : MapTypeInfoImpl<OrtMapTypeInfo>{p} {} ///< Used for interop with the C API
  1048. ConstMapTypeInfo GetConst() const { return ConstMapTypeInfo{this->p_}; }
  1049. };
  1050. namespace detail {
  1051. template <typename T>
  1052. struct TypeInfoImpl : detail::Base<T> {
  1053. using B = Base<T>;
  1054. using B::B;
  1055. ConstTensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const; ///< Wraps OrtApi::CastTypeInfoToTensorInfo
  1056. ConstSequenceTypeInfo GetSequenceTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToSequenceTypeInfo
  1057. ConstMapTypeInfo GetMapTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToMapTypeInfo
  1058. ConstOptionalTypeInfo GetOptionalTypeInfo() const; ///< wraps OrtApi::CastTypeInfoToOptionalTypeInfo
  1059. ONNXType GetONNXType() const;
  1060. };
  1061. } // namespace detail
  1062. /// <summary>
  1063. /// Contains a constant, unowned OrtTypeInfo that can be copied and passed around by value.
  1064. /// Provides access to const OrtTypeInfo APIs.
  1065. /// </summary>
  1066. using ConstTypeInfo = detail::TypeInfoImpl<detail::Unowned<const OrtTypeInfo>>;
  1067. /// <summary>
  1068. /// Type information that may contain either TensorTypeAndShapeInfo or
  1069. /// the information about contained sequence or map depending on the ONNXType.
  1070. /// </summary>
  1071. struct TypeInfo : detail::TypeInfoImpl<OrtTypeInfo> {
  1072. explicit TypeInfo(std::nullptr_t) {} ///< Create an empty TypeInfo object, must be assigned a valid one to be used
  1073. explicit TypeInfo(OrtTypeInfo* p) : TypeInfoImpl<OrtTypeInfo>{p} {} ///< C API Interop
  1074. ConstTypeInfo GetConst() const { return ConstTypeInfo{this->p_}; }
  1075. };
  1076. namespace detail {
  1077. // This structure is used to feed sparse tensor values
  1078. // information for use with FillSparseTensor<Format>() API
  1079. // if the data type for the sparse tensor values is numeric
  1080. // use data.p_data, otherwise, use data.str pointer to feed
  1081. // values. data.str is an array of const char* that are zero terminated.
  1082. // number of strings in the array must match shape size.
  1083. // For fully sparse tensors use shape {0} and set p_data/str
  1084. // to nullptr.
  1085. struct OrtSparseValuesParam {
  1086. const int64_t* values_shape;
  1087. size_t values_shape_len;
  1088. union {
  1089. const void* p_data;
  1090. const char** str;
  1091. } data;
  1092. };
  1093. // Provides a way to pass shape in a single
  1094. // argument
  1095. struct Shape {
  1096. const int64_t* shape;
  1097. size_t shape_len;
  1098. };
  1099. template <typename T>
  1100. struct ConstValueImpl : Base<T> {
  1101. using B = Base<T>;
  1102. using B::B;
  1103. /// <summary>
  1104. /// Obtains a pointer to a user defined data for experimental purposes
  1105. /// </summary>
  1106. template <typename R>
  1107. void GetOpaqueData(const char* domain, const char* type_name, R&) const; ///< Wraps OrtApi::GetOpaqueValue
  1108. bool IsTensor() const; ///< Returns true if Value is a tensor, false for other types like map/sequence/etc
  1109. bool HasValue() const; /// < Return true if OrtValue contains data and returns false if the OrtValue is a None
  1110. size_t GetCount() const; // If a non tensor, returns 2 for map and N for sequence, where N is the number of elements
  1111. Value GetValue(int index, OrtAllocator* allocator) const;
  1112. /// <summary>
  1113. /// This API returns a full length of string data contained within either a tensor or a sparse Tensor.
  1114. /// For sparse tensor it returns a full length of stored non-empty strings (values). The API is useful
  1115. /// for allocating necessary memory and calling GetStringTensorContent().
  1116. /// </summary>
  1117. /// <returns>total length of UTF-8 encoded bytes contained. No zero terminators counted.</returns>
  1118. size_t GetStringTensorDataLength() const;
  1119. /// <summary>
  1120. /// The API copies all of the UTF-8 encoded string data contained within a tensor or a sparse tensor
  1121. /// into a supplied buffer. Use GetStringTensorDataLength() to find out the length of the buffer to allocate.
  1122. /// The user must also allocate offsets buffer with the number of entries equal to that of the contained
  1123. /// strings.
  1124. ///
  1125. /// Strings are always assumed to be on CPU, no X-device copy.
  1126. /// </summary>
  1127. /// <param name="buffer">user allocated buffer</param>
  1128. /// <param name="buffer_length">length in bytes of the allocated buffer</param>
  1129. /// <param name="offsets">a pointer to the offsets user allocated buffer</param>
  1130. /// <param name="offsets_count">count of offsets, must be equal to the number of strings contained.
  1131. /// that can be obtained from the shape of the tensor or from GetSparseTensorValuesTypeAndShapeInfo()
  1132. /// for sparse tensors</param>
  1133. void GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const;
  1134. /// <summary>
  1135. /// Returns a const typed pointer to the tensor contained data.
  1136. /// No type checking is performed, the caller must ensure the type matches the tensor type.
  1137. /// </summary>
  1138. /// <typeparam name="T"></typeparam>
  1139. /// <returns>const pointer to data, no copies made</returns>
  1140. template <typename R>
  1141. const R* GetTensorData() const; ///< Wraps OrtApi::GetTensorMutableData /// <summary>
  1142. /// <summary>
  1143. /// Returns a non-typed pointer to a tensor contained data.
  1144. /// </summary>
  1145. /// <returns>const pointer to data, no copies made</returns>
  1146. const void* GetTensorRawData() const;
  1147. /// <summary>
  1148. /// The API returns type information for data contained in a tensor. For sparse
  1149. /// tensors it returns type information for contained non-zero values.
  1150. /// It returns dense shape for sparse tensors.
  1151. /// </summary>
  1152. /// <returns>TypeInfo</returns>
  1153. TypeInfo GetTypeInfo() const;
  1154. /// <summary>
  1155. /// The API returns type information for data contained in a tensor. For sparse
  1156. /// tensors it returns type information for contained non-zero values.
  1157. /// It returns dense shape for sparse tensors.
  1158. /// </summary>
  1159. /// <returns>TensorTypeAndShapeInfo</returns>
  1160. TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const;
  1161. /// <summary>
  1162. /// This API returns information about the memory allocation used to hold data.
  1163. /// </summary>
  1164. /// <returns>Non owning instance of MemoryInfo</returns>
  1165. ConstMemoryInfo GetTensorMemoryInfo() const;
  1166. /// <summary>
  1167. /// The API copies UTF-8 encoded bytes for the requested string element
  1168. /// contained within a tensor or a sparse tensor into a provided buffer.
  1169. /// Use GetStringTensorElementLength() to obtain the length of the buffer to allocate.
  1170. /// </summary>
  1171. /// <param name="buffer_length"></param>
  1172. /// <param name="element_index"></param>
  1173. /// <param name="buffer"></param>
  1174. void GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const;
  1175. /// <summary>
  1176. /// Returns string tensor UTF-8 encoded string element.
  1177. /// Use of this API is recommended over GetStringTensorElement() that takes void* buffer pointer.
  1178. /// </summary>
  1179. /// <param name="element_index"></param>
  1180. /// <returns>std::string</returns>
  1181. std::string GetStringTensorElement(size_t element_index) const;
  1182. /// <summary>
  1183. /// The API returns a byte length of UTF-8 encoded string element
  1184. /// contained in either a tensor or a spare tensor values.
  1185. /// </summary>
  1186. /// <param name="element_index"></param>
  1187. /// <returns>byte length for the specified string element</returns>
  1188. size_t GetStringTensorElementLength(size_t element_index) const;
  1189. #if !defined(DISABLE_SPARSE_TENSORS)
  1190. /// <summary>
  1191. /// The API returns the sparse data format this OrtValue holds in a sparse tensor.
  1192. /// If the sparse tensor was not fully constructed, i.e. Use*() or Fill*() API were not used
  1193. /// the value returned is ORT_SPARSE_UNDEFINED.
  1194. /// </summary>
  1195. /// <returns>Format enum</returns>
  1196. OrtSparseFormat GetSparseFormat() const;
  1197. /// <summary>
  1198. /// The API returns type and shape information for stored non-zero values of the
  1199. /// sparse tensor. Use GetSparseTensorValues() to obtain values buffer pointer.
  1200. /// </summary>
  1201. /// <returns>TensorTypeAndShapeInfo values information</returns>
  1202. TensorTypeAndShapeInfo GetSparseTensorValuesTypeAndShapeInfo() const;
  1203. /// <summary>
  1204. /// The API returns type and shape information for the specified indices. Each supported
  1205. /// indices have their own enum values even if a give format has more than one kind of indices.
  1206. /// Use GetSparseTensorIndicesData() to obtain pointer to indices buffer.
  1207. /// </summary>
  1208. /// <param name="format">enum requested</param>
  1209. /// <returns>type and shape information</returns>
  1210. TensorTypeAndShapeInfo GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat format) const;
  1211. /// <summary>
  1212. /// The API retrieves a pointer to the internal indices buffer. The API merely performs
  1213. /// a convenience data type casting on the return type pointer. Make sure you are requesting
  1214. /// the right type, use GetSparseTensorIndicesTypeShapeInfo();
  1215. /// </summary>
  1216. /// <typeparam name="T">type to cast to</typeparam>
  1217. /// <param name="indices_format">requested indices kind</param>
  1218. /// <param name="num_indices">number of indices entries</param>
  1219. /// <returns>Pinter to the internal sparse tensor buffer containing indices. Do not free this pointer.</returns>
  1220. template <typename R>
  1221. const R* GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const;
  1222. /// <summary>
  1223. /// Returns true if the OrtValue contains a sparse tensor
  1224. /// </summary>
  1225. /// <returns></returns>
  1226. bool IsSparseTensor() const;
  1227. /// <summary>
  1228. /// The API returns a pointer to an internal buffer of the sparse tensor
  1229. /// containing non-zero values. The API merely does casting. Make sure you
  1230. /// are requesting the right data type by calling GetSparseTensorValuesTypeAndShapeInfo()
  1231. /// first.
  1232. /// </summary>
  1233. /// <typeparam name="T">numeric data types only. Use GetStringTensor*() to retrieve strings.</typeparam>
  1234. /// <returns>a pointer to the internal values buffer. Do not free this pointer.</returns>
  1235. template <typename R>
  1236. const R* GetSparseTensorValues() const;
  1237. #endif
  1238. };
  1239. template <typename T>
  1240. struct ValueImpl : ConstValueImpl<T> {
  1241. using B = ConstValueImpl<T>;
  1242. using B::B;
  1243. /// <summary>
  1244. /// Returns a non-const typed pointer to an OrtValue/Tensor contained buffer
  1245. /// No type checking is performed, the caller must ensure the type matches the tensor type.
  1246. /// </summary>
  1247. /// <returns>non-const pointer to data, no copies made</returns>
  1248. template <typename R>
  1249. R* GetTensorMutableData();
  1250. /// <summary>
  1251. /// Returns a non-typed non-const pointer to a tensor contained data.
  1252. /// </summary>
  1253. /// <returns>pointer to data, no copies made</returns>
  1254. void* GetTensorMutableRawData();
  1255. /// <summary>
  1256. // Obtain a reference to an element of data at the location specified
  1257. /// by the vector of dims.
  1258. /// </summary>
  1259. /// <typeparam name="R"></typeparam>
  1260. /// <param name="location">[in] expressed by a vecotr of dimensions offsets</param>
  1261. /// <returns></returns>
  1262. template <typename R>
  1263. R& At(const std::vector<int64_t>& location);
  1264. /// <summary>
  1265. /// Set all strings at once in a string tensor
  1266. /// </summary>
  1267. /// <param name="s">[in] An array of strings. Each string in this array must be null terminated.</param>
  1268. /// <param name="s_len">[in] Count of strings in s (Must match the size of \p value's tensor shape)</param>
  1269. void FillStringTensor(const char* const* s, size_t s_len);
  1270. /// <summary>
  1271. /// Set a single string in a string tensor
  1272. /// </summary>
  1273. /// <param name="s">[in] A null terminated UTF-8 encoded string</param>
  1274. /// <param name="index">[in] Index of the string in the tensor to set</param>
  1275. void FillStringTensorElement(const char* s, size_t index);
  1276. /// <summary>
  1277. /// Allocate if necessary and obtain a pointer to a UTF-8
  1278. /// encoded string element buffer indexed by the flat element index,
  1279. /// of the specified length.
  1280. ///
  1281. /// This API is for advanced usage. It avoids a need to construct
  1282. /// an auxiliary array of string pointers, and allows to write data directly
  1283. /// (do not zero terminate).
  1284. /// </summary>
  1285. /// <param name="index"></param>
  1286. /// <param name="buffer_length"></param>
  1287. /// <returns>a pointer to a writable buffer</returns>
  1288. char* GetResizedStringTensorElementBuffer(size_t index, size_t buffer_length);
  1289. #if !defined(DISABLE_SPARSE_TENSORS)
  1290. /// <summary>
  1291. /// Supplies COO format specific indices and marks the contained sparse tensor as being a COO format tensor.
  1292. /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
  1293. /// allocated buffers lifespan must eclipse that of the OrtValue.
  1294. /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
  1295. /// </summary>
  1296. /// <param name="indices_data">pointer to the user allocated buffer with indices. Use nullptr for fully sparse tensors.</param>
  1297. /// <param name="indices_num">number of indices entries. Use 0 for fully sparse tensors</param>
  1298. void UseCooIndices(int64_t* indices_data, size_t indices_num);
  1299. /// <summary>
  1300. /// Supplies CSR format specific indices and marks the contained sparse tensor as being a CSR format tensor.
  1301. /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
  1302. /// allocated buffers lifespan must eclipse that of the OrtValue.
  1303. /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
  1304. /// </summary>
  1305. /// <param name="inner_data">pointer to the user allocated buffer with inner indices or nullptr for fully sparse tensors</param>
  1306. /// <param name="inner_num">number of csr inner indices or 0 for fully sparse tensors</param>
  1307. /// <param name="outer_data">pointer to the user allocated buffer with outer indices or nullptr for fully sparse tensors</param>
  1308. /// <param name="outer_num">number of csr outer indices or 0 for fully sparse tensors</param>
  1309. void UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num);
  1310. /// <summary>
  1311. /// Supplies BlockSparse format specific indices and marks the contained sparse tensor as being a BlockSparse format tensor.
  1312. /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
  1313. /// allocated buffers lifespan must eclipse that of the OrtValue.
  1314. /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
  1315. /// </summary>
  1316. /// <param name="indices_shape">indices shape or a {0} for fully sparse</param>
  1317. /// <param name="indices_data">user allocated buffer with indices or nullptr for fully spare tensors</param>
  1318. void UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data);
  1319. /// <summary>
  1320. /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
  1321. /// and copy the values and COO indices into it. If data_mem_info specifies that the data is located
  1322. /// at difference device than the allocator, a X-device copy will be performed if possible.
  1323. /// </summary>
  1324. /// <param name="data_mem_info">specified buffer memory description</param>
  1325. /// <param name="values_param">values buffer information.</param>
  1326. /// <param name="indices_data">coo indices buffer or nullptr for fully sparse data</param>
  1327. /// <param name="indices_num">number of COO indices or 0 for fully sparse data</param>
  1328. void FillSparseTensorCoo(const OrtMemoryInfo* data_mem_info, const OrtSparseValuesParam& values_param,
  1329. const int64_t* indices_data, size_t indices_num);
  1330. /// <summary>
  1331. /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
  1332. /// and copy the values and CSR indices into it. If data_mem_info specifies that the data is located
  1333. /// at difference device than the allocator, a X-device copy will be performed if possible.
  1334. /// </summary>
  1335. /// <param name="data_mem_info">specified buffer memory description</param>
  1336. /// <param name="values">values buffer information</param>
  1337. /// <param name="inner_indices_data">csr inner indices pointer or nullptr for fully sparse tensors</param>
  1338. /// <param name="inner_indices_num">number of csr inner indices or 0 for fully sparse tensors</param>
  1339. /// <param name="outer_indices_data">pointer to csr indices data or nullptr for fully sparse tensors</param>
  1340. /// <param name="outer_indices_num">number of csr outer indices or 0</param>
  1341. void FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info,
  1342. const OrtSparseValuesParam& values,
  1343. const int64_t* inner_indices_data, size_t inner_indices_num,
  1344. const int64_t* outer_indices_data, size_t outer_indices_num);
  1345. /// <summary>
  1346. /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
  1347. /// and copy the values and BlockSparse indices into it. If data_mem_info specifies that the data is located
  1348. /// at difference device than the allocator, a X-device copy will be performed if possible.
  1349. /// </summary>
  1350. /// <param name="data_mem_info">specified buffer memory description</param>
  1351. /// <param name="values">values buffer information</param>
  1352. /// <param name="indices_shape">indices shape. use {0} for fully sparse tensors</param>
  1353. /// <param name="indices_data">pointer to indices data or nullptr for fully sparse tensors</param>
  1354. void FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info,
  1355. const OrtSparseValuesParam& values,
  1356. const Shape& indices_shape,
  1357. const int32_t* indices_data);
  1358. #endif
  1359. };
  1360. } // namespace detail
  1361. using ConstValue = detail::ConstValueImpl<detail::Unowned<const OrtValue>>;
  1362. using UnownedValue = detail::ValueImpl<detail::Unowned<OrtValue>>;
  1363. /** \brief Wrapper around ::OrtValue
  1364. *
  1365. */
  1366. struct Value : detail::ValueImpl<OrtValue> {
  1367. using Base = detail::ValueImpl<OrtValue>;
  1368. using OrtSparseValuesParam = detail::OrtSparseValuesParam;
  1369. using Shape = detail::Shape;
  1370. explicit Value(std::nullptr_t) {} ///< Create an empty Value object, must be assigned a valid one to be used
  1371. explicit Value(OrtValue* p) : Base{p} {} ///< Used for interop with the C API
  1372. Value(Value&&) = default;
  1373. Value& operator=(Value&&) = default;
  1374. ConstValue GetConst() const { return ConstValue{this->p_}; }
  1375. UnownedValue GetUnowned() const { return UnownedValue{this->p_}; }
  1376. /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue.
  1377. * \tparam T The numeric datatype. This API is not suitable for strings.
  1378. * \param info Memory description of where the p_data buffer resides (CPU vs GPU etc).
  1379. * \param p_data Pointer to the data buffer.
  1380. * \param p_data_element_count The number of elements in the data buffer.
  1381. * \param shape Pointer to the tensor shape dimensions.
  1382. * \param shape_len The number of tensor shape dimensions.
  1383. */
  1384. template <typename T>
  1385. static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len);
  1386. /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue.
  1387. *
  1388. * \param info Memory description of where the p_data buffer resides (CPU vs GPU etc).
  1389. * \param p_data Pointer to the data buffer.
  1390. * \param p_data_byte_count The number of bytes in the data buffer.
  1391. * \param shape Pointer to the tensor shape dimensions.
  1392. * \param shape_len The number of tensor shape dimensions.
  1393. * \param type The data type.
  1394. */
  1395. static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
  1396. ONNXTensorElementDataType type);
  1397. /** \brief Creates an OrtValue with a tensor using a supplied OrtAllocator. Wraps OrtApi::CreateTensorAsOrtValue.
  1398. * This overload will allocate the buffer for the tensor according to the supplied shape and data type.
  1399. * The allocated buffer will be owned by the returned OrtValue and will be freed when the OrtValue is released.
  1400. * The input data would need to be copied into the allocated buffer.
  1401. * This API is not suitable for strings.
  1402. *
  1403. * \tparam T The numeric datatype. This API is not suitable for strings.
  1404. * \param allocator The allocator to use.
  1405. * \param shape Pointer to the tensor shape dimensions.
  1406. * \param shape_len The number of tensor shape dimensions.
  1407. */
  1408. template <typename T>
  1409. static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len);
  1410. /** \brief Creates an OrtValue with a tensor using the supplied OrtAllocator.
  1411. * Wraps OrtApi::CreateTensorAsOrtValue.
  1412. * The allocated buffer will be owned by the returned OrtValue and will be freed when the OrtValue is released.
  1413. * The input data would need to be copied into the allocated buffer.
  1414. * This API is not suitable for strings.
  1415. *
  1416. * \param allocator The allocator to use.
  1417. * \param shape Pointer to the tensor shape dimensions.
  1418. * \param shape_len The number of tensor shape dimensions.
  1419. * \param type The data type.
  1420. */
  1421. static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type);
  1422. /** \brief Creates an OrtValue with a Map Onnx type representation.
  1423. * The API would ref-count the supplied OrtValues and they will be released
  1424. * when the returned OrtValue is released. The caller may release keys and values after the call
  1425. * returns.
  1426. *
  1427. * \param keys an OrtValue containing a tensor with primitive data type keys.
  1428. * \param values an OrtValue that may contain a tensor. Ort currently supports only primitive data type values.
  1429. */
  1430. static Value CreateMap(const Value& keys, const Value& values); ///< Wraps OrtApi::CreateValue
  1431. /** \brief Creates an OrtValue with a Sequence Onnx type representation.
  1432. * The API would ref-count the supplied OrtValues and they will be released
  1433. * when the returned OrtValue is released. The caller may release the values after the call
  1434. * returns.
  1435. *
  1436. * \param values a vector of OrtValues that must have the same Onnx value type.
  1437. */
  1438. static Value CreateSequence(const std::vector<Value>& values); ///< Wraps OrtApi::CreateValue
  1439. /** \brief Creates an OrtValue wrapping an Opaque type.
  1440. * This is used for experimental support of non-tensor types.
  1441. *
  1442. * \tparam T - the type of the value.
  1443. * \param domain - zero terminated utf-8 string. Domain of the type.
  1444. * \param type_name - zero terminated utf-8 string. Name of the type.
  1445. * \param value - the value to be wrapped.
  1446. */
  1447. template <typename T>
  1448. static Value CreateOpaque(const char* domain, const char* type_name, const T& value); ///< Wraps OrtApi::CreateOpaqueValue
  1449. #if !defined(DISABLE_SPARSE_TENSORS)
  1450. /// <summary>
  1451. /// This is a simple forwarding method to the other overload that helps deducing
  1452. /// data type enum value from the type of the buffer.
  1453. /// </summary>
  1454. /// <typeparam name="T">numeric datatype. This API is not suitable for strings.</typeparam>
  1455. /// <param name="info">Memory description where the user buffers reside (CPU vs GPU etc)</param>
  1456. /// <param name="p_data">pointer to the user supplied buffer, use nullptr for fully sparse tensors</param>
  1457. /// <param name="dense_shape">a would be dense shape of the tensor</param>
  1458. /// <param name="values_shape">non zero values shape. Use a single 0 shape for fully sparse tensors.</param>
  1459. /// <returns></returns>
  1460. template <typename T>
  1461. static Value CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape,
  1462. const Shape& values_shape);
  1463. /// <summary>
  1464. /// Creates an OrtValue instance containing SparseTensor. This constructs
  1465. /// a sparse tensor that makes use of user allocated buffers. It does not make copies
  1466. /// of the user provided data and does not modify it. The lifespan of user provided buffers should
  1467. /// eclipse the life span of the resulting OrtValue. This call constructs an instance that only contain
  1468. /// a pointer to non-zero values. To fully populate the sparse tensor call Use<Format>Indices() API below
  1469. /// to supply a sparse format specific indices.
  1470. /// This API is not suitable for string data. Use CreateSparseTensor() with allocator specified so strings
  1471. /// can be properly copied into the allocated buffer.
  1472. /// </summary>
  1473. /// <param name="info">Memory description where the user buffers reside (CPU vs GPU etc)</param>
  1474. /// <param name="p_data">pointer to the user supplied buffer, use nullptr for fully sparse tensors</param>
  1475. /// <param name="dense_shape">a would be dense shape of the tensor</param>
  1476. /// <param name="values_shape">non zero values shape. Use a single 0 shape for fully sparse tensors.</param>
  1477. /// <param name="type">data type</param>
  1478. /// <returns>Ort::Value instance containing SparseTensor</returns>
  1479. static Value CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape,
  1480. const Shape& values_shape, ONNXTensorElementDataType type);
  1481. /// <summary>
  1482. /// This is a simple forwarding method to the below CreateSparseTensor.
  1483. /// This helps to specify data type enum in terms of C++ data type.
  1484. /// Use CreateSparseTensor<T>
  1485. /// </summary>
  1486. /// <typeparam name="T">numeric data type only. String data enum must be specified explicitly.</typeparam>
  1487. /// <param name="allocator">allocator to use</param>
  1488. /// <param name="dense_shape">a would be dense shape of the tensor</param>
  1489. /// <returns>Ort::Value</returns>
  1490. template <typename T>
  1491. static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape);
  1492. /// <summary>
  1493. /// Creates an instance of OrtValue containing sparse tensor. The created instance has no data.
  1494. /// The data must be supplied by on of the FillSparseTensor<Format>() methods that take both non-zero values
  1495. /// and indices. The data will be copied into a buffer that would be allocated using the supplied allocator.
  1496. /// Use this API to create OrtValues that contain sparse tensors with all supported data types including
  1497. /// strings.
  1498. /// </summary>
  1499. /// <param name="allocator">allocator to use. The allocator lifespan must eclipse that of the resulting OrtValue</param>
  1500. /// <param name="dense_shape">a would be dense shape of the tensor</param>
  1501. /// <param name="type">data type</param>
  1502. /// <returns>an instance of Ort::Value</returns>
  1503. static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape, ONNXTensorElementDataType type);
  1504. #endif // !defined(DISABLE_SPARSE_TENSORS)
  1505. };
  1506. /// <summary>
  1507. /// Represents native memory allocation coming from one of the
  1508. /// OrtAllocators registered with OnnxRuntime.
  1509. /// Use it to wrap an allocation made by an allocator
  1510. /// so it can be automatically released when no longer needed.
  1511. /// </summary>
  1512. struct MemoryAllocation {
  1513. MemoryAllocation(OrtAllocator* allocator, void* p, size_t size);
  1514. ~MemoryAllocation();
  1515. MemoryAllocation(const MemoryAllocation&) = delete;
  1516. MemoryAllocation& operator=(const MemoryAllocation&) = delete;
  1517. MemoryAllocation(MemoryAllocation&&) noexcept;
  1518. MemoryAllocation& operator=(MemoryAllocation&&) noexcept;
  1519. void* get() { return p_; }
  1520. size_t size() const { return size_; }
  1521. private:
  1522. OrtAllocator* allocator_;
  1523. void* p_;
  1524. size_t size_;
  1525. };
  1526. namespace detail {
  1527. template <typename T>
  1528. struct AllocatorImpl : Base<T> {
  1529. using B = Base<T>;
  1530. using B::B;
  1531. void* Alloc(size_t size);
  1532. MemoryAllocation GetAllocation(size_t size);
  1533. void Free(void* p);
  1534. ConstMemoryInfo GetInfo() const;
  1535. };
  1536. } // namespace detail
  1537. /** \brief Wrapper around ::OrtAllocator default instance that is owned by Onnxruntime
  1538. *
  1539. */
  1540. struct AllocatorWithDefaultOptions : detail::AllocatorImpl<detail::Unowned<OrtAllocator>> {
  1541. explicit AllocatorWithDefaultOptions(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance
  1542. AllocatorWithDefaultOptions();
  1543. };
  1544. /** \brief Wrapper around ::OrtAllocator
  1545. *
  1546. */
  1547. struct Allocator : detail::AllocatorImpl<OrtAllocator> {
  1548. explicit Allocator(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance
  1549. Allocator(const Session& session, const OrtMemoryInfo*);
  1550. };
  1551. using UnownedAllocator = detail::AllocatorImpl<detail::Unowned<OrtAllocator>>;
  1552. namespace detail {
  1553. namespace binding_utils {
  1554. // Bring these out of template
  1555. std::vector<std::string> GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator*);
  1556. std::vector<Value> GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator*);
  1557. } // namespace binding_utils
  1558. template <typename T>
  1559. struct ConstIoBindingImpl : Base<T> {
  1560. using B = Base<T>;
  1561. using B::B;
  1562. std::vector<std::string> GetOutputNames() const;
  1563. std::vector<std::string> GetOutputNames(OrtAllocator*) const;
  1564. std::vector<Value> GetOutputValues() const;
  1565. std::vector<Value> GetOutputValues(OrtAllocator*) const;
  1566. };
  1567. template <typename T>
  1568. struct IoBindingImpl : ConstIoBindingImpl<T> {
  1569. using B = ConstIoBindingImpl<T>;
  1570. using B::B;
  1571. void BindInput(const char* name, const Value&);
  1572. void BindOutput(const char* name, const Value&);
  1573. void BindOutput(const char* name, const OrtMemoryInfo*);
  1574. void ClearBoundInputs();
  1575. void ClearBoundOutputs();
  1576. void SynchronizeInputs();
  1577. void SynchronizeOutputs();
  1578. };
  1579. } // namespace detail
  1580. using ConstIoBinding = detail::ConstIoBindingImpl<detail::Unowned<const OrtIoBinding>>;
  1581. using UnownedIoBinding = detail::IoBindingImpl<detail::Unowned<OrtIoBinding>>;
  1582. /** \brief Wrapper around ::OrtIoBinding
  1583. *
  1584. */
  1585. struct IoBinding : detail::IoBindingImpl<OrtIoBinding> {
  1586. explicit IoBinding(std::nullptr_t) {} ///< Create an empty object for convenience. Sometimes, we want to initialize members later.
  1587. explicit IoBinding(Session& session);
  1588. ConstIoBinding GetConst() const { return ConstIoBinding{this->p_}; }
  1589. UnownedIoBinding GetUnowned() const { return UnownedIoBinding{this->p_}; }
  1590. };
  1591. /*! \struct Ort::ArenaCfg
  1592. * \brief it is a structure that represents the configuration of an arena based allocator
  1593. * \details Please see docs/C_API.md for details
  1594. */
  1595. struct ArenaCfg : detail::Base<OrtArenaCfg> {
  1596. explicit ArenaCfg(std::nullptr_t) {} ///< Create an empty ArenaCfg object, must be assigned a valid one to be used
  1597. /**
  1598. * Wraps OrtApi::CreateArenaCfg
  1599. * \param max_mem - use 0 to allow ORT to choose the default
  1600. * \param arena_extend_strategy - use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested
  1601. * \param initial_chunk_size_bytes - use -1 to allow ORT to choose the default
  1602. * \param max_dead_bytes_per_chunk - use -1 to allow ORT to choose the default
  1603. * See docs/C_API.md for details on what the following parameters mean and how to choose these values
  1604. */
  1605. ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk);
  1606. };
  1607. //
  1608. // Custom OPs (only needed to implement custom OPs)
  1609. //
  1610. /// <summary>
  1611. /// This struct provides life time management for custom op attribute
  1612. /// </summary>
  1613. struct OpAttr : detail::Base<OrtOpAttr> {
  1614. OpAttr(const char* name, const void* data, int len, OrtOpAttrType type);
  1615. };
  1616. /**
  1617. * Macro that logs a message using the provided logger. Throws an exception if OrtApi::Logger_LogMessage fails.
  1618. * Example: ORT_CXX_LOG(logger, ORT_LOGGING_LEVEL_INFO, "Log a message");
  1619. *
  1620. * \param logger The Ort::Logger instance to use. Must be a value or reference.
  1621. * \param message_severity The logging severity level of the message.
  1622. * \param message A null-terminated UTF-8 message to log.
  1623. */
  1624. #define ORT_CXX_LOG(logger, message_severity, message) \
  1625. do { \
  1626. if (message_severity >= logger.GetLoggingSeverityLevel()) { \
  1627. Ort::ThrowOnError(logger.LogMessage(message_severity, ORT_FILE, __LINE__, \
  1628. static_cast<const char*>(__FUNCTION__), message)); \
  1629. } \
  1630. } while (false)
  1631. /**
  1632. * Macro that logs a message using the provided logger. Can be used in noexcept code since errors are silently ignored.
  1633. * Example: ORT_CXX_LOG_NOEXCEPT(logger, ORT_LOGGING_LEVEL_INFO, "Log a message");
  1634. *
  1635. * \param logger The Ort::Logger instance to use. Must be a value or reference.
  1636. * \param message_severity The logging severity level of the message.
  1637. * \param message A null-terminated UTF-8 message to log.
  1638. */
  1639. #define ORT_CXX_LOG_NOEXCEPT(logger, message_severity, message) \
  1640. do { \
  1641. if (message_severity >= logger.GetLoggingSeverityLevel()) { \
  1642. static_cast<void>(logger.LogMessage(message_severity, ORT_FILE, __LINE__, \
  1643. static_cast<const char*>(__FUNCTION__), message)); \
  1644. } \
  1645. } while (false)
  1646. /**
  1647. * Macro that logs a printf-like formatted message using the provided logger. Throws an exception if
  1648. * OrtApi::Logger_LogMessage fails or if a formatting error occurs.
  1649. * Example: ORT_CXX_LOGF(logger, ORT_LOGGING_LEVEL_INFO, "Log an int: %d", 12);
  1650. *
  1651. * \param logger The Ort::Logger instance to use. Must be a value or reference.
  1652. * \param message_severity The logging severity level of the message.
  1653. * \param format A null-terminated UTF-8 format string forwarded to a printf-like function.
  1654. * Refer to https://en.cppreference.com/w/cpp/io/c/fprintf for information on valid formats.
  1655. * \param ... Zero or more variadic arguments referenced by the format string.
  1656. */
  1657. #define ORT_CXX_LOGF(logger, message_severity, /*format,*/...) \
  1658. do { \
  1659. if (message_severity >= logger.GetLoggingSeverityLevel()) { \
  1660. Ort::ThrowOnError(logger.LogFormattedMessage(message_severity, ORT_FILE, __LINE__, \
  1661. static_cast<const char*>(__FUNCTION__), __VA_ARGS__)); \
  1662. } \
  1663. } while (false)
  1664. /**
  1665. * Macro that logs a printf-like formatted message using the provided logger. Can be used in noexcept code since errors
  1666. * are silently ignored.
  1667. * Example: ORT_CXX_LOGF_NOEXCEPT(logger, ORT_LOGGING_LEVEL_INFO, "Log an int: %d", 12);
  1668. *
  1669. * \param logger The Ort::Logger instance to use. Must be a value or reference.
  1670. * \param message_severity The logging severity level of the message.
  1671. * \param format A null-terminated UTF-8 format string forwarded to a printf-like function.
  1672. * Refer to https://en.cppreference.com/w/cpp/io/c/fprintf for information on valid formats.
  1673. * \param ... Zero or more variadic arguments referenced by the format string.
  1674. */
  1675. #define ORT_CXX_LOGF_NOEXCEPT(logger, message_severity, /*format,*/...) \
  1676. do { \
  1677. if (message_severity >= logger.GetLoggingSeverityLevel()) { \
  1678. static_cast<void>(logger.LogFormattedMessage(message_severity, ORT_FILE, __LINE__, \
  1679. static_cast<const char*>(__FUNCTION__), __VA_ARGS__)); \
  1680. } \
  1681. } while (false)
  1682. /// <summary>
  1683. /// This class represents an ONNX Runtime logger that can be used to log information with an
  1684. /// associated severity level and source code location (file path, line number, function name).
  1685. ///
  1686. /// A Logger can be obtained from within custom operators by calling Ort::KernelInfo::GetLogger().
  1687. /// Instances of Ort::Logger are the size of two pointers and can be passed by value.
  1688. ///
  1689. /// Use the ORT_CXX_LOG macros to ensure the source code location is set properly from the callsite
  1690. /// and to take advantage of a cached logging severity level that can bypass calls to the underlying C API.
  1691. /// </summary>
  1692. struct Logger {
  1693. /**
  1694. * Creates an empty Ort::Logger. Must be initialized from a valid Ort::Logger before use.
  1695. */
  1696. Logger() = default;
  1697. /**
  1698. * Creates an empty Ort::Logger. Must be initialized from a valid Ort::Logger before use.
  1699. */
  1700. explicit Logger(std::nullptr_t) {}
  1701. /**
  1702. * Creates a logger from an ::OrtLogger instance. Caches the logger's current severity level by calling
  1703. * OrtApi::Logger_GetLoggingSeverityLevel. Throws an exception if OrtApi::Logger_GetLoggingSeverityLevel fails.
  1704. *
  1705. * \param logger The ::OrtLogger to wrap.
  1706. */
  1707. explicit Logger(const OrtLogger* logger);
  1708. ~Logger() = default;
  1709. Logger(const Logger&) = default;
  1710. Logger& operator=(const Logger&) = default;
  1711. Logger(Logger&& v) noexcept = default;
  1712. Logger& operator=(Logger&& v) noexcept = default;
  1713. /**
  1714. * Returns the logger's current severity level from the cached member.
  1715. *
  1716. * \return The current ::OrtLoggingLevel.
  1717. */
  1718. OrtLoggingLevel GetLoggingSeverityLevel() const noexcept;
  1719. /**
  1720. * Logs the provided message via OrtApi::Logger_LogMessage. Use the ORT_CXX_LOG or ORT_CXX_LOG_NOEXCEPT
  1721. * macros to properly set the source code location and to use the cached severity level to potentially bypass
  1722. * calls to the underlying C API.
  1723. *
  1724. * \param log_severity_level The message's logging severity level.
  1725. * \param file_path The filepath of the file in which the message is logged. Usually the value of ORT_FILE.
  1726. * \param line_number The file line number in which the message is logged. Usually the value of __LINE__.
  1727. * \param func_name The name of the function in which the message is logged. Usually the value of __FUNCTION__.
  1728. * \param message The message to log.
  1729. * \return A Ort::Status value to indicate error or success.
  1730. */
  1731. Status LogMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number,
  1732. const char* func_name, const char* message) const noexcept;
  1733. /**
  1734. * Logs a printf-like formatted message via OrtApi::Logger_LogMessage. Use the ORT_CXX_LOGF or ORT_CXX_LOGF_NOEXCEPT
  1735. * macros to properly set the source code location and to use the cached severity level to potentially bypass
  1736. * calls to the underlying C API. Returns an error status if a formatting error occurs.
  1737. *
  1738. * \param log_severity_level The message's logging severity level.
  1739. * \param file_path The filepath of the file in which the message is logged. Usually the value of ORT_FILE.
  1740. * \param line_number The file line number in which the message is logged. Usually the value of __LINE__.
  1741. * \param func_name The name of the function in which the message is logged. Usually the value of __FUNCTION__.
  1742. * \param format A null-terminated UTF-8 format string forwarded to a printf-like function.
  1743. * Refer to https://en.cppreference.com/w/cpp/io/c/fprintf for information on valid formats.
  1744. * \param args Zero or more variadic arguments referenced by the format string.
  1745. * \return A Ort::Status value to indicate error or success.
  1746. */
  1747. template <typename... Args>
  1748. Status LogFormattedMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number,
  1749. const char* func_name, const char* format, Args&&... args) const noexcept;
  1750. private:
  1751. const OrtLogger* logger_{};
  1752. OrtLoggingLevel cached_severity_level_{};
  1753. };
  1754. /// <summary>
  1755. /// This class wraps a raw pointer OrtKernelContext* that is being passed
  1756. /// to the custom kernel Compute() method. Use it to safely access context
  1757. /// attributes, input and output parameters with exception safety guarantees.
  1758. /// See usage example in onnxruntime/test/testdata/custom_op_library/custom_op_library.cc
  1759. /// </summary>
  1760. struct KernelContext {
  1761. explicit KernelContext(OrtKernelContext* context);
  1762. size_t GetInputCount() const;
  1763. size_t GetOutputCount() const;
  1764. // If input is optional and is not present, the method returns en empty ConstValue
  1765. // which can be compared to nullptr.
  1766. ConstValue GetInput(size_t index) const;
  1767. // If outout is optional and is not present, the method returns en empty UnownedValue
  1768. // which can be compared to nullptr.
  1769. UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const;
  1770. UnownedValue GetOutput(size_t index, const std::vector<int64_t>& dims) const;
  1771. void* GetGPUComputeStream() const;
  1772. Logger GetLogger() const;
  1773. OrtAllocator* GetAllocator(const OrtMemoryInfo& memory_info) const;
  1774. OrtKernelContext* GetOrtKernelContext() const { return ctx_; }
  1775. void ParallelFor(void (*fn)(void*, size_t), size_t total, size_t num_batch, void* usr_data) const;
  1776. private:
  1777. OrtKernelContext* ctx_;
  1778. };
  1779. struct KernelInfo;
  1780. namespace detail {
  1781. namespace attr_utils {
  1782. void GetAttr(const OrtKernelInfo* p, const char* name, float&);
  1783. void GetAttr(const OrtKernelInfo* p, const char* name, int64_t&);
  1784. void GetAttr(const OrtKernelInfo* p, const char* name, std::string&);
  1785. void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<float>&);
  1786. void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<int64_t>&);
  1787. } // namespace attr_utils
  1788. template <typename T>
  1789. struct KernelInfoImpl : Base<T> {
  1790. using B = Base<T>;
  1791. using B::B;
  1792. KernelInfo Copy() const;
  1793. template <typename R> // R is only implemented for float, int64_t, and string
  1794. R GetAttribute(const char* name) const {
  1795. R val;
  1796. attr_utils::GetAttr(this->p_, name, val);
  1797. return val;
  1798. }
  1799. template <typename R> // R is only implemented for std::vector<float>, std::vector<int64_t>
  1800. std::vector<R> GetAttributes(const char* name) const {
  1801. std::vector<R> result;
  1802. attr_utils::GetAttrs(this->p_, name, result);
  1803. return result;
  1804. }
  1805. Value GetTensorAttribute(const char* name, OrtAllocator* allocator) const;
  1806. size_t GetInputCount() const;
  1807. size_t GetOutputCount() const;
  1808. std::string GetInputName(size_t index) const;
  1809. std::string GetOutputName(size_t index) const;
  1810. TypeInfo GetInputTypeInfo(size_t index) const;
  1811. TypeInfo GetOutputTypeInfo(size_t index) const;
  1812. ConstValue GetTensorConstantInput(size_t index, int* is_constant) const;
  1813. std::string GetNodeName() const;
  1814. Logger GetLogger() const;
  1815. };
  1816. } // namespace detail
  1817. using ConstKernelInfo = detail::KernelInfoImpl<detail::Unowned<const OrtKernelInfo>>;
  1818. /// <summary>
  1819. /// This struct owns the OrtKernInfo* pointer when a copy is made.
  1820. /// For convenient wrapping of OrtKernelInfo* passed to kernel constructor
  1821. /// and query attributes, warp the pointer with Ort::Unowned<KernelInfo> instance
  1822. /// so it does not destroy the pointer the kernel does not own.
  1823. /// </summary>
  1824. struct KernelInfo : detail::KernelInfoImpl<OrtKernelInfo> {
  1825. explicit KernelInfo(std::nullptr_t) {} ///< Create an empty instance to initialize later
  1826. explicit KernelInfo(OrtKernelInfo* info); ///< Take ownership of the instance
  1827. ConstKernelInfo GetConst() const { return ConstKernelInfo{this->p_}; }
  1828. };
  1829. /// <summary>
  1830. /// Create and own custom defined operation.
  1831. /// </summary>
  1832. struct Op : detail::Base<OrtOp> {
  1833. explicit Op(std::nullptr_t) {} ///< Create an empty Operator object, must be assigned a valid one to be used
  1834. explicit Op(OrtOp*); ///< Take ownership of the OrtOp
  1835. static Op Create(const OrtKernelInfo* info, const char* op_name, const char* domain,
  1836. int version, const char** type_constraint_names,
  1837. const ONNXTensorElementDataType* type_constraint_values,
  1838. size_t type_constraint_count,
  1839. const OpAttr* attr_values,
  1840. size_t attr_count,
  1841. size_t input_count, size_t output_count);
  1842. void Invoke(const OrtKernelContext* context,
  1843. const Value* input_values,
  1844. size_t input_count,
  1845. Value* output_values,
  1846. size_t output_count);
  1847. // For easier refactoring
  1848. void Invoke(const OrtKernelContext* context,
  1849. const OrtValue* const* input_values,
  1850. size_t input_count,
  1851. OrtValue* const* output_values,
  1852. size_t output_count);
  1853. };
  1854. /// <summary>
  1855. /// Provide access to per-node attributes and input shapes, so one could compute and set output shapes.
  1856. /// </summary>
  1857. struct ShapeInferContext {
  1858. struct SymbolicInteger {
  1859. SymbolicInteger(int64_t i) : i_(i), is_int_(true){};
  1860. SymbolicInteger(const char* s) : s_(s), is_int_(false){};
  1861. SymbolicInteger(const SymbolicInteger&) = default;
  1862. SymbolicInteger(SymbolicInteger&&) = default;
  1863. SymbolicInteger& operator=(const SymbolicInteger&) = default;
  1864. SymbolicInteger& operator=(SymbolicInteger&&) = default;
  1865. bool operator==(const SymbolicInteger& dim) const {
  1866. if (is_int_ == dim.is_int_) {
  1867. if (is_int_) {
  1868. return i_ == dim.i_;
  1869. } else {
  1870. return std::string{s_} == std::string{dim.s_};
  1871. }
  1872. }
  1873. return false;
  1874. }
  1875. bool IsInt() const { return is_int_; }
  1876. int64_t AsInt() const { return i_; }
  1877. const char* AsSym() const { return s_; }
  1878. static constexpr int INVALID_INT_DIM = -2;
  1879. private:
  1880. union {
  1881. int64_t i_;
  1882. const char* s_;
  1883. };
  1884. bool is_int_;
  1885. };
  1886. using Shape = std::vector<SymbolicInteger>;
  1887. ShapeInferContext(const OrtApi* ort_api, OrtShapeInferContext* ctx);
  1888. const Shape& GetInputShape(size_t indice) const { return input_shapes_.at(indice); }
  1889. size_t GetInputCount() const { return input_shapes_.size(); }
  1890. Status SetOutputShape(size_t indice, const Shape& shape);
  1891. int64_t GetAttrInt(const char* attr_name);
  1892. using Ints = std::vector<int64_t>;
  1893. Ints GetAttrInts(const char* attr_name);
  1894. float GetAttrFloat(const char* attr_name);
  1895. using Floats = std::vector<float>;
  1896. Floats GetAttrFloats(const char* attr_name);
  1897. std::string GetAttrString(const char* attr_name);
  1898. using Strings = std::vector<std::string>;
  1899. Strings GetAttrStrings(const char* attr_name);
  1900. private:
  1901. const OrtOpAttr* GetAttrHdl(const char* attr_name) const;
  1902. const OrtApi* ort_api_;
  1903. OrtShapeInferContext* ctx_;
  1904. std::vector<Shape> input_shapes_;
  1905. };
  1906. using ShapeInferFn = Ort::Status (*)(Ort::ShapeInferContext&);
  1907. #define MAX_CUSTOM_OP_END_VER (1UL << 31) - 1
  1908. template <typename TOp, typename TKernel, bool WithStatus = false>
  1909. struct CustomOpBase : OrtCustomOp {
  1910. CustomOpBase() {
  1911. OrtCustomOp::version = ORT_API_VERSION;
  1912. OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetName(); };
  1913. OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetExecutionProviderType(); };
  1914. OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetInputTypeCount(); };
  1915. OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputType(index); };
  1916. OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputMemoryType(index); };
  1917. OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetOutputTypeCount(); };
  1918. OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputType(index); };
  1919. #if defined(_MSC_VER) && !defined(__clang__)
  1920. #pragma warning(push)
  1921. #pragma warning(disable : 26409)
  1922. #endif
  1923. OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<TKernel*>(op_kernel); };
  1924. #if defined(_MSC_VER) && !defined(__clang__)
  1925. #pragma warning(pop)
  1926. #endif
  1927. OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputCharacteristic(index); };
  1928. OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputCharacteristic(index); };
  1929. OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetVariadicInputMinArity(); };
  1930. OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp* this_) { return static_cast<int>(static_cast<const TOp*>(this_)->GetVariadicInputHomogeneity()); };
  1931. OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetVariadicOutputMinArity(); };
  1932. OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp* this_) { return static_cast<int>(static_cast<const TOp*>(this_)->GetVariadicOutputHomogeneity()); };
  1933. #ifdef __cpp_if_constexpr
  1934. if constexpr (WithStatus) {
  1935. #else
  1936. if (WithStatus) {
  1937. #endif
  1938. OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info, void** op_kernel) -> OrtStatusPtr {
  1939. return static_cast<const TOp*>(this_)->CreateKernelV2(*api, info, op_kernel);
  1940. };
  1941. OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
  1942. return static_cast<TKernel*>(op_kernel)->ComputeV2(context);
  1943. };
  1944. } else {
  1945. OrtCustomOp::CreateKernelV2 = nullptr;
  1946. OrtCustomOp::KernelComputeV2 = nullptr;
  1947. OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast<const TOp*>(this_)->CreateKernel(*api, info); };
  1948. OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
  1949. static_cast<TKernel*>(op_kernel)->Compute(context);
  1950. };
  1951. }
  1952. SetShapeInferFn<TOp>(0);
  1953. OrtCustomOp::GetStartVersion = [](const OrtCustomOp* this_) {
  1954. return static_cast<const TOp*>(this_)->start_ver_;
  1955. };
  1956. OrtCustomOp::GetEndVersion = [](const OrtCustomOp* this_) {
  1957. return static_cast<const TOp*>(this_)->end_ver_;
  1958. };
  1959. OrtCustomOp::GetMayInplace = nullptr;
  1960. OrtCustomOp::ReleaseMayInplace = nullptr;
  1961. OrtCustomOp::GetAliasMap = nullptr;
  1962. OrtCustomOp::ReleaseAliasMap = nullptr;
  1963. }
  1964. // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
  1965. const char* GetExecutionProviderType() const { return nullptr; }
  1966. // Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below
  1967. // (inputs and outputs are required by default)
  1968. OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t /*index*/) const {
  1969. return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
  1970. }
  1971. OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const {
  1972. return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
  1973. }
  1974. // Default implemention of GetInputMemoryType() that returns OrtMemTypeDefault
  1975. OrtMemType GetInputMemoryType(size_t /*index*/) const {
  1976. return OrtMemTypeDefault;
  1977. }
  1978. // Default implementation of GetVariadicInputMinArity() returns 1 to specify that a variadic input
  1979. // should expect at least 1 argument.
  1980. int GetVariadicInputMinArity() const {
  1981. return 1;
  1982. }
  1983. // Default implementation of GetVariadicInputHomegeneity() returns true to specify that all arguments
  1984. // to a variadic input should be of the same type.
  1985. bool GetVariadicInputHomogeneity() const {
  1986. return true;
  1987. }
  1988. // Default implementation of GetVariadicOutputMinArity() returns 1 to specify that a variadic output
  1989. // should produce at least 1 output value.
  1990. int GetVariadicOutputMinArity() const {
  1991. return 1;
  1992. }
  1993. // Default implementation of GetVariadicOutputHomegeneity() returns true to specify that all output values
  1994. // produced by a variadic output should be of the same type.
  1995. bool GetVariadicOutputHomogeneity() const {
  1996. return true;
  1997. }
  1998. // Declare list of session config entries used by this Custom Op.
  1999. // Implement this function in order to get configs from CustomOpBase::GetSessionConfigs().
  2000. // This default implementation returns an empty vector of config entries.
  2001. std::vector<std::string> GetSessionConfigKeys() const {
  2002. return std::vector<std::string>{};
  2003. }
  2004. template <typename C>
  2005. decltype(&C::InferOutputShape) SetShapeInferFn(decltype(&C::InferOutputShape)) {
  2006. OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp*, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
  2007. ShapeInferContext ctx(&GetApi(), ort_ctx);
  2008. return C::InferOutputShape(ctx);
  2009. };
  2010. return {};
  2011. }
  2012. template <typename C>
  2013. void SetShapeInferFn(...) {
  2014. OrtCustomOp::InferOutputShapeFn = {};
  2015. }
  2016. protected:
  2017. // Helper function that returns a map of session config entries specified by CustomOpBase::GetSessionConfigKeys.
  2018. void GetSessionConfigs(std::unordered_map<std::string, std::string>& out, ConstSessionOptions options) const;
  2019. int start_ver_ = 1;
  2020. int end_ver_ = MAX_CUSTOM_OP_END_VER;
  2021. };
  2022. } // namespace Ort
  2023. #include "onnxruntime_cxx_inline.h"