| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515 | 
							- import bz2
 
- import contextlib
 
- import gzip
 
- import hashlib
 
- import itertools
 
- import lzma
 
- import os
 
- import os.path
 
- import pathlib
 
- import re
 
- import sys
 
- import tarfile
 
- import urllib
 
- import urllib.error
 
- import urllib.request
 
- import warnings
 
- import zipfile
 
- from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Optional, Tuple, TypeVar
 
- from urllib.parse import urlparse
 
- import numpy as np
 
- import requests
 
- import torch
 
- from torch.utils.model_zoo import tqdm
 
- from .._internally_replaced_utils import _download_file_from_remote_location, _is_remote_location_available
 
- USER_AGENT = "pytorch/vision"
 
- def _save_response_content(
 
-     content: Iterator[bytes],
 
-     destination: str,
 
-     length: Optional[int] = None,
 
- ) -> None:
 
-     with open(destination, "wb") as fh, tqdm(total=length) as pbar:
 
-         for chunk in content:
 
-             # filter out keep-alive new chunks
 
-             if not chunk:
 
-                 continue
 
-             fh.write(chunk)
 
-             pbar.update(len(chunk))
 
- def _urlretrieve(url: str, filename: str, chunk_size: int = 1024 * 32) -> None:
 
-     with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
 
-         _save_response_content(iter(lambda: response.read(chunk_size), b""), filename, length=response.length)
 
- def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
 
-     # Setting the `usedforsecurity` flag does not change anything about the functionality, but indicates that we are
 
-     # not using the MD5 checksum for cryptography. This enables its usage in restricted environments like FIPS. Without
 
-     # it torchvision.datasets is unusable in these environments since we perform a MD5 check everywhere.
 
-     if sys.version_info >= (3, 9):
 
-         md5 = hashlib.md5(usedforsecurity=False)
 
-     else:
 
-         md5 = hashlib.md5()
 
-     with open(fpath, "rb") as f:
 
-         while chunk := f.read(chunk_size):
 
-             md5.update(chunk)
 
-     return md5.hexdigest()
 
- def check_md5(fpath: str, md5: str, **kwargs: Any) -> bool:
 
-     return md5 == calculate_md5(fpath, **kwargs)
 
- def check_integrity(fpath: str, md5: Optional[str] = None) -> bool:
 
-     if not os.path.isfile(fpath):
 
-         return False
 
-     if md5 is None:
 
-         return True
 
-     return check_md5(fpath, md5)
 
- def _get_redirect_url(url: str, max_hops: int = 3) -> str:
 
-     initial_url = url
 
-     headers = {"Method": "HEAD", "User-Agent": USER_AGENT}
 
-     for _ in range(max_hops + 1):
 
-         with urllib.request.urlopen(urllib.request.Request(url, headers=headers)) as response:
 
-             if response.url == url or response.url is None:
 
-                 return url
 
-             url = response.url
 
-     else:
 
-         raise RecursionError(
 
-             f"Request to {initial_url} exceeded {max_hops} redirects. The last redirect points to {url}."
 
-         )
 
- def _get_google_drive_file_id(url: str) -> Optional[str]:
 
-     parts = urlparse(url)
 
-     if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
 
-         return None
 
-     match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
 
-     if match is None:
 
-         return None
 
-     return match.group("id")
 
- def download_url(
 
-     url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None, max_redirect_hops: int = 3
 
- ) -> None:
 
-     """Download a file from a url and place it in root.
 
-     Args:
 
-         url (str): URL to download file from
 
-         root (str): Directory to place downloaded file in
 
-         filename (str, optional): Name to save the file under. If None, use the basename of the URL
 
-         md5 (str, optional): MD5 checksum of the download. If None, do not check
 
-         max_redirect_hops (int, optional): Maximum number of redirect hops allowed
 
-     """
 
-     root = os.path.expanduser(root)
 
-     if not filename:
 
-         filename = os.path.basename(url)
 
-     fpath = os.path.join(root, filename)
 
-     os.makedirs(root, exist_ok=True)
 
-     # check if file is already present locally
 
-     if check_integrity(fpath, md5):
 
-         print("Using downloaded and verified file: " + fpath)
 
-         return
 
-     if _is_remote_location_available():
 
-         _download_file_from_remote_location(fpath, url)
 
-     else:
 
-         # expand redirect chain if needed
 
-         url = _get_redirect_url(url, max_hops=max_redirect_hops)
 
-         # check if file is located on Google Drive
 
-         file_id = _get_google_drive_file_id(url)
 
-         if file_id is not None:
 
-             return download_file_from_google_drive(file_id, root, filename, md5)
 
-         # download the file
 
-         try:
 
-             print("Downloading " + url + " to " + fpath)
 
-             _urlretrieve(url, fpath)
 
-         except (urllib.error.URLError, OSError) as e:  # type: ignore[attr-defined]
 
-             if url[:5] == "https":
 
-                 url = url.replace("https:", "http:")
 
-                 print("Failed download. Trying https -> http instead. Downloading " + url + " to " + fpath)
 
-                 _urlretrieve(url, fpath)
 
-             else:
 
-                 raise e
 
-     # check integrity of downloaded file
 
-     if not check_integrity(fpath, md5):
 
-         raise RuntimeError("File not found or corrupted.")
 
- def list_dir(root: str, prefix: bool = False) -> List[str]:
 
-     """List all directories at a given root
 
-     Args:
 
-         root (str): Path to directory whose folders need to be listed
 
-         prefix (bool, optional): If true, prepends the path to each result, otherwise
 
-             only returns the name of the directories found
 
-     """
 
-     root = os.path.expanduser(root)
 
-     directories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))]
 
-     if prefix is True:
 
-         directories = [os.path.join(root, d) for d in directories]
 
-     return directories
 
- def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
 
-     """List all files ending with a suffix at a given root
 
-     Args:
 
-         root (str): Path to directory whose folders need to be listed
 
-         suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
 
-             It uses the Python "str.endswith" method and is passed directly
 
-         prefix (bool, optional): If true, prepends the path to each result, otherwise
 
-             only returns the name of the files found
 
-     """
 
-     root = os.path.expanduser(root)
 
-     files = [p for p in os.listdir(root) if os.path.isfile(os.path.join(root, p)) and p.endswith(suffix)]
 
-     if prefix is True:
 
-         files = [os.path.join(root, d) for d in files]
 
-     return files
 
- def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple[bytes, Iterator[bytes]]:
 
-     content = response.iter_content(chunk_size)
 
-     first_chunk = None
 
-     # filter out keep-alive new chunks
 
-     while not first_chunk:
 
-         first_chunk = next(content)
 
-     content = itertools.chain([first_chunk], content)
 
-     try:
 
-         match = re.search("<title>Google Drive - (?P<api_response>.+?)</title>", first_chunk.decode())
 
-         api_response = match["api_response"] if match is not None else None
 
-     except UnicodeDecodeError:
 
-         api_response = None
 
-     return api_response, content
 
- def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None):
 
-     """Download a Google Drive file from  and place it in root.
 
-     Args:
 
-         file_id (str): id of file to be downloaded
 
-         root (str): Directory to place downloaded file in
 
-         filename (str, optional): Name to save the file under. If None, use the id of the file.
 
-         md5 (str, optional): MD5 checksum of the download. If None, do not check
 
-     """
 
-     # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
 
-     root = os.path.expanduser(root)
 
-     if not filename:
 
-         filename = file_id
 
-     fpath = os.path.join(root, filename)
 
-     os.makedirs(root, exist_ok=True)
 
-     if check_integrity(fpath, md5):
 
-         print(f"Using downloaded {'and verified ' if md5 else ''}file: {fpath}")
 
-         return
 
-     url = "https://drive.google.com/uc"
 
-     params = dict(id=file_id, export="download")
 
-     with requests.Session() as session:
 
-         response = session.get(url, params=params, stream=True)
 
-         for key, value in response.cookies.items():
 
-             if key.startswith("download_warning"):
 
-                 token = value
 
-                 break
 
-         else:
 
-             api_response, content = _extract_gdrive_api_response(response)
 
-             token = "t" if api_response == "Virus scan warning" else None
 
-         if token is not None:
 
-             response = session.get(url, params=dict(params, confirm=token), stream=True)
 
-             api_response, content = _extract_gdrive_api_response(response)
 
-         if api_response == "Quota exceeded":
 
-             raise RuntimeError(
 
-                 f"The daily quota of the file {filename} is exceeded and it "
 
-                 f"can't be downloaded. This is a limitation of Google Drive "
 
-                 f"and can only be overcome by trying again later."
 
-             )
 
-         _save_response_content(content, fpath)
 
-     # In case we deal with an unhandled GDrive API response, the file should be smaller than 10kB and contain only text
 
-     if os.stat(fpath).st_size < 10 * 1024:
 
-         with contextlib.suppress(UnicodeDecodeError), open(fpath) as fh:
 
-             text = fh.read()
 
-             # Regular expression to detect HTML. Copied from https://stackoverflow.com/a/70585604
 
-             if re.search(r"</?\s*[a-z-][^>]*\s*>|(&(?:[\w\d]+|#\d+|#x[a-f\d]+);)", text):
 
-                 warnings.warn(
 
-                     f"We detected some HTML elements in the downloaded file. "
 
-                     f"This most likely means that the download triggered an unhandled API response by GDrive. "
 
-                     f"Please report this to torchvision at https://github.com/pytorch/vision/issues including "
 
-                     f"the response:\n\n{text}"
 
-                 )
 
-     if md5 and not check_md5(fpath, md5):
 
-         raise RuntimeError(
 
-             f"The MD5 checksum of the download file {fpath} does not match the one on record."
 
-             f"Please delete the file and try again. "
 
-             f"If the issue persists, please report this to torchvision at https://github.com/pytorch/vision/issues."
 
-         )
 
- def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None:
 
-     with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar:
 
-         tar.extractall(to_path)
 
- _ZIP_COMPRESSION_MAP: Dict[str, int] = {
 
-     ".bz2": zipfile.ZIP_BZIP2,
 
-     ".xz": zipfile.ZIP_LZMA,
 
- }
 
- def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> None:
 
-     with zipfile.ZipFile(
 
-         from_path, "r", compression=_ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED
 
-     ) as zip:
 
-         zip.extractall(to_path)
 
- _ARCHIVE_EXTRACTORS: Dict[str, Callable[[str, str, Optional[str]], None]] = {
 
-     ".tar": _extract_tar,
 
-     ".zip": _extract_zip,
 
- }
 
- _COMPRESSED_FILE_OPENERS: Dict[str, Callable[..., IO]] = {
 
-     ".bz2": bz2.open,
 
-     ".gz": gzip.open,
 
-     ".xz": lzma.open,
 
- }
 
- _FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = {
 
-     ".tbz": (".tar", ".bz2"),
 
-     ".tbz2": (".tar", ".bz2"),
 
-     ".tgz": (".tar", ".gz"),
 
- }
 
- def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
 
-     """Detect the archive type and/or compression of a file.
 
-     Args:
 
-         file (str): the filename
 
-     Returns:
 
-         (tuple): tuple of suffix, archive type, and compression
 
-     Raises:
 
-         RuntimeError: if file has no suffix or suffix is not supported
 
-     """
 
-     suffixes = pathlib.Path(file).suffixes
 
-     if not suffixes:
 
-         raise RuntimeError(
 
-             f"File '{file}' has no suffixes that could be used to detect the archive type and compression."
 
-         )
 
-     suffix = suffixes[-1]
 
-     # check if the suffix is a known alias
 
-     if suffix in _FILE_TYPE_ALIASES:
 
-         return (suffix, *_FILE_TYPE_ALIASES[suffix])
 
-     # check if the suffix is an archive type
 
-     if suffix in _ARCHIVE_EXTRACTORS:
 
-         return suffix, suffix, None
 
-     # check if the suffix is a compression
 
-     if suffix in _COMPRESSED_FILE_OPENERS:
 
-         # check for suffix hierarchy
 
-         if len(suffixes) > 1:
 
-             suffix2 = suffixes[-2]
 
-             # check if the suffix2 is an archive type
 
-             if suffix2 in _ARCHIVE_EXTRACTORS:
 
-                 return suffix2 + suffix, suffix2, suffix
 
-         return suffix, None, suffix
 
-     valid_suffixes = sorted(set(_FILE_TYPE_ALIASES) | set(_ARCHIVE_EXTRACTORS) | set(_COMPRESSED_FILE_OPENERS))
 
-     raise RuntimeError(f"Unknown compression or archive type: '{suffix}'.\nKnown suffixes are: '{valid_suffixes}'.")
 
- def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
 
-     r"""Decompress a file.
 
-     The compression is automatically detected from the file name.
 
-     Args:
 
-         from_path (str): Path to the file to be decompressed.
 
-         to_path (str): Path to the decompressed file. If omitted, ``from_path`` without compression extension is used.
 
-         remove_finished (bool): If ``True``, remove the file after the extraction.
 
-     Returns:
 
-         (str): Path to the decompressed file.
 
-     """
 
-     suffix, archive_type, compression = _detect_file_type(from_path)
 
-     if not compression:
 
-         raise RuntimeError(f"Couldn't detect a compression from suffix {suffix}.")
 
-     if to_path is None:
 
-         to_path = from_path.replace(suffix, archive_type if archive_type is not None else "")
 
-     # We don't need to check for a missing key here, since this was already done in _detect_file_type()
 
-     compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression]
 
-     with compressed_file_opener(from_path, "rb") as rfh, open(to_path, "wb") as wfh:
 
-         wfh.write(rfh.read())
 
-     if remove_finished:
 
-         os.remove(from_path)
 
-     return to_path
 
- def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
 
-     """Extract an archive.
 
-     The archive type and a possible compression is automatically detected from the file name. If the file is compressed
 
-     but not an archive the call is dispatched to :func:`decompress`.
 
-     Args:
 
-         from_path (str): Path to the file to be extracted.
 
-         to_path (str): Path to the directory the file will be extracted to. If omitted, the directory of the file is
 
-             used.
 
-         remove_finished (bool): If ``True``, remove the file after the extraction.
 
-     Returns:
 
-         (str): Path to the directory the file was extracted to.
 
-     """
 
-     if to_path is None:
 
-         to_path = os.path.dirname(from_path)
 
-     suffix, archive_type, compression = _detect_file_type(from_path)
 
-     if not archive_type:
 
-         return _decompress(
 
-             from_path,
 
-             os.path.join(to_path, os.path.basename(from_path).replace(suffix, "")),
 
-             remove_finished=remove_finished,
 
-         )
 
-     # We don't need to check for a missing key here, since this was already done in _detect_file_type()
 
-     extractor = _ARCHIVE_EXTRACTORS[archive_type]
 
-     extractor(from_path, to_path, compression)
 
-     if remove_finished:
 
-         os.remove(from_path)
 
-     return to_path
 
- def download_and_extract_archive(
 
-     url: str,
 
-     download_root: str,
 
-     extract_root: Optional[str] = None,
 
-     filename: Optional[str] = None,
 
-     md5: Optional[str] = None,
 
-     remove_finished: bool = False,
 
- ) -> None:
 
-     download_root = os.path.expanduser(download_root)
 
-     if extract_root is None:
 
-         extract_root = download_root
 
-     if not filename:
 
-         filename = os.path.basename(url)
 
-     download_url(url, download_root, filename, md5)
 
-     archive = os.path.join(download_root, filename)
 
-     print(f"Extracting {archive} to {extract_root}")
 
-     extract_archive(archive, extract_root, remove_finished)
 
- def iterable_to_str(iterable: Iterable) -> str:
 
-     return "'" + "', '".join([str(item) for item in iterable]) + "'"
 
- T = TypeVar("T", str, bytes)
 
- def verify_str_arg(
 
-     value: T,
 
-     arg: Optional[str] = None,
 
-     valid_values: Optional[Iterable[T]] = None,
 
-     custom_msg: Optional[str] = None,
 
- ) -> T:
 
-     if not isinstance(value, str):
 
-         if arg is None:
 
-             msg = "Expected type str, but got type {type}."
 
-         else:
 
-             msg = "Expected type str for argument {arg}, but got type {type}."
 
-         msg = msg.format(type=type(value), arg=arg)
 
-         raise ValueError(msg)
 
-     if valid_values is None:
 
-         return value
 
-     if value not in valid_values:
 
-         if custom_msg is not None:
 
-             msg = custom_msg
 
-         else:
 
-             msg = "Unknown value '{value}' for argument {arg}. Valid values are {{{valid_values}}}."
 
-             msg = msg.format(value=value, arg=arg, valid_values=iterable_to_str(valid_values))
 
-         raise ValueError(msg)
 
-     return value
 
- def _read_pfm(file_name: str, slice_channels: int = 2) -> np.ndarray:
 
-     """Read file in .pfm format. Might contain either 1 or 3 channels of data.
 
-     Args:
 
-         file_name (str): Path to the file.
 
-         slice_channels (int): Number of channels to slice out of the file.
 
-             Useful for reading different data formats stored in .pfm files: Optical Flows, Stereo Disparity Maps, etc.
 
-     """
 
-     with open(file_name, "rb") as f:
 
-         header = f.readline().rstrip()
 
-         if header not in [b"PF", b"Pf"]:
 
-             raise ValueError("Invalid PFM file")
 
-         dim_match = re.match(rb"^(\d+)\s(\d+)\s$", f.readline())
 
-         if not dim_match:
 
-             raise Exception("Malformed PFM header.")
 
-         w, h = (int(dim) for dim in dim_match.groups())
 
-         scale = float(f.readline().rstrip())
 
-         if scale < 0:  # little-endian
 
-             endian = "<"
 
-             scale = -scale
 
-         else:
 
-             endian = ">"  # big-endian
 
-         data = np.fromfile(f, dtype=endian + "f")
 
-     pfm_channels = 3 if header == b"PF" else 1
 
-     data = data.reshape(h, w, pfm_channels).transpose(2, 0, 1)
 
-     data = np.flip(data, axis=1)  # flip on h dimension
 
-     data = data[:slice_channels, :, :]
 
-     return data.astype(np.float32)
 
- def _flip_byte_order(t: torch.Tensor) -> torch.Tensor:
 
-     return (
 
-         t.contiguous().view(torch.uint8).view(*t.shape, t.element_size()).flip(-1).view(*t.shape[:-1], -1).view(t.dtype)
 
-     )
 
 
  |