_testutils.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import numpy as np
  2. class _FakeMatrix:
  3. def __init__(self, data):
  4. self._data = data
  5. self.__array_interface__ = data.__array_interface__
  6. class _FakeMatrix2:
  7. def __init__(self, data):
  8. self._data = data
  9. def __array__(self):
  10. return self._data
  11. def _get_array(shape, dtype):
  12. """
  13. Get a test array of given shape and data type.
  14. Returned NxN matrices are posdef, and 2xN are banded-posdef.
  15. """
  16. if len(shape) == 2 and shape[0] == 2:
  17. # yield a banded positive definite one
  18. x = np.zeros(shape, dtype=dtype)
  19. x[0, 1:] = -1
  20. x[1] = 2
  21. return x
  22. elif len(shape) == 2 and shape[0] == shape[1]:
  23. # always yield a positive definite matrix
  24. x = np.zeros(shape, dtype=dtype)
  25. j = np.arange(shape[0])
  26. x[j, j] = 2
  27. x[j[:-1], j[:-1]+1] = -1
  28. x[j[:-1]+1, j[:-1]] = -1
  29. return x
  30. else:
  31. np.random.seed(1234)
  32. return np.random.randn(*shape).astype(dtype)
  33. def _id(x):
  34. return x
  35. def assert_no_overwrite(call, shapes, dtypes=None):
  36. """
  37. Test that a call does not overwrite its input arguments
  38. """
  39. if dtypes is None:
  40. dtypes = [np.float32, np.float64, np.complex64, np.complex128]
  41. for dtype in dtypes:
  42. for order in ["C", "F"]:
  43. for faker in [_id, _FakeMatrix, _FakeMatrix2]:
  44. orig_inputs = [_get_array(s, dtype) for s in shapes]
  45. inputs = [faker(x.copy(order)) for x in orig_inputs]
  46. call(*inputs)
  47. msg = "call modified inputs [%r, %r]" % (dtype, faker)
  48. for a, b in zip(inputs, orig_inputs):
  49. np.testing.assert_equal(a, b, err_msg=msg)