test_multi_thread.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. """
  2. Tests multithreading behaviour for reading and
  3. parsing files for each parser defined in parsers.py
  4. """
  5. from contextlib import ExitStack
  6. from io import BytesIO
  7. from multiprocessing.pool import ThreadPool
  8. import numpy as np
  9. import pytest
  10. import pandas as pd
  11. from pandas import DataFrame
  12. import pandas._testing as tm
  13. # We'll probably always skip these for pyarrow
  14. # Maybe we'll add our own tests for pyarrow too
  15. pytestmark = pytest.mark.usefixtures("pyarrow_skip")
  16. def _construct_dataframe(num_rows):
  17. """
  18. Construct a DataFrame for testing.
  19. Parameters
  20. ----------
  21. num_rows : int
  22. The number of rows for our DataFrame.
  23. Returns
  24. -------
  25. df : DataFrame
  26. """
  27. df = DataFrame(np.random.rand(num_rows, 5), columns=list("abcde"))
  28. df["foo"] = "foo"
  29. df["bar"] = "bar"
  30. df["baz"] = "baz"
  31. df["date"] = pd.date_range("20000101 09:00:00", periods=num_rows, freq="s")
  32. df["int"] = np.arange(num_rows, dtype="int64")
  33. return df
  34. @pytest.mark.slow
  35. def test_multi_thread_string_io_read_csv(all_parsers):
  36. # see gh-11786
  37. parser = all_parsers
  38. max_row_range = 10000
  39. num_files = 100
  40. bytes_to_df = [
  41. "\n".join([f"{i:d},{i:d},{i:d}" for i in range(max_row_range)]).encode()
  42. for _ in range(num_files)
  43. ]
  44. # Read all files in many threads.
  45. with ExitStack() as stack:
  46. files = [stack.enter_context(BytesIO(b)) for b in bytes_to_df]
  47. pool = stack.enter_context(ThreadPool(8))
  48. results = pool.map(parser.read_csv, files)
  49. first_result = results[0]
  50. for result in results:
  51. tm.assert_frame_equal(first_result, result)
  52. def _generate_multi_thread_dataframe(parser, path, num_rows, num_tasks):
  53. """
  54. Generate a DataFrame via multi-thread.
  55. Parameters
  56. ----------
  57. parser : BaseParser
  58. The parser object to use for reading the data.
  59. path : str
  60. The location of the CSV file to read.
  61. num_rows : int
  62. The number of rows to read per task.
  63. num_tasks : int
  64. The number of tasks to use for reading this DataFrame.
  65. Returns
  66. -------
  67. df : DataFrame
  68. """
  69. def reader(arg):
  70. """
  71. Create a reader for part of the CSV.
  72. Parameters
  73. ----------
  74. arg : tuple
  75. A tuple of the following:
  76. * start : int
  77. The starting row to start for parsing CSV
  78. * nrows : int
  79. The number of rows to read.
  80. Returns
  81. -------
  82. df : DataFrame
  83. """
  84. start, nrows = arg
  85. if not start:
  86. return parser.read_csv(
  87. path, index_col=0, header=0, nrows=nrows, parse_dates=["date"]
  88. )
  89. return parser.read_csv(
  90. path,
  91. index_col=0,
  92. header=None,
  93. skiprows=int(start) + 1,
  94. nrows=nrows,
  95. parse_dates=[9],
  96. )
  97. tasks = [
  98. (num_rows * i // num_tasks, num_rows // num_tasks) for i in range(num_tasks)
  99. ]
  100. with ThreadPool(processes=num_tasks) as pool:
  101. results = pool.map(reader, tasks)
  102. header = results[0].columns
  103. for r in results[1:]:
  104. r.columns = header
  105. final_dataframe = pd.concat(results)
  106. return final_dataframe
  107. @pytest.mark.slow
  108. def test_multi_thread_path_multipart_read_csv(all_parsers):
  109. # see gh-11786
  110. num_tasks = 4
  111. num_rows = 100000
  112. parser = all_parsers
  113. file_name = "__thread_pool_reader__.csv"
  114. df = _construct_dataframe(num_rows)
  115. with tm.ensure_clean(file_name) as path:
  116. df.to_csv(path)
  117. final_dataframe = _generate_multi_thread_dataframe(
  118. parser, path, num_rows, num_tasks
  119. )
  120. tm.assert_frame_equal(df, final_dataframe)