import importlib.machinery import os from torch.hub import _get_torch_home _HOME = os.path.join(_get_torch_home(), "datasets", "vision") _USE_SHARDED_DATASETS = False def _download_file_from_remote_location(fpath: str, url: str) -> None: pass def _is_remote_location_available() -> bool: return False try: from torch.hub import load_state_dict_from_url # noqa: 401 except ImportError: from torch.utils.model_zoo import load_url as load_state_dict_from_url # noqa: 401 def _get_extension_path(lib_name): lib_dir = os.path.dirname(__file__) if os.name == "nt": # Register the main torchvision library location on the default DLL path import ctypes kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True) with_load_library_flags = hasattr(kernel32, "AddDllDirectory") prev_error_mode = kernel32.SetErrorMode(0x0001) if with_load_library_flags: kernel32.AddDllDirectory.restype = ctypes.c_void_p os.add_dll_directory(lib_dir) kernel32.SetErrorMode(prev_error_mode) loader_details = (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES) extfinder = importlib.machinery.FileFinder(lib_dir, loader_details) ext_specs = extfinder.find_spec(lib_name) if ext_specs is None: raise ImportError return ext_specs.origin