concurrent.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. """
  2. Thin wrappers around `concurrent.futures`.
  3. """
  4. from contextlib import contextmanager
  5. from operator import length_hint
  6. from os import cpu_count
  7. from ..auto import tqdm as tqdm_auto
  8. from ..std import TqdmWarning
  9. __author__ = {"github.com/": ["casperdcl"]}
  10. __all__ = ['thread_map', 'process_map']
  11. @contextmanager
  12. def ensure_lock(tqdm_class, lock_name=""):
  13. """get (create if necessary) and then restore `tqdm_class`'s lock"""
  14. old_lock = getattr(tqdm_class, '_lock', None) # don't create a new lock
  15. lock = old_lock or tqdm_class.get_lock() # maybe create a new lock
  16. lock = getattr(lock, lock_name, lock) # maybe subtype
  17. tqdm_class.set_lock(lock)
  18. yield lock
  19. if old_lock is None:
  20. del tqdm_class._lock
  21. else:
  22. tqdm_class.set_lock(old_lock)
  23. def _executor_map(PoolExecutor, fn, *iterables, **tqdm_kwargs):
  24. """
  25. Implementation of `thread_map` and `process_map`.
  26. Parameters
  27. ----------
  28. tqdm_class : [default: tqdm.auto.tqdm].
  29. max_workers : [default: min(32, cpu_count() + 4)].
  30. chunksize : [default: 1].
  31. lock_name : [default: "":str].
  32. """
  33. kwargs = tqdm_kwargs.copy()
  34. if "total" not in kwargs:
  35. kwargs["total"] = length_hint(iterables[0])
  36. tqdm_class = kwargs.pop("tqdm_class", tqdm_auto)
  37. max_workers = kwargs.pop("max_workers", min(32, cpu_count() + 4))
  38. chunksize = kwargs.pop("chunksize", 1)
  39. lock_name = kwargs.pop("lock_name", "")
  40. with ensure_lock(tqdm_class, lock_name=lock_name) as lk:
  41. # share lock in case workers are already using `tqdm`
  42. with PoolExecutor(max_workers=max_workers, initializer=tqdm_class.set_lock,
  43. initargs=(lk,)) as ex:
  44. return list(tqdm_class(ex.map(fn, *iterables, chunksize=chunksize), **kwargs))
  45. def thread_map(fn, *iterables, **tqdm_kwargs):
  46. """
  47. Equivalent of `list(map(fn, *iterables))`
  48. driven by `concurrent.futures.ThreadPoolExecutor`.
  49. Parameters
  50. ----------
  51. tqdm_class : optional
  52. `tqdm` class to use for bars [default: tqdm.auto.tqdm].
  53. max_workers : int, optional
  54. Maximum number of workers to spawn; passed to
  55. `concurrent.futures.ThreadPoolExecutor.__init__`.
  56. [default: max(32, cpu_count() + 4)].
  57. """
  58. from concurrent.futures import ThreadPoolExecutor
  59. return _executor_map(ThreadPoolExecutor, fn, *iterables, **tqdm_kwargs)
  60. def process_map(fn, *iterables, **tqdm_kwargs):
  61. """
  62. Equivalent of `list(map(fn, *iterables))`
  63. driven by `concurrent.futures.ProcessPoolExecutor`.
  64. Parameters
  65. ----------
  66. tqdm_class : optional
  67. `tqdm` class to use for bars [default: tqdm.auto.tqdm].
  68. max_workers : int, optional
  69. Maximum number of workers to spawn; passed to
  70. `concurrent.futures.ProcessPoolExecutor.__init__`.
  71. [default: min(32, cpu_count() + 4)].
  72. chunksize : int, optional
  73. Size of chunks sent to worker processes; passed to
  74. `concurrent.futures.ProcessPoolExecutor.map`. [default: 1].
  75. lock_name : str, optional
  76. Member of `tqdm_class.get_lock()` to use [default: mp_lock].
  77. """
  78. from concurrent.futures import ProcessPoolExecutor
  79. if iterables and "chunksize" not in tqdm_kwargs:
  80. # default `chunksize=1` has poor performance for large iterables
  81. # (most time spent dispatching items to workers).
  82. longest_iterable_len = max(map(length_hint, iterables))
  83. if longest_iterable_len > 1000:
  84. from warnings import warn
  85. warn("Iterable length %d > 1000 but `chunksize` is not set."
  86. " This may seriously degrade multiprocess performance."
  87. " Set `chunksize=1` or more." % longest_iterable_len,
  88. TqdmWarning, stacklevel=2)
  89. if "lock_name" not in tqdm_kwargs:
  90. tqdm_kwargs = tqdm_kwargs.copy()
  91. tqdm_kwargs["lock_name"] = "mp_lock"
  92. return _executor_map(ProcessPoolExecutor, fn, *iterables, **tqdm_kwargs)