1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950 |
- import importlib.machinery
- import os
- from torch.hub import _get_torch_home
- _HOME = os.path.join(_get_torch_home(), "datasets", "vision")
- _USE_SHARDED_DATASETS = False
- def _download_file_from_remote_location(fpath: str, url: str) -> None:
- pass
- def _is_remote_location_available() -> bool:
- return False
- try:
- from torch.hub import load_state_dict_from_url # noqa: 401
- except ImportError:
- from torch.utils.model_zoo import load_url as load_state_dict_from_url # noqa: 401
- def _get_extension_path(lib_name):
- lib_dir = os.path.dirname(__file__)
- if os.name == "nt":
- # Register the main torchvision library location on the default DLL path
- import ctypes
- kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
- with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
- prev_error_mode = kernel32.SetErrorMode(0x0001)
- if with_load_library_flags:
- kernel32.AddDllDirectory.restype = ctypes.c_void_p
- os.add_dll_directory(lib_dir)
- kernel32.SetErrorMode(prev_error_mode)
- loader_details = (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES)
- extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)
- ext_specs = extfinder.find_spec(lib_name)
- if ext_specs is None:
- raise ImportError
- return ext_specs.origin
|