test_user_agent.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. """
  2. Tests for the pandas custom headers in http(s) requests
  3. """
  4. import gzip
  5. import http.server
  6. from io import BytesIO
  7. import multiprocessing
  8. import socket
  9. import time
  10. import urllib.error
  11. import pytest
  12. from pandas.compat import is_ci_environment
  13. import pandas.util._test_decorators as td
  14. import pandas as pd
  15. import pandas._testing as tm
  16. pytestmark = pytest.mark.skipif(
  17. is_ci_environment(),
  18. reason="This test can hang in our CI min_versions build "
  19. "and leads to '##[error]The runner has "
  20. "received a shutdown signal...' in GHA. GH 45651",
  21. )
  22. class BaseUserAgentResponder(http.server.BaseHTTPRequestHandler):
  23. """
  24. Base class for setting up a server that can be set up to respond
  25. with a particular file format with accompanying content-type headers.
  26. The interfaces on the different io methods are different enough
  27. that this seemed logical to do.
  28. """
  29. def start_processing_headers(self):
  30. """
  31. shared logic at the start of a GET request
  32. """
  33. self.send_response(200)
  34. self.requested_from_user_agent = self.headers["User-Agent"]
  35. response_df = pd.DataFrame(
  36. {
  37. "header": [self.requested_from_user_agent],
  38. }
  39. )
  40. return response_df
  41. def gzip_bytes(self, response_bytes):
  42. """
  43. some web servers will send back gzipped files to save bandwidth
  44. """
  45. with BytesIO() as bio:
  46. with gzip.GzipFile(fileobj=bio, mode="w") as zipper:
  47. zipper.write(response_bytes)
  48. response_bytes = bio.getvalue()
  49. return response_bytes
  50. def write_back_bytes(self, response_bytes):
  51. """
  52. shared logic at the end of a GET request
  53. """
  54. self.wfile.write(response_bytes)
  55. class CSVUserAgentResponder(BaseUserAgentResponder):
  56. def do_GET(self):
  57. response_df = self.start_processing_headers()
  58. self.send_header("Content-Type", "text/csv")
  59. self.end_headers()
  60. response_bytes = response_df.to_csv(index=False).encode("utf-8")
  61. self.write_back_bytes(response_bytes)
  62. class GzippedCSVUserAgentResponder(BaseUserAgentResponder):
  63. def do_GET(self):
  64. response_df = self.start_processing_headers()
  65. self.send_header("Content-Type", "text/csv")
  66. self.send_header("Content-Encoding", "gzip")
  67. self.end_headers()
  68. response_bytes = response_df.to_csv(index=False).encode("utf-8")
  69. response_bytes = self.gzip_bytes(response_bytes)
  70. self.write_back_bytes(response_bytes)
  71. class JSONUserAgentResponder(BaseUserAgentResponder):
  72. def do_GET(self):
  73. response_df = self.start_processing_headers()
  74. self.send_header("Content-Type", "application/json")
  75. self.end_headers()
  76. response_bytes = response_df.to_json().encode("utf-8")
  77. self.write_back_bytes(response_bytes)
  78. class GzippedJSONUserAgentResponder(BaseUserAgentResponder):
  79. def do_GET(self):
  80. response_df = self.start_processing_headers()
  81. self.send_header("Content-Type", "application/json")
  82. self.send_header("Content-Encoding", "gzip")
  83. self.end_headers()
  84. response_bytes = response_df.to_json().encode("utf-8")
  85. response_bytes = self.gzip_bytes(response_bytes)
  86. self.write_back_bytes(response_bytes)
  87. class ParquetPyArrowUserAgentResponder(BaseUserAgentResponder):
  88. def do_GET(self):
  89. response_df = self.start_processing_headers()
  90. self.send_header("Content-Type", "application/octet-stream")
  91. self.end_headers()
  92. response_bytes = response_df.to_parquet(index=False, engine="pyarrow")
  93. self.write_back_bytes(response_bytes)
  94. class ParquetFastParquetUserAgentResponder(BaseUserAgentResponder):
  95. def do_GET(self):
  96. response_df = self.start_processing_headers()
  97. self.send_header("Content-Type", "application/octet-stream")
  98. self.end_headers()
  99. # the fastparquet engine doesn't like to write to a buffer
  100. # it can do it via the open_with function being set appropriately
  101. # however it automatically calls the close method and wipes the buffer
  102. # so just overwrite that attribute on this instance to not do that
  103. # protected by an importorskip in the respective test
  104. import fsspec
  105. response_df.to_parquet(
  106. "memory://fastparquet_user_agent.parquet",
  107. index=False,
  108. engine="fastparquet",
  109. compression=None,
  110. )
  111. with fsspec.open("memory://fastparquet_user_agent.parquet", "rb") as f:
  112. response_bytes = f.read()
  113. self.write_back_bytes(response_bytes)
  114. class PickleUserAgentResponder(BaseUserAgentResponder):
  115. def do_GET(self):
  116. response_df = self.start_processing_headers()
  117. self.send_header("Content-Type", "application/octet-stream")
  118. self.end_headers()
  119. bio = BytesIO()
  120. response_df.to_pickle(bio)
  121. response_bytes = bio.getvalue()
  122. self.write_back_bytes(response_bytes)
  123. class StataUserAgentResponder(BaseUserAgentResponder):
  124. def do_GET(self):
  125. response_df = self.start_processing_headers()
  126. self.send_header("Content-Type", "application/octet-stream")
  127. self.end_headers()
  128. bio = BytesIO()
  129. response_df.to_stata(bio, write_index=False)
  130. response_bytes = bio.getvalue()
  131. self.write_back_bytes(response_bytes)
  132. class AllHeaderCSVResponder(http.server.BaseHTTPRequestHandler):
  133. """
  134. Send all request headers back for checking round trip
  135. """
  136. def do_GET(self):
  137. response_df = pd.DataFrame(self.headers.items())
  138. self.send_response(200)
  139. self.send_header("Content-Type", "text/csv")
  140. self.end_headers()
  141. response_bytes = response_df.to_csv(index=False).encode("utf-8")
  142. self.wfile.write(response_bytes)
  143. def wait_until_ready(func, *args, **kwargs):
  144. def inner(*args, **kwargs):
  145. while True:
  146. try:
  147. return func(*args, **kwargs)
  148. except urllib.error.URLError:
  149. # Connection refused as http server is starting
  150. time.sleep(0.1)
  151. return inner
  152. def process_server(responder, port):
  153. with http.server.HTTPServer(("localhost", port), responder) as server:
  154. server.handle_request()
  155. server.server_close()
  156. @pytest.fixture
  157. def responder(request):
  158. """
  159. Fixture that starts a local http server in a separate process on localhost
  160. and returns the port.
  161. Running in a separate process instead of a thread to allow termination/killing
  162. of http server upon cleanup.
  163. """
  164. # Find an available port
  165. with socket.socket() as sock:
  166. sock.bind(("localhost", 0))
  167. port = sock.getsockname()[1]
  168. server_process = multiprocessing.Process(
  169. target=process_server, args=(request.param, port)
  170. )
  171. server_process.start()
  172. yield port
  173. server_process.join(10)
  174. server_process.terminate()
  175. kill_time = 5
  176. wait_time = 0
  177. while server_process.is_alive():
  178. if wait_time > kill_time:
  179. server_process.kill()
  180. break
  181. wait_time += 0.1
  182. time.sleep(0.1)
  183. server_process.close()
  184. @pytest.mark.parametrize(
  185. "responder, read_method, parquet_engine",
  186. [
  187. (CSVUserAgentResponder, pd.read_csv, None),
  188. (JSONUserAgentResponder, pd.read_json, None),
  189. (ParquetPyArrowUserAgentResponder, pd.read_parquet, "pyarrow"),
  190. pytest.param(
  191. ParquetFastParquetUserAgentResponder,
  192. pd.read_parquet,
  193. "fastparquet",
  194. # TODO(ArrayManager) fastparquet
  195. marks=[
  196. td.skip_array_manager_not_yet_implemented,
  197. ],
  198. ),
  199. (PickleUserAgentResponder, pd.read_pickle, None),
  200. (StataUserAgentResponder, pd.read_stata, None),
  201. (GzippedCSVUserAgentResponder, pd.read_csv, None),
  202. (GzippedJSONUserAgentResponder, pd.read_json, None),
  203. ],
  204. indirect=["responder"],
  205. )
  206. def test_server_and_default_headers(responder, read_method, parquet_engine):
  207. if parquet_engine is not None:
  208. pytest.importorskip(parquet_engine)
  209. if parquet_engine == "fastparquet":
  210. pytest.importorskip("fsspec")
  211. read_method = wait_until_ready(read_method)
  212. if parquet_engine is None:
  213. df_http = read_method(f"http://localhost:{responder}")
  214. else:
  215. df_http = read_method(f"http://localhost:{responder}", engine=parquet_engine)
  216. assert not df_http.empty
  217. @pytest.mark.parametrize(
  218. "responder, read_method, parquet_engine",
  219. [
  220. (CSVUserAgentResponder, pd.read_csv, None),
  221. (JSONUserAgentResponder, pd.read_json, None),
  222. (ParquetPyArrowUserAgentResponder, pd.read_parquet, "pyarrow"),
  223. pytest.param(
  224. ParquetFastParquetUserAgentResponder,
  225. pd.read_parquet,
  226. "fastparquet",
  227. # TODO(ArrayManager) fastparquet
  228. marks=[
  229. td.skip_array_manager_not_yet_implemented,
  230. ],
  231. ),
  232. (PickleUserAgentResponder, pd.read_pickle, None),
  233. (StataUserAgentResponder, pd.read_stata, None),
  234. (GzippedCSVUserAgentResponder, pd.read_csv, None),
  235. (GzippedJSONUserAgentResponder, pd.read_json, None),
  236. ],
  237. indirect=["responder"],
  238. )
  239. def test_server_and_custom_headers(responder, read_method, parquet_engine):
  240. if parquet_engine is not None:
  241. pytest.importorskip(parquet_engine)
  242. if parquet_engine == "fastparquet":
  243. pytest.importorskip("fsspec")
  244. custom_user_agent = "Super Cool One"
  245. df_true = pd.DataFrame({"header": [custom_user_agent]})
  246. read_method = wait_until_ready(read_method)
  247. if parquet_engine is None:
  248. df_http = read_method(
  249. f"http://localhost:{responder}",
  250. storage_options={"User-Agent": custom_user_agent},
  251. )
  252. else:
  253. df_http = read_method(
  254. f"http://localhost:{responder}",
  255. storage_options={"User-Agent": custom_user_agent},
  256. engine=parquet_engine,
  257. )
  258. tm.assert_frame_equal(df_true, df_http)
  259. @pytest.mark.parametrize(
  260. "responder, read_method",
  261. [
  262. (AllHeaderCSVResponder, pd.read_csv),
  263. ],
  264. indirect=["responder"],
  265. )
  266. def test_server_and_all_custom_headers(responder, read_method):
  267. custom_user_agent = "Super Cool One"
  268. custom_auth_token = "Super Secret One"
  269. storage_options = {
  270. "User-Agent": custom_user_agent,
  271. "Auth": custom_auth_token,
  272. }
  273. read_method = wait_until_ready(read_method)
  274. df_http = read_method(
  275. f"http://localhost:{responder}",
  276. storage_options=storage_options,
  277. )
  278. df_http = df_http[df_http["0"].isin(storage_options.keys())]
  279. df_http = df_http.sort_values(["0"]).reset_index()
  280. df_http = df_http[["0", "1"]]
  281. keys = list(storage_options.keys())
  282. df_true = pd.DataFrame({"0": keys, "1": [storage_options[k] for k in keys]})
  283. df_true = df_true.sort_values(["0"])
  284. df_true = df_true.reset_index().drop(["index"], axis=1)
  285. tm.assert_frame_equal(df_true, df_http)
  286. @pytest.mark.parametrize(
  287. "engine",
  288. [
  289. "pyarrow",
  290. "fastparquet",
  291. ],
  292. )
  293. def test_to_parquet_to_disk_with_storage_options(engine):
  294. headers = {
  295. "User-Agent": "custom",
  296. "Auth": "other_custom",
  297. }
  298. pytest.importorskip(engine)
  299. true_df = pd.DataFrame({"column_name": ["column_value"]})
  300. msg = (
  301. "storage_options passed with file object or non-fsspec file path|"
  302. "storage_options passed with buffer, or non-supported URL"
  303. )
  304. with pytest.raises(ValueError, match=msg):
  305. true_df.to_parquet("/tmp/junk.parquet", storage_options=headers, engine=engine)