__init__.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. import sys
  2. import os
  3. import torch
  4. import warnings
  5. from contextlib import contextmanager
  6. from torch.backends import ContextProp, PropModule, __allow_nonbracketed_mutation
  7. try:
  8. from torch._C import _cudnn
  9. except ImportError:
  10. _cudnn = None # type: ignore[assignment]
  11. # Write:
  12. #
  13. # torch.backends.cudnn.enabled = False
  14. #
  15. # to globally disable CuDNN/MIOpen
  16. __cudnn_version = None
  17. if _cudnn is not None:
  18. def _init():
  19. global __cudnn_version
  20. if __cudnn_version is None:
  21. __cudnn_version = _cudnn.getVersionInt()
  22. runtime_version = _cudnn.getRuntimeVersion()
  23. compile_version = _cudnn.getCompileVersion()
  24. runtime_major, runtime_minor, _ = runtime_version
  25. compile_major, compile_minor, _ = compile_version
  26. # Different major versions are always incompatible
  27. # Starting with cuDNN 7, minor versions are backwards-compatible
  28. # Not sure about MIOpen (ROCm), so always do a strict check
  29. if runtime_major != compile_major:
  30. cudnn_compatible = False
  31. elif runtime_major < 7 or not _cudnn.is_cuda:
  32. cudnn_compatible = runtime_minor == compile_minor
  33. else:
  34. cudnn_compatible = runtime_minor >= compile_minor
  35. if not cudnn_compatible:
  36. if os.environ.get('PYTORCH_SKIP_CUDNN_COMPATIBILITY_CHECK', '0') == '1':
  37. return True
  38. base_error_msg = (f'cuDNN version incompatibility: '
  39. f'PyTorch was compiled against {compile_version} '
  40. f'but found runtime version {runtime_version}. '
  41. f'PyTorch already comes bundled with cuDNN. '
  42. f'One option to resolving this error is to ensure PyTorch '
  43. f'can find the bundled cuDNN.')
  44. if 'LD_LIBRARY_PATH' in os.environ:
  45. ld_library_path = os.environ.get('LD_LIBRARY_PATH', '')
  46. if any(substring in ld_library_path for substring in ['cuda', 'cudnn']):
  47. raise RuntimeError(f'{base_error_msg}'
  48. f'Looks like your LD_LIBRARY_PATH contains incompatible version of cudnn'
  49. f'Please either remove it from the path or install cudnn {compile_version}')
  50. else:
  51. raise RuntimeError(f'{base_error_msg}'
  52. f'one possibility is that there is a '
  53. f'conflicting cuDNN in LD_LIBRARY_PATH.')
  54. else:
  55. raise RuntimeError(base_error_msg)
  56. return True
  57. else:
  58. def _init():
  59. return False
  60. def version():
  61. """Returns the version of cuDNN"""
  62. if not _init():
  63. return None
  64. return __cudnn_version
  65. CUDNN_TENSOR_DTYPES = {
  66. torch.half,
  67. torch.float,
  68. torch.double,
  69. }
  70. def is_available():
  71. r"""Returns a bool indicating if CUDNN is currently available."""
  72. return torch._C.has_cudnn
  73. def is_acceptable(tensor):
  74. if not torch._C._get_cudnn_enabled():
  75. return False
  76. if tensor.device.type != 'cuda' or tensor.dtype not in CUDNN_TENSOR_DTYPES:
  77. return False
  78. if not is_available():
  79. warnings.warn(
  80. "PyTorch was compiled without cuDNN/MIOpen support. To use cuDNN/MIOpen, rebuild "
  81. "PyTorch making sure the library is visible to the build system.")
  82. return False
  83. if not _init():
  84. warnings.warn('cuDNN/MIOpen library not found. Check your {libpath}'.format(
  85. libpath={
  86. 'darwin': 'DYLD_LIBRARY_PATH',
  87. 'win32': 'PATH'
  88. }.get(sys.platform, 'LD_LIBRARY_PATH')))
  89. return False
  90. return True
  91. def set_flags(_enabled=None, _benchmark=None, _benchmark_limit=None, _deterministic=None, _allow_tf32=None):
  92. orig_flags = (torch._C._get_cudnn_enabled(),
  93. torch._C._get_cudnn_benchmark(),
  94. None if not is_available() else torch._C._cuda_get_cudnn_benchmark_limit(),
  95. torch._C._get_cudnn_deterministic(),
  96. torch._C._get_cudnn_allow_tf32())
  97. if _enabled is not None:
  98. torch._C._set_cudnn_enabled(_enabled)
  99. if _benchmark is not None:
  100. torch._C._set_cudnn_benchmark(_benchmark)
  101. if _benchmark_limit is not None and is_available():
  102. torch._C._cuda_set_cudnn_benchmark_limit(_benchmark_limit)
  103. if _deterministic is not None:
  104. torch._C._set_cudnn_deterministic(_deterministic)
  105. if _allow_tf32 is not None:
  106. torch._C._set_cudnn_allow_tf32(_allow_tf32)
  107. return orig_flags
  108. @contextmanager
  109. def flags(enabled=False, benchmark=False, benchmark_limit=10, deterministic=False, allow_tf32=True):
  110. with __allow_nonbracketed_mutation():
  111. orig_flags = set_flags(enabled, benchmark, benchmark_limit, deterministic, allow_tf32)
  112. try:
  113. yield
  114. finally:
  115. # recover the previous values
  116. with __allow_nonbracketed_mutation():
  117. set_flags(*orig_flags)
  118. # The magic here is to allow us to intercept code like this:
  119. #
  120. # torch.backends.<cudnn|mkldnn>.enabled = True
  121. class CudnnModule(PropModule):
  122. def __init__(self, m, name):
  123. super().__init__(m, name)
  124. enabled = ContextProp(torch._C._get_cudnn_enabled, torch._C._set_cudnn_enabled)
  125. deterministic = ContextProp(torch._C._get_cudnn_deterministic, torch._C._set_cudnn_deterministic)
  126. benchmark = ContextProp(torch._C._get_cudnn_benchmark, torch._C._set_cudnn_benchmark)
  127. benchmark_limit = None
  128. if is_available():
  129. benchmark_limit = ContextProp(torch._C._cuda_get_cudnn_benchmark_limit, torch._C._cuda_set_cudnn_benchmark_limit)
  130. allow_tf32 = ContextProp(torch._C._get_cudnn_allow_tf32, torch._C._set_cudnn_allow_tf32)
  131. # This is the sys.modules replacement trick, see
  132. # https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
  133. sys.modules[__name__] = CudnnModule(sys.modules[__name__], __name__)
  134. # Add type annotation for the replaced module
  135. enabled: bool
  136. deterministic: bool
  137. benchmark: bool
  138. allow_tf32: bool
  139. benchmark_limit: int