sampler.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import random
  2. from collections import defaultdict
  3. import torch
  4. from torch.utils.data.sampler import Sampler
  5. def create_groups(groups, k):
  6. """Bins sample indices with respect to groups, remove bins with less than k samples
  7. Args:
  8. groups (list[int]): where ith index stores ith sample's group id
  9. Returns:
  10. defaultdict[list]: Bins of sample indices, binned by group_idx
  11. """
  12. group_samples = defaultdict(list)
  13. for sample_idx, group_idx in enumerate(groups):
  14. group_samples[group_idx].append(sample_idx)
  15. keys_to_remove = []
  16. for key in group_samples:
  17. if len(group_samples[key]) < k:
  18. keys_to_remove.append(key)
  19. continue
  20. for key in keys_to_remove:
  21. group_samples.pop(key)
  22. return group_samples
  23. class PKSampler(Sampler):
  24. """
  25. Randomly samples from a dataset while ensuring that each batch (of size p * k)
  26. includes samples from exactly p labels, with k samples for each label.
  27. Args:
  28. groups (list[int]): List where the ith entry is the group_id/label of the ith sample in the dataset.
  29. p (int): Number of labels/groups to be sampled from in a batch
  30. k (int): Number of samples for each label/group in a batch
  31. """
  32. def __init__(self, groups, p, k):
  33. self.p = p
  34. self.k = k
  35. self.groups = create_groups(groups, self.k)
  36. # Ensures there are enough classes to sample from
  37. if len(self.groups) < p:
  38. raise ValueError("There are not enough classes to sample from")
  39. def __iter__(self):
  40. # Shuffle samples within groups
  41. for key in self.groups:
  42. random.shuffle(self.groups[key])
  43. # Keep track of the number of samples left for each group
  44. group_samples_remaining = {}
  45. for key in self.groups:
  46. group_samples_remaining[key] = len(self.groups[key])
  47. while len(group_samples_remaining) > self.p:
  48. # Select p groups at random from valid/remaining groups
  49. group_ids = list(group_samples_remaining.keys())
  50. selected_group_idxs = torch.multinomial(torch.ones(len(group_ids)), self.p).tolist()
  51. for i in selected_group_idxs:
  52. group_id = group_ids[i]
  53. group = self.groups[group_id]
  54. for _ in range(self.k):
  55. # No need to pick samples at random since group samples are shuffled
  56. sample_idx = len(group) - group_samples_remaining[group_id]
  57. yield group[sample_idx]
  58. group_samples_remaining[group_id] -= 1
  59. # Don't sample from group if it has less than k samples remaining
  60. if group_samples_remaining[group_id] < self.k:
  61. group_samples_remaining.pop(group_id)