extension.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import os
  2. import sys
  3. import torch
  4. from ._internally_replaced_utils import _get_extension_path
  5. _HAS_OPS = False
  6. def _has_ops():
  7. return False
  8. try:
  9. # On Windows Python-3.8.x has `os.add_dll_directory` call,
  10. # which is called to configure dll search path.
  11. # To find cuda related dlls we need to make sure the
  12. # conda environment/bin path is configured Please take a look:
  13. # https://stackoverflow.com/questions/59330863/cant-import-dll-module-in-python
  14. # Please note: if some path can't be added using add_dll_directory we simply ignore this path
  15. if os.name == "nt" and sys.version_info < (3, 9):
  16. env_path = os.environ["PATH"]
  17. path_arr = env_path.split(";")
  18. for path in path_arr:
  19. if os.path.exists(path):
  20. try:
  21. os.add_dll_directory(path) # type: ignore[attr-defined]
  22. except Exception:
  23. pass
  24. lib_path = _get_extension_path("_C")
  25. torch.ops.load_library(lib_path)
  26. _HAS_OPS = True
  27. def _has_ops(): # noqa: F811
  28. return True
  29. except (ImportError, OSError):
  30. pass
  31. def _assert_has_ops():
  32. if not _has_ops():
  33. raise RuntimeError(
  34. "Couldn't load custom C++ ops. This can happen if your PyTorch and "
  35. "torchvision versions are incompatible, or if you had errors while compiling "
  36. "torchvision from source. For further information on the compatible versions, check "
  37. "https://github.com/pytorch/vision#installation for the compatibility matrix. "
  38. "Please check your PyTorch version with torch.__version__ and your torchvision "
  39. "version with torchvision.__version__ and verify if they are compatible, and if not "
  40. "please reinstall torchvision so that it matches your PyTorch install."
  41. )
  42. def _check_cuda_version():
  43. """
  44. Make sure that CUDA versions match between the pytorch install and torchvision install
  45. """
  46. if not _HAS_OPS:
  47. return -1
  48. from torch.version import cuda as torch_version_cuda
  49. _version = torch.ops.torchvision._cuda_version()
  50. if _version != -1 and torch_version_cuda is not None:
  51. tv_version = str(_version)
  52. if int(tv_version) < 10000:
  53. tv_major = int(tv_version[0])
  54. tv_minor = int(tv_version[2])
  55. else:
  56. tv_major = int(tv_version[0:2])
  57. tv_minor = int(tv_version[3])
  58. t_version = torch_version_cuda.split(".")
  59. t_major = int(t_version[0])
  60. t_minor = int(t_version[1])
  61. if t_major != tv_major:
  62. raise RuntimeError(
  63. "Detected that PyTorch and torchvision were compiled with different CUDA major versions. "
  64. f"PyTorch has CUDA Version={t_major}.{t_minor} and torchvision has "
  65. f"CUDA Version={tv_major}.{tv_minor}. "
  66. "Please reinstall the torchvision that matches your PyTorch install."
  67. )
  68. return _version
  69. def _load_library(lib_name):
  70. lib_path = _get_extension_path(lib_name)
  71. torch.ops.load_library(lib_path)
  72. _check_cuda_version()