asyncio.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. """
  2. Asynchronous progressbar decorator for iterators.
  3. Includes a default `range` iterator printing to `stderr`.
  4. Usage:
  5. >>> from tqdm.asyncio import trange, tqdm
  6. >>> async for i in trange(10):
  7. ... ...
  8. """
  9. import asyncio
  10. from sys import version_info
  11. from .std import tqdm as std_tqdm
  12. __author__ = {"github.com/": ["casperdcl"]}
  13. __all__ = ['tqdm_asyncio', 'tarange', 'tqdm', 'trange']
  14. class tqdm_asyncio(std_tqdm):
  15. """
  16. Asynchronous-friendly version of tqdm.
  17. """
  18. def __init__(self, iterable=None, *args, **kwargs):
  19. super(tqdm_asyncio, self).__init__(iterable, *args, **kwargs)
  20. self.iterable_awaitable = False
  21. if iterable is not None:
  22. if hasattr(iterable, "__anext__"):
  23. self.iterable_next = iterable.__anext__
  24. self.iterable_awaitable = True
  25. elif hasattr(iterable, "__next__"):
  26. self.iterable_next = iterable.__next__
  27. else:
  28. self.iterable_iterator = iter(iterable)
  29. self.iterable_next = self.iterable_iterator.__next__
  30. def __aiter__(self):
  31. return self
  32. async def __anext__(self):
  33. try:
  34. if self.iterable_awaitable:
  35. res = await self.iterable_next()
  36. else:
  37. res = self.iterable_next()
  38. self.update()
  39. return res
  40. except StopIteration:
  41. self.close()
  42. raise StopAsyncIteration
  43. except BaseException:
  44. self.close()
  45. raise
  46. def send(self, *args, **kwargs):
  47. return self.iterable.send(*args, **kwargs)
  48. @classmethod
  49. def as_completed(cls, fs, *, loop=None, timeout=None, total=None, **tqdm_kwargs):
  50. """
  51. Wrapper for `asyncio.as_completed`.
  52. """
  53. if total is None:
  54. total = len(fs)
  55. kwargs = {}
  56. if version_info[:2] < (3, 10):
  57. kwargs['loop'] = loop
  58. yield from cls(asyncio.as_completed(fs, timeout=timeout, **kwargs),
  59. total=total, **tqdm_kwargs)
  60. @classmethod
  61. async def gather(cls, *fs, loop=None, timeout=None, total=None, **tqdm_kwargs):
  62. """
  63. Wrapper for `asyncio.gather`.
  64. """
  65. async def wrap_awaitable(i, f):
  66. return i, await f
  67. ifs = [wrap_awaitable(i, f) for i, f in enumerate(fs)]
  68. res = [await f for f in cls.as_completed(ifs, loop=loop, timeout=timeout,
  69. total=total, **tqdm_kwargs)]
  70. return [i for _, i in sorted(res)]
  71. def tarange(*args, **kwargs):
  72. """
  73. A shortcut for `tqdm.asyncio.tqdm(range(*args), **kwargs)`.
  74. """
  75. return tqdm_asyncio(range(*args), **kwargs)
  76. # Aliases
  77. tqdm = tqdm_asyncio
  78. trange = tarange