moving_mnist.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import os.path
  2. from typing import Callable, Optional
  3. import numpy as np
  4. import torch
  5. from torchvision.datasets.utils import download_url, verify_str_arg
  6. from torchvision.datasets.vision import VisionDataset
  7. class MovingMNIST(VisionDataset):
  8. """`MovingMNIST <http://www.cs.toronto.edu/~nitish/unsupervised_video/>`_ Dataset.
  9. Args:
  10. root (string): Root directory of dataset where ``MovingMNIST/mnist_test_seq.npy`` exists.
  11. split (string, optional): The dataset split, supports ``None`` (default), ``"train"`` and ``"test"``.
  12. If ``split=None``, the full data is returned.
  13. split_ratio (int, optional): The split ratio of number of frames. If ``split="train"``, the first split
  14. frames ``data[:, :split_ratio]`` is returned. If ``split="test"``, the last split frames ``data[:, split_ratio:]``
  15. is returned. If ``split=None``, this parameter is ignored and the all frames data is returned.
  16. transform (callable, optional): A function/transform that takes in an torch Tensor
  17. and returns a transformed version. E.g, ``transforms.RandomCrop``
  18. download (bool, optional): If true, downloads the dataset from the internet and
  19. puts it in root directory. If dataset is already downloaded, it is not
  20. downloaded again.
  21. """
  22. _URL = "http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy"
  23. def __init__(
  24. self,
  25. root: str,
  26. split: Optional[str] = None,
  27. split_ratio: int = 10,
  28. download: bool = False,
  29. transform: Optional[Callable] = None,
  30. ) -> None:
  31. super().__init__(root, transform=transform)
  32. self._base_folder = os.path.join(self.root, self.__class__.__name__)
  33. self._filename = self._URL.split("/")[-1]
  34. if split is not None:
  35. verify_str_arg(split, "split", ("train", "test"))
  36. self.split = split
  37. if not isinstance(split_ratio, int):
  38. raise TypeError(f"`split_ratio` should be an integer, but got {type(split_ratio)}")
  39. elif not (1 <= split_ratio <= 19):
  40. raise ValueError(f"`split_ratio` should be `1 <= split_ratio <= 19`, but got {split_ratio} instead.")
  41. self.split_ratio = split_ratio
  42. if download:
  43. self.download()
  44. if not self._check_exists():
  45. raise RuntimeError("Dataset not found. You can use download=True to download it.")
  46. data = torch.from_numpy(np.load(os.path.join(self._base_folder, self._filename)))
  47. if self.split == "train":
  48. data = data[: self.split_ratio]
  49. elif self.split == "test":
  50. data = data[self.split_ratio :]
  51. self.data = data.transpose(0, 1).unsqueeze(2).contiguous()
  52. def __getitem__(self, idx: int) -> torch.Tensor:
  53. """
  54. Args:
  55. index (int): Index
  56. Returns:
  57. torch.Tensor: Video frames (torch Tensor[T, C, H, W]). The `T` is the number of frames.
  58. """
  59. data = self.data[idx]
  60. if self.transform is not None:
  61. data = self.transform(data)
  62. return data
  63. def __len__(self) -> int:
  64. return len(self.data)
  65. def _check_exists(self) -> bool:
  66. return os.path.exists(os.path.join(self._base_folder, self._filename))
  67. def download(self) -> None:
  68. if self._check_exists():
  69. return
  70. download_url(
  71. url=self._URL,
  72. root=self._base_folder,
  73. filename=self._filename,
  74. md5="be083ec986bfe91a449d63653c411eb2",
  75. )