test_datasets_samplers.py 3.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import pytest
  2. import torch
  3. from common_utils import assert_equal, get_list_of_videos
  4. from torchvision import io
  5. from torchvision.datasets.samplers import DistributedSampler, RandomClipSampler, UniformClipSampler
  6. from torchvision.datasets.video_utils import VideoClips
  7. @pytest.mark.skipif(not io.video._av_available(), reason="this test requires av")
  8. class TestDatasetsSamplers:
  9. def test_random_clip_sampler(self, tmpdir):
  10. video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[25, 25, 25])
  11. video_clips = VideoClips(video_list, 5, 5)
  12. sampler = RandomClipSampler(video_clips, 3)
  13. assert len(sampler) == 3 * 3
  14. indices = torch.tensor(list(iter(sampler)))
  15. videos = torch.div(indices, 5, rounding_mode="floor")
  16. v_idxs, count = torch.unique(videos, return_counts=True)
  17. assert_equal(v_idxs, torch.tensor([0, 1, 2]))
  18. assert_equal(count, torch.tensor([3, 3, 3]))
  19. def test_random_clip_sampler_unequal(self, tmpdir):
  20. video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[10, 25, 25])
  21. video_clips = VideoClips(video_list, 5, 5)
  22. sampler = RandomClipSampler(video_clips, 3)
  23. assert len(sampler) == 2 + 3 + 3
  24. indices = list(iter(sampler))
  25. assert 0 in indices
  26. assert 1 in indices
  27. # remove elements of the first video, to simplify testing
  28. indices.remove(0)
  29. indices.remove(1)
  30. indices = torch.tensor(indices) - 2
  31. videos = torch.div(indices, 5, rounding_mode="floor")
  32. v_idxs, count = torch.unique(videos, return_counts=True)
  33. assert_equal(v_idxs, torch.tensor([0, 1]))
  34. assert_equal(count, torch.tensor([3, 3]))
  35. def test_uniform_clip_sampler(self, tmpdir):
  36. video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[25, 25, 25])
  37. video_clips = VideoClips(video_list, 5, 5)
  38. sampler = UniformClipSampler(video_clips, 3)
  39. assert len(sampler) == 3 * 3
  40. indices = torch.tensor(list(iter(sampler)))
  41. videos = torch.div(indices, 5, rounding_mode="floor")
  42. v_idxs, count = torch.unique(videos, return_counts=True)
  43. assert_equal(v_idxs, torch.tensor([0, 1, 2]))
  44. assert_equal(count, torch.tensor([3, 3, 3]))
  45. assert_equal(indices, torch.tensor([0, 2, 4, 5, 7, 9, 10, 12, 14]))
  46. def test_uniform_clip_sampler_insufficient_clips(self, tmpdir):
  47. video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[10, 25, 25])
  48. video_clips = VideoClips(video_list, 5, 5)
  49. sampler = UniformClipSampler(video_clips, 3)
  50. assert len(sampler) == 3 * 3
  51. indices = torch.tensor(list(iter(sampler)))
  52. assert_equal(indices, torch.tensor([0, 0, 1, 2, 4, 6, 7, 9, 11]))
  53. def test_distributed_sampler_and_uniform_clip_sampler(self, tmpdir):
  54. video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[25, 25, 25])
  55. video_clips = VideoClips(video_list, 5, 5)
  56. clip_sampler = UniformClipSampler(video_clips, 3)
  57. distributed_sampler_rank0 = DistributedSampler(
  58. clip_sampler,
  59. num_replicas=2,
  60. rank=0,
  61. group_size=3,
  62. )
  63. indices = torch.tensor(list(iter(distributed_sampler_rank0)))
  64. assert len(distributed_sampler_rank0) == 6
  65. assert_equal(indices, torch.tensor([0, 2, 4, 10, 12, 14]))
  66. distributed_sampler_rank1 = DistributedSampler(
  67. clip_sampler,
  68. num_replicas=2,
  69. rank=1,
  70. group_size=3,
  71. )
  72. indices = torch.tensor(list(iter(distributed_sampler_rank1)))
  73. assert len(distributed_sampler_rank1) == 6
  74. assert_equal(indices, torch.tensor([5, 7, 9, 0, 2, 4]))
  75. if __name__ == "__main__":
  76. pytest.main([__file__])