common_cuda.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. r"""This file is allowed to initialize CUDA context when imported."""
  2. import functools
  3. import torch
  4. import torch.cuda
  5. from torch.testing._internal.common_utils import TEST_NUMBA, IS_WINDOWS, TEST_WITH_ROCM
  6. import inspect
  7. import contextlib
  8. TEST_CUDA = torch.cuda.is_available()
  9. TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2
  10. CUDA_DEVICE = torch.device("cuda:0") if TEST_CUDA else None
  11. # note: if ROCm is targeted, TEST_CUDNN is code for TEST_MIOPEN
  12. TEST_CUDNN = TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE))
  13. TEST_CUDNN_VERSION = torch.backends.cudnn.version() if TEST_CUDNN else 0
  14. SM53OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (5, 3)
  15. SM60OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (6, 0)
  16. SM80OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0)
  17. PLATFORM_SUPPORTS_FUSED_SDPA: bool = TEST_CUDA and not TEST_WITH_ROCM
  18. TEST_MAGMA = TEST_CUDA
  19. if TEST_CUDA:
  20. torch.ones(1).cuda() # has_magma shows up after cuda is initialized
  21. TEST_MAGMA = torch.cuda.has_magma
  22. if TEST_NUMBA:
  23. import numba.cuda
  24. TEST_NUMBA_CUDA = numba.cuda.is_available()
  25. else:
  26. TEST_NUMBA_CUDA = False
  27. # Used below in `initialize_cuda_context_rng` to ensure that CUDA context and
  28. # RNG have been initialized.
  29. __cuda_ctx_rng_initialized = False
  30. # after this call, CUDA context and RNG must have been initialized on each GPU
  31. def initialize_cuda_context_rng():
  32. global __cuda_ctx_rng_initialized
  33. assert TEST_CUDA, 'CUDA must be available when calling initialize_cuda_context_rng'
  34. if not __cuda_ctx_rng_initialized:
  35. # initialize cuda context and rng for memory tests
  36. for i in range(torch.cuda.device_count()):
  37. torch.randn(1, device="cuda:{}".format(i))
  38. __cuda_ctx_rng_initialized = True
  39. # Test whether hardware TF32 math mode enabled. It is enabled only on:
  40. # - CUDA >= 11
  41. # - arch >= Ampere
  42. def tf32_is_not_fp32():
  43. if not torch.cuda.is_available() or torch.version.cuda is None:
  44. return False
  45. if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
  46. return False
  47. if int(torch.version.cuda.split('.')[0]) < 11:
  48. return False
  49. return True
  50. @contextlib.contextmanager
  51. def tf32_off():
  52. old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
  53. try:
  54. torch.backends.cuda.matmul.allow_tf32 = False
  55. with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=False):
  56. yield
  57. finally:
  58. torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
  59. @contextlib.contextmanager
  60. def tf32_on(self, tf32_precision=1e-5):
  61. old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
  62. old_precision = self.precision
  63. try:
  64. torch.backends.cuda.matmul.allow_tf32 = True
  65. self.precision = tf32_precision
  66. with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=True):
  67. yield
  68. finally:
  69. torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
  70. self.precision = old_precision
  71. # This is a wrapper that wraps a test to run this test twice, one with
  72. # allow_tf32=True, another with allow_tf32=False. When running with
  73. # allow_tf32=True, it will use reduced precision as pecified by the
  74. # argument. For example:
  75. # @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
  76. # @tf32_on_and_off(0.005)
  77. # def test_matmul(self, device, dtype):
  78. # a = ...; b = ...;
  79. # c = torch.matmul(a, b)
  80. # self.assertEqual(c, expected)
  81. # In the above example, when testing torch.float32 and torch.complex64 on CUDA
  82. # on a CUDA >= 11 build on an >=Ampere architecture, the matmul will be running at
  83. # TF32 mode and TF32 mode off, and on TF32 mode, the assertEqual will use reduced
  84. # precision to check values.
  85. #
  86. # This decorator can be used for function with or without device/dtype, such as
  87. # @tf32_on_and_off(0.005)
  88. # def test_my_op(self)
  89. # @tf32_on_and_off(0.005)
  90. # def test_my_op(self, device)
  91. # @tf32_on_and_off(0.005)
  92. # def test_my_op(self, device, dtype)
  93. # @tf32_on_and_off(0.005)
  94. # def test_my_op(self, dtype)
  95. # if neither device nor dtype is specified, it will check if the system has ampere device
  96. # if device is specified, it will check if device is cuda
  97. # if dtype is specified, it will check if dtype is float32 or complex64
  98. # tf32 and fp32 are different only when all the three checks pass
  99. def tf32_on_and_off(tf32_precision=1e-5):
  100. def with_tf32_disabled(self, function_call):
  101. with tf32_off():
  102. function_call()
  103. def with_tf32_enabled(self, function_call):
  104. with tf32_on(self, tf32_precision):
  105. function_call()
  106. def wrapper(f):
  107. params = inspect.signature(f).parameters
  108. arg_names = tuple(params.keys())
  109. @functools.wraps(f)
  110. def wrapped(*args, **kwargs):
  111. for k, v in zip(arg_names, args):
  112. kwargs[k] = v
  113. cond = tf32_is_not_fp32()
  114. if 'device' in kwargs:
  115. cond = cond and (torch.device(kwargs['device']).type == 'cuda')
  116. if 'dtype' in kwargs:
  117. cond = cond and (kwargs['dtype'] in {torch.float32, torch.complex64})
  118. if cond:
  119. with_tf32_disabled(kwargs['self'], lambda: f(**kwargs))
  120. with_tf32_enabled(kwargs['self'], lambda: f(**kwargs))
  121. else:
  122. f(**kwargs)
  123. return wrapped
  124. return wrapper
  125. # This is a wrapper that wraps a test to run it with TF32 turned off.
  126. # This wrapper is designed to be used when a test uses matmul or convolutions
  127. # but the purpose of that test is not testing matmul or convolutions.
  128. # Disabling TF32 will enforce torch.float tensors to be always computed
  129. # at full precision.
  130. def with_tf32_off(f):
  131. @functools.wraps(f)
  132. def wrapped(*args, **kwargs):
  133. with tf32_off():
  134. return f(*args, **kwargs)
  135. return wrapped
  136. def _get_magma_version():
  137. if 'Magma' not in torch.__config__.show():
  138. return (0, 0)
  139. position = torch.__config__.show().find('Magma ')
  140. version_str = torch.__config__.show()[position + len('Magma '):].split('\n')[0]
  141. return tuple(int(x) for x in version_str.split("."))
  142. def _get_torch_cuda_version():
  143. if torch.version.cuda is None:
  144. return (0, 0)
  145. cuda_version = str(torch.version.cuda)
  146. return tuple(int(x) for x in cuda_version.split("."))
  147. def _get_torch_rocm_version():
  148. if not TEST_WITH_ROCM:
  149. return (0, 0)
  150. rocm_version = str(torch.version.hip)
  151. rocm_version = rocm_version.split("-")[0] # ignore git sha
  152. return tuple(int(x) for x in rocm_version.split("."))
  153. def _check_cusparse_generic_available():
  154. version = _get_torch_cuda_version()
  155. min_supported_version = (10, 1)
  156. if IS_WINDOWS:
  157. min_supported_version = (11, 0)
  158. return version >= min_supported_version
  159. def _check_hipsparse_generic_available():
  160. if not TEST_WITH_ROCM:
  161. return False
  162. rocm_version = str(torch.version.hip)
  163. rocm_version = rocm_version.split("-")[0] # ignore git sha
  164. rocm_version_tuple = tuple(int(x) for x in rocm_version.split("."))
  165. return not (rocm_version_tuple is None or rocm_version_tuple < (5, 1))
  166. TEST_CUSPARSE_GENERIC = _check_cusparse_generic_available()
  167. TEST_HIPSPARSE_GENERIC = _check_hipsparse_generic_available()