_io.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. from __future__ import annotations
  2. import bz2
  3. from functools import wraps
  4. import gzip
  5. import io
  6. import socket
  7. import tarfile
  8. from typing import (
  9. TYPE_CHECKING,
  10. Any,
  11. Callable,
  12. )
  13. import zipfile
  14. from pandas._typing import (
  15. FilePath,
  16. ReadPickleBuffer,
  17. )
  18. from pandas.compat import get_lzma_file
  19. from pandas.compat._optional import import_optional_dependency
  20. import pandas as pd
  21. from pandas._testing._random import rands
  22. from pandas._testing.contexts import ensure_clean
  23. from pandas.io.common import urlopen
  24. if TYPE_CHECKING:
  25. from pandas import (
  26. DataFrame,
  27. Series,
  28. )
  29. # skip tests on exceptions with these messages
  30. _network_error_messages = (
  31. # 'urlopen error timed out',
  32. # 'timeout: timed out',
  33. # 'socket.timeout: timed out',
  34. "timed out",
  35. "Server Hangup",
  36. "HTTP Error 503: Service Unavailable",
  37. "502: Proxy Error",
  38. "HTTP Error 502: internal error",
  39. "HTTP Error 502",
  40. "HTTP Error 503",
  41. "HTTP Error 403",
  42. "HTTP Error 400",
  43. "Temporary failure in name resolution",
  44. "Name or service not known",
  45. "Connection refused",
  46. "certificate verify",
  47. )
  48. # or this e.errno/e.reason.errno
  49. _network_errno_vals = (
  50. 101, # Network is unreachable
  51. 111, # Connection refused
  52. 110, # Connection timed out
  53. 104, # Connection reset Error
  54. 54, # Connection reset by peer
  55. 60, # urllib.error.URLError: [Errno 60] Connection timed out
  56. )
  57. # Both of the above shouldn't mask real issues such as 404's
  58. # or refused connections (changed DNS).
  59. # But some tests (test_data yahoo) contact incredibly flakey
  60. # servers.
  61. # and conditionally raise on exception types in _get_default_network_errors
  62. def _get_default_network_errors():
  63. # Lazy import for http.client & urllib.error
  64. # because it imports many things from the stdlib
  65. import http.client
  66. import urllib.error
  67. return (
  68. OSError,
  69. http.client.HTTPException,
  70. TimeoutError,
  71. urllib.error.URLError,
  72. socket.timeout,
  73. )
  74. def optional_args(decorator):
  75. """
  76. allows a decorator to take optional positional and keyword arguments.
  77. Assumes that taking a single, callable, positional argument means that
  78. it is decorating a function, i.e. something like this::
  79. @my_decorator
  80. def function(): pass
  81. Calls decorator with decorator(f, *args, **kwargs)
  82. """
  83. @wraps(decorator)
  84. def wrapper(*args, **kwargs):
  85. def dec(f):
  86. return decorator(f, *args, **kwargs)
  87. is_decorating = not kwargs and len(args) == 1 and callable(args[0])
  88. if is_decorating:
  89. f = args[0]
  90. args = ()
  91. return dec(f)
  92. else:
  93. return dec
  94. return wrapper
  95. # error: Untyped decorator makes function "network" untyped
  96. @optional_args # type: ignore[misc]
  97. def network(
  98. t,
  99. url: str = "https://www.google.com",
  100. raise_on_error: bool = False,
  101. check_before_test: bool = False,
  102. error_classes=None,
  103. skip_errnos=_network_errno_vals,
  104. _skip_on_messages=_network_error_messages,
  105. ):
  106. """
  107. Label a test as requiring network connection and, if an error is
  108. encountered, only raise if it does not find a network connection.
  109. In comparison to ``network``, this assumes an added contract to your test:
  110. you must assert that, under normal conditions, your test will ONLY fail if
  111. it does not have network connectivity.
  112. You can call this in 3 ways: as a standard decorator, with keyword
  113. arguments, or with a positional argument that is the url to check.
  114. Parameters
  115. ----------
  116. t : callable
  117. The test requiring network connectivity.
  118. url : path
  119. The url to test via ``pandas.io.common.urlopen`` to check
  120. for connectivity. Defaults to 'https://www.google.com'.
  121. raise_on_error : bool
  122. If True, never catches errors.
  123. check_before_test : bool
  124. If True, checks connectivity before running the test case.
  125. error_classes : tuple or Exception
  126. error classes to ignore. If not in ``error_classes``, raises the error.
  127. defaults to OSError. Be careful about changing the error classes here.
  128. skip_errnos : iterable of int
  129. Any exception that has .errno or .reason.erno set to one
  130. of these values will be skipped with an appropriate
  131. message.
  132. _skip_on_messages: iterable of string
  133. any exception e for which one of the strings is
  134. a substring of str(e) will be skipped with an appropriate
  135. message. Intended to suppress errors where an errno isn't available.
  136. Notes
  137. -----
  138. * ``raise_on_error`` supersedes ``check_before_test``
  139. Returns
  140. -------
  141. t : callable
  142. The decorated test ``t``, with checks for connectivity errors.
  143. Example
  144. -------
  145. Tests decorated with @network will fail if it's possible to make a network
  146. connection to another URL (defaults to google.com)::
  147. >>> from pandas import _testing as tm
  148. >>> @tm.network
  149. ... def test_network():
  150. ... with pd.io.common.urlopen("rabbit://bonanza.com"):
  151. ... pass
  152. >>> test_network() # doctest: +SKIP
  153. Traceback
  154. ...
  155. URLError: <urlopen error unknown url type: rabbit>
  156. You can specify alternative URLs::
  157. >>> @tm.network("https://www.yahoo.com")
  158. ... def test_something_with_yahoo():
  159. ... raise OSError("Failure Message")
  160. >>> test_something_with_yahoo() # doctest: +SKIP
  161. Traceback (most recent call last):
  162. ...
  163. OSError: Failure Message
  164. If you set check_before_test, it will check the url first and not run the
  165. test on failure::
  166. >>> @tm.network("failing://url.blaher", check_before_test=True)
  167. ... def test_something():
  168. ... print("I ran!")
  169. ... raise ValueError("Failure")
  170. >>> test_something() # doctest: +SKIP
  171. Traceback (most recent call last):
  172. ...
  173. Errors not related to networking will always be raised.
  174. """
  175. import pytest
  176. if error_classes is None:
  177. error_classes = _get_default_network_errors()
  178. t.network = True
  179. @wraps(t)
  180. def wrapper(*args, **kwargs):
  181. if (
  182. check_before_test
  183. and not raise_on_error
  184. and not can_connect(url, error_classes)
  185. ):
  186. pytest.skip(
  187. f"May not have network connectivity because cannot connect to {url}"
  188. )
  189. try:
  190. return t(*args, **kwargs)
  191. except Exception as err:
  192. errno = getattr(err, "errno", None)
  193. if not errno and hasattr(errno, "reason"):
  194. # error: "Exception" has no attribute "reason"
  195. errno = getattr(err.reason, "errno", None) # type: ignore[attr-defined]
  196. if errno in skip_errnos:
  197. pytest.skip(f"Skipping test due to known errno and error {err}")
  198. e_str = str(err)
  199. if any(m.lower() in e_str.lower() for m in _skip_on_messages):
  200. pytest.skip(
  201. f"Skipping test because exception message is known and error {err}"
  202. )
  203. if not isinstance(err, error_classes) or raise_on_error:
  204. raise
  205. pytest.skip(f"Skipping test due to lack of connectivity and error {err}")
  206. return wrapper
  207. def can_connect(url, error_classes=None) -> bool:
  208. """
  209. Try to connect to the given url. True if succeeds, False if OSError
  210. raised
  211. Parameters
  212. ----------
  213. url : basestring
  214. The URL to try to connect to
  215. Returns
  216. -------
  217. connectable : bool
  218. Return True if no OSError (unable to connect) or URLError (bad url) was
  219. raised
  220. """
  221. if error_classes is None:
  222. error_classes = _get_default_network_errors()
  223. try:
  224. with urlopen(url, timeout=20) as response:
  225. # Timeout just in case rate-limiting is applied
  226. if response.status != 200:
  227. return False
  228. except error_classes:
  229. return False
  230. else:
  231. return True
  232. # ------------------------------------------------------------------
  233. # File-IO
  234. def round_trip_pickle(
  235. obj: Any, path: FilePath | ReadPickleBuffer | None = None
  236. ) -> DataFrame | Series:
  237. """
  238. Pickle an object and then read it again.
  239. Parameters
  240. ----------
  241. obj : any object
  242. The object to pickle and then re-read.
  243. path : str, path object or file-like object, default None
  244. The path where the pickled object is written and then read.
  245. Returns
  246. -------
  247. pandas object
  248. The original object that was pickled and then re-read.
  249. """
  250. _path = path
  251. if _path is None:
  252. _path = f"__{rands(10)}__.pickle"
  253. with ensure_clean(_path) as temp_path:
  254. pd.to_pickle(obj, temp_path)
  255. return pd.read_pickle(temp_path)
  256. def round_trip_pathlib(writer, reader, path: str | None = None):
  257. """
  258. Write an object to file specified by a pathlib.Path and read it back
  259. Parameters
  260. ----------
  261. writer : callable bound to pandas object
  262. IO writing function (e.g. DataFrame.to_csv )
  263. reader : callable
  264. IO reading function (e.g. pd.read_csv )
  265. path : str, default None
  266. The path where the object is written and then read.
  267. Returns
  268. -------
  269. pandas object
  270. The original object that was serialized and then re-read.
  271. """
  272. import pytest
  273. Path = pytest.importorskip("pathlib").Path
  274. if path is None:
  275. path = "___pathlib___"
  276. with ensure_clean(path) as path:
  277. writer(Path(path))
  278. obj = reader(Path(path))
  279. return obj
  280. def round_trip_localpath(writer, reader, path: str | None = None):
  281. """
  282. Write an object to file specified by a py.path LocalPath and read it back.
  283. Parameters
  284. ----------
  285. writer : callable bound to pandas object
  286. IO writing function (e.g. DataFrame.to_csv )
  287. reader : callable
  288. IO reading function (e.g. pd.read_csv )
  289. path : str, default None
  290. The path where the object is written and then read.
  291. Returns
  292. -------
  293. pandas object
  294. The original object that was serialized and then re-read.
  295. """
  296. import pytest
  297. LocalPath = pytest.importorskip("py.path").local
  298. if path is None:
  299. path = "___localpath___"
  300. with ensure_clean(path) as path:
  301. writer(LocalPath(path))
  302. obj = reader(LocalPath(path))
  303. return obj
  304. def write_to_compressed(compression, path, data, dest: str = "test"):
  305. """
  306. Write data to a compressed file.
  307. Parameters
  308. ----------
  309. compression : {'gzip', 'bz2', 'zip', 'xz', 'zstd'}
  310. The compression type to use.
  311. path : str
  312. The file path to write the data.
  313. data : str
  314. The data to write.
  315. dest : str, default "test"
  316. The destination file (for ZIP only)
  317. Raises
  318. ------
  319. ValueError : An invalid compression value was passed in.
  320. """
  321. args: tuple[Any, ...] = (data,)
  322. mode = "wb"
  323. method = "write"
  324. compress_method: Callable
  325. if compression == "zip":
  326. compress_method = zipfile.ZipFile
  327. mode = "w"
  328. args = (dest, data)
  329. method = "writestr"
  330. elif compression == "tar":
  331. compress_method = tarfile.TarFile
  332. mode = "w"
  333. file = tarfile.TarInfo(name=dest)
  334. bytes = io.BytesIO(data)
  335. file.size = len(data)
  336. args = (file, bytes)
  337. method = "addfile"
  338. elif compression == "gzip":
  339. compress_method = gzip.GzipFile
  340. elif compression == "bz2":
  341. compress_method = bz2.BZ2File
  342. elif compression == "zstd":
  343. compress_method = import_optional_dependency("zstandard").open
  344. elif compression == "xz":
  345. compress_method = get_lzma_file()
  346. else:
  347. raise ValueError(f"Unrecognized compression type: {compression}")
  348. with compress_method(path, mode=mode) as f:
  349. getattr(f, method)(*args)
  350. # ------------------------------------------------------------------
  351. # Plotting
  352. def close(fignum=None) -> None:
  353. from matplotlib.pyplot import (
  354. close as _close,
  355. get_fignums,
  356. )
  357. if fignum is None:
  358. for fignum in get_fignums():
  359. _close(fignum)
  360. else:
  361. _close(fignum)