test_core.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. from __future__ import annotations
  2. from sympy.core.singleton import S
  3. from sympy.core.basic import Basic
  4. from sympy.strategies.core import (
  5. null_safe, exhaust, memoize, condition,
  6. chain, tryit, do_one, debug, switch, minimize)
  7. from io import StringIO
  8. def posdec(x: int) -> int:
  9. if x > 0:
  10. return x - 1
  11. return x
  12. def inc(x: int) -> int:
  13. return x + 1
  14. def dec(x: int) -> int:
  15. return x - 1
  16. def test_null_safe():
  17. def rl(expr: int) -> int | None:
  18. if expr == 1:
  19. return 2
  20. return None
  21. safe_rl = null_safe(rl)
  22. assert rl(1) == safe_rl(1)
  23. assert rl(3) is None
  24. assert safe_rl(3) == 3
  25. def test_exhaust():
  26. sink = exhaust(posdec)
  27. assert sink(5) == 0
  28. assert sink(10) == 0
  29. def test_memoize():
  30. rl = memoize(posdec)
  31. assert rl(5) == posdec(5)
  32. assert rl(5) == posdec(5)
  33. assert rl(-2) == posdec(-2)
  34. def test_condition():
  35. rl = condition(lambda x: x % 2 == 0, posdec)
  36. assert rl(5) == 5
  37. assert rl(4) == 3
  38. def test_chain():
  39. rl = chain(posdec, posdec)
  40. assert rl(5) == 3
  41. assert rl(1) == 0
  42. def test_tryit():
  43. def rl(expr: Basic) -> Basic:
  44. assert False
  45. safe_rl = tryit(rl, AssertionError)
  46. assert safe_rl(S(1)) == S(1)
  47. def test_do_one():
  48. rl = do_one(posdec, posdec)
  49. assert rl(5) == 4
  50. def rl1(x: int) -> int:
  51. if x == 1:
  52. return 2
  53. return x
  54. def rl2(x: int) -> int:
  55. if x == 2:
  56. return 3
  57. return x
  58. rule = do_one(rl1, rl2)
  59. assert rule(1) == 2
  60. assert rule(rule(1)) == 3
  61. def test_debug():
  62. file = StringIO()
  63. rl = debug(posdec, file)
  64. rl(5)
  65. log = file.getvalue()
  66. file.close()
  67. assert posdec.__name__ in log
  68. assert '5' in log
  69. assert '4' in log
  70. def test_switch():
  71. def key(x: int) -> int:
  72. return x % 3
  73. rl = switch(key, {0: inc, 1: dec})
  74. assert rl(3) == 4
  75. assert rl(4) == 3
  76. assert rl(5) == 5
  77. def test_minimize():
  78. def key(x: int) -> int:
  79. return -x
  80. rl = minimize(inc, dec)
  81. assert rl(4) == 3
  82. rl = minimize(inc, dec, objective=key)
  83. assert rl(4) == 5