cuda_block_sparse_crs_view.cc 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. // Ceres Solver - A fast non-linear least squares minimizer
  2. // Copyright 2023 Google Inc. All rights reserved.
  3. // http://ceres-solver.org/
  4. //
  5. // Redistribution and use in source and binary forms, with or without
  6. // modification, are permitted provided that the following conditions are met:
  7. //
  8. // * Redistributions of source code must retain the above copyright notice,
  9. // this list of conditions and the following disclaimer.
  10. // * Redistributions in binary form must reproduce the above copyright notice,
  11. // this list of conditions and the following disclaimer in the documentation
  12. // and/or other materials provided with the distribution.
  13. // * Neither the name of Google Inc. nor the names of its contributors may be
  14. // used to endorse or promote products derived from this software without
  15. // specific prior written permission.
  16. //
  17. // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  18. // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  19. // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  20. // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
  21. // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
  22. // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
  23. // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
  24. // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
  25. // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
  26. // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
  27. // POSSIBILITY OF SUCH DAMAGE.
  28. //
  29. // Authors: dmitriy.korchemkin@gmail.com (Dmitriy Korchemkin)
  30. #include "ceres/cuda_block_sparse_crs_view.h"
  31. #ifndef CERES_NO_CUDA
  32. #include "ceres/cuda_kernels_bsm_to_crs.h"
  33. namespace ceres::internal {
  34. CudaBlockSparseCRSView::CudaBlockSparseCRSView(const BlockSparseMatrix& bsm,
  35. ContextImpl* context)
  36. : context_(context) {
  37. block_structure_ = std::make_unique<CudaBlockSparseStructure>(
  38. *bsm.block_structure(), context);
  39. CudaBuffer<int32_t> rows(context, bsm.num_rows() + 1);
  40. CudaBuffer<int32_t> cols(context, bsm.num_nonzeros());
  41. FillCRSStructure(block_structure_->num_row_blocks(),
  42. bsm.num_rows(),
  43. block_structure_->first_cell_in_row_block(),
  44. block_structure_->cells(),
  45. block_structure_->row_blocks(),
  46. block_structure_->col_blocks(),
  47. rows.data(),
  48. cols.data(),
  49. context->DefaultStream(),
  50. context->is_cuda_memory_pools_supported_);
  51. is_crs_compatible_ = block_structure_->IsCrsCompatible();
  52. // if matrix is crs-compatible - we can drop block-structure and don't need
  53. // streamed_buffer_
  54. if (is_crs_compatible_) {
  55. VLOG(3) << "Block-sparse matrix is compatible with CRS, discarding "
  56. "block-structure";
  57. block_structure_ = nullptr;
  58. } else {
  59. streamed_buffer_ = std::make_unique<CudaStreamedBuffer<double>>(
  60. context_, kMaxTemporaryArraySize);
  61. }
  62. crs_matrix_ = std::make_unique<CudaSparseMatrix>(
  63. bsm.num_cols(), std::move(rows), std::move(cols), context);
  64. UpdateValues(bsm);
  65. }
  66. void CudaBlockSparseCRSView::UpdateValues(const BlockSparseMatrix& bsm) {
  67. if (is_crs_compatible_) {
  68. // Values of CRS-compatible matrices can be copied as-is
  69. CHECK_EQ(cudaSuccess,
  70. cudaMemcpyAsync(crs_matrix_->mutable_values(),
  71. bsm.values(),
  72. bsm.num_nonzeros() * sizeof(double),
  73. cudaMemcpyHostToDevice,
  74. context_->DefaultStream()));
  75. return;
  76. }
  77. streamed_buffer_->CopyToGpu(
  78. bsm.values(),
  79. bsm.num_nonzeros(),
  80. [bs = block_structure_.get(), crs = crs_matrix_.get()](
  81. const double* values, int num_values, int offset, auto stream) {
  82. PermuteToCRS(offset,
  83. num_values,
  84. bs->num_row_blocks(),
  85. bs->first_cell_in_row_block(),
  86. bs->cells(),
  87. bs->row_blocks(),
  88. bs->col_blocks(),
  89. crs->rows(),
  90. values,
  91. crs->mutable_values(),
  92. stream);
  93. });
  94. }
  95. } // namespace ceres::internal
  96. #endif // CERES_NO_CUDA