daubechies_wavelet.hpp 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. /*
  2. * Copyright Nick Thompson, 2020
  3. * Use, modification and distribution are subject to the
  4. * Boost Software License, Version 1.0. (See accompanying file
  5. * LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
  6. */
  7. #ifndef BOOST_MATH_SPECIAL_DAUBECHIES_WAVELET_HPP
  8. #define BOOST_MATH_SPECIAL_DAUBECHIES_WAVELET_HPP
  9. #include <vector>
  10. #include <array>
  11. #include <cmath>
  12. #include <thread>
  13. #include <future>
  14. #include <iostream>
  15. #include <boost/math/constants/constants.hpp>
  16. #include <boost/math/special_functions/detail/daubechies_scaling_integer_grid.hpp>
  17. #include <boost/math/special_functions/daubechies_scaling.hpp>
  18. #include <boost/math/filters/daubechies.hpp>
  19. #include <boost/math/interpolators/detail/cubic_hermite_detail.hpp>
  20. #include <boost/math/interpolators/detail/quintic_hermite_detail.hpp>
  21. #include <boost/math/interpolators/detail/septic_hermite_detail.hpp>
  22. namespace boost::math {
  23. template<class Real, int p, int order>
  24. std::vector<Real> daubechies_wavelet_dyadic_grid(int64_t j_max)
  25. {
  26. if (j_max == 0)
  27. {
  28. throw std::domain_error("The wavelet dyadic grid is refined from the scaling integer grid, so its minimum amount of data is half integer widths.");
  29. }
  30. auto phijk = daubechies_scaling_dyadic_grid<Real, p, order>(j_max - 1);
  31. //psi_j[l] = psi(-p+1 + l/2^j) = \sum_{k=0}^{2p-1} (-1)^k c_k \phi(1-2p+k + l/2^{j-1})
  32. //For derivatives just map c_k -> 2^order c_k.
  33. auto d = boost::math::filters::daubechies_scaling_filter<Real, p>();
  34. Real scale = boost::math::constants::root_two<Real>() * (1 << order);
  35. for (size_t i = 0; i < d.size(); ++i)
  36. {
  37. d[i] *= scale;
  38. if (!(i & 1))
  39. {
  40. d[i] = -d[i];
  41. }
  42. }
  43. std::vector<Real> v(2 * p + (2 * p - 1) * ((int64_t(1) << j_max) - 1), std::numeric_limits<Real>::quiet_NaN());
  44. v[0] = 0;
  45. v[v.size() - 1] = 0;
  46. for (int64_t l = 1; l < static_cast<int64_t>(v.size() - 1); ++l)
  47. {
  48. Real term = 0;
  49. for (int64_t k = 0; k < static_cast<int64_t>(d.size()); ++k)
  50. {
  51. int64_t idx = (int64_t(1) << (j_max - 1)) * (1 - 2 * p + k) + l;
  52. if (idx < 0 || idx >= static_cast<int64_t>(phijk.size()))
  53. {
  54. continue;
  55. }
  56. term += d[k] * phijk[idx];
  57. }
  58. v[l] = term;
  59. }
  60. return v;
  61. }
  62. template<class Real, int p>
  63. class daubechies_wavelet {
  64. //
  65. // Some type manipulation so we know the type of the interpolator, and the vector type it requires:
  66. //
  67. typedef std::vector < std::array < Real, p < 6 ? 2 : p < 10 ? 3 : 4>> vector_type;
  68. //
  69. // List our interpolators:
  70. //
  71. typedef std::tuple<
  72. detail::null_interpolator, detail::matched_holder_aos<vector_type>, detail::linear_interpolation_aos<vector_type>,
  73. interpolators::detail::cardinal_cubic_hermite_detail_aos<vector_type>, interpolators::detail::cardinal_quintic_hermite_detail_aos<vector_type>,
  74. interpolators::detail::cardinal_septic_hermite_detail_aos<vector_type> > interpolator_list;
  75. //
  76. // Select the one we need:
  77. //
  78. typedef std::tuple_element_t<
  79. p == 1 ? 0 :
  80. p == 2 ? 1 :
  81. p == 3 ? 2 :
  82. p <= 5 ? 3 :
  83. p <= 9 ? 4 : 5, interpolator_list> interpolator_type;
  84. public:
  85. daubechies_wavelet(int grid_refinements = -1)
  86. {
  87. static_assert(p < 20, "Daubechies wavelets are only implemented for p < 20.");
  88. static_assert(p > 0, "Daubechies wavelets must have at least 1 vanishing moment.");
  89. if (grid_refinements == 0)
  90. {
  91. throw std::domain_error("The wavelet requires at least 1 grid refinement.");
  92. }
  93. if constexpr (p == 1)
  94. {
  95. return;
  96. }
  97. else
  98. {
  99. if (grid_refinements < 0)
  100. {
  101. if (std::is_same_v<Real, float>)
  102. {
  103. if (grid_refinements == -2)
  104. {
  105. // Control absolute error:
  106. // p= 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19
  107. std::array<int, 20> r{ -1, -1, 18, 19, 16, 11, 8, 7, 7, 7, 5, 5, 4, 4, 4, 4, 3, 3, 3, 3 };
  108. grid_refinements = r[p];
  109. }
  110. else
  111. {
  112. // Control relative error:
  113. // p= 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19
  114. std::array<int, 20> r{ -1, -1, 21, 21, 21, 17, 16, 15, 14, 13, 12, 11, 11, 11, 11, 11, 11, 11, 11, 11 };
  115. grid_refinements = r[p];
  116. }
  117. }
  118. else if (std::is_same_v<Real, double>)
  119. {
  120. // p= 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19
  121. std::array<int, 20> r{ -1, -1, 21, 21, 21, 21, 21, 21, 21, 21, 20, 20, 19, 18, 18, 18, 18, 18, 18, 18 };
  122. grid_refinements = r[p];
  123. }
  124. else
  125. {
  126. grid_refinements = 21;
  127. }
  128. }
  129. // Compute the refined grid:
  130. // In fact for float precision I know the grid must be computed in double precision and then cast back down, or else parts of the support are systematically inaccurate.
  131. std::future<std::vector<Real>> t0 = std::async(std::launch::async, [&grid_refinements]() {
  132. // Computing in higher precision and downcasting is essential for 1ULP evaluation in float precision:
  133. auto v = daubechies_wavelet_dyadic_grid<typename detail::daubechies_eval_type<Real>::type, p, 0>(grid_refinements);
  134. return detail::daubechies_eval_type<Real>::vector_cast(v);
  135. });
  136. // Compute the derivative of the refined grid:
  137. std::future<std::vector<Real>> t1 = std::async(std::launch::async, [&grid_refinements]() {
  138. auto v = daubechies_wavelet_dyadic_grid<typename detail::daubechies_eval_type<Real>::type, p, 1>(grid_refinements);
  139. return detail::daubechies_eval_type<Real>::vector_cast(v);
  140. });
  141. // if necessary, compute the second and third derivative:
  142. std::vector<Real> d2ydx2;
  143. std::vector<Real> d3ydx3;
  144. if constexpr (p >= 6) {
  145. std::future<std::vector<Real>> t3 = std::async(std::launch::async, [&grid_refinements]() {
  146. auto v = daubechies_wavelet_dyadic_grid<typename detail::daubechies_eval_type<Real>::type, p, 2>(grid_refinements);
  147. return detail::daubechies_eval_type<Real>::vector_cast(v);
  148. });
  149. if constexpr (p >= 10) {
  150. std::future<std::vector<Real>> t4 = std::async(std::launch::async, [&grid_refinements]() {
  151. auto v = daubechies_wavelet_dyadic_grid<typename detail::daubechies_eval_type<Real>::type, p, 3>(grid_refinements);
  152. return detail::daubechies_eval_type<Real>::vector_cast(v);
  153. });
  154. d3ydx3 = t4.get();
  155. }
  156. d2ydx2 = t3.get();
  157. }
  158. auto y = t0.get();
  159. auto dydx = t1.get();
  160. if constexpr (p >= 2)
  161. {
  162. vector_type data(y.size());
  163. for (size_t i = 0; i < y.size(); ++i)
  164. {
  165. data[i][0] = y[i];
  166. data[i][1] = dydx[i];
  167. if constexpr (p >= 6)
  168. data[i][2] = d2ydx2[i];
  169. if constexpr (p >= 10)
  170. data[i][3] = d3ydx3[i];
  171. }
  172. if constexpr (p <= 3)
  173. m_interpolator = std::make_shared<interpolator_type>(std::move(data), grid_refinements, Real(-p + 1));
  174. else
  175. m_interpolator = std::make_shared<interpolator_type>(std::move(data), Real(-p + 1), Real(1) / (1 << grid_refinements));
  176. }
  177. else
  178. m_interpolator = std::make_shared<detail::null_interpolator>();
  179. }
  180. }
  181. inline Real operator()(Real x) const
  182. {
  183. if (x <= -p + 1 || x >= p)
  184. {
  185. return 0;
  186. }
  187. if constexpr (p == 1)
  188. {
  189. if (x < Real(1) / Real(2))
  190. {
  191. return 1;
  192. }
  193. else if (x == Real(1) / Real(2))
  194. {
  195. return 0;
  196. }
  197. return -1;
  198. }
  199. return (*m_interpolator)(x);
  200. }
  201. inline Real prime(Real x) const
  202. {
  203. static_assert(p > 2, "The 3-vanishing moment Daubechies wavelet is the first which is continuously differentiable.");
  204. if (x <= -p + 1 || x >= p)
  205. {
  206. return 0;
  207. }
  208. return m_interpolator->prime(x);
  209. }
  210. inline Real double_prime(Real x) const
  211. {
  212. static_assert(p >= 6, "Second derivatives of Daubechies wavelets require at least 6 vanishing moments.");
  213. if (x <= -p + 1 || x >= p)
  214. {
  215. return Real(0);
  216. }
  217. return m_interpolator->double_prime(x);
  218. }
  219. std::pair<Real, Real> support() const
  220. {
  221. return { Real(-p + 1), Real(p) };
  222. }
  223. int64_t bytes() const
  224. {
  225. return m_interpolator->bytes() + sizeof(*this);
  226. }
  227. private:
  228. std::shared_ptr<interpolator_type> m_interpolator;
  229. };
  230. }
  231. #endif