123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398 |
- """
- General helpers required for `tqdm.std`.
- """
- import os
- import re
- import sys
- from functools import partial, partialmethod, wraps
- from inspect import signature
- # TODO consider using wcswidth third-party package for 0-width characters
- from unicodedata import east_asian_width
- from warnings import warn
- from weakref import proxy
- _range, _unich, _unicode, _basestring = range, chr, str, str
- CUR_OS = sys.platform
- IS_WIN = any(CUR_OS.startswith(i) for i in ['win32', 'cygwin'])
- IS_NIX = any(CUR_OS.startswith(i) for i in ['aix', 'linux', 'darwin'])
- RE_ANSI = re.compile(r"\x1b\[[;\d]*[A-Za-z]")
- try:
- if IS_WIN:
- import colorama
- else:
- raise ImportError
- except ImportError:
- colorama = None
- else:
- try:
- colorama.init(strip=False)
- except TypeError:
- colorama.init()
- def envwrap(prefix, types=None, is_method=False):
- """
- Override parameter defaults via `os.environ[prefix + param_name]`.
- Maps UPPER_CASE env vars map to lower_case param names.
- camelCase isn't supported (because Windows ignores case).
- Precedence (highest first):
- - call (`foo(a=3)`)
- - environ (`FOO_A=2`)
- - signature (`def foo(a=1)`)
- Parameters
- ----------
- prefix : str
- Env var prefix, e.g. "FOO_"
- types : dict, optional
- Fallback mappings `{'param_name': type, ...}` if types cannot be
- inferred from function signature.
- Consider using `types=collections.defaultdict(lambda: ast.literal_eval)`.
- is_method : bool, optional
- Whether to use `functools.partialmethod`. If (default: False) use `functools.partial`.
- Examples
- --------
- ```
- $ cat foo.py
- from tqdm.utils import envwrap
- @envwrap("FOO_")
- def test(a=1, b=2, c=3):
- print(f"received: a={a}, b={b}, c={c}")
- $ FOO_A=42 FOO_C=1337 python -c 'import foo; foo.test(c=99)'
- received: a=42, b=2, c=99
- ```
- """
- if types is None:
- types = {}
- i = len(prefix)
- env_overrides = {k[i:].lower(): v for k, v in os.environ.items() if k.startswith(prefix)}
- part = partialmethod if is_method else partial
- def wrap(func):
- params = signature(func).parameters
- # ignore unknown env vars
- overrides = {k: v for k, v in env_overrides.items() if k in params}
- # infer overrides' `type`s
- for k in overrides:
- param = params[k]
- if param.annotation is not param.empty: # typehints
- for typ in getattr(param.annotation, '__args__', (param.annotation,)):
- try:
- overrides[k] = typ(overrides[k])
- except Exception:
- pass
- else:
- break
- elif param.default is not None: # type of default value
- overrides[k] = type(param.default)(overrides[k])
- else:
- try: # `types` fallback
- overrides[k] = types[k](overrides[k])
- except KeyError: # keep unconverted (`str`)
- pass
- return part(func, **overrides)
- return wrap
- class FormatReplace(object):
- """
- >>> a = FormatReplace('something')
- >>> "{:5d}".format(a)
- 'something'
- """ # NOQA: P102
- def __init__(self, replace=''):
- self.replace = replace
- self.format_called = 0
- def __format__(self, _):
- self.format_called += 1
- return self.replace
- class Comparable(object):
- """Assumes child has self._comparable attr/@property"""
- def __lt__(self, other):
- return self._comparable < other._comparable
- def __le__(self, other):
- return (self < other) or (self == other)
- def __eq__(self, other):
- return self._comparable == other._comparable
- def __ne__(self, other):
- return not self == other
- def __gt__(self, other):
- return not self <= other
- def __ge__(self, other):
- return not self < other
- class ObjectWrapper(object):
- def __getattr__(self, name):
- return getattr(self._wrapped, name)
- def __setattr__(self, name, value):
- return setattr(self._wrapped, name, value)
- def wrapper_getattr(self, name):
- """Actual `self.getattr` rather than self._wrapped.getattr"""
- try:
- return object.__getattr__(self, name)
- except AttributeError: # py2
- return getattr(self, name)
- def wrapper_setattr(self, name, value):
- """Actual `self.setattr` rather than self._wrapped.setattr"""
- return object.__setattr__(self, name, value)
- def __init__(self, wrapped):
- """
- Thin wrapper around a given object
- """
- self.wrapper_setattr('_wrapped', wrapped)
- class SimpleTextIOWrapper(ObjectWrapper):
- """
- Change only `.write()` of the wrapped object by encoding the passed
- value and passing the result to the wrapped object's `.write()` method.
- """
- # pylint: disable=too-few-public-methods
- def __init__(self, wrapped, encoding):
- super(SimpleTextIOWrapper, self).__init__(wrapped)
- self.wrapper_setattr('encoding', encoding)
- def write(self, s):
- """
- Encode `s` and pass to the wrapped object's `.write()` method.
- """
- return self._wrapped.write(s.encode(self.wrapper_getattr('encoding')))
- def __eq__(self, other):
- return self._wrapped == getattr(other, '_wrapped', other)
- class DisableOnWriteError(ObjectWrapper):
- """
- Disable the given `tqdm_instance` upon `write()` or `flush()` errors.
- """
- @staticmethod
- def disable_on_exception(tqdm_instance, func):
- """
- Quietly set `tqdm_instance.miniters=inf` if `func` raises `errno=5`.
- """
- tqdm_instance = proxy(tqdm_instance)
- def inner(*args, **kwargs):
- try:
- return func(*args, **kwargs)
- except OSError as e:
- if e.errno != 5:
- raise
- try:
- tqdm_instance.miniters = float('inf')
- except ReferenceError:
- pass
- except ValueError as e:
- if 'closed' not in str(e):
- raise
- try:
- tqdm_instance.miniters = float('inf')
- except ReferenceError:
- pass
- return inner
- def __init__(self, wrapped, tqdm_instance):
- super(DisableOnWriteError, self).__init__(wrapped)
- if hasattr(wrapped, 'write'):
- self.wrapper_setattr(
- 'write', self.disable_on_exception(tqdm_instance, wrapped.write))
- if hasattr(wrapped, 'flush'):
- self.wrapper_setattr(
- 'flush', self.disable_on_exception(tqdm_instance, wrapped.flush))
- def __eq__(self, other):
- return self._wrapped == getattr(other, '_wrapped', other)
- class CallbackIOWrapper(ObjectWrapper):
- def __init__(self, callback, stream, method="read"):
- """
- Wrap a given `file`-like object's `read()` or `write()` to report
- lengths to the given `callback`
- """
- super(CallbackIOWrapper, self).__init__(stream)
- func = getattr(stream, method)
- if method == "write":
- @wraps(func)
- def write(data, *args, **kwargs):
- res = func(data, *args, **kwargs)
- callback(len(data))
- return res
- self.wrapper_setattr('write', write)
- elif method == "read":
- @wraps(func)
- def read(*args, **kwargs):
- data = func(*args, **kwargs)
- callback(len(data))
- return data
- self.wrapper_setattr('read', read)
- else:
- raise KeyError("Can only wrap read/write methods")
- def _is_utf(encoding):
- try:
- u'\u2588\u2589'.encode(encoding)
- except UnicodeEncodeError:
- return False
- except Exception:
- try:
- return encoding.lower().startswith('utf-') or ('U8' == encoding)
- except Exception:
- return False
- else:
- return True
- def _supports_unicode(fp):
- try:
- return _is_utf(fp.encoding)
- except AttributeError:
- return False
- def _is_ascii(s):
- if isinstance(s, str):
- for c in s:
- if ord(c) > 255:
- return False
- return True
- return _supports_unicode(s)
- def _screen_shape_wrapper(): # pragma: no cover
- """
- Return a function which returns console dimensions (width, height).
- Supported: linux, osx, windows, cygwin.
- """
- _screen_shape = None
- if IS_WIN:
- _screen_shape = _screen_shape_windows
- if _screen_shape is None:
- _screen_shape = _screen_shape_tput
- if IS_NIX:
- _screen_shape = _screen_shape_linux
- return _screen_shape
- def _screen_shape_windows(fp): # pragma: no cover
- try:
- import struct
- from ctypes import create_string_buffer, windll
- from sys import stdin, stdout
- io_handle = -12 # assume stderr
- if fp == stdin:
- io_handle = -10
- elif fp == stdout:
- io_handle = -11
- h = windll.kernel32.GetStdHandle(io_handle)
- csbi = create_string_buffer(22)
- res = windll.kernel32.GetConsoleScreenBufferInfo(h, csbi)
- if res:
- (_bufx, _bufy, _curx, _cury, _wattr, left, top, right, bottom,
- _maxx, _maxy) = struct.unpack("hhhhHhhhhhh", csbi.raw)
- return right - left, bottom - top # +1
- except Exception: # nosec
- pass
- return None, None
- def _screen_shape_tput(*_): # pragma: no cover
- """cygwin xterm (windows)"""
- try:
- import shlex
- from subprocess import check_call # nosec
- return [int(check_call(shlex.split('tput ' + i))) - 1
- for i in ('cols', 'lines')]
- except Exception: # nosec
- pass
- return None, None
- def _screen_shape_linux(fp): # pragma: no cover
- try:
- from array import array
- from fcntl import ioctl
- from termios import TIOCGWINSZ
- except ImportError:
- return None, None
- else:
- try:
- rows, cols = array('h', ioctl(fp, TIOCGWINSZ, '\0' * 8))[:2]
- return cols, rows
- except Exception:
- try:
- return [int(os.environ[i]) - 1 for i in ("COLUMNS", "LINES")]
- except (KeyError, ValueError):
- return None, None
- def _environ_cols_wrapper(): # pragma: no cover
- """
- Return a function which returns console width.
- Supported: linux, osx, windows, cygwin.
- """
- warn("Use `_screen_shape_wrapper()(file)[0]` instead of"
- " `_environ_cols_wrapper()(file)`", DeprecationWarning, stacklevel=2)
- shape = _screen_shape_wrapper()
- if not shape:
- return None
- @wraps(shape)
- def inner(fp):
- return shape(fp)[0]
- return inner
- def _term_move_up(): # pragma: no cover
- return '' if (os.name == 'nt') and (colorama is None) else '\x1b[A'
- def _text_width(s):
- return sum(2 if east_asian_width(ch) in 'FW' else 1 for ch in str(s))
- def disp_len(data):
- """
- Returns the real on-screen length of a string which may contain
- ANSI control codes and wide chars.
- """
- return _text_width(RE_ANSI.sub('', data))
- def disp_trim(data, length):
- """
- Trim a string which may contain ANSI control characters.
- """
- if len(data) == disp_len(data):
- return data[:length]
- ansi_present = bool(RE_ANSI.search(data))
- while disp_len(data) > length: # carefully delete one char at a time
- data = data[:-1]
- if ansi_present and bool(RE_ANSI.search(data)):
- # assume ANSI reset is required
- return data if data.endswith("\033[0m") else data + "\033[0m"
- return data