memory.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734
  1. import collections
  2. import contextlib
  3. import ctypes
  4. import warnings
  5. from typing import Any, Dict, Union, Tuple
  6. import torch
  7. from . import is_initialized, _get_device_index, _lazy_init
  8. from ._utils import _dummy_type
  9. from ._memory_viz import segments as _segments, memory as _memory
  10. from torch.types import Device
  11. from torch import _C
  12. __all__ = ["caching_allocator_alloc", "caching_allocator_delete", "set_per_process_memory_fraction",
  13. "empty_cache", "memory_stats", "memory_stats_as_nested_dict", "reset_accumulated_memory_stats",
  14. "reset_peak_memory_stats", "reset_max_memory_allocated", "reset_max_memory_cached",
  15. "memory_allocated", "max_memory_allocated", "memory_reserved", "max_memory_reserved",
  16. "memory_cached", "max_memory_cached", "memory_snapshot", "memory_summary", "list_gpu_processes",
  17. "mem_get_info", "get_allocator_backend", "CUDAPluggableAllocator", "change_current_allocator"]
  18. if not hasattr(torch._C, '_cuda_CUDAAllocator'):
  19. # Define dummy base classes
  20. torch._C.__dict__['_cuda_CUDAAllocator'] = _dummy_type('_cuda_CUDAAllocator')
  21. def _host_allocator():
  22. _lazy_init()
  23. return torch._C._cuda_cudaHostAllocator()
  24. @contextlib.contextmanager
  25. def _free_mutex():
  26. torch._C._cuda_lock_mutex()
  27. try:
  28. yield
  29. finally:
  30. torch._C._cuda_unlock_mutex()
  31. def caching_allocator_alloc(size, device: Union[Device, int] = None, stream=None):
  32. r"""Performs a memory allocation using the CUDA memory allocator.
  33. Memory is allocated for a given device and a stream, this
  34. function is intended to be used for interoperability with other
  35. frameworks. Allocated memory is released through
  36. :func:`~torch.cuda.caching_allocator_delete`.
  37. Args:
  38. size (int): number of bytes to be allocated.
  39. device (torch.device or int, optional): selected device. If it is
  40. ``None`` the default CUDA device is used.
  41. stream (torch.cuda.Stream or int, optional): selected stream. If is ``None`` then
  42. the default stream for the selected device is used.
  43. .. note::
  44. See :ref:`cuda-memory-management` for more details about GPU memory
  45. management.
  46. """
  47. if device is None:
  48. device = torch.cuda.current_device()
  49. device = _get_device_index(device)
  50. if stream is None:
  51. stream = torch.cuda.current_stream(device)
  52. if isinstance(stream, torch.cuda.streams.Stream):
  53. stream = stream.cuda_stream
  54. if not isinstance(stream, int):
  55. raise TypeError('Invalid type for stream argument, must be '
  56. '`torch.cuda.Stream` or `int` representing a pointer '
  57. 'to a existing stream')
  58. with torch.cuda.device(device):
  59. return torch._C._cuda_cudaCachingAllocator_raw_alloc(size, stream)
  60. def caching_allocator_delete(mem_ptr):
  61. r"""Deletes memory allocated using the CUDA memory allocator.
  62. Memory allocated with :func:`~torch.cuda.caching_allocator_alloc`.
  63. is freed here. The associated device and stream are tracked inside
  64. the allocator.
  65. Args:
  66. mem_ptr (int): memory address to be freed by the allocator.
  67. .. note::
  68. See :ref:`cuda-memory-management` for more details about GPU memory
  69. management.
  70. """
  71. torch._C._cuda_cudaCachingAllocator_raw_delete(mem_ptr)
  72. def set_per_process_memory_fraction(fraction, device: Union[Device, int] = None) -> None:
  73. r"""Set memory fraction for a process.
  74. The fraction is used to limit an caching allocator to allocated memory on a CUDA device.
  75. The allowed value equals the total visible memory multiplied fraction.
  76. If trying to allocate more than the allowed value in a process, will raise an out of
  77. memory error in allocator.
  78. Args:
  79. fraction(float): Range: 0~1. Allowed memory equals total_memory * fraction.
  80. device (torch.device or int, optional): selected device. If it is
  81. ``None`` the default CUDA device is used.
  82. .. note::
  83. In general, the total available free memory is less than the total capacity.
  84. """
  85. _lazy_init()
  86. if device is None:
  87. device = torch.cuda.current_device()
  88. device = _get_device_index(device)
  89. if not isinstance(fraction, float):
  90. raise TypeError('Invalid type for fraction argument, must be `float`')
  91. if fraction < 0 or fraction > 1:
  92. raise ValueError('Invalid fraction value: {}. '
  93. 'Allowed range: 0~1'.format(fraction))
  94. torch._C._cuda_setMemoryFraction(fraction, device)
  95. def empty_cache() -> None:
  96. r"""Releases all unoccupied cached memory currently held by the caching
  97. allocator so that those can be used in other GPU application and visible in
  98. `nvidia-smi`.
  99. .. note::
  100. :func:`~torch.cuda.empty_cache` doesn't increase the amount of GPU
  101. memory available for PyTorch. However, it may help reduce fragmentation
  102. of GPU memory in certain cases. See :ref:`cuda-memory-management` for
  103. more details about GPU memory management.
  104. """
  105. if is_initialized():
  106. torch._C._cuda_emptyCache()
  107. def memory_stats(device: Union[Device, int] = None) -> Dict[str, Any]:
  108. r"""Returns a dictionary of CUDA memory allocator statistics for a
  109. given device.
  110. The return value of this function is a dictionary of statistics, each of
  111. which is a non-negative integer.
  112. Core statistics:
  113. - ``"allocated.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
  114. number of allocation requests received by the memory allocator.
  115. - ``"allocated_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
  116. amount of allocated memory.
  117. - ``"segment.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
  118. number of reserved segments from ``cudaMalloc()``.
  119. - ``"reserved_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
  120. amount of reserved memory.
  121. - ``"active.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
  122. number of active memory blocks.
  123. - ``"active_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
  124. amount of active memory.
  125. - ``"inactive_split.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
  126. number of inactive, non-releasable memory blocks.
  127. - ``"inactive_split_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
  128. amount of inactive, non-releasable memory.
  129. For these core statistics, values are broken down as follows.
  130. Pool type:
  131. - ``all``: combined statistics across all memory pools.
  132. - ``large_pool``: statistics for the large allocation pool
  133. (as of October 2019, for size >= 1MB allocations).
  134. - ``small_pool``: statistics for the small allocation pool
  135. (as of October 2019, for size < 1MB allocations).
  136. Metric type:
  137. - ``current``: current value of this metric.
  138. - ``peak``: maximum value of this metric.
  139. - ``allocated``: historical total increase in this metric.
  140. - ``freed``: historical total decrease in this metric.
  141. In addition to the core statistics, we also provide some simple event
  142. counters:
  143. - ``"num_alloc_retries"``: number of failed ``cudaMalloc`` calls that
  144. result in a cache flush and retry.
  145. - ``"num_ooms"``: number of out-of-memory errors thrown.
  146. The caching allocator can be configured via ENV to not split blocks larger than a
  147. defined size (see Memory Management section of the Cuda Semantics documentation).
  148. This helps avoid memory fragmentation but may have a performance
  149. penalty. Additional outputs to assist with tuning and evaluating impact:
  150. - ``"max_split_size"``: blocks above this size will not be split.
  151. - ``"oversize_allocations.{current,peak,allocated,freed}"``:
  152. number of over-size allocation requests received by the memory allocator.
  153. - ``"oversize_segments.{current,peak,allocated,freed}"``:
  154. number of over-size reserved segments from ``cudaMalloc()``.
  155. The caching allocator can be configured via ENV to round memory allocations in order
  156. to reduce fragmentation. Sometimes the overhead from rounding can be higher than
  157. the fragmentation it helps reduce. The following stat can be used to check if
  158. rounding adds too much overhed:
  159. - ``"requested_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
  160. memory requested by client code, compare this with allocated_bytes to check if
  161. allocation rounding adds too much overhead.
  162. Args:
  163. device (torch.device or int, optional): selected device. Returns
  164. statistics for the current device, given by :func:`~torch.cuda.current_device`,
  165. if :attr:`device` is ``None`` (default).
  166. .. note::
  167. See :ref:`cuda-memory-management` for more details about GPU memory
  168. management.
  169. .. note::
  170. With :ref:`backend:cudaMallocAsync<cuda-memory-envvars>`, some stats are not
  171. meaningful, and are always reported as zero.
  172. """
  173. result = []
  174. def _recurse_add_to_result(prefix, obj):
  175. if isinstance(obj, dict):
  176. if len(prefix) > 0:
  177. prefix += "."
  178. for k, v in obj.items():
  179. _recurse_add_to_result(prefix + k, v)
  180. else:
  181. result.append((prefix, obj))
  182. stats = memory_stats_as_nested_dict(device=device)
  183. _recurse_add_to_result("", stats)
  184. result.sort()
  185. return collections.OrderedDict(result)
  186. def memory_stats_as_nested_dict(device: Union[Device, int] = None) -> Dict[str, Any]:
  187. r"""Returns the result of :func:`~torch.cuda.memory_stats` as a nested dictionary."""
  188. if not is_initialized():
  189. return {}
  190. device = _get_device_index(device, optional=True)
  191. return torch._C._cuda_memoryStats(device)
  192. def reset_accumulated_memory_stats(device: Union[Device, int] = None) -> None:
  193. r"""Resets the "accumulated" (historical) stats tracked by the CUDA memory allocator.
  194. See :func:`~torch.cuda.memory_stats` for details. Accumulated stats correspond to
  195. the `"allocated"` and `"freed"` keys in each individual stat dict, as well as
  196. `"num_alloc_retries"` and `"num_ooms"`.
  197. Args:
  198. device (torch.device or int, optional): selected device. Returns
  199. statistic for the current device, given by :func:`~torch.cuda.current_device`,
  200. if :attr:`device` is ``None`` (default).
  201. .. note::
  202. See :ref:`cuda-memory-management` for more details about GPU memory
  203. management.
  204. """
  205. device = _get_device_index(device, optional=True)
  206. return torch._C._cuda_resetAccumulatedMemoryStats(device)
  207. def reset_peak_memory_stats(device: Union[Device, int] = None) -> None:
  208. r"""Resets the "peak" stats tracked by the CUDA memory allocator.
  209. See :func:`~torch.cuda.memory_stats` for details. Peak stats correspond to the
  210. `"peak"` key in each individual stat dict.
  211. Args:
  212. device (torch.device or int, optional): selected device. Returns
  213. statistic for the current device, given by :func:`~torch.cuda.current_device`,
  214. if :attr:`device` is ``None`` (default).
  215. .. note::
  216. See :ref:`cuda-memory-management` for more details about GPU memory
  217. management.
  218. """
  219. device = _get_device_index(device, optional=True)
  220. return torch._C._cuda_resetPeakMemoryStats(device)
  221. def reset_max_memory_allocated(device: Union[Device, int] = None) -> None:
  222. r"""Resets the starting point in tracking maximum GPU memory occupied by
  223. tensors for a given device.
  224. See :func:`~torch.cuda.max_memory_allocated` for details.
  225. Args:
  226. device (torch.device or int, optional): selected device. Returns
  227. statistic for the current device, given by :func:`~torch.cuda.current_device`,
  228. if :attr:`device` is ``None`` (default).
  229. .. warning::
  230. This function now calls :func:`~torch.cuda.reset_peak_memory_stats`, which resets
  231. /all/ peak memory stats.
  232. .. note::
  233. See :ref:`cuda-memory-management` for more details about GPU memory
  234. management.
  235. """
  236. warnings.warn(
  237. "torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, "
  238. "which resets /all/ peak memory stats.",
  239. FutureWarning)
  240. return reset_peak_memory_stats(device=device)
  241. def reset_max_memory_cached(device: Union[Device, int] = None) -> None:
  242. r"""Resets the starting point in tracking maximum GPU memory managed by the
  243. caching allocator for a given device.
  244. See :func:`~torch.cuda.max_memory_cached` for details.
  245. Args:
  246. device (torch.device or int, optional): selected device. Returns
  247. statistic for the current device, given by :func:`~torch.cuda.current_device`,
  248. if :attr:`device` is ``None`` (default).
  249. .. warning::
  250. This function now calls :func:`~torch.cuda.reset_peak_memory_stats`, which resets
  251. /all/ peak memory stats.
  252. .. note::
  253. See :ref:`cuda-memory-management` for more details about GPU memory
  254. management.
  255. """
  256. warnings.warn(
  257. "torch.cuda.reset_max_memory_cached now calls torch.cuda.reset_peak_memory_stats, "
  258. "which resets /all/ peak memory stats.",
  259. FutureWarning)
  260. return reset_peak_memory_stats(device=device)
  261. def memory_allocated(device: Union[Device, int] = None) -> int:
  262. r"""Returns the current GPU memory occupied by tensors in bytes for a given
  263. device.
  264. Args:
  265. device (torch.device or int, optional): selected device. Returns
  266. statistic for the current device, given by :func:`~torch.cuda.current_device`,
  267. if :attr:`device` is ``None`` (default).
  268. .. note::
  269. This is likely less than the amount shown in `nvidia-smi` since some
  270. unused memory can be held by the caching allocator and some context
  271. needs to be created on GPU. See :ref:`cuda-memory-management` for more
  272. details about GPU memory management.
  273. """
  274. return memory_stats(device=device).get("allocated_bytes.all.current", 0)
  275. def max_memory_allocated(device: Union[Device, int] = None) -> int:
  276. r"""Returns the maximum GPU memory occupied by tensors in bytes for a given
  277. device.
  278. By default, this returns the peak allocated memory since the beginning of
  279. this program. :func:`~torch.cuda.reset_peak_memory_stats` can be used to
  280. reset the starting point in tracking this metric. For example, these two
  281. functions can measure the peak allocated memory usage of each iteration in a
  282. training loop.
  283. Args:
  284. device (torch.device or int, optional): selected device. Returns
  285. statistic for the current device, given by :func:`~torch.cuda.current_device`,
  286. if :attr:`device` is ``None`` (default).
  287. .. note::
  288. See :ref:`cuda-memory-management` for more details about GPU memory
  289. management.
  290. """
  291. return memory_stats(device=device).get("allocated_bytes.all.peak", 0)
  292. def memory_reserved(device: Union[Device, int] = None) -> int:
  293. r"""Returns the current GPU memory managed by the caching allocator in bytes
  294. for a given device.
  295. Args:
  296. device (torch.device or int, optional): selected device. Returns
  297. statistic for the current device, given by :func:`~torch.cuda.current_device`,
  298. if :attr:`device` is ``None`` (default).
  299. .. note::
  300. See :ref:`cuda-memory-management` for more details about GPU memory
  301. management.
  302. """
  303. return memory_stats(device=device).get("reserved_bytes.all.current", 0)
  304. def max_memory_reserved(device: Union[Device, int] = None) -> int:
  305. r"""Returns the maximum GPU memory managed by the caching allocator in bytes
  306. for a given device.
  307. By default, this returns the peak cached memory since the beginning of this
  308. program. :func:`~torch.cuda.reset_peak_memory_stats` can be used to reset
  309. the starting point in tracking this metric. For example, these two functions
  310. can measure the peak cached memory amount of each iteration in a training
  311. loop.
  312. Args:
  313. device (torch.device or int, optional): selected device. Returns
  314. statistic for the current device, given by :func:`~torch.cuda.current_device`,
  315. if :attr:`device` is ``None`` (default).
  316. .. note::
  317. See :ref:`cuda-memory-management` for more details about GPU memory
  318. management.
  319. """
  320. return memory_stats(device=device).get("reserved_bytes.all.peak", 0)
  321. def memory_cached(device: Union[Device, int] = None) -> int:
  322. r"""Deprecated; see :func:`~torch.cuda.memory_reserved`."""
  323. warnings.warn(
  324. "torch.cuda.memory_cached has been renamed to torch.cuda.memory_reserved",
  325. FutureWarning)
  326. return memory_reserved(device=device)
  327. def max_memory_cached(device: Union[Device, int] = None) -> int:
  328. r"""Deprecated; see :func:`~torch.cuda.max_memory_reserved`."""
  329. warnings.warn(
  330. "torch.cuda.max_memory_cached has been renamed to torch.cuda.max_memory_reserved",
  331. FutureWarning)
  332. return max_memory_reserved(device=device)
  333. def memory_snapshot():
  334. r"""Returns a snapshot of the CUDA memory allocator state across all devices.
  335. Interpreting the output of this function requires familiarity with the
  336. memory allocator internals.
  337. .. note::
  338. See :ref:`cuda-memory-management` for more details about GPU memory
  339. management.
  340. """
  341. return torch._C._cuda_memorySnapshot()['segments']
  342. def memory_summary(device: Union[Device, int] = None, abbreviated: bool = False) -> str:
  343. r"""Returns a human-readable printout of the current memory allocator
  344. statistics for a given device.
  345. This can be useful to display periodically during training, or when
  346. handling out-of-memory exceptions.
  347. Args:
  348. device (torch.device or int, optional): selected device. Returns
  349. printout for the current device, given by :func:`~torch.cuda.current_device`,
  350. if :attr:`device` is ``None`` (default).
  351. abbreviated (bool, optional): whether to return an abbreviated summary
  352. (default: False).
  353. .. note::
  354. See :ref:`cuda-memory-management` for more details about GPU memory
  355. management.
  356. """
  357. device = _get_device_index(device, optional=True)
  358. stats = memory_stats(device=device)
  359. def _format_size(sz, pref_sz):
  360. prefixes = ["B ", "KiB", "MiB", "GiB", "TiB", "PiB"]
  361. prefix = prefixes[0]
  362. for new_prefix in prefixes[1:]:
  363. if pref_sz < 768 * 1024:
  364. break
  365. prefix = new_prefix
  366. sz //= 1024
  367. pref_sz /= 1024
  368. return "{:6d} {}".format(sz, prefix)
  369. def _format_count(cnt, pref_cnt):
  370. prefixes = [" ", "K", "M"]
  371. prefix = prefixes[0]
  372. for new_prefix in prefixes[1:]:
  373. if pref_cnt < 750 * 1000:
  374. break
  375. prefix = new_prefix
  376. cnt //= 1000
  377. pref_cnt /= 1000
  378. return "{:7d} {} ".format(cnt, prefix)
  379. metrics_to_display = [
  380. ("allocated_bytes", "Allocated memory", _format_size),
  381. ("active_bytes", "Active memory", _format_size),
  382. ("requested_bytes", "Requested memory", _format_size),
  383. ("reserved_bytes", "GPU reserved memory", _format_size),
  384. ("inactive_split_bytes", "Non-releasable memory", _format_size),
  385. ("allocation", "Allocations", _format_count),
  386. ("active", "Active allocs", _format_count),
  387. ("segment", "GPU reserved segments", _format_count),
  388. ("inactive_split", "Non-releasable allocs", _format_count),
  389. ]
  390. lines = []
  391. lines.append("=" * 75)
  392. lines.append(" {_:16} PyTorch CUDA memory summary, device ID {device:<17d} ")
  393. lines.append("-" * 75)
  394. lines.append(" {_:9} CUDA OOMs: {num_ooms:<12d} | {_:6} cudaMalloc retries: {num_alloc_retries:<8d} ")
  395. lines.append("=" * 75)
  396. lines.append(" Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed ")
  397. for metric_key, metric_name, formatter in metrics_to_display:
  398. lines.append("-" * 75)
  399. submetrics = [("all", metric_name)]
  400. if not abbreviated:
  401. submetrics.append(("large_pool", " from large pool"))
  402. submetrics.append(("small_pool", " from small pool"))
  403. current_prefval, peak_prefval, allocated_prefval, freed_prefval = None, None, None, None
  404. for submetric_key, submetric_name in submetrics:
  405. prefix = metric_key + "." + submetric_key + "."
  406. current = stats[prefix + "current"]
  407. peak = stats[prefix + "peak"]
  408. allocated = stats[prefix + "allocated"]
  409. freed = stats[prefix + "freed"]
  410. if current_prefval is None:
  411. current_prefval = current
  412. peak_prefval = peak
  413. allocated_prefval = allocated
  414. freed_prefval = freed
  415. lines.append(" {:<21} | {} | {} | {} | {} ".format(
  416. submetric_name,
  417. formatter(current, current_prefval),
  418. formatter(peak, peak_prefval),
  419. formatter(allocated, allocated_prefval),
  420. formatter(freed, freed_prefval)),
  421. )
  422. metrics_to_display = [
  423. ("oversize_allocations", "Oversize allocations", _format_count),
  424. ("oversize_segments", "Oversize GPU segments", _format_count),
  425. ]
  426. for metric_key, metric_name, formatter in metrics_to_display:
  427. lines.append("-" * 75)
  428. prefix = metric_key + "."
  429. current = stats[prefix + "current"]
  430. peak = stats[prefix + "peak"]
  431. allocated = stats[prefix + "allocated"]
  432. freed = stats[prefix + "freed"]
  433. lines.append(" {:<21} | {} | {} | {} | {} ".format(
  434. metric_name,
  435. formatter(current, current),
  436. formatter(peak, peak),
  437. formatter(allocated, allocated),
  438. formatter(freed, freed)),
  439. )
  440. lines.append("=" * 75)
  441. fmt_dict = {"_": "", "device": device}
  442. for k, v in stats.items():
  443. fmt_dict[k.replace(".", "-")] = v
  444. return "|" + "|\n|".join(lines).format(**fmt_dict) + "|\n"
  445. def list_gpu_processes(device: Union[Device, int] = None) -> str:
  446. r"""Returns a human-readable printout of the running processes
  447. and their GPU memory use for a given device.
  448. This can be useful to display periodically during training, or when
  449. handling out-of-memory exceptions.
  450. Args:
  451. device (torch.device or int, optional): selected device. Returns
  452. printout for the current device, given by :func:`~torch.cuda.current_device`,
  453. if :attr:`device` is ``None`` (default).
  454. """
  455. try:
  456. import pynvml # type: ignore[import]
  457. except ModuleNotFoundError:
  458. return("pynvml module not found, please install pynvml")
  459. from pynvml import NVMLError_DriverNotLoaded
  460. try:
  461. pynvml.nvmlInit()
  462. except NVMLError_DriverNotLoaded:
  463. return ("cuda driver can't be loaded, is cuda enabled?")
  464. device = _get_device_index(device, optional=True)
  465. handle = pynvml.nvmlDeviceGetHandleByIndex(device)
  466. procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
  467. lines = []
  468. lines.append(f"GPU:{device}")
  469. if len(procs) == 0:
  470. lines.append("no processes are running")
  471. for p in procs:
  472. mem = p.usedGpuMemory / (1024 * 1024)
  473. lines.append(f"process {p.pid:>10d} uses {mem:>12.3f} MB GPU memory")
  474. return "\n".join(lines)
  475. def mem_get_info(device: Union[Device, int] = None) -> Tuple[int, int]:
  476. r"""Returns the global free and total GPU memory occupied for a given
  477. device using cudaMemGetInfo.
  478. Args:
  479. device (torch.device or int, optional): selected device. Returns
  480. statistic for the current device, given by :func:`~torch.cuda.current_device`,
  481. if :attr:`device` is ``None`` (default).
  482. .. note::
  483. See :ref:`cuda-memory-management` for more
  484. details about GPU memory management.
  485. """
  486. if device is None:
  487. device = torch.cuda.current_device()
  488. device = _get_device_index(device)
  489. return torch.cuda.cudart().cudaMemGetInfo(device)
  490. def _record_memory_history(enabled: bool, record_context=True,
  491. trace_alloc_max_entries=1,
  492. trace_alloc_record_context=False, device: Union[Device, int] = None,
  493. _enable_expensive_cpp=False):
  494. """Enables recording of Python stack traces to be associated with memory
  495. allocations, so you can tell what allocated any piece of memory in
  496. :func:`torch.memory_snapshot`.
  497. The Python trace collection is fast (2us per trace), so you may consider
  498. enabling this on production jobs if you anticipate ever having to debug
  499. memory issues.
  500. .. warning:
  501. The :attr:`_enable_expensive_cpp` arguments lets you enable also
  502. collecting C++ stack traces. This collection is VERY SLOW and should
  503. only be used if you are debugging framework problems on a minified
  504. example. In principle, it should be possible to implement fast C++
  505. stack trace collection; file an issue with us if you need it.
  506. """
  507. with torch.cuda.device(device):
  508. _C._cuda_recordMemoryHistory(enabled, record_context, _enable_expensive_cpp,
  509. trace_alloc_max_entries, trace_alloc_record_context)
  510. def _snapshot(device: Union[Device, int] = None):
  511. with torch.cuda.device(device):
  512. return _C._cuda_memorySnapshot()
  513. def _save_segment_usage(filename='output.svg', snapshot=None):
  514. if snapshot is None:
  515. snapshot = _snapshot()
  516. with open(filename, 'w') as f:
  517. f.write(_segments(snapshot))
  518. def _save_memory_usage(filename='output.svg', snapshot=None):
  519. if snapshot is None:
  520. snapshot = _snapshot()
  521. with open(filename, 'w') as f:
  522. f.write(_memory(snapshot))
  523. def _set_allocator_settings(env: str):
  524. return torch._C._cuda_cudaCachingAllocator_set_allocator_settings(env)
  525. def get_allocator_backend() -> str:
  526. r"""Returns a string describing the active allocator backend as set by
  527. ``PYTORCH_CUDA_ALLOC_CONF``. Currently available backends are
  528. ``native`` (PyTorch's native caching allocator) and `cudaMallocAsync``
  529. (CUDA's built-in asynchronous allocator).
  530. .. note::
  531. See :ref:`cuda-memory-management` for details on choosing the allocator backend.
  532. """
  533. return torch._C._cuda_getAllocatorBackend()
  534. class _CUDAAllocator:
  535. r"""Wrapper over internal CUDA memory allocators.
  536. """
  537. def __init__(self, allocator: torch._C._cuda_CUDAAllocator):
  538. self._allocator = allocator
  539. def allocator(self):
  540. return self._allocator
  541. class CUDAPluggableAllocator(_CUDAAllocator):
  542. r"""CUDA memory allocator loaded from a so file.
  543. Memory allocators are compiled in .so files and loaded dynamically using ctypes.
  544. To change the active allocator use the :func:`torch.memory.cuda.change_current_allocator`
  545. function.
  546. Args:
  547. path_to_so_file(str): Path in the filesystem to the `.so` file containing
  548. the allocator functions
  549. alloc_fn_name(str): Name of the function to perform the memory allocation
  550. in the so file. The signature must be:
  551. void* alloc_fn_name(ssize_t size, int device, cudaStream_t stream);
  552. free_fn_name(str): Name of the function to perform the memory release
  553. in the so file. The signature must be:
  554. void free_fn_name(void* ptr, size_t size, cudaStream_t stream);
  555. .. warning::
  556. This is currently supported only in unix OSs
  557. .. note::
  558. See :ref:`cuda-memory-management` for details on creating and using a custom allocator
  559. """
  560. def __init__(self, path_to_so_file: str, alloc_fn_name: str, free_fn_name: str):
  561. allocator = ctypes.CDLL(path_to_so_file)
  562. alloc_fn = ctypes.cast(getattr(allocator, alloc_fn_name), ctypes.c_void_p).value
  563. free_fn = ctypes.cast(getattr(allocator, free_fn_name), ctypes.c_void_p).value
  564. assert alloc_fn is not None
  565. assert free_fn is not None
  566. self._allocator = torch._C._cuda_customAllocator(alloc_fn, free_fn)
  567. def change_current_allocator(allocator: _CUDAAllocator) -> None:
  568. r"""Changes the currently used memory allocator to be the one provided.
  569. If the current allocator has already been used/initialized, this function will error.
  570. Args:
  571. allocator (torch.cuda.memory._CUDAAllocator): allocator to be set as the active one.
  572. .. note::
  573. See :ref:`cuda-memory-management` for details on creating and using a custom allocator
  574. """
  575. torch._C._cuda_changeCurrentAllocator(allocator.allocator())
  576. def _get_current_allocator() -> _CUDAAllocator:
  577. r"""Returns the allocator being currently used.
  578. .. note::
  579. See :ref:`cuda-memory-management` for details on creating and using a custom allocator
  580. """
  581. return _CUDAAllocator(torch._C._cuda_getAllocator())