test_cnodes.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. from sympy.core.symbol import symbols
  2. from sympy.printing.codeprinter import ccode
  3. from sympy.codegen.ast import Declaration, Variable, float64, int64, String, CodeBlock
  4. from sympy.codegen.cnodes import (
  5. alignof, CommaOperator, goto, Label, PreDecrement, PostDecrement, PreIncrement, PostIncrement,
  6. sizeof, union, struct
  7. )
  8. x, y = symbols('x y')
  9. def test_alignof():
  10. ax = alignof(x)
  11. assert ccode(ax) == 'alignof(x)'
  12. assert ax.func(*ax.args) == ax
  13. def test_CommaOperator():
  14. expr = CommaOperator(PreIncrement(x), 2*x)
  15. assert ccode(expr) == '(++(x), 2*x)'
  16. assert expr.func(*expr.args) == expr
  17. def test_goto_Label():
  18. s = 'early_exit'
  19. g = goto(s)
  20. assert g.func(*g.args) == g
  21. assert g != goto('foobar')
  22. assert ccode(g) == 'goto early_exit'
  23. l1 = Label(s)
  24. assert ccode(l1) == 'early_exit:'
  25. assert l1 == Label('early_exit')
  26. assert l1 != Label('foobar')
  27. body = [PreIncrement(x)]
  28. l2 = Label(s, body)
  29. assert l2.name == String("early_exit")
  30. assert l2.body == CodeBlock(PreIncrement(x))
  31. assert ccode(l2) == ("early_exit:\n"
  32. "++(x);")
  33. body = [PreIncrement(x), PreDecrement(y)]
  34. l2 = Label(s, body)
  35. assert l2.name == String("early_exit")
  36. assert l2.body == CodeBlock(PreIncrement(x), PreDecrement(y))
  37. assert ccode(l2) == ("early_exit:\n"
  38. "{\n ++(x);\n --(y);\n}")
  39. def test_PreDecrement():
  40. p = PreDecrement(x)
  41. assert p.func(*p.args) == p
  42. assert ccode(p) == '--(x)'
  43. def test_PostDecrement():
  44. p = PostDecrement(x)
  45. assert p.func(*p.args) == p
  46. assert ccode(p) == '(x)--'
  47. def test_PreIncrement():
  48. p = PreIncrement(x)
  49. assert p.func(*p.args) == p
  50. assert ccode(p) == '++(x)'
  51. def test_PostIncrement():
  52. p = PostIncrement(x)
  53. assert p.func(*p.args) == p
  54. assert ccode(p) == '(x)++'
  55. def test_sizeof():
  56. typename = 'unsigned int'
  57. sz = sizeof(typename)
  58. assert ccode(sz) == 'sizeof(%s)' % typename
  59. assert sz.func(*sz.args) == sz
  60. assert not sz.is_Atom
  61. assert sz.atoms() == {String('unsigned int'), String('sizeof')}
  62. def test_struct():
  63. vx, vy = Variable(x, type=float64), Variable(y, type=float64)
  64. s = struct('vec2', [vx, vy])
  65. assert s.func(*s.args) == s
  66. assert s == struct('vec2', (vx, vy))
  67. assert s != struct('vec2', (vy, vx))
  68. assert str(s.name) == 'vec2'
  69. assert len(s.declarations) == 2
  70. assert all(isinstance(arg, Declaration) for arg in s.declarations)
  71. assert ccode(s) == (
  72. "struct vec2 {\n"
  73. " double x;\n"
  74. " double y;\n"
  75. "}")
  76. def test_union():
  77. vx, vy = Variable(x, type=float64), Variable(y, type=int64)
  78. u = union('dualuse', [vx, vy])
  79. assert u.func(*u.args) == u
  80. assert u == union('dualuse', (vx, vy))
  81. assert str(u.name) == 'dualuse'
  82. assert len(u.declarations) == 2
  83. assert all(isinstance(arg, Declaration) for arg in u.declarations)
  84. assert ccode(u) == (
  85. "union dualuse {\n"
  86. " double x;\n"
  87. " int64_t y;\n"
  88. "}")