memory_tracker.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. from collections import defaultdict
  2. from itertools import chain
  3. import pickle
  4. from typing import (
  5. Any,
  6. Callable,
  7. Dict,
  8. List,
  9. no_type_check,
  10. Sequence,
  11. )
  12. import torch
  13. import torch.nn as nn
  14. from torch.utils.hooks import RemovableHandle
  15. from torch.utils._python_dispatch import TorchDispatchMode
  16. BYTES_PER_MB = 1024 * 1024.0
  17. class MemoryProfileDispatchMode(TorchDispatchMode):
  18. """
  19. Run in ``TorchDispatchMode`` to get memory stats at operator level.
  20. """
  21. def __init__(self, memory_tracker) -> None:
  22. self.memory_tracker = memory_tracker
  23. def __torch_dispatch__(self, func, types, args=..., kwargs=None):
  24. rs = func(*args, **kwargs)
  25. if func == torch.ops.aten.detach.default:
  26. return rs
  27. func_name: str = (
  28. self.memory_tracker._cur_module_name
  29. + "."
  30. + func.__name__
  31. + "_"
  32. + str(self.memory_tracker._operator_names[func.__name__])
  33. )
  34. self.memory_tracker._operator_names[func.__name__] = (
  35. self.memory_tracker._operator_names[func.__name__] + 1
  36. )
  37. self.memory_tracker._record_memory_stats(func_name)
  38. return rs
  39. class MemoryTracker:
  40. """
  41. Collect and plot the memory stats including ``memories_allocated``, ``memories_active``
  42. and ``memories_reserved`` at operator level.
  43. It also prints a summary for the top 20 operators that generate the most memories.
  44. Example usage:
  45. >>> # xdoctest: +SKIP(failing)
  46. >>> net.cuda()
  47. >>> input = input.cuda()
  48. >>> mem_tracker = MemoryTracker()
  49. >>> mem_tracker.start_monitor(net)
  50. >>> net.zero_grad(True)
  51. >>> loss = net(input)
  52. >>> if isinstance(loss, dict):
  53. >>> loss = loss['out']
  54. >>> loss.sum().backward()
  55. >>> net.zero_grad(set_to_none=True)
  56. >>> mem_tracker.stop()
  57. >>> mem_tracker.summary()
  58. >>> mem_tracker.show_traces()
  59. """
  60. def __init__(self) -> None:
  61. torch._C._log_api_usage_once("torch.distributed.memory_tracker")
  62. self._hooks: List[RemovableHandle] = []
  63. self._operator_names: Dict[str, int] = defaultdict(int)
  64. self.memories_allocated: Dict[int, Dict[str, float]] = defaultdict()
  65. self.memories_active: Dict[int, Dict[str, float]] = defaultdict()
  66. self.memories_reserved: Dict[int, Dict[str, float]] = defaultdict()
  67. self._markers: Dict[str, int] = defaultdict(int)
  68. self._cur_module_name: str = ""
  69. self._op_index: int = 0
  70. self._num_cuda_retries: int = 0
  71. @no_type_check
  72. def start_monitor(self, root_module: nn.Module) -> None:
  73. """
  74. Register module hooks and entering ``MemoryProfileDispatchMode``, so that
  75. operator level memory stats can be tracked during module runtime.
  76. """
  77. self._clear_state()
  78. root_module.__setattr__("_memory_tracker_is_root", True)
  79. for name, m in root_module.named_modules():
  80. if m is not root_module:
  81. m.__setattr__("_memory_tracker_is_root", False)
  82. # fused_proxy_group does not support hooks
  83. if ".fused_proxy_grouped_embedding_bag" in name:
  84. continue
  85. # hook ordering with other hooks added by users is not managed, so
  86. # the memory stats tracked here may not completely accurate.
  87. h1 = m.register_forward_pre_hook(self._create_pre_forward_hook(name))
  88. h2 = m.register_forward_hook(self._create_post_forward_hook(name))
  89. # it does not work well with jagged tensor somehow, the root cause is not
  90. # clear and remove it for now as it does not really capture important info.
  91. # h3 = m.register_backward_hook(self._create_backward_hook(name))
  92. self._hooks.extend([h1, h2])
  93. torch.cuda.empty_cache()
  94. assert getattr(self, "profile_mode", None) is None
  95. self.profile_mode = MemoryProfileDispatchMode(self)
  96. self.profile_mode.__enter__()
  97. @no_type_check
  98. def stop(self) -> None:
  99. """
  100. Remove module hooks and exit ``MemoryProfileDispatchMode`` to stop
  101. tracking memory stats at operator level.
  102. Get some aggregated stats when the memory_tracker() is enabled, like
  103. cuda ``num_alloc_retries``.
  104. """
  105. self._num_cuda_retries = torch.cuda.memory_stats().get("num_alloc_retries", 0)
  106. for h in self._hooks:
  107. h.remove()
  108. self._hooks.clear()
  109. assert getattr(self, "profile_mode", None) is not None
  110. self.profile_mode.__exit__(None, None, None)
  111. self.profile_mode = None
  112. @no_type_check
  113. def summary(self, top: int = 20) -> None:
  114. """
  115. Print out the top operators that generate the most memories. The number
  116. of the top operators can be configured.
  117. """
  118. op_diff: Dict[str, float] = defaultdict(float)
  119. op_name, previous_allocated_memory = self.memories_allocated[0]
  120. for i in range(1, self._op_index):
  121. op_name, current_allocated_memory = self.memories_allocated[i]
  122. op_diff[op_name] = current_allocated_memory - previous_allocated_memory
  123. previous_allocated_memory = current_allocated_memory
  124. print("------------------------------------------------")
  125. print(f"The number of cuda retries are: {self._num_cuda_retries}")
  126. print(f"Top {top} ops that generates memory are:")
  127. for k, v in sorted(op_diff.items(), key=lambda item: item[1], reverse=True)[
  128. :top
  129. ]:
  130. print(f"{k}: {v}MB")
  131. print("------------------------------------------------")
  132. @no_type_check
  133. def show_traces(self, path: str = "") -> None:
  134. import matplotlib.pyplot as plt
  135. def _plot_figure(x, y_values, labels):
  136. min_val = min(list(chain(*y_values))) * 0.999
  137. max_val = max(list(chain(*y_values))) * 1.001
  138. plt.figure()
  139. for y, label in zip(y_values, labels):
  140. plt.plot(x, y, label=label)
  141. plt.xlabel("# Operator Calls")
  142. plt.ylabel("Memory (MB)")
  143. plt.legend()
  144. for marker_name, marker in self._markers.items():
  145. if marker_name == "fw_bw_boundary":
  146. plt.plot(
  147. [marker, marker],
  148. [min_val, max_val],
  149. "r",
  150. lw=2,
  151. label=marker_name,
  152. )
  153. else:
  154. plt.plot(
  155. [marker, marker],
  156. [min_val, max_val],
  157. "k-",
  158. lw=2,
  159. label=marker_name,
  160. )
  161. if path != "":
  162. self.load(path)
  163. y_1 = [gb for (name, gb) in self.memories_allocated.values()]
  164. y_2 = [gb for (name, gb) in self.memories_active.values()]
  165. y_3 = [gb for (name, gb) in self.memories_reserved.values()]
  166. x = list(range(len(y_1)))
  167. # Split figures when there is big difference between
  168. # "reserved_memory" and "allocated_memory" or "active_memory".
  169. _plot_figure(
  170. x,
  171. [list(y_1), list(y_2), list(y_3)],
  172. ["allocated_memory", "active_memory", "reserved_memory"],
  173. )
  174. _plot_figure(x, [list(y_1)], ["allocated_memory"])
  175. _plot_figure(x, [list(y_2)], ["active_memory"])
  176. _plot_figure(x, [list(y_3)], ["reserved_memory"])
  177. def save_stats(self, path: str) -> None:
  178. """
  179. Save the stats using pickle during runtime if users want to plot the traces
  180. in other places like notebook.
  181. """
  182. stats = {
  183. "memories_allocated": self.memories_allocated,
  184. "memories_active": self.memories_active,
  185. "memories_reserved": self.memories_reserved,
  186. "markers": self._markers,
  187. "num_alloc_retries": self._num_cuda_retries,
  188. }
  189. with open(path, "wb") as f:
  190. pickle.dump(stats, f, pickle.HIGHEST_PROTOCOL)
  191. def load(self, path: str) -> None:
  192. """
  193. Load the pickled memory stats to plot the traces or print the summary.
  194. """
  195. with open(path, "rb") as f:
  196. stats = pickle.load(f)
  197. self.memories_allocated = stats["memories_allocated"]
  198. self.memories_active = stats["memories_active"]
  199. self.memories_reserved = stats["memories_reserved"]
  200. self._markers = stats["markers"]
  201. self._num_cuda_retries = stats["num_alloc_retries"]
  202. def _create_pre_forward_hook(self, name: str) -> Callable:
  203. """
  204. The pre_foward_hook is to insert current module name with forward prefix for the operator
  205. name, also it inserts the marker "fw_start" when the forward pass begins.
  206. """
  207. def _pre_forward_hook(module: nn.Module, inputs: Any) -> None:
  208. self._cur_module_name = f"{name}.forward"
  209. if (
  210. hasattr(module, "_memory_tracker_is_root")
  211. and module._memory_tracker_is_root
  212. ):
  213. self._add_marker("fw_start")
  214. return _pre_forward_hook
  215. def _create_post_forward_hook(self, name: str) -> Callable:
  216. """
  217. The post_forward_hook inserts the marker 'fw_bw_boundary' at the boundary
  218. of forward pass and backward pass.
  219. """
  220. def _post_forward_hook(
  221. module: nn.Module,
  222. inputs: Sequence[torch.Tensor],
  223. outputs: Sequence[torch.Tensor],
  224. ) -> None:
  225. if (
  226. hasattr(module, "_memory_tracker_is_root")
  227. and module._memory_tracker_is_root
  228. ):
  229. self._add_marker("fw_bw_boundary")
  230. return _post_forward_hook
  231. def _create_backward_hook(self, name: str) -> Callable:
  232. """
  233. The backward_hook inserts the current module name with backward prefix for the operator name.
  234. """
  235. def _backward_hook(
  236. module: nn.Module, grad_input: torch.Tensor, grad_output: torch.Tensor
  237. ) -> None:
  238. self._cur_module_name = f"{name}.backward"
  239. return _backward_hook
  240. @no_type_check
  241. def _record_memory_stats(self, fn_name: str) -> None:
  242. """
  243. Record current memory allocated, current memory active and current memory reserved.
  244. The memory stats dict is indexed with ``self._op_index``.
  245. """
  246. memory_allocated: float = torch.cuda.memory_allocated() / BYTES_PER_MB
  247. memory_reserved: float = torch.cuda.memory_reserved() / BYTES_PER_MB
  248. memory_active: float = (
  249. torch.cuda.memory_stats().get("active_bytes.all.current", 0) / BYTES_PER_MB
  250. )
  251. self.memories_allocated[self._op_index] = (fn_name, memory_allocated)
  252. self.memories_reserved[self._op_index] = (fn_name, memory_reserved)
  253. self.memories_active[self._op_index] = (fn_name, memory_active)
  254. self._op_index += 1
  255. def _add_marker(self, marker_name: str) -> None:
  256. """
  257. Set the marker's x-axis value.
  258. """
  259. marker_val = len(self.memories_allocated.values())
  260. self._markers[marker_name] = marker_val
  261. def _clear_state(self) -> None:
  262. """
  263. Clear states when start_monitor() is called.
  264. """
  265. self._operator_names.clear()
  266. self.memories_allocated.clear()
  267. self.memories_active.clear()
  268. self.memories_reserved.clear()
  269. self._markers.clear()
  270. self._cur_module_name = ""
  271. self._op_index = 0
  272. self._num_cuda_retries = 0