123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347 |
- from sympy.concrete.summations import Sum
- from sympy.core.mod import Mod
- from sympy.core.relational import (Equality, Unequality)
- from sympy.functions.elementary.miscellaneous import sqrt
- from sympy.functions.elementary.piecewise import Piecewise
- from sympy.matrices.expressions.blockmatrix import BlockMatrix
- from sympy.matrices.expressions.matexpr import MatrixSymbol
- from sympy.matrices.expressions.special import Identity
- from sympy.utilities.lambdify import lambdify
- from sympy.abc import x, i, j, a, b, c, d
- from sympy.core import Pow
- from sympy.codegen.matrix_nodes import MatrixSolve
- from sympy.codegen.numpy_nodes import logaddexp, logaddexp2
- from sympy.codegen.cfunctions import log1p, expm1, hypot, log10, exp2, log2, Sqrt
- from sympy.tensor.array import Array
- from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, ArrayAdd, \
- PermuteDims, ArrayDiagonal
- from sympy.printing.numpy import JaxPrinter, _jax_known_constants, _jax_known_functions
- from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array
- from sympy.testing.pytest import skip, raises
- from sympy.external import import_module
- # Unlike NumPy which will aggressively promote operands to double precision,
- # jax always uses single precision. Double precision in jax can be
- # configured before the call to `import jax`, however this must be explicitly
- # configured and is not fully supported. Thus, the tests here have been modified
- # from the tests in test_numpy.py, only in the fact that they assert lambdify
- # function accuracy to only single precision accuracy.
- # https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision
- jax = import_module('jax')
- if jax:
- deafult_float_info = jax.numpy.finfo(jax.numpy.array([]).dtype)
- JAX_DEFAULT_EPSILON = deafult_float_info.eps
- def test_jax_piecewise_regression():
- """
- NumPyPrinter needs to print Piecewise()'s choicelist as a list to avoid
- breaking compatibility with numpy 1.8. This is not necessary in numpy 1.9+.
- See gh-9747 and gh-9749 for details.
- """
- printer = JaxPrinter()
- p = Piecewise((1, x < 0), (0, True))
- assert printer.doprint(p) == \
- 'jax.numpy.select([jax.numpy.less(x, 0),True], [1,0], default=jax.numpy.nan)'
- assert printer.module_imports == {'jax.numpy': {'select', 'less', 'nan'}}
- def test_jax_logaddexp():
- lae = logaddexp(a, b)
- assert JaxPrinter().doprint(lae) == 'jax.numpy.logaddexp(a, b)'
- lae2 = logaddexp2(a, b)
- assert JaxPrinter().doprint(lae2) == 'jax.numpy.logaddexp2(a, b)'
- def test_jax_sum():
- if not jax:
- skip("JAX not installed")
- s = Sum(x ** i, (i, a, b))
- f = lambdify((a, b, x), s, 'jax')
- a_, b_ = 0, 10
- x_ = jax.numpy.linspace(-1, +1, 10)
- assert jax.numpy.allclose(f(a_, b_, x_), sum(x_ ** i_ for i_ in range(a_, b_ + 1)))
- s = Sum(i * x, (i, a, b))
- f = lambdify((a, b, x), s, 'jax')
- a_, b_ = 0, 10
- x_ = jax.numpy.linspace(-1, +1, 10)
- assert jax.numpy.allclose(f(a_, b_, x_), sum(i_ * x_ for i_ in range(a_, b_ + 1)))
- def test_jax_multiple_sums():
- if not jax:
- skip("JAX not installed")
- s = Sum((x + j) * i, (i, a, b), (j, c, d))
- f = lambdify((a, b, c, d, x), s, 'jax')
- a_, b_ = 0, 10
- c_, d_ = 11, 21
- x_ = jax.numpy.linspace(-1, +1, 10)
- assert jax.numpy.allclose(f(a_, b_, c_, d_, x_),
- sum((x_ + j_) * i_ for i_ in range(a_, b_ + 1) for j_ in range(c_, d_ + 1)))
- def test_jax_codegen_einsum():
- if not jax:
- skip("JAX not installed")
- M = MatrixSymbol("M", 2, 2)
- N = MatrixSymbol("N", 2, 2)
- cg = convert_matrix_to_array(M * N)
- f = lambdify((M, N), cg, 'jax')
- ma = jax.numpy.array([[1, 2], [3, 4]])
- mb = jax.numpy.array([[1,-2], [-1, 3]])
- assert (f(ma, mb) == jax.numpy.matmul(ma, mb)).all()
- def test_jax_codegen_extra():
- if not jax:
- skip("JAX not installed")
- M = MatrixSymbol("M", 2, 2)
- N = MatrixSymbol("N", 2, 2)
- P = MatrixSymbol("P", 2, 2)
- Q = MatrixSymbol("Q", 2, 2)
- ma = jax.numpy.array([[1, 2], [3, 4]])
- mb = jax.numpy.array([[1,-2], [-1, 3]])
- mc = jax.numpy.array([[2, 0], [1, 2]])
- md = jax.numpy.array([[1,-1], [4, 7]])
- cg = ArrayTensorProduct(M, N)
- f = lambdify((M, N), cg, 'jax')
- assert (f(ma, mb) == jax.numpy.einsum(ma, [0, 1], mb, [2, 3])).all()
- cg = ArrayAdd(M, N)
- f = lambdify((M, N), cg, 'jax')
- assert (f(ma, mb) == ma+mb).all()
- cg = ArrayAdd(M, N, P)
- f = lambdify((M, N, P), cg, 'jax')
- assert (f(ma, mb, mc) == ma+mb+mc).all()
- cg = ArrayAdd(M, N, P, Q)
- f = lambdify((M, N, P, Q), cg, 'jax')
- assert (f(ma, mb, mc, md) == ma+mb+mc+md).all()
- cg = PermuteDims(M, [1, 0])
- f = lambdify((M,), cg, 'jax')
- assert (f(ma) == ma.T).all()
- cg = PermuteDims(ArrayTensorProduct(M, N), [1, 2, 3, 0])
- f = lambdify((M, N), cg, 'jax')
- assert (f(ma, mb) == jax.numpy.transpose(jax.numpy.einsum(ma, [0, 1], mb, [2, 3]), (1, 2, 3, 0))).all()
- cg = ArrayDiagonal(ArrayTensorProduct(M, N), (1, 2))
- f = lambdify((M, N), cg, 'jax')
- assert (f(ma, mb) == jax.numpy.diagonal(jax.numpy.einsum(ma, [0, 1], mb, [2, 3]), axis1=1, axis2=2)).all()
- def test_jax_relational():
- if not jax:
- skip("JAX not installed")
- e = Equality(x, 1)
- f = lambdify((x,), e, 'jax')
- x_ = jax.numpy.array([0, 1, 2])
- assert jax.numpy.array_equal(f(x_), [False, True, False])
- e = Unequality(x, 1)
- f = lambdify((x,), e, 'jax')
- x_ = jax.numpy.array([0, 1, 2])
- assert jax.numpy.array_equal(f(x_), [True, False, True])
- e = (x < 1)
- f = lambdify((x,), e, 'jax')
- x_ = jax.numpy.array([0, 1, 2])
- assert jax.numpy.array_equal(f(x_), [True, False, False])
- e = (x <= 1)
- f = lambdify((x,), e, 'jax')
- x_ = jax.numpy.array([0, 1, 2])
- assert jax.numpy.array_equal(f(x_), [True, True, False])
- e = (x > 1)
- f = lambdify((x,), e, 'jax')
- x_ = jax.numpy.array([0, 1, 2])
- assert jax.numpy.array_equal(f(x_), [False, False, True])
- e = (x >= 1)
- f = lambdify((x,), e, 'jax')
- x_ = jax.numpy.array([0, 1, 2])
- assert jax.numpy.array_equal(f(x_), [False, True, True])
- # Multi-condition expressions
- e = (x >= 1) & (x < 2)
- f = lambdify((x,), e, 'jax')
- x_ = jax.numpy.array([0, 1, 2])
- assert jax.numpy.array_equal(f(x_), [False, True, False])
- e = (x >= 1) | (x < 2)
- f = lambdify((x,), e, 'jax')
- x_ = jax.numpy.array([0, 1, 2])
- assert jax.numpy.array_equal(f(x_), [True, True, True])
- def test_jax_mod():
- if not jax:
- skip("JAX not installed")
- e = Mod(a, b)
- f = lambdify((a, b), e, 'jax')
- a_ = jax.numpy.array([0, 1, 2, 3])
- b_ = 2
- assert jax.numpy.array_equal(f(a_, b_), [0, 1, 0, 1])
- a_ = jax.numpy.array([0, 1, 2, 3])
- b_ = jax.numpy.array([2, 2, 2, 2])
- assert jax.numpy.array_equal(f(a_, b_), [0, 1, 0, 1])
- a_ = jax.numpy.array([2, 3, 4, 5])
- b_ = jax.numpy.array([2, 3, 4, 5])
- assert jax.numpy.array_equal(f(a_, b_), [0, 0, 0, 0])
- def test_jax_pow():
- if not jax:
- skip('JAX not installed')
- expr = Pow(2, -1, evaluate=False)
- f = lambdify([], expr, 'jax')
- assert f() == 0.5
- def test_jax_expm1():
- if not jax:
- skip("JAX not installed")
- f = lambdify((a,), expm1(a), 'jax')
- assert abs(f(1e-10) - 1e-10 - 5e-21) <= 1e-10 * JAX_DEFAULT_EPSILON
- def test_jax_log1p():
- if not jax:
- skip("JAX not installed")
- f = lambdify((a,), log1p(a), 'jax')
- assert abs(f(1e-99) - 1e-99) <= 1e-99 * JAX_DEFAULT_EPSILON
- def test_jax_hypot():
- if not jax:
- skip("JAX not installed")
- assert abs(lambdify((a, b), hypot(a, b), 'jax')(3, 4) - 5) <= JAX_DEFAULT_EPSILON
- def test_jax_log10():
- if not jax:
- skip("JAX not installed")
- assert abs(lambdify((a,), log10(a), 'jax')(100) - 2) <= JAX_DEFAULT_EPSILON
- def test_jax_exp2():
- if not jax:
- skip("JAX not installed")
- assert abs(lambdify((a,), exp2(a), 'jax')(5) - 32) <= JAX_DEFAULT_EPSILON
- def test_jax_log2():
- if not jax:
- skip("JAX not installed")
- assert abs(lambdify((a,), log2(a), 'jax')(256) - 8) <= JAX_DEFAULT_EPSILON
- def test_jax_Sqrt():
- if not jax:
- skip("JAX not installed")
- assert abs(lambdify((a,), Sqrt(a), 'jax')(4) - 2) <= JAX_DEFAULT_EPSILON
- def test_jax_sqrt():
- if not jax:
- skip("JAX not installed")
- assert abs(lambdify((a,), sqrt(a), 'jax')(4) - 2) <= JAX_DEFAULT_EPSILON
- def test_jax_matsolve():
- if not jax:
- skip("JAX not installed")
- M = MatrixSymbol("M", 3, 3)
- x = MatrixSymbol("x", 3, 1)
- expr = M**(-1) * x + x
- matsolve_expr = MatrixSolve(M, x) + x
- f = lambdify((M, x), expr, 'jax')
- f_matsolve = lambdify((M, x), matsolve_expr, 'jax')
- m0 = jax.numpy.array([[1, 2, 3], [3, 2, 5], [5, 6, 7]])
- assert jax.numpy.linalg.matrix_rank(m0) == 3
- x0 = jax.numpy.array([3, 4, 5])
- assert jax.numpy.allclose(f_matsolve(m0, x0), f(m0, x0))
- def test_16857():
- if not jax:
- skip("JAX not installed")
- a_1 = MatrixSymbol('a_1', 10, 3)
- a_2 = MatrixSymbol('a_2', 10, 3)
- a_3 = MatrixSymbol('a_3', 10, 3)
- a_4 = MatrixSymbol('a_4', 10, 3)
- A = BlockMatrix([[a_1, a_2], [a_3, a_4]])
- assert A.shape == (20, 6)
- printer = JaxPrinter()
- assert printer.doprint(A) == 'jax.numpy.block([[a_1, a_2], [a_3, a_4]])'
- def test_issue_17006():
- if not jax:
- skip("JAX not installed")
- M = MatrixSymbol("M", 2, 2)
- f = lambdify(M, M + Identity(2), 'jax')
- ma = jax.numpy.array([[1, 2], [3, 4]])
- mr = jax.numpy.array([[2, 2], [3, 5]])
- assert (f(ma) == mr).all()
- from sympy.core.symbol import symbols
- n = symbols('n', integer=True)
- N = MatrixSymbol("M", n, n)
- raises(NotImplementedError, lambda: lambdify(N, N + Identity(n), 'jax'))
- def test_jax_array():
- assert JaxPrinter().doprint(Array(((1, 2), (3, 5)))) == 'jax.numpy.array([[1, 2], [3, 5]])'
- assert JaxPrinter().doprint(Array((1, 2))) == 'jax.numpy.array((1, 2))'
- def test_jax_known_funcs_consts():
- assert _jax_known_constants['NaN'] == 'jax.numpy.nan'
- assert _jax_known_constants['EulerGamma'] == 'jax.numpy.euler_gamma'
- assert _jax_known_functions['acos'] == 'jax.numpy.arccos'
- assert _jax_known_functions['log'] == 'jax.numpy.log'
- def test_jax_print_methods():
- prntr = JaxPrinter()
- assert hasattr(prntr, '_print_acos')
- assert hasattr(prntr, '_print_log')
|