kitti.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. import csv
  2. import os
  3. from typing import Any, Callable, List, Optional, Tuple
  4. from PIL import Image
  5. from .utils import download_and_extract_archive
  6. from .vision import VisionDataset
  7. class Kitti(VisionDataset):
  8. """`KITTI <http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark>`_ Dataset.
  9. It corresponds to the "left color images of object" dataset, for object detection.
  10. Args:
  11. root (string): Root directory where images are downloaded to.
  12. Expects the following folder structure if download=False:
  13. .. code::
  14. <root>
  15. └── Kitti
  16. └─ raw
  17. ├── training
  18. | ├── image_2
  19. | └── label_2
  20. └── testing
  21. └── image_2
  22. train (bool, optional): Use ``train`` split if true, else ``test`` split.
  23. Defaults to ``train``.
  24. transform (callable, optional): A function/transform that takes in a PIL image
  25. and returns a transformed version. E.g, ``transforms.PILToTensor``
  26. target_transform (callable, optional): A function/transform that takes in the
  27. target and transforms it.
  28. transforms (callable, optional): A function/transform that takes input sample
  29. and its target as entry and returns a transformed version.
  30. download (bool, optional): If true, downloads the dataset from the internet and
  31. puts it in root directory. If dataset is already downloaded, it is not
  32. downloaded again.
  33. """
  34. data_url = "https://s3.eu-central-1.amazonaws.com/avg-kitti/"
  35. resources = [
  36. "data_object_image_2.zip",
  37. "data_object_label_2.zip",
  38. ]
  39. image_dir_name = "image_2"
  40. labels_dir_name = "label_2"
  41. def __init__(
  42. self,
  43. root: str,
  44. train: bool = True,
  45. transform: Optional[Callable] = None,
  46. target_transform: Optional[Callable] = None,
  47. transforms: Optional[Callable] = None,
  48. download: bool = False,
  49. ):
  50. super().__init__(
  51. root,
  52. transform=transform,
  53. target_transform=target_transform,
  54. transforms=transforms,
  55. )
  56. self.images = []
  57. self.targets = []
  58. self.root = root
  59. self.train = train
  60. self._location = "training" if self.train else "testing"
  61. if download:
  62. self.download()
  63. if not self._check_exists():
  64. raise RuntimeError("Dataset not found. You may use download=True to download it.")
  65. image_dir = os.path.join(self._raw_folder, self._location, self.image_dir_name)
  66. if self.train:
  67. labels_dir = os.path.join(self._raw_folder, self._location, self.labels_dir_name)
  68. for img_file in os.listdir(image_dir):
  69. self.images.append(os.path.join(image_dir, img_file))
  70. if self.train:
  71. self.targets.append(os.path.join(labels_dir, f"{img_file.split('.')[0]}.txt"))
  72. def __getitem__(self, index: int) -> Tuple[Any, Any]:
  73. """Get item at a given index.
  74. Args:
  75. index (int): Index
  76. Returns:
  77. tuple: (image, target), where
  78. target is a list of dictionaries with the following keys:
  79. - type: str
  80. - truncated: float
  81. - occluded: int
  82. - alpha: float
  83. - bbox: float[4]
  84. - dimensions: float[3]
  85. - locations: float[3]
  86. - rotation_y: float
  87. """
  88. image = Image.open(self.images[index])
  89. target = self._parse_target(index) if self.train else None
  90. if self.transforms:
  91. image, target = self.transforms(image, target)
  92. return image, target
  93. def _parse_target(self, index: int) -> List:
  94. target = []
  95. with open(self.targets[index]) as inp:
  96. content = csv.reader(inp, delimiter=" ")
  97. for line in content:
  98. target.append(
  99. {
  100. "type": line[0],
  101. "truncated": float(line[1]),
  102. "occluded": int(line[2]),
  103. "alpha": float(line[3]),
  104. "bbox": [float(x) for x in line[4:8]],
  105. "dimensions": [float(x) for x in line[8:11]],
  106. "location": [float(x) for x in line[11:14]],
  107. "rotation_y": float(line[14]),
  108. }
  109. )
  110. return target
  111. def __len__(self) -> int:
  112. return len(self.images)
  113. @property
  114. def _raw_folder(self) -> str:
  115. return os.path.join(self.root, self.__class__.__name__, "raw")
  116. def _check_exists(self) -> bool:
  117. """Check if the data directory exists."""
  118. folders = [self.image_dir_name]
  119. if self.train:
  120. folders.append(self.labels_dir_name)
  121. return all(os.path.isdir(os.path.join(self._raw_folder, self._location, fname)) for fname in folders)
  122. def download(self) -> None:
  123. """Download the KITTI data if it doesn't exist already."""
  124. if self._check_exists():
  125. return
  126. os.makedirs(self._raw_folder, exist_ok=True)
  127. # download files
  128. for fname in self.resources:
  129. download_and_extract_archive(
  130. url=f"{self.data_url}{fname}",
  131. download_root=self._raw_folder,
  132. filename=fname,
  133. )