__init__.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. import sys
  2. import torch
  3. import contextlib
  4. from enum import IntEnum
  5. from typing import Union
  6. __all__ = ["is_built", "cuFFTPlanCacheAttrContextProp", "cuFFTPlanCache", "cuFFTPlanCacheManager",
  7. "cuBLASModule", "preferred_linalg_library", "cufft_plan_cache", "matmul", "SDPBackend", "enable_flash_sdp",
  8. "flash_sdp_enabled", "enable_mem_efficient_sdp", "mem_efficient_sdp_enabled",
  9. "math_sdp_enabled", "enable_math_sdp", "sdp_kernel"]
  10. def is_built():
  11. r"""Returns whether PyTorch is built with CUDA support. Note that this
  12. doesn't necessarily mean CUDA is available; just that if this PyTorch
  13. binary were run a machine with working CUDA drivers and devices, we
  14. would be able to use it."""
  15. return torch._C.has_cuda
  16. class cuFFTPlanCacheAttrContextProp:
  17. # Like regular ContextProp, but uses the `.device_index` attribute from the
  18. # calling object as the first argument to the getter and setter.
  19. def __init__(self, getter, setter):
  20. self.getter = getter
  21. self.setter = setter
  22. def __get__(self, obj, objtype):
  23. return self.getter(obj.device_index)
  24. def __set__(self, obj, val):
  25. if isinstance(self.setter, str):
  26. raise RuntimeError(self.setter)
  27. self.setter(obj.device_index, val)
  28. class cuFFTPlanCache:
  29. r"""
  30. Represents a specific plan cache for a specific `device_index`. The
  31. attributes `size` and `max_size`, and method `clear`, can fetch and/ or
  32. change properties of the C++ cuFFT plan cache.
  33. """
  34. def __init__(self, device_index):
  35. self.device_index = device_index
  36. size = cuFFTPlanCacheAttrContextProp(
  37. torch._cufft_get_plan_cache_size,
  38. '.size is a read-only property showing the number of plans currently in the '
  39. 'cache. To change the cache capacity, set cufft_plan_cache.max_size.')
  40. max_size = cuFFTPlanCacheAttrContextProp(torch._cufft_get_plan_cache_max_size,
  41. torch._cufft_set_plan_cache_max_size)
  42. def clear(self):
  43. return torch._cufft_clear_plan_cache(self.device_index)
  44. class cuFFTPlanCacheManager:
  45. r"""
  46. Represents all cuFFT plan caches. When indexed with a device object/index,
  47. this object returns the `cuFFTPlanCache` corresponding to that device.
  48. Finally, this object, when used directly as a `cuFFTPlanCache` object (e.g.,
  49. setting the `.max_size`) attribute, the current device's cuFFT plan cache is
  50. used.
  51. """
  52. __initialized = False
  53. def __init__(self):
  54. self.caches = []
  55. self.__initialized = True
  56. def __getitem__(self, device):
  57. index = torch.cuda._utils._get_device_index(device)
  58. if index < 0 or index >= torch.cuda.device_count():
  59. raise RuntimeError(
  60. ("cufft_plan_cache: expected 0 <= device index < {}, but got "
  61. "device with index {}").format(torch.cuda.device_count(), index))
  62. if len(self.caches) == 0:
  63. self.caches.extend(cuFFTPlanCache(index) for index in range(torch.cuda.device_count()))
  64. return self.caches[index]
  65. def __getattr__(self, name):
  66. return getattr(self[torch.cuda.current_device()], name)
  67. def __setattr__(self, name, value):
  68. if self.__initialized:
  69. return setattr(self[torch.cuda.current_device()], name, value)
  70. else:
  71. return super().__setattr__(name, value)
  72. class cuBLASModule:
  73. def __getattr__(self, name):
  74. if name == "allow_tf32":
  75. return torch._C._get_cublas_allow_tf32()
  76. elif name == "allow_fp16_reduced_precision_reduction":
  77. return torch._C._get_cublas_allow_fp16_reduced_precision_reduction()
  78. elif name == "allow_bf16_reduced_precision_reduction":
  79. return torch._C._get_cublas_allow_bf16_reduced_precision_reduction()
  80. raise AssertionError("Unknown attribute " + name)
  81. def __setattr__(self, name, value):
  82. if name == "allow_tf32":
  83. return torch._C._set_cublas_allow_tf32(value)
  84. elif name == "allow_fp16_reduced_precision_reduction":
  85. return torch._C._set_cublas_allow_fp16_reduced_precision_reduction(value)
  86. elif name == "allow_bf16_reduced_precision_reduction":
  87. return torch._C._set_cublas_allow_bf16_reduced_precision_reduction(value)
  88. raise AssertionError("Unknown attribute " + name)
  89. _LinalgBackends = {
  90. 'default': torch._C._LinalgBackend.Default,
  91. 'cusolver': torch._C._LinalgBackend.Cusolver,
  92. 'magma': torch._C._LinalgBackend.Magma,
  93. }
  94. _LinalgBackends_str = ', '.join(_LinalgBackends.keys())
  95. def preferred_linalg_library(backend: Union[None, str, torch._C._LinalgBackend] = None) -> torch._C._LinalgBackend:
  96. r'''
  97. .. warning:: This flag is experimental and subject to change.
  98. When PyTorch runs a CUDA linear algebra operation it often uses the cuSOLVER or MAGMA libraries,
  99. and if both are available it decides which to use with a heuristic.
  100. This flag (a :class:`str`) allows overriding those heuristics.
  101. * If `"cusolver"` is set then cuSOLVER will be used wherever possible.
  102. * If `"magma"` is set then MAGMA will be used wherever possible.
  103. * If `"default"` (the default) is set then heuristics will be used to pick between
  104. cuSOLVER and MAGMA if both are available.
  105. * When no input is given, this function returns the currently preferred library.
  106. Note: When a library is preferred other libraries may still be used if the preferred library
  107. doesn't implement the operation(s) called.
  108. This flag may achieve better performance if PyTorch's heuristic library selection is incorrect
  109. for your application's inputs.
  110. Currently supported linalg operators:
  111. * :func:`torch.linalg.inv`
  112. * :func:`torch.linalg.inv_ex`
  113. * :func:`torch.linalg.cholesky`
  114. * :func:`torch.linalg.cholesky_ex`
  115. * :func:`torch.cholesky_solve`
  116. * :func:`torch.cholesky_inverse`
  117. * :func:`torch.linalg.lu_factor`
  118. * :func:`torch.linalg.lu`
  119. * :func:`torch.linalg.lu_solve`
  120. * :func:`torch.linalg.qr`
  121. * :func:`torch.linalg.eigh`
  122. * :func:`torch.linalg.eighvals`
  123. * :func:`torch.linalg.svd`
  124. * :func:`torch.linalg.svdvals`
  125. '''
  126. if backend is None:
  127. pass
  128. elif isinstance(backend, str):
  129. if backend not in _LinalgBackends:
  130. raise RuntimeError("Unknown input value. "
  131. f"Choose from: {_LinalgBackends_str}.")
  132. torch._C._set_linalg_preferred_backend(_LinalgBackends[backend])
  133. elif isinstance(backend, torch._C._LinalgBackend):
  134. torch._C._set_linalg_preferred_backend(backend)
  135. else:
  136. raise RuntimeError("Unknown input value type.")
  137. return torch._C._get_linalg_preferred_backend()
  138. class SDPBackend(IntEnum):
  139. r"""Enum class for the scaled dot product attention backends.
  140. .. warning:: This class is in beta and subject to change.
  141. This class needs to stay aligned with the enum defined in:
  142. pytorch/aten/src/ATen/native/transformers/sdp_utils_cpp.h
  143. """
  144. ERROR = -1
  145. MATH = 0
  146. FLASH_ATTENTION = 1
  147. EFFICIENT_ATTENTION = 2
  148. def flash_sdp_enabled():
  149. r"""
  150. .. warning:: This flag is beta and subject to change.
  151. Returns whether flash scaled dot product attention is enabled or not.
  152. """
  153. return torch._C._get_flash_sdp_enabled()
  154. def enable_flash_sdp(enabled: bool):
  155. r"""
  156. .. warning:: This flag is beta and subject to change.
  157. Enables or disables flash scaled dot product attention.
  158. """
  159. torch._C._set_sdp_use_flash(enabled)
  160. def mem_efficient_sdp_enabled():
  161. r"""
  162. .. warning:: This flag is beta and subject to change.
  163. Returns whether memory efficient scaled dot product attention is enabled or not.
  164. """
  165. return torch._C._get_mem_efficient_sdp_enabled()
  166. def enable_mem_efficient_sdp(enabled: bool):
  167. r"""
  168. .. warning:: This flag is beta and subject to change.
  169. Enables or disables memory efficient scaled dot product attention.
  170. """
  171. torch._C._set_sdp_use_mem_efficient(enabled)
  172. def math_sdp_enabled():
  173. r"""
  174. .. warning:: This flag is beta and subject to change.
  175. Returns whether math scaled dot product attention is enabled or not.
  176. """
  177. return torch._C._get_math_sdp_enabled()
  178. def enable_math_sdp(enabled: bool):
  179. r"""
  180. .. warning:: This flag is beta and subject to change.
  181. Enables or disables math scaled dot product attention.
  182. """
  183. torch._C._set_sdp_use_math(enabled)
  184. @contextlib.contextmanager
  185. def sdp_kernel(enable_flash: bool = True, enable_math: bool = True, enable_mem_efficient: bool = True):
  186. r"""
  187. .. warning:: This flag is beta and subject to change.
  188. This context manager can be used to temporarily enable or disable any of the three backends for scaled dot product attention.
  189. Upon exiting the context manager, the previous state of the flags will be restored.
  190. """
  191. previous_flash: bool = flash_sdp_enabled()
  192. previous_mem_efficient: bool = mem_efficient_sdp_enabled()
  193. previous_math: bool = math_sdp_enabled()
  194. try:
  195. enable_flash_sdp(enable_flash)
  196. enable_mem_efficient_sdp(enable_mem_efficient)
  197. enable_math_sdp(enable_math)
  198. yield{}
  199. except RuntimeError as err:
  200. raise err
  201. finally:
  202. enable_flash_sdp(previous_flash)
  203. enable_mem_efficient_sdp(previous_mem_efficient)
  204. enable_math_sdp(previous_math)
  205. cufft_plan_cache = cuFFTPlanCacheManager()
  206. matmul = cuBLASModule()