test_strings.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import pytest
  2. import operator
  3. import numpy as np
  4. from numpy.testing import assert_array_equal
  5. COMPARISONS = [
  6. (operator.eq, np.equal, "=="),
  7. (operator.ne, np.not_equal, "!="),
  8. (operator.lt, np.less, "<"),
  9. (operator.le, np.less_equal, "<="),
  10. (operator.gt, np.greater, ">"),
  11. (operator.ge, np.greater_equal, ">="),
  12. ]
  13. @pytest.mark.parametrize(["op", "ufunc", "sym"], COMPARISONS)
  14. def test_mixed_string_comparison_ufuncs_fail(op, ufunc, sym):
  15. arr_string = np.array(["a", "b"], dtype="S")
  16. arr_unicode = np.array(["a", "c"], dtype="U")
  17. with pytest.raises(TypeError, match="did not contain a loop"):
  18. ufunc(arr_string, arr_unicode)
  19. with pytest.raises(TypeError, match="did not contain a loop"):
  20. ufunc(arr_unicode, arr_string)
  21. @pytest.mark.parametrize(["op", "ufunc", "sym"], COMPARISONS)
  22. def test_mixed_string_comparisons_ufuncs_with_cast(op, ufunc, sym):
  23. arr_string = np.array(["a", "b"], dtype="S")
  24. arr_unicode = np.array(["a", "c"], dtype="U")
  25. # While there is no loop, manual casting is acceptable:
  26. res1 = ufunc(arr_string, arr_unicode, signature="UU->?", casting="unsafe")
  27. res2 = ufunc(arr_string, arr_unicode, signature="SS->?", casting="unsafe")
  28. expected = op(arr_string.astype('U'), arr_unicode)
  29. assert_array_equal(res1, expected)
  30. assert_array_equal(res2, expected)
  31. @pytest.mark.parametrize(["op", "ufunc", "sym"], COMPARISONS)
  32. @pytest.mark.parametrize("dtypes", [
  33. ("S2", "S2"), ("S2", "S10"),
  34. ("<U1", "<U1"), ("<U1", ">U1"), (">U1", ">U1"),
  35. ("<U1", "<U10"), ("<U1", ">U10")])
  36. @pytest.mark.parametrize("aligned", [True, False])
  37. def test_string_comparisons(op, ufunc, sym, dtypes, aligned):
  38. # ensure native byte-order for the first view to stay within unicode range
  39. native_dt = np.dtype(dtypes[0]).newbyteorder("=")
  40. arr = np.arange(2**15).view(native_dt).astype(dtypes[0])
  41. if not aligned:
  42. # Make `arr` unaligned:
  43. new = np.zeros(arr.nbytes + 1, dtype=np.uint8)[1:].view(dtypes[0])
  44. new[...] = arr
  45. arr = new
  46. arr2 = arr.astype(dtypes[1], copy=True)
  47. np.random.shuffle(arr2)
  48. arr[0] = arr2[0] # make sure one matches
  49. expected = [op(d1, d2) for d1, d2 in zip(arr.tolist(), arr2.tolist())]
  50. assert_array_equal(op(arr, arr2), expected)
  51. assert_array_equal(ufunc(arr, arr2), expected)
  52. assert_array_equal(np.compare_chararrays(arr, arr2, sym, False), expected)
  53. expected = [op(d2, d1) for d1, d2 in zip(arr.tolist(), arr2.tolist())]
  54. assert_array_equal(op(arr2, arr), expected)
  55. assert_array_equal(ufunc(arr2, arr), expected)
  56. assert_array_equal(np.compare_chararrays(arr2, arr, sym, False), expected)
  57. @pytest.mark.parametrize(["op", "ufunc", "sym"], COMPARISONS)
  58. @pytest.mark.parametrize("dtypes", [
  59. ("S2", "S2"), ("S2", "S10"), ("<U1", "<U1"), ("<U1", ">U10")])
  60. def test_string_comparisons_empty(op, ufunc, sym, dtypes):
  61. arr = np.empty((1, 0, 1, 5), dtype=dtypes[0])
  62. arr2 = np.empty((100, 1, 0, 1), dtype=dtypes[1])
  63. expected = np.empty(np.broadcast_shapes(arr.shape, arr2.shape), dtype=bool)
  64. assert_array_equal(op(arr, arr2), expected)
  65. assert_array_equal(ufunc(arr, arr2), expected)
  66. assert_array_equal(np.compare_chararrays(arr, arr2, sym, False), expected)
  67. @pytest.mark.parametrize("str_dt", ["S", "U"])
  68. @pytest.mark.parametrize("float_dt", np.typecodes["AllFloat"])
  69. def test_float_to_string_cast(str_dt, float_dt):
  70. float_dt = np.dtype(float_dt)
  71. fi = np.finfo(float_dt)
  72. arr = np.array([np.nan, np.inf, -np.inf, fi.max, fi.min], dtype=float_dt)
  73. expected = ["nan", "inf", "-inf", repr(fi.max), repr(fi.min)]
  74. if float_dt.kind == 'c':
  75. expected = [f"({r}+0j)" for r in expected]
  76. res = arr.astype(str_dt)
  77. assert_array_equal(res, np.array(expected, dtype=str_dt))