123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426 |
- from sympy.codegen import Assignment
- from sympy.codegen.ast import none
- from sympy.codegen.cfunctions import expm1, log1p
- from sympy.codegen.scipy_nodes import cosm1
- from sympy.codegen.matrix_nodes import MatrixSolve
- from sympy.core import Expr, Mod, symbols, Eq, Le, Gt, zoo, oo, Rational, Pow
- from sympy.core.numbers import pi
- from sympy.core.singleton import S
- from sympy.functions import acos, KroneckerDelta, Piecewise, sign, sqrt, Min, Max, cot, acsch, asec, coth
- from sympy.logic import And, Or
- from sympy.matrices import SparseMatrix, MatrixSymbol, Identity
- from sympy.printing.pycode import (
- MpmathPrinter, PythonCodePrinter, pycode, SymPyPrinter
- )
- from sympy.printing.tensorflow import TensorflowPrinter
- from sympy.printing.numpy import NumPyPrinter, SciPyPrinter
- from sympy.testing.pytest import raises, skip
- from sympy.tensor import IndexedBase, Idx
- from sympy.tensor.array.expressions.array_expressions import ArraySymbol, ArrayDiagonal, ArrayContraction, ZeroArray, OneArray
- from sympy.external import import_module
- from sympy.functions.special.gamma_functions import loggamma
- x, y, z = symbols('x y z')
- p = IndexedBase("p")
- def test_PythonCodePrinter():
- prntr = PythonCodePrinter()
- assert not prntr.module_imports
- assert prntr.doprint(x**y) == 'x**y'
- assert prntr.doprint(Mod(x, 2)) == 'x % 2'
- assert prntr.doprint(-Mod(x, y)) == '-(x % y)'
- assert prntr.doprint(Mod(-x, y)) == '(-x) % y'
- assert prntr.doprint(And(x, y)) == 'x and y'
- assert prntr.doprint(Or(x, y)) == 'x or y'
- assert prntr.doprint(1/(x+y)) == '1/(x + y)'
- assert not prntr.module_imports
- assert prntr.doprint(pi) == 'math.pi'
- assert prntr.module_imports == {'math': {'pi'}}
- assert prntr.doprint(x**Rational(1, 2)) == 'math.sqrt(x)'
- assert prntr.doprint(sqrt(x)) == 'math.sqrt(x)'
- assert prntr.module_imports == {'math': {'pi', 'sqrt'}}
- assert prntr.doprint(acos(x)) == 'math.acos(x)'
- assert prntr.doprint(cot(x)) == '1/math.tan(x)'
- assert prntr.doprint(coth(x)) == '(math.exp(x) + math.exp(-x))/(math.exp(x) - math.exp(-x))'
- assert prntr.doprint(asec(x)) == 'math.acos(1/x)'
- assert prntr.doprint(acsch(x)) == 'math.log(math.sqrt(1 + x**(-2)) + 1/x)'
- assert prntr.doprint(Assignment(x, 2)) == 'x = 2'
- assert prntr.doprint(Piecewise((1, Eq(x, 0)),
- (2, x>6))) == '((1) if (x == 0) else (2) if (x > 6) else None)'
- assert prntr.doprint(Piecewise((2, Le(x, 0)),
- (3, Gt(x, 0)), evaluate=False)) == '((2) if (x <= 0) else'\
- ' (3) if (x > 0) else None)'
- assert prntr.doprint(sign(x)) == '(0.0 if x == 0 else math.copysign(1, x))'
- assert prntr.doprint(p[0, 1]) == 'p[0, 1]'
- assert prntr.doprint(KroneckerDelta(x,y)) == '(1 if x == y else 0)'
- assert prntr.doprint((2,3)) == "(2, 3)"
- assert prntr.doprint([2,3]) == "[2, 3]"
- assert prntr.doprint(Min(x, y)) == "min(x, y)"
- assert prntr.doprint(Max(x, y)) == "max(x, y)"
- def test_PythonCodePrinter_standard():
- prntr = PythonCodePrinter()
- assert prntr.standard == 'python3'
- raises(ValueError, lambda: PythonCodePrinter({'standard':'python4'}))
- def test_MpmathPrinter():
- p = MpmathPrinter()
- assert p.doprint(sign(x)) == 'mpmath.sign(x)'
- assert p.doprint(Rational(1, 2)) == 'mpmath.mpf(1)/mpmath.mpf(2)'
- assert p.doprint(S.Exp1) == 'mpmath.e'
- assert p.doprint(S.Pi) == 'mpmath.pi'
- assert p.doprint(S.GoldenRatio) == 'mpmath.phi'
- assert p.doprint(S.EulerGamma) == 'mpmath.euler'
- assert p.doprint(S.NaN) == 'mpmath.nan'
- assert p.doprint(S.Infinity) == 'mpmath.inf'
- assert p.doprint(S.NegativeInfinity) == 'mpmath.ninf'
- assert p.doprint(loggamma(x)) == 'mpmath.loggamma(x)'
- def test_NumPyPrinter():
- from sympy.core.function import Lambda
- from sympy.matrices.expressions.adjoint import Adjoint
- from sympy.matrices.expressions.diagonal import (DiagMatrix, DiagonalMatrix, DiagonalOf)
- from sympy.matrices.expressions.funcmatrix import FunctionMatrix
- from sympy.matrices.expressions.hadamard import HadamardProduct
- from sympy.matrices.expressions.kronecker import KroneckerProduct
- from sympy.matrices.expressions.special import (OneMatrix, ZeroMatrix)
- from sympy.abc import a, b
- p = NumPyPrinter()
- assert p.doprint(sign(x)) == 'numpy.sign(x)'
- A = MatrixSymbol("A", 2, 2)
- B = MatrixSymbol("B", 2, 2)
- C = MatrixSymbol("C", 1, 5)
- D = MatrixSymbol("D", 3, 4)
- assert p.doprint(A**(-1)) == "numpy.linalg.inv(A)"
- assert p.doprint(A**5) == "numpy.linalg.matrix_power(A, 5)"
- assert p.doprint(Identity(3)) == "numpy.eye(3)"
- u = MatrixSymbol('x', 2, 1)
- v = MatrixSymbol('y', 2, 1)
- assert p.doprint(MatrixSolve(A, u)) == 'numpy.linalg.solve(A, x)'
- assert p.doprint(MatrixSolve(A, u) + v) == 'numpy.linalg.solve(A, x) + y'
- assert p.doprint(ZeroMatrix(2, 3)) == "numpy.zeros((2, 3))"
- assert p.doprint(OneMatrix(2, 3)) == "numpy.ones((2, 3))"
- assert p.doprint(FunctionMatrix(4, 5, Lambda((a, b), a + b))) == \
- "numpy.fromfunction(lambda a, b: a + b, (4, 5))"
- assert p.doprint(HadamardProduct(A, B)) == "numpy.multiply(A, B)"
- assert p.doprint(KroneckerProduct(A, B)) == "numpy.kron(A, B)"
- assert p.doprint(Adjoint(A)) == "numpy.conjugate(numpy.transpose(A))"
- assert p.doprint(DiagonalOf(A)) == "numpy.reshape(numpy.diag(A), (-1, 1))"
- assert p.doprint(DiagMatrix(C)) == "numpy.diagflat(C)"
- assert p.doprint(DiagonalMatrix(D)) == "numpy.multiply(D, numpy.eye(3, 4))"
- # Workaround for numpy negative integer power errors
- assert p.doprint(x**-1) == 'x**(-1.0)'
- assert p.doprint(x**-2) == 'x**(-2.0)'
- expr = Pow(2, -1, evaluate=False)
- assert p.doprint(expr) == "2**(-1.0)"
- assert p.doprint(S.Exp1) == 'numpy.e'
- assert p.doprint(S.Pi) == 'numpy.pi'
- assert p.doprint(S.EulerGamma) == 'numpy.euler_gamma'
- assert p.doprint(S.NaN) == 'numpy.nan'
- assert p.doprint(S.Infinity) == 'numpy.PINF'
- assert p.doprint(S.NegativeInfinity) == 'numpy.NINF'
- def test_issue_18770():
- numpy = import_module('numpy')
- if not numpy:
- skip("numpy not installed.")
- from sympy.functions.elementary.miscellaneous import (Max, Min)
- from sympy.utilities.lambdify import lambdify
- expr1 = Min(0.1*x + 3, x + 1, 0.5*x + 1)
- func = lambdify(x, expr1, "numpy")
- assert (func(numpy.linspace(0, 3, 3)) == [1.0, 1.75, 2.5 ]).all()
- assert func(4) == 3
- expr1 = Max(x**2, x**3)
- func = lambdify(x,expr1, "numpy")
- assert (func(numpy.linspace(-1, 2, 4)) == [1, 0, 1, 8] ).all()
- assert func(4) == 64
- def test_SciPyPrinter():
- p = SciPyPrinter()
- expr = acos(x)
- assert 'numpy' not in p.module_imports
- assert p.doprint(expr) == 'numpy.arccos(x)'
- assert 'numpy' in p.module_imports
- assert not any(m.startswith('scipy') for m in p.module_imports)
- smat = SparseMatrix(2, 5, {(0, 1): 3})
- assert p.doprint(smat) == \
- 'scipy.sparse.coo_matrix(([3], ([0], [1])), shape=(2, 5))'
- assert 'scipy.sparse' in p.module_imports
- assert p.doprint(S.GoldenRatio) == 'scipy.constants.golden_ratio'
- assert p.doprint(S.Pi) == 'scipy.constants.pi'
- assert p.doprint(S.Exp1) == 'numpy.e'
- def test_pycode_reserved_words():
- s1, s2 = symbols('if else')
- raises(ValueError, lambda: pycode(s1 + s2, error_on_reserved=True))
- py_str = pycode(s1 + s2)
- assert py_str in ('else_ + if_', 'if_ + else_')
- def test_issue_20762():
- # Make sure pycode removes curly braces from subscripted variables
- a_b, b, a_11 = symbols('a_{b} b a_{11}')
- expr = a_b*b
- assert pycode(expr) == 'a_b*b'
- expr = a_11*b
- assert pycode(expr) == 'a_11*b'
- def test_sqrt():
- prntr = PythonCodePrinter()
- assert prntr._print_Pow(sqrt(x), rational=False) == 'math.sqrt(x)'
- assert prntr._print_Pow(1/sqrt(x), rational=False) == '1/math.sqrt(x)'
- prntr = PythonCodePrinter({'standard' : 'python3'})
- assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)'
- assert prntr._print_Pow(1/sqrt(x), rational=True) == 'x**(-1/2)'
- prntr = MpmathPrinter()
- assert prntr._print_Pow(sqrt(x), rational=False) == 'mpmath.sqrt(x)'
- assert prntr._print_Pow(sqrt(x), rational=True) == \
- "x**(mpmath.mpf(1)/mpmath.mpf(2))"
- prntr = NumPyPrinter()
- assert prntr._print_Pow(sqrt(x), rational=False) == 'numpy.sqrt(x)'
- assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)'
- prntr = SciPyPrinter()
- assert prntr._print_Pow(sqrt(x), rational=False) == 'numpy.sqrt(x)'
- assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)'
- prntr = SymPyPrinter()
- assert prntr._print_Pow(sqrt(x), rational=False) == 'sympy.sqrt(x)'
- assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)'
- def test_frac():
- from sympy.functions.elementary.integers import frac
- expr = frac(x)
- prntr = NumPyPrinter()
- assert prntr.doprint(expr) == 'numpy.mod(x, 1)'
- prntr = SciPyPrinter()
- assert prntr.doprint(expr) == 'numpy.mod(x, 1)'
- prntr = PythonCodePrinter()
- assert prntr.doprint(expr) == 'x % 1'
- prntr = MpmathPrinter()
- assert prntr.doprint(expr) == 'mpmath.frac(x)'
- prntr = SymPyPrinter()
- assert prntr.doprint(expr) == 'sympy.functions.elementary.integers.frac(x)'
- class CustomPrintedObject(Expr):
- def _numpycode(self, printer):
- return 'numpy'
- def _mpmathcode(self, printer):
- return 'mpmath'
- def test_printmethod():
- obj = CustomPrintedObject()
- assert NumPyPrinter().doprint(obj) == 'numpy'
- assert MpmathPrinter().doprint(obj) == 'mpmath'
- def test_codegen_ast_nodes():
- assert pycode(none) == 'None'
- def test_issue_14283():
- prntr = PythonCodePrinter()
- assert prntr.doprint(zoo) == "math.nan"
- assert prntr.doprint(-oo) == "float('-inf')"
- def test_NumPyPrinter_print_seq():
- n = NumPyPrinter()
- assert n._print_seq(range(2)) == '(0, 1,)'
- def test_issue_16535_16536():
- from sympy.functions.special.gamma_functions import (lowergamma, uppergamma)
- a = symbols('a')
- expr1 = lowergamma(a, x)
- expr2 = uppergamma(a, x)
- prntr = SciPyPrinter()
- assert prntr.doprint(expr1) == 'scipy.special.gamma(a)*scipy.special.gammainc(a, x)'
- assert prntr.doprint(expr2) == 'scipy.special.gamma(a)*scipy.special.gammaincc(a, x)'
- prntr = NumPyPrinter()
- assert "Not supported" in prntr.doprint(expr1)
- assert "Not supported" in prntr.doprint(expr2)
- prntr = PythonCodePrinter()
- assert "Not supported" in prntr.doprint(expr1)
- assert "Not supported" in prntr.doprint(expr2)
- def test_Integral():
- from sympy.functions.elementary.exponential import exp
- from sympy.integrals.integrals import Integral
- single = Integral(exp(-x), (x, 0, oo))
- double = Integral(x**2*exp(x*y), (x, -z, z), (y, 0, z))
- indefinite = Integral(x**2, x)
- evaluateat = Integral(x**2, (x, 1))
- prntr = SciPyPrinter()
- assert prntr.doprint(single) == 'scipy.integrate.quad(lambda x: numpy.exp(-x), 0, numpy.PINF)[0]'
- assert prntr.doprint(double) == 'scipy.integrate.nquad(lambda x, y: x**2*numpy.exp(x*y), ((-z, z), (0, z)))[0]'
- raises(NotImplementedError, lambda: prntr.doprint(indefinite))
- raises(NotImplementedError, lambda: prntr.doprint(evaluateat))
- prntr = MpmathPrinter()
- assert prntr.doprint(single) == 'mpmath.quad(lambda x: mpmath.exp(-x), (0, mpmath.inf))'
- assert prntr.doprint(double) == 'mpmath.quad(lambda x, y: x**2*mpmath.exp(x*y), (-z, z), (0, z))'
- raises(NotImplementedError, lambda: prntr.doprint(indefinite))
- raises(NotImplementedError, lambda: prntr.doprint(evaluateat))
- def test_fresnel_integrals():
- from sympy.functions.special.error_functions import (fresnelc, fresnels)
- expr1 = fresnelc(x)
- expr2 = fresnels(x)
- prntr = SciPyPrinter()
- assert prntr.doprint(expr1) == 'scipy.special.fresnel(x)[1]'
- assert prntr.doprint(expr2) == 'scipy.special.fresnel(x)[0]'
- prntr = NumPyPrinter()
- assert "Not supported" in prntr.doprint(expr1)
- assert "Not supported" in prntr.doprint(expr2)
- prntr = PythonCodePrinter()
- assert "Not supported" in prntr.doprint(expr1)
- assert "Not supported" in prntr.doprint(expr2)
- prntr = MpmathPrinter()
- assert prntr.doprint(expr1) == 'mpmath.fresnelc(x)'
- assert prntr.doprint(expr2) == 'mpmath.fresnels(x)'
- def test_beta():
- from sympy.functions.special.beta_functions import beta
- expr = beta(x, y)
- prntr = SciPyPrinter()
- assert prntr.doprint(expr) == 'scipy.special.beta(x, y)'
- prntr = NumPyPrinter()
- assert prntr.doprint(expr) == 'math.gamma(x)*math.gamma(y)/math.gamma(x + y)'
- prntr = PythonCodePrinter()
- assert prntr.doprint(expr) == 'math.gamma(x)*math.gamma(y)/math.gamma(x + y)'
- prntr = PythonCodePrinter({'allow_unknown_functions': True})
- assert prntr.doprint(expr) == 'math.gamma(x)*math.gamma(y)/math.gamma(x + y)'
- prntr = MpmathPrinter()
- assert prntr.doprint(expr) == 'mpmath.beta(x, y)'
- def test_airy():
- from sympy.functions.special.bessel import (airyai, airybi)
- expr1 = airyai(x)
- expr2 = airybi(x)
- prntr = SciPyPrinter()
- assert prntr.doprint(expr1) == 'scipy.special.airy(x)[0]'
- assert prntr.doprint(expr2) == 'scipy.special.airy(x)[2]'
- prntr = NumPyPrinter()
- assert "Not supported" in prntr.doprint(expr1)
- assert "Not supported" in prntr.doprint(expr2)
- prntr = PythonCodePrinter()
- assert "Not supported" in prntr.doprint(expr1)
- assert "Not supported" in prntr.doprint(expr2)
- def test_airy_prime():
- from sympy.functions.special.bessel import (airyaiprime, airybiprime)
- expr1 = airyaiprime(x)
- expr2 = airybiprime(x)
- prntr = SciPyPrinter()
- assert prntr.doprint(expr1) == 'scipy.special.airy(x)[1]'
- assert prntr.doprint(expr2) == 'scipy.special.airy(x)[3]'
- prntr = NumPyPrinter()
- assert "Not supported" in prntr.doprint(expr1)
- assert "Not supported" in prntr.doprint(expr2)
- prntr = PythonCodePrinter()
- assert "Not supported" in prntr.doprint(expr1)
- assert "Not supported" in prntr.doprint(expr2)
- def test_numerical_accuracy_functions():
- prntr = SciPyPrinter()
- assert prntr.doprint(expm1(x)) == 'numpy.expm1(x)'
- assert prntr.doprint(log1p(x)) == 'numpy.log1p(x)'
- assert prntr.doprint(cosm1(x)) == 'scipy.special.cosm1(x)'
- def test_array_printer():
- A = ArraySymbol('A', (4,4,6,6,6))
- I = IndexedBase('I')
- i,j,k = Idx('i', (0,1)), Idx('j', (2,3)), Idx('k', (4,5))
- prntr = NumPyPrinter()
- assert prntr.doprint(ZeroArray(5)) == 'numpy.zeros((5,))'
- assert prntr.doprint(OneArray(5)) == 'numpy.ones((5,))'
- assert prntr.doprint(ArrayContraction(A, [2,3])) == 'numpy.einsum("abccd->abd", A)'
- assert prntr.doprint(I) == 'I'
- assert prntr.doprint(ArrayDiagonal(A, [2,3,4])) == 'numpy.einsum("abccc->abc", A)'
- assert prntr.doprint(ArrayDiagonal(A, [0,1], [2,3])) == 'numpy.einsum("aabbc->cab", A)'
- assert prntr.doprint(ArrayContraction(A, [2], [3])) == 'numpy.einsum("abcde->abe", A)'
- assert prntr.doprint(Assignment(I[i,j,k], I[i,j,k])) == 'I = I'
- prntr = TensorflowPrinter()
- assert prntr.doprint(ZeroArray(5)) == 'tensorflow.zeros((5,))'
- assert prntr.doprint(OneArray(5)) == 'tensorflow.ones((5,))'
- assert prntr.doprint(ArrayContraction(A, [2,3])) == 'tensorflow.linalg.einsum("abccd->abd", A)'
- assert prntr.doprint(I) == 'I'
- assert prntr.doprint(ArrayDiagonal(A, [2,3,4])) == 'tensorflow.linalg.einsum("abccc->abc", A)'
- assert prntr.doprint(ArrayDiagonal(A, [0,1], [2,3])) == 'tensorflow.linalg.einsum("aabbc->cab", A)'
- assert prntr.doprint(ArrayContraction(A, [2], [3])) == 'tensorflow.linalg.einsum("abcde->abe", A)'
- assert prntr.doprint(Assignment(I[i,j,k], I[i,j,k])) == 'I = I'
|