autodiff_first_order_function.h 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. // Ceres Solver - A fast non-linear least squares minimizer
  2. // Copyright 2023 Google Inc. All rights reserved.
  3. // http://ceres-solver.org/
  4. //
  5. // Redistribution and use in source and binary forms, with or without
  6. // modification, are permitted provided that the following conditions are met:
  7. //
  8. // * Redistributions of source code must retain the above copyright notice,
  9. // this list of conditions and the following disclaimer.
  10. // * Redistributions in binary form must reproduce the above copyright notice,
  11. // this list of conditions and the following disclaimer in the documentation
  12. // and/or other materials provided with the distribution.
  13. // * Neither the name of Google Inc. nor the names of its contributors may be
  14. // used to endorse or promote products derived from this software without
  15. // specific prior written permission.
  16. //
  17. // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  18. // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  19. // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  20. // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
  21. // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
  22. // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
  23. // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
  24. // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
  25. // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
  26. // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
  27. // POSSIBILITY OF SUCH DAMAGE.
  28. //
  29. // Author: sameeragarwal@google.com (Sameer Agarwal)
  30. #ifndef CERES_PUBLIC_AUTODIFF_FIRST_ORDER_FUNCTION_H_
  31. #define CERES_PUBLIC_AUTODIFF_FIRST_ORDER_FUNCTION_H_
  32. #include <memory>
  33. #include "ceres/first_order_function.h"
  34. #include "ceres/internal/eigen.h"
  35. #include "ceres/internal/fixed_array.h"
  36. #include "ceres/jet.h"
  37. #include "ceres/types.h"
  38. namespace ceres {
  39. // Create FirstOrderFunctions as needed by the GradientProblem
  40. // framework, with gradients computed via automatic
  41. // differentiation. For more information on automatic differentiation,
  42. // see the wikipedia article at
  43. // http://en.wikipedia.org/wiki/Automatic_differentiation
  44. //
  45. // To get an auto differentiated function, you must define a class
  46. // with a templated operator() (a functor) that computes the cost
  47. // function in terms of the template parameter T. The autodiff
  48. // framework substitutes appropriate "jet" objects for T in order to
  49. // compute the derivative when necessary, but this is hidden, and you
  50. // should write the function as if T were a scalar type (e.g. a
  51. // double-precision floating point number).
  52. //
  53. // The function must write the computed value in the last argument
  54. // (the only non-const one) and return true to indicate
  55. // success.
  56. //
  57. // For example, consider a scalar error e = x'y - a, where both x and y are
  58. // two-dimensional column vector parameters, the prime sign indicates
  59. // transposition, and a is a constant.
  60. //
  61. // To write an auto-differentiable FirstOrderFunction for the above model, first
  62. // define the object
  63. //
  64. // class QuadraticCostFunctor {
  65. // public:
  66. // explicit QuadraticCostFunctor(double a) : a_(a) {}
  67. // template <typename T>
  68. // bool operator()(const T* const xy, T* cost) const {
  69. // const T* const x = xy;
  70. // const T* const y = xy + 2;
  71. // *cost = x[0] * y[0] + x[1] * y[1] - T(a_);
  72. // return true;
  73. // }
  74. //
  75. // private:
  76. // double a_;
  77. // };
  78. //
  79. // Note that in the declaration of operator() the input parameters xy come
  80. // first, and are passed as const pointers to arrays of T. The
  81. // output is the last parameter.
  82. //
  83. // Then given this class definition, the auto differentiated FirstOrderFunction
  84. // for it can be constructed as follows.
  85. //
  86. // FirstOrderFunction* function =
  87. // new AutoDiffFirstOrderFunction<QuadraticCostFunctor, 4>(
  88. // new QuadraticCostFunctor(1.0)));
  89. //
  90. // In the instantiation above, the template parameters following
  91. // "QuadraticCostFunctor", "4", describe the functor as computing a
  92. // 1-dimensional output from a four dimensional vector.
  93. //
  94. // WARNING: Since the functor will get instantiated with different types for
  95. // T, you must convert from other numeric types to T before mixing
  96. // computations with other variables of type T. In the example above, this is
  97. // seen where instead of using a_ directly, a_ is wrapped with T(a_).
  98. template <typename FirstOrderFunctor, int kNumParameters>
  99. class AutoDiffFirstOrderFunction final : public FirstOrderFunction {
  100. public:
  101. // Takes ownership of functor.
  102. explicit AutoDiffFirstOrderFunction(FirstOrderFunctor* functor)
  103. : functor_(functor) {
  104. static_assert(kNumParameters > 0, "kNumParameters must be positive");
  105. }
  106. bool Evaluate(const double* const parameters,
  107. double* cost,
  108. double* gradient) const override {
  109. if (gradient == nullptr) {
  110. return (*functor_)(parameters, cost);
  111. }
  112. using JetT = Jet<double, kNumParameters>;
  113. internal::FixedArray<JetT, (256 * 7) / sizeof(JetT)> x(kNumParameters);
  114. for (int i = 0; i < kNumParameters; ++i) {
  115. x[i].a = parameters[i];
  116. x[i].v.setZero();
  117. x[i].v[i] = 1.0;
  118. }
  119. JetT output;
  120. output.a = kImpossibleValue;
  121. output.v.setConstant(kImpossibleValue);
  122. if (!(*functor_)(x.data(), &output)) {
  123. return false;
  124. }
  125. *cost = output.a;
  126. VectorRef(gradient, kNumParameters) = output.v;
  127. return true;
  128. }
  129. int NumParameters() const override { return kNumParameters; }
  130. const FirstOrderFunctor& functor() const { return *functor_; }
  131. private:
  132. std::unique_ptr<FirstOrderFunctor> functor_;
  133. };
  134. } // namespace ceres
  135. #endif // CERES_PUBLIC_AUTODIFF_FIRST_ORDER_FUNCTION_H_