test_scalarinherit.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. """ Test printing of scalar types.
  2. """
  3. import pytest
  4. import numpy as np
  5. from numpy.testing import assert_, assert_raises
  6. class A:
  7. pass
  8. class B(A, np.float64):
  9. pass
  10. class C(B):
  11. pass
  12. class D(C, B):
  13. pass
  14. class B0(np.float64, A):
  15. pass
  16. class C0(B0):
  17. pass
  18. class HasNew:
  19. def __new__(cls, *args, **kwargs):
  20. return cls, args, kwargs
  21. class B1(np.float64, HasNew):
  22. pass
  23. class TestInherit:
  24. def test_init(self):
  25. x = B(1.0)
  26. assert_(str(x) == '1.0')
  27. y = C(2.0)
  28. assert_(str(y) == '2.0')
  29. z = D(3.0)
  30. assert_(str(z) == '3.0')
  31. def test_init2(self):
  32. x = B0(1.0)
  33. assert_(str(x) == '1.0')
  34. y = C0(2.0)
  35. assert_(str(y) == '2.0')
  36. def test_gh_15395(self):
  37. # HasNew is the second base, so `np.float64` should have priority
  38. x = B1(1.0)
  39. assert_(str(x) == '1.0')
  40. # previously caused RecursionError!?
  41. with pytest.raises(TypeError):
  42. B1(1.0, 2.0)
  43. class TestCharacter:
  44. def test_char_radd(self):
  45. # GH issue 9620, reached gentype_add and raise TypeError
  46. np_s = np.string_('abc')
  47. np_u = np.unicode_('abc')
  48. s = b'def'
  49. u = 'def'
  50. assert_(np_s.__radd__(np_s) is NotImplemented)
  51. assert_(np_s.__radd__(np_u) is NotImplemented)
  52. assert_(np_s.__radd__(s) is NotImplemented)
  53. assert_(np_s.__radd__(u) is NotImplemented)
  54. assert_(np_u.__radd__(np_s) is NotImplemented)
  55. assert_(np_u.__radd__(np_u) is NotImplemented)
  56. assert_(np_u.__radd__(s) is NotImplemented)
  57. assert_(np_u.__radd__(u) is NotImplemented)
  58. assert_(s + np_s == b'defabc')
  59. assert_(u + np_u == 'defabc')
  60. class MyStr(str, np.generic):
  61. # would segfault
  62. pass
  63. with assert_raises(TypeError):
  64. # Previously worked, but gave completely wrong result
  65. ret = s + MyStr('abc')
  66. class MyBytes(bytes, np.generic):
  67. # would segfault
  68. pass
  69. ret = s + MyBytes(b'abc')
  70. assert(type(ret) is type(s))
  71. assert ret == b"defabc"
  72. def test_char_repeat(self):
  73. np_s = np.string_('abc')
  74. np_u = np.unicode_('abc')
  75. res_s = b'abc' * 5
  76. res_u = 'abc' * 5
  77. assert_(np_s * 5 == res_s)
  78. assert_(np_u * 5 == res_u)