onnxruntime_training_c_api.h 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731
  1. // Copyright (c) Microsoft Corporation. All rights reserved.
  2. // Licensed under the MIT License.
  3. // This file contains the training c apis.
  4. #pragma once
  5. #include <stdbool.h>
  6. #include "onnxruntime_c_api.h"
  7. /** \page training_c_cpp_api Training C & C++ APIs
  8. *
  9. * Training C and C++ APIs are an extension of the \ref c_cpp_api "onnxruntime core C and C++ APIs" and should be used in conjunction with them.
  10. *
  11. * In order to train a model with onnxruntime, the following training artifacts must be generated:
  12. * - The training onnx model
  13. * - The checkpoint file
  14. * - The optimizer onnx model
  15. * - The eval onnx model model (optional)
  16. *
  17. * These training artifacts can be generated as part of an offline step using the python [utilities](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/python/training/onnxblock/README.md) made available in the `onnxruntime-training` python package.
  18. *
  19. * After these artifacts have been generated, the C and C++ utilities listed in this documentation can be leveraged to perform training.
  20. *
  21. * If any problem is encountered, please create an [issue](https://github.com/microsoft/onnxruntime/issues/new) with your scenario and requirements, and we will be sure to respond and follow up on the request.
  22. *
  23. * <h1>Training C API</h1>
  24. *
  25. * ::OrtTrainingApi - Training C API functions.
  26. *
  27. * This C structure contains functions that enable users to perform training with onnxruntime.
  28. *
  29. * _Sample Code_:
  30. *
  31. * ```c
  32. * #include <onnxruntime_training_api.h>
  33. *
  34. * OrtApi* g_ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
  35. * OrtTrainingApi* g_ort_training_api = g_ort_api->GetTrainingApi(ORT_API_VERSION);
  36. *
  37. * OrtEnv* env = NULL;
  38. * g_ort_api->CreateEnv(logging_level, logid, &env);
  39. * OrtSessionOptions* session_options = NULL;
  40. * g_ort_api->CreateSessionOptions(&session_options);
  41. *
  42. * OrtCheckpointState* state = NULL;
  43. * g_ort_training_api->LoadCheckpoint(path_to_checkpoint, &state);
  44. *
  45. * OrtTrainingSession* training_session = NULL;
  46. * g_ort_training_api->CreateTrainingSession(env, session_options, training_model_path,
  47. * state, eval_model_path, optimizer_model_path,
  48. * &training_session);
  49. * // Training loop
  50. * {
  51. * g_ort_training_api->TrainStep(...);
  52. * g_ort_training_api->OptimizerStep(...);
  53. * g_ort_training_api->LazyResetGrad(...);
  54. * }
  55. *
  56. * g_ort_training_api->ExportModelForInferencing(training_session, inference_model_path, ...);
  57. * g_ort_training_api->SaveCheckpoint(state, path_to_checkpoint, false);
  58. *
  59. * g_ort_training_api->ReleaseTrainingSession(training_session);
  60. * g_ort_training_api->ReleaseCheckpointState(state);
  61. * ```
  62. *
  63. * > **Note**
  64. * > The ::OrtCheckpointState contains the entire training state that the ::OrtTrainingSession uses. As a result, the training session must always have access to the state. That is to say, the ::OrtCheckpointState instance must outlive the lifetime of the ::OrtTrainingSession instance.
  65. *
  66. * <h1>Training C++ API</h1>
  67. *
  68. * @ref TrainingCpp - Training C++ API classes and functions.
  69. *
  70. * These C++ classes and functions enable users to perform training with onnxruntime.
  71. *
  72. * _Sample Code_:
  73. *
  74. * ```cc
  75. * #include <onnxruntime_training_cxx_api.h>
  76. *
  77. * Ort::Env env;
  78. * Ort::SessionOptions session_options;
  79. *
  80. * auto state = Ort::CheckpointState::LoadCheckpoint(path_to_checkpoint);
  81. * auto training_session = Ort::TrainingSession(env, session_options, state, training_model_path,
  82. * eval_model_path, optimizer_model_path);
  83. *
  84. * // Training Loop
  85. * {
  86. * training_session.TrainStep(...);
  87. * training_session.OptimizerStep(...);
  88. * training_session.LazyResetGrad(...);
  89. * }
  90. *
  91. * training_session->ExportModelForInferencing(inference_model_path, ...);
  92. * Ort::CheckpointState::SaveCheckpoint(state, path_to_checkpoint, false);
  93. * ```
  94. * > **Note**
  95. * > The ::Ort::CheckpointState contains the entire training state that the ::Ort::TrainingSession uses. As a result, the training session must always have access to the state. That is to say, the ::Ort::CheckpointState instance must outlive the lifetime of the ::Ort::TrainingSession instance.
  96. */
  97. /** @defgroup TrainingC Ort Training C API
  98. * @{
  99. */
  100. ORT_RUNTIME_CLASS(TrainingSession); // Type that enables performing training for the given user models.
  101. ORT_RUNTIME_CLASS(CheckpointState); // Type that holds the training states for the training session.
  102. /** \brief Type of property to be added to or returned from the ::OrtCheckpointState.
  103. */
  104. typedef enum OrtPropertyType {
  105. OrtIntProperty = 0,
  106. OrtFloatProperty = 1,
  107. OrtStringProperty = 2,
  108. } OrtPropertyType;
  109. /** \brief The Training C API that holds onnxruntime training function pointers
  110. *
  111. * All the Training C API functions are defined inside this structure as pointers to functions.
  112. * Call OrtApi::GetTrainingApi to get a pointer to this struct.
  113. *
  114. * \nosubgrouping
  115. */
  116. struct OrtTrainingApi {
  117. /// \name Accessing The Training Session State
  118. /// @{
  119. /** \brief Load a checkpoint state from a file on disk into checkpoint_state.
  120. *
  121. * This function will parse a checkpoint file, pull relevant data and load the training
  122. * state into the checkpoint_state. This checkpoint state can then be used to create the
  123. * training session by invoking OrtTrainingApi::CreateTrainingSession. By doing so, the training
  124. * session will resume training from the given checkpoint state.
  125. * \note Note that the training session created with a checkpoint state uses this state to store the entire
  126. * training state (including model parameters, its gradients, the optimizer states and the properties).
  127. * As a result, it is required that the checkpoint state outlive the lifetime of the training session.
  128. * \note Note that the checkpoint file can be either the complete checkpoint or the nominal checkpoint.
  129. *
  130. * \param[in] checkpoint_path Path to the checkpoint file
  131. * \param[out] checkpoint_state Checkpoint state that contains the states of the training session.
  132. *
  133. * \snippet{doc} snippets.dox OrtStatus Return Value
  134. *
  135. */
  136. ORT_API2_STATUS(LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path,
  137. _Outptr_ OrtCheckpointState** checkpoint_state);
  138. /** \brief Save the given state to a checkpoint file on disk.
  139. *
  140. * This function serializes the provided checkpoint state to a file on disk.
  141. * This checkpoint can later be loaded by invoking OrtTrainingApi::LoadCheckpoint to resume
  142. * training from this snapshot of the state.
  143. *
  144. * \param[in] checkpoint_state The checkpoint state to save.
  145. * \param[in] checkpoint_path Path to the checkpoint file.
  146. * \param[in] include_optimizer_state Flag to indicate whether to save the optimizer state or not.
  147. *
  148. * \snippet{doc} snippets.dox OrtStatus Return Value
  149. *
  150. */
  151. ORT_API2_STATUS(SaveCheckpoint, _In_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* checkpoint_path,
  152. const bool include_optimizer_state);
  153. /// @}
  154. /// \name Implementing The Training Loop
  155. /// @{
  156. /** \brief Create a training session that can be used to begin or resume training.
  157. *
  158. * This function creates a training session based on the env and session options provided that can
  159. * begin or resume training from a given checkpoint state for the given onnx models.
  160. * The checkpoint state represents the parameters of the training session which will be moved
  161. * to the device specified by the user through the session options (if necessary).
  162. * The training session requires four training artifacts
  163. * - The training onnx model
  164. * - The evaluation onnx model (optional)
  165. * - The optimizer onnx model
  166. * - The checkpoint file
  167. *
  168. * These artifacts can be generated using the `onnxruntime-training` python [utility](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/python/training/onnxblock/README.md).
  169. *
  170. * \param[in] env Environment to be used for the training session.
  171. * \param[in] options Session options that the user can customize for this training session.
  172. * \param[in] checkpoint_state Training states that the training session uses as a starting point for training.
  173. * \param[in] train_model_path Model to be used to perform training.
  174. * \param[in] eval_model_path Model to be used to perform evaluation.
  175. * \param[in] optimizer_model_path Model to be used to perform gradient descent.
  176. * \param[out] out Created training session.
  177. *
  178. * \snippet{doc} snippets.dox OrtStatus Return Value
  179. *
  180. */
  181. ORT_API2_STATUS(CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options,
  182. _Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path,
  183. _In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path,
  184. _Outptr_result_maybenull_ OrtTrainingSession** out);
  185. /** \brief Create a training session that can be used to begin or resume training.
  186. * This api provides a way to load all the training artifacts from buffers instead of files.
  187. *
  188. * \param[in] env Environment to be used for the training session.
  189. * \param[in] options Session options that the user can customize for this training session.
  190. * \param[in] checkpoint_state Training states that the training session uses as a starting point for training.
  191. * \param[in] train_model_data Buffer containing the model data to be used to perform training
  192. * \param[in] train_data_length Length of the buffer containing train_model_data
  193. * \param[in] eval_model_data Buffer containing the model data to be used to perform evaluation
  194. * \param[in] eval_data_length Length of the buffer containing eval_model_data
  195. * \param[in] optim_model_data Buffer containing the model data to be used to perform weight update
  196. * \param[in] optim_data_length Length of the buffer containing optim_model_data
  197. * \param[out] out Created training session.
  198. *
  199. */
  200. ORT_API2_STATUS(CreateTrainingSessionFromBuffer, _In_ const OrtEnv* env,
  201. _In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state,
  202. _In_ const void* train_model_data, size_t train_data_length,
  203. _In_ const void* eval_model_data, size_t eval_data_length,
  204. _In_ const void* optim_model_data, size_t optim_data_length,
  205. _Outptr_result_maybenull_ OrtTrainingSession** out);
  206. /// @}
  207. /// \name Model IO Information
  208. /// @{
  209. /** \brief Retrieves the number of user outputs in the training model.
  210. *
  211. * This function returns the number of outputs of the training model so that the user can
  212. * allocate space for the number of outputs when OrtTrainingApi::TrainStep is invoked.
  213. *
  214. * \param[in] sess The `this` pointer to the training session.
  215. * \param[out] out Number of user outputs in the training model.
  216. *
  217. * \snippet{doc} snippets.dox OrtStatus Return Value
  218. *
  219. */
  220. ORT_API2_STATUS(TrainingSessionGetTrainingModelOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
  221. /** \brief Retrieves the number of user outputs in the eval model.
  222. *
  223. * This function returns the number of outputs of the eval model so that the user can
  224. * allocate space for the number of outputs when OrtTrainingApi::EvalStep is invoked.
  225. *
  226. * \param[in] sess The `this` pointer to the training session.
  227. * \param[out] out Number of user outputs in the eval model.
  228. *
  229. * \snippet{doc} snippets.dox OrtStatus Return Value
  230. *
  231. */
  232. ORT_API2_STATUS(TrainingSessionGetEvalModelOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
  233. /** \brief Retrieves the names of user outputs in the training model.
  234. *
  235. * This function returns the names of outputs of the training model that can be associated with the OrtValue(s)
  236. * returned by the OrtTrainingApi::TrainStep function.
  237. *
  238. * \param[in] sess The `this` pointer to the training session.
  239. * \param[in] index Index of the output name requested.
  240. * \param[in] allocator Allocator to use to allocate the memory for the name.
  241. * \param[out] output Name of the training model output at the given index.
  242. *
  243. * \snippet{doc} snippets.dox OrtStatus Return Value
  244. *
  245. */
  246. ORT_API2_STATUS(TrainingSessionGetTrainingModelOutputName, _In_ const OrtTrainingSession* sess, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** output);
  247. /** \brief Retrieves the names of user outputs in the eval model.
  248. *
  249. * This function returns the names of outputs of the eval model that can be associated with the OrtValue(s) returned
  250. * by the OrtTrainingApi::EvalStep function.
  251. *
  252. * \param[in] sess The `this` pointer to the training session.
  253. * \param[in] index Index of the output name requested.
  254. * \param[in] allocator Allocator to use to allocate the memory for the name.
  255. * \param[out] output Name of the eval model output at the given index.
  256. *
  257. * \snippet{doc} snippets.dox OrtStatus Return Value
  258. *
  259. */
  260. ORT_API2_STATUS(TrainingSessionGetEvalModelOutputName, _In_ const OrtTrainingSession* sess, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** output);
  261. /// @}
  262. /// \name Implementing The Training Loop
  263. /// @{
  264. /** \brief Reset the gradients of all trainable parameters to zero lazily.
  265. *
  266. * This function sets the internal state of the training session such that the gradients of the trainable
  267. * parameters in the OrtCheckpointState will be scheduled to be reset just before the new gradients are
  268. * computed on the next invocation of the next OrtTrainingApi::TrainStep.
  269. *
  270. * \param[in] session The `this` pointer to the training session.
  271. *
  272. * \snippet{doc} snippets.dox OrtStatus Return Value
  273. *
  274. */
  275. ORT_API2_STATUS(LazyResetGrad, _Inout_ OrtTrainingSession* session);
  276. /** \brief Computes the outputs of the training model and the gradients of the trainable parameters for the given inputs
  277. *
  278. * This function performs a training step that computes the outputs of the training model and the gradients
  279. * of the trainable parameters for the given inputs. The train step is performed based on the training model
  280. * that was provided to the training session.
  281. * The OrtTrainingApi::TrainStep is equivalent of running forward propagation and backward propagation in a single
  282. * step.
  283. * The gradients computed are stored inside the training session state so they can be later consumed
  284. * by the OrtTrainingApi::OptimizerStep function.
  285. * The gradients can be lazily reset by invoking the OrtTrainingApi::LazyResetGrad function.
  286. *
  287. * \param[in] sess The `this` pointer to the training session.
  288. * \param[in] run_options Run options for this training step.
  289. * \param[in] inputs_len Number of user inputs to the training model.
  290. * \param[in] inputs The user inputs to the training model.
  291. * \param[in] outputs_len Number of user outputs expected from this training step.
  292. * \param[out] outputs User outputs computed by train step.
  293. *
  294. * \snippet{doc} snippets.dox OrtStatus Return Value
  295. *
  296. */
  297. ORT_API2_STATUS(TrainStep, _Inout_ OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options,
  298. _In_ size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs,
  299. _In_ size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs);
  300. /** \brief Computes the outputs for the eval model for the given inputs
  301. *
  302. * This function performs an eval step that computes the outputs of the eval model for the given inputs.
  303. * The eval step is performed based on the eval model that was provided to the training session.
  304. *
  305. * \param[in] sess The `this` pointer to the training session.
  306. * \param[in] run_options Run options for this eval step.
  307. * \param[in] inputs_len Number of user inputs to the eval model.
  308. * \param[in] inputs The user inputs to the eval model.
  309. * \param[in] outputs_len Number of user outputs expected from this eval step.
  310. * \param[out] outputs User outputs computed by eval step.
  311. *
  312. * \snippet{doc} snippets.dox OrtStatus Return Value
  313. *
  314. */
  315. ORT_API2_STATUS(EvalStep, _In_ const OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options,
  316. _In_ size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs,
  317. _In_ size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs);
  318. /** \brief Sets the learning rate for this training session.
  319. *
  320. * This function allows users to set the learning rate for the training session. The current
  321. * learning rate is maintained by the training session and can be overwritten by invoking
  322. * this function with the desired learning rate. This function should not be used when a valid
  323. * learning rate scheduler is registered. It should be used either to set the learning rate
  324. * derived from a custom learning rate scheduler or to set a constant learning rate to be used
  325. * throughout the training session.
  326. * \note Please note that this function does not set the initial learning rate that may be needed
  327. * by the predefined learning rate schedulers. To set the initial learning rate for learning
  328. * rate schedulers, please look at the function OrtTrainingApi::RegisterLinearLRScheduler.
  329. *
  330. * \param[in] sess The `this` pointer to the training session.
  331. * \param[in] learning_rate Desired learning rate to be set.
  332. *
  333. * \snippet{doc} snippets.dox OrtStatus Return Value
  334. *
  335. */
  336. ORT_API2_STATUS(SetLearningRate, _Inout_ OrtTrainingSession* sess, _In_ float learning_rate);
  337. /** \brief Gets the current learning rate for this training session.
  338. *
  339. * This function allows users to get the learning rate for the training session. The current
  340. * learning rate is maintained by the training session, and users can query it for the purpose
  341. * of implementing their own learning rate schedulers.
  342. *
  343. * \param[in] sess The `this` pointer to the training session.
  344. * \param[out] learning_rate Learning rate currently in use by the training session.
  345. *
  346. * \snippet{doc} snippets.dox OrtStatus Return Value
  347. *
  348. */
  349. ORT_API2_STATUS(GetLearningRate, _Inout_ OrtTrainingSession* sess, _Out_ float* learning_rate);
  350. /** \brief Performs the weight updates for the trainable parameters using the optimizer model.
  351. *
  352. * This function performs the weight update step that updates the trainable parameters such that they
  353. * take a step in the direction of their gradients (gradient descent). The optimizer step is performed
  354. * based on the optimizer model that was provided to the training session.
  355. * The updated parameters are stored inside the training state so that they can be used by the next
  356. * OrtTrainingApi::TrainStep function call.
  357. *
  358. * \param[in] sess The `this` pointer to the training session.
  359. * \param[in] run_options Run options for this optimizer step.
  360. *
  361. * \snippet{doc} snippets.dox OrtStatus Return Value
  362. *
  363. */
  364. ORT_API2_STATUS(OptimizerStep, _Inout_ OrtTrainingSession* sess,
  365. _In_opt_ const OrtRunOptions* run_options);
  366. /** \brief Registers a linear learning rate scheduler for the training session.
  367. *
  368. * Register a linear learning rate scheduler that decays the learning rate by linearly updated
  369. * multiplicative factor from the initial learning rate set on the training session to 0. The decay
  370. * is performed after the initial warm up phase where the learning rate is linearly incremented
  371. * from 0 to the initial learning rate provided.
  372. *
  373. * \param[in] sess The `this` pointer to the training session.
  374. * \param[in] warmup_step_count Warmup steps for LR warmup.
  375. * \param[in] total_step_count Total step count.
  376. * \param[in] initial_lr The initial learning rate to be used by the training session.
  377. *
  378. * \snippet{doc} snippets.dox OrtStatus Return Value
  379. *
  380. */
  381. ORT_API2_STATUS(RegisterLinearLRScheduler, _Inout_ OrtTrainingSession* sess, _In_ const int64_t warmup_step_count,
  382. _In_ const int64_t total_step_count, _In_ const float initial_lr);
  383. /** \brief Update the learning rate based on the registered learing rate scheduler.
  384. *
  385. * Takes a scheduler step that updates the learning rate that is being used by the training session.
  386. * This function should typically be called before invoking the optimizer step for each round,
  387. * or as determined necessary to update the learning rate being used by the training session.
  388. * \note Please note that a valid predefined learning rate scheduler must be first registered to invoke this
  389. * function.
  390. *
  391. * \param[in] sess The `this` pointer to the training session.
  392. *
  393. * \snippet{doc} snippets.dox OrtStatus Return Value
  394. *
  395. */
  396. ORT_API2_STATUS(SchedulerStep, _Inout_ OrtTrainingSession* sess);
  397. /// @}
  398. /// \name Accessing The Training Session State
  399. /// @{
  400. /** \brief Retrieves the size of all the parameters.
  401. *
  402. * Calculates the total number of primitive (datatype of the parameters) elements of all the parameters in the
  403. * training state.
  404. * When trainable_only argument is true, the size is calculated for trainable params only.
  405. *
  406. * \param[in] sess The `this` pointer to the training session.
  407. * \param[out] out Size of all parameter elements.
  408. * \param[in] trainable_only Whether to skip non-trainable parameters
  409. *
  410. * \snippet{doc} snippets.dox OrtStatus Return Value
  411. *
  412. */
  413. ORT_API2_STATUS(GetParametersSize, _Inout_ OrtTrainingSession* sess, _Out_ size_t* out, bool trainable_only);
  414. /** \brief Copy all parameters to a contiguous buffer held by the argument parameters_buffer
  415. *
  416. * The parameters_buffer has to be of the size given by GetParametersSize api call,
  417. * with matching setting for the argument trainable_only. All the target parameters must be of the same
  418. * datatype. The OrtValue must be pre-allocated onto
  419. * the desired device. This is a complementary function to OrtTrainingApi::CopyBufferToParameters.
  420. * Parameter ordering is preserved.
  421. * User is responsible for allocating and freeing the resources used by the parameters_buffer.
  422. *
  423. * \param[in] sess The `this` pointer to the training session.
  424. * \param[in] trainable_only Whether to skip non-trainable parameters
  425. * \param[out] parameters_buffer The pre-allocated OrtValue buffer to copy onto.
  426. *
  427. * \snippet{doc} snippets.dox OrtStatus Return Value
  428. *
  429. */
  430. ORT_API2_STATUS(CopyParametersToBuffer, _Inout_ OrtTrainingSession* sess,
  431. _Inout_ OrtValue* parameters_buffer, bool trainable_only);
  432. /** \brief Copy parameter values from the given contiguous buffer held by parameters_buffer to the training state
  433. *
  434. * The parameters_buffer argument has to be of the size given by OrtTrainingApi::GetParametersSize api call,
  435. * with matching setting for trainable_only argument. All the target parameters must be of the same
  436. * datatype. This is a complementary function to OrtTrainingApi::CopyParametersToBuffer
  437. * and can be used to load updated buffer values onto the training state.
  438. * Parameter ordering is preserved.
  439. * User is responsible for allocating and freeing the resources used by the parameters_buffer.
  440. * In case the training session was created with a nominal checkpoint, invoking this function is required
  441. * to load the updated parameters onto the checkpoint to complete it.
  442. *
  443. * \param[in] sess The `this` pointer to the training session.
  444. * \param[in] trainable_only Whether to skip non-trainable parameters
  445. * \param[out] parameters_buffer The pre-allocated OrtValue buffer to copy from.
  446. *
  447. * \snippet{doc} snippets.dox OrtStatus Return Value
  448. *
  449. */
  450. ORT_API2_STATUS(CopyBufferToParameters, _Inout_ OrtTrainingSession* sess,
  451. _Inout_ OrtValue* parameters_buffer, bool trainable_only);
  452. /// @}
  453. /// \name Release Training Resources
  454. /// @{
  455. /** \brief Frees up the memory used up by the training session.
  456. *
  457. * This function frees up any memory that was allocated in the training session. The training
  458. * session can no longer be used after this call.
  459. *
  460. */
  461. ORT_CLASS_RELEASE(TrainingSession);
  462. /** \brief Frees up the memory used up by the checkpoint state.
  463. *
  464. * This function frees up any memory that was allocated in the checkpoint state. The checkpoint
  465. * state can no longer be used after this call.
  466. * \note Note that the checkpoint state must be released only after the training session has been released.
  467. *
  468. */
  469. ORT_CLASS_RELEASE(CheckpointState);
  470. /// @}
  471. /// \name Prepare For Inferencing
  472. /// @{
  473. /** \brief Export a model that can be used for inferencing.
  474. *
  475. * If the training session was provided with an eval model, the training session can generate
  476. * an inference model if it knows the inference graph outputs. The input inference graph outputs
  477. * are used to prune the eval model so that the inference model's outputs align with the provided outputs.
  478. * The exported model is saved at the path provided and can be used for inferencing with InferenceSession.
  479. * \note Note that the function re-loads the eval model from the path provided to OrtTrainingApi::CreateTrainingSession
  480. * and expects that this path still be valid.
  481. *
  482. * \param[in] sess The `this` pointer to the training session.
  483. * \param[in] inference_model_path Path where the inference model should be serialized to.
  484. * \param[in] graph_outputs_len Size of the graph output names array.
  485. * \param[in] graph_output_names Names of the outputs that are needed in the inference model.
  486. *
  487. * \snippet{doc} snippets.dox OrtStatus Return Value
  488. *
  489. */
  490. ORT_API2_STATUS(ExportModelForInferencing, _Inout_ OrtTrainingSession* sess,
  491. _In_ const ORTCHAR_T* inference_model_path, size_t graph_outputs_len,
  492. _In_reads_(graph_outputs_len) const char* const* graph_output_names);
  493. /// @}
  494. /// \name Training Utilities
  495. /// @{
  496. /** \brief Sets the seed used for random number generation in Onnxruntime.
  497. *
  498. * Use this function to generate reproducible results. It should be noted that completely reproducible
  499. * results are not guaranteed.
  500. *
  501. * \param[in] seed The seed to be set.
  502. *
  503. * \snippet{doc} snippets.dox OrtStatus Return Value
  504. *
  505. */
  506. ORT_API2_STATUS(SetSeed, _In_ const int64_t seed);
  507. /// @}
  508. /// \name Model IO Information
  509. /// @{
  510. /** \brief Retrieves the number of user inputs in the training model.
  511. *
  512. * This function returns the number of inputs of the training model so that the user can accordingly
  513. * allocate the OrtValue(s) provided to the OrtTrainingApi::TrainStep function.
  514. *
  515. * \param[in] sess The `this` pointer to the training session.
  516. * \param[out] out Number of user inputs in the training model.
  517. *
  518. * \snippet{doc} snippets.dox OrtStatus Return Value
  519. *
  520. */
  521. ORT_API2_STATUS(TrainingSessionGetTrainingModelInputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
  522. /** \brief Retrieves the number of user inputs in the eval model.
  523. *
  524. * This function returns the number of inputs of the eval model so that the user can accordingly
  525. * allocate the OrtValue(s) provided to the OrtTrainingApi::EvalStep function.
  526. *
  527. * \param[in] sess The `this` pointer to the training session.
  528. * \param[out] out Number of user inputs in the eval model.
  529. *
  530. * \snippet{doc} snippets.dox OrtStatus Return Value
  531. *
  532. */
  533. ORT_API2_STATUS(TrainingSessionGetEvalModelInputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
  534. /** \brief Retrieves the name of the user input at given index in the training model.
  535. *
  536. * This function returns the names of inputs of the training model that can be associated with the
  537. * OrtValue(s) provided to the OrtTrainingApi::TrainStep function.
  538. *
  539. * \param[in] sess The `this` pointer to the training session.
  540. * \param[in] index The index of the training model input name requested.
  541. * \param[in] allocator The allocator to use to allocate the memory for the requested name.
  542. * \param[out] output Name of the user input for the training model at the given index.
  543. *
  544. * \snippet{doc} snippets.dox OrtStatus Return Value
  545. *
  546. */
  547. ORT_API2_STATUS(TrainingSessionGetTrainingModelInputName, _In_ const OrtTrainingSession* sess, size_t index,
  548. _In_ OrtAllocator* allocator, _Outptr_ char** output);
  549. /** \brief Retrieves the name of the user input at given index in the eval model.
  550. *
  551. * This function returns the names of inputs of the eval model that can be associated with the OrtValue(s) provided
  552. * to the OrtTrainingApi::EvalStep function.
  553. *
  554. * \param[in] sess The `this` pointer to the training session.
  555. * \param[in] index The index of the eval model input name requested.
  556. * \param[in] allocator The allocator to use to allocate the memory for the requested name.
  557. * \param[out] output Name of the user input for the eval model at the given index.
  558. *
  559. * \snippet{doc} snippets.dox OrtStatus Return Value
  560. *
  561. */
  562. ORT_API2_STATUS(TrainingSessionGetEvalModelInputName, _In_ const OrtTrainingSession* sess, size_t index,
  563. _In_ OrtAllocator* allocator, _Outptr_ char** output);
  564. /// @}
  565. /// \name Accessing The Training Session State
  566. /// @{
  567. /** \brief Adds or updates the given property to/in the checkpoint state.
  568. *
  569. * Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
  570. * state by the user by calling this function with the corresponding property name and value.
  571. * The given property name must be unique to be able to successfully add the property.
  572. *
  573. * \param[in] checkpoint_state The checkpoint state which should hold the property.
  574. * \param[in] property_name Name of the property being added or updated.
  575. * \param[in] property_type Type of the property associated with the given name.
  576. * \param[in] property_value Property value associated with the given name.
  577. *
  578. * \snippet{doc} snippets.dox OrtStatus Return Value
  579. *
  580. */
  581. ORT_API2_STATUS(AddProperty, _Inout_ OrtCheckpointState* checkpoint_state,
  582. _In_ const char* property_name, _In_ enum OrtPropertyType property_type,
  583. _In_ void* property_value);
  584. /** \brief Gets the property value associated with the given name from the checkpoint state.
  585. *
  586. * Gets the property value from an existing entry in the checkpoint state. The property must
  587. * exist in the checkpoint state to be able to retrieve it successfully.
  588. *
  589. * \param[in] checkpoint_state The checkpoint state that is currently holding the property.
  590. * \param[in] property_name Name of the property being retrieved.
  591. * \param[in] allocator Allocator used to allocate the memory for the property_value.
  592. * \param[out] property_type Type of the property associated with the given name.
  593. * \param[out] property_value Property value associated with the given name.
  594. *
  595. * \snippet{doc} snippets.dox OrtStatus Return Value
  596. *
  597. */
  598. ORT_API2_STATUS(GetProperty, _In_ const OrtCheckpointState* checkpoint_state,
  599. _In_ const char* property_name, _Inout_ OrtAllocator* allocator,
  600. _Out_ enum OrtPropertyType* property_type, _Outptr_ void** property_value);
  601. /// @}
  602. /// \name Accessing The Training Session State
  603. /// @{
  604. /** \brief Load a checkpoint state from a buffer into checkpoint_state.
  605. *
  606. * This function will parse a checkpoint bytes buffer, pull relevant data and load the training
  607. * state into the checkpoint_state. This checkpoint state can then be used to create the
  608. * training session by invoking OrtTrainingApi::CreateTrainingSession. By doing so, the training
  609. * session will resume training from the given checkpoint state.
  610. * \note Note that the training session created with a checkpoint state uses this state to store the entire
  611. * training state (including model parameters, its gradients, the optimizer states and the properties).
  612. * As a result, it is required that the checkpoint state outlive the lifetime of the training session.
  613. *
  614. * \param[in] checkpoint_buffer Path to the checkpoint bytes buffer.
  615. * \param[in] num_bytes Number of bytes in the checkpoint buffer.
  616. * \param[out] checkpoint_state Checkpoint state that contains the states of the training session.
  617. *
  618. * \snippet{doc} snippets.dox OrtStatus Return Value
  619. *
  620. */
  621. ORT_API2_STATUS(LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer,
  622. _In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state);
  623. /** \brief Retrieves the type and shape information of the parameter associated with the given parameter name.
  624. *
  625. * This function retrieves the type and shape of the parameter associated with the given parameter name.
  626. * The parameter must exist in the checkpoint state to be able to retrieve its type and shape information successfully.
  627. *
  628. * \param[in] checkpoint_state The checkpoint state.
  629. * \param[in] parameter_name Name of the parameter being retrieved.
  630. * \param[out] parameter_type_and_shape The type and shape of the parameter being retrieved.
  631. *
  632. * \snippet{doc} snippets.dox OrtStatus Return Value
  633. *
  634. */
  635. ORT_API2_STATUS(GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state,
  636. _In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape);
  637. /** \brief Updates the data associated with the model parameter in the checkpoint state for the given parameter name.
  638. *
  639. * This function updates a model parameter in the checkpoint state with the given parameter data.
  640. * The training session must be already created with the checkpoint state that contains the parameter
  641. * being updated. The given parameter is copied over to the registered device for the training session.
  642. * The parameter must exist in the checkpoint state to be able to update it successfully.
  643. *
  644. * \param[in] checkpoint_state The checkpoint state.
  645. * \param[in] parameter_name Name of the parameter being updated.
  646. * \param[in] parameter The parameter data that should replace the existing parameter data.
  647. *
  648. * \snippet{doc} snippets.dox OrtStatus Return Value
  649. *
  650. */
  651. ORT_API2_STATUS(UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state,
  652. _In_ const char* parameter_name, _In_ OrtValue* parameter);
  653. /** \brief Gets the data associated with the model parameter from the checkpoint state for the given parameter name.
  654. *
  655. * This function retrieves the model parameter data from the checkpoint state for the given parameter name.
  656. * The parameter is copied over and returned as an OrtValue. The training session must be already created
  657. * with the checkpoint state that contains the parameter being retrieved.
  658. * The parameter must exist in the checkpoint state to be able to retrieve it successfully.
  659. *
  660. * \param[in] checkpoint_state The checkpoint state.
  661. * \param[in] parameter_name Name of the parameter being retrieved.
  662. * \param[in] allocator Allocator used to allocate the memory for the parameter.
  663. * \param[out] parameter The parameter data that is retrieved from the checkpoint state.
  664. *
  665. * \snippet{doc} snippets.dox OrtStatus Return Value
  666. *
  667. */
  668. ORT_API2_STATUS(GetParameter, _In_ const OrtCheckpointState* checkpoint_state,
  669. _In_ const char* parameter_name, _Inout_ OrtAllocator* allocator,
  670. _Outptr_ OrtValue** parameter);
  671. /// @}
  672. };
  673. typedef struct OrtTrainingApi OrtTrainingApi;
  674. /// @}