clip_sampler.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. import math
  2. from typing import cast, Iterator, List, Optional, Sized, Union
  3. import torch
  4. import torch.distributed as dist
  5. from torch.utils.data import Sampler
  6. from torchvision.datasets.video_utils import VideoClips
  7. class DistributedSampler(Sampler):
  8. """
  9. Extension of DistributedSampler, as discussed in
  10. https://github.com/pytorch/pytorch/issues/23430
  11. Example:
  12. dataset: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
  13. num_replicas: 4
  14. shuffle: False
  15. when group_size = 1
  16. RANK | shard_dataset
  17. =========================
  18. rank_0 | [0, 4, 8, 12]
  19. rank_1 | [1, 5, 9, 13]
  20. rank_2 | [2, 6, 10, 0]
  21. rank_3 | [3, 7, 11, 1]
  22. when group_size = 2
  23. RANK | shard_dataset
  24. =========================
  25. rank_0 | [0, 1, 8, 9]
  26. rank_1 | [2, 3, 10, 11]
  27. rank_2 | [4, 5, 12, 13]
  28. rank_3 | [6, 7, 0, 1]
  29. """
  30. def __init__(
  31. self,
  32. dataset: Sized,
  33. num_replicas: Optional[int] = None,
  34. rank: Optional[int] = None,
  35. shuffle: bool = False,
  36. group_size: int = 1,
  37. ) -> None:
  38. if num_replicas is None:
  39. if not dist.is_available():
  40. raise RuntimeError("Requires distributed package to be available")
  41. num_replicas = dist.get_world_size()
  42. if rank is None:
  43. if not dist.is_available():
  44. raise RuntimeError("Requires distributed package to be available")
  45. rank = dist.get_rank()
  46. if len(dataset) % group_size != 0:
  47. raise ValueError(
  48. f"dataset length must be a multiplier of group size dataset length: {len(dataset)}, group size: {group_size}"
  49. )
  50. self.dataset = dataset
  51. self.group_size = group_size
  52. self.num_replicas = num_replicas
  53. self.rank = rank
  54. self.epoch = 0
  55. dataset_group_length = len(dataset) // group_size
  56. self.num_group_samples = int(math.ceil(dataset_group_length * 1.0 / self.num_replicas))
  57. self.num_samples = self.num_group_samples * group_size
  58. self.total_size = self.num_samples * self.num_replicas
  59. self.shuffle = shuffle
  60. def __iter__(self) -> Iterator[int]:
  61. # deterministically shuffle based on epoch
  62. g = torch.Generator()
  63. g.manual_seed(self.epoch)
  64. indices: Union[torch.Tensor, List[int]]
  65. if self.shuffle:
  66. indices = torch.randperm(len(self.dataset), generator=g).tolist()
  67. else:
  68. indices = list(range(len(self.dataset)))
  69. # add extra samples to make it evenly divisible
  70. indices += indices[: (self.total_size - len(indices))]
  71. assert len(indices) == self.total_size
  72. total_group_size = self.total_size // self.group_size
  73. indices = torch.reshape(torch.LongTensor(indices), (total_group_size, self.group_size))
  74. # subsample
  75. indices = indices[self.rank : total_group_size : self.num_replicas, :]
  76. indices = torch.reshape(indices, (-1,)).tolist()
  77. assert len(indices) == self.num_samples
  78. if isinstance(self.dataset, Sampler):
  79. orig_indices = list(iter(self.dataset))
  80. indices = [orig_indices[i] for i in indices]
  81. return iter(indices)
  82. def __len__(self) -> int:
  83. return self.num_samples
  84. def set_epoch(self, epoch: int) -> None:
  85. self.epoch = epoch
  86. class UniformClipSampler(Sampler):
  87. """
  88. Sample `num_video_clips_per_video` clips for each video, equally spaced.
  89. When number of unique clips in the video is fewer than num_video_clips_per_video,
  90. repeat the clips until `num_video_clips_per_video` clips are collected
  91. Args:
  92. video_clips (VideoClips): video clips to sample from
  93. num_clips_per_video (int): number of clips to be sampled per video
  94. """
  95. def __init__(self, video_clips: VideoClips, num_clips_per_video: int) -> None:
  96. if not isinstance(video_clips, VideoClips):
  97. raise TypeError(f"Expected video_clips to be an instance of VideoClips, got {type(video_clips)}")
  98. self.video_clips = video_clips
  99. self.num_clips_per_video = num_clips_per_video
  100. def __iter__(self) -> Iterator[int]:
  101. idxs = []
  102. s = 0
  103. # select num_clips_per_video for each video, uniformly spaced
  104. for c in self.video_clips.clips:
  105. length = len(c)
  106. if length == 0:
  107. # corner case where video decoding fails
  108. continue
  109. sampled = torch.linspace(s, s + length - 1, steps=self.num_clips_per_video).floor().to(torch.int64)
  110. s += length
  111. idxs.append(sampled)
  112. return iter(cast(List[int], torch.cat(idxs).tolist()))
  113. def __len__(self) -> int:
  114. return sum(self.num_clips_per_video for c in self.video_clips.clips if len(c) > 0)
  115. class RandomClipSampler(Sampler):
  116. """
  117. Samples at most `max_video_clips_per_video` clips for each video randomly
  118. Args:
  119. video_clips (VideoClips): video clips to sample from
  120. max_clips_per_video (int): maximum number of clips to be sampled per video
  121. """
  122. def __init__(self, video_clips: VideoClips, max_clips_per_video: int) -> None:
  123. if not isinstance(video_clips, VideoClips):
  124. raise TypeError(f"Expected video_clips to be an instance of VideoClips, got {type(video_clips)}")
  125. self.video_clips = video_clips
  126. self.max_clips_per_video = max_clips_per_video
  127. def __iter__(self) -> Iterator[int]:
  128. idxs = []
  129. s = 0
  130. # select at most max_clips_per_video for each video, randomly
  131. for c in self.video_clips.clips:
  132. length = len(c)
  133. size = min(length, self.max_clips_per_video)
  134. sampled = torch.randperm(length)[:size] + s
  135. s += length
  136. idxs.append(sampled)
  137. idxs_ = torch.cat(idxs)
  138. # shuffle all clips randomly
  139. perm = torch.randperm(len(idxs_))
  140. return iter(idxs_[perm].tolist())
  141. def __len__(self) -> int:
  142. return sum(min(len(c), self.max_clips_per_video) for c in self.video_clips.clips)