123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295 |
- // Copyright (c) Microsoft Corporation. All rights reserved.
- // Licensed under the MIT License.
- #pragma once
- #include "onnxruntime_training_c_api.h"
- #include "onnxruntime_cxx_api.h"
- namespace Ort {
- inline TrainingSession::TrainingSession(const Env& env, const SessionOptions& session_options,
- CheckpointState& checkpoint_state,
- const std::basic_string<ORTCHAR_T>& train_model_path,
- const std::optional<std::basic_string<ORTCHAR_T>>& eval_model_path,
- const std::optional<std::basic_string<ORTCHAR_T>>& optimizer_model_path) {
- ThrowOnError(GetTrainingApi().CreateTrainingSession(
- env, session_options, checkpoint_state,
- train_model_path.c_str(),
- eval_model_path.has_value() ? eval_model_path.value().c_str() : nullptr,
- optimizer_model_path.has_value() ? optimizer_model_path.value().c_str() : nullptr,
- &p_));
- ThrowOnError(GetTrainingApi().TrainingSessionGetTrainingModelOutputCount(p_, &training_model_output_count_));
- ThrowOnError(GetTrainingApi().TrainingSessionGetEvalModelOutputCount(p_, &eval_model_output_count_));
- }
- inline TrainingSession::TrainingSession(const Env& env, const SessionOptions& session_options,
- CheckpointState& checkpoint_state,
- const std::vector<uint8_t>& train_model_data,
- const std::vector<uint8_t>& eval_model_data,
- const std::vector<uint8_t>& optim_model_data) {
- ThrowOnError(GetTrainingApi().CreateTrainingSessionFromBuffer(
- env, session_options, checkpoint_state,
- train_model_data.data(), train_model_data.size(),
- eval_model_data.data(), eval_model_data.size(),
- optim_model_data.data(), optim_model_data.size(),
- &p_));
- ThrowOnError(GetTrainingApi().TrainingSessionGetTrainingModelOutputCount(p_, &training_model_output_count_));
- ThrowOnError(GetTrainingApi().TrainingSessionGetEvalModelOutputCount(p_, &eval_model_output_count_));
- }
- inline std::vector<Value> TrainingSession::TrainStep(const std::vector<Value>& input_values) {
- std::vector<Value> output_values;
- output_values.reserve(training_model_output_count_);
- for (size_t i = 0; i < training_model_output_count_; i++) output_values.emplace_back(nullptr);
- auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values.data());
- auto ort_output_values = reinterpret_cast<OrtValue**>(output_values.data());
- RunOptions run_options;
- ThrowOnError(GetTrainingApi().TrainStep(
- p_, run_options, input_values.size(), ort_input_values,
- training_model_output_count_, ort_output_values));
- return output_values;
- }
- inline void TrainingSession::LazyResetGrad() {
- ThrowOnError(GetTrainingApi().LazyResetGrad(p_));
- }
- inline std::vector<Value> TrainingSession::EvalStep(const std::vector<Value>& input_values) {
- std::vector<Value> output_values;
- output_values.reserve(eval_model_output_count_);
- for (size_t i = 0; i < eval_model_output_count_; i++) output_values.emplace_back(nullptr);
- auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values.data());
- auto ort_output_values = reinterpret_cast<OrtValue**>(output_values.data());
- RunOptions run_options;
- ThrowOnError(GetTrainingApi().EvalStep(
- p_, run_options, input_values.size(), ort_input_values,
- eval_model_output_count_, ort_output_values));
- return output_values;
- }
- inline void TrainingSession::SetLearningRate(float learning_rate) {
- ThrowOnError(GetTrainingApi().SetLearningRate(p_, learning_rate));
- }
- inline float TrainingSession::GetLearningRate() const {
- float learning_rate = 0;
- ThrowOnError(GetTrainingApi().GetLearningRate(p_, &learning_rate));
- return learning_rate;
- }
- inline void TrainingSession::RegisterLinearLRScheduler(int64_t warmup_step_count, int64_t total_step_count,
- float initial_lr) {
- ThrowOnError(GetTrainingApi().RegisterLinearLRScheduler(p_, warmup_step_count, total_step_count,
- initial_lr));
- }
- inline void TrainingSession::SchedulerStep() {
- ThrowOnError(GetTrainingApi().SchedulerStep(p_));
- }
- inline void TrainingSession::OptimizerStep() {
- RunOptions run_options;
- ThrowOnError(GetTrainingApi().OptimizerStep(p_, run_options));
- }
- inline std::vector<std::string> TrainingSession::InputNames(const bool training) {
- auto& input_count_function = training ? GetTrainingApi().TrainingSessionGetTrainingModelInputCount
- : GetTrainingApi().TrainingSessionGetEvalModelInputCount;
- auto& input_name_function = training ? GetTrainingApi().TrainingSessionGetTrainingModelInputName
- : GetTrainingApi().TrainingSessionGetEvalModelInputName;
- size_t input_count = 0;
- ThrowOnError(input_count_function(p_, &input_count));
- std::vector<std::string> input_names(input_count);
- AllocatorWithDefaultOptions allocator;
- for (size_t index = 0; index < input_count; ++index) {
- char* input_name;
- ThrowOnError(input_name_function(p_, index, allocator, &input_name));
- input_names[index] = std::string(input_name);
- allocator.Free(input_name);
- }
- return input_names;
- }
- inline std::vector<std::string> TrainingSession::OutputNames(const bool training) {
- auto& output_count_function = training ? GetTrainingApi().TrainingSessionGetTrainingModelOutputCount
- : GetTrainingApi().TrainingSessionGetEvalModelOutputCount;
- auto& output_name_function = training ? GetTrainingApi().TrainingSessionGetTrainingModelOutputName
- : GetTrainingApi().TrainingSessionGetEvalModelOutputName;
- size_t output_count = 0;
- ThrowOnError(output_count_function(p_, &output_count));
- std::vector<std::string> output_names(output_count);
- AllocatorWithDefaultOptions allocator;
- for (size_t index = 0; index < output_count; ++index) {
- char* output_name;
- ThrowOnError(output_name_function(p_, index, allocator, &output_name));
- output_names[index] = std::string(output_name);
- allocator.Free(output_name);
- }
- return output_names;
- }
- inline Value TrainingSession::ToBuffer(const bool only_trainable) {
- size_t buffer_size = 0U;
- ThrowOnError(GetTrainingApi().GetParametersSize(p_, &buffer_size, only_trainable));
- std::array<int64_t, 1> buffer_shape{static_cast<int64_t>(buffer_size)};
- AllocatorWithDefaultOptions allocator;
- Value buffer = Value::CreateTensor(allocator, buffer_shape.data(), 1U,
- ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
- ThrowOnError(GetTrainingApi().CopyParametersToBuffer(p_, buffer, only_trainable));
- return buffer;
- }
- inline void TrainingSession::FromBuffer(Value& buffer) {
- if (!buffer.IsTensor()) {
- ThrowStatus(Status("Incorrect buffer received. Expected a tensor buffer.", OrtErrorCode::ORT_INVALID_ARGUMENT));
- }
- auto tensor_info = buffer.GetTensorTypeAndShapeInfo();
- auto buffer_shape = tensor_info.GetShape();
- if (buffer_shape.size() != 1U) {
- ThrowStatus(Status("Incorrect buffer received. Expected a contiguous tensor buffer.",
- OrtErrorCode::ORT_INVALID_ARGUMENT));
- }
- auto buffer_size = buffer_shape.front();
- size_t session_buffer_size = 0U;
- ThrowOnError(GetTrainingApi().GetParametersSize(p_, &session_buffer_size, false));
- if (buffer_size == static_cast<int64_t>(session_buffer_size)) {
- ThrowOnError(GetTrainingApi().CopyBufferToParameters(p_, buffer, false));
- return;
- }
- size_t session_buffer_size_trainable_only = 0U;
- ThrowOnError(GetTrainingApi().GetParametersSize(p_, &session_buffer_size_trainable_only, true));
- if (buffer_size == static_cast<int64_t>(session_buffer_size_trainable_only)) {
- ThrowOnError(GetTrainingApi().CopyBufferToParameters(p_, buffer, true));
- return;
- } else {
- ThrowStatus(Status("Incorrect buffer size received.", OrtErrorCode::ORT_INVALID_ARGUMENT));
- }
- }
- inline CheckpointState CheckpointState::LoadCheckpoint(const std::basic_string<ORTCHAR_T>& path_to_checkpoint) {
- OrtCheckpointState* checkpoint_state;
- ThrowOnError(GetTrainingApi().LoadCheckpoint(path_to_checkpoint.c_str(), &checkpoint_state));
- return CheckpointState(checkpoint_state);
- }
- inline CheckpointState CheckpointState::LoadCheckpointFromBuffer(const std::vector<uint8_t>& buffer) {
- OrtCheckpointState* checkpoint_state;
- ThrowOnError(GetTrainingApi().LoadCheckpointFromBuffer(buffer.data(), buffer.size(), &checkpoint_state));
- return CheckpointState(checkpoint_state);
- }
- inline void CheckpointState::SaveCheckpoint(const CheckpointState& checkpoint_states,
- const std::basic_string<ORTCHAR_T>& path_to_checkpoint,
- const bool include_optimizer_state) {
- ThrowOnError(GetTrainingApi().SaveCheckpoint(checkpoint_states, path_to_checkpoint.c_str(),
- include_optimizer_state));
- }
- inline void TrainingSession::ExportModelForInferencing(const std::basic_string<ORTCHAR_T>& inference_model_path,
- const std::vector<std::string>& graph_output_names) {
- std::vector<const char*> output_names;
- output_names.reserve(graph_output_names.size());
- for (const auto& output_name : graph_output_names) {
- output_names.push_back(output_name.c_str());
- }
- ThrowOnError(GetTrainingApi().ExportModelForInferencing(
- p_, inference_model_path.c_str(), graph_output_names.size(), output_names.data()));
- }
- inline void SetSeed(const int64_t seed) {
- ThrowOnError(GetTrainingApi().SetSeed(seed));
- }
- inline void CheckpointState::AddProperty(const std::string& property_name, const Property& property_value) {
- if (std::holds_alternative<int64_t>(property_value)) {
- int64_t value = std::get<int64_t>(property_value);
- void* value_p = &value;
- ThrowOnError(GetTrainingApi().AddProperty(p_, property_name.c_str(), OrtPropertyType::OrtIntProperty, value_p));
- } else if (std::holds_alternative<float>(property_value)) {
- float value = std::get<float>(property_value);
- void* value_p = &value;
- ThrowOnError(GetTrainingApi().AddProperty(p_, property_name.c_str(), OrtPropertyType::OrtFloatProperty, value_p));
- } else if (std::holds_alternative<std::string>(property_value)) {
- std::string value = std::get<std::string>(property_value);
- auto buffer = std::make_unique<char[]>(value.length() + 1);
- memcpy(buffer.get(), value.c_str(), value.length());
- // AddProperty takes a char* and calls PropertyBag::AddProperty which takes a std::string. The data will be
- // copied at that point so buffer can free the local allocation once the call is made.
- ThrowOnError(GetTrainingApi().AddProperty(p_, property_name.c_str(), OrtPropertyType::OrtStringProperty,
- buffer.get()));
- } else {
- ThrowStatus(Status("Unknown property type received.", OrtErrorCode::ORT_INVALID_ARGUMENT));
- }
- }
- inline Property CheckpointState::GetProperty(const std::string& property_name) {
- void* property_value = nullptr;
- OrtPropertyType property_type;
- AllocatorWithDefaultOptions allocator;
- ThrowOnError(GetTrainingApi().GetProperty(p_, property_name.c_str(), allocator, &property_type, &property_value));
- Property property;
- switch (property_type) {
- case OrtPropertyType::OrtIntProperty: {
- auto value_p = reinterpret_cast<int64_t*>(property_value);
- property = *value_p;
- allocator.Free(property_value);
- break;
- }
- case OrtPropertyType::OrtFloatProperty: {
- auto value_p = reinterpret_cast<float*>(property_value);
- property = *value_p;
- allocator.Free(property_value);
- break;
- }
- case OrtPropertyType::OrtStringProperty: {
- auto value_p = reinterpret_cast<char*>(property_value);
- property = std::string(value_p);
- allocator.Free(property_value);
- break;
- }
- default: {
- ThrowStatus(Status("Unknown property type received.", OrtErrorCode::ORT_INVALID_ARGUMENT));
- break;
- }
- }
- return property;
- }
- inline void CheckpointState::UpdateParameter(const std::string& parameter_name, const Value& parameter) {
- ThrowOnError(GetTrainingApi().UpdateParameter(p_, parameter_name.c_str(), parameter));
- }
- inline Value CheckpointState::GetParameter(const std::string& parameter_name) {
- AllocatorWithDefaultOptions allocator;
- OrtValue* parameter;
- ThrowOnError(GetTrainingApi().GetParameter(p_, parameter_name.c_str(), allocator, ¶meter));
- return Value{parameter};
- }
- } // namespace Ort
|