test.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. import unittest
  2. from collections import defaultdict
  3. import torch
  4. import torchvision.transforms as transforms
  5. from sampler import PKSampler
  6. from torch.utils.data import DataLoader
  7. from torchvision.datasets import FakeData
  8. class Tester(unittest.TestCase):
  9. def test_pksampler(self):
  10. p, k = 16, 4
  11. # Ensure sampler does not allow p to be greater than num_classes
  12. dataset = FakeData(size=100, num_classes=10, image_size=(3, 1, 1))
  13. targets = [target.item() for _, target in dataset]
  14. self.assertRaises(AssertionError, PKSampler, targets, p, k)
  15. # Ensure p, k constraints on batch
  16. trans = transforms.Compose(
  17. [
  18. transforms.PILToTensor(),
  19. transforms.ConvertImageDtype(torch.float),
  20. ]
  21. )
  22. dataset = FakeData(size=1000, num_classes=100, image_size=(3, 1, 1), transform=trans)
  23. targets = [target.item() for _, target in dataset]
  24. sampler = PKSampler(targets, p, k)
  25. loader = DataLoader(dataset, batch_size=p * k, sampler=sampler)
  26. for _, labels in loader:
  27. bins = defaultdict(int)
  28. for label in labels.tolist():
  29. bins[label] += 1
  30. # Ensure that each batch has samples from exactly p classes
  31. self.assertEqual(len(bins), p)
  32. # Ensure that there are k samples from each class
  33. for b in bins:
  34. self.assertEqual(bins[b], k)
  35. if __name__ == "__main__":
  36. unittest.main()