ops.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. from __future__ import annotations
  2. from typing import (
  3. TYPE_CHECKING,
  4. Iterator,
  5. NamedTuple,
  6. )
  7. from pandas._typing import ArrayLike
  8. if TYPE_CHECKING:
  9. from pandas._libs.internals import BlockPlacement
  10. from pandas.core.internals.blocks import Block
  11. from pandas.core.internals.managers import BlockManager
  12. class BlockPairInfo(NamedTuple):
  13. lvals: ArrayLike
  14. rvals: ArrayLike
  15. locs: BlockPlacement
  16. left_ea: bool
  17. right_ea: bool
  18. rblk: Block
  19. def _iter_block_pairs(
  20. left: BlockManager, right: BlockManager
  21. ) -> Iterator[BlockPairInfo]:
  22. # At this point we have already checked the parent DataFrames for
  23. # assert rframe._indexed_same(lframe)
  24. for blk in left.blocks:
  25. locs = blk.mgr_locs
  26. blk_vals = blk.values
  27. left_ea = blk_vals.ndim == 1
  28. rblks = right._slice_take_blocks_ax0(locs.indexer, only_slice=True)
  29. # Assertions are disabled for performance, but should hold:
  30. # if left_ea:
  31. # assert len(locs) == 1, locs
  32. # assert len(rblks) == 1, rblks
  33. # assert rblks[0].shape[0] == 1, rblks[0].shape
  34. for rblk in rblks:
  35. right_ea = rblk.values.ndim == 1
  36. lvals, rvals = _get_same_shape_values(blk, rblk, left_ea, right_ea)
  37. info = BlockPairInfo(lvals, rvals, locs, left_ea, right_ea, rblk)
  38. yield info
  39. def operate_blockwise(
  40. left: BlockManager, right: BlockManager, array_op
  41. ) -> BlockManager:
  42. # At this point we have already checked the parent DataFrames for
  43. # assert rframe._indexed_same(lframe)
  44. res_blks: list[Block] = []
  45. for lvals, rvals, locs, left_ea, right_ea, rblk in _iter_block_pairs(left, right):
  46. res_values = array_op(lvals, rvals)
  47. if left_ea and not right_ea and hasattr(res_values, "reshape"):
  48. res_values = res_values.reshape(1, -1)
  49. nbs = rblk._split_op_result(res_values)
  50. # Assertions are disabled for performance, but should hold:
  51. # if right_ea or left_ea:
  52. # assert len(nbs) == 1
  53. # else:
  54. # assert res_values.shape == lvals.shape, (res_values.shape, lvals.shape)
  55. _reset_block_mgr_locs(nbs, locs)
  56. res_blks.extend(nbs)
  57. # Assertions are disabled for performance, but should hold:
  58. # slocs = {y for nb in res_blks for y in nb.mgr_locs.as_array}
  59. # nlocs = sum(len(nb.mgr_locs.as_array) for nb in res_blks)
  60. # assert nlocs == len(left.items), (nlocs, len(left.items))
  61. # assert len(slocs) == nlocs, (len(slocs), nlocs)
  62. # assert slocs == set(range(nlocs)), slocs
  63. new_mgr = type(right)(tuple(res_blks), axes=right.axes, verify_integrity=False)
  64. return new_mgr
  65. def _reset_block_mgr_locs(nbs: list[Block], locs) -> None:
  66. """
  67. Reset mgr_locs to correspond to our original DataFrame.
  68. """
  69. for nb in nbs:
  70. nblocs = locs[nb.mgr_locs.indexer]
  71. nb.mgr_locs = nblocs
  72. # Assertions are disabled for performance, but should hold:
  73. # assert len(nblocs) == nb.shape[0], (len(nblocs), nb.shape)
  74. # assert all(x in locs.as_array for x in nb.mgr_locs.as_array)
  75. def _get_same_shape_values(
  76. lblk: Block, rblk: Block, left_ea: bool, right_ea: bool
  77. ) -> tuple[ArrayLike, ArrayLike]:
  78. """
  79. Slice lblk.values to align with rblk. Squeeze if we have EAs.
  80. """
  81. lvals = lblk.values
  82. rvals = rblk.values
  83. # Require that the indexing into lvals be slice-like
  84. assert rblk.mgr_locs.is_slice_like, rblk.mgr_locs
  85. # TODO(EA2D): with 2D EAs only this first clause would be needed
  86. if not (left_ea or right_ea):
  87. # error: No overload variant of "__getitem__" of "ExtensionArray" matches
  88. # argument type "Tuple[Union[ndarray, slice], slice]"
  89. lvals = lvals[rblk.mgr_locs.indexer, :] # type: ignore[call-overload]
  90. assert lvals.shape == rvals.shape, (lvals.shape, rvals.shape)
  91. elif left_ea and right_ea:
  92. assert lvals.shape == rvals.shape, (lvals.shape, rvals.shape)
  93. elif right_ea:
  94. # lvals are 2D, rvals are 1D
  95. # error: No overload variant of "__getitem__" of "ExtensionArray" matches
  96. # argument type "Tuple[Union[ndarray, slice], slice]"
  97. lvals = lvals[rblk.mgr_locs.indexer, :] # type: ignore[call-overload]
  98. assert lvals.shape[0] == 1, lvals.shape
  99. lvals = lvals[0, :]
  100. else:
  101. # lvals are 1D, rvals are 2D
  102. assert rvals.shape[0] == 1, rvals.shape
  103. # error: No overload variant of "__getitem__" of "ExtensionArray" matches
  104. # argument type "Tuple[int, slice]"
  105. rvals = rvals[0, :] # type: ignore[call-overload]
  106. return lvals, rvals
  107. def blockwise_all(left: BlockManager, right: BlockManager, op) -> bool:
  108. """
  109. Blockwise `all` reduction.
  110. """
  111. for info in _iter_block_pairs(left, right):
  112. res = op(info.lvals, info.rvals)
  113. if not res:
  114. return False
  115. return True