test_array_comprehension.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. from sympy.tensor.array.array_comprehension import ArrayComprehension, ArrayComprehensionMap
  2. from sympy.tensor.array import ImmutableDenseNDimArray
  3. from sympy.abc import i, j, k, l
  4. from sympy.testing.pytest import raises, warns
  5. from sympy.utilities.exceptions import SymPyDeprecationWarning
  6. from sympy.matrices import Matrix
  7. def test_array_comprehension():
  8. a = ArrayComprehension(i*j, (i, 1, 3), (j, 2, 4))
  9. b = ArrayComprehension(i, (i, 1, j+1))
  10. c = ArrayComprehension(i+j+k+l, (i, 1, 2), (j, 1, 3), (k, 1, 4), (l, 1, 5))
  11. d = ArrayComprehension(k, (i, 1, 5))
  12. e = ArrayComprehension(i, (j, k+1, k+5))
  13. assert a.doit().tolist() == [[2, 3, 4], [4, 6, 8], [6, 9, 12]]
  14. assert a.shape == (3, 3)
  15. assert a.is_shape_numeric == True
  16. assert a.tolist() == [[2, 3, 4], [4, 6, 8], [6, 9, 12]]
  17. assert a.tomatrix() == Matrix([
  18. [2, 3, 4],
  19. [4, 6, 8],
  20. [6, 9, 12]])
  21. assert len(a) == 9
  22. assert isinstance(b.doit(), ArrayComprehension)
  23. assert isinstance(a.doit(), ImmutableDenseNDimArray)
  24. assert b.subs(j, 3) == ArrayComprehension(i, (i, 1, 4))
  25. assert b.free_symbols == {j}
  26. assert b.shape == (j + 1,)
  27. assert b.rank() == 1
  28. assert b.is_shape_numeric == False
  29. assert c.free_symbols == set()
  30. assert c.function == i + j + k + l
  31. assert c.limits == ((i, 1, 2), (j, 1, 3), (k, 1, 4), (l, 1, 5))
  32. assert c.doit().tolist() == [[[[4, 5, 6, 7, 8], [5, 6, 7, 8, 9], [6, 7, 8, 9, 10], [7, 8, 9, 10, 11]],
  33. [[5, 6, 7, 8, 9], [6, 7, 8, 9, 10], [7, 8, 9, 10, 11], [8, 9, 10, 11, 12]],
  34. [[6, 7, 8, 9, 10], [7, 8, 9, 10, 11], [8, 9, 10, 11, 12], [9, 10, 11, 12, 13]]],
  35. [[[5, 6, 7, 8, 9], [6, 7, 8, 9, 10], [7, 8, 9, 10, 11], [8, 9, 10, 11, 12]],
  36. [[6, 7, 8, 9, 10], [7, 8, 9, 10, 11], [8, 9, 10, 11, 12], [9, 10, 11, 12, 13]],
  37. [[7, 8, 9, 10, 11], [8, 9, 10, 11, 12], [9, 10, 11, 12, 13], [10, 11, 12, 13, 14]]]]
  38. assert c.free_symbols == set()
  39. assert c.variables == [i, j, k, l]
  40. assert c.bound_symbols == [i, j, k, l]
  41. assert d.doit().tolist() == [k, k, k, k, k]
  42. assert len(e) == 5
  43. raises(TypeError, lambda: ArrayComprehension(i*j, (i, 1, 3), (j, 2, [1, 3, 2])))
  44. raises(ValueError, lambda: ArrayComprehension(i*j, (i, 1, 3), (j, 2, 1)))
  45. raises(ValueError, lambda: ArrayComprehension(i*j, (i, 1, 3), (j, 2, j+1)))
  46. raises(ValueError, lambda: len(ArrayComprehension(i*j, (i, 1, 3), (j, 2, j+4))))
  47. raises(TypeError, lambda: ArrayComprehension(i*j, (i, 0, i + 1.5), (j, 0, 2)))
  48. raises(ValueError, lambda: b.tolist())
  49. raises(ValueError, lambda: b.tomatrix())
  50. raises(ValueError, lambda: c.tomatrix())
  51. def test_arraycomprehensionmap():
  52. a = ArrayComprehensionMap(lambda i: i+1, (i, 1, 5))
  53. assert a.doit().tolist() == [2, 3, 4, 5, 6]
  54. assert a.shape == (5,)
  55. assert a.is_shape_numeric
  56. assert a.tolist() == [2, 3, 4, 5, 6]
  57. assert len(a) == 5
  58. assert isinstance(a.doit(), ImmutableDenseNDimArray)
  59. expr = ArrayComprehensionMap(lambda i: i+1, (i, 1, k))
  60. assert expr.doit() == expr
  61. assert expr.subs(k, 4) == ArrayComprehensionMap(lambda i: i+1, (i, 1, 4))
  62. assert expr.subs(k, 4).doit() == ImmutableDenseNDimArray([2, 3, 4, 5])
  63. b = ArrayComprehensionMap(lambda i: i+1, (i, 1, 2), (i, 1, 3), (i, 1, 4), (i, 1, 5))
  64. assert b.doit().tolist() == [[[[2, 3, 4, 5, 6], [3, 5, 7, 9, 11], [4, 7, 10, 13, 16], [5, 9, 13, 17, 21]],
  65. [[3, 5, 7, 9, 11], [5, 9, 13, 17, 21], [7, 13, 19, 25, 31], [9, 17, 25, 33, 41]],
  66. [[4, 7, 10, 13, 16], [7, 13, 19, 25, 31], [10, 19, 28, 37, 46], [13, 25, 37, 49, 61]]],
  67. [[[3, 5, 7, 9, 11], [5, 9, 13, 17, 21], [7, 13, 19, 25, 31], [9, 17, 25, 33, 41]],
  68. [[5, 9, 13, 17, 21], [9, 17, 25, 33, 41], [13, 25, 37, 49, 61], [17, 33, 49, 65, 81]],
  69. [[7, 13, 19, 25, 31], [13, 25, 37, 49, 61], [19, 37, 55, 73, 91], [25, 49, 73, 97, 121]]]]
  70. # tests about lambda expression
  71. assert ArrayComprehensionMap(lambda: 3, (i, 1, 5)).doit().tolist() == [3, 3, 3, 3, 3]
  72. assert ArrayComprehensionMap(lambda i: i+1, (i, 1, 5)).doit().tolist() == [2, 3, 4, 5, 6]
  73. raises(ValueError, lambda: ArrayComprehensionMap(i*j, (i, 1, 3), (j, 2, 4)))
  74. # The use of a function here triggers a deprecation warning from sympify()
  75. with warns(SymPyDeprecationWarning, test_stacklevel=False):
  76. a = ArrayComprehensionMap(lambda i, j: i+j, (i, 1, 5))
  77. raises(ValueError, lambda: a.doit())