__init__.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. """
  2. Thin wrappers around common functions.
  3. Subpackages contain potentially unstable extensions.
  4. """
  5. from warnings import warn
  6. from ..auto import tqdm as tqdm_auto
  7. from ..std import TqdmDeprecationWarning, tqdm
  8. from ..utils import ObjectWrapper
  9. __author__ = {"github.com/": ["casperdcl"]}
  10. __all__ = ['tenumerate', 'tzip', 'tmap']
  11. class DummyTqdmFile(ObjectWrapper):
  12. """Dummy file-like that will write to tqdm"""
  13. def __init__(self, wrapped):
  14. super(DummyTqdmFile, self).__init__(wrapped)
  15. self._buf = []
  16. def write(self, x, nolock=False):
  17. nl = b"\n" if isinstance(x, bytes) else "\n"
  18. pre, sep, post = x.rpartition(nl)
  19. if sep:
  20. blank = type(nl)()
  21. tqdm.write(blank.join(self._buf + [pre, sep]),
  22. end=blank, file=self._wrapped, nolock=nolock)
  23. self._buf = [post]
  24. else:
  25. self._buf.append(x)
  26. def __del__(self):
  27. if self._buf:
  28. blank = type(self._buf[0])()
  29. try:
  30. tqdm.write(blank.join(self._buf), end=blank, file=self._wrapped)
  31. except (OSError, ValueError):
  32. pass
  33. def builtin_iterable(func):
  34. """Returns `func`"""
  35. warn("This function has no effect, and will be removed in tqdm==5.0.0",
  36. TqdmDeprecationWarning, stacklevel=2)
  37. return func
  38. def tenumerate(iterable, start=0, total=None, tqdm_class=tqdm_auto, **tqdm_kwargs):
  39. """
  40. Equivalent of `numpy.ndenumerate` or builtin `enumerate`.
  41. Parameters
  42. ----------
  43. tqdm_class : [default: tqdm.auto.tqdm].
  44. """
  45. try:
  46. import numpy as np
  47. except ImportError:
  48. pass
  49. else:
  50. if isinstance(iterable, np.ndarray):
  51. return tqdm_class(np.ndenumerate(iterable), total=total or iterable.size,
  52. **tqdm_kwargs)
  53. return enumerate(tqdm_class(iterable, total=total, **tqdm_kwargs), start)
  54. def tzip(iter1, *iter2plus, **tqdm_kwargs):
  55. """
  56. Equivalent of builtin `zip`.
  57. Parameters
  58. ----------
  59. tqdm_class : [default: tqdm.auto.tqdm].
  60. """
  61. kwargs = tqdm_kwargs.copy()
  62. tqdm_class = kwargs.pop("tqdm_class", tqdm_auto)
  63. for i in zip(tqdm_class(iter1, **kwargs), *iter2plus):
  64. yield i
  65. def tmap(function, *sequences, **tqdm_kwargs):
  66. """
  67. Equivalent of builtin `map`.
  68. Parameters
  69. ----------
  70. tqdm_class : [default: tqdm.auto.tqdm].
  71. """
  72. for i in tzip(*sequences, **tqdm_kwargs):
  73. yield function(*i)