12345678910111213141516171819202122232425262728293031323334353637383940414243444546 |
- import unittest
- from collections import defaultdict
- import torch
- import torchvision.transforms as transforms
- from sampler import PKSampler
- from torch.utils.data import DataLoader
- from torchvision.datasets import FakeData
- class Tester(unittest.TestCase):
- def test_pksampler(self):
- p, k = 16, 4
- # Ensure sampler does not allow p to be greater than num_classes
- dataset = FakeData(size=100, num_classes=10, image_size=(3, 1, 1))
- targets = [target.item() for _, target in dataset]
- self.assertRaises(AssertionError, PKSampler, targets, p, k)
- # Ensure p, k constraints on batch
- trans = transforms.Compose(
- [
- transforms.PILToTensor(),
- transforms.ConvertImageDtype(torch.float),
- ]
- )
- dataset = FakeData(size=1000, num_classes=100, image_size=(3, 1, 1), transform=trans)
- targets = [target.item() for _, target in dataset]
- sampler = PKSampler(targets, p, k)
- loader = DataLoader(dataset, batch_size=p * k, sampler=sampler)
- for _, labels in loader:
- bins = defaultdict(int)
- for label in labels.tolist():
- bins[label] += 1
- # Ensure that each batch has samples from exactly p classes
- self.assertEqual(len(bins), p)
- # Ensure that there are k samples from each class
- for b in bins:
- self.assertEqual(bins[b], k)
- if __name__ == "__main__":
- unittest.main()
|