vision.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import os
  2. from typing import Any, Callable, List, Optional, Tuple
  3. import torch.utils.data as data
  4. from ..utils import _log_api_usage_once
  5. class VisionDataset(data.Dataset):
  6. """
  7. Base Class For making datasets which are compatible with torchvision.
  8. It is necessary to override the ``__getitem__`` and ``__len__`` method.
  9. Args:
  10. root (string): Root directory of dataset.
  11. transforms (callable, optional): A function/transforms that takes in
  12. an image and a label and returns the transformed versions of both.
  13. transform (callable, optional): A function/transform that takes in an PIL image
  14. and returns a transformed version. E.g, ``transforms.RandomCrop``
  15. target_transform (callable, optional): A function/transform that takes in the
  16. target and transforms it.
  17. .. note::
  18. :attr:`transforms` and the combination of :attr:`transform` and :attr:`target_transform` are mutually exclusive.
  19. """
  20. _repr_indent = 4
  21. def __init__(
  22. self,
  23. root: str,
  24. transforms: Optional[Callable] = None,
  25. transform: Optional[Callable] = None,
  26. target_transform: Optional[Callable] = None,
  27. ) -> None:
  28. _log_api_usage_once(self)
  29. if isinstance(root, str):
  30. root = os.path.expanduser(root)
  31. self.root = root
  32. has_transforms = transforms is not None
  33. has_separate_transform = transform is not None or target_transform is not None
  34. if has_transforms and has_separate_transform:
  35. raise ValueError("Only transforms or transform/target_transform can be passed as argument")
  36. # for backwards-compatibility
  37. self.transform = transform
  38. self.target_transform = target_transform
  39. if has_separate_transform:
  40. transforms = StandardTransform(transform, target_transform)
  41. self.transforms = transforms
  42. def __getitem__(self, index: int) -> Any:
  43. """
  44. Args:
  45. index (int): Index
  46. Returns:
  47. (Any): Sample and meta data, optionally transformed by the respective transforms.
  48. """
  49. raise NotImplementedError
  50. def __len__(self) -> int:
  51. raise NotImplementedError
  52. def __repr__(self) -> str:
  53. head = "Dataset " + self.__class__.__name__
  54. body = [f"Number of datapoints: {self.__len__()}"]
  55. if self.root is not None:
  56. body.append(f"Root location: {self.root}")
  57. body += self.extra_repr().splitlines()
  58. if hasattr(self, "transforms") and self.transforms is not None:
  59. body += [repr(self.transforms)]
  60. lines = [head] + [" " * self._repr_indent + line for line in body]
  61. return "\n".join(lines)
  62. def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
  63. lines = transform.__repr__().splitlines()
  64. return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
  65. def extra_repr(self) -> str:
  66. return ""
  67. class StandardTransform:
  68. def __init__(self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) -> None:
  69. self.transform = transform
  70. self.target_transform = target_transform
  71. def __call__(self, input: Any, target: Any) -> Tuple[Any, Any]:
  72. if self.transform is not None:
  73. input = self.transform(input)
  74. if self.target_transform is not None:
  75. target = self.target_transform(target)
  76. return input, target
  77. def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
  78. lines = transform.__repr__().splitlines()
  79. return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
  80. def __repr__(self) -> str:
  81. body = [self.__class__.__name__]
  82. if self.transform is not None:
  83. body += self._format_transform_repr(self.transform, "Transform: ")
  84. if self.target_transform is not None:
  85. body += self._format_transform_repr(self.target_transform, "Target transform: ")
  86. return "\n".join(body)