onnxruntime_lite_custom_op.h 61 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119
  1. // Copyright (c) Microsoft Corporation. All rights reserved.
  2. // Licensed under the MIT License.
  3. // Summary
  4. // The header has APIs to save custom op authors the trouble of defining schemas,
  5. // which will be inferred by functions' signature, as long as their argument list has types supported here.
  6. // Input could be:
  7. // 1. Tensor of onnx data types.
  8. // 2. Span of onnx data types.
  9. // 3. Scalar of onnx data types.
  10. // A input could be optional if indicated as std::optional<...>.
  11. // For an output, it must be a tensor of onnx data types.
  12. // Further, the header also has utility for a simple custom struct, where resources could be kept, to be registered as a custom op.
  13. // For concrete examples, please search keyword "LiteCustomOpTest" under "<cloned_src_dir>/onnxruntime/test/".
  14. // Note - all APIs in this header are ABI.
  15. #pragma once
  16. #include "onnxruntime_cxx_api.h"
  17. #include <optional>
  18. #include <numeric>
  19. #include <functional>
  20. #include <unordered_set>
  21. namespace Ort {
  22. namespace Custom {
  23. class ArgBase {
  24. public:
  25. ArgBase(OrtKernelContext* ctx,
  26. size_t indice,
  27. bool is_input) : ctx_(ctx), indice_(indice), is_input_(is_input) {}
  28. virtual ~ArgBase(){};
  29. protected:
  30. struct KernelContext ctx_;
  31. size_t indice_;
  32. bool is_input_;
  33. };
  34. using ArgPtr = std::unique_ptr<Custom::ArgBase>;
  35. using ArgPtrs = std::vector<ArgPtr>;
  36. class TensorBase : public ArgBase {
  37. public:
  38. TensorBase(OrtKernelContext* ctx,
  39. size_t indice,
  40. bool is_input) : ArgBase(ctx, indice, is_input) {}
  41. operator bool() const {
  42. return shape_.has_value();
  43. }
  44. const std::vector<int64_t>& Shape() const {
  45. if (!shape_.has_value()) {
  46. ORT_CXX_API_THROW("tensor shape is not yet initialized", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
  47. }
  48. return shape_.value();
  49. }
  50. ONNXTensorElementDataType Type() const {
  51. return type_;
  52. }
  53. int64_t NumberOfElement() const {
  54. if (shape_.has_value()) {
  55. return std::accumulate(shape_->begin(), shape_->end(), 1LL, std::multiplies<int64_t>());
  56. } else {
  57. return 0;
  58. }
  59. }
  60. std::string Shape2Str() const {
  61. if (shape_.has_value()) {
  62. std::string shape_str;
  63. for (const auto& dim : *shape_) {
  64. shape_str.append(std::to_string(dim));
  65. shape_str.append(", ");
  66. }
  67. return shape_str;
  68. } else {
  69. return "empty";
  70. }
  71. }
  72. bool IsCpuTensor() const {
  73. return strcmp("Cpu", mem_type_) == 0;
  74. }
  75. virtual const void* DataRaw() const = 0;
  76. virtual size_t SizeInBytes() const = 0;
  77. protected:
  78. std::optional<std::vector<int64_t>> shape_;
  79. ONNXTensorElementDataType type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
  80. const char* mem_type_ = "Cpu";
  81. };
  82. template <typename T>
  83. struct Span {
  84. const T* data_ = {};
  85. size_t size_ = {};
  86. void Assign(const T* data, size_t size) {
  87. data_ = data;
  88. size_ = size;
  89. }
  90. size_t size() const { return size_; }
  91. T operator[](size_t indice) const {
  92. return data_[indice];
  93. }
  94. const T* data() const { return data_; }
  95. };
  96. template <typename T>
  97. class Tensor : public TensorBase {
  98. public:
  99. using TT = typename std::remove_reference<T>::type;
  100. Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) {
  101. if (is_input_) {
  102. if (indice >= ctx_.GetInputCount()) {
  103. ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT);
  104. }
  105. const_value_ = ctx_.GetInput(indice);
  106. auto type_shape_info = const_value_.GetTensorTypeAndShapeInfo();
  107. shape_ = type_shape_info.GetShape();
  108. }
  109. }
  110. const TT* Data() const {
  111. return reinterpret_cast<const TT*>(const_value_.GetTensorRawData());
  112. }
  113. TT* Allocate(const std::vector<int64_t>& shape) {
  114. shape_ = shape;
  115. if (!data_) {
  116. shape_ = shape;
  117. data_ = ctx_.GetOutput(indice_, shape).template GetTensorMutableData<TT>();
  118. }
  119. return data_;
  120. }
  121. static TT GetT() { return (TT)0; }
  122. const Span<T>& AsSpan() {
  123. if (!shape_.has_value() || shape_->size() != 1) {
  124. ORT_CXX_API_THROW("invalid shape while trying to get a span out of Ort::Custom::Tensor",
  125. OrtErrorCode::ORT_RUNTIME_EXCEPTION);
  126. }
  127. span_.Assign(Data(), static_cast<size_t>((*shape_)[0]));
  128. return span_;
  129. }
  130. const T& AsScalar() {
  131. if (!shape_.has_value() || shape_->size() != 1 || (*shape_)[0] != 1) {
  132. ORT_CXX_API_THROW("invalid shape while trying to get a scalar from Ort::Custom::Tensor",
  133. OrtErrorCode::ORT_RUNTIME_EXCEPTION);
  134. }
  135. return *Data();
  136. }
  137. const void* DataRaw() const override {
  138. return reinterpret_cast<const void*>(Data());
  139. }
  140. size_t SizeInBytes() const override {
  141. return sizeof(TT) * static_cast<size_t>(NumberOfElement());
  142. }
  143. private:
  144. ConstValue const_value_; // for input
  145. TT* data_{}; // for output
  146. Span<T> span_;
  147. };
  148. template <>
  149. class Tensor<std::string> : public TensorBase {
  150. public:
  151. using strings = std::vector<std::string>;
  152. Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) {
  153. if (is_input_) {
  154. if (indice >= ctx_.GetInputCount()) {
  155. ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT);
  156. }
  157. auto const_value = ctx_.GetInput(indice);
  158. auto type_shape_info = const_value.GetTensorTypeAndShapeInfo();
  159. shape_ = type_shape_info.GetShape();
  160. auto num_chars = const_value.GetStringTensorDataLength();
  161. // note - there will be copy ...
  162. auto num_strings = static_cast<size_t>(NumberOfElement());
  163. if (num_strings) {
  164. std::vector<char> chars(num_chars + 1, '\0');
  165. std::vector<size_t> offsets(num_strings);
  166. const_value.GetStringTensorContent(static_cast<void*>(chars.data()), num_chars, offsets.data(), offsets.size());
  167. auto upper_bound = num_strings - 1;
  168. input_strings_.resize(num_strings);
  169. for (size_t i = upper_bound;; --i) {
  170. if (i < upper_bound) {
  171. chars[offsets[i + 1]] = '\0';
  172. }
  173. input_strings_[i] = chars.data() + offsets[i];
  174. if (0 == i) {
  175. break;
  176. }
  177. }
  178. }
  179. }
  180. }
  181. const strings& Data() const {
  182. return input_strings_;
  183. }
  184. const void* DataRaw() const override {
  185. if (input_strings_.size() != 1) {
  186. ORT_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
  187. }
  188. return reinterpret_cast<const void*>(input_strings_[0].c_str());
  189. }
  190. size_t SizeInBytes() const override {
  191. if (input_strings_.size() != 1) {
  192. ORT_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
  193. }
  194. return input_strings_[0].size();
  195. }
  196. void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
  197. shape_ = dims;
  198. std::vector<const char*> raw;
  199. for (const auto& s : ss) {
  200. raw.push_back(s.data());
  201. }
  202. auto output = ctx_.GetOutput(indice_, dims.data(), dims.size());
  203. // note - there will be copy ...
  204. output.FillStringTensor(raw.data(), raw.size());
  205. }
  206. const Span<std::string>& AsSpan() {
  207. ORT_CXX_API_THROW("span for TensorT of string not implemented", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
  208. }
  209. const std::string& AsScalar() {
  210. if (input_strings_.size() != 1) {
  211. ORT_CXX_API_THROW("invalid shape while trying to get a scalar string from Ort::Custom::Tensor",
  212. OrtErrorCode::ORT_RUNTIME_EXCEPTION);
  213. }
  214. return input_strings_[0];
  215. }
  216. private:
  217. std::vector<std::string> input_strings_; // for input
  218. };
  219. template <>
  220. class Tensor<std::string_view> : public TensorBase {
  221. public:
  222. using strings = std::vector<std::string>;
  223. using string_views = std::vector<std::string_view>;
  224. Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) {
  225. if (is_input_) {
  226. if (indice >= ctx_.GetInputCount()) {
  227. ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT);
  228. }
  229. auto const_value = ctx_.GetInput(indice);
  230. auto type_shape_info = const_value.GetTensorTypeAndShapeInfo();
  231. shape_ = type_shape_info.GetShape();
  232. auto num_chars = const_value.GetStringTensorDataLength();
  233. chars_.resize(num_chars + 1, '\0');
  234. auto num_strings = static_cast<size_t>(NumberOfElement());
  235. if (num_strings) {
  236. std::vector<size_t> offsets(num_strings);
  237. const_value.GetStringTensorContent(static_cast<void*>(chars_.data()), num_chars, offsets.data(), offsets.size());
  238. offsets.push_back(num_chars);
  239. for (size_t i = 0; i < num_strings; ++i) {
  240. input_string_views_.emplace_back(chars_.data() + offsets[i], offsets[i + 1] - offsets[i]);
  241. }
  242. }
  243. }
  244. }
  245. const string_views& Data() const {
  246. return input_string_views_;
  247. }
  248. const void* DataRaw() const override {
  249. if (input_string_views_.size() != 1) {
  250. ORT_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
  251. }
  252. return reinterpret_cast<const void*>(input_string_views_[0].data());
  253. }
  254. size_t SizeInBytes() const override {
  255. if (input_string_views_.size() != 1) {
  256. ORT_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
  257. }
  258. return input_string_views_[0].size();
  259. }
  260. void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
  261. shape_ = dims;
  262. std::vector<const char*> raw;
  263. for (const auto& s : ss) {
  264. raw.push_back(s.data());
  265. }
  266. auto output = ctx_.GetOutput(indice_, dims.data(), dims.size());
  267. // note - there will be copy ...
  268. output.FillStringTensor(raw.data(), raw.size());
  269. }
  270. const Span<std::string_view>& AsSpan() {
  271. ORT_CXX_API_THROW("span for TensorT of string view not implemented", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
  272. }
  273. std::string_view AsScalar() {
  274. if (input_string_views_.size() != 1) {
  275. ORT_CXX_API_THROW("invalid shape while trying to get a scalar string view from Ort::Custom::Tensor",
  276. OrtErrorCode::ORT_RUNTIME_EXCEPTION);
  277. }
  278. return input_string_views_[0];
  279. }
  280. private:
  281. std::vector<char> chars_; // for input
  282. std::vector<std::string_view> input_string_views_; // for input
  283. };
  284. using TensorPtr = std::unique_ptr<Custom::TensorBase>;
  285. using TensorPtrs = std::vector<TensorPtr>;
  286. struct TensorArray : public ArgBase {
  287. TensorArray(OrtKernelContext* ctx,
  288. size_t start_indice,
  289. bool is_input) : ArgBase(ctx,
  290. start_indice,
  291. is_input) {
  292. if (is_input) {
  293. auto input_count = ctx_.GetInputCount();
  294. for (size_t ith_input = start_indice; ith_input < input_count; ++ith_input) {
  295. auto const_value = ctx_.GetInput(start_indice);
  296. auto type_shape_info = const_value.GetTensorTypeAndShapeInfo();
  297. auto type = type_shape_info.GetElementType();
  298. TensorPtr tensor;
  299. switch (type) {
  300. case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
  301. tensor = std::make_unique<Custom::Tensor<bool>>(ctx, ith_input, true);
  302. break;
  303. case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
  304. tensor = std::make_unique<Custom::Tensor<float>>(ctx, ith_input, true);
  305. break;
  306. case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
  307. tensor = std::make_unique<Custom::Tensor<double>>(ctx, ith_input, true);
  308. break;
  309. case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
  310. tensor = std::make_unique<Custom::Tensor<uint8_t>>(ctx, ith_input, true);
  311. break;
  312. case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
  313. tensor = std::make_unique<Custom::Tensor<int8_t>>(ctx, ith_input, true);
  314. break;
  315. case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
  316. tensor = std::make_unique<Custom::Tensor<uint16_t>>(ctx, ith_input, true);
  317. break;
  318. case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
  319. tensor = std::make_unique<Custom::Tensor<int16_t>>(ctx, ith_input, true);
  320. break;
  321. case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
  322. tensor = std::make_unique<Custom::Tensor<uint32_t>>(ctx, ith_input, true);
  323. break;
  324. case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
  325. tensor = std::make_unique<Custom::Tensor<int32_t>>(ctx, ith_input, true);
  326. break;
  327. case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
  328. tensor = std::make_unique<Custom::Tensor<uint64_t>>(ctx, ith_input, true);
  329. break;
  330. case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
  331. tensor = std::make_unique<Custom::Tensor<int64_t>>(ctx, ith_input, true);
  332. break;
  333. case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
  334. tensor = std::make_unique<Custom::Tensor<std::string>>(ctx, ith_input, true);
  335. break;
  336. default:
  337. ORT_CXX_API_THROW("unknow input type", ORT_RUNTIME_EXCEPTION);
  338. break;
  339. }
  340. tensors_.emplace_back(tensor.release());
  341. } // for
  342. }
  343. }
  344. template <typename T>
  345. T* AllocateOutput(size_t ith_output, const std::vector<int64_t>& shape) {
  346. // ith_output is the indice of output relative to the tensor array
  347. // indice_ + ith_output is the indice relative to context
  348. auto tensor = std::make_unique<Tensor<T>>(ctx_.GetOrtKernelContext(), indice_ + ith_output, false);
  349. auto raw_output = tensor.get()->Allocate(shape);
  350. tensors_.emplace_back(tensor.release());
  351. return raw_output;
  352. }
  353. Tensor<std::string>& AllocateStringTensor(size_t ith_output) {
  354. // ith_output is the indice of output relative to the tensor array
  355. // indice_ + ith_output is the indice relative to context
  356. auto tensor = std::make_unique<Tensor<std::string>>(ctx_.GetOrtKernelContext(), indice_ + ith_output, false);
  357. Tensor<std::string>& output = *tensor;
  358. tensors_.emplace_back(tensor.release());
  359. return output;
  360. }
  361. size_t Size() const {
  362. return tensors_.size();
  363. }
  364. const TensorPtr& operator[](size_t ith_input) const {
  365. // ith_input is the indice of output relative to the tensor array
  366. return tensors_.at(ith_input);
  367. }
  368. private:
  369. TensorPtrs tensors_;
  370. };
  371. using Variadic = TensorArray;
  372. /*
  373. Note:
  374. OrtLiteCustomOp inherits from OrtCustomOp to bridge tween a custom func/struct and ort core.
  375. The lifetime of an OrtLiteCustomOp instance is managed by customer code, not ort, so:
  376. 1. DO NOT cast OrtLiteCustomOp to OrtCustomOp and release since there is no virtual destructor in the hierachy.
  377. 2. OrtLiteCustomFunc and OrtLiteCustomStruct, as two sub-structs, can be released in form of OrtLiteCustomOp since all members are kept in the OrtLiteCustomOp,
  378. hence memory could still be recycled properly.
  379. Further, OrtCustomOp is a c struct bearing no v-table, so offspring structs are by design to be of zero virtual functions to maintain cast safety.
  380. */
  381. struct OrtLiteCustomOp : public OrtCustomOp {
  382. using ConstOptionalFloatTensor = std::optional<const Custom::Tensor<float>&>;
  383. using OptionalFloatTensor = std::optional<Custom::Tensor<float>>;
  384. // CreateTuple
  385. template <size_t ith_input, size_t ith_output, typename... Ts>
  386. static typename std::enable_if<sizeof...(Ts) == 0, std::tuple<>>::type
  387. CreateTuple(OrtKernelContext*, ArgPtrs&, size_t, size_t, const std::string&) {
  388. return std::make_tuple();
  389. }
  390. template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
  391. static typename std::enable_if<std::is_same<T, OrtKernelContext*>::value, std::tuple<T, Ts...>>::type
  392. CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
  393. std::tuple<T> current = std::tuple<OrtKernelContext*>{context};
  394. auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep);
  395. return std::tuple_cat(current, next);
  396. }
  397. template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
  398. static typename std::enable_if<std::is_same<T, OrtKernelContext&>::value, std::tuple<T, Ts...>>::type
  399. CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
  400. std::tuple<T> current = std::tuple<OrtKernelContext&>{*context};
  401. auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep);
  402. return std::tuple_cat(current, next);
  403. }
  404. #ifdef ORT_CUDA_CTX
  405. template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
  406. static typename std::enable_if<std::is_same<T, const CudaContext&>::value, std::tuple<T, Ts...>>::type
  407. CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
  408. thread_local CudaContext cuda_context;
  409. cuda_context.Init(*context);
  410. std::tuple<T> current = std::tuple<const CudaContext&>{cuda_context};
  411. auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep);
  412. return std::tuple_cat(current, next);
  413. }
  414. #endif
  415. #ifdef ORT_ROCM_CTX
  416. template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
  417. static typename std::enable_if<std::is_same<T, const RocmContext&>::value, std::tuple<T, Ts...>>::type
  418. CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
  419. thread_local RocmContext rocm_context;
  420. rocm_context.Init(*context);
  421. std::tuple<T> current = std::tuple<const RocmContext&>{rocm_context};
  422. auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep);
  423. return std::tuple_cat(current, next);
  424. }
  425. #endif
  426. template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
  427. static typename std::enable_if<std::is_same<T, const TensorArray*>::value, std::tuple<T, Ts...>>::type
  428. CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
  429. args.push_back(std::make_unique<TensorArray>(context, ith_input, true));
  430. std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())};
  431. auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep);
  432. return std::tuple_cat(current, next);
  433. }
  434. template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
  435. static typename std::enable_if<std::is_same<T, const TensorArray&>::value, std::tuple<T, Ts...>>::type
  436. CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
  437. args.push_back(std::make_unique<TensorArray>(context, ith_input, true));
  438. std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())};
  439. auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep);
  440. return std::tuple_cat(current, next);
  441. }
  442. template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
  443. static typename std::enable_if<std::is_same<T, TensorArray*>::value, std::tuple<T, Ts...>>::type
  444. CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
  445. args.push_back(std::make_unique<TensorArray>(context, ith_output, false));
  446. std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())};
  447. auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep);
  448. return std::tuple_cat(current, next);
  449. }
  450. template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
  451. static typename std::enable_if<std::is_same<T, TensorArray&>::value, std::tuple<T, Ts...>>::type
  452. CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
  453. args.push_back(std::make_unique<TensorArray>(context, ith_output, false));
  454. std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())};
  455. auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep);
  456. return std::tuple_cat(current, next);
  457. }
  458. #define CREATE_TUPLE_INPUT(data_type) \
  459. template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
  460. static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
  461. CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
  462. args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
  463. std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())}; \
  464. auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
  465. return std::tuple_cat(current, next); \
  466. } \
  467. template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
  468. static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \
  469. CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
  470. args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
  471. std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())}; \
  472. auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
  473. return std::tuple_cat(current, next); \
  474. } \
  475. template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
  476. static typename std::enable_if<std::is_same<T, std::optional<const Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
  477. CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
  478. if (ith_input < num_input) { \
  479. args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
  480. std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())}; \
  481. auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
  482. return std::tuple_cat(current, next); \
  483. } else { \
  484. std::tuple<T> current = std::tuple<T>{}; \
  485. auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
  486. return std::tuple_cat(current, next); \
  487. } \
  488. } \
  489. template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
  490. static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>*>::value, std::tuple<T, Ts...>>::type \
  491. CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
  492. if ("CPUExecutionProvider" != ep) { \
  493. ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
  494. } \
  495. args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
  496. std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsSpan()}; \
  497. auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
  498. return std::tuple_cat(current, next); \
  499. } \
  500. template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
  501. static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>&>::value, std::tuple<T, Ts...>>::type \
  502. CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
  503. if ("CPUExecutionProvider" != ep) { \
  504. ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
  505. } \
  506. args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
  507. std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsSpan()}; \
  508. auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
  509. return std::tuple_cat(current, next); \
  510. } \
  511. template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
  512. static typename std::enable_if<std::is_same<T, std::optional<const Custom::Span<data_type>*>>::value, std::tuple<T, Ts...>>::type \
  513. CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
  514. if (ith_input < num_input) { \
  515. if ("CPUExecutionProvider" != ep) { \
  516. ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
  517. } \
  518. args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
  519. std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsSpan()}; \
  520. auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
  521. return std::tuple_cat(current, next); \
  522. } else { \
  523. std::tuple<T> current = std::tuple<T>{}; \
  524. auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
  525. return std::tuple_cat(current, next); \
  526. } \
  527. } \
  528. template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
  529. static typename std::enable_if<std::is_same<T, data_type>::value, std::tuple<T, Ts...>>::type \
  530. CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
  531. if ("CPUExecutionProvider" != ep) { \
  532. ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
  533. } \
  534. args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
  535. std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsScalar()}; \
  536. auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
  537. return std::tuple_cat(current, next); \
  538. } \
  539. template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
  540. static typename std::enable_if<std::is_same<T, std::optional<data_type>>::value, std::tuple<T, Ts...>>::type \
  541. CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
  542. if (ith_input < num_input) { \
  543. if ("CPUExecutionProvider" != ep) { \
  544. ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
  545. } \
  546. args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
  547. std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsScalar()}; \
  548. auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
  549. return std::tuple_cat(current, next); \
  550. } else { \
  551. std::tuple<T> current = std::tuple<T>{}; \
  552. auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
  553. return std::tuple_cat(current, next); \
  554. } \
  555. }
  556. #define CREATE_TUPLE_OUTPUT(data_type) \
  557. template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
  558. static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
  559. CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
  560. args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_output, false)); \
  561. std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())}; \
  562. auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \
  563. return std::tuple_cat(current, next); \
  564. } \
  565. template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
  566. static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \
  567. CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
  568. args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_output, false)); \
  569. std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())}; \
  570. auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \
  571. return std::tuple_cat(current, next); \
  572. } \
  573. template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
  574. static typename std::enable_if<std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
  575. CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
  576. if (ith_output < num_output) { \
  577. args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_output, false)); \
  578. std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())}; \
  579. auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \
  580. return std::tuple_cat(current, next); \
  581. } else { \
  582. std::tuple<T> current = std::tuple<T>{}; \
  583. auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \
  584. return std::tuple_cat(current, next); \
  585. } \
  586. }
  587. #define CREATE_TUPLE(data_type) \
  588. CREATE_TUPLE_INPUT(data_type) \
  589. CREATE_TUPLE_OUTPUT(data_type)
  590. CREATE_TUPLE(bool)
  591. CREATE_TUPLE(float)
  592. CREATE_TUPLE(Ort::Float16_t)
  593. CREATE_TUPLE(Ort::BFloat16_t)
  594. CREATE_TUPLE(double)
  595. CREATE_TUPLE(int8_t)
  596. CREATE_TUPLE(int16_t)
  597. CREATE_TUPLE(int32_t)
  598. CREATE_TUPLE(int64_t)
  599. CREATE_TUPLE(uint8_t)
  600. CREATE_TUPLE(uint16_t)
  601. CREATE_TUPLE(uint32_t)
  602. CREATE_TUPLE(uint64_t)
  603. CREATE_TUPLE(std::string)
  604. CREATE_TUPLE_INPUT(std::string_view)
  605. CREATE_TUPLE(Ort::Float8E4M3FN_t)
  606. CREATE_TUPLE(Ort::Float8E4M3FNUZ_t)
  607. CREATE_TUPLE(Ort::Float8E5M2_t)
  608. CREATE_TUPLE(Ort::Float8E5M2FNUZ_t)
  609. // ParseArgs ...
  610. template <typename... Ts>
  611. static typename std::enable_if<0 == sizeof...(Ts)>::type
  612. ParseArgs(std::vector<ONNXTensorElementDataType>&, std::vector<ONNXTensorElementDataType>&) {
  613. }
  614. template <typename T, typename... Ts>
  615. static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext*>::value>::type
  616. ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
  617. ParseArgs<Ts...>(input_types, output_types);
  618. }
  619. template <typename T, typename... Ts>
  620. static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext&>::value>::type
  621. ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
  622. ParseArgs<Ts...>(input_types, output_types);
  623. }
  624. #ifdef ORT_CUDA_CTX
  625. template <typename T, typename... Ts>
  626. static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const CudaContext&>::value>::type
  627. ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
  628. ParseArgs<Ts...>(input_types, output_types);
  629. }
  630. #endif
  631. #ifdef ORT_ROCM_CTX
  632. template <typename T, typename... Ts>
  633. static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const RocmContext&>::value>::type
  634. ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
  635. ParseArgs<Ts...>(input_types, output_types);
  636. }
  637. #endif
  638. template <typename T, typename... Ts>
  639. static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const TensorArray&>::value>::type
  640. ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
  641. input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
  642. ParseArgs<Ts...>(input_types, output_types);
  643. }
  644. template <typename T, typename... Ts>
  645. static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const TensorArray*>::value>::type
  646. ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
  647. input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
  648. ParseArgs<Ts...>(input_types, output_types);
  649. }
  650. template <typename T, typename... Ts>
  651. static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, TensorArray&>::value>::type
  652. ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
  653. output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
  654. ParseArgs<Ts...>(input_types, output_types);
  655. }
  656. template <typename T, typename... Ts>
  657. static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, TensorArray*>::value>::type
  658. ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
  659. output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
  660. ParseArgs<Ts...>(input_types, output_types);
  661. }
  662. #define PARSE_INPUT_BASE(pack_type, onnx_type) \
  663. template <typename T, typename... Ts> \
  664. static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, pack_type>::value>::type \
  665. ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
  666. input_types.push_back(onnx_type); \
  667. ParseArgs<Ts...>(input_types, output_types); \
  668. } \
  669. template <typename T, typename... Ts> \
  670. static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const std::optional<pack_type>>::value>::type \
  671. ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
  672. input_types.push_back(onnx_type); \
  673. ParseArgs<Ts...>(input_types, output_types); \
  674. } \
  675. template <typename T, typename... Ts> \
  676. static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<pack_type>>::value>::type \
  677. ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
  678. input_types.push_back(onnx_type); \
  679. ParseArgs<Ts...>(input_types, output_types); \
  680. }
  681. #define PARSE_INPUT(data_type, onnx_type) \
  682. PARSE_INPUT_BASE(const Custom::Tensor<data_type>*, onnx_type) \
  683. PARSE_INPUT_BASE(const Custom::Tensor<data_type>&, onnx_type) \
  684. PARSE_INPUT_BASE(const Custom::Span<data_type>*, onnx_type) \
  685. PARSE_INPUT_BASE(const Custom::Span<data_type>&, onnx_type) \
  686. PARSE_INPUT_BASE(data_type, onnx_type)
  687. #define PARSE_OUTPUT(data_type, onnx_type) \
  688. template <typename T, typename... Ts> \
  689. static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>*>::value>::type \
  690. ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
  691. output_types.push_back(onnx_type); \
  692. ParseArgs<Ts...>(input_types, output_types); \
  693. } \
  694. template <typename T, typename... Ts> \
  695. static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>&>::value>::type \
  696. ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
  697. output_types.push_back(onnx_type); \
  698. ParseArgs<Ts...>(input_types, output_types); \
  699. } \
  700. template <typename T, typename... Ts> \
  701. static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value>::type \
  702. ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
  703. output_types.push_back(onnx_type); \
  704. ParseArgs<Ts...>(input_types, output_types); \
  705. }
  706. #define PARSE_ARGS(data_type, onnx_type) \
  707. PARSE_INPUT(data_type, onnx_type) \
  708. PARSE_OUTPUT(data_type, onnx_type)
  709. PARSE_ARGS(bool, ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL)
  710. PARSE_ARGS(float, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
  711. PARSE_ARGS(Ort::Float16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16)
  712. PARSE_ARGS(Ort::BFloat16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16)
  713. PARSE_ARGS(double, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE)
  714. PARSE_ARGS(int8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8)
  715. PARSE_ARGS(int16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16)
  716. PARSE_ARGS(int32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32)
  717. PARSE_ARGS(int64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)
  718. PARSE_ARGS(uint8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8)
  719. PARSE_ARGS(uint16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16)
  720. PARSE_ARGS(uint32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32)
  721. PARSE_ARGS(uint64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64)
  722. PARSE_ARGS(std::string, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING)
  723. PARSE_ARGS(std::string_view, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) // todo - remove string_view output
  724. PARSE_ARGS(Ort::Float8E4M3FN_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN)
  725. PARSE_ARGS(Ort::Float8E4M3FNUZ_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ)
  726. PARSE_ARGS(Ort::Float8E5M2_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2)
  727. PARSE_ARGS(Ort::Float8E5M2FNUZ_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ)
  728. OrtLiteCustomOp(const char* op_name,
  729. const char* execution_provider,
  730. ShapeInferFn shape_infer_fn,
  731. int start_ver = 1,
  732. int end_ver = MAX_CUSTOM_OP_END_VER) : op_name_(op_name),
  733. execution_provider_(execution_provider),
  734. shape_infer_fn_(shape_infer_fn),
  735. start_ver_(start_ver),
  736. end_ver_(end_ver) {
  737. OrtCustomOp::version = ORT_API_VERSION;
  738. OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast<const OrtLiteCustomOp*>(op)->op_name_.c_str(); };
  739. OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* op) { return ((OrtLiteCustomOp*)op)->execution_provider_.c_str(); };
  740. OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp*, size_t) { return OrtMemTypeDefault; };
  741. OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* op) {
  742. auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
  743. return self->input_types_.size();
  744. };
  745. OrtCustomOp::GetInputType = [](const OrtCustomOp* op, size_t indice) {
  746. auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
  747. return self->input_types_[indice];
  748. };
  749. OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* op) {
  750. auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
  751. return self->output_types_.size();
  752. };
  753. OrtCustomOp::GetOutputType = [](const OrtCustomOp* op, size_t indice) {
  754. auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
  755. return self->output_types_[indice];
  756. };
  757. OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* op, size_t indice) {
  758. auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
  759. return self->input_types_[indice] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ? INPUT_OUTPUT_VARIADIC : INPUT_OUTPUT_OPTIONAL;
  760. };
  761. OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t indice) {
  762. auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
  763. return self->output_types_[indice] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ? INPUT_OUTPUT_VARIADIC : INPUT_OUTPUT_OPTIONAL;
  764. };
  765. OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) {
  766. return 1;
  767. };
  768. OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) {
  769. return 0;
  770. };
  771. OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) {
  772. return 1;
  773. };
  774. OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) {
  775. return 0;
  776. };
  777. OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) { return 0; };
  778. OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) { return 0; };
  779. OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) { return 0; };
  780. OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) { return 0; };
  781. OrtCustomOp::CreateKernelV2 = {};
  782. OrtCustomOp::KernelComputeV2 = {};
  783. OrtCustomOp::KernelCompute = {};
  784. OrtCustomOp::InferOutputShapeFn = {};
  785. OrtCustomOp::GetStartVersion = [](const OrtCustomOp* op) {
  786. auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
  787. return self->start_ver_;
  788. };
  789. OrtCustomOp::GetEndVersion = [](const OrtCustomOp* op) {
  790. auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
  791. return self->end_ver_;
  792. };
  793. OrtCustomOp::GetMayInplace = {};
  794. OrtCustomOp::ReleaseMayInplace = {};
  795. OrtCustomOp::GetAliasMap = {};
  796. OrtCustomOp::ReleaseAliasMap = {};
  797. }
  798. const std::string op_name_;
  799. const std::string execution_provider_;
  800. std::vector<ONNXTensorElementDataType> input_types_;
  801. std::vector<ONNXTensorElementDataType> output_types_;
  802. ShapeInferFn shape_infer_fn_ = {};
  803. int start_ver_ = 1;
  804. int end_ver_ = MAX_CUSTOM_OP_END_VER;
  805. void* compute_fn_ = {};
  806. void* compute_fn_return_status_ = {};
  807. };
  808. //////////////////////////// OrtLiteCustomFunc ////////////////////////////////
  809. // The struct is to implement function-as-op.
  810. // E.g. a function might be defined as:
  811. // void Filter(const Ort::Custom::Tensor<float>& floats_in, Ort::Custom::Tensor<float>& floats_out) { ... }
  812. // It could be registered this way:
  813. // Ort::CustomOpDomain v2_domain{"v2"};
  814. // std::unique_ptr<OrtLiteCustomOp> fil_op_ptr{Ort::Custom::CreateLiteCustomOp("Filter", "CPUExecutionProvider", Filter)};
  815. // v2_domain.Add(fil_op_ptr.get());
  816. // session_options.Add(v2_domain);
  817. // For the complete example, please search keyword "LiteCustomOpTest" under "<cloned_src_dir>/onnxruntime/test/".
  818. template <typename... Args>
  819. struct OrtLiteCustomFunc : public OrtLiteCustomOp {
  820. using ComputeFn = void (*)(Args...);
  821. using ComputeFnReturnStatus = Status (*)(Args...);
  822. using MyType = OrtLiteCustomFunc<Args...>;
  823. struct Kernel {
  824. size_t num_input_{};
  825. size_t num_output_{};
  826. ComputeFn compute_fn_{};
  827. ComputeFnReturnStatus compute_fn_return_status_{};
  828. std::string ep_{};
  829. };
  830. OrtLiteCustomFunc(const char* op_name,
  831. const char* execution_provider,
  832. ComputeFn compute_fn,
  833. ShapeInferFn shape_infer_fn = {},
  834. int start_ver = 1,
  835. int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) {
  836. compute_fn_ = reinterpret_cast<void*>(compute_fn);
  837. ParseArgs<Args...>(input_types_, output_types_);
  838. OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
  839. auto kernel = reinterpret_cast<Kernel*>(op_kernel);
  840. std::vector<ArgPtr> args;
  841. auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_);
  842. std::apply([kernel](Args const&... t_args) { kernel->compute_fn_(t_args...); }, t);
  843. };
  844. OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
  845. auto kernel = std::make_unique<Kernel>();
  846. auto me = static_cast<const MyType*>(this_);
  847. kernel->compute_fn_ = reinterpret_cast<ComputeFn>(me->compute_fn_);
  848. Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_));
  849. Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_));
  850. auto self = static_cast<const OrtLiteCustomFunc*>(this_);
  851. kernel->ep_ = self->execution_provider_;
  852. return reinterpret_cast<void*>(kernel.release());
  853. };
  854. OrtCustomOp::KernelDestroy = [](void* op_kernel) {
  855. delete reinterpret_cast<Kernel*>(op_kernel);
  856. };
  857. if (shape_infer_fn_) {
  858. OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp* op, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
  859. auto shape_info_fn = static_cast<const MyType*>(op)->shape_infer_fn_;
  860. ShapeInferContext ctx(&GetApi(), ort_ctx);
  861. return shape_info_fn(ctx);
  862. };
  863. }
  864. }
  865. OrtLiteCustomFunc(const char* op_name,
  866. const char* execution_provider,
  867. ComputeFnReturnStatus compute_fn_return_status,
  868. ShapeInferFn shape_infer_fn = {},
  869. int start_ver = 1,
  870. int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) {
  871. compute_fn_return_status_ = reinterpret_cast<void*>(compute_fn_return_status);
  872. ParseArgs<Args...>(input_types_, output_types_);
  873. OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
  874. auto kernel = reinterpret_cast<Kernel*>(op_kernel);
  875. std::vector<ArgPtr> args;
  876. auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_);
  877. return std::apply([kernel](Args const&... t_args) { Status status = kernel->compute_fn_return_status_(t_args...); return status.release(); }, t);
  878. };
  879. OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
  880. auto kernel = std::make_unique<Kernel>();
  881. auto me = static_cast<const MyType*>(this_);
  882. kernel->compute_fn_return_status_ = reinterpret_cast<ComputeFnReturnStatus>(me->compute_fn_return_status_);
  883. Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_));
  884. Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_));
  885. auto self = static_cast<const OrtLiteCustomFunc*>(this_);
  886. kernel->ep_ = self->execution_provider_;
  887. return reinterpret_cast<void*>(kernel.release());
  888. };
  889. OrtCustomOp::KernelDestroy = [](void* op_kernel) {
  890. delete reinterpret_cast<Kernel*>(op_kernel);
  891. };
  892. if (shape_infer_fn_) {
  893. OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp* op, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
  894. auto shape_info_fn = static_cast<const MyType*>(op)->shape_infer_fn_;
  895. ShapeInferContext ctx(&GetApi(), ort_ctx);
  896. return shape_info_fn(ctx);
  897. };
  898. }
  899. }
  900. }; // struct OrtLiteCustomFunc
  901. /////////////////////////// OrtLiteCustomStruct ///////////////////////////
  902. // The struct is to implement struct-as-op.
  903. // E.g. a struct might be defined as:
  904. // struct Merge {
  905. // Merge(const OrtApi* ort_api, const OrtKernelInfo* info) {...}
  906. // void Compute(const Ort::Custom::Tensor<std::string_view>& strings_in,
  907. // std::string_view string_in,
  908. // Ort::Custom::Tensor<std::string>* strings_out) {...}
  909. // bool reverse_ = false;
  910. // };
  911. // It could be registered this way:
  912. // Ort::CustomOpDomain v2_domain{"v2"};
  913. // std::unique_ptr<OrtLiteCustomOp> mrg_op_ptr{Ort::Custom::CreateLiteCustomOp<Merge>("Merge", "CPUExecutionProvider")};
  914. // v2_domain.Add(mrg_op_ptr.get());
  915. // session_options.Add(v2_domain);
  916. // For the complete example, please search keyword "LiteCustomOpTest" under "<cloned_src_dir>/onnxruntime/test/".
  917. template <typename CustomOp>
  918. struct OrtLiteCustomStruct : public OrtLiteCustomOp {
  919. template <typename... Args>
  920. using CustomComputeFn = void (CustomOp::*)(Args...);
  921. template <typename... Args>
  922. using CustomComputeFnReturnStatus = Status (CustomOp::*)(Args...);
  923. using MyType = OrtLiteCustomStruct<CustomOp>;
  924. struct Kernel {
  925. size_t num_input_{};
  926. size_t num_output_{};
  927. std::unique_ptr<CustomOp> custom_op_;
  928. std::string ep_{};
  929. };
  930. OrtLiteCustomStruct(const char* op_name,
  931. const char* execution_provider,
  932. int start_ver = 1,
  933. int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, {}, start_ver, end_ver) {
  934. SetCompute(&CustomOp::Compute);
  935. OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
  936. auto kernel = std::make_unique<Kernel>();
  937. Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_));
  938. Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_));
  939. kernel->custom_op_ = std::make_unique<CustomOp>(ort_api, info);
  940. auto self = static_cast<const OrtLiteCustomStruct*>(this_);
  941. kernel->ep_ = self->execution_provider_;
  942. return reinterpret_cast<void*>(kernel.release());
  943. };
  944. OrtCustomOp::KernelDestroy = [](void* op_kernel) {
  945. delete reinterpret_cast<Kernel*>(op_kernel);
  946. };
  947. SetShapeInfer<CustomOp>(0);
  948. }
  949. template <typename... Args>
  950. void SetCompute(CustomComputeFn<Args...>) {
  951. ParseArgs<Args...>(input_types_, output_types_);
  952. OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
  953. auto kernel = reinterpret_cast<Kernel*>(op_kernel);
  954. ArgPtrs args;
  955. auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_);
  956. std::apply([kernel](Args const&... t_args) { kernel->custom_op_->Compute(t_args...); }, t);
  957. };
  958. }
  959. template <typename... Args>
  960. void SetCompute(CustomComputeFnReturnStatus<Args...>) {
  961. ParseArgs<Args...>(input_types_, output_types_);
  962. OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
  963. auto kernel = reinterpret_cast<Kernel*>(op_kernel);
  964. ArgPtrs args;
  965. auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_);
  966. return std::apply([kernel](Args const&... t_args) { Status status = kernel->custom_op_->Compute(t_args...); return status.release(); }, t);
  967. };
  968. }
  969. template <typename C>
  970. decltype(&C::InferOutputShape) SetShapeInfer(decltype(&C::InferOutputShape)) {
  971. OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp*, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
  972. ShapeInferContext ctx(&GetApi(), ort_ctx);
  973. return C::InferOutputShape(ctx);
  974. };
  975. return {};
  976. }
  977. template <typename C>
  978. void SetShapeInfer(...) {
  979. OrtCustomOp::InferOutputShapeFn = {};
  980. }
  981. }; // struct OrtLiteCustomStruct
  982. /////////////////////////// CreateLiteCustomOp ////////////////////////////
  983. template <typename... Args>
  984. OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
  985. const char* execution_provider,
  986. void (*custom_compute_fn)(Args...),
  987. Status (*shape_infer_fn)(ShapeInferContext&) = {},
  988. int start_ver = 1,
  989. int end_ver = MAX_CUSTOM_OP_END_VER) {
  990. using LiteOp = OrtLiteCustomFunc<Args...>;
  991. return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn, shape_infer_fn, start_ver, end_ver).release();
  992. }
  993. template <typename... Args>
  994. OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
  995. const char* execution_provider,
  996. Status (*custom_compute_fn_v2)(Args...),
  997. Status (*shape_infer_fn)(ShapeInferContext&) = {},
  998. int start_ver = 1,
  999. int end_ver = MAX_CUSTOM_OP_END_VER) {
  1000. using LiteOp = OrtLiteCustomFunc<Args...>;
  1001. return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn_v2, shape_infer_fn, start_ver, end_ver).release();
  1002. }
  1003. template <typename CustomOp>
  1004. OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
  1005. const char* execution_provider,
  1006. int start_ver = 1,
  1007. int end_ver = MAX_CUSTOM_OP_END_VER) {
  1008. using LiteOp = OrtLiteCustomStruct<CustomOp>;
  1009. return std::make_unique<LiteOp>(op_name, execution_provider, start_ver, end_ver).release();
  1010. }
  1011. } // namespace Custom
  1012. } // namespace Ort