test_dot.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. from sympy.printing.dot import (purestr, styleof, attrprint, dotnode,
  2. dotedges, dotprint)
  3. from sympy.core.basic import Basic
  4. from sympy.core.expr import Expr
  5. from sympy.core.numbers import (Float, Integer)
  6. from sympy.core.singleton import S
  7. from sympy.core.symbol import (Symbol, symbols)
  8. from sympy.printing.repr import srepr
  9. from sympy.abc import x
  10. def test_purestr():
  11. assert purestr(Symbol('x')) == "Symbol('x')"
  12. assert purestr(Basic(S(1), S(2))) == "Basic(Integer(1), Integer(2))"
  13. assert purestr(Float(2)) == "Float('2.0', precision=53)"
  14. assert purestr(Symbol('x'), with_args=True) == ("Symbol('x')", ())
  15. assert purestr(Basic(S(1), S(2)), with_args=True) == \
  16. ('Basic(Integer(1), Integer(2))', ('Integer(1)', 'Integer(2)'))
  17. assert purestr(Float(2), with_args=True) == \
  18. ("Float('2.0', precision=53)", ())
  19. def test_styleof():
  20. styles = [(Basic, {'color': 'blue', 'shape': 'ellipse'}),
  21. (Expr, {'color': 'black'})]
  22. assert styleof(Basic(S(1)), styles) == {'color': 'blue', 'shape': 'ellipse'}
  23. assert styleof(x + 1, styles) == {'color': 'black', 'shape': 'ellipse'}
  24. def test_attrprint():
  25. assert attrprint({'color': 'blue', 'shape': 'ellipse'}) == \
  26. '"color"="blue", "shape"="ellipse"'
  27. def test_dotnode():
  28. assert dotnode(x, repeat=False) == \
  29. '"Symbol(\'x\')" ["color"="black", "label"="x", "shape"="ellipse"];'
  30. assert dotnode(x+2, repeat=False) == \
  31. '"Add(Integer(2), Symbol(\'x\'))" ' \
  32. '["color"="black", "label"="Add", "shape"="ellipse"];', \
  33. dotnode(x+2,repeat=0)
  34. assert dotnode(x + x**2, repeat=False) == \
  35. '"Add(Symbol(\'x\'), Pow(Symbol(\'x\'), Integer(2)))" ' \
  36. '["color"="black", "label"="Add", "shape"="ellipse"];'
  37. assert dotnode(x + x**2, repeat=True) == \
  38. '"Add(Symbol(\'x\'), Pow(Symbol(\'x\'), Integer(2)))_()" ' \
  39. '["color"="black", "label"="Add", "shape"="ellipse"];'
  40. def test_dotedges():
  41. assert sorted(dotedges(x+2, repeat=False)) == [
  42. '"Add(Integer(2), Symbol(\'x\'))" -> "Integer(2)";',
  43. '"Add(Integer(2), Symbol(\'x\'))" -> "Symbol(\'x\')";'
  44. ]
  45. assert sorted(dotedges(x + 2, repeat=True)) == [
  46. '"Add(Integer(2), Symbol(\'x\'))_()" -> "Integer(2)_(0,)";',
  47. '"Add(Integer(2), Symbol(\'x\'))_()" -> "Symbol(\'x\')_(1,)";'
  48. ]
  49. def test_dotprint():
  50. text = dotprint(x+2, repeat=False)
  51. assert all(e in text for e in dotedges(x+2, repeat=False))
  52. assert all(
  53. n in text for n in [dotnode(expr, repeat=False)
  54. for expr in (x, Integer(2), x+2)])
  55. assert 'digraph' in text
  56. text = dotprint(x+x**2, repeat=False)
  57. assert all(e in text for e in dotedges(x+x**2, repeat=False))
  58. assert all(
  59. n in text for n in [dotnode(expr, repeat=False)
  60. for expr in (x, Integer(2), x**2)])
  61. assert 'digraph' in text
  62. text = dotprint(x+x**2, repeat=True)
  63. assert all(e in text for e in dotedges(x+x**2, repeat=True))
  64. assert all(
  65. n in text for n in [dotnode(expr, pos=())
  66. for expr in [x + x**2]])
  67. text = dotprint(x**x, repeat=True)
  68. assert all(e in text for e in dotedges(x**x, repeat=True))
  69. assert all(
  70. n in text for n in [dotnode(x, pos=(0,)), dotnode(x, pos=(1,))])
  71. assert 'digraph' in text
  72. def test_dotprint_depth():
  73. text = dotprint(3*x+2, depth=1)
  74. assert dotnode(3*x+2) in text
  75. assert dotnode(x) not in text
  76. text = dotprint(3*x+2)
  77. assert "depth" not in text
  78. def test_Matrix_and_non_basics():
  79. from sympy.matrices.expressions.matexpr import MatrixSymbol
  80. n = Symbol('n')
  81. assert dotprint(MatrixSymbol('X', n, n)) == \
  82. """digraph{
  83. # Graph style
  84. "ordering"="out"
  85. "rankdir"="TD"
  86. #########
  87. # Nodes #
  88. #########
  89. "MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" ["color"="black", "label"="MatrixSymbol", "shape"="ellipse"];
  90. "Str('X')_(0,)" ["color"="blue", "label"="X", "shape"="ellipse"];
  91. "Symbol('n')_(1,)" ["color"="black", "label"="n", "shape"="ellipse"];
  92. "Symbol('n')_(2,)" ["color"="black", "label"="n", "shape"="ellipse"];
  93. #########
  94. # Edges #
  95. #########
  96. "MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" -> "Str('X')_(0,)";
  97. "MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" -> "Symbol('n')_(1,)";
  98. "MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" -> "Symbol('n')_(2,)";
  99. }"""
  100. def test_labelfunc():
  101. text = dotprint(x + 2, labelfunc=srepr)
  102. assert "Symbol('x')" in text
  103. assert "Integer(2)" in text
  104. def test_commutative():
  105. x, y = symbols('x y', commutative=False)
  106. assert dotprint(x + y) == dotprint(y + x)
  107. assert dotprint(x*y) != dotprint(y*x)