test_matmul.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import operator
  2. import numpy as np
  3. import pytest
  4. from pandas import (
  5. DataFrame,
  6. Index,
  7. Series,
  8. )
  9. import pandas._testing as tm
  10. class TestMatMul:
  11. def test_matmul(self):
  12. # matmul test is for GH#10259
  13. a = DataFrame(
  14. np.random.randn(3, 4), index=["a", "b", "c"], columns=["p", "q", "r", "s"]
  15. )
  16. b = DataFrame(
  17. np.random.randn(4, 2), index=["p", "q", "r", "s"], columns=["one", "two"]
  18. )
  19. # DataFrame @ DataFrame
  20. result = operator.matmul(a, b)
  21. expected = DataFrame(
  22. np.dot(a.values, b.values), index=["a", "b", "c"], columns=["one", "two"]
  23. )
  24. tm.assert_frame_equal(result, expected)
  25. # DataFrame @ Series
  26. result = operator.matmul(a, b.one)
  27. expected = Series(np.dot(a.values, b.one.values), index=["a", "b", "c"])
  28. tm.assert_series_equal(result, expected)
  29. # np.array @ DataFrame
  30. result = operator.matmul(a.values, b)
  31. assert isinstance(result, DataFrame)
  32. assert result.columns.equals(b.columns)
  33. assert result.index.equals(Index(range(3)))
  34. expected = np.dot(a.values, b.values)
  35. tm.assert_almost_equal(result.values, expected)
  36. # nested list @ DataFrame (__rmatmul__)
  37. result = operator.matmul(a.values.tolist(), b)
  38. expected = DataFrame(
  39. np.dot(a.values, b.values), index=["a", "b", "c"], columns=["one", "two"]
  40. )
  41. tm.assert_almost_equal(result.values, expected.values)
  42. # mixed dtype DataFrame @ DataFrame
  43. a["q"] = a.q.round().astype(int)
  44. result = operator.matmul(a, b)
  45. expected = DataFrame(
  46. np.dot(a.values, b.values), index=["a", "b", "c"], columns=["one", "two"]
  47. )
  48. tm.assert_frame_equal(result, expected)
  49. # different dtypes DataFrame @ DataFrame
  50. a = a.astype(int)
  51. result = operator.matmul(a, b)
  52. expected = DataFrame(
  53. np.dot(a.values, b.values), index=["a", "b", "c"], columns=["one", "two"]
  54. )
  55. tm.assert_frame_equal(result, expected)
  56. # unaligned
  57. df = DataFrame(np.random.randn(3, 4), index=[1, 2, 3], columns=range(4))
  58. df2 = DataFrame(np.random.randn(5, 3), index=range(5), columns=[1, 2, 3])
  59. with pytest.raises(ValueError, match="aligned"):
  60. operator.matmul(df, df2)
  61. def test_matmul_message_shapes(self):
  62. # GH#21581 exception message should reflect original shapes,
  63. # not transposed shapes
  64. a = np.random.rand(10, 4)
  65. b = np.random.rand(5, 3)
  66. df = DataFrame(b)
  67. msg = r"shapes \(10, 4\) and \(5, 3\) not aligned"
  68. with pytest.raises(ValueError, match=msg):
  69. a @ df
  70. with pytest.raises(ValueError, match=msg):
  71. a.tolist() @ df