utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. """
  2. General helpers required for `tqdm.std`.
  3. """
  4. import os
  5. import re
  6. import sys
  7. from functools import partial, partialmethod, wraps
  8. from inspect import signature
  9. # TODO consider using wcswidth third-party package for 0-width characters
  10. from unicodedata import east_asian_width
  11. from warnings import warn
  12. from weakref import proxy
  13. _range, _unich, _unicode, _basestring = range, chr, str, str
  14. CUR_OS = sys.platform
  15. IS_WIN = any(CUR_OS.startswith(i) for i in ['win32', 'cygwin'])
  16. IS_NIX = any(CUR_OS.startswith(i) for i in ['aix', 'linux', 'darwin'])
  17. RE_ANSI = re.compile(r"\x1b\[[;\d]*[A-Za-z]")
  18. try:
  19. if IS_WIN:
  20. import colorama
  21. else:
  22. raise ImportError
  23. except ImportError:
  24. colorama = None
  25. else:
  26. try:
  27. colorama.init(strip=False)
  28. except TypeError:
  29. colorama.init()
  30. def envwrap(prefix, types=None, is_method=False):
  31. """
  32. Override parameter defaults via `os.environ[prefix + param_name]`.
  33. Maps UPPER_CASE env vars map to lower_case param names.
  34. camelCase isn't supported (because Windows ignores case).
  35. Precedence (highest first):
  36. - call (`foo(a=3)`)
  37. - environ (`FOO_A=2`)
  38. - signature (`def foo(a=1)`)
  39. Parameters
  40. ----------
  41. prefix : str
  42. Env var prefix, e.g. "FOO_"
  43. types : dict, optional
  44. Fallback mappings `{'param_name': type, ...}` if types cannot be
  45. inferred from function signature.
  46. Consider using `types=collections.defaultdict(lambda: ast.literal_eval)`.
  47. is_method : bool, optional
  48. Whether to use `functools.partialmethod`. If (default: False) use `functools.partial`.
  49. Examples
  50. --------
  51. ```
  52. $ cat foo.py
  53. from tqdm.utils import envwrap
  54. @envwrap("FOO_")
  55. def test(a=1, b=2, c=3):
  56. print(f"received: a={a}, b={b}, c={c}")
  57. $ FOO_A=42 FOO_C=1337 python -c 'import foo; foo.test(c=99)'
  58. received: a=42, b=2, c=99
  59. ```
  60. """
  61. if types is None:
  62. types = {}
  63. i = len(prefix)
  64. env_overrides = {k[i:].lower(): v for k, v in os.environ.items() if k.startswith(prefix)}
  65. part = partialmethod if is_method else partial
  66. def wrap(func):
  67. params = signature(func).parameters
  68. # ignore unknown env vars
  69. overrides = {k: v for k, v in env_overrides.items() if k in params}
  70. # infer overrides' `type`s
  71. for k in overrides:
  72. param = params[k]
  73. if param.annotation is not param.empty: # typehints
  74. for typ in getattr(param.annotation, '__args__', (param.annotation,)):
  75. try:
  76. overrides[k] = typ(overrides[k])
  77. except Exception:
  78. pass
  79. else:
  80. break
  81. elif param.default is not None: # type of default value
  82. overrides[k] = type(param.default)(overrides[k])
  83. else:
  84. try: # `types` fallback
  85. overrides[k] = types[k](overrides[k])
  86. except KeyError: # keep unconverted (`str`)
  87. pass
  88. return part(func, **overrides)
  89. return wrap
  90. class FormatReplace(object):
  91. """
  92. >>> a = FormatReplace('something')
  93. >>> "{:5d}".format(a)
  94. 'something'
  95. """ # NOQA: P102
  96. def __init__(self, replace=''):
  97. self.replace = replace
  98. self.format_called = 0
  99. def __format__(self, _):
  100. self.format_called += 1
  101. return self.replace
  102. class Comparable(object):
  103. """Assumes child has self._comparable attr/@property"""
  104. def __lt__(self, other):
  105. return self._comparable < other._comparable
  106. def __le__(self, other):
  107. return (self < other) or (self == other)
  108. def __eq__(self, other):
  109. return self._comparable == other._comparable
  110. def __ne__(self, other):
  111. return not self == other
  112. def __gt__(self, other):
  113. return not self <= other
  114. def __ge__(self, other):
  115. return not self < other
  116. class ObjectWrapper(object):
  117. def __getattr__(self, name):
  118. return getattr(self._wrapped, name)
  119. def __setattr__(self, name, value):
  120. return setattr(self._wrapped, name, value)
  121. def wrapper_getattr(self, name):
  122. """Actual `self.getattr` rather than self._wrapped.getattr"""
  123. try:
  124. return object.__getattr__(self, name)
  125. except AttributeError: # py2
  126. return getattr(self, name)
  127. def wrapper_setattr(self, name, value):
  128. """Actual `self.setattr` rather than self._wrapped.setattr"""
  129. return object.__setattr__(self, name, value)
  130. def __init__(self, wrapped):
  131. """
  132. Thin wrapper around a given object
  133. """
  134. self.wrapper_setattr('_wrapped', wrapped)
  135. class SimpleTextIOWrapper(ObjectWrapper):
  136. """
  137. Change only `.write()` of the wrapped object by encoding the passed
  138. value and passing the result to the wrapped object's `.write()` method.
  139. """
  140. # pylint: disable=too-few-public-methods
  141. def __init__(self, wrapped, encoding):
  142. super(SimpleTextIOWrapper, self).__init__(wrapped)
  143. self.wrapper_setattr('encoding', encoding)
  144. def write(self, s):
  145. """
  146. Encode `s` and pass to the wrapped object's `.write()` method.
  147. """
  148. return self._wrapped.write(s.encode(self.wrapper_getattr('encoding')))
  149. def __eq__(self, other):
  150. return self._wrapped == getattr(other, '_wrapped', other)
  151. class DisableOnWriteError(ObjectWrapper):
  152. """
  153. Disable the given `tqdm_instance` upon `write()` or `flush()` errors.
  154. """
  155. @staticmethod
  156. def disable_on_exception(tqdm_instance, func):
  157. """
  158. Quietly set `tqdm_instance.miniters=inf` if `func` raises `errno=5`.
  159. """
  160. tqdm_instance = proxy(tqdm_instance)
  161. def inner(*args, **kwargs):
  162. try:
  163. return func(*args, **kwargs)
  164. except OSError as e:
  165. if e.errno != 5:
  166. raise
  167. try:
  168. tqdm_instance.miniters = float('inf')
  169. except ReferenceError:
  170. pass
  171. except ValueError as e:
  172. if 'closed' not in str(e):
  173. raise
  174. try:
  175. tqdm_instance.miniters = float('inf')
  176. except ReferenceError:
  177. pass
  178. return inner
  179. def __init__(self, wrapped, tqdm_instance):
  180. super(DisableOnWriteError, self).__init__(wrapped)
  181. if hasattr(wrapped, 'write'):
  182. self.wrapper_setattr(
  183. 'write', self.disable_on_exception(tqdm_instance, wrapped.write))
  184. if hasattr(wrapped, 'flush'):
  185. self.wrapper_setattr(
  186. 'flush', self.disable_on_exception(tqdm_instance, wrapped.flush))
  187. def __eq__(self, other):
  188. return self._wrapped == getattr(other, '_wrapped', other)
  189. class CallbackIOWrapper(ObjectWrapper):
  190. def __init__(self, callback, stream, method="read"):
  191. """
  192. Wrap a given `file`-like object's `read()` or `write()` to report
  193. lengths to the given `callback`
  194. """
  195. super(CallbackIOWrapper, self).__init__(stream)
  196. func = getattr(stream, method)
  197. if method == "write":
  198. @wraps(func)
  199. def write(data, *args, **kwargs):
  200. res = func(data, *args, **kwargs)
  201. callback(len(data))
  202. return res
  203. self.wrapper_setattr('write', write)
  204. elif method == "read":
  205. @wraps(func)
  206. def read(*args, **kwargs):
  207. data = func(*args, **kwargs)
  208. callback(len(data))
  209. return data
  210. self.wrapper_setattr('read', read)
  211. else:
  212. raise KeyError("Can only wrap read/write methods")
  213. def _is_utf(encoding):
  214. try:
  215. u'\u2588\u2589'.encode(encoding)
  216. except UnicodeEncodeError:
  217. return False
  218. except Exception:
  219. try:
  220. return encoding.lower().startswith('utf-') or ('U8' == encoding)
  221. except Exception:
  222. return False
  223. else:
  224. return True
  225. def _supports_unicode(fp):
  226. try:
  227. return _is_utf(fp.encoding)
  228. except AttributeError:
  229. return False
  230. def _is_ascii(s):
  231. if isinstance(s, str):
  232. for c in s:
  233. if ord(c) > 255:
  234. return False
  235. return True
  236. return _supports_unicode(s)
  237. def _screen_shape_wrapper(): # pragma: no cover
  238. """
  239. Return a function which returns console dimensions (width, height).
  240. Supported: linux, osx, windows, cygwin.
  241. """
  242. _screen_shape = None
  243. if IS_WIN:
  244. _screen_shape = _screen_shape_windows
  245. if _screen_shape is None:
  246. _screen_shape = _screen_shape_tput
  247. if IS_NIX:
  248. _screen_shape = _screen_shape_linux
  249. return _screen_shape
  250. def _screen_shape_windows(fp): # pragma: no cover
  251. try:
  252. import struct
  253. from ctypes import create_string_buffer, windll
  254. from sys import stdin, stdout
  255. io_handle = -12 # assume stderr
  256. if fp == stdin:
  257. io_handle = -10
  258. elif fp == stdout:
  259. io_handle = -11
  260. h = windll.kernel32.GetStdHandle(io_handle)
  261. csbi = create_string_buffer(22)
  262. res = windll.kernel32.GetConsoleScreenBufferInfo(h, csbi)
  263. if res:
  264. (_bufx, _bufy, _curx, _cury, _wattr, left, top, right, bottom,
  265. _maxx, _maxy) = struct.unpack("hhhhHhhhhhh", csbi.raw)
  266. return right - left, bottom - top # +1
  267. except Exception: # nosec
  268. pass
  269. return None, None
  270. def _screen_shape_tput(*_): # pragma: no cover
  271. """cygwin xterm (windows)"""
  272. try:
  273. import shlex
  274. from subprocess import check_call # nosec
  275. return [int(check_call(shlex.split('tput ' + i))) - 1
  276. for i in ('cols', 'lines')]
  277. except Exception: # nosec
  278. pass
  279. return None, None
  280. def _screen_shape_linux(fp): # pragma: no cover
  281. try:
  282. from array import array
  283. from fcntl import ioctl
  284. from termios import TIOCGWINSZ
  285. except ImportError:
  286. return None, None
  287. else:
  288. try:
  289. rows, cols = array('h', ioctl(fp, TIOCGWINSZ, '\0' * 8))[:2]
  290. return cols, rows
  291. except Exception:
  292. try:
  293. return [int(os.environ[i]) - 1 for i in ("COLUMNS", "LINES")]
  294. except (KeyError, ValueError):
  295. return None, None
  296. def _environ_cols_wrapper(): # pragma: no cover
  297. """
  298. Return a function which returns console width.
  299. Supported: linux, osx, windows, cygwin.
  300. """
  301. warn("Use `_screen_shape_wrapper()(file)[0]` instead of"
  302. " `_environ_cols_wrapper()(file)`", DeprecationWarning, stacklevel=2)
  303. shape = _screen_shape_wrapper()
  304. if not shape:
  305. return None
  306. @wraps(shape)
  307. def inner(fp):
  308. return shape(fp)[0]
  309. return inner
  310. def _term_move_up(): # pragma: no cover
  311. return '' if (os.name == 'nt') and (colorama is None) else '\x1b[A'
  312. def _text_width(s):
  313. return sum(2 if east_asian_width(ch) in 'FW' else 1 for ch in str(s))
  314. def disp_len(data):
  315. """
  316. Returns the real on-screen length of a string which may contain
  317. ANSI control codes and wide chars.
  318. """
  319. return _text_width(RE_ANSI.sub('', data))
  320. def disp_trim(data, length):
  321. """
  322. Trim a string which may contain ANSI control characters.
  323. """
  324. if len(data) == disp_len(data):
  325. return data[:length]
  326. ansi_present = bool(RE_ANSI.search(data))
  327. while disp_len(data) > length: # carefully delete one char at a time
  328. data = data[:-1]
  329. if ansi_present and bool(RE_ANSI.search(data)):
  330. # assume ANSI reset is required
  331. return data if data.endswith("\033[0m") else data + "\033[0m"
  332. return data