sharding.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. from typing import (
  2. Dict,
  3. Sized,
  4. Tuple,
  5. )
  6. from torch.utils.data.datapipes._decorator import functional_datapipe
  7. from torch.utils.data.datapipes.datapipe import IterDataPipe
  8. from enum import IntEnum
  9. __all__ = [
  10. "SHARDING_PRIORITIES",
  11. "ShardingFilterIterDataPipe",
  12. ]
  13. class SHARDING_PRIORITIES(IntEnum):
  14. DEFAULT = 1
  15. DISTRIBUTED = 2
  16. MULTIPROCESSING = 3
  17. class _ShardingIterDataPipe(IterDataPipe):
  18. def apply_sharding(self, num_of_instances, instance_id, sharding_group):
  19. raise NotImplementedError
  20. @functional_datapipe('sharding_filter')
  21. class ShardingFilterIterDataPipe(_ShardingIterDataPipe):
  22. r"""
  23. Wrapper that allows DataPipe to be sharded (functional name: ``sharding_filter``). After ``apply_sharding`` is
  24. called, each instance of the DataPipe (on different workers) will have every `n`-th element of the
  25. original DataPipe, where `n` equals to the number of instances.
  26. Args:
  27. source_datapipe: Iterable DataPipe that will be sharded
  28. """
  29. def __init__(self, source_datapipe: IterDataPipe, sharding_group_filter=None):
  30. self.source_datapipe = source_datapipe
  31. self.sharding_group_filter = sharding_group_filter
  32. self.groups: Dict[int, Tuple[int, int]] = {}
  33. self.num_of_instances = 1
  34. self.instance_id = 0
  35. self._update_num_of_instances()
  36. def apply_sharding(self, num_of_instances, instance_id, sharding_group=SHARDING_PRIORITIES.DEFAULT):
  37. if instance_id >= num_of_instances:
  38. raise ValueError(f"instance_id({instance_id}) should be smaller than num_of_instances({num_of_instances})")
  39. if sharding_group == SHARDING_PRIORITIES.DEFAULT:
  40. if len(self.groups) and SHARDING_PRIORITIES.DEFAULT not in self.groups:
  41. raise Exception('ShardingFilter cannot mix DEFAULT and non DEFAULT groups')
  42. else:
  43. if SHARDING_PRIORITIES.DEFAULT in self.groups:
  44. raise Exception('ShardingFilter cannot mix DEFAULT and non DEFAULT groups')
  45. self.groups[sharding_group] = (num_of_instances, instance_id)
  46. self._update_num_of_instances()
  47. def _update_num_of_instances(self):
  48. sorted_sharding_groups = []
  49. for key in sorted(self.groups.keys()):
  50. if self.sharding_group_filter is None or key == self.sharding_group_filter:
  51. sorted_sharding_groups.append(self.groups[key])
  52. sorted_sharding_groups.reverse()
  53. self.num_of_instances = 1
  54. self.instance_id = 0
  55. for group_num_of_instances, group_instance_id in sorted_sharding_groups:
  56. self.instance_id += self.num_of_instances * group_instance_id
  57. self.num_of_instances *= group_num_of_instances
  58. def __iter__(self):
  59. for i, item in enumerate(self.source_datapipe):
  60. if i % self.num_of_instances == self.instance_id:
  61. yield item
  62. def __len__(self):
  63. if isinstance(self.source_datapipe, Sized):
  64. return len(self.source_datapipe) // self.num_of_instances +\
  65. (1 if (self.instance_id < len(self.source_datapipe) % self.num_of_instances) else 0)
  66. raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))