123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479 |
- r"""Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter
- To support these two classes, in `./_utils` we define many utility methods and
- functions to be run in multiprocessing. E.g., the data loading worker loop is
- in `./_utils/worker.py`.
- """
- import functools
- import itertools
- import logging
- import os
- import queue
- import threading
- import warnings
- from typing import Any, Callable, Iterable, TypeVar, Generic, Sequence, List, Optional, Union
- import multiprocessing as python_multiprocessing
- import torch
- import torch.distributed as dist
- import torch.multiprocessing as multiprocessing
- import torch.utils.data.graph_settings
- from torch._utils import ExceptionWrapper
- from . import (
- IterDataPipe,
- MapDataPipe,
- IterableDataset,
- Sampler,
- SequentialSampler,
- RandomSampler,
- BatchSampler,
- Dataset,)
- from torch.utils.data.datapipes.datapipe import _IterDataPipeSerializationWrapper, _MapDataPipeSerializationWrapper
- from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES
- from . import _utils
- __all__ = [
- "DataLoader",
- "get_worker_info",
- "default_collate",
- "default_convert",
- ]
- T_co = TypeVar('T_co', covariant=True)
- T = TypeVar('T')
- _worker_init_fn_t = Callable[[int], None]
- _collate_fn_t = Callable[[List[T]], Any]
- default_collate: _collate_fn_t = _utils.collate.default_collate
- default_convert = _utils.collate.default_convert
- get_worker_info = _utils.worker.get_worker_info
- logger = logging.getLogger(__name__)
- class _DatasetKind:
- Map = 0
- Iterable = 1
- @staticmethod
- def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
- if kind == _DatasetKind.Map:
- return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
- else:
- return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
- class _InfiniteConstantSampler(Sampler):
- r"""Analogous to ``itertools.repeat(None, None)``.
- Used as sampler for :class:`~torch.utils.data.IterableDataset`.
- Args:
- data_source (Dataset): dataset to sample from
- """
- def __init__(self):
- super().__init__(None)
- def __iter__(self):
- while True:
- yield None
- def _get_distributed_settings():
- if dist.is_available() and dist.is_initialized():
- return dist.get_world_size(), dist.get_rank()
- else:
- return 1, 0
- def _sharding_worker_init_fn(worker_init_fn, world_size, rank_id, worker_id):
- info = torch.utils.data.get_worker_info()
- assert info is not None
- total_workers = info.num_workers
- datapipe = info.dataset
- assert isinstance(datapipe, (IterDataPipe, MapDataPipe))
-
-
- torch.utils.data.graph_settings.apply_sharding(
- datapipe, world_size, rank_id, sharding_group=SHARDING_PRIORITIES.DISTRIBUTED)
- torch.utils.data.graph_settings.apply_sharding(
- datapipe, total_workers, worker_id, sharding_group=SHARDING_PRIORITIES.MULTIPROCESSING)
- if worker_init_fn is not None:
- worker_init_fn(worker_id)
- def _share_dist_seed(generator, pg):
- _shared_seed = torch.empty((), dtype=torch.int64).random_(generator=generator)
- if isinstance(pg, dist.ProcessGroup):
- dist.broadcast(_shared_seed, src=0, group=pg)
- return _shared_seed.item()
- class DataLoader(Generic[T_co]):
- r"""
- Data loader. Combines a dataset and a sampler, and provides an iterable over
- the given dataset.
- The :class:`~torch.utils.data.DataLoader` supports both map-style and
- iterable-style datasets with single- or multi-process loading, customizing
- loading order and optional automatic batching (collation) and memory pinning.
- See :py:mod:`torch.utils.data` documentation page for more details.
- Args:
- dataset (Dataset): dataset from which to load the data.
- batch_size (int, optional): how many samples per batch to load
- (default: ``1``).
- shuffle (bool, optional): set to ``True`` to have the data reshuffled
- at every epoch (default: ``False``).
- sampler (Sampler or Iterable, optional): defines the strategy to draw
- samples from the dataset. Can be any ``Iterable`` with ``__len__``
- implemented. If specified, :attr:`shuffle` must not be specified.
- batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but
- returns a batch of indices at a time. Mutually exclusive with
- :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`,
- and :attr:`drop_last`.
- num_workers (int, optional): how many subprocesses to use for data
- loading. ``0`` means that the data will be loaded in the main process.
- (default: ``0``)
- collate_fn (Callable, optional): merges a list of samples to form a
- mini-batch of Tensor(s). Used when using batched loading from a
- map-style dataset.
- pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
- into device/CUDA pinned memory before returning them. If your data elements
- are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
- see the example below.
- drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
- if the dataset size is not divisible by the batch size. If ``False`` and
- the size of dataset is not divisible by the batch size, then the last batch
- will be smaller. (default: ``False``)
- timeout (numeric, optional): if positive, the timeout value for collecting a batch
- from workers. Should always be non-negative. (default: ``0``)
- worker_init_fn (Callable, optional): If not ``None``, this will be called on each
- worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
- input, after seeding and before data loading. (default: ``None``)
- generator (torch.Generator, optional): If not ``None``, this RNG will be used
- by RandomSampler to generate random indexes and multiprocessing to generate
- `base_seed` for workers. (default: ``None``)
- prefetch_factor (int, optional, keyword-only arg): Number of batches loaded
- in advance by each worker. ``2`` means there will be a total of
- 2 * num_workers batches prefetched across all workers. (default value depends
- on the set value for num_workers. If value of num_workers=0 default is ``None``.
- Otherwise if value of num_workers>0 default is ``2``).
- persistent_workers (bool, optional): If ``True``, the data loader will not shutdown
- the worker processes after a dataset has been consumed once. This allows to
- maintain the workers `Dataset` instances alive. (default: ``False``)
- pin_memory_device (str, optional): the data loader will copy Tensors
- into device pinned memory before returning them if pin_memory is set to true.
- .. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn`
- cannot be an unpicklable object, e.g., a lambda function. See
- :ref:`multiprocessing-best-practices` on more details related
- to multiprocessing in PyTorch.
- .. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used.
- When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`,
- it instead returns an estimate based on ``len(dataset) / batch_size``, with proper
- rounding depending on :attr:`drop_last`, regardless of multi-process loading
- configurations. This represents the best guess PyTorch can make because PyTorch
- trusts user :attr:`dataset` code in correctly handling multi-process
- loading to avoid duplicate data.
- However, if sharding results in multiple workers having incomplete last batches,
- this estimate can still be inaccurate, because (1) an otherwise complete batch can
- be broken into multiple ones and (2) more than one batch worth of samples can be
- dropped when :attr:`drop_last` is set. Unfortunately, PyTorch can not detect such
- cases in general.
- See `Dataset Types`_ for more details on these two types of datasets and how
- :class:`~torch.utils.data.IterableDataset` interacts with
- `Multi-process data loading`_.
- .. warning:: See :ref:`reproducibility`, and :ref:`dataloader-workers-random-seed`, and
- :ref:`data-loading-randomness` notes for random seed related questions.
- """
- dataset: Dataset[T_co]
- batch_size: Optional[int]
- num_workers: int
- pin_memory: bool
- drop_last: bool
- timeout: float
- sampler: Union[Sampler, Iterable]
- pin_memory_device: str
- prefetch_factor: Optional[int]
- _iterator : Optional['_BaseDataLoaderIter']
- __initialized = False
- def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
- shuffle: Optional[bool] = None, sampler: Union[Sampler, Iterable, None] = None,
- batch_sampler: Union[Sampler[Sequence], Iterable[Sequence], None] = None,
- num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None,
- pin_memory: bool = False, drop_last: bool = False,
- timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None,
- multiprocessing_context=None, generator=None,
- *, prefetch_factor: Optional[int] = None,
- persistent_workers: bool = False,
- pin_memory_device: str = ""):
- torch._C._log_api_usage_once("python.data_loader")
- if num_workers < 0:
- raise ValueError('num_workers option should be non-negative; '
- 'use num_workers=0 to disable multiprocessing.')
- if timeout < 0:
- raise ValueError('timeout option should be non-negative')
- if num_workers == 0 and prefetch_factor is not None:
- raise ValueError('prefetch_factor option could only be specified in multiprocessing.'
- 'let num_workers > 0 to enable multiprocessing, otherwise set prefetch_factor to None.')
- elif num_workers > 0 and prefetch_factor is None:
- prefetch_factor = 2
- elif prefetch_factor is not None and prefetch_factor < 0:
- raise ValueError('prefetch_factor option should be non-negative')
- if persistent_workers and num_workers == 0:
- raise ValueError('persistent_workers option needs num_workers > 0')
- self.dataset = dataset
- self.num_workers = num_workers
- self.prefetch_factor = prefetch_factor
- self.pin_memory = pin_memory
- self.pin_memory_device = pin_memory_device
- self.timeout = timeout
- self.worker_init_fn = worker_init_fn
- self.multiprocessing_context = multiprocessing_context
-
-
- if isinstance(self.dataset, IterDataPipe):
- self.dataset = _IterDataPipeSerializationWrapper(self.dataset)
- elif isinstance(self.dataset, MapDataPipe):
- self.dataset = _MapDataPipeSerializationWrapper(self.dataset)
-
-
-
-
- if isinstance(dataset, IterableDataset):
- self._dataset_kind = _DatasetKind.Iterable
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- if isinstance(dataset, IterDataPipe):
- if shuffle is not None:
- dataset = torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle)
-
- elif shuffle not in {False, None}:
- raise ValueError(
- "DataLoader with IterableDataset: expected unspecified "
- "shuffle option, but got shuffle={}".format(shuffle))
- if sampler is not None:
-
- raise ValueError(
- "DataLoader with IterableDataset: expected unspecified "
- "sampler option, but got sampler={}".format(sampler))
- elif batch_sampler is not None:
-
- raise ValueError(
- "DataLoader with IterableDataset: expected unspecified "
- "batch_sampler option, but got batch_sampler={}".format(batch_sampler))
- else:
- shuffle = bool(shuffle)
- self._dataset_kind = _DatasetKind.Map
- if sampler is not None and shuffle:
- raise ValueError('sampler option is mutually exclusive with '
- 'shuffle')
- if batch_sampler is not None:
-
- if batch_size != 1 or shuffle or sampler is not None or drop_last:
- raise ValueError('batch_sampler option is mutually exclusive '
- 'with batch_size, shuffle, sampler, and '
- 'drop_last')
- batch_size = None
- drop_last = False
- elif batch_size is None:
-
- if drop_last:
- raise ValueError('batch_size=None option disables auto-batching '
- 'and is mutually exclusive with drop_last')
- if sampler is None:
- if self._dataset_kind == _DatasetKind.Iterable:
-
- sampler = _InfiniteConstantSampler()
- else:
- if shuffle:
- sampler = RandomSampler(dataset, generator=generator)
- else:
- sampler = SequentialSampler(dataset)
- if batch_size is not None and batch_sampler is None:
-
- batch_sampler = BatchSampler(sampler, batch_size, drop_last)
- self.batch_size = batch_size
- self.drop_last = drop_last
- self.sampler = sampler
- self.batch_sampler = batch_sampler
- self.generator = generator
- if collate_fn is None:
- if self._auto_collation:
- collate_fn = _utils.collate.default_collate
- else:
- collate_fn = _utils.collate.default_convert
- self.collate_fn = collate_fn
- self.persistent_workers = persistent_workers
- self.__initialized = True
- self._IterableDataset_len_called = None
- self._iterator = None
- self.check_worker_number_rationality()
- torch.set_vital('Dataloader', 'enabled', 'True')
- def _get_iterator(self) -> '_BaseDataLoaderIter':
- if self.num_workers == 0:
- return _SingleProcessDataLoaderIter(self)
- else:
- self.check_worker_number_rationality()
- return _MultiProcessingDataLoaderIter(self)
- @property
- def multiprocessing_context(self):
- return self.__multiprocessing_context
- @multiprocessing_context.setter
- def multiprocessing_context(self, multiprocessing_context):
- if multiprocessing_context is not None:
- if self.num_workers > 0:
- if isinstance(multiprocessing_context, str):
- valid_start_methods = multiprocessing.get_all_start_methods()
- if multiprocessing_context not in valid_start_methods:
- raise ValueError(
- ('multiprocessing_context option '
- 'should specify a valid start method in {!r}, but got '
- 'multiprocessing_context={!r}').format(valid_start_methods, multiprocessing_context))
-
- multiprocessing_context = multiprocessing.get_context(multiprocessing_context)
- if not isinstance(multiprocessing_context, python_multiprocessing.context.BaseContext):
- raise TypeError(('multiprocessing_context option should be a valid context '
- 'object or a string specifying the start method, but got '
- 'multiprocessing_context={}').format(multiprocessing_context))
- else:
- raise ValueError(('multiprocessing_context can only be used with '
- 'multi-process loading (num_workers > 0), but got '
- 'num_workers={}').format(self.num_workers))
- self.__multiprocessing_context = multiprocessing_context
- def __setattr__(self, attr, val):
- if self.__initialized and attr in (
- 'batch_size', 'batch_sampler', 'sampler', 'drop_last', 'dataset', 'persistent_workers'):
- raise ValueError('{} attribute should not be set after {} is '
- 'initialized'.format(attr, self.__class__.__name__))
- super().__setattr__(attr, val)
-
-
- def __iter__(self) -> '_BaseDataLoaderIter':
-
-
-
-
-
- if self.persistent_workers and self.num_workers > 0:
- if self._iterator is None:
- self._iterator = self._get_iterator()
- else:
- self._iterator._reset(self)
- return self._iterator
- else:
- return self._get_iterator()
- @property
- def _auto_collation(self):
- return self.batch_sampler is not None
- @property
- def _index_sampler(self):
-
-
-
-
-
- if self._auto_collation:
- return self.batch_sampler
- else:
- return self.sampler
- def __len__(self) -> int:
- if self._dataset_kind == _DatasetKind.Iterable:
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- length = self._IterableDataset_len_called = len(self.dataset)
- if self.batch_size is not None:
- from math import ceil
- if self.drop_last:
- length = length // self.batch_size
- else:
- length = ceil(length / self.batch_size)
- return length
- else:
- return len(self._index_sampler)
- def check_worker_number_rationality(self):
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- def _create_warning_msg(num_worker_suggest, num_worker_created, cpuset_checked):
- suggested_max_worker_msg = ((
- "Our suggested max number of worker in current system is {}{}, which is smaller "
- "than what this DataLoader is going to create.").format(
- num_worker_suggest,
- ("" if cpuset_checked else " (`cpuset` is not taken into account)"))
- ) if num_worker_suggest is not None else (
- "DataLoader is not able to compute a suggested max number of worker in current system.")
- warn_msg = (
- "This DataLoader will create {} worker processes in total. {} "
- "Please be aware that excessive worker creation might get DataLoader running slow or even freeze, "
- "lower the worker number to avoid potential slowness/freeze if necessary.").format(
- num_worker_created,
- suggested_max_worker_msg)
- return warn_msg
- if not self.num_workers or self.num_workers == 0:
- return
-
- max_num_worker_suggest = None
- cpuset_checked = False
- if hasattr(os, 'sched_getaffinity'):
- try:
- max_num_worker_suggest = len(os.sched_getaffinity(0))
- cpuset_checked = True
- except Exception:
- pass
- if max_num_worker_suggest is None:
-
-
- cpu_count = os.cpu_count()
- if cpu_count is not None:
- max_num_worker_suggest = cpu_count
- if max_num_worker_suggest is None:
- warnings.warn(_create_warning_msg(
- max_num_worker_suggest,
- self.num_workers,
- cpuset_checked))
- return
- if self.num_workers > max_num_worker_suggest:
- warnings.warn(_create_warning_msg(
- max_num_worker_suggest,
- self.num_workers,
- cpuset_checked))
- class _BaseDataLoaderIter:
- def __init__(self, loader: DataLoader) -> None:
- self._dataset = loader.dataset
- self._shared_seed = None
- self._pg = None
- if isinstance(self._dataset, IterDataPipe):
- if dist.is_available() and dist.is_initialized():
- self._pg = dist.new_group(backend="gloo")
- self._shared_seed = _share_dist_seed(loader.generator, self._pg)
- shared_rng = torch.Generator()
- shared_rng.manual_seed(self._shared_seed)
- self._dataset = torch.utils.data.graph_settings.apply_random_seed(self._dataset, shared_rng)
- self._dataset_kind = loader._dataset_kind
- self._IterableDataset_len_called = loader._IterableDataset_len_called
- self._auto_collation = loader._auto_collation
- self._drop_last = loader.drop_last
- self._index_sampler = loader._index_sampler
- self._num_workers = loader.num_workers
- ws, rank = _get_distributed_settings()
- self._world_size = ws
- self._rank = rank
-
-
-
- if (len(loader.pin_memory_device) == 0):
- self._pin_memory = loader.pin_memory and torch.cuda.is_available()
- self._pin_memory_device = None
- else:
- if not loader.pin_memory:
- warn_msg = ("pin memory device is set and pin_memory flag is not used then device pinned memory won't be used"
- "please set pin_memory to true, if you need to use the device pin memory")
- warnings.warn(warn_msg)
- self._pin_memory = loader.pin_memory
- self._pin_memory_device = loader.pin_memory_device
- self._timeout = loader.timeout
- self._collate_fn = loader.collate_fn
- self._sampler_iter = iter(self._index_sampler)
- self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item()
- self._persistent_workers = loader.persistent_workers
- self._num_yielded = 0
- self._profile_name = "enumerate(DataLoader)#{}.__next__".format(self.__class__.__name__)
- def __iter__(self) -> '_BaseDataLoaderIter':
- return self
- def _reset(self, loader, first_iter=False):
- self._sampler_iter = iter(self._index_sampler)
- self._num_yielded = 0
- self._IterableDataset_len_called = loader._IterableDataset_len_called
- if isinstance(self._dataset, IterDataPipe):
- self._shared_seed = _share_dist_seed(loader.generator, self._pg)
- shared_rng = torch.Generator()
- shared_rng.manual_seed(self._shared_seed)
- self._dataset = torch.utils.data.graph_settings.apply_random_seed(self._dataset, shared_rng)
- def _next_index(self):
- return next(self._sampler_iter)
- def _next_data(self):
- raise NotImplementedError
- def __next__(self) -> Any:
- with torch.autograd.profiler.record_function(self._profile_name):
- if self._sampler_iter is None:
-
- self._reset()
- data = self._next_data()
- self._num_yielded += 1
- if self._dataset_kind == _DatasetKind.Iterable and \
- self._IterableDataset_len_called is not None and \
- self._num_yielded > self._IterableDataset_len_called:
- warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
- "samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called,
- self._num_yielded)
- if self._num_workers > 0:
- warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the "
- "IterableDataset replica at each worker. Please see "
- "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.")
- warnings.warn(warn_msg)
- return data
- def __len__(self) -> int:
- return len(self._index_sampler)
- def __getstate__(self):
-
-
-
-
-
- raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
- class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
- def __init__(self, loader):
- super().__init__(loader)
- assert self._timeout == 0
- assert self._num_workers == 0
-
-
- if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
- torch.utils.data.graph_settings.apply_sharding(
- self._dataset, self._world_size, self._rank, sharding_group=SHARDING_PRIORITIES.DISTRIBUTED)
- self._dataset_fetcher = _DatasetKind.create_fetcher(
- self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
- def _next_data(self):
- index = self._next_index()
- data = self._dataset_fetcher.fetch(index)
- if self._pin_memory:
- data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
- return data
- class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
- r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- def __init__(self, loader):
- super().__init__(loader)
- self._prefetch_factor = loader.prefetch_factor
- assert self._num_workers > 0
- assert self._prefetch_factor > 0
- if loader.multiprocessing_context is None:
- multiprocessing_context = multiprocessing
- else:
- multiprocessing_context = loader.multiprocessing_context
- self._worker_init_fn = loader.worker_init_fn
-
-
- if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
- self._worker_init_fn = functools.partial(
- _sharding_worker_init_fn, self._worker_init_fn, self._world_size, self._rank)
-
- self._worker_result_queue = multiprocessing_context.Queue()
- self._worker_pids_set = False
- self._shutdown = False
- self._workers_done_event = multiprocessing_context.Event()
- self._index_queues = []
- self._workers = []
- for i in range(self._num_workers):
-
- index_queue = multiprocessing_context.Queue()
-
-
- index_queue.cancel_join_thread()
- w = multiprocessing_context.Process(
- target=_utils.worker._worker_loop,
- args=(self._dataset_kind, self._dataset, index_queue,
- self._worker_result_queue, self._workers_done_event,
- self._auto_collation, self._collate_fn, self._drop_last,
- self._base_seed, self._worker_init_fn, i, self._num_workers,
- self._persistent_workers, self._shared_seed))
- w.daemon = True
-
-
-
-
-
-
- w.start()
- self._index_queues.append(index_queue)
- self._workers.append(w)
- if self._pin_memory:
- self._pin_memory_thread_done_event = threading.Event()
-
- self._data_queue = queue.Queue()
- if self._pin_memory_device == "xpu":
- current_device = torch.xpu.current_device()
- else:
- current_device = torch.cuda.current_device()
- pin_memory_thread = threading.Thread(
- target=_utils.pin_memory._pin_memory_loop,
- args=(self._worker_result_queue, self._data_queue,
- current_device,
- self._pin_memory_thread_done_event, self._pin_memory_device))
- pin_memory_thread.daemon = True
- pin_memory_thread.start()
-
-
- self._pin_memory_thread = pin_memory_thread
- else:
- self._data_queue = self._worker_result_queue
-
-
-
-
-
-
-
- if self._persistent_workers and self._pin_memory:
- import atexit
- for w in self._workers:
- atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w)
-
- _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers))
- _utils.signal_handling._set_SIGCHLD_handler()
- self._worker_pids_set = True
- self._reset(loader, first_iter=True)
- def _reset(self, loader, first_iter=False):
- super()._reset(loader, first_iter)
- self._send_idx = 0
- self._rcvd_idx = 0
-
-
-
- self._task_info = {}
- self._tasks_outstanding = 0
-
-
-
-
-
-
-
- self._workers_status = [True for i in range(self._num_workers)]
-
- self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
-
- if not first_iter:
- for idx in range(self._num_workers):
- self._index_queues[idx].put(_utils.worker._ResumeIteration(self._shared_seed))
- resume_iteration_cnt = self._num_workers
- while resume_iteration_cnt > 0:
- return_idx, return_data = self._get_data()
- if isinstance(return_idx, _utils.worker._ResumeIteration):
- assert return_data is None
- resume_iteration_cnt -= 1
-
- for _ in range(self._prefetch_factor * self._num_workers):
- self._try_put_index()
- def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
-
-
-
-
-
-
-
-
-
-
-
- try:
- data = self._data_queue.get(timeout=timeout)
- return (True, data)
- except Exception as e:
-
-
-
- failed_workers = []
- for worker_id, w in enumerate(self._workers):
- if self._workers_status[worker_id] and not w.is_alive():
- failed_workers.append(w)
- self._mark_worker_as_unavailable(worker_id)
- if len(failed_workers) > 0:
- pids_str = ', '.join(str(w.pid) for w in failed_workers)
- raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str)) from e
- if isinstance(e, queue.Empty):
- return (False, None)
- import tempfile
- import errno
- try:
-
-
-
-
- fds_limit_margin = 10
- fs = [tempfile.NamedTemporaryFile() for i in range(fds_limit_margin)]
- except OSError as e:
- if e.errno == errno.EMFILE:
- raise RuntimeError(
- "Too many open files. Communication with the"
- " workers is no longer possible. Please increase the"
- " limit using `ulimit -n` in the shell or change the"
- " sharing strategy by calling"
- " `torch.multiprocessing.set_sharing_strategy('file_system')`"
- " at the beginning of your code") from None
- raise
- def _get_data(self):
-
-
-
-
-
-
-
-
-
-
- if self._timeout > 0:
- success, data = self._try_get_data(self._timeout)
- if success:
- return data
- else:
- raise RuntimeError('DataLoader timed out after {} seconds'.format(self._timeout))
- elif self._pin_memory:
- while self._pin_memory_thread.is_alive():
- success, data = self._try_get_data()
- if success:
- return data
- else:
-
- raise RuntimeError('Pin memory thread exited unexpectedly')
-
-
- else:
- while True:
- success, data = self._try_get_data()
- if success:
- return data
- def _next_data(self):
- while True:
-
-
-
-
-
-
-
- while self._rcvd_idx < self._send_idx:
- info = self._task_info[self._rcvd_idx]
- worker_id = info[0]
- if len(info) == 2 or self._workers_status[worker_id]:
- break
- del self._task_info[self._rcvd_idx]
- self._rcvd_idx += 1
- else:
-
- if not self._persistent_workers:
- self._shutdown_workers()
- raise StopIteration
-
-
- if len(self._task_info[self._rcvd_idx]) == 2:
- data = self._task_info.pop(self._rcvd_idx)[1]
- return self._process_data(data)
- assert not self._shutdown and self._tasks_outstanding > 0
- idx, data = self._get_data()
- self._tasks_outstanding -= 1
- if self._dataset_kind == _DatasetKind.Iterable:
-
- if isinstance(data, _utils.worker._IterableDatasetStopIteration):
- if self._persistent_workers:
- self._workers_status[data.worker_id] = False
- else:
- self._mark_worker_as_unavailable(data.worker_id)
- self._try_put_index()
- continue
- if idx != self._rcvd_idx:
-
- self._task_info[idx] += (data,)
- else:
- del self._task_info[idx]
- return self._process_data(data)
- def _try_put_index(self):
- assert self._tasks_outstanding < self._prefetch_factor * self._num_workers
- try:
- index = self._next_index()
- except StopIteration:
- return
- for _ in range(self._num_workers):
- worker_queue_idx = next(self._worker_queue_idx_cycle)
- if self._workers_status[worker_queue_idx]:
- break
- else:
-
- return
- self._index_queues[worker_queue_idx].put((self._send_idx, index))
- self._task_info[self._send_idx] = (worker_queue_idx,)
- self._tasks_outstanding += 1
- self._send_idx += 1
- def _process_data(self, data):
- self._rcvd_idx += 1
- self._try_put_index()
- if isinstance(data, ExceptionWrapper):
- data.reraise()
- return data
- def _mark_worker_as_unavailable(self, worker_id, shutdown=False):
-
-
-
- assert self._workers_status[worker_id] or (self._persistent_workers and shutdown)
-
- q = self._index_queues[worker_id]
-
-
- q.put(None)
-
-
-
-
-
-
-
-
- self._workers_status[worker_id] = False
- assert self._workers_done_event.is_set() == shutdown
- def _shutdown_workers(self):
-
-
-
- if _utils is None or _utils.python_exit_status is True or _utils.python_exit_status is None:
-
- return
-
-
- if not self._shutdown:
- self._shutdown = True
- try:
-
-
-
-
-
- if hasattr(self, '_pin_memory_thread'):
-
- self._pin_memory_thread_done_event.set()
-
-
- self._worker_result_queue.put((None, None))
- self._pin_memory_thread.join()
- self._worker_result_queue.cancel_join_thread()
- self._worker_result_queue.close()
-
- self._workers_done_event.set()
- for worker_id in range(len(self._workers)):
-
-
-
-
-
- if self._persistent_workers or self._workers_status[worker_id]:
- self._mark_worker_as_unavailable(worker_id, shutdown=True)
- for w in self._workers:
-
-
-
- w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
- for q in self._index_queues:
- q.cancel_join_thread()
- q.close()
- finally:
-
-
-
-
-
-
-
-
-
-
- if self._worker_pids_set:
- _utils.signal_handling._remove_worker_pids(id(self))
- self._worker_pids_set = False
- for w in self._workers:
- if w.is_alive():
-
-
-
-
- w.terminate()
-
- @staticmethod
- def _clean_up_worker(w):
- try:
- w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
- finally:
- if w.is_alive():
- w.terminate()
- def __del__(self):
- self._shutdown_workers()
|