cuda_vector.h 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. // Ceres Solver - A fast non-linear least squares minimizer
  2. // Copyright 2023 Google Inc. All rights reserved.
  3. // http://ceres-solver.org/
  4. //
  5. // Redistribution and use in source and binary forms, with or without
  6. // modification, are permitted provided that the following conditions are met:
  7. //
  8. // * Redistributions of source code must retain the above copyright notice,
  9. // this list of conditions and the following disclaimer.
  10. // * Redistributions in binary form must reproduce the above copyright notice,
  11. // this list of conditions and the following disclaimer in the documentation
  12. // and/or other materials provided with the distribution.
  13. // * Neither the name of Google Inc. nor the names of its contributors may be
  14. // used to endorse or promote products derived from this software without
  15. // specific prior written permission.
  16. //
  17. // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  18. // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  19. // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  20. // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
  21. // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
  22. // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
  23. // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
  24. // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
  25. // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
  26. // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
  27. // POSSIBILITY OF SUCH DAMAGE.
  28. //
  29. // Author: joydeepb@cs.utexas.edu (Joydeep Biswas)
  30. //
  31. // A simple CUDA vector class.
  32. #ifndef CERES_INTERNAL_CUDA_VECTOR_H_
  33. #define CERES_INTERNAL_CUDA_VECTOR_H_
  34. // This include must come before any #ifndef check on Ceres compile options.
  35. // clang-format off
  36. #include "ceres/internal/config.h"
  37. // clang-format on
  38. #include <math.h>
  39. #include <memory>
  40. #include <string>
  41. #include "ceres/context_impl.h"
  42. #include "ceres/internal/export.h"
  43. #include "ceres/types.h"
  44. #ifndef CERES_NO_CUDA
  45. #include "ceres/cuda_buffer.h"
  46. #include "ceres/cuda_kernels_vector_ops.h"
  47. #include "ceres/internal/eigen.h"
  48. #include "cublas_v2.h"
  49. #include "cusparse.h"
  50. namespace ceres::internal {
  51. // An Nx1 vector, denoted y hosted on the GPU, with CUDA-accelerated operations.
  52. class CERES_NO_EXPORT CudaVector {
  53. public:
  54. // Create a pre-allocated vector of size N and return a pointer to it. The
  55. // caller must ensure that InitCuda() has already been successfully called on
  56. // context before calling this method.
  57. CudaVector(ContextImpl* context, int size);
  58. CudaVector(CudaVector&& other);
  59. ~CudaVector();
  60. void Resize(int size);
  61. // Perform a deep copy of the vector.
  62. CudaVector& operator=(const CudaVector&);
  63. // Return the inner product x' * y.
  64. double Dot(const CudaVector& x) const;
  65. // Return the L2 norm of the vector (||y||_2).
  66. double Norm() const;
  67. // Set all elements to zero.
  68. void SetZero();
  69. // Copy from Eigen vector.
  70. void CopyFromCpu(const Vector& x);
  71. // Copy from CPU memory array.
  72. void CopyFromCpu(const double* x);
  73. // Copy to Eigen vector.
  74. void CopyTo(Vector* x) const;
  75. // Copy to CPU memory array. It is the caller's responsibility to ensure
  76. // that the array is large enough.
  77. void CopyTo(double* x) const;
  78. // y = a * x + b * y.
  79. void Axpby(double a, const CudaVector& x, double b);
  80. // y = diag(d)' * diag(d) * x + y.
  81. void DtDxpy(const CudaVector& D, const CudaVector& x);
  82. // y = s * y.
  83. void Scale(double s);
  84. int num_rows() const { return num_rows_; }
  85. int num_cols() const { return 1; }
  86. const double* data() const { return data_.data(); }
  87. double* mutable_data() { return data_.data(); }
  88. const cusparseDnVecDescr_t& descr() const { return descr_; }
  89. private:
  90. CudaVector(const CudaVector&) = delete;
  91. void DestroyDescriptor();
  92. int num_rows_ = 0;
  93. ContextImpl* context_ = nullptr;
  94. CudaBuffer<double> data_;
  95. // CuSparse object that describes this dense vector.
  96. cusparseDnVecDescr_t descr_ = nullptr;
  97. };
  98. // Blas1 operations on Cuda vectors. These functions are needed as an
  99. // abstraction layer so that we can use different versions of a vector style
  100. // object in the conjugate gradients linear solver.
  101. // Context and num_threads arguments are not used by CUDA implementation,
  102. // context embedded into CudaVector is used instead.
  103. inline double Norm(const CudaVector& x,
  104. ContextImpl* context = nullptr,
  105. int num_threads = 1) {
  106. (void)context;
  107. (void)num_threads;
  108. return x.Norm();
  109. }
  110. inline void SetZero(CudaVector& x,
  111. ContextImpl* context = nullptr,
  112. int num_threads = 1) {
  113. (void)context;
  114. (void)num_threads;
  115. x.SetZero();
  116. }
  117. inline void Axpby(double a,
  118. const CudaVector& x,
  119. double b,
  120. const CudaVector& y,
  121. CudaVector& z,
  122. ContextImpl* context = nullptr,
  123. int num_threads = 1) {
  124. (void)context;
  125. (void)num_threads;
  126. if (&x == &y && &y == &z) {
  127. // z = (a + b) * z;
  128. z.Scale(a + b);
  129. } else if (&x == &z) {
  130. // x is aliased to z.
  131. // z = x
  132. // = b * y + a * x;
  133. z.Axpby(b, y, a);
  134. } else if (&y == &z) {
  135. // y is aliased to z.
  136. // z = y = a * x + b * y;
  137. z.Axpby(a, x, b);
  138. } else {
  139. // General case: all inputs and outputs are distinct.
  140. z = y;
  141. z.Axpby(a, x, b);
  142. }
  143. }
  144. inline double Dot(const CudaVector& x,
  145. const CudaVector& y,
  146. ContextImpl* context = nullptr,
  147. int num_threads = 1) {
  148. (void)context;
  149. (void)num_threads;
  150. return x.Dot(y);
  151. }
  152. inline void Copy(const CudaVector& from,
  153. CudaVector& to,
  154. ContextImpl* context = nullptr,
  155. int num_threads = 1) {
  156. (void)context;
  157. (void)num_threads;
  158. to = from;
  159. }
  160. } // namespace ceres::internal
  161. #endif // CERES_NO_CUDA
  162. #endif // CERES_INTERNAL_CUDA_SPARSE_LINEAR_OPERATOR_H_