image_list.py 783 B

12345678910111213141516171819202122232425
  1. from typing import List, Tuple
  2. import torch
  3. from torch import Tensor
  4. class ImageList:
  5. """
  6. Structure that holds a list of images (of possibly
  7. varying sizes) as a single tensor.
  8. This works by padding the images to the same size,
  9. and storing in a field the original sizes of each image
  10. Args:
  11. tensors (tensor): Tensor containing images.
  12. image_sizes (list[tuple[int, int]]): List of Tuples each containing size of images.
  13. """
  14. def __init__(self, tensors: Tensor, image_sizes: List[Tuple[int, int]]) -> None:
  15. self.tensors = tensors
  16. self.image_sizes = image_sizes
  17. def to(self, device: torch.device) -> "ImageList":
  18. cast_tensor = self.tensors.to(device)
  19. return ImageList(cast_tensor, self.image_sizes)