test_basic_ops.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. import mpmath
  2. from mpmath import *
  3. from mpmath.libmp import *
  4. import random
  5. import sys
  6. try:
  7. long = long
  8. except NameError:
  9. long = int
  10. def test_type_compare():
  11. assert mpf(2) == mpc(2,0)
  12. assert mpf(0) == mpc(0)
  13. assert mpf(2) != mpc(2, 0.00001)
  14. assert mpf(2) == 2.0
  15. assert mpf(2) != 3.0
  16. assert mpf(2) == 2
  17. assert mpf(2) != '2.0'
  18. assert mpc(2) != '2.0'
  19. def test_add():
  20. assert mpf(2.5) + mpf(3) == 5.5
  21. assert mpf(2.5) + 3 == 5.5
  22. assert mpf(2.5) + 3.0 == 5.5
  23. assert 3 + mpf(2.5) == 5.5
  24. assert 3.0 + mpf(2.5) == 5.5
  25. assert (3+0j) + mpf(2.5) == 5.5
  26. assert mpc(2.5) + mpf(3) == 5.5
  27. assert mpc(2.5) + 3 == 5.5
  28. assert mpc(2.5) + 3.0 == 5.5
  29. assert mpc(2.5) + (3+0j) == 5.5
  30. assert 3 + mpc(2.5) == 5.5
  31. assert 3.0 + mpc(2.5) == 5.5
  32. assert (3+0j) + mpc(2.5) == 5.5
  33. def test_sub():
  34. assert mpf(2.5) - mpf(3) == -0.5
  35. assert mpf(2.5) - 3 == -0.5
  36. assert mpf(2.5) - 3.0 == -0.5
  37. assert 3 - mpf(2.5) == 0.5
  38. assert 3.0 - mpf(2.5) == 0.5
  39. assert (3+0j) - mpf(2.5) == 0.5
  40. assert mpc(2.5) - mpf(3) == -0.5
  41. assert mpc(2.5) - 3 == -0.5
  42. assert mpc(2.5) - 3.0 == -0.5
  43. assert mpc(2.5) - (3+0j) == -0.5
  44. assert 3 - mpc(2.5) == 0.5
  45. assert 3.0 - mpc(2.5) == 0.5
  46. assert (3+0j) - mpc(2.5) == 0.5
  47. def test_mul():
  48. assert mpf(2.5) * mpf(3) == 7.5
  49. assert mpf(2.5) * 3 == 7.5
  50. assert mpf(2.5) * 3.0 == 7.5
  51. assert 3 * mpf(2.5) == 7.5
  52. assert 3.0 * mpf(2.5) == 7.5
  53. assert (3+0j) * mpf(2.5) == 7.5
  54. assert mpc(2.5) * mpf(3) == 7.5
  55. assert mpc(2.5) * 3 == 7.5
  56. assert mpc(2.5) * 3.0 == 7.5
  57. assert mpc(2.5) * (3+0j) == 7.5
  58. assert 3 * mpc(2.5) == 7.5
  59. assert 3.0 * mpc(2.5) == 7.5
  60. assert (3+0j) * mpc(2.5) == 7.5
  61. def test_div():
  62. assert mpf(6) / mpf(3) == 2.0
  63. assert mpf(6) / 3 == 2.0
  64. assert mpf(6) / 3.0 == 2.0
  65. assert 6 / mpf(3) == 2.0
  66. assert 6.0 / mpf(3) == 2.0
  67. assert (6+0j) / mpf(3.0) == 2.0
  68. assert mpc(6) / mpf(3) == 2.0
  69. assert mpc(6) / 3 == 2.0
  70. assert mpc(6) / 3.0 == 2.0
  71. assert mpc(6) / (3+0j) == 2.0
  72. assert 6 / mpc(3) == 2.0
  73. assert 6.0 / mpc(3) == 2.0
  74. assert (6+0j) / mpc(3) == 2.0
  75. def test_pow():
  76. assert mpf(6) ** mpf(3) == 216.0
  77. assert mpf(6) ** 3 == 216.0
  78. assert mpf(6) ** 3.0 == 216.0
  79. assert 6 ** mpf(3) == 216.0
  80. assert 6.0 ** mpf(3) == 216.0
  81. assert (6+0j) ** mpf(3.0) == 216.0
  82. assert mpc(6) ** mpf(3) == 216.0
  83. assert mpc(6) ** 3 == 216.0
  84. assert mpc(6) ** 3.0 == 216.0
  85. assert mpc(6) ** (3+0j) == 216.0
  86. assert 6 ** mpc(3) == 216.0
  87. assert 6.0 ** mpc(3) == 216.0
  88. assert (6+0j) ** mpc(3) == 216.0
  89. def test_mixed_misc():
  90. assert 1 + mpf(3) == mpf(3) + 1 == 4
  91. assert 1 - mpf(3) == -(mpf(3) - 1) == -2
  92. assert 3 * mpf(2) == mpf(2) * 3 == 6
  93. assert 6 / mpf(2) == mpf(6) / 2 == 3
  94. assert 1.0 + mpf(3) == mpf(3) + 1.0 == 4
  95. assert 1.0 - mpf(3) == -(mpf(3) - 1.0) == -2
  96. assert 3.0 * mpf(2) == mpf(2) * 3.0 == 6
  97. assert 6.0 / mpf(2) == mpf(6) / 2.0 == 3
  98. def test_add_misc():
  99. mp.dps = 15
  100. assert mpf(4) + mpf(-70) == -66
  101. assert mpf(1) + mpf(1.1)/80 == 1 + 1.1/80
  102. assert mpf((1, 10000000000)) + mpf(3) == mpf((1, 10000000000))
  103. assert mpf(3) + mpf((1, 10000000000)) == mpf((1, 10000000000))
  104. assert mpf((1, -10000000000)) + mpf(3) == mpf(3)
  105. assert mpf(3) + mpf((1, -10000000000)) == mpf(3)
  106. assert mpf(1) + 1e-15 != 1
  107. assert mpf(1) + 1e-20 == 1
  108. assert mpf(1.07e-22) + 0 == mpf(1.07e-22)
  109. assert mpf(0) + mpf(1.07e-22) == mpf(1.07e-22)
  110. def test_complex_misc():
  111. # many more tests needed
  112. assert 1 + mpc(2) == 3
  113. assert not mpc(2).ae(2 + 1e-13)
  114. assert mpc(2+1e-15j).ae(2)
  115. def test_complex_zeros():
  116. for a in [0,2]:
  117. for b in [0,3]:
  118. for c in [0,4]:
  119. for d in [0,5]:
  120. assert mpc(a,b)*mpc(c,d) == complex(a,b)*complex(c,d)
  121. def test_hash():
  122. for i in range(-256, 256):
  123. assert hash(mpf(i)) == hash(i)
  124. assert hash(mpf(0.5)) == hash(0.5)
  125. assert hash(mpc(2,3)) == hash(2+3j)
  126. # Check that this doesn't fail
  127. assert hash(inf)
  128. # Check that overflow doesn't assign equal hashes to large numbers
  129. assert hash(mpf('1e1000')) != hash('1e10000')
  130. assert hash(mpc(100,'1e1000')) != hash(mpc(200,'1e1000'))
  131. from mpmath.rational import mpq
  132. assert hash(mp.mpq(1,3))
  133. assert hash(mp.mpq(0,1)) == 0
  134. assert hash(mp.mpq(-1,1)) == hash(-1)
  135. assert hash(mp.mpq(1,1)) == hash(1)
  136. assert hash(mp.mpq(5,1)) == hash(5)
  137. assert hash(mp.mpq(1,2)) == hash(0.5)
  138. if sys.version_info >= (3, 2):
  139. assert hash(mpf(1)*2**2000) == hash(2**2000)
  140. assert hash(mpf(1)/2**2000) == hash(mpq(1,2**2000))
  141. # Advanced rounding test
  142. def test_add_rounding():
  143. mp.dps = 15
  144. a = from_float(1e-50)
  145. assert mpf_sub(mpf_add(fone, a, 53, round_up), fone, 53, round_up) == from_float(2.2204460492503131e-16)
  146. assert mpf_sub(fone, a, 53, round_up) == fone
  147. assert mpf_sub(fone, mpf_sub(fone, a, 53, round_down), 53, round_down) == from_float(1.1102230246251565e-16)
  148. assert mpf_add(fone, a, 53, round_down) == fone
  149. def test_almost_equal():
  150. assert mpf(1.2).ae(mpf(1.20000001), 1e-7)
  151. assert not mpf(1.2).ae(mpf(1.20000001), 1e-9)
  152. assert not mpf(-0.7818314824680298).ae(mpf(-0.774695868667929))
  153. def test_arithmetic_functions():
  154. import operator
  155. ops = [(operator.add, fadd), (operator.sub, fsub), (operator.mul, fmul),
  156. (operator.truediv, fdiv)]
  157. a = mpf(0.27)
  158. b = mpf(1.13)
  159. c = mpc(0.51+2.16j)
  160. d = mpc(1.08-0.99j)
  161. for x in [a,b,c,d]:
  162. for y in [a,b,c,d]:
  163. for op, fop in ops:
  164. if fop is not fdiv:
  165. mp.prec = 200
  166. z0 = op(x,y)
  167. mp.prec = 60
  168. z1 = op(x,y)
  169. mp.prec = 53
  170. z2 = op(x,y)
  171. assert fop(x, y, prec=60) == z1
  172. assert fop(x, y) == z2
  173. if fop is not fdiv:
  174. assert fop(x, y, prec=inf) == z0
  175. assert fop(x, y, dps=inf) == z0
  176. assert fop(x, y, exact=True) == z0
  177. assert fneg(fneg(z1, exact=True), prec=inf) == z1
  178. assert fneg(z1) == -(+z1)
  179. mp.dps = 15
  180. def test_exact_integer_arithmetic():
  181. # XXX: re-fix this so that all operations are tested with all rounding modes
  182. random.seed(0)
  183. for prec in [6, 10, 25, 40, 100, 250, 725]:
  184. for rounding in ['d', 'u', 'f', 'c', 'n']:
  185. mp.dps = prec
  186. M = 10**(prec-2)
  187. M2 = 10**(prec//2-2)
  188. for i in range(10):
  189. a = random.randint(-M, M)
  190. b = random.randint(-M, M)
  191. assert mpf(a, rounding=rounding) == a
  192. assert int(mpf(a, rounding=rounding)) == a
  193. assert int(mpf(str(a), rounding=rounding)) == a
  194. assert mpf(a) + mpf(b) == a + b
  195. assert mpf(a) - mpf(b) == a - b
  196. assert -mpf(a) == -a
  197. a = random.randint(-M2, M2)
  198. b = random.randint(-M2, M2)
  199. assert mpf(a) * mpf(b) == a*b
  200. assert mpf_mul(from_int(a), from_int(b), mp.prec, rounding) == from_int(a*b)
  201. mp.dps = 15
  202. def test_odd_int_bug():
  203. assert to_int(from_int(3), round_nearest) == 3
  204. def test_str_1000_digits():
  205. mp.dps = 1001
  206. # last digit may be wrong
  207. assert str(mpf(2)**0.5)[-10:-1] == '9518488472'[:9]
  208. assert str(pi)[-10:-1] == '2164201989'[:9]
  209. mp.dps = 15
  210. def test_str_10000_digits():
  211. mp.dps = 10001
  212. # last digit may be wrong
  213. assert str(mpf(2)**0.5)[-10:-1] == '5873258351'[:9]
  214. assert str(pi)[-10:-1] == '5256375678'[:9]
  215. mp.dps = 15
  216. def test_monitor():
  217. f = lambda x: x**2
  218. a = []
  219. b = []
  220. g = monitor(f, a.append, b.append)
  221. assert g(3) == 9
  222. assert g(4) == 16
  223. assert a[0] == ((3,), {})
  224. assert b[0] == 9
  225. def test_nint_distance():
  226. assert nint_distance(mpf(-3)) == (-3, -inf)
  227. assert nint_distance(mpc(-3)) == (-3, -inf)
  228. assert nint_distance(mpf(-3.1)) == (-3, -3)
  229. assert nint_distance(mpf(-3.01)) == (-3, -6)
  230. assert nint_distance(mpf(-3.001)) == (-3, -9)
  231. assert nint_distance(mpf(-3.0001)) == (-3, -13)
  232. assert nint_distance(mpf(-2.9)) == (-3, -3)
  233. assert nint_distance(mpf(-2.99)) == (-3, -6)
  234. assert nint_distance(mpf(-2.999)) == (-3, -9)
  235. assert nint_distance(mpf(-2.9999)) == (-3, -13)
  236. assert nint_distance(mpc(-3+0.1j)) == (-3, -3)
  237. assert nint_distance(mpc(-3+0.01j)) == (-3, -6)
  238. assert nint_distance(mpc(-3.1+0.1j)) == (-3, -3)
  239. assert nint_distance(mpc(-3.01+0.01j)) == (-3, -6)
  240. assert nint_distance(mpc(-3.001+0.001j)) == (-3, -9)
  241. assert nint_distance(mpf(0)) == (0, -inf)
  242. assert nint_distance(mpf(0.01)) == (0, -6)
  243. assert nint_distance(mpf('1e-100')) == (0, -332)
  244. def test_floor_ceil_nint_frac():
  245. mp.dps = 15
  246. for n in range(-10,10):
  247. assert floor(n) == n
  248. assert floor(n+0.5) == n
  249. assert ceil(n) == n
  250. assert ceil(n+0.5) == n+1
  251. assert nint(n) == n
  252. # nint rounds to even
  253. if n % 2 == 1:
  254. assert nint(n+0.5) == n+1
  255. else:
  256. assert nint(n+0.5) == n
  257. assert floor(inf) == inf
  258. assert floor(ninf) == ninf
  259. assert isnan(floor(nan))
  260. assert ceil(inf) == inf
  261. assert ceil(ninf) == ninf
  262. assert isnan(ceil(nan))
  263. assert nint(inf) == inf
  264. assert nint(ninf) == ninf
  265. assert isnan(nint(nan))
  266. assert floor(0.1) == 0
  267. assert floor(0.9) == 0
  268. assert floor(-0.1) == -1
  269. assert floor(-0.9) == -1
  270. assert floor(10000000000.1) == 10000000000
  271. assert floor(10000000000.9) == 10000000000
  272. assert floor(-10000000000.1) == -10000000000-1
  273. assert floor(-10000000000.9) == -10000000000-1
  274. assert floor(1e-100) == 0
  275. assert floor(-1e-100) == -1
  276. assert floor(1e100) == 1e100
  277. assert floor(-1e100) == -1e100
  278. assert ceil(0.1) == 1
  279. assert ceil(0.9) == 1
  280. assert ceil(-0.1) == 0
  281. assert ceil(-0.9) == 0
  282. assert ceil(10000000000.1) == 10000000000+1
  283. assert ceil(10000000000.9) == 10000000000+1
  284. assert ceil(-10000000000.1) == -10000000000
  285. assert ceil(-10000000000.9) == -10000000000
  286. assert ceil(1e-100) == 1
  287. assert ceil(-1e-100) == 0
  288. assert ceil(1e100) == 1e100
  289. assert ceil(-1e100) == -1e100
  290. assert nint(0.1) == 0
  291. assert nint(0.9) == 1
  292. assert nint(-0.1) == 0
  293. assert nint(-0.9) == -1
  294. assert nint(10000000000.1) == 10000000000
  295. assert nint(10000000000.9) == 10000000000+1
  296. assert nint(-10000000000.1) == -10000000000
  297. assert nint(-10000000000.9) == -10000000000-1
  298. assert nint(1e-100) == 0
  299. assert nint(-1e-100) == 0
  300. assert nint(1e100) == 1e100
  301. assert nint(-1e100) == -1e100
  302. assert floor(3.2+4.6j) == 3+4j
  303. assert ceil(3.2+4.6j) == 4+5j
  304. assert nint(3.2+4.6j) == 3+5j
  305. for n in range(-10,10):
  306. assert frac(n) == 0
  307. assert frac(0.25) == 0.25
  308. assert frac(1.25) == 0.25
  309. assert frac(2.25) == 0.25
  310. assert frac(-0.25) == 0.75
  311. assert frac(-1.25) == 0.75
  312. assert frac(-2.25) == 0.75
  313. assert frac('1e100000000000000') == 0
  314. u = mpf('1e-100000000000000')
  315. assert frac(u) == u
  316. assert frac(-u) == 1 # rounding!
  317. u = mpf('1e-400')
  318. assert frac(-u, prec=0) == fsub(1, u, exact=True)
  319. assert frac(3.25+4.75j) == 0.25+0.75j
  320. def test_isnan_etc():
  321. from mpmath.rational import mpq
  322. assert isnan(nan) == True
  323. assert isnan(3) == False
  324. assert isnan(mpf(3)) == False
  325. assert isnan(inf) == False
  326. assert isnan(mpc(2,nan)) == True
  327. assert isnan(mpc(2,nan)) == True
  328. assert isnan(mpc(nan,nan)) == True
  329. assert isnan(mpc(2,2)) == False
  330. assert isnan(mpc(nan,inf)) == True
  331. assert isnan(mpc(inf,inf)) == False
  332. assert isnan(mpq((3,2))) == False
  333. assert isnan(mpq((0,1))) == False
  334. assert isinf(inf) == True
  335. assert isinf(-inf) == True
  336. assert isinf(3) == False
  337. assert isinf(nan) == False
  338. assert isinf(3+4j) == False
  339. assert isinf(mpc(inf)) == True
  340. assert isinf(mpc(3,inf)) == True
  341. assert isinf(mpc(inf,3)) == True
  342. assert isinf(mpc(inf,inf)) == True
  343. assert isinf(mpc(nan,inf)) == True
  344. assert isinf(mpc(inf,nan)) == True
  345. assert isinf(mpc(nan,nan)) == False
  346. assert isinf(mpq((3,2))) == False
  347. assert isinf(mpq((0,1))) == False
  348. assert isnormal(3) == True
  349. assert isnormal(3.5) == True
  350. assert isnormal(mpf(3.5)) == True
  351. assert isnormal(0) == False
  352. assert isnormal(mpf(0)) == False
  353. assert isnormal(0.0) == False
  354. assert isnormal(inf) == False
  355. assert isnormal(-inf) == False
  356. assert isnormal(nan) == False
  357. assert isnormal(float(inf)) == False
  358. assert isnormal(mpc(0,0)) == False
  359. assert isnormal(mpc(3,0)) == True
  360. assert isnormal(mpc(0,3)) == True
  361. assert isnormal(mpc(3,3)) == True
  362. assert isnormal(mpc(0,nan)) == False
  363. assert isnormal(mpc(0,inf)) == False
  364. assert isnormal(mpc(3,nan)) == False
  365. assert isnormal(mpc(3,inf)) == False
  366. assert isnormal(mpc(3,-inf)) == False
  367. assert isnormal(mpc(nan,0)) == False
  368. assert isnormal(mpc(inf,0)) == False
  369. assert isnormal(mpc(nan,3)) == False
  370. assert isnormal(mpc(inf,3)) == False
  371. assert isnormal(mpc(inf,nan)) == False
  372. assert isnormal(mpc(nan,inf)) == False
  373. assert isnormal(mpc(nan,nan)) == False
  374. assert isnormal(mpc(inf,inf)) == False
  375. assert isnormal(mpq((3,2))) == True
  376. assert isnormal(mpq((0,1))) == False
  377. assert isint(3) == True
  378. assert isint(0) == True
  379. assert isint(long(3)) == True
  380. assert isint(long(0)) == True
  381. assert isint(mpf(3)) == True
  382. assert isint(mpf(0)) == True
  383. assert isint(mpf(-3)) == True
  384. assert isint(mpf(3.2)) == False
  385. assert isint(3.2) == False
  386. assert isint(nan) == False
  387. assert isint(inf) == False
  388. assert isint(-inf) == False
  389. assert isint(mpc(0)) == True
  390. assert isint(mpc(3)) == True
  391. assert isint(mpc(3.2)) == False
  392. assert isint(mpc(3,inf)) == False
  393. assert isint(mpc(inf)) == False
  394. assert isint(mpc(3,2)) == False
  395. assert isint(mpc(0,2)) == False
  396. assert isint(mpc(3,2),gaussian=True) == True
  397. assert isint(mpc(3,0),gaussian=True) == True
  398. assert isint(mpc(0,3),gaussian=True) == True
  399. assert isint(3+4j) == False
  400. assert isint(3+4j, gaussian=True) == True
  401. assert isint(3+0j) == True
  402. assert isint(mpq((3,2))) == False
  403. assert isint(mpq((3,9))) == False
  404. assert isint(mpq((9,3))) == True
  405. assert isint(mpq((0,4))) == True
  406. assert isint(mpq((1,1))) == True
  407. assert isint(mpq((-1,1))) == True
  408. assert mp.isnpint(0) == True
  409. assert mp.isnpint(1) == False
  410. assert mp.isnpint(-1) == True
  411. assert mp.isnpint(-1.1) == False
  412. assert mp.isnpint(-1.0) == True
  413. assert mp.isnpint(mp.mpq(1,2)) == False
  414. assert mp.isnpint(mp.mpq(-1,2)) == False
  415. assert mp.isnpint(mp.mpq(-3,1)) == True
  416. assert mp.isnpint(mp.mpq(0,1)) == True
  417. assert mp.isnpint(mp.mpq(1,1)) == False
  418. assert mp.isnpint(0+0j) == True
  419. assert mp.isnpint(-1+0j) == True
  420. assert mp.isnpint(-1.1+0j) == False
  421. assert mp.isnpint(-1+0.1j) == False
  422. assert mp.isnpint(0+0.1j) == False
  423. def test_issue_438():
  424. assert mpf(finf) == mpf('inf')
  425. assert mpf(fninf) == mpf('-inf')
  426. assert mpf(fnan)._mpf_ == mpf('nan')._mpf_