123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317 |
- from collections import defaultdict
- from itertools import chain
- import pickle
- from typing import (
- Any,
- Callable,
- Dict,
- List,
- no_type_check,
- Sequence,
- )
- import torch
- import torch.nn as nn
- from torch.utils.hooks import RemovableHandle
- from torch.utils._python_dispatch import TorchDispatchMode
- BYTES_PER_MB = 1024 * 1024.0
- class MemoryProfileDispatchMode(TorchDispatchMode):
- """
- Run in ``TorchDispatchMode`` to get memory stats at operator level.
- """
- def __init__(self, memory_tracker) -> None:
- self.memory_tracker = memory_tracker
- def __torch_dispatch__(self, func, types, args=..., kwargs=None):
- rs = func(*args, **kwargs)
- if func == torch.ops.aten.detach.default:
- return rs
- func_name: str = (
- self.memory_tracker._cur_module_name
- + "."
- + func.__name__
- + "_"
- + str(self.memory_tracker._operator_names[func.__name__])
- )
- self.memory_tracker._operator_names[func.__name__] = (
- self.memory_tracker._operator_names[func.__name__] + 1
- )
- self.memory_tracker._record_memory_stats(func_name)
- return rs
- class MemoryTracker:
- """
- Collect and plot the memory stats including ``memories_allocated``, ``memories_active``
- and ``memories_reserved`` at operator level.
- It also prints a summary for the top 20 operators that generate the most memories.
- Example usage:
- >>> # xdoctest: +SKIP(failing)
- >>> net.cuda()
- >>> input = input.cuda()
- >>> mem_tracker = MemoryTracker()
- >>> mem_tracker.start_monitor(net)
- >>> net.zero_grad(True)
- >>> loss = net(input)
- >>> if isinstance(loss, dict):
- >>> loss = loss['out']
- >>> loss.sum().backward()
- >>> net.zero_grad(set_to_none=True)
- >>> mem_tracker.stop()
- >>> mem_tracker.summary()
- >>> mem_tracker.show_traces()
- """
- def __init__(self) -> None:
- torch._C._log_api_usage_once("torch.distributed.memory_tracker")
- self._hooks: List[RemovableHandle] = []
- self._operator_names: Dict[str, int] = defaultdict(int)
- self.memories_allocated: Dict[int, Dict[str, float]] = defaultdict()
- self.memories_active: Dict[int, Dict[str, float]] = defaultdict()
- self.memories_reserved: Dict[int, Dict[str, float]] = defaultdict()
- self._markers: Dict[str, int] = defaultdict(int)
- self._cur_module_name: str = ""
- self._op_index: int = 0
- self._num_cuda_retries: int = 0
- @no_type_check
- def start_monitor(self, root_module: nn.Module) -> None:
- """
- Register module hooks and entering ``MemoryProfileDispatchMode``, so that
- operator level memory stats can be tracked during module runtime.
- """
- self._clear_state()
- root_module.__setattr__("_memory_tracker_is_root", True)
- for name, m in root_module.named_modules():
- if m is not root_module:
- m.__setattr__("_memory_tracker_is_root", False)
- # fused_proxy_group does not support hooks
- if ".fused_proxy_grouped_embedding_bag" in name:
- continue
- # hook ordering with other hooks added by users is not managed, so
- # the memory stats tracked here may not completely accurate.
- h1 = m.register_forward_pre_hook(self._create_pre_forward_hook(name))
- h2 = m.register_forward_hook(self._create_post_forward_hook(name))
- # it does not work well with jagged tensor somehow, the root cause is not
- # clear and remove it for now as it does not really capture important info.
- # h3 = m.register_backward_hook(self._create_backward_hook(name))
- self._hooks.extend([h1, h2])
- torch.cuda.empty_cache()
- assert getattr(self, "profile_mode", None) is None
- self.profile_mode = MemoryProfileDispatchMode(self)
- self.profile_mode.__enter__()
- @no_type_check
- def stop(self) -> None:
- """
- Remove module hooks and exit ``MemoryProfileDispatchMode`` to stop
- tracking memory stats at operator level.
- Get some aggregated stats when the memory_tracker() is enabled, like
- cuda ``num_alloc_retries``.
- """
- self._num_cuda_retries = torch.cuda.memory_stats().get("num_alloc_retries", 0)
- for h in self._hooks:
- h.remove()
- self._hooks.clear()
- assert getattr(self, "profile_mode", None) is not None
- self.profile_mode.__exit__(None, None, None)
- self.profile_mode = None
- @no_type_check
- def summary(self, top: int = 20) -> None:
- """
- Print out the top operators that generate the most memories. The number
- of the top operators can be configured.
- """
- op_diff: Dict[str, float] = defaultdict(float)
- op_name, previous_allocated_memory = self.memories_allocated[0]
- for i in range(1, self._op_index):
- op_name, current_allocated_memory = self.memories_allocated[i]
- op_diff[op_name] = current_allocated_memory - previous_allocated_memory
- previous_allocated_memory = current_allocated_memory
- print("------------------------------------------------")
- print(f"The number of cuda retries are: {self._num_cuda_retries}")
- print(f"Top {top} ops that generates memory are:")
- for k, v in sorted(op_diff.items(), key=lambda item: item[1], reverse=True)[
- :top
- ]:
- print(f"{k}: {v}MB")
- print("------------------------------------------------")
- @no_type_check
- def show_traces(self, path: str = "") -> None:
- import matplotlib.pyplot as plt
- def _plot_figure(x, y_values, labels):
- min_val = min(list(chain(*y_values))) * 0.999
- max_val = max(list(chain(*y_values))) * 1.001
- plt.figure()
- for y, label in zip(y_values, labels):
- plt.plot(x, y, label=label)
- plt.xlabel("# Operator Calls")
- plt.ylabel("Memory (MB)")
- plt.legend()
- for marker_name, marker in self._markers.items():
- if marker_name == "fw_bw_boundary":
- plt.plot(
- [marker, marker],
- [min_val, max_val],
- "r",
- lw=2,
- label=marker_name,
- )
- else:
- plt.plot(
- [marker, marker],
- [min_val, max_val],
- "k-",
- lw=2,
- label=marker_name,
- )
- if path != "":
- self.load(path)
- y_1 = [gb for (name, gb) in self.memories_allocated.values()]
- y_2 = [gb for (name, gb) in self.memories_active.values()]
- y_3 = [gb for (name, gb) in self.memories_reserved.values()]
- x = list(range(len(y_1)))
- # Split figures when there is big difference between
- # "reserved_memory" and "allocated_memory" or "active_memory".
- _plot_figure(
- x,
- [list(y_1), list(y_2), list(y_3)],
- ["allocated_memory", "active_memory", "reserved_memory"],
- )
- _plot_figure(x, [list(y_1)], ["allocated_memory"])
- _plot_figure(x, [list(y_2)], ["active_memory"])
- _plot_figure(x, [list(y_3)], ["reserved_memory"])
- def save_stats(self, path: str) -> None:
- """
- Save the stats using pickle during runtime if users want to plot the traces
- in other places like notebook.
- """
- stats = {
- "memories_allocated": self.memories_allocated,
- "memories_active": self.memories_active,
- "memories_reserved": self.memories_reserved,
- "markers": self._markers,
- "num_alloc_retries": self._num_cuda_retries,
- }
- with open(path, "wb") as f:
- pickle.dump(stats, f, pickle.HIGHEST_PROTOCOL)
- def load(self, path: str) -> None:
- """
- Load the pickled memory stats to plot the traces or print the summary.
- """
- with open(path, "rb") as f:
- stats = pickle.load(f)
- self.memories_allocated = stats["memories_allocated"]
- self.memories_active = stats["memories_active"]
- self.memories_reserved = stats["memories_reserved"]
- self._markers = stats["markers"]
- self._num_cuda_retries = stats["num_alloc_retries"]
- def _create_pre_forward_hook(self, name: str) -> Callable:
- """
- The pre_foward_hook is to insert current module name with forward prefix for the operator
- name, also it inserts the marker "fw_start" when the forward pass begins.
- """
- def _pre_forward_hook(module: nn.Module, inputs: Any) -> None:
- self._cur_module_name = f"{name}.forward"
- if (
- hasattr(module, "_memory_tracker_is_root")
- and module._memory_tracker_is_root
- ):
- self._add_marker("fw_start")
- return _pre_forward_hook
- def _create_post_forward_hook(self, name: str) -> Callable:
- """
- The post_forward_hook inserts the marker 'fw_bw_boundary' at the boundary
- of forward pass and backward pass.
- """
- def _post_forward_hook(
- module: nn.Module,
- inputs: Sequence[torch.Tensor],
- outputs: Sequence[torch.Tensor],
- ) -> None:
- if (
- hasattr(module, "_memory_tracker_is_root")
- and module._memory_tracker_is_root
- ):
- self._add_marker("fw_bw_boundary")
- return _post_forward_hook
- def _create_backward_hook(self, name: str) -> Callable:
- """
- The backward_hook inserts the current module name with backward prefix for the operator name.
- """
- def _backward_hook(
- module: nn.Module, grad_input: torch.Tensor, grad_output: torch.Tensor
- ) -> None:
- self._cur_module_name = f"{name}.backward"
- return _backward_hook
- @no_type_check
- def _record_memory_stats(self, fn_name: str) -> None:
- """
- Record current memory allocated, current memory active and current memory reserved.
- The memory stats dict is indexed with ``self._op_index``.
- """
- memory_allocated: float = torch.cuda.memory_allocated() / BYTES_PER_MB
- memory_reserved: float = torch.cuda.memory_reserved() / BYTES_PER_MB
- memory_active: float = (
- torch.cuda.memory_stats().get("active_bytes.all.current", 0) / BYTES_PER_MB
- )
- self.memories_allocated[self._op_index] = (fn_name, memory_allocated)
- self.memories_reserved[self._op_index] = (fn_name, memory_reserved)
- self.memories_active[self._op_index] = (fn_name, memory_active)
- self._op_index += 1
- def _add_marker(self, marker_name: str) -> None:
- """
- Set the marker's x-axis value.
- """
- marker_val = len(self.memories_allocated.values())
- self._markers[marker_name] = marker_val
- def _clear_state(self) -> None:
- """
- Clear states when start_monitor() is called.
- """
- self._operator_names.clear()
- self.memories_allocated.clear()
- self.memories_active.clear()
- self.memories_reserved.clear()
- self._markers.clear()
- self._cur_module_name = ""
- self._op_index = 0
- self._num_cuda_retries = 0
|