arrow_parser_wrapper.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. from __future__ import annotations
  2. from pandas._typing import ReadBuffer
  3. from pandas.compat._optional import import_optional_dependency
  4. from pandas.core.dtypes.inference import is_integer
  5. import pandas as pd
  6. from pandas import DataFrame
  7. from pandas.io._util import _arrow_dtype_mapping
  8. from pandas.io.parsers.base_parser import ParserBase
  9. class ArrowParserWrapper(ParserBase):
  10. """
  11. Wrapper for the pyarrow engine for read_csv()
  12. """
  13. def __init__(self, src: ReadBuffer[bytes], **kwds) -> None:
  14. super().__init__(kwds)
  15. self.kwds = kwds
  16. self.src = src
  17. self._parse_kwds()
  18. def _parse_kwds(self):
  19. """
  20. Validates keywords before passing to pyarrow.
  21. """
  22. encoding: str | None = self.kwds.get("encoding")
  23. self.encoding = "utf-8" if encoding is None else encoding
  24. self.usecols, self.usecols_dtype = self._validate_usecols_arg(
  25. self.kwds["usecols"]
  26. )
  27. na_values = self.kwds["na_values"]
  28. if isinstance(na_values, dict):
  29. raise ValueError(
  30. "The pyarrow engine doesn't support passing a dict for na_values"
  31. )
  32. self.na_values = list(self.kwds["na_values"])
  33. def _get_pyarrow_options(self) -> None:
  34. """
  35. Rename some arguments to pass to pyarrow
  36. """
  37. mapping = {
  38. "usecols": "include_columns",
  39. "na_values": "null_values",
  40. "escapechar": "escape_char",
  41. "skip_blank_lines": "ignore_empty_lines",
  42. "decimal": "decimal_point",
  43. }
  44. for pandas_name, pyarrow_name in mapping.items():
  45. if pandas_name in self.kwds and self.kwds.get(pandas_name) is not None:
  46. self.kwds[pyarrow_name] = self.kwds.pop(pandas_name)
  47. self.parse_options = {
  48. option_name: option_value
  49. for option_name, option_value in self.kwds.items()
  50. if option_value is not None
  51. and option_name
  52. in ("delimiter", "quote_char", "escape_char", "ignore_empty_lines")
  53. }
  54. self.convert_options = {
  55. option_name: option_value
  56. for option_name, option_value in self.kwds.items()
  57. if option_value is not None
  58. and option_name
  59. in (
  60. "include_columns",
  61. "null_values",
  62. "true_values",
  63. "false_values",
  64. "decimal_point",
  65. )
  66. }
  67. self.read_options = {
  68. "autogenerate_column_names": self.header is None,
  69. "skip_rows": self.header
  70. if self.header is not None
  71. else self.kwds["skiprows"],
  72. "encoding": self.encoding,
  73. }
  74. def _finalize_pandas_output(self, frame: DataFrame) -> DataFrame:
  75. """
  76. Processes data read in based on kwargs.
  77. Parameters
  78. ----------
  79. frame: DataFrame
  80. The DataFrame to process.
  81. Returns
  82. -------
  83. DataFrame
  84. The processed DataFrame.
  85. """
  86. num_cols = len(frame.columns)
  87. multi_index_named = True
  88. if self.header is None:
  89. if self.names is None:
  90. if self.header is None:
  91. self.names = range(num_cols)
  92. if len(self.names) != num_cols:
  93. # usecols is passed through to pyarrow, we only handle index col here
  94. # The only way self.names is not the same length as number of cols is
  95. # if we have int index_col. We should just pad the names(they will get
  96. # removed anyways) to expected length then.
  97. self.names = list(range(num_cols - len(self.names))) + self.names
  98. multi_index_named = False
  99. frame.columns = self.names
  100. # we only need the frame not the names
  101. frame.columns, frame = self._do_date_conversions(frame.columns, frame)
  102. if self.index_col is not None:
  103. for i, item in enumerate(self.index_col):
  104. if is_integer(item):
  105. self.index_col[i] = frame.columns[item]
  106. else:
  107. # String case
  108. if item not in frame.columns:
  109. raise ValueError(f"Index {item} invalid")
  110. frame.set_index(self.index_col, drop=True, inplace=True)
  111. # Clear names if headerless and no name given
  112. if self.header is None and not multi_index_named:
  113. frame.index.names = [None] * len(frame.index.names)
  114. if self.kwds.get("dtype") is not None:
  115. try:
  116. frame = frame.astype(self.kwds.get("dtype"))
  117. except TypeError as e:
  118. # GH#44901 reraise to keep api consistent
  119. raise ValueError(e)
  120. return frame
  121. def read(self) -> DataFrame:
  122. """
  123. Reads the contents of a CSV file into a DataFrame and
  124. processes it according to the kwargs passed in the
  125. constructor.
  126. Returns
  127. -------
  128. DataFrame
  129. The DataFrame created from the CSV file.
  130. """
  131. pyarrow_csv = import_optional_dependency("pyarrow.csv")
  132. self._get_pyarrow_options()
  133. table = pyarrow_csv.read_csv(
  134. self.src,
  135. read_options=pyarrow_csv.ReadOptions(**self.read_options),
  136. parse_options=pyarrow_csv.ParseOptions(**self.parse_options),
  137. convert_options=pyarrow_csv.ConvertOptions(**self.convert_options),
  138. )
  139. if self.kwds["dtype_backend"] == "pyarrow":
  140. frame = table.to_pandas(types_mapper=pd.ArrowDtype)
  141. elif self.kwds["dtype_backend"] == "numpy_nullable":
  142. frame = table.to_pandas(types_mapper=_arrow_dtype_mapping().get)
  143. else:
  144. frame = table.to_pandas()
  145. return self._finalize_pandas_output(frame)