from collections import defaultdict from itertools import chain import pickle from typing import ( Any, Callable, Dict, List, no_type_check, Sequence, ) import torch import torch.nn as nn from torch.utils.hooks import RemovableHandle from torch.utils._python_dispatch import TorchDispatchMode BYTES_PER_MB = 1024 * 1024.0 class MemoryProfileDispatchMode(TorchDispatchMode): """ Run in ``TorchDispatchMode`` to get memory stats at operator level. """ def __init__(self, memory_tracker) -> None: self.memory_tracker = memory_tracker def __torch_dispatch__(self, func, types, args=..., kwargs=None): rs = func(*args, **kwargs) if func == torch.ops.aten.detach.default: return rs func_name: str = ( self.memory_tracker._cur_module_name + "." + func.__name__ + "_" + str(self.memory_tracker._operator_names[func.__name__]) ) self.memory_tracker._operator_names[func.__name__] = ( self.memory_tracker._operator_names[func.__name__] + 1 ) self.memory_tracker._record_memory_stats(func_name) return rs class MemoryTracker: """ Collect and plot the memory stats including ``memories_allocated``, ``memories_active`` and ``memories_reserved`` at operator level. It also prints a summary for the top 20 operators that generate the most memories. Example usage: >>> # xdoctest: +SKIP(failing) >>> net.cuda() >>> input = input.cuda() >>> mem_tracker = MemoryTracker() >>> mem_tracker.start_monitor(net) >>> net.zero_grad(True) >>> loss = net(input) >>> if isinstance(loss, dict): >>> loss = loss['out'] >>> loss.sum().backward() >>> net.zero_grad(set_to_none=True) >>> mem_tracker.stop() >>> mem_tracker.summary() >>> mem_tracker.show_traces() """ def __init__(self) -> None: torch._C._log_api_usage_once("torch.distributed.memory_tracker") self._hooks: List[RemovableHandle] = [] self._operator_names: Dict[str, int] = defaultdict(int) self.memories_allocated: Dict[int, Dict[str, float]] = defaultdict() self.memories_active: Dict[int, Dict[str, float]] = defaultdict() self.memories_reserved: Dict[int, Dict[str, float]] = defaultdict() self._markers: Dict[str, int] = defaultdict(int) self._cur_module_name: str = "" self._op_index: int = 0 self._num_cuda_retries: int = 0 @no_type_check def start_monitor(self, root_module: nn.Module) -> None: """ Register module hooks and entering ``MemoryProfileDispatchMode``, so that operator level memory stats can be tracked during module runtime. """ self._clear_state() root_module.__setattr__("_memory_tracker_is_root", True) for name, m in root_module.named_modules(): if m is not root_module: m.__setattr__("_memory_tracker_is_root", False) # fused_proxy_group does not support hooks if ".fused_proxy_grouped_embedding_bag" in name: continue # hook ordering with other hooks added by users is not managed, so # the memory stats tracked here may not completely accurate. h1 = m.register_forward_pre_hook(self._create_pre_forward_hook(name)) h2 = m.register_forward_hook(self._create_post_forward_hook(name)) # it does not work well with jagged tensor somehow, the root cause is not # clear and remove it for now as it does not really capture important info. # h3 = m.register_backward_hook(self._create_backward_hook(name)) self._hooks.extend([h1, h2]) torch.cuda.empty_cache() assert getattr(self, "profile_mode", None) is None self.profile_mode = MemoryProfileDispatchMode(self) self.profile_mode.__enter__() @no_type_check def stop(self) -> None: """ Remove module hooks and exit ``MemoryProfileDispatchMode`` to stop tracking memory stats at operator level. Get some aggregated stats when the memory_tracker() is enabled, like cuda ``num_alloc_retries``. """ self._num_cuda_retries = torch.cuda.memory_stats().get("num_alloc_retries", 0) for h in self._hooks: h.remove() self._hooks.clear() assert getattr(self, "profile_mode", None) is not None self.profile_mode.__exit__(None, None, None) self.profile_mode = None @no_type_check def summary(self, top: int = 20) -> None: """ Print out the top operators that generate the most memories. The number of the top operators can be configured. """ op_diff: Dict[str, float] = defaultdict(float) op_name, previous_allocated_memory = self.memories_allocated[0] for i in range(1, self._op_index): op_name, current_allocated_memory = self.memories_allocated[i] op_diff[op_name] = current_allocated_memory - previous_allocated_memory previous_allocated_memory = current_allocated_memory print("------------------------------------------------") print(f"The number of cuda retries are: {self._num_cuda_retries}") print(f"Top {top} ops that generates memory are:") for k, v in sorted(op_diff.items(), key=lambda item: item[1], reverse=True)[ :top ]: print(f"{k}: {v}MB") print("------------------------------------------------") @no_type_check def show_traces(self, path: str = "") -> None: import matplotlib.pyplot as plt def _plot_figure(x, y_values, labels): min_val = min(list(chain(*y_values))) * 0.999 max_val = max(list(chain(*y_values))) * 1.001 plt.figure() for y, label in zip(y_values, labels): plt.plot(x, y, label=label) plt.xlabel("# Operator Calls") plt.ylabel("Memory (MB)") plt.legend() for marker_name, marker in self._markers.items(): if marker_name == "fw_bw_boundary": plt.plot( [marker, marker], [min_val, max_val], "r", lw=2, label=marker_name, ) else: plt.plot( [marker, marker], [min_val, max_val], "k-", lw=2, label=marker_name, ) if path != "": self.load(path) y_1 = [gb for (name, gb) in self.memories_allocated.values()] y_2 = [gb for (name, gb) in self.memories_active.values()] y_3 = [gb for (name, gb) in self.memories_reserved.values()] x = list(range(len(y_1))) # Split figures when there is big difference between # "reserved_memory" and "allocated_memory" or "active_memory". _plot_figure( x, [list(y_1), list(y_2), list(y_3)], ["allocated_memory", "active_memory", "reserved_memory"], ) _plot_figure(x, [list(y_1)], ["allocated_memory"]) _plot_figure(x, [list(y_2)], ["active_memory"]) _plot_figure(x, [list(y_3)], ["reserved_memory"]) def save_stats(self, path: str) -> None: """ Save the stats using pickle during runtime if users want to plot the traces in other places like notebook. """ stats = { "memories_allocated": self.memories_allocated, "memories_active": self.memories_active, "memories_reserved": self.memories_reserved, "markers": self._markers, "num_alloc_retries": self._num_cuda_retries, } with open(path, "wb") as f: pickle.dump(stats, f, pickle.HIGHEST_PROTOCOL) def load(self, path: str) -> None: """ Load the pickled memory stats to plot the traces or print the summary. """ with open(path, "rb") as f: stats = pickle.load(f) self.memories_allocated = stats["memories_allocated"] self.memories_active = stats["memories_active"] self.memories_reserved = stats["memories_reserved"] self._markers = stats["markers"] self._num_cuda_retries = stats["num_alloc_retries"] def _create_pre_forward_hook(self, name: str) -> Callable: """ The pre_foward_hook is to insert current module name with forward prefix for the operator name, also it inserts the marker "fw_start" when the forward pass begins. """ def _pre_forward_hook(module: nn.Module, inputs: Any) -> None: self._cur_module_name = f"{name}.forward" if ( hasattr(module, "_memory_tracker_is_root") and module._memory_tracker_is_root ): self._add_marker("fw_start") return _pre_forward_hook def _create_post_forward_hook(self, name: str) -> Callable: """ The post_forward_hook inserts the marker 'fw_bw_boundary' at the boundary of forward pass and backward pass. """ def _post_forward_hook( module: nn.Module, inputs: Sequence[torch.Tensor], outputs: Sequence[torch.Tensor], ) -> None: if ( hasattr(module, "_memory_tracker_is_root") and module._memory_tracker_is_root ): self._add_marker("fw_bw_boundary") return _post_forward_hook def _create_backward_hook(self, name: str) -> Callable: """ The backward_hook inserts the current module name with backward prefix for the operator name. """ def _backward_hook( module: nn.Module, grad_input: torch.Tensor, grad_output: torch.Tensor ) -> None: self._cur_module_name = f"{name}.backward" return _backward_hook @no_type_check def _record_memory_stats(self, fn_name: str) -> None: """ Record current memory allocated, current memory active and current memory reserved. The memory stats dict is indexed with ``self._op_index``. """ memory_allocated: float = torch.cuda.memory_allocated() / BYTES_PER_MB memory_reserved: float = torch.cuda.memory_reserved() / BYTES_PER_MB memory_active: float = ( torch.cuda.memory_stats().get("active_bytes.all.current", 0) / BYTES_PER_MB ) self.memories_allocated[self._op_index] = (fn_name, memory_allocated) self.memories_reserved[self._op_index] = (fn_name, memory_reserved) self.memories_active[self._op_index] = (fn_name, memory_active) self._op_index += 1 def _add_marker(self, marker_name: str) -> None: """ Set the marker's x-axis value. """ marker_val = len(self.memories_allocated.values()) self._markers[marker_name] = marker_val def _clear_state(self) -> None: """ Clear states when start_monitor() is called. """ self._operator_names.clear() self.memories_allocated.clear() self.memories_active.clear() self.memories_reserved.clear() self._markers.clear() self._cur_module_name = "" self._op_index = 0 self._num_cuda_retries = 0