onnxruntime_training_cxx_api.h 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. // Copyright (c) Microsoft Corporation. All rights reserved.
  2. // Licensed under the MIT License.
  3. #pragma once
  4. #include "onnxruntime_training_c_api.h"
  5. #include <optional>
  6. #include <variant>
  7. namespace Ort::detail {
  8. #define ORT_DECLARE_TRAINING_RELEASE(NAME) \
  9. void OrtRelease(Ort##NAME* ptr);
  10. // These release methods must be forward declared before including onnxruntime_cxx_api.h
  11. // otherwise class Base won't be aware of them
  12. ORT_DECLARE_TRAINING_RELEASE(CheckpointState);
  13. ORT_DECLARE_TRAINING_RELEASE(TrainingSession);
  14. } // namespace Ort::detail
  15. #include "onnxruntime_cxx_api.h"
  16. namespace Ort {
  17. /// <summary>
  18. /// This function returns the C training api struct with the pointers to the ort training C functions.
  19. /// If using C++, please use the class instances instead of invoking the C functions directly.
  20. /// </summary>
  21. /// <returns>OrtTrainingApi struct with ort training C function pointers.</returns>
  22. inline const OrtTrainingApi& GetTrainingApi() { return *GetApi().GetTrainingApi(ORT_API_VERSION); }
  23. namespace detail {
  24. #define ORT_DEFINE_TRAINING_RELEASE(NAME) \
  25. inline void OrtRelease(Ort##NAME* ptr) { GetTrainingApi().Release##NAME(ptr); }
  26. ORT_DEFINE_TRAINING_RELEASE(CheckpointState);
  27. ORT_DEFINE_TRAINING_RELEASE(TrainingSession);
  28. #undef ORT_DECLARE_TRAINING_RELEASE
  29. #undef ORT_DEFINE_TRAINING_RELEASE
  30. } // namespace detail
  31. using Property = std::variant<int64_t, float, std::string>;
  32. /**
  33. * \defgroup TrainingCpp Ort Training C++ API
  34. * @{
  35. */
  36. /** \brief Holds the state of the training session.
  37. *
  38. * This class holds the entire training session state that includes model parameters, their gradients,
  39. * optimizer parameters, and user properties. The Ort::TrainingSession leverages the Ort::CheckpointState
  40. * by accessing and updating the contained training state.
  41. * \note Note that the training session created with a checkpoint state uses this state to store the entire
  42. * training state (including model parameters, its gradients, the optimizer states and the properties).
  43. * The Ort::TrainingSession does not hold a copy of the Ort::CheckpointState and as a result, it is required
  44. * that the checkpoint state outlive the lifetime of the training session.
  45. * \note Note that the checkpoint state can be either the complete checkpoint state or the nominal checkpoint
  46. * state depending on the version provided while loading the checkpoint.
  47. *
  48. */
  49. class CheckpointState : public detail::Base<OrtCheckpointState> {
  50. private:
  51. CheckpointState(OrtCheckpointState* checkpoint_state) { p_ = checkpoint_state; }
  52. public:
  53. // Construct the checkpoint state by loading the checkpoint by calling LoadCheckpoint
  54. CheckpointState() = delete;
  55. /// \name Accessing The Training Session State
  56. /// @{
  57. /** \brief Load a checkpoint state from a file on disk into checkpoint_state.
  58. *
  59. * This function will parse a checkpoint file, pull relevant data and load the training
  60. * state and return an instance of Ort::CheckpointState. This checkpoint state can then be used to create the
  61. * training session by instantiating Ort::TrainingSession. By doing so, the training session will resume
  62. * training from the given checkpoint state.
  63. *
  64. * \param[in] path_to_checkpoint Path to the checkpoint file
  65. * \return Ort::CheckpointState object which holds the state of the training session parameters.
  66. *
  67. */
  68. static CheckpointState LoadCheckpoint(const std::basic_string<ORTCHAR_T>& path_to_checkpoint);
  69. /** \brief Load a checkpoint state from a buffer.
  70. *
  71. * This function will parse a checkpoint buffer, pull relevant data and load the training
  72. * state and return an instance of Ort::CheckpointState. This checkpoint state can then be used to create the
  73. * training session by instantiating Ort::TrainingSession. By doing so, the training session will resume
  74. * training from the given checkpoint state.
  75. *
  76. * \param[in] buffer Buffer containing the checkpoint data.
  77. * \return Ort::CheckpointState object which holds the state of the training session parameters.
  78. *
  79. */
  80. static CheckpointState LoadCheckpointFromBuffer(const std::vector<uint8_t>& buffer);
  81. /** \brief Save the given state to a checkpoint file on disk.
  82. *
  83. * This function serializes the provided checkpoint state to a file on disk.
  84. * This checkpoint can later be loaded by invoking Ort::CheckpointState::LoadCheckpoint to resume
  85. * training from this snapshot of the state.
  86. *
  87. * \param[in] checkpoint_state The checkpoint state to save.
  88. * \param[in] path_to_checkpoint Path to the checkpoint file.
  89. * \param[in] include_optimizer_state Flag to indicate whether to save the optimizer state or not.
  90. *
  91. */
  92. static void SaveCheckpoint(const CheckpointState& checkpoint_state,
  93. const std::basic_string<ORTCHAR_T>& path_to_checkpoint,
  94. const bool include_optimizer_state = false);
  95. /** \brief Adds or updates the given property to/in the checkpoint state.
  96. *
  97. * Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
  98. * state by the user by calling this function with the corresponding property name and value.
  99. * The given property name must be unique to be able to successfully add the property.
  100. *
  101. * \param[in] property_name Name of the property being added or updated.
  102. * \param[in] property_value Property value associated with the given name.
  103. *
  104. */
  105. void AddProperty(const std::string& property_name, const Property& property_value);
  106. /** \brief Gets the property value associated with the given name from the checkpoint state.
  107. *
  108. * Gets the property value from an existing entry in the checkpoint state. The property must
  109. * exist in the checkpoint state to be able to retrieve it successfully.
  110. *
  111. * \param[in] property_name Name of the property being retrieved.
  112. * \return Property value associated with the given property name.
  113. *
  114. */
  115. Property GetProperty(const std::string& property_name);
  116. /** \brief Updates the data associated with the model parameter in the checkpoint state for the given parameter name.
  117. *
  118. * This function updates a model parameter in the checkpoint state with the given parameter data.
  119. * The training session must be already created with the checkpoint state that contains the parameter
  120. * being updated. The given parameter is copied over to the registered device for the training session.
  121. * The parameter must exist in the checkpoint state to be able to update it successfully.
  122. *
  123. * \param[in] parameter_name Name of the parameter being updated.
  124. * \param[in] parameter The parameter data that should replace the existing parameter data.
  125. *
  126. */
  127. void UpdateParameter(const std::string& parameter_name, const Value& parameter);
  128. /** \brief Gets the data associated with the model parameter from the checkpoint state for the given parameter name.
  129. *
  130. * This function retrieves the model parameter data from the checkpoint state for the given parameter name.
  131. * The parameter is copied over to the provided OrtValue. The training session must be already created
  132. * with the checkpoint state that contains the parameter being retrieved.
  133. * The parameter must exist in the checkpoint state to be able to retrieve it successfully.
  134. *
  135. * \param[in] parameter_name Name of the parameter being retrieved.
  136. * \return The parameter data that is retrieved from the checkpoint state.
  137. *
  138. */
  139. Value GetParameter(const std::string& parameter_name);
  140. /// @}
  141. };
  142. /** \brief Trainer class that provides training, evaluation and optimizer methods for training an ONNX models.
  143. *
  144. * The training session requires four training artifacts
  145. * - The training onnx model
  146. * - The evaluation onnx model (optional)
  147. * - The optimizer onnx model
  148. * - The checkpoint file
  149. *
  150. * These artifacts can be generated using the `onnxruntime-training` python [utility](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/python/training/onnxblock/README.md).
  151. *
  152. */
  153. class TrainingSession : public detail::Base<OrtTrainingSession> {
  154. private:
  155. size_t training_model_output_count_, eval_model_output_count_;
  156. public:
  157. /// \name Constructing the Training Session
  158. /// @{
  159. /** \brief Create a training session that can be used to begin or resume training.
  160. *
  161. * This constructor instantiates the training session based on the env and session options provided that can
  162. * begin or resume training from a given checkpoint state for the given onnx models.
  163. * The checkpoint state represents the parameters of the training session which will be moved
  164. * to the device specified by the user through the session options (if necessary).
  165. *
  166. * \param[in] env Env to be used for the training session.
  167. * \param[in] session_options SessionOptions that the user can customize for this training session.
  168. * \param[in] checkpoint_state Training states that the training session uses as a starting point for training.
  169. * \param[in] train_model_path Model to be used to perform training.
  170. * \param[in] eval_model_path Model to be used to perform evaluation.
  171. * \param[in] optimizer_model_path Model to be used to perform gradient descent.
  172. *
  173. */
  174. TrainingSession(const Env& env, const SessionOptions& session_options, CheckpointState& checkpoint_state,
  175. const std::basic_string<ORTCHAR_T>& train_model_path,
  176. const std::optional<std::basic_string<ORTCHAR_T>>& eval_model_path = std::nullopt,
  177. const std::optional<std::basic_string<ORTCHAR_T>>& optimizer_model_path = std::nullopt);
  178. /** \brief Create a training session that can be used to begin or resume training.
  179. * This constructor allows the users to load the models from buffers instead of files.
  180. *
  181. * \param[in] env Env to be used for the training session.
  182. * \param[in] session_options SessionOptions that the user can customize for this training session.
  183. * \param[in] checkpoint_state Training states that the training session uses as a starting point for training.
  184. * \param[in] train_model_data Buffer containing training model data.
  185. * \param[in] eval_model_data Buffer containing evaluation model data.
  186. * \param[in] optim_model_data Buffer containing optimizer model (used for performing weight/parameter update).
  187. *
  188. */
  189. TrainingSession(const Env& env, const SessionOptions& session_options, CheckpointState& checkpoint_state,
  190. const std::vector<uint8_t>& train_model_data, const std::vector<uint8_t>& eval_model_data = {},
  191. const std::vector<uint8_t>& optim_model_data = {});
  192. /// @}
  193. /// \name Implementing The Training Loop
  194. /// @{
  195. /** \brief Computes the outputs of the training model and the gradients of the trainable parameters for the given inputs
  196. *
  197. * This function performs a training step that computes the outputs of the training model and the gradients
  198. * of the trainable parameters for the given inputs. The train step is performed based on the training model
  199. * that was provided to the training session.
  200. * The Ort::TrainingSession::TrainStep is equivalent of running forward propagation and backward propagation in a single
  201. * step.
  202. * The gradients computed are stored inside the training session state so they can be later consumed
  203. * by the Ort::TrainingSession::OptimizerStep function.
  204. * The gradients can be lazily reset by invoking the Ort::TrainingSession::LazyResetGrad function.
  205. *
  206. * \param[in] input_values The user inputs to the training model.
  207. * \return A std::vector of Ort::Value objects that represents the output of the forward pass of the training model.
  208. *
  209. *
  210. */
  211. std::vector<Value> TrainStep(const std::vector<Value>& input_values);
  212. /** \brief Reset the gradients of all trainable parameters to zero lazily.
  213. *
  214. * This function sets the internal state of the training session such that the gradients of the trainable
  215. * parameters in the OrtCheckpointState will be scheduled to be reset just before the new gradients are
  216. * computed on the next invocation of the next Ort::TrainingSession::TrainStep.
  217. *
  218. */
  219. void LazyResetGrad();
  220. /** \brief Computes the outputs for the eval model for the given inputs
  221. *
  222. * This function performs an eval step that computes the outputs of the eval model for the given inputs.
  223. * The eval step is performed based on the eval model that was provided to the training session.
  224. *
  225. * \param[in] input_values The user inputs to the eval model.
  226. * \return A std::vector of Ort::Value objects that represents the output of the eval pass.
  227. *
  228. */
  229. std::vector<Value> EvalStep(const std::vector<Value>& input_values);
  230. /** \brief Sets the learning rate for this training session.
  231. *
  232. * This function allows users to set the learning rate for the training session. The current
  233. * learning rate is maintained by the training session and can be overwritten by invoking
  234. * this function with the desired learning rate. This function should not be used when a valid
  235. * learning rate scheduler is registered. It should be used either to set the learning rate
  236. * derived from a custom learning rate scheduler or to set a constant learning rate to be used
  237. * throughout the training session.
  238. * \note Please note that this function does not set the initial learning rate that may be needed
  239. * by the predefined learning rate schedulers. To set the initial learning rate for learning
  240. * rate schedulers, please look at the function Ort::TrainingSession::RegisterLinearLRScheduler.
  241. *
  242. * \param[in] learning_rate Desired learning rate to be set.
  243. *
  244. */
  245. void SetLearningRate(float learning_rate);
  246. /** \brief Gets the current learning rate for this training session.
  247. *
  248. * This function allows users to get the learning rate for the training session. The current
  249. * learning rate is maintained by the training session, and users can query it for the purpose
  250. * of implementing their own learning rate schedulers.
  251. *
  252. * \return float representing the current learning rate.
  253. *
  254. */
  255. float GetLearningRate() const;
  256. /** \brief Registers a linear learning rate scheduler for the training session.
  257. *
  258. * Register a linear learning rate scheduler that decays the learning rate by linearly updated
  259. * multiplicative factor from the initial learning rate set on the training session to 0. The decay
  260. * is performed after the initial warm up phase where the learning rate is linearly incremented
  261. * from 0 to the initial learning rate provided.
  262. *
  263. * \param[in] warmup_step_count Warmup steps for LR warmup.
  264. * \param[in] total_step_count Total step count.
  265. * \param[in] initial_lr The initial learning rate to be used by the training session.
  266. *
  267. */
  268. void RegisterLinearLRScheduler(int64_t warmup_step_count, int64_t total_step_count,
  269. float initial_lr);
  270. /** \brief Update the learning rate based on the registered learing rate scheduler.
  271. *
  272. * Takes a scheduler step that updates the learning rate that is being used by the training session.
  273. * This function should typically be called before invoking the optimizer step for each round,
  274. * or as determined necessary to update the learning rate being used by the training session.
  275. * \note Please note that a valid predefined learning rate scheduler must be first registered to invoke this
  276. * function.
  277. *
  278. */
  279. void SchedulerStep();
  280. /** \brief Performs the weight updates for the trainable parameters using the optimizer model.
  281. *
  282. * This function performs the weight update step that updates the trainable parameters such that they
  283. * take a step in the direction of their gradients (gradient descent). The optimizer step is performed
  284. * based on the optimizer model that was provided to the training session.
  285. * The updated parameters are stored inside the training state so that they can be used by the next
  286. * Ort::TrainingSession::TrainStep function call.
  287. *
  288. */
  289. void OptimizerStep();
  290. /// @}
  291. /// \name Prepare For Inferencing
  292. /// @{
  293. /** \brief Export a model that can be used for inferencing.
  294. *
  295. * If the training session was provided with an eval model, the training session can generate
  296. * an inference model if it knows the inference graph outputs. The input inference graph outputs
  297. * are used to prune the eval model so that the inference model's outputs align with the provided outputs.
  298. * The exported model is saved at the path provided and can be used for inferencing with Ort::Session.
  299. * \note Note that the function re-loads the eval model from the path provided to Ort::TrainingSession
  300. * and expects that this path still be valid.
  301. *
  302. * \param[in] inference_model_path Path where the inference model should be serialized to.
  303. * \param[in] graph_output_names Names of the outputs that are needed in the inference model.
  304. *
  305. */
  306. void ExportModelForInferencing(const std::basic_string<ORTCHAR_T>& inference_model_path,
  307. const std::vector<std::string>& graph_output_names);
  308. /// @}
  309. /// \name Model IO Information
  310. /// @{
  311. /** \brief Retrieves the names of the user inputs for the training and eval models.
  312. *
  313. * This function returns the names of inputs of the training or eval model that can be associated
  314. * with the Ort::Value(s) provided to the Ort::TrainingSession::TrainStep or Ort::TrainingSession::EvalStep
  315. * function.
  316. *
  317. * \param[in] training Whether the training model input names are requested or eval model input names.
  318. * \return Graph input names for either the training model or the eval model.
  319. *
  320. */
  321. std::vector<std::string> InputNames(const bool training);
  322. /** \brief Retrieves the names of the user outputs for the training and eval models.
  323. *
  324. * This function returns the names of outputs of the training or eval model that can be associated
  325. * with the Ort::Value(s) returned by the Ort::TrainingSession::TrainStep or Ort::TrainingSession::EvalStep
  326. * function.
  327. *
  328. * \param[in] training Whether the training model output names are requested or eval model output names.
  329. * \return Graph output names for either the training model or the eval model.
  330. *
  331. */
  332. std::vector<std::string> OutputNames(const bool training);
  333. /// @}
  334. /// \name Accessing The Training Session State
  335. /// @{
  336. /** \brief Returns a contiguous buffer that holds a copy of all training state parameters
  337. *
  338. * \param[in] only_trainable Whether to only copy trainable parameters or to copy all parameters.
  339. * \return Contiguous buffer to the model parameters.
  340. *
  341. */
  342. Value ToBuffer(const bool only_trainable);
  343. /** \brief Loads the training session model parameters from a contiguous buffer
  344. *
  345. * In case the training session was created with a nominal checkpoint, invoking this function is required
  346. * to load the updated parameters onto the checkpoint to complete it.
  347. *
  348. * \param[in] buffer Contiguous buffer to load the parameters from.
  349. */
  350. void FromBuffer(Value& buffer);
  351. /// @}
  352. };
  353. /// \name Training Utilities
  354. /// @{
  355. /** \brief This function sets the seed for generating random numbers.
  356. *
  357. * Use this function to generate reproducible results. It should be noted that completely
  358. * reproducible results are not guaranteed.
  359. *
  360. * \param[in] seed Manual seed to use for random number generation.
  361. */
  362. void SetSeed(const int64_t seed);
  363. /// @}
  364. /// @}
  365. } // namespace Ort
  366. #include "onnxruntime_training_cxx_inline.h"