test_extint128.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. import itertools
  2. import contextlib
  3. import operator
  4. import pytest
  5. import numpy as np
  6. import numpy.core._multiarray_tests as mt
  7. from numpy.testing import assert_raises, assert_equal
  8. INT64_MAX = np.iinfo(np.int64).max
  9. INT64_MIN = np.iinfo(np.int64).min
  10. INT64_MID = 2**32
  11. # int128 is not two's complement, the sign bit is separate
  12. INT128_MAX = 2**128 - 1
  13. INT128_MIN = -INT128_MAX
  14. INT128_MID = 2**64
  15. INT64_VALUES = (
  16. [INT64_MIN + j for j in range(20)] +
  17. [INT64_MAX - j for j in range(20)] +
  18. [INT64_MID + j for j in range(-20, 20)] +
  19. [2*INT64_MID + j for j in range(-20, 20)] +
  20. [INT64_MID//2 + j for j in range(-20, 20)] +
  21. list(range(-70, 70))
  22. )
  23. INT128_VALUES = (
  24. [INT128_MIN + j for j in range(20)] +
  25. [INT128_MAX - j for j in range(20)] +
  26. [INT128_MID + j for j in range(-20, 20)] +
  27. [2*INT128_MID + j for j in range(-20, 20)] +
  28. [INT128_MID//2 + j for j in range(-20, 20)] +
  29. list(range(-70, 70)) +
  30. [False] # negative zero
  31. )
  32. INT64_POS_VALUES = [x for x in INT64_VALUES if x > 0]
  33. @contextlib.contextmanager
  34. def exc_iter(*args):
  35. """
  36. Iterate over Cartesian product of *args, and if an exception is raised,
  37. add information of the current iterate.
  38. """
  39. value = [None]
  40. def iterate():
  41. for v in itertools.product(*args):
  42. value[0] = v
  43. yield v
  44. try:
  45. yield iterate()
  46. except Exception:
  47. import traceback
  48. msg = "At: %r\n%s" % (repr(value[0]),
  49. traceback.format_exc())
  50. raise AssertionError(msg)
  51. def test_safe_binop():
  52. # Test checked arithmetic routines
  53. ops = [
  54. (operator.add, 1),
  55. (operator.sub, 2),
  56. (operator.mul, 3)
  57. ]
  58. with exc_iter(ops, INT64_VALUES, INT64_VALUES) as it:
  59. for xop, a, b in it:
  60. pyop, op = xop
  61. c = pyop(a, b)
  62. if not (INT64_MIN <= c <= INT64_MAX):
  63. assert_raises(OverflowError, mt.extint_safe_binop, a, b, op)
  64. else:
  65. d = mt.extint_safe_binop(a, b, op)
  66. if c != d:
  67. # assert_equal is slow
  68. assert_equal(d, c)
  69. def test_to_128():
  70. with exc_iter(INT64_VALUES) as it:
  71. for a, in it:
  72. b = mt.extint_to_128(a)
  73. if a != b:
  74. assert_equal(b, a)
  75. def test_to_64():
  76. with exc_iter(INT128_VALUES) as it:
  77. for a, in it:
  78. if not (INT64_MIN <= a <= INT64_MAX):
  79. assert_raises(OverflowError, mt.extint_to_64, a)
  80. else:
  81. b = mt.extint_to_64(a)
  82. if a != b:
  83. assert_equal(b, a)
  84. def test_mul_64_64():
  85. with exc_iter(INT64_VALUES, INT64_VALUES) as it:
  86. for a, b in it:
  87. c = a * b
  88. d = mt.extint_mul_64_64(a, b)
  89. if c != d:
  90. assert_equal(d, c)
  91. def test_add_128():
  92. with exc_iter(INT128_VALUES, INT128_VALUES) as it:
  93. for a, b in it:
  94. c = a + b
  95. if not (INT128_MIN <= c <= INT128_MAX):
  96. assert_raises(OverflowError, mt.extint_add_128, a, b)
  97. else:
  98. d = mt.extint_add_128(a, b)
  99. if c != d:
  100. assert_equal(d, c)
  101. def test_sub_128():
  102. with exc_iter(INT128_VALUES, INT128_VALUES) as it:
  103. for a, b in it:
  104. c = a - b
  105. if not (INT128_MIN <= c <= INT128_MAX):
  106. assert_raises(OverflowError, mt.extint_sub_128, a, b)
  107. else:
  108. d = mt.extint_sub_128(a, b)
  109. if c != d:
  110. assert_equal(d, c)
  111. def test_neg_128():
  112. with exc_iter(INT128_VALUES) as it:
  113. for a, in it:
  114. b = -a
  115. c = mt.extint_neg_128(a)
  116. if b != c:
  117. assert_equal(c, b)
  118. def test_shl_128():
  119. with exc_iter(INT128_VALUES) as it:
  120. for a, in it:
  121. if a < 0:
  122. b = -(((-a) << 1) & (2**128-1))
  123. else:
  124. b = (a << 1) & (2**128-1)
  125. c = mt.extint_shl_128(a)
  126. if b != c:
  127. assert_equal(c, b)
  128. def test_shr_128():
  129. with exc_iter(INT128_VALUES) as it:
  130. for a, in it:
  131. if a < 0:
  132. b = -((-a) >> 1)
  133. else:
  134. b = a >> 1
  135. c = mt.extint_shr_128(a)
  136. if b != c:
  137. assert_equal(c, b)
  138. def test_gt_128():
  139. with exc_iter(INT128_VALUES, INT128_VALUES) as it:
  140. for a, b in it:
  141. c = a > b
  142. d = mt.extint_gt_128(a, b)
  143. if c != d:
  144. assert_equal(d, c)
  145. @pytest.mark.slow
  146. def test_divmod_128_64():
  147. with exc_iter(INT128_VALUES, INT64_POS_VALUES) as it:
  148. for a, b in it:
  149. if a >= 0:
  150. c, cr = divmod(a, b)
  151. else:
  152. c, cr = divmod(-a, b)
  153. c = -c
  154. cr = -cr
  155. d, dr = mt.extint_divmod_128_64(a, b)
  156. if c != d or d != dr or b*d + dr != a:
  157. assert_equal(d, c)
  158. assert_equal(dr, cr)
  159. assert_equal(b*d + dr, a)
  160. def test_floordiv_128_64():
  161. with exc_iter(INT128_VALUES, INT64_POS_VALUES) as it:
  162. for a, b in it:
  163. c = a // b
  164. d = mt.extint_floordiv_128_64(a, b)
  165. if c != d:
  166. assert_equal(d, c)
  167. def test_ceildiv_128_64():
  168. with exc_iter(INT128_VALUES, INT64_POS_VALUES) as it:
  169. for a, b in it:
  170. c = (a + b - 1) // b
  171. d = mt.extint_ceildiv_128_64(a, b)
  172. if c != d:
  173. assert_equal(d, c)