test_lambertw.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. #
  2. # Tests for the lambertw function,
  3. # Adapted from the MPMath tests [1] by Yosef Meller, mellerf@netvision.net.il
  4. # Distributed under the same license as SciPy itself.
  5. #
  6. # [1] mpmath source code, Subversion revision 992
  7. # http://code.google.com/p/mpmath/source/browse/trunk/mpmath/tests/test_functions2.py?spec=svn994&r=992
  8. import pytest
  9. import numpy as np
  10. from numpy.testing import assert_, assert_equal, assert_array_almost_equal
  11. from scipy.special import lambertw
  12. from numpy import nan, inf, pi, e, isnan, log, r_, array, complex_
  13. from scipy.special._testutils import FuncData
  14. def test_values():
  15. assert_(isnan(lambertw(nan)))
  16. assert_equal(lambertw(inf,1).real, inf)
  17. assert_equal(lambertw(inf,1).imag, 2*pi)
  18. assert_equal(lambertw(-inf,1).real, inf)
  19. assert_equal(lambertw(-inf,1).imag, 3*pi)
  20. assert_equal(lambertw(1.), lambertw(1., 0))
  21. data = [
  22. (0,0, 0),
  23. (0+0j,0, 0),
  24. (inf,0, inf),
  25. (0,-1, -inf),
  26. (0,1, -inf),
  27. (0,3, -inf),
  28. (e,0, 1),
  29. (1,0, 0.567143290409783873),
  30. (-pi/2,0, 1j*pi/2),
  31. (-log(2)/2,0, -log(2)),
  32. (0.25,0, 0.203888354702240164),
  33. (-0.25,0, -0.357402956181388903),
  34. (-1./10000,0, -0.000100010001500266719),
  35. (-0.25,-1, -2.15329236411034965),
  36. (0.25,-1, -3.00899800997004620-4.07652978899159763j),
  37. (-0.25,-1, -2.15329236411034965),
  38. (0.25,1, -3.00899800997004620+4.07652978899159763j),
  39. (-0.25,1, -3.48973228422959210+7.41405453009603664j),
  40. (-4,0, 0.67881197132094523+1.91195078174339937j),
  41. (-4,1, -0.66743107129800988+7.76827456802783084j),
  42. (-4,-1, 0.67881197132094523-1.91195078174339937j),
  43. (1000,0, 5.24960285240159623),
  44. (1000,1, 4.91492239981054535+5.44652615979447070j),
  45. (1000,-1, 4.91492239981054535-5.44652615979447070j),
  46. (1000,5, 3.5010625305312892+29.9614548941181328j),
  47. (3+4j,0, 1.281561806123775878+0.533095222020971071j),
  48. (-0.4+0.4j,0, -0.10396515323290657+0.61899273315171632j),
  49. (3+4j,1, -0.11691092896595324+5.61888039871282334j),
  50. (3+4j,-1, 0.25856740686699742-3.85211668616143559j),
  51. (-0.5,-1, -0.794023632344689368-0.770111750510379110j),
  52. (-1./10000,1, -11.82350837248724344+6.80546081842002101j),
  53. (-1./10000,-1, -11.6671145325663544),
  54. (-1./10000,-2, -11.82350837248724344-6.80546081842002101j),
  55. (-1./100000,4, -14.9186890769540539+26.1856750178782046j),
  56. (-1./100000,5, -15.0931437726379218666+32.5525721210262290086j),
  57. ((2+1j)/10,0, 0.173704503762911669+0.071781336752835511j),
  58. ((2+1j)/10,1, -3.21746028349820063+4.56175438896292539j),
  59. ((2+1j)/10,-1, -3.03781405002993088-3.53946629633505737j),
  60. ((2+1j)/10,4, -4.6878509692773249+23.8313630697683291j),
  61. (-(2+1j)/10,0, -0.226933772515757933-0.164986470020154580j),
  62. (-(2+1j)/10,1, -2.43569517046110001+0.76974067544756289j),
  63. (-(2+1j)/10,-1, -3.54858738151989450-6.91627921869943589j),
  64. (-(2+1j)/10,4, -4.5500846928118151+20.6672982215434637j),
  65. (pi,0, 1.073658194796149172092178407024821347547745350410314531),
  66. # Former bug in generated branch,
  67. (-0.5+0.002j,0, -0.78917138132659918344 + 0.76743539379990327749j),
  68. (-0.5-0.002j,0, -0.78917138132659918344 - 0.76743539379990327749j),
  69. (-0.448+0.4j,0, -0.11855133765652382241 + 0.66570534313583423116j),
  70. (-0.448-0.4j,0, -0.11855133765652382241 - 0.66570534313583423116j),
  71. ]
  72. data = array(data, dtype=complex_)
  73. def w(x, y):
  74. return lambertw(x, y.real.astype(int))
  75. with np.errstate(all='ignore'):
  76. FuncData(w, data, (0,1), 2, rtol=1e-10, atol=1e-13).check()
  77. def test_ufunc():
  78. assert_array_almost_equal(
  79. lambertw(r_[0., e, 1.]), r_[0., 1., 0.567143290409783873])
  80. def test_lambertw_ufunc_loop_selection():
  81. # see https://github.com/scipy/scipy/issues/4895
  82. dt = np.dtype(np.complex128)
  83. assert_equal(lambertw(0, 0, 0).dtype, dt)
  84. assert_equal(lambertw([0], 0, 0).dtype, dt)
  85. assert_equal(lambertw(0, [0], 0).dtype, dt)
  86. assert_equal(lambertw(0, 0, [0]).dtype, dt)
  87. assert_equal(lambertw([0], [0], [0]).dtype, dt)
  88. @pytest.mark.parametrize('z', [1e-316, -2e-320j, -5e-318+1e-320j])
  89. def test_lambertw_subnormal_k0(z):
  90. # Verify that subnormal inputs are handled correctly on
  91. # the branch k=0 (regression test for gh-16291).
  92. w = lambertw(z)
  93. # For values this small, we can be sure that numerically,
  94. # lambertw(z) is z.
  95. assert w == z