test_scipy_nodes.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. from itertools import product
  2. from sympy.core.power import Pow
  3. from sympy.core.symbol import symbols
  4. from sympy.functions.elementary.exponential import exp, log
  5. from sympy.functions.elementary.trigonometric import cos
  6. from sympy.core.numbers import pi
  7. from sympy.codegen.scipy_nodes import cosm1, powm1
  8. x, y, z = symbols('x y z')
  9. def test_cosm1():
  10. cm1_xy = cosm1(x*y)
  11. ref_xy = cos(x*y) - 1
  12. for wrt, deriv_order in product([x, y, z], range(3)):
  13. assert (
  14. cm1_xy.diff(wrt, deriv_order) -
  15. ref_xy.diff(wrt, deriv_order)
  16. ).rewrite(cos).simplify() == 0
  17. expr_minus2 = cosm1(pi)
  18. assert expr_minus2.rewrite(cos) == -2
  19. assert cosm1(3.14).simplify() == cosm1(3.14) # cannot simplify with 3.14
  20. assert cosm1(pi/2).simplify() == -1
  21. assert (1/cos(x) - 1 + cosm1(x)/cos(x)).simplify() == 0
  22. def test_powm1():
  23. cases = {
  24. powm1(x, y): x**y - 1,
  25. powm1(x*y, z): (x*y)**z - 1,
  26. powm1(x, y*z): x**(y*z)-1,
  27. powm1(x*y*z, x*y*z): (x*y*z)**(x*y*z)-1
  28. }
  29. for pm1_e, ref_e in cases.items():
  30. for wrt, deriv_order in product([x, y, z], range(3)):
  31. der = pm1_e.diff(wrt, deriv_order)
  32. ref = ref_e.diff(wrt, deriv_order)
  33. delta = (der - ref).rewrite(Pow)
  34. assert delta.simplify() == 0
  35. eulers_constant_m1 = powm1(x, 1/log(x))
  36. assert eulers_constant_m1.rewrite(Pow) == exp(1) - 1
  37. assert eulers_constant_m1.simplify() == exp(1) - 1