onnxruntime_training_cxx_inline.h 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. // Copyright (c) Microsoft Corporation. All rights reserved.
  2. // Licensed under the MIT License.
  3. #pragma once
  4. #include "onnxruntime_training_c_api.h"
  5. #include "onnxruntime_cxx_api.h"
  6. namespace Ort {
  7. inline TrainingSession::TrainingSession(const Env& env, const SessionOptions& session_options,
  8. CheckpointState& checkpoint_state,
  9. const std::basic_string<ORTCHAR_T>& train_model_path,
  10. const std::optional<std::basic_string<ORTCHAR_T>>& eval_model_path,
  11. const std::optional<std::basic_string<ORTCHAR_T>>& optimizer_model_path) {
  12. ThrowOnError(GetTrainingApi().CreateTrainingSession(
  13. env, session_options, checkpoint_state,
  14. train_model_path.c_str(),
  15. eval_model_path.has_value() ? eval_model_path.value().c_str() : nullptr,
  16. optimizer_model_path.has_value() ? optimizer_model_path.value().c_str() : nullptr,
  17. &p_));
  18. ThrowOnError(GetTrainingApi().TrainingSessionGetTrainingModelOutputCount(p_, &training_model_output_count_));
  19. ThrowOnError(GetTrainingApi().TrainingSessionGetEvalModelOutputCount(p_, &eval_model_output_count_));
  20. }
  21. inline TrainingSession::TrainingSession(const Env& env, const SessionOptions& session_options,
  22. CheckpointState& checkpoint_state,
  23. const std::vector<uint8_t>& train_model_data,
  24. const std::vector<uint8_t>& eval_model_data,
  25. const std::vector<uint8_t>& optim_model_data) {
  26. ThrowOnError(GetTrainingApi().CreateTrainingSessionFromBuffer(
  27. env, session_options, checkpoint_state,
  28. train_model_data.data(), train_model_data.size(),
  29. eval_model_data.data(), eval_model_data.size(),
  30. optim_model_data.data(), optim_model_data.size(),
  31. &p_));
  32. ThrowOnError(GetTrainingApi().TrainingSessionGetTrainingModelOutputCount(p_, &training_model_output_count_));
  33. ThrowOnError(GetTrainingApi().TrainingSessionGetEvalModelOutputCount(p_, &eval_model_output_count_));
  34. }
  35. inline std::vector<Value> TrainingSession::TrainStep(const std::vector<Value>& input_values) {
  36. std::vector<Value> output_values;
  37. output_values.reserve(training_model_output_count_);
  38. for (size_t i = 0; i < training_model_output_count_; i++) output_values.emplace_back(nullptr);
  39. auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values.data());
  40. auto ort_output_values = reinterpret_cast<OrtValue**>(output_values.data());
  41. RunOptions run_options;
  42. ThrowOnError(GetTrainingApi().TrainStep(
  43. p_, run_options, input_values.size(), ort_input_values,
  44. training_model_output_count_, ort_output_values));
  45. return output_values;
  46. }
  47. inline void TrainingSession::LazyResetGrad() {
  48. ThrowOnError(GetTrainingApi().LazyResetGrad(p_));
  49. }
  50. inline std::vector<Value> TrainingSession::EvalStep(const std::vector<Value>& input_values) {
  51. std::vector<Value> output_values;
  52. output_values.reserve(eval_model_output_count_);
  53. for (size_t i = 0; i < eval_model_output_count_; i++) output_values.emplace_back(nullptr);
  54. auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values.data());
  55. auto ort_output_values = reinterpret_cast<OrtValue**>(output_values.data());
  56. RunOptions run_options;
  57. ThrowOnError(GetTrainingApi().EvalStep(
  58. p_, run_options, input_values.size(), ort_input_values,
  59. eval_model_output_count_, ort_output_values));
  60. return output_values;
  61. }
  62. inline void TrainingSession::SetLearningRate(float learning_rate) {
  63. ThrowOnError(GetTrainingApi().SetLearningRate(p_, learning_rate));
  64. }
  65. inline float TrainingSession::GetLearningRate() const {
  66. float learning_rate = 0;
  67. ThrowOnError(GetTrainingApi().GetLearningRate(p_, &learning_rate));
  68. return learning_rate;
  69. }
  70. inline void TrainingSession::RegisterLinearLRScheduler(int64_t warmup_step_count, int64_t total_step_count,
  71. float initial_lr) {
  72. ThrowOnError(GetTrainingApi().RegisterLinearLRScheduler(p_, warmup_step_count, total_step_count,
  73. initial_lr));
  74. }
  75. inline void TrainingSession::SchedulerStep() {
  76. ThrowOnError(GetTrainingApi().SchedulerStep(p_));
  77. }
  78. inline void TrainingSession::OptimizerStep() {
  79. RunOptions run_options;
  80. ThrowOnError(GetTrainingApi().OptimizerStep(p_, run_options));
  81. }
  82. inline std::vector<std::string> TrainingSession::InputNames(const bool training) {
  83. auto& input_count_function = training ? GetTrainingApi().TrainingSessionGetTrainingModelInputCount
  84. : GetTrainingApi().TrainingSessionGetEvalModelInputCount;
  85. auto& input_name_function = training ? GetTrainingApi().TrainingSessionGetTrainingModelInputName
  86. : GetTrainingApi().TrainingSessionGetEvalModelInputName;
  87. size_t input_count = 0;
  88. ThrowOnError(input_count_function(p_, &input_count));
  89. std::vector<std::string> input_names(input_count);
  90. AllocatorWithDefaultOptions allocator;
  91. for (size_t index = 0; index < input_count; ++index) {
  92. char* input_name;
  93. ThrowOnError(input_name_function(p_, index, allocator, &input_name));
  94. input_names[index] = std::string(input_name);
  95. allocator.Free(input_name);
  96. }
  97. return input_names;
  98. }
  99. inline std::vector<std::string> TrainingSession::OutputNames(const bool training) {
  100. auto& output_count_function = training ? GetTrainingApi().TrainingSessionGetTrainingModelOutputCount
  101. : GetTrainingApi().TrainingSessionGetEvalModelOutputCount;
  102. auto& output_name_function = training ? GetTrainingApi().TrainingSessionGetTrainingModelOutputName
  103. : GetTrainingApi().TrainingSessionGetEvalModelOutputName;
  104. size_t output_count = 0;
  105. ThrowOnError(output_count_function(p_, &output_count));
  106. std::vector<std::string> output_names(output_count);
  107. AllocatorWithDefaultOptions allocator;
  108. for (size_t index = 0; index < output_count; ++index) {
  109. char* output_name;
  110. ThrowOnError(output_name_function(p_, index, allocator, &output_name));
  111. output_names[index] = std::string(output_name);
  112. allocator.Free(output_name);
  113. }
  114. return output_names;
  115. }
  116. inline Value TrainingSession::ToBuffer(const bool only_trainable) {
  117. size_t buffer_size = 0U;
  118. ThrowOnError(GetTrainingApi().GetParametersSize(p_, &buffer_size, only_trainable));
  119. std::array<int64_t, 1> buffer_shape{static_cast<int64_t>(buffer_size)};
  120. AllocatorWithDefaultOptions allocator;
  121. Value buffer = Value::CreateTensor(allocator, buffer_shape.data(), 1U,
  122. ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
  123. ThrowOnError(GetTrainingApi().CopyParametersToBuffer(p_, buffer, only_trainable));
  124. return buffer;
  125. }
  126. inline void TrainingSession::FromBuffer(Value& buffer) {
  127. if (!buffer.IsTensor()) {
  128. ThrowStatus(Status("Incorrect buffer received. Expected a tensor buffer.", OrtErrorCode::ORT_INVALID_ARGUMENT));
  129. }
  130. auto tensor_info = buffer.GetTensorTypeAndShapeInfo();
  131. auto buffer_shape = tensor_info.GetShape();
  132. if (buffer_shape.size() != 1U) {
  133. ThrowStatus(Status("Incorrect buffer received. Expected a contiguous tensor buffer.",
  134. OrtErrorCode::ORT_INVALID_ARGUMENT));
  135. }
  136. auto buffer_size = buffer_shape.front();
  137. size_t session_buffer_size = 0U;
  138. ThrowOnError(GetTrainingApi().GetParametersSize(p_, &session_buffer_size, false));
  139. if (buffer_size == static_cast<int64_t>(session_buffer_size)) {
  140. ThrowOnError(GetTrainingApi().CopyBufferToParameters(p_, buffer, false));
  141. return;
  142. }
  143. size_t session_buffer_size_trainable_only = 0U;
  144. ThrowOnError(GetTrainingApi().GetParametersSize(p_, &session_buffer_size_trainable_only, true));
  145. if (buffer_size == static_cast<int64_t>(session_buffer_size_trainable_only)) {
  146. ThrowOnError(GetTrainingApi().CopyBufferToParameters(p_, buffer, true));
  147. return;
  148. } else {
  149. ThrowStatus(Status("Incorrect buffer size received.", OrtErrorCode::ORT_INVALID_ARGUMENT));
  150. }
  151. }
  152. inline CheckpointState CheckpointState::LoadCheckpoint(const std::basic_string<ORTCHAR_T>& path_to_checkpoint) {
  153. OrtCheckpointState* checkpoint_state;
  154. ThrowOnError(GetTrainingApi().LoadCheckpoint(path_to_checkpoint.c_str(), &checkpoint_state));
  155. return CheckpointState(checkpoint_state);
  156. }
  157. inline CheckpointState CheckpointState::LoadCheckpointFromBuffer(const std::vector<uint8_t>& buffer) {
  158. OrtCheckpointState* checkpoint_state;
  159. ThrowOnError(GetTrainingApi().LoadCheckpointFromBuffer(buffer.data(), buffer.size(), &checkpoint_state));
  160. return CheckpointState(checkpoint_state);
  161. }
  162. inline void CheckpointState::SaveCheckpoint(const CheckpointState& checkpoint_states,
  163. const std::basic_string<ORTCHAR_T>& path_to_checkpoint,
  164. const bool include_optimizer_state) {
  165. ThrowOnError(GetTrainingApi().SaveCheckpoint(checkpoint_states, path_to_checkpoint.c_str(),
  166. include_optimizer_state));
  167. }
  168. inline void TrainingSession::ExportModelForInferencing(const std::basic_string<ORTCHAR_T>& inference_model_path,
  169. const std::vector<std::string>& graph_output_names) {
  170. std::vector<const char*> output_names;
  171. output_names.reserve(graph_output_names.size());
  172. for (const auto& output_name : graph_output_names) {
  173. output_names.push_back(output_name.c_str());
  174. }
  175. ThrowOnError(GetTrainingApi().ExportModelForInferencing(
  176. p_, inference_model_path.c_str(), graph_output_names.size(), output_names.data()));
  177. }
  178. inline void SetSeed(const int64_t seed) {
  179. ThrowOnError(GetTrainingApi().SetSeed(seed));
  180. }
  181. inline void CheckpointState::AddProperty(const std::string& property_name, const Property& property_value) {
  182. if (std::holds_alternative<int64_t>(property_value)) {
  183. int64_t value = std::get<int64_t>(property_value);
  184. void* value_p = &value;
  185. ThrowOnError(GetTrainingApi().AddProperty(p_, property_name.c_str(), OrtPropertyType::OrtIntProperty, value_p));
  186. } else if (std::holds_alternative<float>(property_value)) {
  187. float value = std::get<float>(property_value);
  188. void* value_p = &value;
  189. ThrowOnError(GetTrainingApi().AddProperty(p_, property_name.c_str(), OrtPropertyType::OrtFloatProperty, value_p));
  190. } else if (std::holds_alternative<std::string>(property_value)) {
  191. std::string value = std::get<std::string>(property_value);
  192. auto buffer = std::make_unique<char[]>(value.length() + 1);
  193. memcpy(buffer.get(), value.c_str(), value.length());
  194. // AddProperty takes a char* and calls PropertyBag::AddProperty which takes a std::string. The data will be
  195. // copied at that point so buffer can free the local allocation once the call is made.
  196. ThrowOnError(GetTrainingApi().AddProperty(p_, property_name.c_str(), OrtPropertyType::OrtStringProperty,
  197. buffer.get()));
  198. } else {
  199. ThrowStatus(Status("Unknown property type received.", OrtErrorCode::ORT_INVALID_ARGUMENT));
  200. }
  201. }
  202. inline Property CheckpointState::GetProperty(const std::string& property_name) {
  203. void* property_value = nullptr;
  204. OrtPropertyType property_type;
  205. AllocatorWithDefaultOptions allocator;
  206. ThrowOnError(GetTrainingApi().GetProperty(p_, property_name.c_str(), allocator, &property_type, &property_value));
  207. Property property;
  208. switch (property_type) {
  209. case OrtPropertyType::OrtIntProperty: {
  210. auto value_p = reinterpret_cast<int64_t*>(property_value);
  211. property = *value_p;
  212. allocator.Free(property_value);
  213. break;
  214. }
  215. case OrtPropertyType::OrtFloatProperty: {
  216. auto value_p = reinterpret_cast<float*>(property_value);
  217. property = *value_p;
  218. allocator.Free(property_value);
  219. break;
  220. }
  221. case OrtPropertyType::OrtStringProperty: {
  222. auto value_p = reinterpret_cast<char*>(property_value);
  223. property = std::string(value_p);
  224. allocator.Free(property_value);
  225. break;
  226. }
  227. default: {
  228. ThrowStatus(Status("Unknown property type received.", OrtErrorCode::ORT_INVALID_ARGUMENT));
  229. break;
  230. }
  231. }
  232. return property;
  233. }
  234. inline void CheckpointState::UpdateParameter(const std::string& parameter_name, const Value& parameter) {
  235. ThrowOnError(GetTrainingApi().UpdateParameter(p_, parameter_name.c_str(), parameter));
  236. }
  237. inline Value CheckpointState::GetParameter(const std::string& parameter_name) {
  238. AllocatorWithDefaultOptions allocator;
  239. OrtValue* parameter;
  240. ThrowOnError(GetTrainingApi().GetParameter(p_, parameter_name.c_str(), allocator, &parameter));
  241. return Value{parameter};
  242. }
  243. } // namespace Ort