_memory_viz.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833
  1. import pickle
  2. import sys
  3. import os
  4. import io
  5. import subprocess
  6. import json
  7. from functools import lru_cache
  8. from typing import List, Tuple
  9. cache = lru_cache(None)
  10. __all__ = ["format_flamegraph", "segments", "memory", "compare"]
  11. def _frame_fmt(f, full_filename=False):
  12. i = f['line']
  13. fname = f['filename']
  14. if not full_filename:
  15. fname = fname.split('/')[-1]
  16. func = f['name']
  17. return f'{fname}:{i}:{func}'
  18. def format_flamegraph(flamegraph_lines, flamegraph_script=None):
  19. if flamegraph_script is None:
  20. flamegraph_script = f'/tmp/{os.getuid()}_flamegraph.pl'
  21. if not os.path.exists(flamegraph_script):
  22. import urllib.request
  23. print(f"Downloading flamegraph.pl to: {flamegraph_script}")
  24. urllib.request.urlretrieve(
  25. 'https://raw.githubusercontent.com/brendangregg/FlameGraph/master/flamegraph.pl', flamegraph_script)
  26. subprocess.run(['chmod', '+x', flamegraph_script])
  27. args = [flamegraph_script, '--countname', 'bytes']
  28. p = subprocess.Popen(args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, encoding='utf-8')
  29. assert p.stdin is not None
  30. assert p.stdout is not None
  31. p.stdin.write(flamegraph_lines)
  32. p.stdin.close()
  33. result = p.stdout.read()
  34. p.stdout.close()
  35. p.wait()
  36. assert p.wait() == 0
  37. return result
  38. def _write_blocks(f, prefix, blocks):
  39. for b in blocks:
  40. if 'history' not in b:
  41. f.write(f'{prefix};{b["state"]} {b["size"]}\n')
  42. continue
  43. accounted_for_size = 0
  44. for h in b['history']:
  45. sz = h['real_size']
  46. accounted_for_size += sz
  47. if 'frames' in h:
  48. frames = h['frames']
  49. if frames:
  50. frame_s = ';'.join([_frame_fmt(f) for f in reversed(frames)])
  51. else:
  52. frame_s = "<non-python>"
  53. f.write(f'{prefix};{b["state"]};{frame_s} {sz}\n')
  54. else:
  55. f.write(f'{prefix};{b["state"]};<no-context> {sz}\n')
  56. gaps = b['size'] - accounted_for_size
  57. if gaps:
  58. f.write(f'{prefix};{b["state"]};<gaps> {gaps}\n')
  59. def segments(snapshot, format_flamegraph=format_flamegraph):
  60. f = io.StringIO()
  61. for seg in snapshot['segments']:
  62. prefix = f'stream_{seg["stream"]};seg_{seg["address"]}'
  63. _write_blocks(f, prefix, seg['blocks'])
  64. return format_flamegraph(f.getvalue())
  65. def memory(snapshot, format_flamegraph=format_flamegraph):
  66. f = io.StringIO()
  67. for seg in snapshot['segments']:
  68. prefix = f'stream_{seg["stream"]}'
  69. _write_blocks(f, prefix, seg['blocks'])
  70. return format_flamegraph(f.getvalue())
  71. def compare(before, after, format_flamegraph=format_flamegraph):
  72. def _seg_key(seg):
  73. return (seg['address'], seg['total_size'])
  74. def _seg_info(seg):
  75. return f'stream_{seg["stream"]};seg_{seg["address"]}'
  76. f = io.StringIO()
  77. before_segs = {_seg_key(seg) for seg in before}
  78. after_segs = {_seg_key(seg) for seg in after}
  79. print(f'only_before = {[a for a,_ in (before_segs - after_segs)]}')
  80. print(f'only_after = {[a for a,_ in (after_segs - before_segs)]}')
  81. for seg in before:
  82. if _seg_key(seg) not in after_segs:
  83. _write_blocks(f, f'only_before;{_seg_info(seg)}', seg['blocks'])
  84. for seg in after:
  85. if _seg_key(seg) not in before_segs:
  86. _write_blocks(f, f'only_after;{_seg_info(seg)}', seg['blocks'])
  87. return format_flamegraph(f.getvalue())
  88. def _format_size(num):
  89. # https://stackoverflow.com/questions/1094841/get-human-readable-version-of-file-size
  90. for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]:
  91. if abs(num) < 1024.0:
  92. return f"{num:3.1f}{unit}B"
  93. num /= 1024.0
  94. return f"{num:.1f}YiB"
  95. class Bytes:
  96. def __init__(self, value):
  97. self.value = value
  98. def __add__(self, rhs):
  99. return Bytes(self.value + rhs)
  100. def __repr__(self):
  101. return _format_size(self.value)
  102. def calc_active(seg):
  103. return sum(b['size'] for b in seg['blocks'] if b['state'] == 'active_allocated')
  104. def _report_free(free_external, free_internal):
  105. total = free_external + free_internal
  106. pct = (free_internal / total) * 100
  107. suffix = f' ({pct:.1f}% internal)'
  108. return f'{Bytes(total)}{suffix}'
  109. PAGE_SIZE = 1024 * 1024 * 20
  110. legend = f"""\
  111. Legend:
  112. [a ] - a segment in the allocator
  113. ^-- a page {Bytes(PAGE_SIZE)} of memory in the segment
  114. a-z: pages filled with a single block's content
  115. ' ': page is completely free
  116. *: page if completely full with multiple blocks
  117. 0-9: page is partially full with tensors of multiple blocks (9 == 90% full)
  118. (X% internal) - of the free memory, X% is free because we rounded the size of the allocation.
  119. """
  120. def segsum(data):
  121. """" Visually reports how the allocator has filled its segments. This printout can help debug fragmentation issues
  122. since free fragments will appear as gaps in this printout. The amount of free space is reported for each segment.
  123. We distinguish between internal free memory which occurs because the allocator rounds the allocation size, and
  124. external free memory, which are the gaps between allocations in a segment.
  125. Args:
  126. data: snapshot dictionary created from _snapshot()
  127. """
  128. segments = []
  129. out = io.StringIO()
  130. out.write(f"Summary of segments >= {Bytes(PAGE_SIZE)} in size\n")
  131. total_reserved = 0
  132. total_allocated = 0
  133. free_external = 0
  134. free_internal = 0
  135. for seg in sorted(data['segments'], key=lambda x: (x['total_size'], calc_active(x))):
  136. total_reserved += seg['total_size']
  137. seg_free_external = 0
  138. seg_free_internal = 0
  139. seg_allocated = 0
  140. all_ranges = []
  141. boffset = 0
  142. for b in seg['blocks']:
  143. active = b['state'] == 'active_allocated'
  144. if 'history' in b:
  145. # use the more accureate real_size to account for internal fragmenetation if we have it
  146. for h in b['history']:
  147. if active:
  148. all_ranges.append((h['addr'] - seg['address'], h['real_size'], active))
  149. seg_allocated += h['real_size']
  150. assert len(b['history']) == 1
  151. seg_free_internal += b['size'] - h['real_size']
  152. else:
  153. if active:
  154. all_ranges.append((boffset, b['size'], True))
  155. seg_allocated += b['size']
  156. if not active:
  157. seg_free_external += b['size']
  158. boffset += b['size']
  159. total_allocated += seg_allocated
  160. free_external += seg_free_external
  161. free_internal += seg_free_internal
  162. nseg = (seg['total_size'] - 1) // PAGE_SIZE + 1
  163. occupied = [' ' for _ in range(nseg)]
  164. frac = [0.0 for _ in range(nseg)]
  165. active_size = 0
  166. for i, (start_, size, active) in enumerate(all_ranges):
  167. active_size += size
  168. finish_ = (start_ + size)
  169. start = start_ // PAGE_SIZE
  170. finish = (finish_ - 1) // PAGE_SIZE + 1
  171. m = chr((ord('a' if active else 'A') + (i % 26)))
  172. for j in range(start, finish):
  173. s = max(start_, j * PAGE_SIZE)
  174. e = min(finish_, (j + 1) * PAGE_SIZE)
  175. frac[j] += (e - s) / PAGE_SIZE
  176. if occupied[j] != ' ':
  177. occupied[j] = '0123456789*'[int(frac[j] * 10)]
  178. else:
  179. occupied[j] = m
  180. stream = '' if seg['stream'] == 0 else f', stream_{seg["stream"]}'
  181. body = ''.join(occupied)
  182. assert seg_free_external + seg_free_internal + seg_allocated == seg['total_size']
  183. stream = f' stream_{seg["stream"]}' if seg['stream'] != 0 else ''
  184. if seg['total_size'] >= PAGE_SIZE:
  185. out.write(f'[{body}] {Bytes(seg["total_size"])} allocated, '
  186. f'{_report_free(seg_free_external, seg_free_internal)} free{stream}\n')
  187. out.write(f'segments: {len(data["segments"])}\n')
  188. out.write(f'total_reserved: {Bytes(total_reserved)}\n')
  189. out.write(f'total_allocated: {Bytes(total_allocated)}\n')
  190. internal_external = f' ({Bytes(free_internal)} internal + {Bytes(free_external)} external)' if free_internal else ''
  191. out.write(f'total_free: {_report_free(free_external, free_internal)}\n')
  192. out.write(legend)
  193. assert free_internal + free_external + total_allocated == total_reserved
  194. return out.getvalue()
  195. def trace(data):
  196. out = io.StringIO()
  197. def format(entries):
  198. segment_intervals : list = []
  199. segment_addr_to_name = {}
  200. allocation_addr_to_name = {}
  201. free_names : list = []
  202. next_name = 0
  203. def _name():
  204. nonlocal next_name
  205. if free_names:
  206. return free_names.pop()
  207. r, m = next_name // 26, next_name % 26
  208. next_name += 1
  209. return f'{chr(ord("a") + m)}{"" if r == 0 else r}'
  210. def find_segment(addr):
  211. for name, saddr, size in segment_intervals:
  212. if addr >= saddr and addr < saddr + size:
  213. return name, saddr
  214. for i, seg in enumerate(data['segments']):
  215. saddr = seg['address']
  216. size = seg['allocated_size']
  217. if addr >= saddr and addr < saddr + size:
  218. return f'seg_{i}', saddr
  219. return None, None
  220. count = 0
  221. out.write(f'{len(entries)} entries\n')
  222. total_reserved = 0
  223. for seg in data['segments']:
  224. total_reserved += seg['total_size']
  225. for count, e in enumerate(entries):
  226. if e['action'] == 'alloc':
  227. addr, size = e['addr'], e['size']
  228. n = _name()
  229. seg_name, seg_addr = find_segment(addr)
  230. if seg_name is None:
  231. seg_name = "MEM"
  232. offset = addr
  233. else:
  234. offset = addr - seg_addr
  235. out.write(f'{n} = {seg_name}[{offset}:{Bytes(size)}]\n')
  236. allocation_addr_to_name[addr] = (n, size, count)
  237. count += size
  238. elif e['action'] == 'free_requested':
  239. addr, size = e['addr'], e['size']
  240. name, _, _ = allocation_addr_to_name.get(addr, (addr, None, None))
  241. out.write(f'del {name} # {Bytes(size)}\n')
  242. elif e['action'] == 'free_completed':
  243. addr, size = e['addr'], e['size']
  244. count -= size
  245. name, _, _ = allocation_addr_to_name.get(addr, (addr, None, None))
  246. out.write(f'# free completed for {name} {Bytes(size)}\n')
  247. if name in allocation_addr_to_name:
  248. free_names.append(name)
  249. del allocation_addr_to_name[name]
  250. elif e['action'] == 'segment_alloc':
  251. addr, size = e['addr'], e['size']
  252. name = _name()
  253. out.write(f'{name} = cudaMalloc({addr}, {Bytes(size)})\n')
  254. segment_intervals.append((name, addr, size))
  255. segment_addr_to_name[addr] = name
  256. elif e['action'] == 'segment_free':
  257. addr, size = e['addr'], e['size']
  258. name = segment_addr_to_name.get(addr, addr)
  259. out.write(f'cudaFree({name}) # {Bytes(size)}\n')
  260. if name in segment_addr_to_name:
  261. free_names.append(name)
  262. del segment_addr_to_name[name]
  263. elif e['action'] == 'oom':
  264. size = e['size']
  265. free = e['device_free']
  266. out.write(f'raise OutOfMemoryError() # {Bytes(size)} requested, {Bytes(free)} free in CUDA\n')
  267. else:
  268. out.write(f'{e}\n')
  269. out.write(f"TOTAL MEM: {Bytes(count)}")
  270. for i, d in enumerate(data['device_traces']):
  271. if d:
  272. out.write(f'Device {i} ----------------\n')
  273. format(d)
  274. return out.getvalue()
  275. class PlotWriter:
  276. def __init__(self):
  277. string_table: List[str] = []
  278. suffix_table: List[Tuple[int, int]] = []
  279. elements = []
  280. actions: List[int] = []
  281. initially_allocated: List[int] = []
  282. @cache
  283. def intern_str(s):
  284. string_table.append(s)
  285. return len(string_table) - 1
  286. @cache
  287. def intern_suffix(sid, restid):
  288. suffix_table.append((sid, restid))
  289. return len(suffix_table) - 1
  290. def intern_stack(frames):
  291. sids = [intern_str(f) for f in frames]
  292. next_id = None
  293. for sid in reversed(sids):
  294. next_id = intern_suffix(sid, next_id)
  295. return next_id
  296. def add_element(size, lines):
  297. elements.append({'size': size, 'info': intern_stack(lines)})
  298. return len(elements) - 1
  299. def to_html():
  300. r = {
  301. 'actions': actions,
  302. 'elements': elements,
  303. 'suffix_table': suffix_table,
  304. 'string_table': string_table,
  305. 'initially_allocated': list(reversed(initially_allocated)),
  306. }
  307. plot_data = json.dumps(r)
  308. return _memory_over_time_template.replace('$PLOT_DATA', plot_data)
  309. self.add_element = add_element
  310. self.allocate = actions.append
  311. self.free = actions.append
  312. self.initially_allocated = initially_allocated.append
  313. self.to_html = to_html
  314. def trace_plot(data, device=None, plot_segments=False):
  315. w = PlotWriter()
  316. addr_to_alloc = {}
  317. if device is None:
  318. for i, t in enumerate(data['device_traces']):
  319. if len(t) > 0:
  320. if device is not None:
  321. raise ValueError(f'Both device {device} and {i} have traces, use --device to specify which trace.')
  322. device = i
  323. if device is None:
  324. raise ValueError('No trace information was recorded.')
  325. trace = data['device_traces'][device]
  326. if plot_segments:
  327. alloc = 'segment_alloc'
  328. free = 'segment_free'
  329. else:
  330. alloc = 'alloc'
  331. free = 'free_completed'
  332. def add_element(size, frames, extra=()):
  333. frames = [f"{_format_size(size)} allocation", *extra, *(_frame_fmt(f, full_filename=True) for f in frames)]
  334. return w.add_element(size, frames)
  335. for i, e in enumerate(trace):
  336. if e['action'] == alloc:
  337. elemid = add_element(e['size'], e['frames'])
  338. addr_to_alloc[e['addr']] = elemid
  339. w.allocate(elemid)
  340. elif e['action'] == free:
  341. idx = addr_to_alloc.pop(e['addr'], None)
  342. if idx is None:
  343. idx = add_element(e['size'], e['frames'], extra=('alloc not recorded, stack trace for free:',))
  344. w.initially_allocated(idx)
  345. w.free(idx)
  346. return w.to_html()
  347. def profile_plot(memory_profile, device=None):
  348. import torch
  349. from torch.profiler._memory_profiler import Action, TensorKey
  350. from torch._C._profiler import _EventType
  351. if device is None:
  352. if torch.cuda.is_available():
  353. device = torch.device('cuda', torch.cuda.current_device())
  354. else:
  355. device = torch.device('cpu')
  356. w = PlotWriter()
  357. allocation_stacks = {}
  358. for event in memory_profile._op_tree.sorted_nodes:
  359. if event.tag == _EventType.Allocation:
  360. parent = event.parent
  361. python_parents = []
  362. while parent:
  363. if parent.tag in (_EventType.PyCall, _EventType.PyCCall):
  364. python_parents.append(parent)
  365. parent = parent.parent
  366. key = TensorKey.from_allocation(event.extra_fields)
  367. # Corner case: If allocation doesn't have an ID (can't prove it was used as a Tensor)
  368. # key will be None. I should add some way to identify these, I just haven't yet.
  369. if key and event.extra_fields.alloc_size > 0:
  370. allocation_stacks[key] = python_parents
  371. def add_element(size, tensor_key, version):
  372. category = memory_profile._categories.get(tensor_key, version)
  373. if category is None:
  374. category = 'unknown'
  375. else:
  376. category = category.name.lower()
  377. stack = allocation_stacks.get(tensor_key, ())
  378. return w.add_element(size, [f"{_format_size(size)} allocation ({category})", *(p.name for p in stack)])
  379. kv_to_elem = {}
  380. for time, action, (tensor_key, version), size in memory_profile.timeline:
  381. if tensor_key.device != device:
  382. continue
  383. if action == Action.CREATE:
  384. kv_to_elem[(tensor_key, version)] = elemid = add_element(size, tensor_key, version)
  385. w.allocate(elemid)
  386. elif action == Action.DESTROY:
  387. w.free(kv_to_elem.pop((tensor_key, version)))
  388. elif action == Action.INCREMENT_VERSION:
  389. w.free(kv_to_elem.pop((tensor_key, version)))
  390. kv_to_elem[(tensor_key, version + 1)] = elemid = add_element(size, tensor_key, version + 1)
  391. w.allocate(elemid)
  392. elif action == Action.PREEXISTING:
  393. kv_to_elem[(tensor_key, version)] = elemid = add_element(size, tensor_key, version)
  394. w.initially_allocated(elemid)
  395. return w.to_html()
  396. # note: this template should eventually move to its own file,
  397. # however, we first need to package _memory_viz.py so that it can be
  398. # pip-installed separately from pytorch so it is easy to run e.g.
  399. # on a laptop with downloaded snapshots. Currently this is
  400. # accomplished by downloading _memory_viz.py so the template
  401. # needs to be included
  402. _memory_over_time_template = r"""
  403. <!DOCTYPE html>
  404. <html>
  405. <head></head>
  406. <body>
  407. <script type="module">
  408. import * as d3 from "https://cdn.jsdelivr.net/npm/d3@7.7.0/+esm";
  409. import {schemeTableau10} from "https://cdn.skypack.dev/d3-scale-chromatic@3";
  410. import {axisLeft} from "https://cdn.skypack.dev/d3-axis@3";
  411. import {scaleLinear} from "https://cdn.skypack.dev/d3-scale@4";
  412. import {zoom, zoomIdentity} from "https://cdn.skypack.dev/d3-zoom@3";
  413. import {brushX} from "https://cdn.skypack.dev/d3-brush@3";
  414. let alloc_data = $PLOT_DATA
  415. function process_alloc_data(fraction_of_memory_reported=1) {
  416. let current = []
  417. let current_data = []
  418. let data = []
  419. let max_size = 0
  420. let total_mem = 0
  421. let timestep = 0
  422. let max_at_time = []
  423. function advance(n, max) {
  424. timestep += n
  425. for (let i = 0; i < n; i++) {
  426. max_at_time.push(max)
  427. }
  428. }
  429. let mini_points = []
  430. let sizes = alloc_data.elements.map(x => x.size).sort((x, y) => y - x)
  431. let total_size = sizes.reduce((x, y) => x + y)
  432. const memory_threshold = fraction_of_memory_reported * total_size
  433. let total_seen = 0
  434. let memory_threshold_size = 0
  435. for (const [i, size] of sizes.entries()) {
  436. total_seen += size
  437. if (total_seen > memory_threshold) {
  438. memory_threshold_size = size
  439. break
  440. }
  441. }
  442. function add_allocation(elem) {
  443. let size = alloc_data.elements[elem].size
  444. current.push(elem)
  445. let e = {elem: elem, timesteps: [timestep], offsets: [total_mem], size: alloc_data.elements[elem].size}
  446. current_data.push(e)
  447. data.push(e)
  448. total_mem += size
  449. }
  450. for (const elem of alloc_data.initially_allocated) {
  451. add_allocation(elem)
  452. }
  453. for (const action of alloc_data.actions) {
  454. const elem = action
  455. const idx = current.findIndex(x => x === elem)
  456. const size = alloc_data.elements[elem].size
  457. if (size < memory_threshold_size) {
  458. continue
  459. }
  460. // first time we see an action we add it
  461. // second time we remove it
  462. if (idx == -1) {
  463. add_allocation(elem)
  464. advance(1, total_mem)
  465. } else {
  466. advance(1, total_mem)
  467. const removed = current_data[idx]
  468. removed.timesteps.push(timestep)
  469. removed.offsets.push(removed.offsets.at(-1))
  470. current.splice(idx, 1)
  471. current_data.splice(idx, 1)
  472. if (idx < current.length) {
  473. for (let j = idx; j < current.length; j++) {
  474. const e = current_data[j]
  475. e.timesteps.push(timestep)
  476. e.offsets.push(e.offsets.at(-1))
  477. e.timesteps.push(timestep + 3)
  478. e.offsets.push(e.offsets.at(-1) - size)
  479. }
  480. advance(3, total_mem)
  481. }
  482. total_mem -= size
  483. }
  484. max_size = Math.max(total_mem, max_size)
  485. }
  486. for (const elem of current_data) {
  487. elem.timesteps.push(timestep)
  488. elem.offsets.push(elem.offsets.at(-1))
  489. }
  490. return {
  491. max_size: max_size,
  492. allocations_over_time: data,
  493. max_at_time: max_at_time,
  494. context_for_id: (elem) => {
  495. let strings = []
  496. let id = alloc_data.elements[elem].info
  497. while (id !== null) {
  498. const [sid, next_id] = alloc_data.suffix_table[id]
  499. strings.push(alloc_data.string_table[sid])
  500. id = next_id
  501. }
  502. return `${strings.join('\n')}\n`
  503. }
  504. }
  505. }
  506. function MemoryPlot(svg, data, left_pad, colors=schemeTableau10) {
  507. function format_points(d) {
  508. const size = d.size
  509. const xs = d.timesteps.map(t => xscale(t))
  510. const bottom = d.offsets.map(t => yscale(t))
  511. const top = d.offsets.map(t => yscale(t + size))
  512. const p0 = xs.map((x, i) => `${x},${bottom[i]}`)
  513. const p1 = xs.map((x, i) => `${x},${top[i]}`).reverse()
  514. return `${p0.join(' ')} ${p1.join(' ')}`
  515. }
  516. let max_timestep = data.max_at_time.length
  517. let max_size = data.max_size
  518. let width = svg.attr('width')
  519. let height = svg.attr('height')
  520. let plot_width = width - left_pad
  521. let plot_height = height
  522. let yscale = scaleLinear().domain([0, max_size]).range([plot_height, 0]);
  523. let heightscale = scaleLinear().domain([0, max_size]).range([0, plot_height]);
  524. let yaxis = axisLeft(yscale).tickFormat(d3.format("~s"))
  525. let xscale = scaleLinear().domain([0, max_timestep]).range([0, plot_width])
  526. let plot_coordinate_space = svg.append("g").attr("transform", `translate(${left_pad}, ${0})`)
  527. let plot_outer = plot_coordinate_space.append('g')
  528. function view_rect(a) {
  529. return a.append('rect').attr('x', 0).attr('y', 0)
  530. .attr('width', plot_width).attr('height', plot_height)
  531. .attr('fill', 'white')
  532. }
  533. view_rect(plot_outer)
  534. let cp = svg.append("clipPath").attr("id", "clip")
  535. view_rect(cp)
  536. plot_outer.attr('clip-path', "url(#clip)")
  537. let zoom_group = plot_outer.append("g")
  538. let scrub_group = zoom_group.append('g')
  539. let plot = scrub_group.selectAll("polygon")
  540. .data(data.allocations_over_time)
  541. .enter()
  542. .append("polygon")
  543. .attr('points', format_points)
  544. .attr('fill', d => colors[d.elem % colors.length])
  545. let axis = plot_coordinate_space.append('g').call(yaxis)
  546. let scale_mini = 0
  547. let translate_mini = 0
  548. function handleZoom(e) {
  549. const t = e.transform
  550. zoom_group.attr("transform", t)
  551. axis.call(yaxis.scale(e.transform.rescaleY(yscale)))
  552. }
  553. const thezoom = zoom().on('zoom', handleZoom)
  554. plot_outer.call(thezoom)
  555. return {
  556. select_window: (stepbegin, stepend, max) => {
  557. let begin = xscale(stepbegin)
  558. let size = xscale(stepend) - xscale(stepbegin);
  559. let scale = plot_width / size
  560. let translate = -begin
  561. let yscale = max_size/max
  562. scrub_group.attr("transform", `scale(${scale/yscale}, 1) translate(${translate}, 0)`)
  563. plot_outer.call(thezoom.transform, zoomIdentity.scale(yscale).translate(0, -(plot_height - plot_height/yscale)))
  564. },
  565. set_delegate: (delegate) => {
  566. plot.on('mouseover', function (e, d) { delegate.set_selected(d3.select(this)) } )
  567. .on('mousedown', function(e, d) { delegate.default_selected = d3.select(this)})
  568. .on('mouseleave', function (e, d) { delegate.set_selected(delegate.default_selected) } )
  569. }
  570. }
  571. }
  572. function ContextViewer(text, data) {
  573. let current_selected = null
  574. return {
  575. default_selected: null,
  576. set_selected: (d) => {
  577. if (current_selected !== null) {
  578. current_selected.attr('stroke', null).attr('stroke-width', null);
  579. }
  580. if (d === null) {
  581. text.text("")
  582. } else {
  583. const dd = d.datum()
  584. text.text(`${dd.elem} ${data.context_for_id(dd.elem)}`)
  585. d.attr('stroke', 'black').attr('stroke-width', 1).attr('vector-effect', 'non-scaling-stroke')
  586. }
  587. current_selected = d
  588. }
  589. }
  590. }
  591. function MiniMap(mini_svg, plot, data, left_pad, height=70) {
  592. let max_at_time = data.max_at_time
  593. let width = mini_svg.attr('width')
  594. let plot_width = width - left_pad
  595. let yscale = scaleLinear().domain([0, data.max_size]).range([height, 0]);
  596. let minixscale = scaleLinear().domain([0, max_at_time.length]).range([left_pad, width])
  597. let mini_points = [[max_at_time.length, 0], [0, 0]]
  598. for (const [i, m] of max_at_time.entries()) {
  599. let [lastx, lasty] = mini_points[mini_points.length - 1]
  600. if (m !== lasty) {
  601. mini_points.push([i, lasty])
  602. mini_points.push([i, m])
  603. } else if (i === max_at_time.length - 1) {
  604. mini_points.push([i, m])
  605. }
  606. }
  607. let points = mini_points.map(([t, o]) => `${minixscale(t)}, ${yscale(o)}`)
  608. points = points.join(' ')
  609. mini_svg.append('polygon').attr('points', points).attr('fill', schemeTableau10[0])
  610. let xscale = scaleLinear().domain([0, max_at_time.length]).range([0, plot_width])
  611. const brush = brushX()
  612. brush.extent([[left_pad, 0], [width, height]])
  613. brush.on('brush', function({selection}) {
  614. let [begin, end] = selection.map(x => x - left_pad)
  615. let stepbegin = Math.floor(xscale.invert(begin))
  616. let stepend = Math.floor(xscale.invert(end))
  617. let max = 0
  618. for (let i = stepbegin; i < stepend; i++) {
  619. max = Math.max(max, max_at_time[i])
  620. }
  621. plot.select_window(stepbegin, stepend, max)
  622. })
  623. mini_svg.call(brush)
  624. return {}
  625. }
  626. let left_pad = 70
  627. let width = 1024
  628. let height = 768
  629. let data = process_alloc_data()
  630. let body = d3.select("body")
  631. let plot = MemoryPlot(body.append("svg").attr('width', width).attr('height', height).attr('display', 'block'), data, left_pad)
  632. MiniMap(body.append("svg").attr('width', width).attr('height', 80).attr('display', 'block'), plot, data, left_pad)
  633. let delegate = ContextViewer(body.append("div").append("pre").text('none'), data)
  634. plot.set_delegate(delegate)
  635. </script>
  636. </body>
  637. </html>
  638. """
  639. if __name__ == "__main__":
  640. import os.path
  641. thedir = os.path.realpath(os.path.dirname(__file__))
  642. if thedir in sys.path:
  643. # otherwise we find cuda/random.py as random...
  644. sys.path.remove(thedir)
  645. import argparse
  646. fn_name = 'torch.cuda.memory._snapshot()'
  647. pickled = f'pickled memory statistics from {fn_name}'
  648. parser = argparse.ArgumentParser(description=f'Visualize memory dumps produced by {fn_name}')
  649. subparsers = parser.add_subparsers(dest='action')
  650. def _output(p):
  651. p.add_argument('-o', '--output', default='output.svg', help='flamegraph svg (default: output.svg)')
  652. description = 'Prints overall allocation statistics and a visualization of how the allocators segments are currently filled.'
  653. stats_a = subparsers.add_parser('stats', description=description)
  654. stats_a.add_argument('input', help=pickled)
  655. description = 'Prints buffer of the most recent allocation events embedded in the snapshot in a Pythonic style.'
  656. trace_a = subparsers.add_parser('trace', description=description)
  657. trace_a.add_argument('input', help=pickled)
  658. description = 'Generate a flamegraph that visualizes what memory is stored in each allocator segment (aka block)'
  659. segments_a = subparsers.add_parser('segments', description=description)
  660. segments_a.add_argument('input', help=pickled)
  661. _output(segments_a)
  662. description = "Generate a flamegraph the program locations contributing to CUDA memory usage."
  663. memory_a = subparsers.add_parser('memory', description=description)
  664. memory_a.add_argument('input', help=pickled)
  665. _output(memory_a)
  666. description = 'Generate a flamegraph that shows segments (aka blocks) that have been added ' \
  667. 'or removed between two different memorys snapshots.'
  668. compare_a = subparsers.add_parser('compare', description=description)
  669. compare_a.add_argument('before', help=pickled)
  670. compare_a.add_argument('after', help=pickled)
  671. _output(compare_a)
  672. description = "Generate a visualization over time of the memory usage recorded by the trace as an html file."
  673. trace_plot_a = subparsers.add_parser('trace_plot', description=description)
  674. trace_plot_a.add_argument('input', help=pickled)
  675. help = 'visualize trace from this device (default: chooses the only device with trace info or errors)'
  676. trace_plot_a.add_argument('-d', '--device', type=int, default=None, help=help)
  677. help = 'path to save the visualization(default: output.html)'
  678. trace_plot_a.add_argument('-o', '--output', default='output.html', help=help)
  679. help = 'visualize change to segments rather than individual allocations'
  680. trace_plot_a.add_argument('-s', '--segments', action='store_true', help=help)
  681. args = parser.parse_args()
  682. def _read(name):
  683. if name == '-':
  684. f = sys.stdin.buffer
  685. else:
  686. f = open(name, 'rb')
  687. data = pickle.load(f)
  688. if isinstance(data, list): # segments only...
  689. data = {'segments': data, 'traces': []}
  690. return data
  691. def _write(name, data):
  692. with open(name, 'w') as f:
  693. f.write(data)
  694. if args.action == 'segments':
  695. data = _read(args.input)
  696. _write(args.output, segments(data))
  697. elif args.action == 'memory':
  698. data = _read(args.input)
  699. _write(args.output, memory(data))
  700. elif args.action == 'stats':
  701. data = _read(args.input)
  702. print(segsum(data))
  703. elif args.action == 'trace':
  704. data = _read(args.input)
  705. print(trace(data))
  706. elif args.action == 'compare':
  707. before = _read(args.before)
  708. after = _read(args.after)
  709. _write(args.output, compare(before, after))
  710. elif args.action == 'trace_plot':
  711. data = _read(args.input)
  712. _write(args.output, trace_plot(data, device=args.device, plot_segments=args.segments))