test_matmul.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import operator
  2. import numpy as np
  3. import pytest
  4. from pandas import (
  5. DataFrame,
  6. Series,
  7. )
  8. import pandas._testing as tm
  9. class TestMatmul:
  10. def test_matmul(self):
  11. # matmul test is for GH#10259
  12. a = Series(np.random.randn(4), index=["p", "q", "r", "s"])
  13. b = DataFrame(
  14. np.random.randn(3, 4), index=["1", "2", "3"], columns=["p", "q", "r", "s"]
  15. ).T
  16. # Series @ DataFrame -> Series
  17. result = operator.matmul(a, b)
  18. expected = Series(np.dot(a.values, b.values), index=["1", "2", "3"])
  19. tm.assert_series_equal(result, expected)
  20. # DataFrame @ Series -> Series
  21. result = operator.matmul(b.T, a)
  22. expected = Series(np.dot(b.T.values, a.T.values), index=["1", "2", "3"])
  23. tm.assert_series_equal(result, expected)
  24. # Series @ Series -> scalar
  25. result = operator.matmul(a, a)
  26. expected = np.dot(a.values, a.values)
  27. tm.assert_almost_equal(result, expected)
  28. # GH#21530
  29. # vector (1D np.array) @ Series (__rmatmul__)
  30. result = operator.matmul(a.values, a)
  31. expected = np.dot(a.values, a.values)
  32. tm.assert_almost_equal(result, expected)
  33. # GH#21530
  34. # vector (1D list) @ Series (__rmatmul__)
  35. result = operator.matmul(a.values.tolist(), a)
  36. expected = np.dot(a.values, a.values)
  37. tm.assert_almost_equal(result, expected)
  38. # GH#21530
  39. # matrix (2D np.array) @ Series (__rmatmul__)
  40. result = operator.matmul(b.T.values, a)
  41. expected = np.dot(b.T.values, a.values)
  42. tm.assert_almost_equal(result, expected)
  43. # GH#21530
  44. # matrix (2D nested lists) @ Series (__rmatmul__)
  45. result = operator.matmul(b.T.values.tolist(), a)
  46. expected = np.dot(b.T.values, a.values)
  47. tm.assert_almost_equal(result, expected)
  48. # mixed dtype DataFrame @ Series
  49. a["p"] = int(a.p)
  50. result = operator.matmul(b.T, a)
  51. expected = Series(np.dot(b.T.values, a.T.values), index=["1", "2", "3"])
  52. tm.assert_series_equal(result, expected)
  53. # different dtypes DataFrame @ Series
  54. a = a.astype(int)
  55. result = operator.matmul(b.T, a)
  56. expected = Series(np.dot(b.T.values, a.T.values), index=["1", "2", "3"])
  57. tm.assert_series_equal(result, expected)
  58. msg = r"Dot product shape mismatch, \(4,\) vs \(3,\)"
  59. # exception raised is of type Exception
  60. with pytest.raises(Exception, match=msg):
  61. a.dot(a.values[:3])
  62. msg = "matrices are not aligned"
  63. with pytest.raises(ValueError, match=msg):
  64. a.dot(b.T)