12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387 |
- // Copyright (c) Microsoft Corporation. All rights reserved.
- // Licensed under the MIT License.
- // Summary: The Ort C++ API is a header only wrapper around the Ort C API.
- //
- // The C++ API simplifies usage by returning values directly instead of error codes, throwing exceptions on errors
- // and automatically releasing resources in the destructors. The primary purpose of C++ API is exception safety so
- // all the resources follow RAII and do not leak memory.
- //
- // Each of the C++ wrapper classes holds only a pointer to the C internal object. Treat them like smart pointers.
- // To create an empty object, pass 'nullptr' to the constructor (for example, Env e{nullptr};). However, you can't use them
- // until you assign an instance that actually holds an underlying object.
- //
- // For Ort objects only move assignment between objects is allowed, there are no copy constructors.
- // Some objects have explicit 'Clone' methods for this purpose.
- //
- // ConstXXXX types are copyable since they do not own the underlying C object, so you can pass them to functions as arguments
- // by value or by reference. ConstXXXX types are restricted to const only interfaces.
- //
- // UnownedXXXX are similar to ConstXXXX but also allow non-const interfaces.
- //
- // The lifetime of the corresponding owning object must eclipse the lifetimes of the ConstXXXX/UnownedXXXX types. They exists so you do not
- // have to fallback to C types and the API with the usual pitfalls. In general, do not use C API from your C++ code.
- #pragma once
- #include "onnxruntime_c_api.h"
- #include "onnxruntime_float16.h"
- #include <cstddef>
- #include <cstdio>
- #include <array>
- #include <memory>
- #include <stdexcept>
- #include <string>
- #include <vector>
- #include <unordered_map>
- #include <utility>
- #include <type_traits>
- #ifdef ORT_NO_EXCEPTIONS
- #include <iostream>
- #endif
- /** \brief All C++ Onnxruntime APIs are defined inside this namespace
- *
- */
- namespace Ort {
- /** \brief All C++ methods that can fail will throw an exception of this type
- *
- * If <tt>ORT_NO_EXCEPTIONS</tt> is defined, then any error will result in a call to abort()
- */
- struct Exception : std::exception {
- Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {}
- OrtErrorCode GetOrtErrorCode() const { return code_; }
- const char* what() const noexcept override { return message_.c_str(); }
- private:
- std::string message_;
- OrtErrorCode code_;
- };
- #ifdef ORT_NO_EXCEPTIONS
- // The #ifndef is for the very special case where the user of this library wants to define their own way of handling errors.
- // NOTE: This header expects control flow to not continue after calling ORT_CXX_API_THROW
- #ifndef ORT_CXX_API_THROW
- #define ORT_CXX_API_THROW(string, code) \
- do { \
- std::cerr << Ort::Exception(string, code) \
- .what() \
- << std::endl; \
- abort(); \
- } while (false)
- #endif
- #else
- #define ORT_CXX_API_THROW(string, code) \
- throw Ort::Exception(string, code)
- #endif
- // This is used internally by the C++ API. This class holds the global variable that points to the OrtApi,
- // it's in a template so that we can define a global variable in a header and make
- // it transparent to the users of the API.
- template <typename T>
- struct Global {
- static const OrtApi* api_;
- };
- // If macro ORT_API_MANUAL_INIT is defined, no static initialization will be performed. Instead, user must call InitApi() before using it.
- template <typename T>
- #ifdef ORT_API_MANUAL_INIT
- const OrtApi* Global<T>::api_{};
- inline void InitApi() noexcept { Global<void>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); }
- // Used by custom operator libraries that are not linked to onnxruntime. Sets the global API object, which is
- // required by C++ APIs.
- //
- // Example mycustomop.cc:
- //
- // #define ORT_API_MANUAL_INIT
- // #include <onnxruntime_cxx_api.h>
- // #undef ORT_API_MANUAL_INIT
- //
- // OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api_base) {
- // Ort::InitApi(api_base->GetApi(ORT_API_VERSION));
- // // ...
- // }
- //
- inline void InitApi(const OrtApi* api) noexcept { Global<void>::api_ = api; }
- #else
- #if defined(_MSC_VER) && !defined(__clang__)
- #pragma warning(push)
- // "Global initializer calls a non-constexpr function." Therefore you can't use ORT APIs in the other global initializers.
- // Please define ORT_API_MANUAL_INIT if it conerns you.
- #pragma warning(disable : 26426)
- #endif
- const OrtApi* Global<T>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION);
- #if defined(_MSC_VER) && !defined(__clang__)
- #pragma warning(pop)
- #endif
- #endif
- /// This returns a reference to the OrtApi interface in use
- inline const OrtApi& GetApi() noexcept { return *Global<void>::api_; }
- /// <summary>
- /// This function returns the onnxruntime version string
- /// </summary>
- /// <returns>version string major.minor.rev</returns>
- std::string GetVersionString();
- /// <summary>
- /// This function returns the onnxruntime build information: including git branch,
- /// git commit id, build type(Debug/Release/RelWithDebInfo) and cmake cpp flags.
- /// </summary>
- /// <returns>string</returns>
- std::string GetBuildInfoString();
- /// <summary>
- /// This is a C++ wrapper for OrtApi::GetAvailableProviders() and
- /// returns a vector of strings representing the available execution providers.
- /// </summary>
- /// <returns>vector of strings</returns>
- std::vector<std::string> GetAvailableProviders();
- /** \brief IEEE 754 half-precision floating point data type
- *
- * \details This struct is used for converting float to float16 and back
- * so the user could feed inputs and fetch outputs using these type.
- *
- * The size of the structure should align with uint16_t and one can freely cast
- * uint16_t buffers to/from Ort::Float16_t to feed and retrieve data.
- *
- * \code{.unparsed}
- * // This example demonstrates converion from float to float16
- * constexpr float values[] = {1.f, 2.f, 3.f, 4.f, 5.f};
- * std::vector<Ort::Float16_t> fp16_values;
- * fp16_values.reserve(std::size(values));
- * std::transform(std::begin(values), std::end(values), std::back_inserter(fp16_values),
- * [](float value) { return Ort::Float16_t(value); });
- *
- * \endcode
- */
- struct Float16_t : onnxruntime_float16::Float16Impl<Float16_t> {
- private:
- /// <summary>
- /// Constructor from a 16-bit representation of a float16 value
- /// No conversion is done here.
- /// </summary>
- /// <param name="v">16-bit representation</param>
- constexpr explicit Float16_t(uint16_t v) noexcept { val = v; }
- public:
- using Base = onnxruntime_float16::Float16Impl<Float16_t>;
- /// <summary>
- /// Default constructor
- /// </summary>
- Float16_t() = default;
- /// <summary>
- /// Explicit conversion to uint16_t representation of float16.
- /// </summary>
- /// <param name="v">uint16_t bit representation of float16</param>
- /// <returns>new instance of Float16_t</returns>
- constexpr static Float16_t FromBits(uint16_t v) noexcept { return Float16_t(v); }
- /// <summary>
- /// __ctor from float. Float is converted into float16 16-bit representation.
- /// </summary>
- /// <param name="v">float value</param>
- explicit Float16_t(float v) noexcept { val = Base::ToUint16Impl(v); }
- /// <summary>
- /// Converts float16 to float
- /// </summary>
- /// <returns>float representation of float16 value</returns>
- float ToFloat() const noexcept { return Base::ToFloatImpl(); }
- /// <summary>
- /// Checks if the value is negative
- /// </summary>
- /// <returns>true if negative</returns>
- using Base::IsNegative;
- /// <summary>
- /// Tests if the value is NaN
- /// </summary>
- /// <returns>true if NaN</returns>
- using Base::IsNaN;
- /// <summary>
- /// Tests if the value is finite
- /// </summary>
- /// <returns>true if finite</returns>
- using Base::IsFinite;
- /// <summary>
- /// Tests if the value represents positive infinity.
- /// </summary>
- /// <returns>true if positive infinity</returns>
- using Base::IsPositiveInfinity;
- /// <summary>
- /// Tests if the value represents negative infinity
- /// </summary>
- /// <returns>true if negative infinity</returns>
- using Base::IsNegativeInfinity;
- /// <summary>
- /// Tests if the value is either positive or negative infinity.
- /// </summary>
- /// <returns>True if absolute value is infinity</returns>
- using Base::IsInfinity;
- /// <summary>
- /// Tests if the value is NaN or zero. Useful for comparisons.
- /// </summary>
- /// <returns>True if NaN or zero.</returns>
- using Base::IsNaNOrZero;
- /// <summary>
- /// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
- /// </summary>
- /// <returns>True if so</returns>
- using Base::IsNormal;
- /// <summary>
- /// Tests if the value is subnormal (denormal).
- /// </summary>
- /// <returns>True if so</returns>
- using Base::IsSubnormal;
- /// <summary>
- /// Creates an instance that represents absolute value.
- /// </summary>
- /// <returns>Absolute value</returns>
- using Base::Abs;
- /// <summary>
- /// Creates a new instance with the sign flipped.
- /// </summary>
- /// <returns>Flipped sign instance</returns>
- using Base::Negate;
- /// <summary>
- /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
- /// for two values by or'ing the private bits together and stripping the sign. They are both zero,
- /// and therefore equivalent, if the resulting value is still zero.
- /// </summary>
- /// <param name="lhs">first value</param>
- /// <param name="rhs">second value</param>
- /// <returns>True if both arguments represent zero</returns>
- using Base::AreZero;
- /// <summary>
- /// User defined conversion operator. Converts Float16_t to float.
- /// </summary>
- explicit operator float() const noexcept { return ToFloat(); }
- using Base::operator==;
- using Base::operator!=;
- using Base::operator<;
- };
- static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match");
- /** \brief bfloat16 (Brain Floating Point) data type
- *
- * \details This struct is used for converting float to bfloat16 and back
- * so the user could feed inputs and fetch outputs using these type.
- *
- * The size of the structure should align with uint16_t and one can freely cast
- * uint16_t buffers to/from Ort::BFloat16_t to feed and retrieve data.
- *
- * \code{.unparsed}
- * // This example demonstrates converion from float to float16
- * constexpr float values[] = {1.f, 2.f, 3.f, 4.f, 5.f};
- * std::vector<Ort::BFloat16_t> bfp16_values;
- * bfp16_values.reserve(std::size(values));
- * std::transform(std::begin(values), std::end(values), std::back_inserter(bfp16_values),
- * [](float value) { return Ort::BFloat16_t(value); });
- *
- * \endcode
- */
- struct BFloat16_t : onnxruntime_float16::BFloat16Impl<BFloat16_t> {
- private:
- /// <summary>
- /// Constructor from a uint16_t representation of bfloat16
- /// used in FromBits() to escape overload resolution issue with
- /// constructor from float.
- /// No conversion is done.
- /// </summary>
- /// <param name="v">16-bit bfloat16 value</param>
- constexpr explicit BFloat16_t(uint16_t v) noexcept { val = v; }
- public:
- using Base = onnxruntime_float16::BFloat16Impl<BFloat16_t>;
- BFloat16_t() = default;
- /// <summary>
- /// Explicit conversion to uint16_t representation of bfloat16.
- /// </summary>
- /// <param name="v">uint16_t bit representation of bfloat16</param>
- /// <returns>new instance of BFloat16_t</returns>
- static constexpr BFloat16_t FromBits(uint16_t v) noexcept { return BFloat16_t(v); }
- /// <summary>
- /// __ctor from float. Float is converted into bfloat16 16-bit representation.
- /// </summary>
- /// <param name="v">float value</param>
- explicit BFloat16_t(float v) noexcept { val = Base::ToUint16Impl(v); }
- /// <summary>
- /// Converts bfloat16 to float
- /// </summary>
- /// <returns>float representation of bfloat16 value</returns>
- float ToFloat() const noexcept { return Base::ToFloatImpl(); }
- /// <summary>
- /// Checks if the value is negative
- /// </summary>
- /// <returns>true if negative</returns>
- using Base::IsNegative;
- /// <summary>
- /// Tests if the value is NaN
- /// </summary>
- /// <returns>true if NaN</returns>
- using Base::IsNaN;
- /// <summary>
- /// Tests if the value is finite
- /// </summary>
- /// <returns>true if finite</returns>
- using Base::IsFinite;
- /// <summary>
- /// Tests if the value represents positive infinity.
- /// </summary>
- /// <returns>true if positive infinity</returns>
- using Base::IsPositiveInfinity;
- /// <summary>
- /// Tests if the value represents negative infinity
- /// </summary>
- /// <returns>true if negative infinity</returns>
- using Base::IsNegativeInfinity;
- /// <summary>
- /// Tests if the value is either positive or negative infinity.
- /// </summary>
- /// <returns>True if absolute value is infinity</returns>
- using Base::IsInfinity;
- /// <summary>
- /// Tests if the value is NaN or zero. Useful for comparisons.
- /// </summary>
- /// <returns>True if NaN or zero.</returns>
- using Base::IsNaNOrZero;
- /// <summary>
- /// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
- /// </summary>
- /// <returns>True if so</returns>
- using Base::IsNormal;
- /// <summary>
- /// Tests if the value is subnormal (denormal).
- /// </summary>
- /// <returns>True if so</returns>
- using Base::IsSubnormal;
- /// <summary>
- /// Creates an instance that represents absolute value.
- /// </summary>
- /// <returns>Absolute value</returns>
- using Base::Abs;
- /// <summary>
- /// Creates a new instance with the sign flipped.
- /// </summary>
- /// <returns>Flipped sign instance</returns>
- using Base::Negate;
- /// <summary>
- /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
- /// for two values by or'ing the private bits together and stripping the sign. They are both zero,
- /// and therefore equivalent, if the resulting value is still zero.
- /// </summary>
- /// <param name="lhs">first value</param>
- /// <param name="rhs">second value</param>
- /// <returns>True if both arguments represent zero</returns>
- using Base::AreZero;
- /// <summary>
- /// User defined conversion operator. Converts BFloat16_t to float.
- /// </summary>
- explicit operator float() const noexcept { return ToFloat(); }
- // We do not have an inherited impl for the below operators
- // as the internal class implements them a little differently
- bool operator==(const BFloat16_t& rhs) const noexcept;
- bool operator!=(const BFloat16_t& rhs) const noexcept { return !(*this == rhs); }
- bool operator<(const BFloat16_t& rhs) const noexcept;
- };
- static_assert(sizeof(BFloat16_t) == sizeof(uint16_t), "Sizes must match");
- /** \brief float8e4m3fn (Float8 Floating Point) data type
- * \details It is necessary for type dispatching to make use of C++ API
- * The type is implicitly convertible to/from uint8_t.
- * See https://onnx.ai/onnx/technical/float8.html for further details.
- */
- struct Float8E4M3FN_t {
- uint8_t value;
- constexpr Float8E4M3FN_t() noexcept : value(0) {}
- constexpr Float8E4M3FN_t(uint8_t v) noexcept : value(v) {}
- constexpr operator uint8_t() const noexcept { return value; }
- // nan values are treated like any other value for operator ==, !=
- constexpr bool operator==(const Float8E4M3FN_t& rhs) const noexcept { return value == rhs.value; };
- constexpr bool operator!=(const Float8E4M3FN_t& rhs) const noexcept { return value != rhs.value; };
- };
- static_assert(sizeof(Float8E4M3FN_t) == sizeof(uint8_t), "Sizes must match");
- /** \brief float8e4m3fnuz (Float8 Floating Point) data type
- * \details It is necessary for type dispatching to make use of C++ API
- * The type is implicitly convertible to/from uint8_t.
- * See https://onnx.ai/onnx/technical/float8.html for further details.
- */
- struct Float8E4M3FNUZ_t {
- uint8_t value;
- constexpr Float8E4M3FNUZ_t() noexcept : value(0) {}
- constexpr Float8E4M3FNUZ_t(uint8_t v) noexcept : value(v) {}
- constexpr operator uint8_t() const noexcept { return value; }
- // nan values are treated like any other value for operator ==, !=
- constexpr bool operator==(const Float8E4M3FNUZ_t& rhs) const noexcept { return value == rhs.value; };
- constexpr bool operator!=(const Float8E4M3FNUZ_t& rhs) const noexcept { return value != rhs.value; };
- };
- static_assert(sizeof(Float8E4M3FNUZ_t) == sizeof(uint8_t), "Sizes must match");
- /** \brief float8e5m2 (Float8 Floating Point) data type
- * \details It is necessary for type dispatching to make use of C++ API
- * The type is implicitly convertible to/from uint8_t.
- * See https://onnx.ai/onnx/technical/float8.html for further details.
- */
- struct Float8E5M2_t {
- uint8_t value;
- constexpr Float8E5M2_t() noexcept : value(0) {}
- constexpr Float8E5M2_t(uint8_t v) noexcept : value(v) {}
- constexpr operator uint8_t() const noexcept { return value; }
- // nan values are treated like any other value for operator ==, !=
- constexpr bool operator==(const Float8E5M2_t& rhs) const noexcept { return value == rhs.value; };
- constexpr bool operator!=(const Float8E5M2_t& rhs) const noexcept { return value != rhs.value; };
- };
- static_assert(sizeof(Float8E5M2_t) == sizeof(uint8_t), "Sizes must match");
- /** \brief float8e5m2fnuz (Float8 Floating Point) data type
- * \details It is necessary for type dispatching to make use of C++ API
- * The type is implicitly convertible to/from uint8_t.
- * See https://onnx.ai/onnx/technical/float8.html for further details.
- */
- struct Float8E5M2FNUZ_t {
- uint8_t value;
- constexpr Float8E5M2FNUZ_t() noexcept : value(0) {}
- constexpr Float8E5M2FNUZ_t(uint8_t v) noexcept : value(v) {}
- constexpr operator uint8_t() const noexcept { return value; }
- // nan values are treated like any other value for operator ==, !=
- constexpr bool operator==(const Float8E5M2FNUZ_t& rhs) const noexcept { return value == rhs.value; };
- constexpr bool operator!=(const Float8E5M2FNUZ_t& rhs) const noexcept { return value != rhs.value; };
- };
- static_assert(sizeof(Float8E5M2FNUZ_t) == sizeof(uint8_t), "Sizes must match");
- namespace detail {
- // This is used internally by the C++ API. This macro is to make it easy to generate overloaded methods for all of the various OrtRelease* functions for every Ort* type
- // This can't be done in the C API since C doesn't have function overloading.
- #define ORT_DEFINE_RELEASE(NAME) \
- inline void OrtRelease(Ort##NAME* ptr) { GetApi().Release##NAME(ptr); }
- ORT_DEFINE_RELEASE(Allocator);
- ORT_DEFINE_RELEASE(MemoryInfo);
- ORT_DEFINE_RELEASE(CustomOpDomain);
- ORT_DEFINE_RELEASE(ThreadingOptions);
- ORT_DEFINE_RELEASE(Env);
- ORT_DEFINE_RELEASE(RunOptions);
- ORT_DEFINE_RELEASE(Session);
- ORT_DEFINE_RELEASE(SessionOptions);
- ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo);
- ORT_DEFINE_RELEASE(SequenceTypeInfo);
- ORT_DEFINE_RELEASE(MapTypeInfo);
- ORT_DEFINE_RELEASE(TypeInfo);
- ORT_DEFINE_RELEASE(Value);
- ORT_DEFINE_RELEASE(ModelMetadata);
- ORT_DEFINE_RELEASE(IoBinding);
- ORT_DEFINE_RELEASE(ArenaCfg);
- ORT_DEFINE_RELEASE(Status);
- ORT_DEFINE_RELEASE(OpAttr);
- ORT_DEFINE_RELEASE(Op);
- ORT_DEFINE_RELEASE(KernelInfo);
- #undef ORT_DEFINE_RELEASE
- /** \brief This is a tagging template type. Use it with Base<T> to indicate that the C++ interface object
- * has no ownership of the underlying C object.
- */
- template <typename T>
- struct Unowned {
- using Type = T;
- };
- /** \brief Used internally by the C++ API. C++ wrapper types inherit from this.
- * This is a zero cost abstraction to wrap the C API objects and delete them on destruction.
- *
- * All of the C++ classes
- * a) serve as containers for pointers to objects that are created by the underlying C API.
- * Their size is just a pointer size, no need to dynamically allocate them. Use them by value.
- * b) Each of struct XXXX, XXX instances function as smart pointers to the underlying C API objects.
- * they would release objects owned automatically when going out of scope, they are move-only.
- * c) ConstXXXX and UnownedXXX structs function as non-owning, copyable containers for the above pointers.
- * ConstXXXX allow calling const interfaces only. They give access to objects that are owned by somebody else
- * such as Onnxruntime or instances of XXXX classes.
- * d) serve convenient interfaces that return C++ objects and further enhance exception and type safety so they can be used
- * in C++ code.
- *
- */
- /// <summary>
- /// This is a non-const pointer holder that is move-only. Disposes of the pointer on destruction.
- /// </summary>
- template <typename T>
- struct Base {
- using contained_type = T;
- constexpr Base() = default;
- constexpr explicit Base(contained_type* p) noexcept : p_{p} {}
- ~Base() { OrtRelease(p_); }
- Base(const Base&) = delete;
- Base& operator=(const Base&) = delete;
- Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
- Base& operator=(Base&& v) noexcept {
- OrtRelease(p_);
- p_ = v.release();
- return *this;
- }
- constexpr operator contained_type*() const noexcept { return p_; }
- /// \brief Relinquishes ownership of the contained C object pointer
- /// The underlying object is not destroyed
- contained_type* release() {
- T* p = p_;
- p_ = nullptr;
- return p;
- }
- protected:
- contained_type* p_{};
- };
- // Undefined. For const types use Base<Unowned<const T>>
- template <typename T>
- struct Base<const T>;
- /// <summary>
- /// Covers unowned pointers owned by either the ORT
- /// or some other instance of CPP wrappers.
- /// Used for ConstXXX and UnownedXXXX types that are copyable.
- /// Also convenient to wrap raw OrtXX pointers .
- /// </summary>
- /// <typeparam name="T"></typeparam>
- template <typename T>
- struct Base<Unowned<T>> {
- using contained_type = typename Unowned<T>::Type;
- constexpr Base() = default;
- constexpr explicit Base(contained_type* p) noexcept : p_{p} {}
- ~Base() = default;
- Base(const Base&) = default;
- Base& operator=(const Base&) = default;
- Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
- Base& operator=(Base&& v) noexcept {
- p_ = nullptr;
- std::swap(p_, v.p_);
- return *this;
- }
- constexpr operator contained_type*() const noexcept { return p_; }
- protected:
- contained_type* p_{};
- };
- // Light functor to release memory with OrtAllocator
- struct AllocatedFree {
- OrtAllocator* allocator_;
- explicit AllocatedFree(OrtAllocator* allocator)
- : allocator_(allocator) {}
- void operator()(void* ptr) const {
- if (ptr) allocator_->Free(allocator_, ptr);
- }
- };
- } // namespace detail
- struct AllocatorWithDefaultOptions;
- struct Env;
- struct TypeInfo;
- struct Value;
- struct ModelMetadata;
- /** \brief unique_ptr typedef used to own strings allocated by OrtAllocators
- * and release them at the end of the scope. The lifespan of the given allocator
- * must eclipse the lifespan of AllocatedStringPtr instance
- */
- using AllocatedStringPtr = std::unique_ptr<char, detail::AllocatedFree>;
- /** \brief The Status that holds ownership of OrtStatus received from C API
- * Use it to safely destroy OrtStatus* returned from the C API. Use appropriate
- * constructors to construct an instance of a Status object from exceptions.
- */
- struct Status : detail::Base<OrtStatus> {
- explicit Status(std::nullptr_t) noexcept {} ///< Create an empty object, must be assigned a valid one to be used
- explicit Status(OrtStatus* status) noexcept; ///< Takes ownership of OrtStatus instance returned from the C API.
- explicit Status(const Exception&) noexcept; ///< Creates status instance out of exception
- explicit Status(const std::exception&) noexcept; ///< Creates status instance out of exception
- Status(const char* message, OrtErrorCode code) noexcept; ///< Creates status instance out of null-terminated string message.
- std::string GetErrorMessage() const;
- OrtErrorCode GetErrorCode() const;
- bool IsOK() const noexcept; ///< Returns true if instance represents an OK (non-error) status.
- };
- /** \brief The ThreadingOptions
- *
- * The ThreadingOptions used for set global threadpools' options of The Env.
- */
- struct ThreadingOptions : detail::Base<OrtThreadingOptions> {
- /// \brief Wraps OrtApi::CreateThreadingOptions
- ThreadingOptions();
- /// \brief Wraps OrtApi::SetGlobalIntraOpNumThreads
- ThreadingOptions& SetGlobalIntraOpNumThreads(int intra_op_num_threads);
- /// \brief Wraps OrtApi::SetGlobalInterOpNumThreads
- ThreadingOptions& SetGlobalInterOpNumThreads(int inter_op_num_threads);
- /// \brief Wraps OrtApi::SetGlobalSpinControl
- ThreadingOptions& SetGlobalSpinControl(int allow_spinning);
- /// \brief Wraps OrtApi::SetGlobalDenormalAsZero
- ThreadingOptions& SetGlobalDenormalAsZero();
- /// \brief Wraps OrtApi::SetGlobalCustomCreateThreadFn
- ThreadingOptions& SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn);
- /// \brief Wraps OrtApi::SetGlobalCustomThreadCreationOptions
- ThreadingOptions& SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options);
- /// \brief Wraps OrtApi::SetGlobalCustomJoinThreadFn
- ThreadingOptions& SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn);
- };
- /** \brief The Env (Environment)
- *
- * The Env holds the logging state used by all other objects.
- * <b>Note:</b> One Env must be created before using any other Onnxruntime functionality
- */
- struct Env : detail::Base<OrtEnv> {
- explicit Env(std::nullptr_t) {} ///< Create an empty Env object, must be assigned a valid one to be used
- /// \brief Wraps OrtApi::CreateEnv
- Env(OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
- /// \brief Wraps OrtApi::CreateEnvWithCustomLogger
- Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param);
- /// \brief Wraps OrtApi::CreateEnvWithGlobalThreadPools
- Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
- /// \brief Wraps OrtApi::CreateEnvWithCustomLoggerAndGlobalThreadPools
- Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param,
- OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
- /// \brief C Interop Helper
- explicit Env(OrtEnv* p) : Base<OrtEnv>{p} {}
- Env& EnableTelemetryEvents(); ///< Wraps OrtApi::EnableTelemetryEvents
- Env& DisableTelemetryEvents(); ///< Wraps OrtApi::DisableTelemetryEvents
- Env& UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level); ///< Wraps OrtApi::UpdateEnvWithCustomLogLevel
- Env& CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocator
- Env& CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo* mem_info, const std::unordered_map<std::string, std::string>& options, const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocatorV2
- };
- /** \brief Custom Op Domain
- *
- */
- struct CustomOpDomain : detail::Base<OrtCustomOpDomain> {
- explicit CustomOpDomain(std::nullptr_t) {} ///< Create an empty CustomOpDomain object, must be assigned a valid one to be used
- /// \brief Wraps OrtApi::CreateCustomOpDomain
- explicit CustomOpDomain(const char* domain);
- // This does not take ownership of the op, simply registers it.
- void Add(const OrtCustomOp* op); ///< Wraps CustomOpDomain_Add
- };
- /** \brief RunOptions
- *
- */
- struct RunOptions : detail::Base<OrtRunOptions> {
- explicit RunOptions(std::nullptr_t) {} ///< Create an empty RunOptions object, must be assigned a valid one to be used
- RunOptions(); ///< Wraps OrtApi::CreateRunOptions
- RunOptions& SetRunLogVerbosityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogVerbosityLevel
- int GetRunLogVerbosityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogVerbosityLevel
- RunOptions& SetRunLogSeverityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogSeverityLevel
- int GetRunLogSeverityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogSeverityLevel
- RunOptions& SetRunTag(const char* run_tag); ///< wraps OrtApi::RunOptionsSetRunTag
- const char* GetRunTag() const; ///< Wraps OrtApi::RunOptionsGetRunTag
- RunOptions& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddRunConfigEntry
- /** \brief Terminates all currently executing Session::Run calls that were made using this RunOptions instance
- *
- * If a currently executing session needs to be force terminated, this can be called from another thread to force it to fail with an error
- * Wraps OrtApi::RunOptionsSetTerminate
- */
- RunOptions& SetTerminate();
- /** \brief Clears the terminate flag so this RunOptions instance can be used in a new Session::Run call without it instantly terminating
- *
- * Wraps OrtApi::RunOptionsUnsetTerminate
- */
- RunOptions& UnsetTerminate();
- };
- namespace detail {
- // Utility function that returns a SessionOption config entry key for a specific custom operator.
- // Ex: custom_op.[custom_op_name].[config]
- std::string MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config);
- } // namespace detail
- /// <summary>
- /// Class that represents session configuration entries for one or more custom operators.
- ///
- /// Example:
- /// Ort::CustomOpConfigs op_configs;
- /// op_configs.AddConfig("my_custom_op", "device_type", "CPU");
- ///
- /// Passed to Ort::SessionOptions::RegisterCustomOpsLibrary.
- /// </summary>
- struct CustomOpConfigs {
- CustomOpConfigs() = default;
- ~CustomOpConfigs() = default;
- CustomOpConfigs(const CustomOpConfigs&) = default;
- CustomOpConfigs& operator=(const CustomOpConfigs&) = default;
- CustomOpConfigs(CustomOpConfigs&& o) = default;
- CustomOpConfigs& operator=(CustomOpConfigs&& o) = default;
- /** \brief Adds a session configuration entry/value for a specific custom operator.
- *
- * \param custom_op_name The name of the custom operator for which to add a configuration entry.
- * Must match the name returned by the CustomOp's GetName() method.
- * \param config_key The name of the configuration entry.
- * \param config_value The value of the configuration entry.
- * \return A reference to this object to enable call chaining.
- */
- CustomOpConfigs& AddConfig(const char* custom_op_name, const char* config_key, const char* config_value);
- /** \brief Returns a flattened map of custom operator configuration entries and their values.
- *
- * The keys has been flattened to include both the custom operator name and the configuration entry key name.
- * For example, a prior call to AddConfig("my_op", "key", "value") corresponds to the flattened key/value pair
- * {"my_op.key", "value"}.
- *
- * \return An unordered map of flattened configurations.
- */
- const std::unordered_map<std::string, std::string>& GetFlattenedConfigs() const;
- private:
- std::unordered_map<std::string, std::string> flat_configs_;
- };
- /** \brief Options object used when creating a new Session object
- *
- * Wraps ::OrtSessionOptions object and methods
- */
- struct SessionOptions;
- namespace detail {
- // we separate const-only methods because passing const ptr to non-const methods
- // is only discovered when inline methods are compiled which is counter-intuitive
- template <typename T>
- struct ConstSessionOptionsImpl : Base<T> {
- using B = Base<T>;
- using B::B;
- SessionOptions Clone() const; ///< Creates and returns a copy of this SessionOptions object. Wraps OrtApi::CloneSessionOptions
- std::string GetConfigEntry(const char* config_key) const; ///< Wraps OrtApi::GetSessionConfigEntry
- bool HasConfigEntry(const char* config_key) const; ///< Wraps OrtApi::HasSessionConfigEntry
- std::string GetConfigEntryOrDefault(const char* config_key, const std::string& def);
- };
- template <typename T>
- struct SessionOptionsImpl : ConstSessionOptionsImpl<T> {
- using B = ConstSessionOptionsImpl<T>;
- using B::B;
- SessionOptionsImpl& SetIntraOpNumThreads(int intra_op_num_threads); ///< Wraps OrtApi::SetIntraOpNumThreads
- SessionOptionsImpl& SetInterOpNumThreads(int inter_op_num_threads); ///< Wraps OrtApi::SetInterOpNumThreads
- SessionOptionsImpl& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); ///< Wraps OrtApi::SetSessionGraphOptimizationLevel
- SessionOptionsImpl& SetDeterministicCompute(bool value); ///< Wraps OrtApi::SetDeterministicCompute
- SessionOptionsImpl& EnableCpuMemArena(); ///< Wraps OrtApi::EnableCpuMemArena
- SessionOptionsImpl& DisableCpuMemArena(); ///< Wraps OrtApi::DisableCpuMemArena
- SessionOptionsImpl& SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_file); ///< Wraps OrtApi::SetOptimizedModelFilePath
- SessionOptionsImpl& EnableProfiling(const ORTCHAR_T* profile_file_prefix); ///< Wraps OrtApi::EnableProfiling
- SessionOptionsImpl& DisableProfiling(); ///< Wraps OrtApi::DisableProfiling
- SessionOptionsImpl& EnableOrtCustomOps(); ///< Wraps OrtApi::EnableOrtCustomOps
- SessionOptionsImpl& EnableMemPattern(); ///< Wraps OrtApi::EnableMemPattern
- SessionOptionsImpl& DisableMemPattern(); ///< Wraps OrtApi::DisableMemPattern
- SessionOptionsImpl& SetExecutionMode(ExecutionMode execution_mode); ///< Wraps OrtApi::SetSessionExecutionMode
- SessionOptionsImpl& SetLogId(const char* logid); ///< Wraps OrtApi::SetSessionLogId
- SessionOptionsImpl& SetLogSeverityLevel(int level); ///< Wraps OrtApi::SetSessionLogSeverityLevel
- SessionOptionsImpl& Add(OrtCustomOpDomain* custom_op_domain); ///< Wraps OrtApi::AddCustomOpDomain
- SessionOptionsImpl& DisablePerSessionThreads(); ///< Wraps OrtApi::DisablePerSessionThreads
- SessionOptionsImpl& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddSessionConfigEntry
- SessionOptionsImpl& AddInitializer(const char* name, const OrtValue* ort_val); ///< Wraps OrtApi::AddInitializer
- SessionOptionsImpl& AddExternalInitializers(const std::vector<std::string>& names, const std::vector<Value>& ort_values); ///< Wraps OrtApi::AddExternalInitializers
- SessionOptionsImpl& AddExternalInitializersFromFilesInMemory(const std::vector<std::basic_string<ORTCHAR_T>>& external_initializer_file_names,
- const std::vector<char*>& external_initializer_file_buffer_array,
- const std::vector<size_t>& external_initializer_file_lengths); ///< Wraps OrtApi::AddExternalInitializersFromFilesInMemory
- SessionOptionsImpl& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA
- SessionOptionsImpl& AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA_V2
- SessionOptionsImpl& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM
- SessionOptionsImpl& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO
- ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO_V2
- SessionOptionsImpl& AppendExecutionProvider_OpenVINO_V2(const std::unordered_map<std::string, std::string>& provider_options = {});
- SessionOptionsImpl& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
- SessionOptionsImpl& AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
- SessionOptionsImpl& AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX
- ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CANN
- SessionOptionsImpl& AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options);
- ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_Dnnl
- SessionOptionsImpl& AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options);
- /// Wraps OrtApi::SessionOptionsAppendExecutionProvider. Currently supports QNN, SNPE and XNNPACK.
- SessionOptionsImpl& AppendExecutionProvider(const std::string& provider_name,
- const std::unordered_map<std::string, std::string>& provider_options = {});
- SessionOptionsImpl& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn
- SessionOptionsImpl& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options); ///< Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions
- SessionOptionsImpl& SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomJoinThreadFn
- ///< Registers the custom operator from the specified shared library via OrtApi::RegisterCustomOpsLibrary_V2.
- ///< The custom operator configurations are optional. If provided, custom operator configs are set via
- ///< OrtApi::AddSessionConfigEntry.
- SessionOptionsImpl& RegisterCustomOpsLibrary(const ORTCHAR_T* library_name, const CustomOpConfigs& custom_op_configs = {});
- SessionOptionsImpl& RegisterCustomOpsUsingFunction(const char* function_name); ///< Wraps OrtApi::RegisterCustomOpsUsingFunction
- ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_VitisAI
- SessionOptionsImpl& AppendExecutionProvider_VitisAI(const std::unordered_map<std::string, std::string>& provider_options = {});
- };
- } // namespace detail
- using UnownedSessionOptions = detail::SessionOptionsImpl<detail::Unowned<OrtSessionOptions>>;
- using ConstSessionOptions = detail::ConstSessionOptionsImpl<detail::Unowned<const OrtSessionOptions>>;
- /** \brief Wrapper around ::OrtSessionOptions
- *
- */
- struct SessionOptions : detail::SessionOptionsImpl<OrtSessionOptions> {
- explicit SessionOptions(std::nullptr_t) {} ///< Create an empty SessionOptions object, must be assigned a valid one to be used
- SessionOptions(); ///< Wraps OrtApi::CreateSessionOptions
- explicit SessionOptions(OrtSessionOptions* p) : SessionOptionsImpl<OrtSessionOptions>{p} {} ///< Used for interop with the C API
- UnownedSessionOptions GetUnowned() const { return UnownedSessionOptions{this->p_}; }
- ConstSessionOptions GetConst() const { return ConstSessionOptions{this->p_}; }
- };
- /** \brief Wrapper around ::OrtModelMetadata
- *
- */
- struct ModelMetadata : detail::Base<OrtModelMetadata> {
- explicit ModelMetadata(std::nullptr_t) {} ///< Create an empty ModelMetadata object, must be assigned a valid one to be used
- explicit ModelMetadata(OrtModelMetadata* p) : Base<OrtModelMetadata>{p} {} ///< Used for interop with the C API
- /** \brief Returns a copy of the producer name.
- *
- * \param allocator to allocate memory for the copy of the name returned
- * \return a instance of smart pointer that would deallocate the buffer when out of scope.
- * The OrtAllocator instances must be valid at the point of memory release.
- */
- AllocatedStringPtr GetProducerNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetProducerName
- /** \brief Returns a copy of the graph name.
- *
- * \param allocator to allocate memory for the copy of the name returned
- * \return a instance of smart pointer that would deallocate the buffer when out of scope.
- * The OrtAllocator instances must be valid at the point of memory release.
- */
- AllocatedStringPtr GetGraphNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphName
- /** \brief Returns a copy of the domain name.
- *
- * \param allocator to allocate memory for the copy of the name returned
- * \return a instance of smart pointer that would deallocate the buffer when out of scope.
- * The OrtAllocator instances must be valid at the point of memory release.
- */
- AllocatedStringPtr GetDomainAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDomain
- /** \brief Returns a copy of the description.
- *
- * \param allocator to allocate memory for the copy of the string returned
- * \return a instance of smart pointer that would deallocate the buffer when out of scope.
- * The OrtAllocator instances must be valid at the point of memory release.
- */
- AllocatedStringPtr GetDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDescription
- /** \brief Returns a copy of the graph description.
- *
- * \param allocator to allocate memory for the copy of the string returned
- * \return a instance of smart pointer that would deallocate the buffer when out of scope.
- * The OrtAllocator instances must be valid at the point of memory release.
- */
- AllocatedStringPtr GetGraphDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphDescription
- /** \brief Returns a vector of copies of the custom metadata keys.
- *
- * \param allocator to allocate memory for the copy of the string returned
- * \return a instance std::vector of smart pointers that would deallocate the buffers when out of scope.
- * The OrtAllocator instance must be valid at the point of memory release.
- */
- std::vector<AllocatedStringPtr> GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetCustomMetadataMapKeys
- /** \brief Looks up a value by a key in the Custom Metadata map
- *
- * \param key zero terminated string key to lookup
- * \param allocator to allocate memory for the copy of the string returned
- * \return a instance of smart pointer that would deallocate the buffer when out of scope.
- * maybe nullptr if key is not found.
- *
- * The OrtAllocator instances must be valid at the point of memory release.
- */
- AllocatedStringPtr LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataLookupCustomMetadataMap
- int64_t GetVersion() const; ///< Wraps OrtApi::ModelMetadataGetVersion
- };
- struct IoBinding;
- namespace detail {
- // we separate const-only methods because passing const ptr to non-const methods
- // is only discovered when inline methods are compiled which is counter-intuitive
- template <typename T>
- struct ConstSessionImpl : Base<T> {
- using B = Base<T>;
- using B::B;
- size_t GetInputCount() const; ///< Returns the number of model inputs
- size_t GetOutputCount() const; ///< Returns the number of model outputs
- size_t GetOverridableInitializerCount() const; ///< Returns the number of inputs that have defaults that can be overridden
- /** \brief Returns a copy of input name at the specified index.
- *
- * \param index must less than the value returned by GetInputCount()
- * \param allocator to allocate memory for the copy of the name returned
- * \return a instance of smart pointer that would deallocate the buffer when out of scope.
- * The OrtAllocator instances must be valid at the point of memory release.
- */
- AllocatedStringPtr GetInputNameAllocated(size_t index, OrtAllocator* allocator) const;
- /** \brief Returns a copy of output name at then specified index.
- *
- * \param index must less than the value returned by GetOutputCount()
- * \param allocator to allocate memory for the copy of the name returned
- * \return a instance of smart pointer that would deallocate the buffer when out of scope.
- * The OrtAllocator instances must be valid at the point of memory release.
- */
- AllocatedStringPtr GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const;
- /** \brief Returns a copy of the overridable initializer name at then specified index.
- *
- * \param index must less than the value returned by GetOverridableInitializerCount()
- * \param allocator to allocate memory for the copy of the name returned
- * \return a instance of smart pointer that would deallocate the buffer when out of scope.
- * The OrtAllocator instances must be valid at the point of memory release.
- */
- AllocatedStringPtr GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetOverridableInitializerName
- uint64_t GetProfilingStartTimeNs() const; ///< Wraps OrtApi::SessionGetProfilingStartTimeNs
- ModelMetadata GetModelMetadata() const; ///< Wraps OrtApi::SessionGetModelMetadata
- TypeInfo GetInputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetInputTypeInfo
- TypeInfo GetOutputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOutputTypeInfo
- TypeInfo GetOverridableInitializerTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOverridableInitializerTypeInfo
- };
- template <typename T>
- struct SessionImpl : ConstSessionImpl<T> {
- using B = ConstSessionImpl<T>;
- using B::B;
- /** \brief Run the model returning results in an Ort allocated vector.
- *
- * Wraps OrtApi::Run
- *
- * The caller provides a list of inputs and a list of the desired outputs to return.
- *
- * See the output logs for more information on warnings/errors that occur while processing the model.
- * Common errors are.. (TODO)
- *
- * \param[in] run_options
- * \param[in] input_names Array of null terminated strings of length input_count that is the list of input names
- * \param[in] input_values Array of Value objects of length input_count that is the list of input values
- * \param[in] input_count Number of inputs (the size of the input_names & input_values arrays)
- * \param[in] output_names Array of C style strings of length output_count that is the list of output names
- * \param[in] output_count Number of outputs (the size of the output_names array)
- * \return A std::vector of Value objects that directly maps to the output_names array (eg. output_name[0] is the first entry of the returned vector)
- */
- std::vector<Value> Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
- const char* const* output_names, size_t output_count);
- /** \brief Run the model returning results in user provided outputs
- * Same as Run(const RunOptions&, const char* const*, const Value*, size_t,const char* const*, size_t)
- */
- void Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
- const char* const* output_names, Value* output_values, size_t output_count);
- void Run(const RunOptions& run_options, const IoBinding&); ///< Wraps OrtApi::RunWithBinding
- /** \brief Run the model asynchronously in a thread owned by intra op thread pool
- *
- * Wraps OrtApi::RunAsync
- *
- * \param[in] run_options
- * \param[in] input_names Array of null terminated UTF8 encoded strings of the input names
- * \param[in] input_values Array of Value objects of length input_count
- * \param[in] input_count Number of elements in the input_names and inputs arrays
- * \param[in] output_names Array of null terminated UTF8 encoded strings of the output names
- * \param[out] output_values Array of provided Values to be filled with outputs.
- * On calling RunAsync, output_values[i] could either be initialized by a null pointer or a preallocated OrtValue*.
- * Later, on invoking the callback, each output_values[i] of null will be filled with an OrtValue* allocated by onnxruntime.
- * Then, an OrtValue** pointer will be casted from output_values, and pass to the callback.
- * NOTE: it is customer's duty to finally release output_values and each of its member,
- * regardless of whether the member (Ort::Value) is allocated by onnxruntime or preallocated by the customer.
- * \param[in] output_count Number of elements in the output_names and outputs array
- * \param[in] callback Callback function on model run completion
- * \param[in] user_data User data that pass back to the callback
- */
- void RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
- const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data);
- /** \brief End profiling and return a copy of the profiling file name.
- *
- * \param allocator to allocate memory for the copy of the string returned
- * \return a instance of smart pointer that would deallocate the buffer when out of scope.
- * The OrtAllocator instances must be valid at the point of memory release.
- */
- AllocatedStringPtr EndProfilingAllocated(OrtAllocator* allocator); ///< Wraps OrtApi::SessionEndProfiling
- };
- } // namespace detail
- using ConstSession = detail::ConstSessionImpl<detail::Unowned<const OrtSession>>;
- using UnownedSession = detail::SessionImpl<detail::Unowned<OrtSession>>;
- /** \brief Wrapper around ::OrtSession
- *
- */
- struct Session : detail::SessionImpl<OrtSession> {
- explicit Session(std::nullptr_t) {} ///< Create an empty Session object, must be assigned a valid one to be used
- Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); ///< Wraps OrtApi::CreateSession
- Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options,
- OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer
- Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); ///< Wraps OrtApi::CreateSessionFromArray
- Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options,
- OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionFromArrayWithPrepackedWeightsContainer
- ConstSession GetConst() const { return ConstSession{this->p_}; }
- UnownedSession GetUnowned() const { return UnownedSession{this->p_}; }
- };
- namespace detail {
- template <typename T>
- struct MemoryInfoImpl : Base<T> {
- using B = Base<T>;
- using B::B;
- std::string GetAllocatorName() const;
- OrtAllocatorType GetAllocatorType() const;
- int GetDeviceId() const;
- OrtMemoryInfoDeviceType GetDeviceType() const;
- OrtMemType GetMemoryType() const;
- template <typename U>
- bool operator==(const MemoryInfoImpl<U>& o) const;
- };
- } // namespace detail
- // Const object holder that does not own the underlying object
- using ConstMemoryInfo = detail::MemoryInfoImpl<detail::Unowned<const OrtMemoryInfo>>;
- /** \brief Wrapper around ::OrtMemoryInfo
- *
- */
- struct MemoryInfo : detail::MemoryInfoImpl<OrtMemoryInfo> {
- static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1);
- explicit MemoryInfo(std::nullptr_t) {} ///< No instance is created
- explicit MemoryInfo(OrtMemoryInfo* p) : MemoryInfoImpl<OrtMemoryInfo>{p} {} ///< Take ownership of a pointer created by C Api
- MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type);
- ConstMemoryInfo GetConst() const { return ConstMemoryInfo{this->p_}; }
- };
- namespace detail {
- template <typename T>
- struct TensorTypeAndShapeInfoImpl : Base<T> {
- using B = Base<T>;
- using B::B;
- ONNXTensorElementDataType GetElementType() const; ///< Wraps OrtApi::GetTensorElementType
- size_t GetElementCount() const; ///< Wraps OrtApi::GetTensorShapeElementCount
- size_t GetDimensionsCount() const; ///< Wraps OrtApi::GetDimensionsCount
- /** \deprecated use GetShape() returning std::vector
- * [[deprecated]]
- * This interface is unsafe to use
- */
- [[deprecated("use GetShape()")]] void GetDimensions(int64_t* values, size_t values_count) const; ///< Wraps OrtApi::GetDimensions
- void GetSymbolicDimensions(const char** values, size_t values_count) const; ///< Wraps OrtApi::GetSymbolicDimensions
- std::vector<int64_t> GetShape() const; ///< Uses GetDimensionsCount & GetDimensions to return a std::vector of the shape
- };
- } // namespace detail
- using ConstTensorTypeAndShapeInfo = detail::TensorTypeAndShapeInfoImpl<detail::Unowned<const OrtTensorTypeAndShapeInfo>>;
- /** \brief Wrapper around ::OrtTensorTypeAndShapeInfo
- *
- */
- struct TensorTypeAndShapeInfo : detail::TensorTypeAndShapeInfoImpl<OrtTensorTypeAndShapeInfo> {
- explicit TensorTypeAndShapeInfo(std::nullptr_t) {} ///< Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used
- explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : TensorTypeAndShapeInfoImpl{p} {} ///< Used for interop with the C API
- ConstTensorTypeAndShapeInfo GetConst() const { return ConstTensorTypeAndShapeInfo{this->p_}; }
- };
- namespace detail {
- template <typename T>
- struct SequenceTypeInfoImpl : Base<T> {
- using B = Base<T>;
- using B::B;
- TypeInfo GetSequenceElementType() const; ///< Wraps OrtApi::GetSequenceElementType
- };
- } // namespace detail
- using ConstSequenceTypeInfo = detail::SequenceTypeInfoImpl<detail::Unowned<const OrtSequenceTypeInfo>>;
- /** \brief Wrapper around ::OrtSequenceTypeInfo
- *
- */
- struct SequenceTypeInfo : detail::SequenceTypeInfoImpl<OrtSequenceTypeInfo> {
- explicit SequenceTypeInfo(std::nullptr_t) {} ///< Create an empty SequenceTypeInfo object, must be assigned a valid one to be used
- explicit SequenceTypeInfo(OrtSequenceTypeInfo* p) : SequenceTypeInfoImpl<OrtSequenceTypeInfo>{p} {} ///< Used for interop with the C API
- ConstSequenceTypeInfo GetConst() const { return ConstSequenceTypeInfo{this->p_}; }
- };
- namespace detail {
- template <typename T>
- struct OptionalTypeInfoImpl : Base<T> {
- using B = Base<T>;
- using B::B;
- TypeInfo GetOptionalElementType() const; ///< Wraps OrtApi::CastOptionalTypeToContainedTypeInfo
- };
- } // namespace detail
- // This is always owned by the TypeInfo and can only be obtained from it.
- using ConstOptionalTypeInfo = detail::OptionalTypeInfoImpl<detail::Unowned<const OrtOptionalTypeInfo>>;
- namespace detail {
- template <typename T>
- struct MapTypeInfoImpl : detail::Base<T> {
- using B = Base<T>;
- using B::B;
- ONNXTensorElementDataType GetMapKeyType() const; ///< Wraps OrtApi::GetMapKeyType
- TypeInfo GetMapValueType() const; ///< Wraps OrtApi::GetMapValueType
- };
- } // namespace detail
- using ConstMapTypeInfo = detail::MapTypeInfoImpl<detail::Unowned<const OrtMapTypeInfo>>;
- /** \brief Wrapper around ::OrtMapTypeInfo
- *
- */
- struct MapTypeInfo : detail::MapTypeInfoImpl<OrtMapTypeInfo> {
- explicit MapTypeInfo(std::nullptr_t) {} ///< Create an empty MapTypeInfo object, must be assigned a valid one to be used
- explicit MapTypeInfo(OrtMapTypeInfo* p) : MapTypeInfoImpl<OrtMapTypeInfo>{p} {} ///< Used for interop with the C API
- ConstMapTypeInfo GetConst() const { return ConstMapTypeInfo{this->p_}; }
- };
- namespace detail {
- template <typename T>
- struct TypeInfoImpl : detail::Base<T> {
- using B = Base<T>;
- using B::B;
- ConstTensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const; ///< Wraps OrtApi::CastTypeInfoToTensorInfo
- ConstSequenceTypeInfo GetSequenceTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToSequenceTypeInfo
- ConstMapTypeInfo GetMapTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToMapTypeInfo
- ConstOptionalTypeInfo GetOptionalTypeInfo() const; ///< wraps OrtApi::CastTypeInfoToOptionalTypeInfo
- ONNXType GetONNXType() const;
- };
- } // namespace detail
- /// <summary>
- /// Contains a constant, unowned OrtTypeInfo that can be copied and passed around by value.
- /// Provides access to const OrtTypeInfo APIs.
- /// </summary>
- using ConstTypeInfo = detail::TypeInfoImpl<detail::Unowned<const OrtTypeInfo>>;
- /// <summary>
- /// Type information that may contain either TensorTypeAndShapeInfo or
- /// the information about contained sequence or map depending on the ONNXType.
- /// </summary>
- struct TypeInfo : detail::TypeInfoImpl<OrtTypeInfo> {
- explicit TypeInfo(std::nullptr_t) {} ///< Create an empty TypeInfo object, must be assigned a valid one to be used
- explicit TypeInfo(OrtTypeInfo* p) : TypeInfoImpl<OrtTypeInfo>{p} {} ///< C API Interop
- ConstTypeInfo GetConst() const { return ConstTypeInfo{this->p_}; }
- };
- namespace detail {
- // This structure is used to feed sparse tensor values
- // information for use with FillSparseTensor<Format>() API
- // if the data type for the sparse tensor values is numeric
- // use data.p_data, otherwise, use data.str pointer to feed
- // values. data.str is an array of const char* that are zero terminated.
- // number of strings in the array must match shape size.
- // For fully sparse tensors use shape {0} and set p_data/str
- // to nullptr.
- struct OrtSparseValuesParam {
- const int64_t* values_shape;
- size_t values_shape_len;
- union {
- const void* p_data;
- const char** str;
- } data;
- };
- // Provides a way to pass shape in a single
- // argument
- struct Shape {
- const int64_t* shape;
- size_t shape_len;
- };
- template <typename T>
- struct ConstValueImpl : Base<T> {
- using B = Base<T>;
- using B::B;
- /// <summary>
- /// Obtains a pointer to a user defined data for experimental purposes
- /// </summary>
- template <typename R>
- void GetOpaqueData(const char* domain, const char* type_name, R&) const; ///< Wraps OrtApi::GetOpaqueValue
- bool IsTensor() const; ///< Returns true if Value is a tensor, false for other types like map/sequence/etc
- bool HasValue() const; /// < Return true if OrtValue contains data and returns false if the OrtValue is a None
- size_t GetCount() const; // If a non tensor, returns 2 for map and N for sequence, where N is the number of elements
- Value GetValue(int index, OrtAllocator* allocator) const;
- /// <summary>
- /// This API returns a full length of string data contained within either a tensor or a sparse Tensor.
- /// For sparse tensor it returns a full length of stored non-empty strings (values). The API is useful
- /// for allocating necessary memory and calling GetStringTensorContent().
- /// </summary>
- /// <returns>total length of UTF-8 encoded bytes contained. No zero terminators counted.</returns>
- size_t GetStringTensorDataLength() const;
- /// <summary>
- /// The API copies all of the UTF-8 encoded string data contained within a tensor or a sparse tensor
- /// into a supplied buffer. Use GetStringTensorDataLength() to find out the length of the buffer to allocate.
- /// The user must also allocate offsets buffer with the number of entries equal to that of the contained
- /// strings.
- ///
- /// Strings are always assumed to be on CPU, no X-device copy.
- /// </summary>
- /// <param name="buffer">user allocated buffer</param>
- /// <param name="buffer_length">length in bytes of the allocated buffer</param>
- /// <param name="offsets">a pointer to the offsets user allocated buffer</param>
- /// <param name="offsets_count">count of offsets, must be equal to the number of strings contained.
- /// that can be obtained from the shape of the tensor or from GetSparseTensorValuesTypeAndShapeInfo()
- /// for sparse tensors</param>
- void GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const;
- /// <summary>
- /// Returns a const typed pointer to the tensor contained data.
- /// No type checking is performed, the caller must ensure the type matches the tensor type.
- /// </summary>
- /// <typeparam name="T"></typeparam>
- /// <returns>const pointer to data, no copies made</returns>
- template <typename R>
- const R* GetTensorData() const; ///< Wraps OrtApi::GetTensorMutableData /// <summary>
- /// <summary>
- /// Returns a non-typed pointer to a tensor contained data.
- /// </summary>
- /// <returns>const pointer to data, no copies made</returns>
- const void* GetTensorRawData() const;
- /// <summary>
- /// The API returns type information for data contained in a tensor. For sparse
- /// tensors it returns type information for contained non-zero values.
- /// It returns dense shape for sparse tensors.
- /// </summary>
- /// <returns>TypeInfo</returns>
- TypeInfo GetTypeInfo() const;
- /// <summary>
- /// The API returns type information for data contained in a tensor. For sparse
- /// tensors it returns type information for contained non-zero values.
- /// It returns dense shape for sparse tensors.
- /// </summary>
- /// <returns>TensorTypeAndShapeInfo</returns>
- TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const;
- /// <summary>
- /// This API returns information about the memory allocation used to hold data.
- /// </summary>
- /// <returns>Non owning instance of MemoryInfo</returns>
- ConstMemoryInfo GetTensorMemoryInfo() const;
- /// <summary>
- /// The API copies UTF-8 encoded bytes for the requested string element
- /// contained within a tensor or a sparse tensor into a provided buffer.
- /// Use GetStringTensorElementLength() to obtain the length of the buffer to allocate.
- /// </summary>
- /// <param name="buffer_length"></param>
- /// <param name="element_index"></param>
- /// <param name="buffer"></param>
- void GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const;
- /// <summary>
- /// Returns string tensor UTF-8 encoded string element.
- /// Use of this API is recommended over GetStringTensorElement() that takes void* buffer pointer.
- /// </summary>
- /// <param name="element_index"></param>
- /// <returns>std::string</returns>
- std::string GetStringTensorElement(size_t element_index) const;
- /// <summary>
- /// The API returns a byte length of UTF-8 encoded string element
- /// contained in either a tensor or a spare tensor values.
- /// </summary>
- /// <param name="element_index"></param>
- /// <returns>byte length for the specified string element</returns>
- size_t GetStringTensorElementLength(size_t element_index) const;
- #if !defined(DISABLE_SPARSE_TENSORS)
- /// <summary>
- /// The API returns the sparse data format this OrtValue holds in a sparse tensor.
- /// If the sparse tensor was not fully constructed, i.e. Use*() or Fill*() API were not used
- /// the value returned is ORT_SPARSE_UNDEFINED.
- /// </summary>
- /// <returns>Format enum</returns>
- OrtSparseFormat GetSparseFormat() const;
- /// <summary>
- /// The API returns type and shape information for stored non-zero values of the
- /// sparse tensor. Use GetSparseTensorValues() to obtain values buffer pointer.
- /// </summary>
- /// <returns>TensorTypeAndShapeInfo values information</returns>
- TensorTypeAndShapeInfo GetSparseTensorValuesTypeAndShapeInfo() const;
- /// <summary>
- /// The API returns type and shape information for the specified indices. Each supported
- /// indices have their own enum values even if a give format has more than one kind of indices.
- /// Use GetSparseTensorIndicesData() to obtain pointer to indices buffer.
- /// </summary>
- /// <param name="format">enum requested</param>
- /// <returns>type and shape information</returns>
- TensorTypeAndShapeInfo GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat format) const;
- /// <summary>
- /// The API retrieves a pointer to the internal indices buffer. The API merely performs
- /// a convenience data type casting on the return type pointer. Make sure you are requesting
- /// the right type, use GetSparseTensorIndicesTypeShapeInfo();
- /// </summary>
- /// <typeparam name="T">type to cast to</typeparam>
- /// <param name="indices_format">requested indices kind</param>
- /// <param name="num_indices">number of indices entries</param>
- /// <returns>Pinter to the internal sparse tensor buffer containing indices. Do not free this pointer.</returns>
- template <typename R>
- const R* GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const;
- /// <summary>
- /// Returns true if the OrtValue contains a sparse tensor
- /// </summary>
- /// <returns></returns>
- bool IsSparseTensor() const;
- /// <summary>
- /// The API returns a pointer to an internal buffer of the sparse tensor
- /// containing non-zero values. The API merely does casting. Make sure you
- /// are requesting the right data type by calling GetSparseTensorValuesTypeAndShapeInfo()
- /// first.
- /// </summary>
- /// <typeparam name="T">numeric data types only. Use GetStringTensor*() to retrieve strings.</typeparam>
- /// <returns>a pointer to the internal values buffer. Do not free this pointer.</returns>
- template <typename R>
- const R* GetSparseTensorValues() const;
- #endif
- };
- template <typename T>
- struct ValueImpl : ConstValueImpl<T> {
- using B = ConstValueImpl<T>;
- using B::B;
- /// <summary>
- /// Returns a non-const typed pointer to an OrtValue/Tensor contained buffer
- /// No type checking is performed, the caller must ensure the type matches the tensor type.
- /// </summary>
- /// <returns>non-const pointer to data, no copies made</returns>
- template <typename R>
- R* GetTensorMutableData();
- /// <summary>
- /// Returns a non-typed non-const pointer to a tensor contained data.
- /// </summary>
- /// <returns>pointer to data, no copies made</returns>
- void* GetTensorMutableRawData();
- /// <summary>
- // Obtain a reference to an element of data at the location specified
- /// by the vector of dims.
- /// </summary>
- /// <typeparam name="R"></typeparam>
- /// <param name="location">[in] expressed by a vecotr of dimensions offsets</param>
- /// <returns></returns>
- template <typename R>
- R& At(const std::vector<int64_t>& location);
- /// <summary>
- /// Set all strings at once in a string tensor
- /// </summary>
- /// <param name="s">[in] An array of strings. Each string in this array must be null terminated.</param>
- /// <param name="s_len">[in] Count of strings in s (Must match the size of \p value's tensor shape)</param>
- void FillStringTensor(const char* const* s, size_t s_len);
- /// <summary>
- /// Set a single string in a string tensor
- /// </summary>
- /// <param name="s">[in] A null terminated UTF-8 encoded string</param>
- /// <param name="index">[in] Index of the string in the tensor to set</param>
- void FillStringTensorElement(const char* s, size_t index);
- /// <summary>
- /// Allocate if necessary and obtain a pointer to a UTF-8
- /// encoded string element buffer indexed by the flat element index,
- /// of the specified length.
- ///
- /// This API is for advanced usage. It avoids a need to construct
- /// an auxiliary array of string pointers, and allows to write data directly
- /// (do not zero terminate).
- /// </summary>
- /// <param name="index"></param>
- /// <param name="buffer_length"></param>
- /// <returns>a pointer to a writable buffer</returns>
- char* GetResizedStringTensorElementBuffer(size_t index, size_t buffer_length);
- #if !defined(DISABLE_SPARSE_TENSORS)
- /// <summary>
- /// Supplies COO format specific indices and marks the contained sparse tensor as being a COO format tensor.
- /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
- /// allocated buffers lifespan must eclipse that of the OrtValue.
- /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
- /// </summary>
- /// <param name="indices_data">pointer to the user allocated buffer with indices. Use nullptr for fully sparse tensors.</param>
- /// <param name="indices_num">number of indices entries. Use 0 for fully sparse tensors</param>
- void UseCooIndices(int64_t* indices_data, size_t indices_num);
- /// <summary>
- /// Supplies CSR format specific indices and marks the contained sparse tensor as being a CSR format tensor.
- /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
- /// allocated buffers lifespan must eclipse that of the OrtValue.
- /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
- /// </summary>
- /// <param name="inner_data">pointer to the user allocated buffer with inner indices or nullptr for fully sparse tensors</param>
- /// <param name="inner_num">number of csr inner indices or 0 for fully sparse tensors</param>
- /// <param name="outer_data">pointer to the user allocated buffer with outer indices or nullptr for fully sparse tensors</param>
- /// <param name="outer_num">number of csr outer indices or 0 for fully sparse tensors</param>
- void UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num);
- /// <summary>
- /// Supplies BlockSparse format specific indices and marks the contained sparse tensor as being a BlockSparse format tensor.
- /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
- /// allocated buffers lifespan must eclipse that of the OrtValue.
- /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
- /// </summary>
- /// <param name="indices_shape">indices shape or a {0} for fully sparse</param>
- /// <param name="indices_data">user allocated buffer with indices or nullptr for fully spare tensors</param>
- void UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data);
- /// <summary>
- /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
- /// and copy the values and COO indices into it. If data_mem_info specifies that the data is located
- /// at difference device than the allocator, a X-device copy will be performed if possible.
- /// </summary>
- /// <param name="data_mem_info">specified buffer memory description</param>
- /// <param name="values_param">values buffer information.</param>
- /// <param name="indices_data">coo indices buffer or nullptr for fully sparse data</param>
- /// <param name="indices_num">number of COO indices or 0 for fully sparse data</param>
- void FillSparseTensorCoo(const OrtMemoryInfo* data_mem_info, const OrtSparseValuesParam& values_param,
- const int64_t* indices_data, size_t indices_num);
- /// <summary>
- /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
- /// and copy the values and CSR indices into it. If data_mem_info specifies that the data is located
- /// at difference device than the allocator, a X-device copy will be performed if possible.
- /// </summary>
- /// <param name="data_mem_info">specified buffer memory description</param>
- /// <param name="values">values buffer information</param>
- /// <param name="inner_indices_data">csr inner indices pointer or nullptr for fully sparse tensors</param>
- /// <param name="inner_indices_num">number of csr inner indices or 0 for fully sparse tensors</param>
- /// <param name="outer_indices_data">pointer to csr indices data or nullptr for fully sparse tensors</param>
- /// <param name="outer_indices_num">number of csr outer indices or 0</param>
- void FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info,
- const OrtSparseValuesParam& values,
- const int64_t* inner_indices_data, size_t inner_indices_num,
- const int64_t* outer_indices_data, size_t outer_indices_num);
- /// <summary>
- /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
- /// and copy the values and BlockSparse indices into it. If data_mem_info specifies that the data is located
- /// at difference device than the allocator, a X-device copy will be performed if possible.
- /// </summary>
- /// <param name="data_mem_info">specified buffer memory description</param>
- /// <param name="values">values buffer information</param>
- /// <param name="indices_shape">indices shape. use {0} for fully sparse tensors</param>
- /// <param name="indices_data">pointer to indices data or nullptr for fully sparse tensors</param>
- void FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info,
- const OrtSparseValuesParam& values,
- const Shape& indices_shape,
- const int32_t* indices_data);
- #endif
- };
- } // namespace detail
- using ConstValue = detail::ConstValueImpl<detail::Unowned<const OrtValue>>;
- using UnownedValue = detail::ValueImpl<detail::Unowned<OrtValue>>;
- /** \brief Wrapper around ::OrtValue
- *
- */
- struct Value : detail::ValueImpl<OrtValue> {
- using Base = detail::ValueImpl<OrtValue>;
- using OrtSparseValuesParam = detail::OrtSparseValuesParam;
- using Shape = detail::Shape;
- explicit Value(std::nullptr_t) {} ///< Create an empty Value object, must be assigned a valid one to be used
- explicit Value(OrtValue* p) : Base{p} {} ///< Used for interop with the C API
- Value(Value&&) = default;
- Value& operator=(Value&&) = default;
- ConstValue GetConst() const { return ConstValue{this->p_}; }
- UnownedValue GetUnowned() const { return UnownedValue{this->p_}; }
- /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue.
- * \tparam T The numeric datatype. This API is not suitable for strings.
- * \param info Memory description of where the p_data buffer resides (CPU vs GPU etc).
- * \param p_data Pointer to the data buffer.
- * \param p_data_element_count The number of elements in the data buffer.
- * \param shape Pointer to the tensor shape dimensions.
- * \param shape_len The number of tensor shape dimensions.
- */
- template <typename T>
- static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len);
- /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue.
- *
- * \param info Memory description of where the p_data buffer resides (CPU vs GPU etc).
- * \param p_data Pointer to the data buffer.
- * \param p_data_byte_count The number of bytes in the data buffer.
- * \param shape Pointer to the tensor shape dimensions.
- * \param shape_len The number of tensor shape dimensions.
- * \param type The data type.
- */
- static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
- ONNXTensorElementDataType type);
- /** \brief Creates an OrtValue with a tensor using a supplied OrtAllocator. Wraps OrtApi::CreateTensorAsOrtValue.
- * This overload will allocate the buffer for the tensor according to the supplied shape and data type.
- * The allocated buffer will be owned by the returned OrtValue and will be freed when the OrtValue is released.
- * The input data would need to be copied into the allocated buffer.
- * This API is not suitable for strings.
- *
- * \tparam T The numeric datatype. This API is not suitable for strings.
- * \param allocator The allocator to use.
- * \param shape Pointer to the tensor shape dimensions.
- * \param shape_len The number of tensor shape dimensions.
- */
- template <typename T>
- static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len);
- /** \brief Creates an OrtValue with a tensor using the supplied OrtAllocator.
- * Wraps OrtApi::CreateTensorAsOrtValue.
- * The allocated buffer will be owned by the returned OrtValue and will be freed when the OrtValue is released.
- * The input data would need to be copied into the allocated buffer.
- * This API is not suitable for strings.
- *
- * \param allocator The allocator to use.
- * \param shape Pointer to the tensor shape dimensions.
- * \param shape_len The number of tensor shape dimensions.
- * \param type The data type.
- */
- static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type);
- /** \brief Creates an OrtValue with a Map Onnx type representation.
- * The API would ref-count the supplied OrtValues and they will be released
- * when the returned OrtValue is released. The caller may release keys and values after the call
- * returns.
- *
- * \param keys an OrtValue containing a tensor with primitive data type keys.
- * \param values an OrtValue that may contain a tensor. Ort currently supports only primitive data type values.
- */
- static Value CreateMap(const Value& keys, const Value& values); ///< Wraps OrtApi::CreateValue
- /** \brief Creates an OrtValue with a Sequence Onnx type representation.
- * The API would ref-count the supplied OrtValues and they will be released
- * when the returned OrtValue is released. The caller may release the values after the call
- * returns.
- *
- * \param values a vector of OrtValues that must have the same Onnx value type.
- */
- static Value CreateSequence(const std::vector<Value>& values); ///< Wraps OrtApi::CreateValue
- /** \brief Creates an OrtValue wrapping an Opaque type.
- * This is used for experimental support of non-tensor types.
- *
- * \tparam T - the type of the value.
- * \param domain - zero terminated utf-8 string. Domain of the type.
- * \param type_name - zero terminated utf-8 string. Name of the type.
- * \param value - the value to be wrapped.
- */
- template <typename T>
- static Value CreateOpaque(const char* domain, const char* type_name, const T& value); ///< Wraps OrtApi::CreateOpaqueValue
- #if !defined(DISABLE_SPARSE_TENSORS)
- /// <summary>
- /// This is a simple forwarding method to the other overload that helps deducing
- /// data type enum value from the type of the buffer.
- /// </summary>
- /// <typeparam name="T">numeric datatype. This API is not suitable for strings.</typeparam>
- /// <param name="info">Memory description where the user buffers reside (CPU vs GPU etc)</param>
- /// <param name="p_data">pointer to the user supplied buffer, use nullptr for fully sparse tensors</param>
- /// <param name="dense_shape">a would be dense shape of the tensor</param>
- /// <param name="values_shape">non zero values shape. Use a single 0 shape for fully sparse tensors.</param>
- /// <returns></returns>
- template <typename T>
- static Value CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape,
- const Shape& values_shape);
- /// <summary>
- /// Creates an OrtValue instance containing SparseTensor. This constructs
- /// a sparse tensor that makes use of user allocated buffers. It does not make copies
- /// of the user provided data and does not modify it. The lifespan of user provided buffers should
- /// eclipse the life span of the resulting OrtValue. This call constructs an instance that only contain
- /// a pointer to non-zero values. To fully populate the sparse tensor call Use<Format>Indices() API below
- /// to supply a sparse format specific indices.
- /// This API is not suitable for string data. Use CreateSparseTensor() with allocator specified so strings
- /// can be properly copied into the allocated buffer.
- /// </summary>
- /// <param name="info">Memory description where the user buffers reside (CPU vs GPU etc)</param>
- /// <param name="p_data">pointer to the user supplied buffer, use nullptr for fully sparse tensors</param>
- /// <param name="dense_shape">a would be dense shape of the tensor</param>
- /// <param name="values_shape">non zero values shape. Use a single 0 shape for fully sparse tensors.</param>
- /// <param name="type">data type</param>
- /// <returns>Ort::Value instance containing SparseTensor</returns>
- static Value CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape,
- const Shape& values_shape, ONNXTensorElementDataType type);
- /// <summary>
- /// This is a simple forwarding method to the below CreateSparseTensor.
- /// This helps to specify data type enum in terms of C++ data type.
- /// Use CreateSparseTensor<T>
- /// </summary>
- /// <typeparam name="T">numeric data type only. String data enum must be specified explicitly.</typeparam>
- /// <param name="allocator">allocator to use</param>
- /// <param name="dense_shape">a would be dense shape of the tensor</param>
- /// <returns>Ort::Value</returns>
- template <typename T>
- static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape);
- /// <summary>
- /// Creates an instance of OrtValue containing sparse tensor. The created instance has no data.
- /// The data must be supplied by on of the FillSparseTensor<Format>() methods that take both non-zero values
- /// and indices. The data will be copied into a buffer that would be allocated using the supplied allocator.
- /// Use this API to create OrtValues that contain sparse tensors with all supported data types including
- /// strings.
- /// </summary>
- /// <param name="allocator">allocator to use. The allocator lifespan must eclipse that of the resulting OrtValue</param>
- /// <param name="dense_shape">a would be dense shape of the tensor</param>
- /// <param name="type">data type</param>
- /// <returns>an instance of Ort::Value</returns>
- static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape, ONNXTensorElementDataType type);
- #endif // !defined(DISABLE_SPARSE_TENSORS)
- };
- /// <summary>
- /// Represents native memory allocation coming from one of the
- /// OrtAllocators registered with OnnxRuntime.
- /// Use it to wrap an allocation made by an allocator
- /// so it can be automatically released when no longer needed.
- /// </summary>
- struct MemoryAllocation {
- MemoryAllocation(OrtAllocator* allocator, void* p, size_t size);
- ~MemoryAllocation();
- MemoryAllocation(const MemoryAllocation&) = delete;
- MemoryAllocation& operator=(const MemoryAllocation&) = delete;
- MemoryAllocation(MemoryAllocation&&) noexcept;
- MemoryAllocation& operator=(MemoryAllocation&&) noexcept;
- void* get() { return p_; }
- size_t size() const { return size_; }
- private:
- OrtAllocator* allocator_;
- void* p_;
- size_t size_;
- };
- namespace detail {
- template <typename T>
- struct AllocatorImpl : Base<T> {
- using B = Base<T>;
- using B::B;
- void* Alloc(size_t size);
- MemoryAllocation GetAllocation(size_t size);
- void Free(void* p);
- ConstMemoryInfo GetInfo() const;
- };
- } // namespace detail
- /** \brief Wrapper around ::OrtAllocator default instance that is owned by Onnxruntime
- *
- */
- struct AllocatorWithDefaultOptions : detail::AllocatorImpl<detail::Unowned<OrtAllocator>> {
- explicit AllocatorWithDefaultOptions(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance
- AllocatorWithDefaultOptions();
- };
- /** \brief Wrapper around ::OrtAllocator
- *
- */
- struct Allocator : detail::AllocatorImpl<OrtAllocator> {
- explicit Allocator(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance
- Allocator(const Session& session, const OrtMemoryInfo*);
- };
- using UnownedAllocator = detail::AllocatorImpl<detail::Unowned<OrtAllocator>>;
- namespace detail {
- namespace binding_utils {
- // Bring these out of template
- std::vector<std::string> GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator*);
- std::vector<Value> GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator*);
- } // namespace binding_utils
- template <typename T>
- struct ConstIoBindingImpl : Base<T> {
- using B = Base<T>;
- using B::B;
- std::vector<std::string> GetOutputNames() const;
- std::vector<std::string> GetOutputNames(OrtAllocator*) const;
- std::vector<Value> GetOutputValues() const;
- std::vector<Value> GetOutputValues(OrtAllocator*) const;
- };
- template <typename T>
- struct IoBindingImpl : ConstIoBindingImpl<T> {
- using B = ConstIoBindingImpl<T>;
- using B::B;
- void BindInput(const char* name, const Value&);
- void BindOutput(const char* name, const Value&);
- void BindOutput(const char* name, const OrtMemoryInfo*);
- void ClearBoundInputs();
- void ClearBoundOutputs();
- void SynchronizeInputs();
- void SynchronizeOutputs();
- };
- } // namespace detail
- using ConstIoBinding = detail::ConstIoBindingImpl<detail::Unowned<const OrtIoBinding>>;
- using UnownedIoBinding = detail::IoBindingImpl<detail::Unowned<OrtIoBinding>>;
- /** \brief Wrapper around ::OrtIoBinding
- *
- */
- struct IoBinding : detail::IoBindingImpl<OrtIoBinding> {
- explicit IoBinding(std::nullptr_t) {} ///< Create an empty object for convenience. Sometimes, we want to initialize members later.
- explicit IoBinding(Session& session);
- ConstIoBinding GetConst() const { return ConstIoBinding{this->p_}; }
- UnownedIoBinding GetUnowned() const { return UnownedIoBinding{this->p_}; }
- };
- /*! \struct Ort::ArenaCfg
- * \brief it is a structure that represents the configuration of an arena based allocator
- * \details Please see docs/C_API.md for details
- */
- struct ArenaCfg : detail::Base<OrtArenaCfg> {
- explicit ArenaCfg(std::nullptr_t) {} ///< Create an empty ArenaCfg object, must be assigned a valid one to be used
- /**
- * Wraps OrtApi::CreateArenaCfg
- * \param max_mem - use 0 to allow ORT to choose the default
- * \param arena_extend_strategy - use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested
- * \param initial_chunk_size_bytes - use -1 to allow ORT to choose the default
- * \param max_dead_bytes_per_chunk - use -1 to allow ORT to choose the default
- * See docs/C_API.md for details on what the following parameters mean and how to choose these values
- */
- ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk);
- };
- //
- // Custom OPs (only needed to implement custom OPs)
- //
- /// <summary>
- /// This struct provides life time management for custom op attribute
- /// </summary>
- struct OpAttr : detail::Base<OrtOpAttr> {
- OpAttr(const char* name, const void* data, int len, OrtOpAttrType type);
- };
- /**
- * Macro that logs a message using the provided logger. Throws an exception if OrtApi::Logger_LogMessage fails.
- * Example: ORT_CXX_LOG(logger, ORT_LOGGING_LEVEL_INFO, "Log a message");
- *
- * \param logger The Ort::Logger instance to use. Must be a value or reference.
- * \param message_severity The logging severity level of the message.
- * \param message A null-terminated UTF-8 message to log.
- */
- #define ORT_CXX_LOG(logger, message_severity, message) \
- do { \
- if (message_severity >= logger.GetLoggingSeverityLevel()) { \
- Ort::ThrowOnError(logger.LogMessage(message_severity, ORT_FILE, __LINE__, \
- static_cast<const char*>(__FUNCTION__), message)); \
- } \
- } while (false)
- /**
- * Macro that logs a message using the provided logger. Can be used in noexcept code since errors are silently ignored.
- * Example: ORT_CXX_LOG_NOEXCEPT(logger, ORT_LOGGING_LEVEL_INFO, "Log a message");
- *
- * \param logger The Ort::Logger instance to use. Must be a value or reference.
- * \param message_severity The logging severity level of the message.
- * \param message A null-terminated UTF-8 message to log.
- */
- #define ORT_CXX_LOG_NOEXCEPT(logger, message_severity, message) \
- do { \
- if (message_severity >= logger.GetLoggingSeverityLevel()) { \
- static_cast<void>(logger.LogMessage(message_severity, ORT_FILE, __LINE__, \
- static_cast<const char*>(__FUNCTION__), message)); \
- } \
- } while (false)
- /**
- * Macro that logs a printf-like formatted message using the provided logger. Throws an exception if
- * OrtApi::Logger_LogMessage fails or if a formatting error occurs.
- * Example: ORT_CXX_LOGF(logger, ORT_LOGGING_LEVEL_INFO, "Log an int: %d", 12);
- *
- * \param logger The Ort::Logger instance to use. Must be a value or reference.
- * \param message_severity The logging severity level of the message.
- * \param format A null-terminated UTF-8 format string forwarded to a printf-like function.
- * Refer to https://en.cppreference.com/w/cpp/io/c/fprintf for information on valid formats.
- * \param ... Zero or more variadic arguments referenced by the format string.
- */
- #define ORT_CXX_LOGF(logger, message_severity, /*format,*/...) \
- do { \
- if (message_severity >= logger.GetLoggingSeverityLevel()) { \
- Ort::ThrowOnError(logger.LogFormattedMessage(message_severity, ORT_FILE, __LINE__, \
- static_cast<const char*>(__FUNCTION__), __VA_ARGS__)); \
- } \
- } while (false)
- /**
- * Macro that logs a printf-like formatted message using the provided logger. Can be used in noexcept code since errors
- * are silently ignored.
- * Example: ORT_CXX_LOGF_NOEXCEPT(logger, ORT_LOGGING_LEVEL_INFO, "Log an int: %d", 12);
- *
- * \param logger The Ort::Logger instance to use. Must be a value or reference.
- * \param message_severity The logging severity level of the message.
- * \param format A null-terminated UTF-8 format string forwarded to a printf-like function.
- * Refer to https://en.cppreference.com/w/cpp/io/c/fprintf for information on valid formats.
- * \param ... Zero or more variadic arguments referenced by the format string.
- */
- #define ORT_CXX_LOGF_NOEXCEPT(logger, message_severity, /*format,*/...) \
- do { \
- if (message_severity >= logger.GetLoggingSeverityLevel()) { \
- static_cast<void>(logger.LogFormattedMessage(message_severity, ORT_FILE, __LINE__, \
- static_cast<const char*>(__FUNCTION__), __VA_ARGS__)); \
- } \
- } while (false)
- /// <summary>
- /// This class represents an ONNX Runtime logger that can be used to log information with an
- /// associated severity level and source code location (file path, line number, function name).
- ///
- /// A Logger can be obtained from within custom operators by calling Ort::KernelInfo::GetLogger().
- /// Instances of Ort::Logger are the size of two pointers and can be passed by value.
- ///
- /// Use the ORT_CXX_LOG macros to ensure the source code location is set properly from the callsite
- /// and to take advantage of a cached logging severity level that can bypass calls to the underlying C API.
- /// </summary>
- struct Logger {
- /**
- * Creates an empty Ort::Logger. Must be initialized from a valid Ort::Logger before use.
- */
- Logger() = default;
- /**
- * Creates an empty Ort::Logger. Must be initialized from a valid Ort::Logger before use.
- */
- explicit Logger(std::nullptr_t) {}
- /**
- * Creates a logger from an ::OrtLogger instance. Caches the logger's current severity level by calling
- * OrtApi::Logger_GetLoggingSeverityLevel. Throws an exception if OrtApi::Logger_GetLoggingSeverityLevel fails.
- *
- * \param logger The ::OrtLogger to wrap.
- */
- explicit Logger(const OrtLogger* logger);
- ~Logger() = default;
- Logger(const Logger&) = default;
- Logger& operator=(const Logger&) = default;
- Logger(Logger&& v) noexcept = default;
- Logger& operator=(Logger&& v) noexcept = default;
- /**
- * Returns the logger's current severity level from the cached member.
- *
- * \return The current ::OrtLoggingLevel.
- */
- OrtLoggingLevel GetLoggingSeverityLevel() const noexcept;
- /**
- * Logs the provided message via OrtApi::Logger_LogMessage. Use the ORT_CXX_LOG or ORT_CXX_LOG_NOEXCEPT
- * macros to properly set the source code location and to use the cached severity level to potentially bypass
- * calls to the underlying C API.
- *
- * \param log_severity_level The message's logging severity level.
- * \param file_path The filepath of the file in which the message is logged. Usually the value of ORT_FILE.
- * \param line_number The file line number in which the message is logged. Usually the value of __LINE__.
- * \param func_name The name of the function in which the message is logged. Usually the value of __FUNCTION__.
- * \param message The message to log.
- * \return A Ort::Status value to indicate error or success.
- */
- Status LogMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number,
- const char* func_name, const char* message) const noexcept;
- /**
- * Logs a printf-like formatted message via OrtApi::Logger_LogMessage. Use the ORT_CXX_LOGF or ORT_CXX_LOGF_NOEXCEPT
- * macros to properly set the source code location and to use the cached severity level to potentially bypass
- * calls to the underlying C API. Returns an error status if a formatting error occurs.
- *
- * \param log_severity_level The message's logging severity level.
- * \param file_path The filepath of the file in which the message is logged. Usually the value of ORT_FILE.
- * \param line_number The file line number in which the message is logged. Usually the value of __LINE__.
- * \param func_name The name of the function in which the message is logged. Usually the value of __FUNCTION__.
- * \param format A null-terminated UTF-8 format string forwarded to a printf-like function.
- * Refer to https://en.cppreference.com/w/cpp/io/c/fprintf for information on valid formats.
- * \param args Zero or more variadic arguments referenced by the format string.
- * \return A Ort::Status value to indicate error or success.
- */
- template <typename... Args>
- Status LogFormattedMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number,
- const char* func_name, const char* format, Args&&... args) const noexcept;
- private:
- const OrtLogger* logger_{};
- OrtLoggingLevel cached_severity_level_{};
- };
- /// <summary>
- /// This class wraps a raw pointer OrtKernelContext* that is being passed
- /// to the custom kernel Compute() method. Use it to safely access context
- /// attributes, input and output parameters with exception safety guarantees.
- /// See usage example in onnxruntime/test/testdata/custom_op_library/custom_op_library.cc
- /// </summary>
- struct KernelContext {
- explicit KernelContext(OrtKernelContext* context);
- size_t GetInputCount() const;
- size_t GetOutputCount() const;
- // If input is optional and is not present, the method returns en empty ConstValue
- // which can be compared to nullptr.
- ConstValue GetInput(size_t index) const;
- // If outout is optional and is not present, the method returns en empty UnownedValue
- // which can be compared to nullptr.
- UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const;
- UnownedValue GetOutput(size_t index, const std::vector<int64_t>& dims) const;
- void* GetGPUComputeStream() const;
- Logger GetLogger() const;
- OrtAllocator* GetAllocator(const OrtMemoryInfo& memory_info) const;
- OrtKernelContext* GetOrtKernelContext() const { return ctx_; }
- void ParallelFor(void (*fn)(void*, size_t), size_t total, size_t num_batch, void* usr_data) const;
- private:
- OrtKernelContext* ctx_;
- };
- struct KernelInfo;
- namespace detail {
- namespace attr_utils {
- void GetAttr(const OrtKernelInfo* p, const char* name, float&);
- void GetAttr(const OrtKernelInfo* p, const char* name, int64_t&);
- void GetAttr(const OrtKernelInfo* p, const char* name, std::string&);
- void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<float>&);
- void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<int64_t>&);
- } // namespace attr_utils
- template <typename T>
- struct KernelInfoImpl : Base<T> {
- using B = Base<T>;
- using B::B;
- KernelInfo Copy() const;
- template <typename R> // R is only implemented for float, int64_t, and string
- R GetAttribute(const char* name) const {
- R val;
- attr_utils::GetAttr(this->p_, name, val);
- return val;
- }
- template <typename R> // R is only implemented for std::vector<float>, std::vector<int64_t>
- std::vector<R> GetAttributes(const char* name) const {
- std::vector<R> result;
- attr_utils::GetAttrs(this->p_, name, result);
- return result;
- }
- Value GetTensorAttribute(const char* name, OrtAllocator* allocator) const;
- size_t GetInputCount() const;
- size_t GetOutputCount() const;
- std::string GetInputName(size_t index) const;
- std::string GetOutputName(size_t index) const;
- TypeInfo GetInputTypeInfo(size_t index) const;
- TypeInfo GetOutputTypeInfo(size_t index) const;
- ConstValue GetTensorConstantInput(size_t index, int* is_constant) const;
- std::string GetNodeName() const;
- Logger GetLogger() const;
- };
- } // namespace detail
- using ConstKernelInfo = detail::KernelInfoImpl<detail::Unowned<const OrtKernelInfo>>;
- /// <summary>
- /// This struct owns the OrtKernInfo* pointer when a copy is made.
- /// For convenient wrapping of OrtKernelInfo* passed to kernel constructor
- /// and query attributes, warp the pointer with Ort::Unowned<KernelInfo> instance
- /// so it does not destroy the pointer the kernel does not own.
- /// </summary>
- struct KernelInfo : detail::KernelInfoImpl<OrtKernelInfo> {
- explicit KernelInfo(std::nullptr_t) {} ///< Create an empty instance to initialize later
- explicit KernelInfo(OrtKernelInfo* info); ///< Take ownership of the instance
- ConstKernelInfo GetConst() const { return ConstKernelInfo{this->p_}; }
- };
- /// <summary>
- /// Create and own custom defined operation.
- /// </summary>
- struct Op : detail::Base<OrtOp> {
- explicit Op(std::nullptr_t) {} ///< Create an empty Operator object, must be assigned a valid one to be used
- explicit Op(OrtOp*); ///< Take ownership of the OrtOp
- static Op Create(const OrtKernelInfo* info, const char* op_name, const char* domain,
- int version, const char** type_constraint_names,
- const ONNXTensorElementDataType* type_constraint_values,
- size_t type_constraint_count,
- const OpAttr* attr_values,
- size_t attr_count,
- size_t input_count, size_t output_count);
- void Invoke(const OrtKernelContext* context,
- const Value* input_values,
- size_t input_count,
- Value* output_values,
- size_t output_count);
- // For easier refactoring
- void Invoke(const OrtKernelContext* context,
- const OrtValue* const* input_values,
- size_t input_count,
- OrtValue* const* output_values,
- size_t output_count);
- };
- /// <summary>
- /// Provide access to per-node attributes and input shapes, so one could compute and set output shapes.
- /// </summary>
- struct ShapeInferContext {
- struct SymbolicInteger {
- SymbolicInteger(int64_t i) : i_(i), is_int_(true){};
- SymbolicInteger(const char* s) : s_(s), is_int_(false){};
- SymbolicInteger(const SymbolicInteger&) = default;
- SymbolicInteger(SymbolicInteger&&) = default;
- SymbolicInteger& operator=(const SymbolicInteger&) = default;
- SymbolicInteger& operator=(SymbolicInteger&&) = default;
- bool operator==(const SymbolicInteger& dim) const {
- if (is_int_ == dim.is_int_) {
- if (is_int_) {
- return i_ == dim.i_;
- } else {
- return std::string{s_} == std::string{dim.s_};
- }
- }
- return false;
- }
- bool IsInt() const { return is_int_; }
- int64_t AsInt() const { return i_; }
- const char* AsSym() const { return s_; }
- static constexpr int INVALID_INT_DIM = -2;
- private:
- union {
- int64_t i_;
- const char* s_;
- };
- bool is_int_;
- };
- using Shape = std::vector<SymbolicInteger>;
- ShapeInferContext(const OrtApi* ort_api, OrtShapeInferContext* ctx);
- const Shape& GetInputShape(size_t indice) const { return input_shapes_.at(indice); }
- size_t GetInputCount() const { return input_shapes_.size(); }
- Status SetOutputShape(size_t indice, const Shape& shape);
- int64_t GetAttrInt(const char* attr_name);
- using Ints = std::vector<int64_t>;
- Ints GetAttrInts(const char* attr_name);
- float GetAttrFloat(const char* attr_name);
- using Floats = std::vector<float>;
- Floats GetAttrFloats(const char* attr_name);
- std::string GetAttrString(const char* attr_name);
- using Strings = std::vector<std::string>;
- Strings GetAttrStrings(const char* attr_name);
- private:
- const OrtOpAttr* GetAttrHdl(const char* attr_name) const;
- const OrtApi* ort_api_;
- OrtShapeInferContext* ctx_;
- std::vector<Shape> input_shapes_;
- };
- using ShapeInferFn = Ort::Status (*)(Ort::ShapeInferContext&);
- #define MAX_CUSTOM_OP_END_VER (1UL << 31) - 1
- template <typename TOp, typename TKernel, bool WithStatus = false>
- struct CustomOpBase : OrtCustomOp {
- CustomOpBase() {
- OrtCustomOp::version = ORT_API_VERSION;
- OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetName(); };
- OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetExecutionProviderType(); };
- OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetInputTypeCount(); };
- OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputType(index); };
- OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputMemoryType(index); };
- OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetOutputTypeCount(); };
- OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputType(index); };
- #if defined(_MSC_VER) && !defined(__clang__)
- #pragma warning(push)
- #pragma warning(disable : 26409)
- #endif
- OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<TKernel*>(op_kernel); };
- #if defined(_MSC_VER) && !defined(__clang__)
- #pragma warning(pop)
- #endif
- OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputCharacteristic(index); };
- OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputCharacteristic(index); };
- OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetVariadicInputMinArity(); };
- OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp* this_) { return static_cast<int>(static_cast<const TOp*>(this_)->GetVariadicInputHomogeneity()); };
- OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetVariadicOutputMinArity(); };
- OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp* this_) { return static_cast<int>(static_cast<const TOp*>(this_)->GetVariadicOutputHomogeneity()); };
- #ifdef __cpp_if_constexpr
- if constexpr (WithStatus) {
- #else
- if (WithStatus) {
- #endif
- OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info, void** op_kernel) -> OrtStatusPtr {
- return static_cast<const TOp*>(this_)->CreateKernelV2(*api, info, op_kernel);
- };
- OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
- return static_cast<TKernel*>(op_kernel)->ComputeV2(context);
- };
- } else {
- OrtCustomOp::CreateKernelV2 = nullptr;
- OrtCustomOp::KernelComputeV2 = nullptr;
- OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast<const TOp*>(this_)->CreateKernel(*api, info); };
- OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
- static_cast<TKernel*>(op_kernel)->Compute(context);
- };
- }
- SetShapeInferFn<TOp>(0);
- OrtCustomOp::GetStartVersion = [](const OrtCustomOp* this_) {
- return static_cast<const TOp*>(this_)->start_ver_;
- };
- OrtCustomOp::GetEndVersion = [](const OrtCustomOp* this_) {
- return static_cast<const TOp*>(this_)->end_ver_;
- };
- OrtCustomOp::GetMayInplace = nullptr;
- OrtCustomOp::ReleaseMayInplace = nullptr;
- OrtCustomOp::GetAliasMap = nullptr;
- OrtCustomOp::ReleaseAliasMap = nullptr;
- }
- // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
- const char* GetExecutionProviderType() const { return nullptr; }
- // Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below
- // (inputs and outputs are required by default)
- OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t /*index*/) const {
- return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
- }
- OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const {
- return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
- }
- // Default implemention of GetInputMemoryType() that returns OrtMemTypeDefault
- OrtMemType GetInputMemoryType(size_t /*index*/) const {
- return OrtMemTypeDefault;
- }
- // Default implementation of GetVariadicInputMinArity() returns 1 to specify that a variadic input
- // should expect at least 1 argument.
- int GetVariadicInputMinArity() const {
- return 1;
- }
- // Default implementation of GetVariadicInputHomegeneity() returns true to specify that all arguments
- // to a variadic input should be of the same type.
- bool GetVariadicInputHomogeneity() const {
- return true;
- }
- // Default implementation of GetVariadicOutputMinArity() returns 1 to specify that a variadic output
- // should produce at least 1 output value.
- int GetVariadicOutputMinArity() const {
- return 1;
- }
- // Default implementation of GetVariadicOutputHomegeneity() returns true to specify that all output values
- // produced by a variadic output should be of the same type.
- bool GetVariadicOutputHomogeneity() const {
- return true;
- }
- // Declare list of session config entries used by this Custom Op.
- // Implement this function in order to get configs from CustomOpBase::GetSessionConfigs().
- // This default implementation returns an empty vector of config entries.
- std::vector<std::string> GetSessionConfigKeys() const {
- return std::vector<std::string>{};
- }
- template <typename C>
- decltype(&C::InferOutputShape) SetShapeInferFn(decltype(&C::InferOutputShape)) {
- OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp*, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
- ShapeInferContext ctx(&GetApi(), ort_ctx);
- return C::InferOutputShape(ctx);
- };
- return {};
- }
- template <typename C>
- void SetShapeInferFn(...) {
- OrtCustomOp::InferOutputShapeFn = {};
- }
- protected:
- // Helper function that returns a map of session config entries specified by CustomOpBase::GetSessionConfigKeys.
- void GetSessionConfigs(std::unordered_map<std::string, std::string>& out, ConstSessionOptions options) const;
- int start_ver_ = 1;
- int end_ver_ = MAX_CUSTOM_OP_END_VER;
- };
- } // namespace Ort
- #include "onnxruntime_cxx_inline.h"
|