123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119 |
- // Copyright (c) Microsoft Corporation. All rights reserved.
- // Licensed under the MIT License.
- // Summary
- // The header has APIs to save custom op authors the trouble of defining schemas,
- // which will be inferred by functions' signature, as long as their argument list has types supported here.
- // Input could be:
- // 1. Tensor of onnx data types.
- // 2. Span of onnx data types.
- // 3. Scalar of onnx data types.
- // A input could be optional if indicated as std::optional<...>.
- // For an output, it must be a tensor of onnx data types.
- // Further, the header also has utility for a simple custom struct, where resources could be kept, to be registered as a custom op.
- // For concrete examples, please search keyword "LiteCustomOpTest" under "<cloned_src_dir>/onnxruntime/test/".
- // Note - all APIs in this header are ABI.
- #pragma once
- #include "onnxruntime_cxx_api.h"
- #include <optional>
- #include <numeric>
- #include <functional>
- #include <unordered_set>
- namespace Ort {
- namespace Custom {
- class ArgBase {
- public:
- ArgBase(OrtKernelContext* ctx,
- size_t indice,
- bool is_input) : ctx_(ctx), indice_(indice), is_input_(is_input) {}
- virtual ~ArgBase(){};
- protected:
- struct KernelContext ctx_;
- size_t indice_;
- bool is_input_;
- };
- using ArgPtr = std::unique_ptr<Custom::ArgBase>;
- using ArgPtrs = std::vector<ArgPtr>;
- class TensorBase : public ArgBase {
- public:
- TensorBase(OrtKernelContext* ctx,
- size_t indice,
- bool is_input) : ArgBase(ctx, indice, is_input) {}
- operator bool() const {
- return shape_.has_value();
- }
- const std::vector<int64_t>& Shape() const {
- if (!shape_.has_value()) {
- ORT_CXX_API_THROW("tensor shape is not yet initialized", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
- }
- return shape_.value();
- }
- ONNXTensorElementDataType Type() const {
- return type_;
- }
- int64_t NumberOfElement() const {
- if (shape_.has_value()) {
- return std::accumulate(shape_->begin(), shape_->end(), 1LL, std::multiplies<int64_t>());
- } else {
- return 0;
- }
- }
- std::string Shape2Str() const {
- if (shape_.has_value()) {
- std::string shape_str;
- for (const auto& dim : *shape_) {
- shape_str.append(std::to_string(dim));
- shape_str.append(", ");
- }
- return shape_str;
- } else {
- return "empty";
- }
- }
- bool IsCpuTensor() const {
- return strcmp("Cpu", mem_type_) == 0;
- }
- virtual const void* DataRaw() const = 0;
- virtual size_t SizeInBytes() const = 0;
- protected:
- std::optional<std::vector<int64_t>> shape_;
- ONNXTensorElementDataType type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
- const char* mem_type_ = "Cpu";
- };
- template <typename T>
- struct Span {
- const T* data_ = {};
- size_t size_ = {};
- void Assign(const T* data, size_t size) {
- data_ = data;
- size_ = size;
- }
- size_t size() const { return size_; }
- T operator[](size_t indice) const {
- return data_[indice];
- }
- const T* data() const { return data_; }
- };
- template <typename T>
- class Tensor : public TensorBase {
- public:
- using TT = typename std::remove_reference<T>::type;
- Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) {
- if (is_input_) {
- if (indice >= ctx_.GetInputCount()) {
- ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT);
- }
- const_value_ = ctx_.GetInput(indice);
- auto type_shape_info = const_value_.GetTensorTypeAndShapeInfo();
- shape_ = type_shape_info.GetShape();
- }
- }
- const TT* Data() const {
- return reinterpret_cast<const TT*>(const_value_.GetTensorRawData());
- }
- TT* Allocate(const std::vector<int64_t>& shape) {
- shape_ = shape;
- if (!data_) {
- shape_ = shape;
- data_ = ctx_.GetOutput(indice_, shape).template GetTensorMutableData<TT>();
- }
- return data_;
- }
- static TT GetT() { return (TT)0; }
- const Span<T>& AsSpan() {
- if (!shape_.has_value() || shape_->size() != 1) {
- ORT_CXX_API_THROW("invalid shape while trying to get a span out of Ort::Custom::Tensor",
- OrtErrorCode::ORT_RUNTIME_EXCEPTION);
- }
- span_.Assign(Data(), static_cast<size_t>((*shape_)[0]));
- return span_;
- }
- const T& AsScalar() {
- if (!shape_.has_value() || shape_->size() != 1 || (*shape_)[0] != 1) {
- ORT_CXX_API_THROW("invalid shape while trying to get a scalar from Ort::Custom::Tensor",
- OrtErrorCode::ORT_RUNTIME_EXCEPTION);
- }
- return *Data();
- }
- const void* DataRaw() const override {
- return reinterpret_cast<const void*>(Data());
- }
- size_t SizeInBytes() const override {
- return sizeof(TT) * static_cast<size_t>(NumberOfElement());
- }
- private:
- ConstValue const_value_; // for input
- TT* data_{}; // for output
- Span<T> span_;
- };
- template <>
- class Tensor<std::string> : public TensorBase {
- public:
- using strings = std::vector<std::string>;
- Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) {
- if (is_input_) {
- if (indice >= ctx_.GetInputCount()) {
- ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT);
- }
- auto const_value = ctx_.GetInput(indice);
- auto type_shape_info = const_value.GetTensorTypeAndShapeInfo();
- shape_ = type_shape_info.GetShape();
- auto num_chars = const_value.GetStringTensorDataLength();
- // note - there will be copy ...
- auto num_strings = static_cast<size_t>(NumberOfElement());
- if (num_strings) {
- std::vector<char> chars(num_chars + 1, '\0');
- std::vector<size_t> offsets(num_strings);
- const_value.GetStringTensorContent(static_cast<void*>(chars.data()), num_chars, offsets.data(), offsets.size());
- auto upper_bound = num_strings - 1;
- input_strings_.resize(num_strings);
- for (size_t i = upper_bound;; --i) {
- if (i < upper_bound) {
- chars[offsets[i + 1]] = '\0';
- }
- input_strings_[i] = chars.data() + offsets[i];
- if (0 == i) {
- break;
- }
- }
- }
- }
- }
- const strings& Data() const {
- return input_strings_;
- }
- const void* DataRaw() const override {
- if (input_strings_.size() != 1) {
- ORT_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
- }
- return reinterpret_cast<const void*>(input_strings_[0].c_str());
- }
- size_t SizeInBytes() const override {
- if (input_strings_.size() != 1) {
- ORT_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
- }
- return input_strings_[0].size();
- }
- void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
- shape_ = dims;
- std::vector<const char*> raw;
- for (const auto& s : ss) {
- raw.push_back(s.data());
- }
- auto output = ctx_.GetOutput(indice_, dims.data(), dims.size());
- // note - there will be copy ...
- output.FillStringTensor(raw.data(), raw.size());
- }
- const Span<std::string>& AsSpan() {
- ORT_CXX_API_THROW("span for TensorT of string not implemented", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
- }
- const std::string& AsScalar() {
- if (input_strings_.size() != 1) {
- ORT_CXX_API_THROW("invalid shape while trying to get a scalar string from Ort::Custom::Tensor",
- OrtErrorCode::ORT_RUNTIME_EXCEPTION);
- }
- return input_strings_[0];
- }
- private:
- std::vector<std::string> input_strings_; // for input
- };
- template <>
- class Tensor<std::string_view> : public TensorBase {
- public:
- using strings = std::vector<std::string>;
- using string_views = std::vector<std::string_view>;
- Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) {
- if (is_input_) {
- if (indice >= ctx_.GetInputCount()) {
- ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT);
- }
- auto const_value = ctx_.GetInput(indice);
- auto type_shape_info = const_value.GetTensorTypeAndShapeInfo();
- shape_ = type_shape_info.GetShape();
- auto num_chars = const_value.GetStringTensorDataLength();
- chars_.resize(num_chars + 1, '\0');
- auto num_strings = static_cast<size_t>(NumberOfElement());
- if (num_strings) {
- std::vector<size_t> offsets(num_strings);
- const_value.GetStringTensorContent(static_cast<void*>(chars_.data()), num_chars, offsets.data(), offsets.size());
- offsets.push_back(num_chars);
- for (size_t i = 0; i < num_strings; ++i) {
- input_string_views_.emplace_back(chars_.data() + offsets[i], offsets[i + 1] - offsets[i]);
- }
- }
- }
- }
- const string_views& Data() const {
- return input_string_views_;
- }
- const void* DataRaw() const override {
- if (input_string_views_.size() != 1) {
- ORT_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
- }
- return reinterpret_cast<const void*>(input_string_views_[0].data());
- }
- size_t SizeInBytes() const override {
- if (input_string_views_.size() != 1) {
- ORT_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
- }
- return input_string_views_[0].size();
- }
- void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
- shape_ = dims;
- std::vector<const char*> raw;
- for (const auto& s : ss) {
- raw.push_back(s.data());
- }
- auto output = ctx_.GetOutput(indice_, dims.data(), dims.size());
- // note - there will be copy ...
- output.FillStringTensor(raw.data(), raw.size());
- }
- const Span<std::string_view>& AsSpan() {
- ORT_CXX_API_THROW("span for TensorT of string view not implemented", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
- }
- std::string_view AsScalar() {
- if (input_string_views_.size() != 1) {
- ORT_CXX_API_THROW("invalid shape while trying to get a scalar string view from Ort::Custom::Tensor",
- OrtErrorCode::ORT_RUNTIME_EXCEPTION);
- }
- return input_string_views_[0];
- }
- private:
- std::vector<char> chars_; // for input
- std::vector<std::string_view> input_string_views_; // for input
- };
- using TensorPtr = std::unique_ptr<Custom::TensorBase>;
- using TensorPtrs = std::vector<TensorPtr>;
- struct TensorArray : public ArgBase {
- TensorArray(OrtKernelContext* ctx,
- size_t start_indice,
- bool is_input) : ArgBase(ctx,
- start_indice,
- is_input) {
- if (is_input) {
- auto input_count = ctx_.GetInputCount();
- for (size_t ith_input = start_indice; ith_input < input_count; ++ith_input) {
- auto const_value = ctx_.GetInput(start_indice);
- auto type_shape_info = const_value.GetTensorTypeAndShapeInfo();
- auto type = type_shape_info.GetElementType();
- TensorPtr tensor;
- switch (type) {
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
- tensor = std::make_unique<Custom::Tensor<bool>>(ctx, ith_input, true);
- break;
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
- tensor = std::make_unique<Custom::Tensor<float>>(ctx, ith_input, true);
- break;
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
- tensor = std::make_unique<Custom::Tensor<double>>(ctx, ith_input, true);
- break;
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
- tensor = std::make_unique<Custom::Tensor<uint8_t>>(ctx, ith_input, true);
- break;
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
- tensor = std::make_unique<Custom::Tensor<int8_t>>(ctx, ith_input, true);
- break;
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
- tensor = std::make_unique<Custom::Tensor<uint16_t>>(ctx, ith_input, true);
- break;
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
- tensor = std::make_unique<Custom::Tensor<int16_t>>(ctx, ith_input, true);
- break;
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
- tensor = std::make_unique<Custom::Tensor<uint32_t>>(ctx, ith_input, true);
- break;
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
- tensor = std::make_unique<Custom::Tensor<int32_t>>(ctx, ith_input, true);
- break;
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
- tensor = std::make_unique<Custom::Tensor<uint64_t>>(ctx, ith_input, true);
- break;
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
- tensor = std::make_unique<Custom::Tensor<int64_t>>(ctx, ith_input, true);
- break;
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
- tensor = std::make_unique<Custom::Tensor<std::string>>(ctx, ith_input, true);
- break;
- default:
- ORT_CXX_API_THROW("unknow input type", ORT_RUNTIME_EXCEPTION);
- break;
- }
- tensors_.emplace_back(tensor.release());
- } // for
- }
- }
- template <typename T>
- T* AllocateOutput(size_t ith_output, const std::vector<int64_t>& shape) {
- // ith_output is the indice of output relative to the tensor array
- // indice_ + ith_output is the indice relative to context
- auto tensor = std::make_unique<Tensor<T>>(ctx_.GetOrtKernelContext(), indice_ + ith_output, false);
- auto raw_output = tensor.get()->Allocate(shape);
- tensors_.emplace_back(tensor.release());
- return raw_output;
- }
- Tensor<std::string>& AllocateStringTensor(size_t ith_output) {
- // ith_output is the indice of output relative to the tensor array
- // indice_ + ith_output is the indice relative to context
- auto tensor = std::make_unique<Tensor<std::string>>(ctx_.GetOrtKernelContext(), indice_ + ith_output, false);
- Tensor<std::string>& output = *tensor;
- tensors_.emplace_back(tensor.release());
- return output;
- }
- size_t Size() const {
- return tensors_.size();
- }
- const TensorPtr& operator[](size_t ith_input) const {
- // ith_input is the indice of output relative to the tensor array
- return tensors_.at(ith_input);
- }
- private:
- TensorPtrs tensors_;
- };
- using Variadic = TensorArray;
- /*
- Note:
- OrtLiteCustomOp inherits from OrtCustomOp to bridge tween a custom func/struct and ort core.
- The lifetime of an OrtLiteCustomOp instance is managed by customer code, not ort, so:
- 1. DO NOT cast OrtLiteCustomOp to OrtCustomOp and release since there is no virtual destructor in the hierachy.
- 2. OrtLiteCustomFunc and OrtLiteCustomStruct, as two sub-structs, can be released in form of OrtLiteCustomOp since all members are kept in the OrtLiteCustomOp,
- hence memory could still be recycled properly.
- Further, OrtCustomOp is a c struct bearing no v-table, so offspring structs are by design to be of zero virtual functions to maintain cast safety.
- */
- struct OrtLiteCustomOp : public OrtCustomOp {
- using ConstOptionalFloatTensor = std::optional<const Custom::Tensor<float>&>;
- using OptionalFloatTensor = std::optional<Custom::Tensor<float>>;
- // CreateTuple
- template <size_t ith_input, size_t ith_output, typename... Ts>
- static typename std::enable_if<sizeof...(Ts) == 0, std::tuple<>>::type
- CreateTuple(OrtKernelContext*, ArgPtrs&, size_t, size_t, const std::string&) {
- return std::make_tuple();
- }
- template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
- static typename std::enable_if<std::is_same<T, OrtKernelContext*>::value, std::tuple<T, Ts...>>::type
- CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
- std::tuple<T> current = std::tuple<OrtKernelContext*>{context};
- auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep);
- return std::tuple_cat(current, next);
- }
- template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
- static typename std::enable_if<std::is_same<T, OrtKernelContext&>::value, std::tuple<T, Ts...>>::type
- CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
- std::tuple<T> current = std::tuple<OrtKernelContext&>{*context};
- auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep);
- return std::tuple_cat(current, next);
- }
- #ifdef ORT_CUDA_CTX
- template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
- static typename std::enable_if<std::is_same<T, const CudaContext&>::value, std::tuple<T, Ts...>>::type
- CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
- thread_local CudaContext cuda_context;
- cuda_context.Init(*context);
- std::tuple<T> current = std::tuple<const CudaContext&>{cuda_context};
- auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep);
- return std::tuple_cat(current, next);
- }
- #endif
- #ifdef ORT_ROCM_CTX
- template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
- static typename std::enable_if<std::is_same<T, const RocmContext&>::value, std::tuple<T, Ts...>>::type
- CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
- thread_local RocmContext rocm_context;
- rocm_context.Init(*context);
- std::tuple<T> current = std::tuple<const RocmContext&>{rocm_context};
- auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep);
- return std::tuple_cat(current, next);
- }
- #endif
- template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
- static typename std::enable_if<std::is_same<T, const TensorArray*>::value, std::tuple<T, Ts...>>::type
- CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
- args.push_back(std::make_unique<TensorArray>(context, ith_input, true));
- std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())};
- auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep);
- return std::tuple_cat(current, next);
- }
- template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
- static typename std::enable_if<std::is_same<T, const TensorArray&>::value, std::tuple<T, Ts...>>::type
- CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
- args.push_back(std::make_unique<TensorArray>(context, ith_input, true));
- std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())};
- auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep);
- return std::tuple_cat(current, next);
- }
- template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
- static typename std::enable_if<std::is_same<T, TensorArray*>::value, std::tuple<T, Ts...>>::type
- CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
- args.push_back(std::make_unique<TensorArray>(context, ith_output, false));
- std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())};
- auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep);
- return std::tuple_cat(current, next);
- }
- template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
- static typename std::enable_if<std::is_same<T, TensorArray&>::value, std::tuple<T, Ts...>>::type
- CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
- args.push_back(std::make_unique<TensorArray>(context, ith_output, false));
- std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())};
- auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep);
- return std::tuple_cat(current, next);
- }
- #define CREATE_TUPLE_INPUT(data_type) \
- template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
- static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
- CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
- args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
- std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())}; \
- auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
- return std::tuple_cat(current, next); \
- } \
- template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
- static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \
- CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
- args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
- std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())}; \
- auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
- return std::tuple_cat(current, next); \
- } \
- template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
- static typename std::enable_if<std::is_same<T, std::optional<const Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
- CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
- if (ith_input < num_input) { \
- args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
- std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())}; \
- auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
- return std::tuple_cat(current, next); \
- } else { \
- std::tuple<T> current = std::tuple<T>{}; \
- auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
- return std::tuple_cat(current, next); \
- } \
- } \
- template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
- static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>*>::value, std::tuple<T, Ts...>>::type \
- CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
- if ("CPUExecutionProvider" != ep) { \
- ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
- } \
- args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
- std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsSpan()}; \
- auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
- return std::tuple_cat(current, next); \
- } \
- template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
- static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>&>::value, std::tuple<T, Ts...>>::type \
- CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
- if ("CPUExecutionProvider" != ep) { \
- ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
- } \
- args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
- std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsSpan()}; \
- auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
- return std::tuple_cat(current, next); \
- } \
- template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
- static typename std::enable_if<std::is_same<T, std::optional<const Custom::Span<data_type>*>>::value, std::tuple<T, Ts...>>::type \
- CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
- if (ith_input < num_input) { \
- if ("CPUExecutionProvider" != ep) { \
- ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
- } \
- args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
- std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsSpan()}; \
- auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
- return std::tuple_cat(current, next); \
- } else { \
- std::tuple<T> current = std::tuple<T>{}; \
- auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
- return std::tuple_cat(current, next); \
- } \
- } \
- template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
- static typename std::enable_if<std::is_same<T, data_type>::value, std::tuple<T, Ts...>>::type \
- CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
- if ("CPUExecutionProvider" != ep) { \
- ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
- } \
- args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
- std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsScalar()}; \
- auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
- return std::tuple_cat(current, next); \
- } \
- template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
- static typename std::enable_if<std::is_same<T, std::optional<data_type>>::value, std::tuple<T, Ts...>>::type \
- CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
- if (ith_input < num_input) { \
- if ("CPUExecutionProvider" != ep) { \
- ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
- } \
- args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
- std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsScalar()}; \
- auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
- return std::tuple_cat(current, next); \
- } else { \
- std::tuple<T> current = std::tuple<T>{}; \
- auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
- return std::tuple_cat(current, next); \
- } \
- }
- #define CREATE_TUPLE_OUTPUT(data_type) \
- template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
- static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
- CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
- args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_output, false)); \
- std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())}; \
- auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \
- return std::tuple_cat(current, next); \
- } \
- template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
- static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \
- CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
- args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_output, false)); \
- std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())}; \
- auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \
- return std::tuple_cat(current, next); \
- } \
- template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
- static typename std::enable_if<std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
- CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
- if (ith_output < num_output) { \
- args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_output, false)); \
- std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())}; \
- auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \
- return std::tuple_cat(current, next); \
- } else { \
- std::tuple<T> current = std::tuple<T>{}; \
- auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \
- return std::tuple_cat(current, next); \
- } \
- }
- #define CREATE_TUPLE(data_type) \
- CREATE_TUPLE_INPUT(data_type) \
- CREATE_TUPLE_OUTPUT(data_type)
- CREATE_TUPLE(bool)
- CREATE_TUPLE(float)
- CREATE_TUPLE(Ort::Float16_t)
- CREATE_TUPLE(Ort::BFloat16_t)
- CREATE_TUPLE(double)
- CREATE_TUPLE(int8_t)
- CREATE_TUPLE(int16_t)
- CREATE_TUPLE(int32_t)
- CREATE_TUPLE(int64_t)
- CREATE_TUPLE(uint8_t)
- CREATE_TUPLE(uint16_t)
- CREATE_TUPLE(uint32_t)
- CREATE_TUPLE(uint64_t)
- CREATE_TUPLE(std::string)
- CREATE_TUPLE_INPUT(std::string_view)
- CREATE_TUPLE(Ort::Float8E4M3FN_t)
- CREATE_TUPLE(Ort::Float8E4M3FNUZ_t)
- CREATE_TUPLE(Ort::Float8E5M2_t)
- CREATE_TUPLE(Ort::Float8E5M2FNUZ_t)
- // ParseArgs ...
- template <typename... Ts>
- static typename std::enable_if<0 == sizeof...(Ts)>::type
- ParseArgs(std::vector<ONNXTensorElementDataType>&, std::vector<ONNXTensorElementDataType>&) {
- }
- template <typename T, typename... Ts>
- static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext*>::value>::type
- ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
- ParseArgs<Ts...>(input_types, output_types);
- }
- template <typename T, typename... Ts>
- static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext&>::value>::type
- ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
- ParseArgs<Ts...>(input_types, output_types);
- }
- #ifdef ORT_CUDA_CTX
- template <typename T, typename... Ts>
- static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const CudaContext&>::value>::type
- ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
- ParseArgs<Ts...>(input_types, output_types);
- }
- #endif
- #ifdef ORT_ROCM_CTX
- template <typename T, typename... Ts>
- static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const RocmContext&>::value>::type
- ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
- ParseArgs<Ts...>(input_types, output_types);
- }
- #endif
- template <typename T, typename... Ts>
- static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const TensorArray&>::value>::type
- ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
- input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
- ParseArgs<Ts...>(input_types, output_types);
- }
- template <typename T, typename... Ts>
- static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const TensorArray*>::value>::type
- ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
- input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
- ParseArgs<Ts...>(input_types, output_types);
- }
- template <typename T, typename... Ts>
- static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, TensorArray&>::value>::type
- ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
- output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
- ParseArgs<Ts...>(input_types, output_types);
- }
- template <typename T, typename... Ts>
- static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, TensorArray*>::value>::type
- ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
- output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
- ParseArgs<Ts...>(input_types, output_types);
- }
- #define PARSE_INPUT_BASE(pack_type, onnx_type) \
- template <typename T, typename... Ts> \
- static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, pack_type>::value>::type \
- ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
- input_types.push_back(onnx_type); \
- ParseArgs<Ts...>(input_types, output_types); \
- } \
- template <typename T, typename... Ts> \
- static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const std::optional<pack_type>>::value>::type \
- ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
- input_types.push_back(onnx_type); \
- ParseArgs<Ts...>(input_types, output_types); \
- } \
- template <typename T, typename... Ts> \
- static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<pack_type>>::value>::type \
- ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
- input_types.push_back(onnx_type); \
- ParseArgs<Ts...>(input_types, output_types); \
- }
- #define PARSE_INPUT(data_type, onnx_type) \
- PARSE_INPUT_BASE(const Custom::Tensor<data_type>*, onnx_type) \
- PARSE_INPUT_BASE(const Custom::Tensor<data_type>&, onnx_type) \
- PARSE_INPUT_BASE(const Custom::Span<data_type>*, onnx_type) \
- PARSE_INPUT_BASE(const Custom::Span<data_type>&, onnx_type) \
- PARSE_INPUT_BASE(data_type, onnx_type)
- #define PARSE_OUTPUT(data_type, onnx_type) \
- template <typename T, typename... Ts> \
- static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>*>::value>::type \
- ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
- output_types.push_back(onnx_type); \
- ParseArgs<Ts...>(input_types, output_types); \
- } \
- template <typename T, typename... Ts> \
- static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>&>::value>::type \
- ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
- output_types.push_back(onnx_type); \
- ParseArgs<Ts...>(input_types, output_types); \
- } \
- template <typename T, typename... Ts> \
- static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value>::type \
- ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
- output_types.push_back(onnx_type); \
- ParseArgs<Ts...>(input_types, output_types); \
- }
- #define PARSE_ARGS(data_type, onnx_type) \
- PARSE_INPUT(data_type, onnx_type) \
- PARSE_OUTPUT(data_type, onnx_type)
- PARSE_ARGS(bool, ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL)
- PARSE_ARGS(float, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
- PARSE_ARGS(Ort::Float16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16)
- PARSE_ARGS(Ort::BFloat16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16)
- PARSE_ARGS(double, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE)
- PARSE_ARGS(int8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8)
- PARSE_ARGS(int16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16)
- PARSE_ARGS(int32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32)
- PARSE_ARGS(int64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)
- PARSE_ARGS(uint8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8)
- PARSE_ARGS(uint16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16)
- PARSE_ARGS(uint32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32)
- PARSE_ARGS(uint64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64)
- PARSE_ARGS(std::string, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING)
- PARSE_ARGS(std::string_view, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) // todo - remove string_view output
- PARSE_ARGS(Ort::Float8E4M3FN_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN)
- PARSE_ARGS(Ort::Float8E4M3FNUZ_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ)
- PARSE_ARGS(Ort::Float8E5M2_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2)
- PARSE_ARGS(Ort::Float8E5M2FNUZ_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ)
- OrtLiteCustomOp(const char* op_name,
- const char* execution_provider,
- ShapeInferFn shape_infer_fn,
- int start_ver = 1,
- int end_ver = MAX_CUSTOM_OP_END_VER) : op_name_(op_name),
- execution_provider_(execution_provider),
- shape_infer_fn_(shape_infer_fn),
- start_ver_(start_ver),
- end_ver_(end_ver) {
- OrtCustomOp::version = ORT_API_VERSION;
- OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast<const OrtLiteCustomOp*>(op)->op_name_.c_str(); };
- OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* op) { return ((OrtLiteCustomOp*)op)->execution_provider_.c_str(); };
- OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp*, size_t) { return OrtMemTypeDefault; };
- OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* op) {
- auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
- return self->input_types_.size();
- };
- OrtCustomOp::GetInputType = [](const OrtCustomOp* op, size_t indice) {
- auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
- return self->input_types_[indice];
- };
- OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* op) {
- auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
- return self->output_types_.size();
- };
- OrtCustomOp::GetOutputType = [](const OrtCustomOp* op, size_t indice) {
- auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
- return self->output_types_[indice];
- };
- OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* op, size_t indice) {
- auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
- return self->input_types_[indice] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ? INPUT_OUTPUT_VARIADIC : INPUT_OUTPUT_OPTIONAL;
- };
- OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t indice) {
- auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
- return self->output_types_[indice] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ? INPUT_OUTPUT_VARIADIC : INPUT_OUTPUT_OPTIONAL;
- };
- OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) {
- return 1;
- };
- OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) {
- return 0;
- };
- OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) {
- return 1;
- };
- OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) {
- return 0;
- };
- OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) { return 0; };
- OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) { return 0; };
- OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) { return 0; };
- OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) { return 0; };
- OrtCustomOp::CreateKernelV2 = {};
- OrtCustomOp::KernelComputeV2 = {};
- OrtCustomOp::KernelCompute = {};
- OrtCustomOp::InferOutputShapeFn = {};
- OrtCustomOp::GetStartVersion = [](const OrtCustomOp* op) {
- auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
- return self->start_ver_;
- };
- OrtCustomOp::GetEndVersion = [](const OrtCustomOp* op) {
- auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
- return self->end_ver_;
- };
- OrtCustomOp::GetMayInplace = {};
- OrtCustomOp::ReleaseMayInplace = {};
- OrtCustomOp::GetAliasMap = {};
- OrtCustomOp::ReleaseAliasMap = {};
- }
- const std::string op_name_;
- const std::string execution_provider_;
- std::vector<ONNXTensorElementDataType> input_types_;
- std::vector<ONNXTensorElementDataType> output_types_;
- ShapeInferFn shape_infer_fn_ = {};
- int start_ver_ = 1;
- int end_ver_ = MAX_CUSTOM_OP_END_VER;
- void* compute_fn_ = {};
- void* compute_fn_return_status_ = {};
- };
- //////////////////////////// OrtLiteCustomFunc ////////////////////////////////
- // The struct is to implement function-as-op.
- // E.g. a function might be defined as:
- // void Filter(const Ort::Custom::Tensor<float>& floats_in, Ort::Custom::Tensor<float>& floats_out) { ... }
- // It could be registered this way:
- // Ort::CustomOpDomain v2_domain{"v2"};
- // std::unique_ptr<OrtLiteCustomOp> fil_op_ptr{Ort::Custom::CreateLiteCustomOp("Filter", "CPUExecutionProvider", Filter)};
- // v2_domain.Add(fil_op_ptr.get());
- // session_options.Add(v2_domain);
- // For the complete example, please search keyword "LiteCustomOpTest" under "<cloned_src_dir>/onnxruntime/test/".
- template <typename... Args>
- struct OrtLiteCustomFunc : public OrtLiteCustomOp {
- using ComputeFn = void (*)(Args...);
- using ComputeFnReturnStatus = Status (*)(Args...);
- using MyType = OrtLiteCustomFunc<Args...>;
- struct Kernel {
- size_t num_input_{};
- size_t num_output_{};
- ComputeFn compute_fn_{};
- ComputeFnReturnStatus compute_fn_return_status_{};
- std::string ep_{};
- };
- OrtLiteCustomFunc(const char* op_name,
- const char* execution_provider,
- ComputeFn compute_fn,
- ShapeInferFn shape_infer_fn = {},
- int start_ver = 1,
- int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) {
- compute_fn_ = reinterpret_cast<void*>(compute_fn);
- ParseArgs<Args...>(input_types_, output_types_);
- OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
- auto kernel = reinterpret_cast<Kernel*>(op_kernel);
- std::vector<ArgPtr> args;
- auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_);
- std::apply([kernel](Args const&... t_args) { kernel->compute_fn_(t_args...); }, t);
- };
- OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
- auto kernel = std::make_unique<Kernel>();
- auto me = static_cast<const MyType*>(this_);
- kernel->compute_fn_ = reinterpret_cast<ComputeFn>(me->compute_fn_);
- Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_));
- Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_));
- auto self = static_cast<const OrtLiteCustomFunc*>(this_);
- kernel->ep_ = self->execution_provider_;
- return reinterpret_cast<void*>(kernel.release());
- };
- OrtCustomOp::KernelDestroy = [](void* op_kernel) {
- delete reinterpret_cast<Kernel*>(op_kernel);
- };
- if (shape_infer_fn_) {
- OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp* op, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
- auto shape_info_fn = static_cast<const MyType*>(op)->shape_infer_fn_;
- ShapeInferContext ctx(&GetApi(), ort_ctx);
- return shape_info_fn(ctx);
- };
- }
- }
- OrtLiteCustomFunc(const char* op_name,
- const char* execution_provider,
- ComputeFnReturnStatus compute_fn_return_status,
- ShapeInferFn shape_infer_fn = {},
- int start_ver = 1,
- int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) {
- compute_fn_return_status_ = reinterpret_cast<void*>(compute_fn_return_status);
- ParseArgs<Args...>(input_types_, output_types_);
- OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
- auto kernel = reinterpret_cast<Kernel*>(op_kernel);
- std::vector<ArgPtr> args;
- auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_);
- return std::apply([kernel](Args const&... t_args) { Status status = kernel->compute_fn_return_status_(t_args...); return status.release(); }, t);
- };
- OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
- auto kernel = std::make_unique<Kernel>();
- auto me = static_cast<const MyType*>(this_);
- kernel->compute_fn_return_status_ = reinterpret_cast<ComputeFnReturnStatus>(me->compute_fn_return_status_);
- Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_));
- Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_));
- auto self = static_cast<const OrtLiteCustomFunc*>(this_);
- kernel->ep_ = self->execution_provider_;
- return reinterpret_cast<void*>(kernel.release());
- };
- OrtCustomOp::KernelDestroy = [](void* op_kernel) {
- delete reinterpret_cast<Kernel*>(op_kernel);
- };
- if (shape_infer_fn_) {
- OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp* op, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
- auto shape_info_fn = static_cast<const MyType*>(op)->shape_infer_fn_;
- ShapeInferContext ctx(&GetApi(), ort_ctx);
- return shape_info_fn(ctx);
- };
- }
- }
- }; // struct OrtLiteCustomFunc
- /////////////////////////// OrtLiteCustomStruct ///////////////////////////
- // The struct is to implement struct-as-op.
- // E.g. a struct might be defined as:
- // struct Merge {
- // Merge(const OrtApi* ort_api, const OrtKernelInfo* info) {...}
- // void Compute(const Ort::Custom::Tensor<std::string_view>& strings_in,
- // std::string_view string_in,
- // Ort::Custom::Tensor<std::string>* strings_out) {...}
- // bool reverse_ = false;
- // };
- // It could be registered this way:
- // Ort::CustomOpDomain v2_domain{"v2"};
- // std::unique_ptr<OrtLiteCustomOp> mrg_op_ptr{Ort::Custom::CreateLiteCustomOp<Merge>("Merge", "CPUExecutionProvider")};
- // v2_domain.Add(mrg_op_ptr.get());
- // session_options.Add(v2_domain);
- // For the complete example, please search keyword "LiteCustomOpTest" under "<cloned_src_dir>/onnxruntime/test/".
- template <typename CustomOp>
- struct OrtLiteCustomStruct : public OrtLiteCustomOp {
- template <typename... Args>
- using CustomComputeFn = void (CustomOp::*)(Args...);
- template <typename... Args>
- using CustomComputeFnReturnStatus = Status (CustomOp::*)(Args...);
- using MyType = OrtLiteCustomStruct<CustomOp>;
- struct Kernel {
- size_t num_input_{};
- size_t num_output_{};
- std::unique_ptr<CustomOp> custom_op_;
- std::string ep_{};
- };
- OrtLiteCustomStruct(const char* op_name,
- const char* execution_provider,
- int start_ver = 1,
- int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, {}, start_ver, end_ver) {
- SetCompute(&CustomOp::Compute);
- OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
- auto kernel = std::make_unique<Kernel>();
- Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_));
- Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_));
- kernel->custom_op_ = std::make_unique<CustomOp>(ort_api, info);
- auto self = static_cast<const OrtLiteCustomStruct*>(this_);
- kernel->ep_ = self->execution_provider_;
- return reinterpret_cast<void*>(kernel.release());
- };
- OrtCustomOp::KernelDestroy = [](void* op_kernel) {
- delete reinterpret_cast<Kernel*>(op_kernel);
- };
- SetShapeInfer<CustomOp>(0);
- }
- template <typename... Args>
- void SetCompute(CustomComputeFn<Args...>) {
- ParseArgs<Args...>(input_types_, output_types_);
- OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
- auto kernel = reinterpret_cast<Kernel*>(op_kernel);
- ArgPtrs args;
- auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_);
- std::apply([kernel](Args const&... t_args) { kernel->custom_op_->Compute(t_args...); }, t);
- };
- }
- template <typename... Args>
- void SetCompute(CustomComputeFnReturnStatus<Args...>) {
- ParseArgs<Args...>(input_types_, output_types_);
- OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
- auto kernel = reinterpret_cast<Kernel*>(op_kernel);
- ArgPtrs args;
- auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_);
- return std::apply([kernel](Args const&... t_args) { Status status = kernel->custom_op_->Compute(t_args...); return status.release(); }, t);
- };
- }
- template <typename C>
- decltype(&C::InferOutputShape) SetShapeInfer(decltype(&C::InferOutputShape)) {
- OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp*, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
- ShapeInferContext ctx(&GetApi(), ort_ctx);
- return C::InferOutputShape(ctx);
- };
- return {};
- }
- template <typename C>
- void SetShapeInfer(...) {
- OrtCustomOp::InferOutputShapeFn = {};
- }
- }; // struct OrtLiteCustomStruct
- /////////////////////////// CreateLiteCustomOp ////////////////////////////
- template <typename... Args>
- OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
- const char* execution_provider,
- void (*custom_compute_fn)(Args...),
- Status (*shape_infer_fn)(ShapeInferContext&) = {},
- int start_ver = 1,
- int end_ver = MAX_CUSTOM_OP_END_VER) {
- using LiteOp = OrtLiteCustomFunc<Args...>;
- return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn, shape_infer_fn, start_ver, end_ver).release();
- }
- template <typename... Args>
- OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
- const char* execution_provider,
- Status (*custom_compute_fn_v2)(Args...),
- Status (*shape_infer_fn)(ShapeInferContext&) = {},
- int start_ver = 1,
- int end_ver = MAX_CUSTOM_OP_END_VER) {
- using LiteOp = OrtLiteCustomFunc<Args...>;
- return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn_v2, shape_infer_fn, start_ver, end_ver).release();
- }
- template <typename CustomOp>
- OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
- const char* execution_provider,
- int start_ver = 1,
- int end_ver = MAX_CUSTOM_OP_END_VER) {
- using LiteOp = OrtLiteCustomStruct<CustomOp>;
- return std::make_unique<LiteOp>(op_name, execution_provider, start_ver, end_ver).release();
- }
- } // namespace Custom
- } // namespace Ort
|