123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219 |
- import itertools
- import contextlib
- import operator
- import pytest
- import numpy as np
- import numpy.core._multiarray_tests as mt
- from numpy.testing import assert_raises, assert_equal
- INT64_MAX = np.iinfo(np.int64).max
- INT64_MIN = np.iinfo(np.int64).min
- INT64_MID = 2**32
- # int128 is not two's complement, the sign bit is separate
- INT128_MAX = 2**128 - 1
- INT128_MIN = -INT128_MAX
- INT128_MID = 2**64
- INT64_VALUES = (
- [INT64_MIN + j for j in range(20)] +
- [INT64_MAX - j for j in range(20)] +
- [INT64_MID + j for j in range(-20, 20)] +
- [2*INT64_MID + j for j in range(-20, 20)] +
- [INT64_MID//2 + j for j in range(-20, 20)] +
- list(range(-70, 70))
- )
- INT128_VALUES = (
- [INT128_MIN + j for j in range(20)] +
- [INT128_MAX - j for j in range(20)] +
- [INT128_MID + j for j in range(-20, 20)] +
- [2*INT128_MID + j for j in range(-20, 20)] +
- [INT128_MID//2 + j for j in range(-20, 20)] +
- list(range(-70, 70)) +
- [False] # negative zero
- )
- INT64_POS_VALUES = [x for x in INT64_VALUES if x > 0]
- @contextlib.contextmanager
- def exc_iter(*args):
- """
- Iterate over Cartesian product of *args, and if an exception is raised,
- add information of the current iterate.
- """
- value = [None]
- def iterate():
- for v in itertools.product(*args):
- value[0] = v
- yield v
- try:
- yield iterate()
- except Exception:
- import traceback
- msg = "At: %r\n%s" % (repr(value[0]),
- traceback.format_exc())
- raise AssertionError(msg)
- def test_safe_binop():
- # Test checked arithmetic routines
- ops = [
- (operator.add, 1),
- (operator.sub, 2),
- (operator.mul, 3)
- ]
- with exc_iter(ops, INT64_VALUES, INT64_VALUES) as it:
- for xop, a, b in it:
- pyop, op = xop
- c = pyop(a, b)
- if not (INT64_MIN <= c <= INT64_MAX):
- assert_raises(OverflowError, mt.extint_safe_binop, a, b, op)
- else:
- d = mt.extint_safe_binop(a, b, op)
- if c != d:
- # assert_equal is slow
- assert_equal(d, c)
- def test_to_128():
- with exc_iter(INT64_VALUES) as it:
- for a, in it:
- b = mt.extint_to_128(a)
- if a != b:
- assert_equal(b, a)
- def test_to_64():
- with exc_iter(INT128_VALUES) as it:
- for a, in it:
- if not (INT64_MIN <= a <= INT64_MAX):
- assert_raises(OverflowError, mt.extint_to_64, a)
- else:
- b = mt.extint_to_64(a)
- if a != b:
- assert_equal(b, a)
- def test_mul_64_64():
- with exc_iter(INT64_VALUES, INT64_VALUES) as it:
- for a, b in it:
- c = a * b
- d = mt.extint_mul_64_64(a, b)
- if c != d:
- assert_equal(d, c)
- def test_add_128():
- with exc_iter(INT128_VALUES, INT128_VALUES) as it:
- for a, b in it:
- c = a + b
- if not (INT128_MIN <= c <= INT128_MAX):
- assert_raises(OverflowError, mt.extint_add_128, a, b)
- else:
- d = mt.extint_add_128(a, b)
- if c != d:
- assert_equal(d, c)
- def test_sub_128():
- with exc_iter(INT128_VALUES, INT128_VALUES) as it:
- for a, b in it:
- c = a - b
- if not (INT128_MIN <= c <= INT128_MAX):
- assert_raises(OverflowError, mt.extint_sub_128, a, b)
- else:
- d = mt.extint_sub_128(a, b)
- if c != d:
- assert_equal(d, c)
- def test_neg_128():
- with exc_iter(INT128_VALUES) as it:
- for a, in it:
- b = -a
- c = mt.extint_neg_128(a)
- if b != c:
- assert_equal(c, b)
- def test_shl_128():
- with exc_iter(INT128_VALUES) as it:
- for a, in it:
- if a < 0:
- b = -(((-a) << 1) & (2**128-1))
- else:
- b = (a << 1) & (2**128-1)
- c = mt.extint_shl_128(a)
- if b != c:
- assert_equal(c, b)
- def test_shr_128():
- with exc_iter(INT128_VALUES) as it:
- for a, in it:
- if a < 0:
- b = -((-a) >> 1)
- else:
- b = a >> 1
- c = mt.extint_shr_128(a)
- if b != c:
- assert_equal(c, b)
- def test_gt_128():
- with exc_iter(INT128_VALUES, INT128_VALUES) as it:
- for a, b in it:
- c = a > b
- d = mt.extint_gt_128(a, b)
- if c != d:
- assert_equal(d, c)
- @pytest.mark.slow
- def test_divmod_128_64():
- with exc_iter(INT128_VALUES, INT64_POS_VALUES) as it:
- for a, b in it:
- if a >= 0:
- c, cr = divmod(a, b)
- else:
- c, cr = divmod(-a, b)
- c = -c
- cr = -cr
- d, dr = mt.extint_divmod_128_64(a, b)
- if c != d or d != dr or b*d + dr != a:
- assert_equal(d, c)
- assert_equal(dr, cr)
- assert_equal(b*d + dr, a)
- def test_floordiv_128_64():
- with exc_iter(INT128_VALUES, INT64_POS_VALUES) as it:
- for a, b in it:
- c = a // b
- d = mt.extint_floordiv_128_64(a, b)
- if c != d:
- assert_equal(d, c)
- def test_ceildiv_128_64():
- with exc_iter(INT128_VALUES, INT64_POS_VALUES) as it:
- for a, b in it:
- c = (a + b - 1) // b
- d = mt.extint_ceildiv_128_64(a, b)
- if c != d:
- assert_equal(d, c)
|