_memory_profiler.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929
  1. import collections
  2. import dataclasses
  3. import enum
  4. import itertools as it
  5. import logging
  6. from typing import (
  7. Any,
  8. cast,
  9. DefaultDict,
  10. Dict,
  11. Iterator,
  12. List,
  13. Optional,
  14. Set,
  15. Tuple,
  16. Union,
  17. )
  18. import torch
  19. from torch._C import FunctionSchema
  20. from torch._C._autograd import _ProfilerResult
  21. from torch._C._profiler import (
  22. _EventType,
  23. _ExtraFields_Allocation,
  24. _ExtraFields_TorchOp,
  25. _ProfilerEvent,
  26. _TensorMetadata,
  27. RecordScope,
  28. )
  29. from torch._utils import _element_size
  30. from torch.profiler import _utils
  31. TensorAndID = Tuple["TensorKey", int]
  32. log = logging.getLogger(__name__)
  33. class Category(enum.Enum):
  34. INPUT = enum.auto()
  35. TEMPORARY = enum.auto()
  36. ACTIVATION = enum.auto()
  37. GRADIENT = enum.auto()
  38. AUTOGRAD_DETAIL = enum.auto()
  39. PARAMETER = enum.auto()
  40. OPTIMIZER_STATE = enum.auto()
  41. class Action(enum.Enum):
  42. PREEXISTING = enum.auto()
  43. CREATE = enum.auto()
  44. INCREMENT_VERSION = enum.auto()
  45. DESTROY = enum.auto()
  46. @dataclasses.dataclass
  47. class _Storage:
  48. """Bundle storage pointer and id.
  49. All profiling logic should use `allocation_id`, however it is useful to
  50. print storage pointers for debugging and unit tests sometimes look up
  51. values using the storage data pointer of a live Tensor."""
  52. ptr: int
  53. allocation_id: int
  54. def __repr__(self) -> str:
  55. return f"{hex(self.ptr):>18} ({self.allocation_id})"
  56. def __eq__(self, other: Any) -> bool:
  57. return isinstance(other, _Storage) and self.allocation_id == other.allocation_id
  58. def __hash__(self) -> int:
  59. return hash(self.allocation_id)
  60. @dataclasses.dataclass(eq=True, unsafe_hash=True, frozen=True)
  61. class TensorKey:
  62. """Hashable identifier for a storage which has been asigned an ID.
  63. A detailed description of Tensor IDs and why they are needed is given in
  64. `torch/csrc/profiler/collection.h` when `TensorID` is declared. To
  65. summarize, multiple Storage buffers can map to the same logical Tensor.
  66. This dataclass is used to refer to a concrete in-memory StorageImpl of
  67. a Tensor.
  68. """
  69. id: int
  70. storage: _Storage
  71. device: torch.device
  72. def __repr__(self) -> str:
  73. return f"id={self.id}: {repr(self.storage):<24} ({self.device})"
  74. def __lt__(self, other: "TensorKey") -> bool:
  75. return self._as_sortable < other._as_sortable
  76. @staticmethod
  77. def _make(
  78. tensor_id: Optional[int],
  79. storage_ptr: Optional[int],
  80. allocation_id: Optional[int],
  81. device: torch.device,
  82. ) -> Optional["TensorKey"]:
  83. if (
  84. tensor_id is not None
  85. and storage_ptr is not None
  86. and allocation_id is not None
  87. ):
  88. return TensorKey(tensor_id, _Storage(storage_ptr, allocation_id), device)
  89. return None
  90. @classmethod
  91. def from_allocation(cls, alloc: _ExtraFields_Allocation) -> Optional["TensorKey"]:
  92. return cls._make(alloc.id, alloc.ptr, alloc.allocation_id, alloc.device)
  93. @classmethod
  94. def from_tensor(cls, t: Optional[_TensorMetadata]) -> Optional["TensorKey"]:
  95. if t is not None:
  96. return cls._make(t.id, t.storage_data_ptr, t.allocation_id, t.device)
  97. return None
  98. @property
  99. def _as_sortable(self) -> Tuple[int, int, str, int]:
  100. return self.id, self.storage.allocation_id, self.device.type, self.device.index
  101. def _extract_parameters_and_gradients(
  102. node: _ProfilerEvent,
  103. ) -> Iterator[Tuple[Optional[TensorKey], Optional[TensorKey]]]:
  104. children = node.children
  105. # AccumulateGrad is used in the Autograd engine to handle gradient updates.
  106. # There are two possible cases:
  107. # 1) This is a newly created gradient Tensor. In that case there is nothing
  108. # to accumulate, so autograd simply detaches the Tensor.
  109. #
  110. # 2) There is a preexisting gradient Tensor and we need to add the newly
  111. # computed update. This is done with an in-place add (aten::add_) op.
  112. # (The underscore suffix denotes "in-place".)
  113. if (
  114. node.typed[0] == _EventType.TorchOp
  115. and node.typed[1].scope == RecordScope.BACKWARD_FUNCTION
  116. # TODO(robieta): Move away from load bearing names
  117. and node.name == "torch::autograd::AccumulateGrad"
  118. and children
  119. and children[0].typed[0] == _EventType.TorchOp
  120. and children[0].name in ("aten::detach", "aten::add_")
  121. and children[0].typed[1].inputs
  122. and isinstance(children[0].typed[1].inputs[0], _TensorMetadata)
  123. ):
  124. yield None, TensorKey.from_tensor(children[0].typed[1].inputs[0])
  125. # We directly instrument `torch.nn.Module` and `torch.optim.Optimizer`
  126. # NOTE: The values captured by the python tracer are cached; they can be
  127. # used to build up labels but do not imply that a Tensor was live at
  128. # a particular time.
  129. elif node.typed[0] == _EventType.PyCall:
  130. typed_fields = node.typed[1]
  131. assert typed_fields.module is None or typed_fields.optimizer is None
  132. if typed_fields.module is not None:
  133. for _, p, p_grad in typed_fields.module.parameters:
  134. yield TensorKey.from_tensor(p), TensorKey.from_tensor(p_grad)
  135. if typed_fields.optimizer is not None:
  136. for p, p_grad, _ in typed_fields.optimizer.parameters:
  137. yield TensorKey.from_tensor(p), TensorKey.from_tensor(p_grad)
  138. def extract_parameters(node: _ProfilerEvent) -> Iterator[TensorKey]:
  139. for p, p_grad in _extract_parameters_and_gradients(node):
  140. if p is not None:
  141. yield p
  142. def extract_gradients(
  143. node: _ProfilerEvent,
  144. ) -> Iterator[Tuple[Optional[TensorKey], TensorKey]]:
  145. for p, p_grad in _extract_parameters_and_gradients(node):
  146. if p_grad is not None:
  147. yield p, p_grad
  148. def get_scopes(event: Optional[_ProfilerEvent]) -> Tuple[RecordScope, ...]:
  149. scopes = []
  150. while event:
  151. if event.typed[0] == _EventType.TorchOp:
  152. scopes.append(event.typed[1].scope)
  153. event = event.parent
  154. return tuple(scopes)
  155. class SchemaMatcher:
  156. """Lookup operator schema based on profiled name.
  157. When profiling we record the operator's name but not the schema. However
  158. some analysis requires that information. Fortunately we can look up
  159. registered schema from the recorded name. We do not, however, record the
  160. overload and so we must compare the profiled arguments with all overloads
  161. to determine viable matches.
  162. Note: Once https://github.com/pytorch/pytorch/issues/78871 is completed
  163. this code will be obsolete.
  164. """
  165. @classmethod
  166. def inputs_are_mutable(cls, t: _ExtraFields_TorchOp) -> Tuple[Optional[bool], ...]:
  167. """Determine which inputs may have mutated based on function schema.
  168. Note that we don't need to resolve down to a single schema to perform
  169. this analysis. An input is mutable if it is mutable in any overload. In
  170. practice, however, it is overwhelmingly common to match a single
  171. overload. If we cannot find any valid schema then we must be
  172. conservative and assume all inputs are mutable.
  173. """
  174. mutable: Optional[List[bool]] = None
  175. for schema in cls.match_schemas(t):
  176. mutable = mutable or [False for _ in schema.arguments]
  177. for i, arg in enumerate(schema.arguments):
  178. mutable[i] |= getattr(arg.alias_info, "is_write", False)
  179. return tuple(mutable or (None for _ in t.inputs))
  180. @classmethod
  181. def match_schemas(cls, t: _ExtraFields_TorchOp) -> Tuple[FunctionSchema, ...]:
  182. signature = tuple(
  183. # Tensor
  184. TensorKey.from_tensor(i) if isinstance(i, _TensorMetadata)
  185. #
  186. # TensorList
  187. else [TensorKey.from_tensor(j) for j in i] if isinstance(i, list)
  188. #
  189. # Scalar and uncaptured inputs.
  190. else i
  191. for i in t.inputs
  192. )
  193. def matches(schema) -> bool:
  194. return len(schema.arguments) == len(signature) and all(
  195. cls._types_match(observed, schema_arg.type)
  196. for observed, schema_arg in zip(signature, schema.arguments)
  197. )
  198. return tuple(s for s in cls.lookup_schemas(t.name) or () if matches(s))
  199. @classmethod
  200. def _types_match(cls, observed, schema_type) -> bool:
  201. if isinstance(schema_type, torch._C.OptionalType):
  202. schema_type = schema_type.getElementType()
  203. return observed is None or cls._types_match(observed, schema_type)
  204. if isinstance(schema_type, torch._C.AnyType):
  205. return True
  206. if schema_type.isSubtypeOf(torch._C.ListType.ofTensors()):
  207. return isinstance(observed, list) and all(
  208. isinstance(i, TensorKey) for i in observed
  209. )
  210. type_map: Tuple[Tuple[Any, Union[type, Tuple[type, ...]]], ...] = (
  211. (torch._C.TensorType, TensorKey),
  212. (torch._C.NoneType, type(None)),
  213. (torch._C.BoolType, bool),
  214. (torch._C.IntType, int),
  215. (torch._C.FloatType, float),
  216. (torch._C.ComplexType, complex),
  217. (torch._C.NumberType, (bool, int, float, complex)),
  218. )
  219. for jit_type, py_types in type_map:
  220. if isinstance(schema_type, jit_type):
  221. return isinstance(observed, py_types)
  222. # Profiler only records a subset of possible argument types. If we
  223. # reach this point then the schema must call for a type that profiler
  224. # does not record. Thus, the schema can only be a match if `observed`
  225. # is also None.
  226. return observed is None
  227. @staticmethod
  228. def lookup_schemas(name: str) -> Optional[Tuple[FunctionSchema, ...]]:
  229. # TODO(robieta):
  230. # _jit_get_schemas_for_operator is quite expensive. (~100us / call)
  231. # Consider adding `functools.lru_cache` if that becomes an issue.
  232. try:
  233. # Schema lookup will throw if `name` is malformed. (For example,
  234. # schemas must be namespaced and schema lookup will fail if name
  235. # does not include "::".) We simply catch the exception and return
  236. # `None` to denote that `name` cannot be an operator name.
  237. #
  238. # Note that record_function annotations also go through this path,
  239. # so it is expected that some names will not correspond to PyTorch
  240. # operators.
  241. return tuple(torch._C._jit_get_schemas_for_operator(name))
  242. except RuntimeError:
  243. return None
  244. class OpTree:
  245. def __init__(self, result: _ProfilerResult) -> None:
  246. self._root_nodes = result.experimental_event_tree()
  247. self._sorted_nodes = tuple(sorted(self.dfs(), key=lambda x: x.start_time_ns))
  248. def dfs(self, *args, **kwargs) -> Iterator[_ProfilerEvent]:
  249. yield from _utils.traverse_dfs(self._root_nodes, *args, **kwargs)
  250. @property
  251. def sorted_nodes(self) -> Tuple[_ProfilerEvent, ...]:
  252. return self._sorted_nodes
  253. class SizeMap:
  254. def __init__(self, op_tree: OpTree) -> None:
  255. self._values: Dict[TensorKey, int] = {}
  256. for node in op_tree.sorted_nodes:
  257. if node.typed[0] == _EventType.TorchOp:
  258. for t in self._flat_tensor_inputs(node.typed[1]):
  259. self._update_values(t)
  260. elif node.typed[0] == _EventType.PyCall:
  261. typed_fields = node.typed[1]
  262. assert typed_fields.module is None or typed_fields.optimizer is None
  263. if typed_fields.module is not None:
  264. for _, p, p_grad in typed_fields.module.parameters:
  265. self._update_values(p)
  266. self._update_values(p_grad)
  267. if typed_fields.optimizer is not None:
  268. for p, p_grad, state in typed_fields.optimizer.parameters:
  269. self._update_values(p)
  270. self._update_values(p_grad)
  271. for _, t in state:
  272. self._update_values(t)
  273. allocations: Dict[TensorKey, int] = {}
  274. for node in op_tree.sorted_nodes:
  275. if node.typed[0] == _EventType.Allocation:
  276. alloc_fields = node.typed[1]
  277. key = TensorKey.from_allocation(alloc_fields)
  278. if key:
  279. new_size = abs(alloc_fields.alloc_size)
  280. prior_size = allocations.setdefault(key, new_size)
  281. # It is possible to resize Storage in PyTorch, however we
  282. # key on data pointer so most resizes will be treated as a
  283. # change in storage. The one corner case that cannot be
  284. # handled is `realloc` which successfully resizes the
  285. # storage. At time of writing this is not done anywhere in
  286. # the core PyTorch codebase.
  287. if prior_size != new_size:
  288. delta = f"{prior_size} vs. {new_size}"
  289. log.warning(f"Mismatch between allocation and free: {delta}")
  290. self._values.update(allocations)
  291. def _update_values(self, t: Optional[_TensorMetadata]) -> None:
  292. key = TensorKey.from_tensor(t)
  293. if key is not None and t is not None and t.layout == torch.strided:
  294. # Scalars are represented as zero dim Tensors
  295. n = max(i[0] * i[1] for i in zip(t.sizes or [1], t.strides or [1]))
  296. num_bytes = n * _element_size(t.dtype)
  297. assert num_bytes >= 0, f"{num_bytes}"
  298. self._values[key] = max(self._values.get(key, 0), num_bytes)
  299. @staticmethod
  300. def _flat_tensor_inputs(op: _ExtraFields_TorchOp) -> Iterator[_TensorMetadata]:
  301. for i in op.inputs:
  302. if isinstance(i, _TensorMetadata):
  303. yield i
  304. elif isinstance(i, list):
  305. for t in i:
  306. yield t
  307. def __getitem__(self, key: TensorKey):
  308. return self._values[key]
  309. @dataclasses.dataclass()
  310. class DataFlowEdge:
  311. input_version: Optional[int] = None
  312. mutated: Optional[bool] = False
  313. @property
  314. def is_allocation(self) -> bool:
  315. return self.input_version is None
  316. @property
  317. def is_deletion(self) -> bool:
  318. return self.mutated is None
  319. class DataFlowNode:
  320. def __init__(self, event: _ProfilerEvent, graph: "DataFlowGraph") -> None:
  321. self._event = event
  322. self._graph = graph
  323. self._edges: Dict[TensorKey, DataFlowEdge] = self._determine_edges()
  324. for key, edge in self._edges.items():
  325. if edge.mutated and not edge.is_allocation:
  326. self._graph.bump(key)
  327. # Make sure the version bumping behavior matches what we expect.
  328. versions = {k: (v, self._graph.lookup(k)) for k, v in self.outputs.items()}
  329. assert all(i == j for i, j in versions.values()), f"{versions}, {self._edges}"
  330. def _determine_edges(self) -> Dict[TensorKey, DataFlowEdge]:
  331. subtree = tuple(_utils.traverse_dfs([self._event]))
  332. # Start by populating edges from op inputs and outputs.
  333. mutable_by_key: Dict[Optional[TensorKey], Set[Optional[bool]]] = {}
  334. for op in (i.typed[1] for i in subtree if i.typed[0] == _EventType.TorchOp):
  335. for op_input, mutable in zip(
  336. op.inputs, SchemaMatcher.inputs_are_mutable(op)
  337. ):
  338. # Tensor
  339. if isinstance(op_input, _TensorMetadata):
  340. key = TensorKey.from_tensor(op_input)
  341. mutable_by_key.setdefault(key, set()).add(mutable)
  342. # TensorList
  343. elif isinstance(op_input, list):
  344. for op_input_i in op_input:
  345. key = TensorKey.from_tensor(op_input_i)
  346. mutable_by_key.setdefault(key, set()).add(mutable)
  347. edges: DefaultDict[Optional[TensorKey], DataFlowEdge]
  348. edges = collections.defaultdict(DataFlowEdge)
  349. for key, mutable_set in mutable_by_key.items():
  350. if key is not None:
  351. edges[key].input_version = self._graph.lookup(key) if key else -1
  352. # We consider an op to be mutated if we encounter a schema where it
  353. # is a mutable argument OR if it is ambiguous. (We never explicitly
  354. # see it in any schema.)
  355. mutated = (True in mutable_set) or (tuple(mutable_set) == (None,))
  356. edges[key].mutated = mutated
  357. # Then handle deletions. Note that deleting a Tensor implicitly adds
  358. # it as an input edge.
  359. for i in subtree:
  360. if i.typed[0] == _EventType.Allocation and i.typed[1].alloc_size < 0:
  361. key = TensorKey.from_allocation(i.typed[1])
  362. edge = edges[key]
  363. assert key is None or edge.mutated is not None, f"Double delete: {key}"
  364. edge.mutated = None
  365. edge.input_version = self._graph.lookup(key) if key else -1
  366. # And finally handle allocations. This step must be last, because the
  367. # previous two steps optimistically add input edges.
  368. for i in subtree:
  369. if i.typed[0] == _EventType.Allocation and i.typed[1].alloc_size > 0:
  370. edges[TensorKey.from_allocation(i.typed[1])].input_version = None
  371. # We don't need to sort the inputs, but it makes debugging and unit tests nicer.
  372. return dict(sorted((k, v) for k, v in edges.items() if k is not None))
  373. @property
  374. def inputs(self) -> Dict[TensorKey, Tuple[bool, int]]:
  375. return {
  376. # MyPy can't see through `is_allocation` to know that
  377. # `v.input_version` is not None.
  378. k: (bool(v.mutated), cast(int, v.input_version))
  379. for k, v in self._edges.items()
  380. if not v.is_allocation
  381. }
  382. @property
  383. def outputs(self) -> Dict[TensorKey, int]:
  384. return {
  385. k: 0 if v.input_version is None else v.input_version + 1
  386. for k, v in self._edges.items()
  387. if (v.is_allocation and not v.is_deletion) or v.mutated
  388. }
  389. @property
  390. def intermediates(self) -> Tuple[TensorKey, ...]:
  391. return tuple(
  392. k for k, v in self._edges.items() if v.is_allocation and v.is_deletion
  393. )
  394. @property
  395. def start_time(self) -> int:
  396. return self._event.start_time_ns
  397. class DataFlowGraph:
  398. def __init__(self, op_tree: OpTree) -> None:
  399. self._op_tree = op_tree
  400. self._leaf_events = self._extract_leaf_events(op_tree)
  401. self._active_version: Dict[TensorKey, Optional[int]] = {}
  402. self._flow_nodes = [DataFlowNode(e, self) for e in self.leaf_events]
  403. self._flow_nodes.sort(key=lambda x: x.start_time)
  404. self.validate()
  405. @property
  406. def flow_nodes(self) -> Tuple[DataFlowNode, ...]:
  407. return tuple(self._flow_nodes)
  408. def validate(self):
  409. # Check that each (Tensor, version) pair has a unique creation node
  410. outputs: Set[Tuple[TensorKey, int]] = set()
  411. for node in self.flow_nodes:
  412. node_outputs = set(node.outputs.items())
  413. duplicates = outputs & node_outputs
  414. assert not duplicates, f"{node._event.name} {node._edges} {duplicates}"
  415. outputs |= node_outputs
  416. # And check that `self._nodes` forms a valid topologically sorted DAG.
  417. tensor_versions: Dict[TensorKey, int] = {}
  418. for node in self.flow_nodes:
  419. for key, (_, version) in node.inputs.items():
  420. expected = tensor_versions.get(key, 0)
  421. assert expected == version, (expected, version)
  422. for key, version in node.outputs.items():
  423. prior_version = tensor_versions.get(key, version)
  424. assert version >= prior_version, (version, prior_version)
  425. tensor_versions[key] = version
  426. @property
  427. def leaf_events(self) -> Tuple[_ProfilerEvent, ...]:
  428. return self._leaf_events
  429. @staticmethod
  430. def _extract_leaf_events(op_tree: OpTree) -> Tuple[_ProfilerEvent, ...]:
  431. """Partially traverse the op tree and extract top level ops.
  432. Consider the following code:
  433. ```
  434. with record_function("My annotation"):
  435. x.zero_()
  436. y.zero_()
  437. ```
  438. The op tree (assuming no Autograd) will look like:
  439. <Python context>
  440. TorchOp: "My annotation"
  441. TorchOp: zero_
  442. TorchOp: fill_
  443. TorchOp: zero_
  444. TorchOp: fill_
  445. The recursive structure of operator calls makes data flow unwieldy.
  446. In order to simplify analysis we would like to select the highest level
  447. ops to represent in the graph. In this case those are the `zero_` ops;
  448. the fact that `fill_` is called is an implementation detail. We also
  449. do not want to group everything under "My annotation" as this could
  450. create overly coarse bundles and lose critical semantics.
  451. To address this issue we walk over the graph and select the topmost
  452. torch ops ** which match at least one operator schema **. These form
  453. the leaves of the first pass through the op tree. (As well as any
  454. allocations or frees which do are not part of a kernel.) These events
  455. form the logical nodes in our data flow graph.
  456. """
  457. leaf_events: List[_ProfilerEvent] = []
  458. def leaf_op(e: _ProfilerEvent) -> bool:
  459. return e.typed[0] == _EventType.TorchOp and (
  460. e.typed[1].scope == RecordScope.BACKWARD_FUNCTION
  461. or bool(SchemaMatcher.match_schemas(e.typed[1]))
  462. )
  463. def children_fn(e: _ProfilerEvent):
  464. if leaf_op(e) or e.tag == _EventType.Allocation:
  465. leaf_events.append(e)
  466. return []
  467. return e.children
  468. for _ in op_tree.dfs(children_fn=children_fn):
  469. pass
  470. return tuple(sorted(leaf_events, key=lambda x: x.start_time_ns))
  471. def lookup(self, key: TensorKey) -> int:
  472. version = self._active_version.setdefault(key, 0)
  473. assert version is not None
  474. return version
  475. def bump(self, key: TensorKey) -> None:
  476. prior_version = self._active_version.get(key, None)
  477. assert prior_version is not None
  478. self._active_version[key] = prior_version + 1
  479. def delete(self, key: TensorKey) -> None:
  480. assert self._active_version.setdefault(key, 0) is not None
  481. self._active_version[key] = None
  482. @dataclasses.dataclass
  483. class CategoryElement:
  484. by_id: Optional[Category] = None
  485. by_key: Dict[TensorKey, Category] = dataclasses.field(default_factory=dict)
  486. by_version: Dict[TensorAndID, Category] = dataclasses.field(default_factory=dict)
  487. # Used by unit tests to check internals. (And consequently by
  488. # MemoryProfile.lookup) This should not be used in any other capacity.
  489. _by_id_keyset: Set[TensorKey] = dataclasses.field(default_factory=set)
  490. @dataclasses.dataclass
  491. class CategoryDict:
  492. _values: DefaultDict[int, CategoryElement] = dataclasses.field(
  493. default_factory=lambda: collections.defaultdict(CategoryElement)
  494. )
  495. def set_by_id(self, key: TensorKey, category: Category) -> None:
  496. self._values[key.id].by_id = category
  497. self._values[key.id]._by_id_keyset.add(key)
  498. def set_by_key(self, key: TensorKey, category: Category) -> None:
  499. self._values[key.id].by_key[key] = category
  500. def set_by_version(self, key: TensorKey, version: int, category: Category) -> None:
  501. self._values[key.id].by_version[(key, version)] = category
  502. def setdefault_by_version(
  503. self, key: TensorKey, version: int, category: Category
  504. ) -> None:
  505. self._values[key.id].by_version.setdefault((key, version), category)
  506. def get(self, key: TensorKey, version: int) -> Optional[Category]:
  507. element = self._values[key.id]
  508. return (
  509. element.by_id
  510. or element.by_key.get(key, None)
  511. or element.by_version.get((key, version), None)
  512. )
  513. class MemoryProfile:
  514. def __init__(self, result: _ProfilerResult) -> None:
  515. self._op_tree = OpTree(result)
  516. self._data_flow_graph = DataFlowGraph(self._op_tree)
  517. self._size_map = SizeMap(self._op_tree)
  518. self._categories = CategoryDict()
  519. self._set_gradients_and_temporaries()
  520. self._set_parameters_using_python_tracer()
  521. self._set_inputs()
  522. self._set_parameters_using_data_flow()
  523. self._set_activations()
  524. self._set_optimizer_state()
  525. self._set_autograd_detail()
  526. @property
  527. def timeline(self) -> Tuple[Tuple[int, Action, TensorAndID, int], ...]:
  528. t0 = min(event.start_time_ns for event in self._op_tree.dfs())
  529. allocation_times: Dict[Tuple[TensorKey, bool], int] = {}
  530. for event in self._op_tree.dfs():
  531. if event.typed[0] == _EventType.Allocation:
  532. alloc_fields = event.typed[1]
  533. key = TensorKey.from_allocation(alloc_fields)
  534. if key is not None:
  535. is_allocation = alloc_fields.alloc_size > 0
  536. allocation_times[(key, is_allocation)] = event.start_time_ns - t0
  537. snapshot = self._category_snapshot()
  538. last_version = {key: version for key, version in sorted(snapshot.keys())}
  539. events: List[Tuple[int, Action, TensorAndID]] = [
  540. (-1, Action.PREEXISTING, (key, version))
  541. for key, version in snapshot.keys()
  542. if (key, True) not in allocation_times and version == 0
  543. ]
  544. for node in self._data_flow_graph.flow_nodes:
  545. for key, edge in node._edges.items():
  546. if edge.is_allocation:
  547. t = allocation_times[(key, True)]
  548. events.append((t, Action.CREATE, (key, 0)))
  549. elif edge.mutated:
  550. t = node._event.start_time_ns - t0
  551. version = edge.input_version
  552. assert version is not None
  553. events.append((t, Action.INCREMENT_VERSION, (key, version)))
  554. if edge.is_deletion:
  555. t = allocation_times[(key, False)]
  556. events.append((t, Action.DESTROY, (key, last_version[key])))
  557. events.sort(key=lambda x: (x[0], x[1].value))
  558. return tuple(
  559. (time, action, (key, version), self._size_map[key])
  560. for time, action, (key, version) in events
  561. )
  562. def _is_gradient(self, *args, **kwargs) -> bool:
  563. return self._categories.get(*args, **kwargs) == Category.GRADIENT
  564. def _category_snapshot(self) -> Dict[TensorAndID, Optional[Category]]:
  565. all_tensor_versions: Set[TensorAndID] = set()
  566. for node in self._data_flow_graph.flow_nodes:
  567. all_tensor_versions.update(((k, v) for k, (_, v) in node.inputs.items()))
  568. all_tensor_versions.update(((key, 0) for key in node.intermediates))
  569. all_tensor_versions.update(node.outputs.items())
  570. for i in self._categories._values.values():
  571. all_tensor_versions.update(((key, 0) for key in i._by_id_keyset))
  572. return {
  573. (key, version): self._categories.get(key, version)
  574. for key, version in sorted(all_tensor_versions)
  575. }
  576. def _any_version_depends_on_gradient(self) -> Set[int]:
  577. """Extract IDs of Tensors which depend or will depend on a gradient.
  578. Note that this weakened definition of "depends" requires us to loop
  579. over the data flow graph multiple times because it allows dependency
  580. information to flow backward through edges and removes the guarantee
  581. that nodes are topologically sorted. (Or indeed, even that a valid
  582. topological order exists.) Put another way, we have converted an
  583. acyclic data flow graph into a cyclic graph and we are attempting to
  584. partition cycles involving a gradient from the rest of the graph.
  585. """
  586. depends_on_gradient: Set[int] = set()
  587. while True:
  588. start_size = len(depends_on_gradient)
  589. for node in self._data_flow_graph.flow_nodes:
  590. ids = tuple(
  591. key.id
  592. for key, (_, version) in node.inputs.items()
  593. if self._categories.get(key, version)
  594. in (Category.GRADIENT, Category.PARAMETER)
  595. or key.id in depends_on_gradient
  596. )
  597. if ids:
  598. depends_on_gradient.update(ids)
  599. depends_on_gradient.update(key.id for key in node.outputs)
  600. # We are guaranteed to exit because there is a finite set of
  601. # TensorAndID pairs. In practice we do not expect to loop more than
  602. # three times: once to identify the core parameter update loop,
  603. # once to fold the first step into that loop, and a third time
  604. # where no new elements are added.
  605. if len(depends_on_gradient) == start_size:
  606. return depends_on_gradient
  607. def _set_gradients_and_temporaries(self) -> None:
  608. """Mark Tensors which are unambiguous and simple to reason about."""
  609. # Gradients are straightforward to detect. We directly check the
  610. # `.grad` property in the Python tracer, and we can detect any new
  611. # gradient Tensors from `AccumulateGrad` ops.
  612. for event in self._op_tree.dfs():
  613. for _, p_grad in extract_gradients(event):
  614. self._categories.set_by_id(p_grad, Category.GRADIENT)
  615. # Similarly, temporary Tensors are easy to identify and are useful to
  616. # flag since they can make memory use "spikier" than one would
  617. # otherwise expect.
  618. for node in self._data_flow_graph.flow_nodes:
  619. for i in node.intermediates:
  620. self._categories.set_by_key(i, Category.TEMPORARY)
  621. def _set_parameters_using_python_tracer(self) -> None:
  622. for event in self._op_tree.dfs():
  623. for p in extract_parameters(event):
  624. if p is not None:
  625. self._categories.set_by_id(p, Category.PARAMETER)
  626. def _set_inputs(self) -> None:
  627. """Mark inputs based on which Tensors are updated using gradients.
  628. The process for differentiating between inputs and activations is more
  629. involved. Most Tensors in a training loop depend on at least one
  630. gradient: parameters depend on them through updates, and activations
  631. and optimizer state depend on them transitively through parameters.
  632. Critically, we do not need to know which Tensors are parameters to
  633. apply this method; we can simply walk the data flow graph to build the
  634. set of all values which depend on a gradient and then obtain the set
  635. of inputs from the conjugate set.
  636. There is, however, one hiccup. The first time we see a parameter is
  637. generally on the forward pass of the first step. We know from
  638. inspection of the data flow graph that v1 of that Tensor depends on
  639. a gradient (provided we profile an optimizer step), but not v0. To
  640. address this problem we weaken the definition of "depends on a
  641. gradient" to "any version of this Tensor depends on a gradient",
  642. which in turn strengthens the criteria for the input set enough to
  643. filter the activations in the forward pass of the first step."""
  644. # All of this analysis is predicated on using at least one training
  645. # step (or parameters from the python tracer) to partition the graph.
  646. # Absent that we cannot determine which Tensors are inputs and which
  647. # ones are part of the model.
  648. depends_on_gradient = self._any_version_depends_on_gradient()
  649. # We only want to annotate Tensors which actually contribute to the
  650. # model calculation.
  651. produces_gradient: Set[TensorAndID] = set()
  652. for node in reversed(self._data_flow_graph.flow_nodes):
  653. tensors = {(key, version) for key, (_, version) in node.inputs.items()}
  654. tensors |= node.outputs.items()
  655. if any(
  656. self._categories.get(*i) in (Category.GRADIENT, Category.PARAMETER)
  657. or i in produces_gradient
  658. for i in tensors
  659. ):
  660. produces_gradient |= tensors
  661. # Don't include Tensors created in the backward pass, as these are
  662. # generally Autograd implementation details rather than proper inputs.
  663. input_candidates = produces_gradient.copy()
  664. for node in self._data_flow_graph.flow_nodes:
  665. if RecordScope.BACKWARD_FUNCTION in get_scopes(node._event):
  666. input_candidates -= set(node.outputs.items())
  667. for key, version in input_candidates:
  668. if key.id not in depends_on_gradient:
  669. self._categories.setdefault_by_version(key, version, Category.INPUT)
  670. def _set_parameters_using_data_flow(self) -> None:
  671. """Deduce which Tensors are parameters.
  672. Consider the following code for the step of SGD with momentum
  673. (nesterov=False), where `d_p` is the gradient of `param` and `buf` is
  674. the momentum buffer.
  675. ```
  676. buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
  677. d_p = buf
  678. param.add_(d_p, alpha=-lr)
  679. ```
  680. Both `param` and `buf` take a gradient and perform an in-place update.
  681. The python tracer will inspect calls to `nn.Module.forward` and
  682. `optim.Optimizer.step` to extract parameter and optimizer state
  683. respectively (including parameters), so this is generally a non-issue.
  684. However as a fallback we can also exploit several properties of
  685. parameters to distinguish them from other model state.
  686. First, they are directly used in the forward pass. (At this point we
  687. haven't established which parts of the graph correspond to the forward
  688. pass but we can deduce enough to suffice.) Some mutable state such as
  689. batch norm moving averages also contribute to the forward pass, but
  690. optimizer state does not.
  691. Second, a parameter is by definition used to compute at least one
  692. gradient and depends on at least one gradient.
  693. """
  694. snapshot = self._category_snapshot()
  695. # Determine which Tensors might be parameters based on forward pass
  696. # data flow. Note this these are only candidates; we filter nodes that
  697. # we know are part of the backward pass but that doesn't guarantee that
  698. # they are part of the forward pass.
  699. candidate_parameters: Set[TensorAndID] = set()
  700. candidate_fwd_tensors: Set[TensorAndID] = {
  701. i for i, category in snapshot.items() if category == Category.INPUT
  702. }
  703. for node in self._data_flow_graph.flow_nodes:
  704. inputs = {(key, value) for key, (_, value) in node.inputs.items()}
  705. if (
  706. # Don't check nodes in the backward pass.
  707. RecordScope.BACKWARD_FUNCTION not in get_scopes(node._event)
  708. and not any(self._is_gradient(*i) for i in inputs)
  709. and not any(self._is_gradient(*i) for i in node.outputs.items())
  710. #
  711. # and only check nodes which depend on an input.
  712. and candidate_fwd_tensors.intersection(inputs)
  713. ):
  714. candidate_fwd_tensors |= node.outputs.items()
  715. candidate_parameters |= inputs.difference(candidate_fwd_tensors)
  716. # Require that each parameter eventually contributes to the value of a gradient
  717. used_for_gradient: Set[TensorAndID] = set()
  718. for node in reversed(self._data_flow_graph.flow_nodes):
  719. if any(
  720. self._is_gradient(*i) or i in used_for_gradient
  721. for i in node.outputs.items()
  722. ):
  723. for key, (_, version) in node.inputs.items():
  724. used_for_gradient.add((key, version))
  725. candidate_parameters.intersection_update(used_for_gradient)
  726. # and depends on a gradient.
  727. parameter_keys = {key.id for key, _ in candidate_parameters}
  728. parameter_keys &= self._any_version_depends_on_gradient()
  729. for key, _ in snapshot.keys():
  730. if key.id in parameter_keys:
  731. self._categories.set_by_id(key, Category.PARAMETER)
  732. def _set_activations(self) -> None:
  733. """Flood the graph to identify activations."""
  734. required = {Category.INPUT, Category.ACTIVATION}
  735. also_allowed = {Category.PARAMETER, Category.TEMPORARY}
  736. for node in self._data_flow_graph.flow_nodes:
  737. inputs = {(key, value) for key, (_, value) in node.inputs.items()}
  738. input_categories = {self._categories.get(*i) for i in inputs}
  739. if (
  740. (input_categories & required)
  741. and not (input_categories - (required | also_allowed))
  742. #
  743. # Stop filling when we reach the backward pass.
  744. and RecordScope.BACKWARD_FUNCTION not in get_scopes(node._event)
  745. ):
  746. for i in node.outputs.items():
  747. self._categories.setdefault_by_version(*i, Category.ACTIVATION)
  748. def _set_optimizer_state(self) -> None:
  749. for event in self._op_tree.dfs():
  750. if event.typed[0] == _EventType.PyCall and event.typed[1].optimizer:
  751. parameters = event.typed[1].optimizer.parameters
  752. for _, t in it.chain(*[state for _, _, state in parameters]):
  753. key = TensorKey.from_tensor(t)
  754. if key is not None:
  755. self._categories.set_by_id(key, Category.OPTIMIZER_STATE)
  756. def _set_autograd_detail(self):
  757. prior = {None, Category.AUTOGRAD_DETAIL}
  758. for node in self._data_flow_graph.flow_nodes:
  759. if RecordScope.BACKWARD_FUNCTION in get_scopes(node._event):
  760. for key, version in node.outputs.items():
  761. if version == 0 or self._categories.get(key, version - 1) in prior:
  762. self._categories.setdefault_by_version(
  763. key, version, Category.AUTOGRAD_DETAIL
  764. )