1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980 |
- from typing import (
- Dict,
- Sized,
- Tuple,
- )
- from torch.utils.data.datapipes._decorator import functional_datapipe
- from torch.utils.data.datapipes.datapipe import IterDataPipe
- from enum import IntEnum
- __all__ = [
- "SHARDING_PRIORITIES",
- "ShardingFilterIterDataPipe",
- ]
- class SHARDING_PRIORITIES(IntEnum):
- DEFAULT = 1
- DISTRIBUTED = 2
- MULTIPROCESSING = 3
- class _ShardingIterDataPipe(IterDataPipe):
- def apply_sharding(self, num_of_instances, instance_id, sharding_group):
- raise NotImplementedError
- @functional_datapipe('sharding_filter')
- class ShardingFilterIterDataPipe(_ShardingIterDataPipe):
- r"""
- Wrapper that allows DataPipe to be sharded (functional name: ``sharding_filter``). After ``apply_sharding`` is
- called, each instance of the DataPipe (on different workers) will have every `n`-th element of the
- original DataPipe, where `n` equals to the number of instances.
- Args:
- source_datapipe: Iterable DataPipe that will be sharded
- """
- def __init__(self, source_datapipe: IterDataPipe, sharding_group_filter=None):
- self.source_datapipe = source_datapipe
- self.sharding_group_filter = sharding_group_filter
- self.groups: Dict[int, Tuple[int, int]] = {}
- self.num_of_instances = 1
- self.instance_id = 0
- self._update_num_of_instances()
- def apply_sharding(self, num_of_instances, instance_id, sharding_group=SHARDING_PRIORITIES.DEFAULT):
- if instance_id >= num_of_instances:
- raise ValueError(f"instance_id({instance_id}) should be smaller than num_of_instances({num_of_instances})")
- if sharding_group == SHARDING_PRIORITIES.DEFAULT:
- if len(self.groups) and SHARDING_PRIORITIES.DEFAULT not in self.groups:
- raise Exception('ShardingFilter cannot mix DEFAULT and non DEFAULT groups')
- else:
- if SHARDING_PRIORITIES.DEFAULT in self.groups:
- raise Exception('ShardingFilter cannot mix DEFAULT and non DEFAULT groups')
- self.groups[sharding_group] = (num_of_instances, instance_id)
- self._update_num_of_instances()
- def _update_num_of_instances(self):
- sorted_sharding_groups = []
- for key in sorted(self.groups.keys()):
- if self.sharding_group_filter is None or key == self.sharding_group_filter:
- sorted_sharding_groups.append(self.groups[key])
- sorted_sharding_groups.reverse()
- self.num_of_instances = 1
- self.instance_id = 0
- for group_num_of_instances, group_instance_id in sorted_sharding_groups:
- self.instance_id += self.num_of_instances * group_instance_id
- self.num_of_instances *= group_num_of_instances
- def __iter__(self):
- for i, item in enumerate(self.source_datapipe):
- if i % self.num_of_instances == self.instance_id:
- yield item
- def __len__(self):
- if isinstance(self.source_datapipe, Sized):
- return len(self.source_datapipe) // self.num_of_instances +\
- (1 if (self.instance_id < len(self.source_datapipe) % self.num_of_instances) else 0)
- raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
|