cuda_partitioned_block_sparse_crs_view.cc 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. // Ceres Solver - A fast non-linear least squares minimizer
  2. // Copyright 2023 Google Inc. All rights reserved.
  3. // http://ceres-solver.org/
  4. //
  5. // Redistribution and use in source and binary forms, with or without
  6. // modification, are permitted provided that the following conditions are met:
  7. //
  8. // * Redistributions of source code must retain the above copyright notice,
  9. // this list of conditions and the following disclaimer.
  10. // * Redistributions in binary form must reproduce the above copyright notice,
  11. // this list of conditions and the following disclaimer in the documentation
  12. // and/or other materials provided with the distribution.
  13. // * Neither the name of Google Inc. nor the names of its contributors may be
  14. // used to endorse or promote products derived from this software without
  15. // specific prior written permission.
  16. //
  17. // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  18. // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  19. // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  20. // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
  21. // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
  22. // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
  23. // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
  24. // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
  25. // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
  26. // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
  27. // POSSIBILITY OF SUCH DAMAGE.
  28. //
  29. // Authors: dmitriy.korchemkin@gmail.com (Dmitriy Korchemkin)
  30. #include "ceres/cuda_partitioned_block_sparse_crs_view.h"
  31. #ifndef CERES_NO_CUDA
  32. #include "ceres/cuda_block_structure.h"
  33. #include "ceres/cuda_kernels_bsm_to_crs.h"
  34. namespace ceres::internal {
  35. CudaPartitionedBlockSparseCRSView::CudaPartitionedBlockSparseCRSView(
  36. const BlockSparseMatrix& bsm,
  37. const int num_col_blocks_e,
  38. ContextImpl* context)
  39. :
  40. context_(context) {
  41. const auto& bs = *bsm.block_structure();
  42. block_structure_ =
  43. std::make_unique<CudaBlockSparseStructure>(bs, num_col_blocks_e, context);
  44. // Determine number of non-zeros in left submatrix
  45. // Row-blocks are at least 1 row high, thus we can use a temporary array of
  46. // num_rows for ComputeNonZerosInColumnBlockSubMatrix; and later reuse it for
  47. // FillCRSStructurePartitioned
  48. const int num_rows = bsm.num_rows();
  49. const int num_nonzeros_e = block_structure_->num_nonzeros_e();
  50. const int num_nonzeros_f = bsm.num_nonzeros() - num_nonzeros_e;
  51. const int num_cols_e = num_col_blocks_e < bs.cols.size()
  52. ? bs.cols[num_col_blocks_e].position
  53. : bsm.num_cols();
  54. const int num_cols_f = bsm.num_cols() - num_cols_e;
  55. CudaBuffer<int32_t> rows_e(context, num_rows + 1);
  56. CudaBuffer<int32_t> cols_e(context, num_nonzeros_e);
  57. CudaBuffer<int32_t> rows_f(context, num_rows + 1);
  58. CudaBuffer<int32_t> cols_f(context, num_nonzeros_f);
  59. num_row_blocks_e_ = block_structure_->num_row_blocks_e();
  60. FillCRSStructurePartitioned(block_structure_->num_row_blocks(),
  61. num_rows,
  62. num_row_blocks_e_,
  63. num_col_blocks_e,
  64. num_nonzeros_e,
  65. block_structure_->first_cell_in_row_block(),
  66. block_structure_->cells(),
  67. block_structure_->row_blocks(),
  68. block_structure_->col_blocks(),
  69. rows_e.data(),
  70. cols_e.data(),
  71. rows_f.data(),
  72. cols_f.data(),
  73. context->DefaultStream(),
  74. context->is_cuda_memory_pools_supported_);
  75. f_is_crs_compatible_ = block_structure_->IsCrsCompatible();
  76. if (f_is_crs_compatible_) {
  77. block_structure_ = nullptr;
  78. } else {
  79. streamed_buffer_ = std::make_unique<CudaStreamedBuffer<double>>(
  80. context, kMaxTemporaryArraySize);
  81. }
  82. matrix_e_ = std::make_unique<CudaSparseMatrix>(
  83. num_cols_e, std::move(rows_e), std::move(cols_e), context);
  84. matrix_f_ = std::make_unique<CudaSparseMatrix>(
  85. num_cols_f, std::move(rows_f), std::move(cols_f), context);
  86. CHECK_EQ(bsm.num_nonzeros(),
  87. matrix_e_->num_nonzeros() + matrix_f_->num_nonzeros());
  88. UpdateValues(bsm);
  89. }
  90. void CudaPartitionedBlockSparseCRSView::UpdateValues(
  91. const BlockSparseMatrix& bsm) {
  92. if (f_is_crs_compatible_) {
  93. CHECK_EQ(cudaSuccess,
  94. cudaMemcpyAsync(matrix_e_->mutable_values(),
  95. bsm.values(),
  96. matrix_e_->num_nonzeros() * sizeof(double),
  97. cudaMemcpyHostToDevice,
  98. context_->DefaultStream()));
  99. CHECK_EQ(cudaSuccess,
  100. cudaMemcpyAsync(matrix_f_->mutable_values(),
  101. bsm.values() + matrix_e_->num_nonzeros(),
  102. matrix_f_->num_nonzeros() * sizeof(double),
  103. cudaMemcpyHostToDevice,
  104. context_->DefaultStream()));
  105. return;
  106. }
  107. streamed_buffer_->CopyToGpu(
  108. bsm.values(),
  109. bsm.num_nonzeros(),
  110. [block_structure = block_structure_.get(),
  111. num_nonzeros_e = matrix_e_->num_nonzeros(),
  112. num_row_blocks_e = num_row_blocks_e_,
  113. values_f = matrix_f_->mutable_values(),
  114. rows_f = matrix_f_->rows()](
  115. const double* values, int num_values, int offset, auto stream) {
  116. PermuteToCRSPartitionedF(num_nonzeros_e + offset,
  117. num_values,
  118. block_structure->num_row_blocks(),
  119. num_row_blocks_e,
  120. block_structure->first_cell_in_row_block(),
  121. block_structure->value_offset_row_block_f(),
  122. block_structure->cells(),
  123. block_structure->row_blocks(),
  124. block_structure->col_blocks(),
  125. rows_f,
  126. values,
  127. values_f,
  128. stream);
  129. });
  130. CHECK_EQ(cudaSuccess,
  131. cudaMemcpyAsync(matrix_e_->mutable_values(),
  132. bsm.values(),
  133. matrix_e_->num_nonzeros() * sizeof(double),
  134. cudaMemcpyHostToDevice,
  135. context_->DefaultStream()));
  136. }
  137. } // namespace ceres::internal
  138. #endif // CERES_NO_CUDA