test_cfunctions.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. from sympy.core.numbers import (Rational, pi)
  2. from sympy.core.singleton import S
  3. from sympy.core.symbol import (Symbol, symbols)
  4. from sympy.functions.elementary.exponential import (exp, log)
  5. from sympy.codegen.cfunctions import (
  6. expm1, log1p, exp2, log2, fma, log10, Sqrt, Cbrt, hypot
  7. )
  8. from sympy.core.function import expand_log
  9. def test_expm1():
  10. # Eval
  11. assert expm1(0) == 0
  12. x = Symbol('x', real=True)
  13. # Expand and rewrite
  14. assert expm1(x).expand(func=True) - exp(x) == -1
  15. assert expm1(x).rewrite('tractable') - exp(x) == -1
  16. assert expm1(x).rewrite('exp') - exp(x) == -1
  17. # Precision
  18. assert not ((exp(1e-10).evalf() - 1) - 1e-10 - 5e-21) < 1e-22 # for comparison
  19. assert abs(expm1(1e-10).evalf() - 1e-10 - 5e-21) < 1e-22
  20. # Properties
  21. assert expm1(x).is_real
  22. assert expm1(x).is_finite
  23. # Diff
  24. assert expm1(42*x).diff(x) - 42*exp(42*x) == 0
  25. assert expm1(42*x).diff(x) - expm1(42*x).expand(func=True).diff(x) == 0
  26. def test_log1p():
  27. # Eval
  28. assert log1p(0) == 0
  29. d = S(10)
  30. assert expand_log(log1p(d**-1000) - log(d**1000 + 1) + log(d**1000)) == 0
  31. x = Symbol('x', real=True)
  32. # Expand and rewrite
  33. assert log1p(x).expand(func=True) - log(x + 1) == 0
  34. assert log1p(x).rewrite('tractable') - log(x + 1) == 0
  35. assert log1p(x).rewrite('log') - log(x + 1) == 0
  36. # Precision
  37. assert not abs(log(1e-99 + 1).evalf() - 1e-99) < 1e-100 # for comparison
  38. assert abs(expand_log(log1p(1e-99)).evalf() - 1e-99) < 1e-100
  39. # Properties
  40. assert log1p(-2**Rational(-1, 2)).is_real
  41. assert not log1p(-1).is_finite
  42. assert log1p(pi).is_finite
  43. assert not log1p(x).is_positive
  44. assert log1p(Symbol('y', positive=True)).is_positive
  45. assert not log1p(x).is_zero
  46. assert log1p(Symbol('z', zero=True)).is_zero
  47. assert not log1p(x).is_nonnegative
  48. assert log1p(Symbol('o', nonnegative=True)).is_nonnegative
  49. # Diff
  50. assert log1p(42*x).diff(x) - 42/(42*x + 1) == 0
  51. assert log1p(42*x).diff(x) - log1p(42*x).expand(func=True).diff(x) == 0
  52. def test_exp2():
  53. # Eval
  54. assert exp2(2) == 4
  55. x = Symbol('x', real=True)
  56. # Expand
  57. assert exp2(x).expand(func=True) - 2**x == 0
  58. # Diff
  59. assert exp2(42*x).diff(x) - 42*exp2(42*x)*log(2) == 0
  60. assert exp2(42*x).diff(x) - exp2(42*x).diff(x) == 0
  61. def test_log2():
  62. # Eval
  63. assert log2(8) == 3
  64. assert log2(pi) != log(pi)/log(2) # log2 should *save* (CPU) instructions
  65. x = Symbol('x', real=True)
  66. assert log2(x) != log(x)/log(2)
  67. assert log2(2**x) == x
  68. # Expand
  69. assert log2(x).expand(func=True) - log(x)/log(2) == 0
  70. # Diff
  71. assert log2(42*x).diff() - 1/(log(2)*x) == 0
  72. assert log2(42*x).diff() - log2(42*x).expand(func=True).diff(x) == 0
  73. def test_fma():
  74. x, y, z = symbols('x y z')
  75. # Expand
  76. assert fma(x, y, z).expand(func=True) - x*y - z == 0
  77. expr = fma(17*x, 42*y, 101*z)
  78. # Diff
  79. assert expr.diff(x) - expr.expand(func=True).diff(x) == 0
  80. assert expr.diff(y) - expr.expand(func=True).diff(y) == 0
  81. assert expr.diff(z) - expr.expand(func=True).diff(z) == 0
  82. assert expr.diff(x) - 17*42*y == 0
  83. assert expr.diff(y) - 17*42*x == 0
  84. assert expr.diff(z) - 101 == 0
  85. def test_log10():
  86. x = Symbol('x')
  87. # Expand
  88. assert log10(x).expand(func=True) - log(x)/log(10) == 0
  89. # Diff
  90. assert log10(42*x).diff(x) - 1/(log(10)*x) == 0
  91. assert log10(42*x).diff(x) - log10(42*x).expand(func=True).diff(x) == 0
  92. def test_Cbrt():
  93. x = Symbol('x')
  94. # Expand
  95. assert Cbrt(x).expand(func=True) - x**Rational(1, 3) == 0
  96. # Diff
  97. assert Cbrt(42*x).diff(x) - 42*(42*x)**(Rational(1, 3) - 1)/3 == 0
  98. assert Cbrt(42*x).diff(x) - Cbrt(42*x).expand(func=True).diff(x) == 0
  99. def test_Sqrt():
  100. x = Symbol('x')
  101. # Expand
  102. assert Sqrt(x).expand(func=True) - x**S.Half == 0
  103. # Diff
  104. assert Sqrt(42*x).diff(x) - 42*(42*x)**(S.Half - 1)/2 == 0
  105. assert Sqrt(42*x).diff(x) - Sqrt(42*x).expand(func=True).diff(x) == 0
  106. def test_hypot():
  107. x, y = symbols('x y')
  108. # Expand
  109. assert hypot(x, y).expand(func=True) - (x**2 + y**2)**S.Half == 0
  110. # Diff
  111. assert hypot(17*x, 42*y).diff(x).expand(func=True) - hypot(17*x, 42*y).expand(func=True).diff(x) == 0
  112. assert hypot(17*x, 42*y).diff(y).expand(func=True) - hypot(17*x, 42*y).expand(func=True).diff(y) == 0
  113. assert hypot(17*x, 42*y).diff(x).expand(func=True) - 2*17*17*x*((17*x)**2 + (42*y)**2)**Rational(-1, 2)/2 == 0
  114. assert hypot(17*x, 42*y).diff(y).expand(func=True) - 2*42*42*y*((17*x)**2 + (42*y)**2)**Rational(-1, 2)/2 == 0