123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515 |
- import bz2
- import contextlib
- import gzip
- import hashlib
- import itertools
- import lzma
- import os
- import os.path
- import pathlib
- import re
- import sys
- import tarfile
- import urllib
- import urllib.error
- import urllib.request
- import warnings
- import zipfile
- from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Optional, Tuple, TypeVar
- from urllib.parse import urlparse
- import numpy as np
- import requests
- import torch
- from torch.utils.model_zoo import tqdm
- from .._internally_replaced_utils import _download_file_from_remote_location, _is_remote_location_available
- USER_AGENT = "pytorch/vision"
- def _save_response_content(
- content: Iterator[bytes],
- destination: str,
- length: Optional[int] = None,
- ) -> None:
- with open(destination, "wb") as fh, tqdm(total=length) as pbar:
- for chunk in content:
- # filter out keep-alive new chunks
- if not chunk:
- continue
- fh.write(chunk)
- pbar.update(len(chunk))
- def _urlretrieve(url: str, filename: str, chunk_size: int = 1024 * 32) -> None:
- with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
- _save_response_content(iter(lambda: response.read(chunk_size), b""), filename, length=response.length)
- def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
- # Setting the `usedforsecurity` flag does not change anything about the functionality, but indicates that we are
- # not using the MD5 checksum for cryptography. This enables its usage in restricted environments like FIPS. Without
- # it torchvision.datasets is unusable in these environments since we perform a MD5 check everywhere.
- if sys.version_info >= (3, 9):
- md5 = hashlib.md5(usedforsecurity=False)
- else:
- md5 = hashlib.md5()
- with open(fpath, "rb") as f:
- while chunk := f.read(chunk_size):
- md5.update(chunk)
- return md5.hexdigest()
- def check_md5(fpath: str, md5: str, **kwargs: Any) -> bool:
- return md5 == calculate_md5(fpath, **kwargs)
- def check_integrity(fpath: str, md5: Optional[str] = None) -> bool:
- if not os.path.isfile(fpath):
- return False
- if md5 is None:
- return True
- return check_md5(fpath, md5)
- def _get_redirect_url(url: str, max_hops: int = 3) -> str:
- initial_url = url
- headers = {"Method": "HEAD", "User-Agent": USER_AGENT}
- for _ in range(max_hops + 1):
- with urllib.request.urlopen(urllib.request.Request(url, headers=headers)) as response:
- if response.url == url or response.url is None:
- return url
- url = response.url
- else:
- raise RecursionError(
- f"Request to {initial_url} exceeded {max_hops} redirects. The last redirect points to {url}."
- )
- def _get_google_drive_file_id(url: str) -> Optional[str]:
- parts = urlparse(url)
- if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
- return None
- match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
- if match is None:
- return None
- return match.group("id")
- def download_url(
- url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None, max_redirect_hops: int = 3
- ) -> None:
- """Download a file from a url and place it in root.
- Args:
- url (str): URL to download file from
- root (str): Directory to place downloaded file in
- filename (str, optional): Name to save the file under. If None, use the basename of the URL
- md5 (str, optional): MD5 checksum of the download. If None, do not check
- max_redirect_hops (int, optional): Maximum number of redirect hops allowed
- """
- root = os.path.expanduser(root)
- if not filename:
- filename = os.path.basename(url)
- fpath = os.path.join(root, filename)
- os.makedirs(root, exist_ok=True)
- # check if file is already present locally
- if check_integrity(fpath, md5):
- print("Using downloaded and verified file: " + fpath)
- return
- if _is_remote_location_available():
- _download_file_from_remote_location(fpath, url)
- else:
- # expand redirect chain if needed
- url = _get_redirect_url(url, max_hops=max_redirect_hops)
- # check if file is located on Google Drive
- file_id = _get_google_drive_file_id(url)
- if file_id is not None:
- return download_file_from_google_drive(file_id, root, filename, md5)
- # download the file
- try:
- print("Downloading " + url + " to " + fpath)
- _urlretrieve(url, fpath)
- except (urllib.error.URLError, OSError) as e: # type: ignore[attr-defined]
- if url[:5] == "https":
- url = url.replace("https:", "http:")
- print("Failed download. Trying https -> http instead. Downloading " + url + " to " + fpath)
- _urlretrieve(url, fpath)
- else:
- raise e
- # check integrity of downloaded file
- if not check_integrity(fpath, md5):
- raise RuntimeError("File not found or corrupted.")
- def list_dir(root: str, prefix: bool = False) -> List[str]:
- """List all directories at a given root
- Args:
- root (str): Path to directory whose folders need to be listed
- prefix (bool, optional): If true, prepends the path to each result, otherwise
- only returns the name of the directories found
- """
- root = os.path.expanduser(root)
- directories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))]
- if prefix is True:
- directories = [os.path.join(root, d) for d in directories]
- return directories
- def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
- """List all files ending with a suffix at a given root
- Args:
- root (str): Path to directory whose folders need to be listed
- suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
- It uses the Python "str.endswith" method and is passed directly
- prefix (bool, optional): If true, prepends the path to each result, otherwise
- only returns the name of the files found
- """
- root = os.path.expanduser(root)
- files = [p for p in os.listdir(root) if os.path.isfile(os.path.join(root, p)) and p.endswith(suffix)]
- if prefix is True:
- files = [os.path.join(root, d) for d in files]
- return files
- def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple[bytes, Iterator[bytes]]:
- content = response.iter_content(chunk_size)
- first_chunk = None
- # filter out keep-alive new chunks
- while not first_chunk:
- first_chunk = next(content)
- content = itertools.chain([first_chunk], content)
- try:
- match = re.search("<title>Google Drive - (?P<api_response>.+?)</title>", first_chunk.decode())
- api_response = match["api_response"] if match is not None else None
- except UnicodeDecodeError:
- api_response = None
- return api_response, content
- def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None):
- """Download a Google Drive file from and place it in root.
- Args:
- file_id (str): id of file to be downloaded
- root (str): Directory to place downloaded file in
- filename (str, optional): Name to save the file under. If None, use the id of the file.
- md5 (str, optional): MD5 checksum of the download. If None, do not check
- """
- # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
- root = os.path.expanduser(root)
- if not filename:
- filename = file_id
- fpath = os.path.join(root, filename)
- os.makedirs(root, exist_ok=True)
- if check_integrity(fpath, md5):
- print(f"Using downloaded {'and verified ' if md5 else ''}file: {fpath}")
- return
- url = "https://drive.google.com/uc"
- params = dict(id=file_id, export="download")
- with requests.Session() as session:
- response = session.get(url, params=params, stream=True)
- for key, value in response.cookies.items():
- if key.startswith("download_warning"):
- token = value
- break
- else:
- api_response, content = _extract_gdrive_api_response(response)
- token = "t" if api_response == "Virus scan warning" else None
- if token is not None:
- response = session.get(url, params=dict(params, confirm=token), stream=True)
- api_response, content = _extract_gdrive_api_response(response)
- if api_response == "Quota exceeded":
- raise RuntimeError(
- f"The daily quota of the file {filename} is exceeded and it "
- f"can't be downloaded. This is a limitation of Google Drive "
- f"and can only be overcome by trying again later."
- )
- _save_response_content(content, fpath)
- # In case we deal with an unhandled GDrive API response, the file should be smaller than 10kB and contain only text
- if os.stat(fpath).st_size < 10 * 1024:
- with contextlib.suppress(UnicodeDecodeError), open(fpath) as fh:
- text = fh.read()
- # Regular expression to detect HTML. Copied from https://stackoverflow.com/a/70585604
- if re.search(r"</?\s*[a-z-][^>]*\s*>|(&(?:[\w\d]+|#\d+|#x[a-f\d]+);)", text):
- warnings.warn(
- f"We detected some HTML elements in the downloaded file. "
- f"This most likely means that the download triggered an unhandled API response by GDrive. "
- f"Please report this to torchvision at https://github.com/pytorch/vision/issues including "
- f"the response:\n\n{text}"
- )
- if md5 and not check_md5(fpath, md5):
- raise RuntimeError(
- f"The MD5 checksum of the download file {fpath} does not match the one on record."
- f"Please delete the file and try again. "
- f"If the issue persists, please report this to torchvision at https://github.com/pytorch/vision/issues."
- )
- def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None:
- with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar:
- tar.extractall(to_path)
- _ZIP_COMPRESSION_MAP: Dict[str, int] = {
- ".bz2": zipfile.ZIP_BZIP2,
- ".xz": zipfile.ZIP_LZMA,
- }
- def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> None:
- with zipfile.ZipFile(
- from_path, "r", compression=_ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED
- ) as zip:
- zip.extractall(to_path)
- _ARCHIVE_EXTRACTORS: Dict[str, Callable[[str, str, Optional[str]], None]] = {
- ".tar": _extract_tar,
- ".zip": _extract_zip,
- }
- _COMPRESSED_FILE_OPENERS: Dict[str, Callable[..., IO]] = {
- ".bz2": bz2.open,
- ".gz": gzip.open,
- ".xz": lzma.open,
- }
- _FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = {
- ".tbz": (".tar", ".bz2"),
- ".tbz2": (".tar", ".bz2"),
- ".tgz": (".tar", ".gz"),
- }
- def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
- """Detect the archive type and/or compression of a file.
- Args:
- file (str): the filename
- Returns:
- (tuple): tuple of suffix, archive type, and compression
- Raises:
- RuntimeError: if file has no suffix or suffix is not supported
- """
- suffixes = pathlib.Path(file).suffixes
- if not suffixes:
- raise RuntimeError(
- f"File '{file}' has no suffixes that could be used to detect the archive type and compression."
- )
- suffix = suffixes[-1]
- # check if the suffix is a known alias
- if suffix in _FILE_TYPE_ALIASES:
- return (suffix, *_FILE_TYPE_ALIASES[suffix])
- # check if the suffix is an archive type
- if suffix in _ARCHIVE_EXTRACTORS:
- return suffix, suffix, None
- # check if the suffix is a compression
- if suffix in _COMPRESSED_FILE_OPENERS:
- # check for suffix hierarchy
- if len(suffixes) > 1:
- suffix2 = suffixes[-2]
- # check if the suffix2 is an archive type
- if suffix2 in _ARCHIVE_EXTRACTORS:
- return suffix2 + suffix, suffix2, suffix
- return suffix, None, suffix
- valid_suffixes = sorted(set(_FILE_TYPE_ALIASES) | set(_ARCHIVE_EXTRACTORS) | set(_COMPRESSED_FILE_OPENERS))
- raise RuntimeError(f"Unknown compression or archive type: '{suffix}'.\nKnown suffixes are: '{valid_suffixes}'.")
- def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
- r"""Decompress a file.
- The compression is automatically detected from the file name.
- Args:
- from_path (str): Path to the file to be decompressed.
- to_path (str): Path to the decompressed file. If omitted, ``from_path`` without compression extension is used.
- remove_finished (bool): If ``True``, remove the file after the extraction.
- Returns:
- (str): Path to the decompressed file.
- """
- suffix, archive_type, compression = _detect_file_type(from_path)
- if not compression:
- raise RuntimeError(f"Couldn't detect a compression from suffix {suffix}.")
- if to_path is None:
- to_path = from_path.replace(suffix, archive_type if archive_type is not None else "")
- # We don't need to check for a missing key here, since this was already done in _detect_file_type()
- compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression]
- with compressed_file_opener(from_path, "rb") as rfh, open(to_path, "wb") as wfh:
- wfh.write(rfh.read())
- if remove_finished:
- os.remove(from_path)
- return to_path
- def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
- """Extract an archive.
- The archive type and a possible compression is automatically detected from the file name. If the file is compressed
- but not an archive the call is dispatched to :func:`decompress`.
- Args:
- from_path (str): Path to the file to be extracted.
- to_path (str): Path to the directory the file will be extracted to. If omitted, the directory of the file is
- used.
- remove_finished (bool): If ``True``, remove the file after the extraction.
- Returns:
- (str): Path to the directory the file was extracted to.
- """
- if to_path is None:
- to_path = os.path.dirname(from_path)
- suffix, archive_type, compression = _detect_file_type(from_path)
- if not archive_type:
- return _decompress(
- from_path,
- os.path.join(to_path, os.path.basename(from_path).replace(suffix, "")),
- remove_finished=remove_finished,
- )
- # We don't need to check for a missing key here, since this was already done in _detect_file_type()
- extractor = _ARCHIVE_EXTRACTORS[archive_type]
- extractor(from_path, to_path, compression)
- if remove_finished:
- os.remove(from_path)
- return to_path
- def download_and_extract_archive(
- url: str,
- download_root: str,
- extract_root: Optional[str] = None,
- filename: Optional[str] = None,
- md5: Optional[str] = None,
- remove_finished: bool = False,
- ) -> None:
- download_root = os.path.expanduser(download_root)
- if extract_root is None:
- extract_root = download_root
- if not filename:
- filename = os.path.basename(url)
- download_url(url, download_root, filename, md5)
- archive = os.path.join(download_root, filename)
- print(f"Extracting {archive} to {extract_root}")
- extract_archive(archive, extract_root, remove_finished)
- def iterable_to_str(iterable: Iterable) -> str:
- return "'" + "', '".join([str(item) for item in iterable]) + "'"
- T = TypeVar("T", str, bytes)
- def verify_str_arg(
- value: T,
- arg: Optional[str] = None,
- valid_values: Optional[Iterable[T]] = None,
- custom_msg: Optional[str] = None,
- ) -> T:
- if not isinstance(value, str):
- if arg is None:
- msg = "Expected type str, but got type {type}."
- else:
- msg = "Expected type str for argument {arg}, but got type {type}."
- msg = msg.format(type=type(value), arg=arg)
- raise ValueError(msg)
- if valid_values is None:
- return value
- if value not in valid_values:
- if custom_msg is not None:
- msg = custom_msg
- else:
- msg = "Unknown value '{value}' for argument {arg}. Valid values are {{{valid_values}}}."
- msg = msg.format(value=value, arg=arg, valid_values=iterable_to_str(valid_values))
- raise ValueError(msg)
- return value
- def _read_pfm(file_name: str, slice_channels: int = 2) -> np.ndarray:
- """Read file in .pfm format. Might contain either 1 or 3 channels of data.
- Args:
- file_name (str): Path to the file.
- slice_channels (int): Number of channels to slice out of the file.
- Useful for reading different data formats stored in .pfm files: Optical Flows, Stereo Disparity Maps, etc.
- """
- with open(file_name, "rb") as f:
- header = f.readline().rstrip()
- if header not in [b"PF", b"Pf"]:
- raise ValueError("Invalid PFM file")
- dim_match = re.match(rb"^(\d+)\s(\d+)\s$", f.readline())
- if not dim_match:
- raise Exception("Malformed PFM header.")
- w, h = (int(dim) for dim in dim_match.groups())
- scale = float(f.readline().rstrip())
- if scale < 0: # little-endian
- endian = "<"
- scale = -scale
- else:
- endian = ">" # big-endian
- data = np.fromfile(f, dtype=endian + "f")
- pfm_channels = 3 if header == b"PF" else 1
- data = data.reshape(h, w, pfm_channels).transpose(2, 0, 1)
- data = np.flip(data, axis=1) # flip on h dimension
- data = data[:slice_channels, :, :]
- return data.astype(np.float32)
- def _flip_byte_order(t: torch.Tensor) -> torch.Tensor:
- return (
- t.contiguous().view(torch.uint8).view(*t.shape, t.element_size()).flip(-1).view(*t.shape[:-1], -1).view(t.dtype)
- )
|