_internally_replaced_utils.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. import importlib.machinery
  2. import os
  3. from torch.hub import _get_torch_home
  4. _HOME = os.path.join(_get_torch_home(), "datasets", "vision")
  5. _USE_SHARDED_DATASETS = False
  6. def _download_file_from_remote_location(fpath: str, url: str) -> None:
  7. pass
  8. def _is_remote_location_available() -> bool:
  9. return False
  10. try:
  11. from torch.hub import load_state_dict_from_url # noqa: 401
  12. except ImportError:
  13. from torch.utils.model_zoo import load_url as load_state_dict_from_url # noqa: 401
  14. def _get_extension_path(lib_name):
  15. lib_dir = os.path.dirname(__file__)
  16. if os.name == "nt":
  17. # Register the main torchvision library location on the default DLL path
  18. import ctypes
  19. kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
  20. with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
  21. prev_error_mode = kernel32.SetErrorMode(0x0001)
  22. if with_load_library_flags:
  23. kernel32.AddDllDirectory.restype = ctypes.c_void_p
  24. os.add_dll_directory(lib_dir)
  25. kernel32.SetErrorMode(prev_error_mode)
  26. loader_details = (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES)
  27. extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)
  28. ext_specs = extfinder.find_spec(lib_name)
  29. if ext_specs is None:
  30. raise ImportError
  31. return ext_specs.origin