123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435 |
- from __future__ import annotations
- import bz2
- from functools import wraps
- import gzip
- import io
- import socket
- import tarfile
- from typing import (
- TYPE_CHECKING,
- Any,
- Callable,
- )
- import zipfile
- from pandas._typing import (
- FilePath,
- ReadPickleBuffer,
- )
- from pandas.compat import get_lzma_file
- from pandas.compat._optional import import_optional_dependency
- import pandas as pd
- from pandas._testing._random import rands
- from pandas._testing.contexts import ensure_clean
- from pandas.io.common import urlopen
- if TYPE_CHECKING:
- from pandas import (
- DataFrame,
- Series,
- )
- # skip tests on exceptions with these messages
- _network_error_messages = (
- # 'urlopen error timed out',
- # 'timeout: timed out',
- # 'socket.timeout: timed out',
- "timed out",
- "Server Hangup",
- "HTTP Error 503: Service Unavailable",
- "502: Proxy Error",
- "HTTP Error 502: internal error",
- "HTTP Error 502",
- "HTTP Error 503",
- "HTTP Error 403",
- "HTTP Error 400",
- "Temporary failure in name resolution",
- "Name or service not known",
- "Connection refused",
- "certificate verify",
- )
- # or this e.errno/e.reason.errno
- _network_errno_vals = (
- 101, # Network is unreachable
- 111, # Connection refused
- 110, # Connection timed out
- 104, # Connection reset Error
- 54, # Connection reset by peer
- 60, # urllib.error.URLError: [Errno 60] Connection timed out
- )
- # Both of the above shouldn't mask real issues such as 404's
- # or refused connections (changed DNS).
- # But some tests (test_data yahoo) contact incredibly flakey
- # servers.
- # and conditionally raise on exception types in _get_default_network_errors
- def _get_default_network_errors():
- # Lazy import for http.client & urllib.error
- # because it imports many things from the stdlib
- import http.client
- import urllib.error
- return (
- OSError,
- http.client.HTTPException,
- TimeoutError,
- urllib.error.URLError,
- socket.timeout,
- )
- def optional_args(decorator):
- """
- allows a decorator to take optional positional and keyword arguments.
- Assumes that taking a single, callable, positional argument means that
- it is decorating a function, i.e. something like this::
- @my_decorator
- def function(): pass
- Calls decorator with decorator(f, *args, **kwargs)
- """
- @wraps(decorator)
- def wrapper(*args, **kwargs):
- def dec(f):
- return decorator(f, *args, **kwargs)
- is_decorating = not kwargs and len(args) == 1 and callable(args[0])
- if is_decorating:
- f = args[0]
- args = ()
- return dec(f)
- else:
- return dec
- return wrapper
- # error: Untyped decorator makes function "network" untyped
- @optional_args # type: ignore[misc]
- def network(
- t,
- url: str = "https://www.google.com",
- raise_on_error: bool = False,
- check_before_test: bool = False,
- error_classes=None,
- skip_errnos=_network_errno_vals,
- _skip_on_messages=_network_error_messages,
- ):
- """
- Label a test as requiring network connection and, if an error is
- encountered, only raise if it does not find a network connection.
- In comparison to ``network``, this assumes an added contract to your test:
- you must assert that, under normal conditions, your test will ONLY fail if
- it does not have network connectivity.
- You can call this in 3 ways: as a standard decorator, with keyword
- arguments, or with a positional argument that is the url to check.
- Parameters
- ----------
- t : callable
- The test requiring network connectivity.
- url : path
- The url to test via ``pandas.io.common.urlopen`` to check
- for connectivity. Defaults to 'https://www.google.com'.
- raise_on_error : bool
- If True, never catches errors.
- check_before_test : bool
- If True, checks connectivity before running the test case.
- error_classes : tuple or Exception
- error classes to ignore. If not in ``error_classes``, raises the error.
- defaults to OSError. Be careful about changing the error classes here.
- skip_errnos : iterable of int
- Any exception that has .errno or .reason.erno set to one
- of these values will be skipped with an appropriate
- message.
- _skip_on_messages: iterable of string
- any exception e for which one of the strings is
- a substring of str(e) will be skipped with an appropriate
- message. Intended to suppress errors where an errno isn't available.
- Notes
- -----
- * ``raise_on_error`` supersedes ``check_before_test``
- Returns
- -------
- t : callable
- The decorated test ``t``, with checks for connectivity errors.
- Example
- -------
- Tests decorated with @network will fail if it's possible to make a network
- connection to another URL (defaults to google.com)::
- >>> from pandas import _testing as tm
- >>> @tm.network
- ... def test_network():
- ... with pd.io.common.urlopen("rabbit://bonanza.com"):
- ... pass
- >>> test_network() # doctest: +SKIP
- Traceback
- ...
- URLError: <urlopen error unknown url type: rabbit>
- You can specify alternative URLs::
- >>> @tm.network("https://www.yahoo.com")
- ... def test_something_with_yahoo():
- ... raise OSError("Failure Message")
- >>> test_something_with_yahoo() # doctest: +SKIP
- Traceback (most recent call last):
- ...
- OSError: Failure Message
- If you set check_before_test, it will check the url first and not run the
- test on failure::
- >>> @tm.network("failing://url.blaher", check_before_test=True)
- ... def test_something():
- ... print("I ran!")
- ... raise ValueError("Failure")
- >>> test_something() # doctest: +SKIP
- Traceback (most recent call last):
- ...
- Errors not related to networking will always be raised.
- """
- import pytest
- if error_classes is None:
- error_classes = _get_default_network_errors()
- t.network = True
- @wraps(t)
- def wrapper(*args, **kwargs):
- if (
- check_before_test
- and not raise_on_error
- and not can_connect(url, error_classes)
- ):
- pytest.skip(
- f"May not have network connectivity because cannot connect to {url}"
- )
- try:
- return t(*args, **kwargs)
- except Exception as err:
- errno = getattr(err, "errno", None)
- if not errno and hasattr(errno, "reason"):
- # error: "Exception" has no attribute "reason"
- errno = getattr(err.reason, "errno", None) # type: ignore[attr-defined]
- if errno in skip_errnos:
- pytest.skip(f"Skipping test due to known errno and error {err}")
- e_str = str(err)
- if any(m.lower() in e_str.lower() for m in _skip_on_messages):
- pytest.skip(
- f"Skipping test because exception message is known and error {err}"
- )
- if not isinstance(err, error_classes) or raise_on_error:
- raise
- pytest.skip(f"Skipping test due to lack of connectivity and error {err}")
- return wrapper
- def can_connect(url, error_classes=None) -> bool:
- """
- Try to connect to the given url. True if succeeds, False if OSError
- raised
- Parameters
- ----------
- url : basestring
- The URL to try to connect to
- Returns
- -------
- connectable : bool
- Return True if no OSError (unable to connect) or URLError (bad url) was
- raised
- """
- if error_classes is None:
- error_classes = _get_default_network_errors()
- try:
- with urlopen(url, timeout=20) as response:
- # Timeout just in case rate-limiting is applied
- if response.status != 200:
- return False
- except error_classes:
- return False
- else:
- return True
- # ------------------------------------------------------------------
- # File-IO
- def round_trip_pickle(
- obj: Any, path: FilePath | ReadPickleBuffer | None = None
- ) -> DataFrame | Series:
- """
- Pickle an object and then read it again.
- Parameters
- ----------
- obj : any object
- The object to pickle and then re-read.
- path : str, path object or file-like object, default None
- The path where the pickled object is written and then read.
- Returns
- -------
- pandas object
- The original object that was pickled and then re-read.
- """
- _path = path
- if _path is None:
- _path = f"__{rands(10)}__.pickle"
- with ensure_clean(_path) as temp_path:
- pd.to_pickle(obj, temp_path)
- return pd.read_pickle(temp_path)
- def round_trip_pathlib(writer, reader, path: str | None = None):
- """
- Write an object to file specified by a pathlib.Path and read it back
- Parameters
- ----------
- writer : callable bound to pandas object
- IO writing function (e.g. DataFrame.to_csv )
- reader : callable
- IO reading function (e.g. pd.read_csv )
- path : str, default None
- The path where the object is written and then read.
- Returns
- -------
- pandas object
- The original object that was serialized and then re-read.
- """
- import pytest
- Path = pytest.importorskip("pathlib").Path
- if path is None:
- path = "___pathlib___"
- with ensure_clean(path) as path:
- writer(Path(path))
- obj = reader(Path(path))
- return obj
- def round_trip_localpath(writer, reader, path: str | None = None):
- """
- Write an object to file specified by a py.path LocalPath and read it back.
- Parameters
- ----------
- writer : callable bound to pandas object
- IO writing function (e.g. DataFrame.to_csv )
- reader : callable
- IO reading function (e.g. pd.read_csv )
- path : str, default None
- The path where the object is written and then read.
- Returns
- -------
- pandas object
- The original object that was serialized and then re-read.
- """
- import pytest
- LocalPath = pytest.importorskip("py.path").local
- if path is None:
- path = "___localpath___"
- with ensure_clean(path) as path:
- writer(LocalPath(path))
- obj = reader(LocalPath(path))
- return obj
- def write_to_compressed(compression, path, data, dest: str = "test"):
- """
- Write data to a compressed file.
- Parameters
- ----------
- compression : {'gzip', 'bz2', 'zip', 'xz', 'zstd'}
- The compression type to use.
- path : str
- The file path to write the data.
- data : str
- The data to write.
- dest : str, default "test"
- The destination file (for ZIP only)
- Raises
- ------
- ValueError : An invalid compression value was passed in.
- """
- args: tuple[Any, ...] = (data,)
- mode = "wb"
- method = "write"
- compress_method: Callable
- if compression == "zip":
- compress_method = zipfile.ZipFile
- mode = "w"
- args = (dest, data)
- method = "writestr"
- elif compression == "tar":
- compress_method = tarfile.TarFile
- mode = "w"
- file = tarfile.TarInfo(name=dest)
- bytes = io.BytesIO(data)
- file.size = len(data)
- args = (file, bytes)
- method = "addfile"
- elif compression == "gzip":
- compress_method = gzip.GzipFile
- elif compression == "bz2":
- compress_method = bz2.BZ2File
- elif compression == "zstd":
- compress_method = import_optional_dependency("zstandard").open
- elif compression == "xz":
- compress_method = get_lzma_file()
- else:
- raise ValueError(f"Unrecognized compression type: {compression}")
- with compress_method(path, mode=mode) as f:
- getattr(f, method)(*args)
- # ------------------------------------------------------------------
- # Plotting
- def close(fignum=None) -> None:
- from matplotlib.pyplot import (
- close as _close,
- get_fignums,
- )
- if fignum is None:
- for fignum in get_fignums():
- _close(fignum)
- else:
- _close(fignum)
|