profiler.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891
  1. from typing import Any, Dict, List, Optional
  2. from collections import defaultdict
  3. from warnings import warn
  4. import torch
  5. import torch.cuda
  6. from torch._C._profiler import _ExperimentalConfig
  7. from torch.autograd import (
  8. _disable_profiler,
  9. _enable_profiler,
  10. _kineto_step,
  11. _prepare_profiler,
  12. _ProfilerResult,
  13. _supported_activities,
  14. DeviceType,
  15. kineto_available,
  16. ProfilerActivity,
  17. ProfilerConfig,
  18. ProfilerState,
  19. )
  20. from torch.autograd.profiler_util import (
  21. _filter_name,
  22. _filter_stack_entry,
  23. _rewrite_name,
  24. EventList,
  25. FunctionEvent,
  26. MEMORY_EVENT_NAME,
  27. MemRecordsAcc,
  28. OUT_OF_MEMORY_EVENT_NAME,
  29. )
  30. from torch.futures import Future
  31. __all__ = ["profile", "record_function", "emit_itt", "emit_nvtx", "load_nvprof", "EnforceUnique",
  32. "parse_nvprof_trace", "KinetoStepTracker", "EventList", "FunctionEvent", "MemRecordsAcc"]
  33. try:
  34. # Available in Python >= 3.2
  35. from contextlib import ContextDecorator as _ContextDecorator
  36. except ImportError:
  37. import functools
  38. class _ContextDecorator: # type: ignore[no-redef]
  39. def __enter__(self):
  40. raise NotImplementedError
  41. def __exit__(self, exc_type, exc_val, exc_tb):
  42. raise NotImplementedError
  43. def __call__(self, func):
  44. @functools.wraps(func)
  45. def wrapped(*args, **kwargs):
  46. with self:
  47. return func(*args, **kwargs)
  48. return wrapped
  49. class profile:
  50. """Context manager that manages autograd profiler state and holds a summary of results.
  51. Under the hood it just records events of functions being executed in C++ and
  52. exposes those events to Python. You can wrap any code into it and it will
  53. only report runtime of PyTorch functions.
  54. Note: profiler is thread local and is automatically propagated into the async tasks
  55. Args:
  56. enabled (bool, optional): Setting this to False makes this context manager a no-op.
  57. use_cuda (bool, optional): Enables timing of CUDA events as well using the cudaEvent API.
  58. Adds approximately 4us of overhead to each tensor operation.
  59. record_shapes (bool, optional): If shapes recording is set, information
  60. about input dimensions will be collected. This allows one to see which
  61. dimensions have been used under the hood and further group by them
  62. using prof.key_averages(group_by_input_shape=True). Please note that
  63. shape recording might skew your profiling data. It is recommended to
  64. use separate runs with and without shape recording to validate the timing.
  65. Most likely the skew will be negligible for bottom most events (in a case
  66. of nested function calls). But for higher level functions the total
  67. self cpu time might be artificially increased because of the shape
  68. collection.
  69. with_flops (bool, optional): If with_flops is set, the profiler will estimate
  70. the FLOPs (floating point operations) value using the operator's input shape.
  71. This allows one to estimate the hardware performance. Currently,
  72. this option only works for the matrix multiplication and 2D convolution operators.
  73. profile_memory (bool, optional): track tensor memory allocation/deallocation.
  74. with_stack (bool, optional): record source information (file and line number) for the ops.
  75. with_modules (bool): record module hierarchy (including function names)
  76. corresponding to the callstack of the op. e.g. If module A's forward call's
  77. module B's forward which contains an aten::add op,
  78. then aten::add's module hierarchy is A.B
  79. Note that this support exist, at the moment, only for TorchScript models
  80. and not eager mode models.
  81. use_kineto (bool, optional): experimental, enable profiling with Kineto profiler.
  82. use_cpu (bool, optional): profile CPU events; setting to ``False`` requires
  83. ``use_kineto=True`` and can be used to lower the overhead for GPU-only profiling.
  84. experimental_config (_ExperimentalConfig) : A set of experimental options
  85. used by profiler libraries like Kineto. Note, backward compatibility is not guaranteed.
  86. .. warning:
  87. Enabling memory profiling or source attribution incurs additional profiler
  88. overhead
  89. .. warning:
  90. This context managers should not be called recursively, i.e. no nested
  91. instances are allowed
  92. .. warning:
  93. Due to some CUDA multiprocessing limitations (multiprocessing-cuda-note_),
  94. one cannot use the profiler with ``use_cuda = True`` to benchmark
  95. DataLoaders with ``num_workers > 0``. If you wish to benchmark data loading,
  96. please use ``use_cuda = False`` or ``num_workers = 0``.
  97. Example:
  98. >>> # xdoctest: +SKIP
  99. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER)
  100. >>> x = torch.randn((1, 1), requires_grad=True)
  101. >>> with torch.autograd.profiler.profile() as prof:
  102. >>> for _ in range(100): # any normal python code, really!
  103. >>> y = x ** 2
  104. >>> y.backward()
  105. >>> # NOTE: some columns were removed for brevity
  106. >>> print(prof.key_averages().table(sort_by="self_cpu_time_total"))
  107. ----------------------------------- --------------- --------------- ---------------
  108. Name Self CPU total CPU time avg Number of Calls
  109. ----------------------------------- --------------- --------------- ---------------
  110. mul 32.048ms 32.048ms 200
  111. pow 27.041ms 27.041ms 200
  112. PowBackward0 9.727ms 55.483ms 100
  113. torch::autograd::AccumulateGrad 9.148ms 9.148ms 100
  114. torch::autograd::GraphRoot 691.816us 691.816us 100
  115. ----------------------------------- --------------- --------------- ---------------
  116. """
  117. def __init__(
  118. self,
  119. enabled=True,
  120. *,
  121. use_cuda=False,
  122. record_shapes=False,
  123. with_flops=False,
  124. profile_memory=False,
  125. with_stack=False,
  126. with_modules=False,
  127. use_kineto=False,
  128. use_cpu=True,
  129. experimental_config=None):
  130. self.enabled: bool = enabled
  131. if not self.enabled:
  132. return
  133. self.use_cuda = use_cuda
  134. self.function_events: Optional[EventList] = None
  135. self.entered = False
  136. self.record_shapes = record_shapes
  137. self.with_flops = with_flops
  138. self.record_shapes |= self.with_flops
  139. self.profile_memory = profile_memory
  140. self.with_stack = with_stack
  141. self.with_modules = with_modules
  142. self.use_cpu = use_cpu
  143. if experimental_config is None:
  144. experimental_config = _ExperimentalConfig()
  145. self.experimental_config = experimental_config
  146. self.kineto_results: Optional[_ProfilerResult] = None
  147. if not self.use_cpu:
  148. assert use_kineto, \
  149. "Device-only events supported only with Kineto (use_kineto=True)"
  150. if self.use_cuda and not torch.cuda.is_available():
  151. warn("CUDA is not available, disabling CUDA profiling")
  152. self.use_cuda = False
  153. self.kineto_activities = set()
  154. if self.use_cpu:
  155. self.kineto_activities.add(ProfilerActivity.CPU)
  156. self.profiler_kind = ProfilerState.KINETO
  157. if self.use_cuda:
  158. if (not use_kineto or ProfilerActivity.CUDA not in
  159. _supported_activities()):
  160. assert self.use_cpu, "Legacy CUDA profiling requires use_cpu=True"
  161. self.profiler_kind = ProfilerState.KINETO_GPU_FALLBACK
  162. else:
  163. self.kineto_activities.add(ProfilerActivity.CUDA)
  164. assert len(self.kineto_activities) > 0, \
  165. "No activities specified for the profiler"
  166. def config(self):
  167. return ProfilerConfig(
  168. self.profiler_kind,
  169. self.record_shapes,
  170. self.profile_memory,
  171. self.with_stack,
  172. self.with_flops,
  173. self.with_modules,
  174. self.experimental_config)
  175. def __enter__(self):
  176. if not self.enabled:
  177. return
  178. if self.entered:
  179. raise RuntimeError("Profiler context manager is not reentrant")
  180. self._prepare_trace()
  181. self._start_trace()
  182. return self
  183. def _prepare_trace(self):
  184. self.entered = True
  185. _prepare_profiler(self.config(), self.kineto_activities)
  186. def _start_trace(self):
  187. self.entered = True
  188. _enable_profiler(self.config(), self.kineto_activities)
  189. def __exit__(self, exc_type, exc_val, exc_tb):
  190. if not self.enabled:
  191. return
  192. if self.use_cuda:
  193. torch.cuda.synchronize()
  194. self.kineto_results = _disable_profiler()
  195. parsed_results = self._parse_kineto_results(self.kineto_results)
  196. self.function_events = EventList(
  197. parsed_results,
  198. use_cuda=self.use_cuda,
  199. profile_memory=self.profile_memory,
  200. with_flops=self.with_flops)
  201. self.function_events._build_tree()
  202. return False
  203. def __repr__(self):
  204. if self.function_events is None:
  205. return '<unfinished torch.autograd.profile>'
  206. return repr(self.function_events)
  207. def __str__(self):
  208. if self.function_events is None:
  209. return '<unfinished torch.autograd.profile>'
  210. return str(self.function_events)
  211. def _check_finish(self):
  212. if self.function_events is None:
  213. raise RuntimeError("Profiler didn't finish running")
  214. def table(
  215. self,
  216. sort_by=None,
  217. row_limit=100,
  218. max_src_column_width=75,
  219. max_name_column_width=55,
  220. max_shapes_column_width=80,
  221. header=None,
  222. top_level_events_only=False
  223. ):
  224. self._check_finish()
  225. assert self.function_events is not None
  226. return self.function_events.table(
  227. sort_by=sort_by,
  228. row_limit=row_limit,
  229. max_src_column_width=max_src_column_width,
  230. max_name_column_width=max_name_column_width,
  231. max_shapes_column_width=max_shapes_column_width,
  232. header=header,
  233. top_level_events_only=top_level_events_only
  234. )
  235. table.__doc__ = EventList.table.__doc__
  236. def export_chrome_trace(self, path):
  237. self._check_finish()
  238. if kineto_available():
  239. self.kineto_results.save(path) # type: ignore[union-attr]
  240. else:
  241. return self.function_events.export_chrome_trace(path) # type: ignore[union-attr]
  242. export_chrome_trace.__doc__ = EventList.export_chrome_trace.__doc__
  243. def export_stacks(self, path: str, metric: str = "self_cpu_time_total"):
  244. self._check_finish()
  245. assert self.function_events is not None, "Expected profiling results"
  246. assert self.with_stack, "export_stacks() requires with_stack=True"
  247. return self.function_events.export_stacks(path, metric)
  248. def key_averages(self, group_by_input_shape=False, group_by_stack_n=0):
  249. self._check_finish()
  250. assert self.function_events is not None, "Expected profiling results"
  251. return self.function_events.key_averages(group_by_input_shape, group_by_stack_n)
  252. key_averages.__doc__ = EventList.key_averages.__doc__
  253. def total_average(self):
  254. self._check_finish()
  255. assert self.function_events is not None, "Expected profiling results"
  256. return self.function_events.total_average()
  257. total_average.__doc__ = EventList.total_average.__doc__
  258. @property
  259. def self_cpu_time_total(self):
  260. """ Returns total time spent on CPU obtained as a sum of
  261. all self times across all the events.
  262. """
  263. self._check_finish()
  264. assert self.function_events is not None
  265. return self.function_events.self_cpu_time_total
  266. def _parse_kineto_results(self, result):
  267. # result.events() has most of the events - PyTorch op-level and device-level events
  268. trace_start_us = result.trace_start_us()
  269. mem_records = [[evt, False] for evt in result.events() if evt.name() == MEMORY_EVENT_NAME]
  270. oom_records = [evt for evt in result.events() if evt.name() == OUT_OF_MEMORY_EVENT_NAME]
  271. mem_records_acc = MemRecordsAcc(mem_records)
  272. def _cpu_memory_usage(mem_record):
  273. return mem_record.nbytes() if \
  274. mem_record.device_type() in [DeviceType.CPU, DeviceType.MKLDNN, DeviceType.IDEEP] \
  275. else 0
  276. def _cuda_memory_usage(mem_record):
  277. return mem_record.nbytes() if \
  278. mem_record.device_type() in [DeviceType.CUDA, DeviceType.HIP] \
  279. else 0
  280. # Create and return FunctionEvent list
  281. function_events = []
  282. cuda_corr_map: Dict[int, List[FunctionEvent]] = {}
  283. max_evt_id = 0
  284. for kineto_event in result.events():
  285. if _filter_name(kineto_event.name()):
  286. continue
  287. rel_start_us = kineto_event.start_us() - trace_start_us
  288. rel_end_us = rel_start_us + kineto_event.duration_us()
  289. abs_end_us = kineto_event.start_us() + kineto_event.duration_us()
  290. cpu_memory_usage = 0
  291. cuda_memory_usage = 0
  292. if kineto_event.device_type() == DeviceType.CPU:
  293. # find the corresponding memory allocation events
  294. for mem_record in mem_records_acc.in_interval(kineto_event.start_us(), abs_end_us):
  295. cpu_memory_usage += _cpu_memory_usage(mem_record[0])
  296. cuda_memory_usage += _cuda_memory_usage(mem_record[0])
  297. mem_record[1] = True
  298. is_async = kineto_event.is_async() or (
  299. kineto_event.start_thread_id() != kineto_event.end_thread_id()
  300. )
  301. fe = FunctionEvent(
  302. id=kineto_event.correlation_id(),
  303. name=_rewrite_name(name=kineto_event.name(), with_wildcard=True),
  304. trace_name=_rewrite_name(name=kineto_event.name(), with_wildcard=False),
  305. thread=kineto_event.start_thread_id(),
  306. start_us=rel_start_us,
  307. end_us=rel_end_us,
  308. fwd_thread=kineto_event.fwd_thread_id(),
  309. input_shapes=kineto_event.shapes(),
  310. stack=[entry for entry in kineto_event.stack() if _filter_stack_entry(entry)],
  311. scope=kineto_event.scope(),
  312. cpu_memory_usage=cpu_memory_usage,
  313. cuda_memory_usage=cuda_memory_usage,
  314. is_async=is_async,
  315. sequence_nr=kineto_event.sequence_nr(),
  316. device_type=kineto_event.device_type(),
  317. device_index=kineto_event.device_index(),
  318. flops=kineto_event.flops(),
  319. )
  320. max_evt_id = fe.id if fe.id > max_evt_id else max_evt_id
  321. if fe.device_type == DeviceType.CPU and not fe.is_async:
  322. # Check if we have CUDA time as a fallback
  323. cuda_time = kineto_event.cuda_elapsed_us()
  324. if cuda_time > 0:
  325. fe.append_kernel(
  326. fe.name,
  327. fe.device_index,
  328. cuda_time)
  329. fe.is_legacy = True
  330. function_events.append(fe)
  331. corr_id = kineto_event.linked_correlation_id()
  332. if corr_id > 0:
  333. if corr_id not in cuda_corr_map:
  334. cuda_corr_map[corr_id] = []
  335. cuda_corr_map[corr_id].append(fe)
  336. # associate CUDA kernels and CUDA runtime (CPU) with CPU events
  337. for fe in function_events:
  338. if (fe.device_type == DeviceType.CPU and not fe.is_async and
  339. fe.id in cuda_corr_map):
  340. for f_evt in cuda_corr_map[fe.id]:
  341. if f_evt.device_type == DeviceType.CUDA:
  342. fe.append_kernel(
  343. f_evt.name,
  344. f_evt.device_index,
  345. f_evt.time_range.end - f_evt.time_range.start)
  346. elif f_evt.device_type == DeviceType.CPU:
  347. # make sure that 'thread' of a CPU Kineto (e.g. CUDA Runtime) event is associated
  348. # with the 'thread' of the corresponding linked PyTorch event to properly track
  349. # parents and children
  350. f_evt.thread = fe.thread
  351. def createFunctionEventForMemoryEvents(evt):
  352. rel_start_us = evt.start_us() - trace_start_us
  353. fe = FunctionEvent(
  354. id=max_evt_id,
  355. name=evt.name(),
  356. trace_name=None, # not outputting in the trace
  357. thread=evt.start_thread_id(),
  358. start_us=rel_start_us,
  359. end_us=rel_start_us, # no duration
  360. fwd_thread=evt.start_thread_id(),
  361. input_shapes=[],
  362. stack=[],
  363. scope=0, # RecordScope::FUNCTION
  364. cpu_memory_usage=_cpu_memory_usage(evt),
  365. cuda_memory_usage=_cuda_memory_usage(evt),
  366. is_async=False,
  367. sequence_nr=-1,
  368. device_type=DeviceType.CPU,
  369. device_index=0,
  370. )
  371. return fe
  372. # output top-level memory events
  373. for mem_record in mem_records:
  374. if not mem_record[1]:
  375. max_evt_id += 1
  376. fe = createFunctionEventForMemoryEvents(mem_record[0])
  377. function_events.append(fe)
  378. for oom_record in oom_records:
  379. max_evt_id += 1
  380. fe = createFunctionEventForMemoryEvents(oom_record)
  381. function_events.append(fe)
  382. function_events.sort(key=lambda evt: [evt.time_range.start, -evt.time_range.end])
  383. return function_events
  384. class record_function(_ContextDecorator):
  385. """Context manager/function decorator that adds a label to a block of
  386. Python code (or function) when running autograd profiler. It is
  387. useful when tracing the code profile.
  388. Args:
  389. name (str): Label assigned to the block of code.
  390. node_id (int): ID of node, for distributed profiling. Unset in
  391. non-distributed cases.
  392. Example:
  393. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER)
  394. >>> x = torch.randn((1, 1), requires_grad=True)
  395. >>> with torch.autograd.profiler.profile() as prof:
  396. ... y = x ** 2
  397. ... with torch.autograd.profiler.record_function("label-z"): # label the block
  398. ... z = y ** 3
  399. ... y.backward()
  400. ...
  401. >>> # xdoctest: +IGNORE_WANT
  402. >>> # NOTE: some columns were removed for brevity
  403. >>> print(prof.key_averages().table(sort_by="self_cpu_time_total"))
  404. ----------------------------------- --------------- --------------- ---------------
  405. Name Self CPU total % CPU time avg Number of Calls
  406. ----------------------------------- --------------- --------------- ---------------
  407. pow 60.77% 47.470us 3
  408. mul 21.73% 25.465us 2
  409. PowBackward0 12.03% 121.891us 1
  410. torch::autograd::AccumulateGrad 2.70% 6.324us 1
  411. label-z 2.13% 12.421us 1
  412. torch::autograd::GraphRoot 0.64% 1.503us 1
  413. ----------------------------------- --------------- --------------- ---------------
  414. Self CPU time total: 234.344us
  415. CUDA time total: 0.000us
  416. """
  417. def __init__(self, name: str, args: Optional[str] = None):
  418. self.name: str = name
  419. self.args: Optional[str] = args
  420. # Whether or not we should run record function's end callbacks when exiting.
  421. self.run_callbacks_on_exit: bool = True
  422. # TODO: TorchScript ignores standard type annotation here
  423. # self.record: Optional["torch.classes.profiler._RecordFunction"] = None
  424. self.record = torch.jit.annotate(Optional["torch.classes.profiler._RecordFunction"], None)
  425. def __enter__(self):
  426. self.record = torch.ops.profiler._record_function_enter_new(self.name, self.args)
  427. return self
  428. def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any):
  429. if not self.run_callbacks_on_exit:
  430. return
  431. # Local variable is needed by TorchScript to refine Optional[T] to T
  432. record = self.record
  433. assert record is not None
  434. # TODO: Too slow with __torch_function__ handling enabled
  435. # See https://github.com/pytorch/pytorch/issues/76410
  436. if not torch.jit.is_scripting():
  437. with torch._C.DisableTorchFunctionSubclass():
  438. torch.ops.profiler._record_function_exit._RecordFunction(record)
  439. else:
  440. torch.ops.profiler._record_function_exit(record)
  441. def _call_end_callbacks_on_future(self, fut: Future[Any]) -> Future[Any]:
  442. """
  443. _call_end_callbacks_on_future is meant to be used for profiling async
  444. calls that return a future. Calling this function will extend recording
  445. beyond this scope, until the future is satisfied. It is useful for profiling
  446. the end to end time of asynchronous calls. This function should only be called
  447. once to attach the callback onto the future, and will throw if called multiple
  448. times.
  449. Args:
  450. fut: (torch._C.Future): future for which to schedule
  451. callback for.
  452. Returns:
  453. A future that completes with the value of the passed in future when
  454. the profiling callbacks have ran.
  455. """
  456. # Throw if we have already attached a callback onto the future.
  457. if not self.run_callbacks_on_exit:
  458. raise RuntimeError("_call_end_callbacks_on_future can only be called once.")
  459. # We are scheduling to run this RecordFunction's end callbacks when the
  460. # passed in future completes, so don't run end callbacks on exit.
  461. self.run_callbacks_on_exit = False
  462. # Local variable is needed by TorchScript to refine Optional[T] to T
  463. record = self.record
  464. assert record is not None
  465. # TODO: Too slow with __torch_function__ handling enabled
  466. # See https://github.com/pytorch/pytorch/issues/76410
  467. if not torch.jit.is_scripting():
  468. with torch._C.DisableTorchFunctionSubclass():
  469. profiled_future = torch.ops.profiler._call_end_callbacks_on_jit_fut._RecordFunction(
  470. record, fut)
  471. else:
  472. profiled_future = torch.ops.profiler._call_end_callbacks_on_jit_fut(record, fut)
  473. return profiled_future
  474. class emit_itt:
  475. """Context manager that makes every autograd operation emit an ITT range.
  476. It is useful when running the program under Intel(R) VTune Profiler::
  477. vtune <--vtune-flags> <regular command here>
  478. The Instrumentation and Tracing Technology (ITT) API enables your application to generate and
  479. control the collection of trace data during its execution across different Intel tools.
  480. This context manager is to annotate Intel(R) VTune Profiling trace. With help of this context manager,
  481. you will be able to see labled ranges in Intel(R) VTune Profiler GUI.
  482. .. warning:
  483. This context manager should not be called recursively, i.e. at most one
  484. instance should be enabled at any given time.
  485. Args:
  486. enabled (bool, optional): Setting ``enabled=False`` makes this context manager a no-op.
  487. Default: ``True``.
  488. record_shapes (bool, optional): If ``record_shapes=True``, the itt range wrapping
  489. each autograd op will append information about the sizes of Tensor arguments received
  490. by that op, in the following format:
  491. ``[[arg0.size(0), arg0.size(1), ...], [arg1.size(0), arg1.size(1), ...], ...]``
  492. Non-tensor arguments will be represented by ``[]``.
  493. Arguments will be listed in the order they are received by the backend op.
  494. Please note that this order may not match the order in which those arguments were passed
  495. on the Python side. Also note that shape recording may increase the overhead of itt range creation.
  496. Default: ``False``
  497. Example:
  498. >>> # xdoctest: +SKIP("Undefined variables")
  499. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER)
  500. >>> with torch.autograd.profiler.emit_itt():
  501. ... model(x)
  502. """
  503. def __init__(self, enabled=True, record_shapes=False):
  504. self.enabled = enabled
  505. self.entered = False
  506. self.record_shapes = record_shapes
  507. def __enter__(self):
  508. if not self.enabled:
  509. return
  510. if self.entered:
  511. raise RuntimeError("ITT annotation context manager is not reentrant")
  512. self.entered = True
  513. _enable_profiler(
  514. ProfilerConfig(
  515. ProfilerState.ITT,
  516. self.record_shapes,
  517. False,
  518. False,
  519. False,
  520. False,
  521. _ExperimentalConfig()),
  522. set()
  523. )
  524. return self
  525. def __exit__(self, exc_type, exc_val, exc_tb):
  526. if not self.enabled:
  527. return
  528. _disable_profiler()
  529. return False
  530. class emit_nvtx:
  531. """Context manager that makes every autograd operation emit an NVTX range.
  532. It is useful when running the program under nvprof::
  533. nvprof --profile-from-start off -o trace_name.prof -- <regular command here>
  534. Unfortunately, there's no way to force nvprof to flush the data it collected
  535. to disk, so for CUDA profiling one has to use this context manager to annotate
  536. nvprof traces and wait for the process to exit before inspecting them.
  537. Then, either NVIDIA Visual Profiler (nvvp) can be used to visualize the timeline, or
  538. :func:`torch.autograd.profiler.load_nvprof` can load the results for inspection
  539. e.g. in Python REPL.
  540. .. warning:
  541. This context manager should not be called recursively, i.e. at most one
  542. instance should be enabled at any given time.
  543. Args:
  544. enabled (bool, optional): Setting ``enabled=False`` makes this context manager a no-op.
  545. Default: ``True``.
  546. record_shapes (bool, optional): If ``record_shapes=True``, the nvtx range wrapping
  547. each autograd op will append information about the sizes of Tensor arguments received
  548. by that op, in the following format:
  549. ``[[arg0.size(0), arg0.size(1), ...], [arg1.size(0), arg1.size(1), ...], ...]``
  550. Non-tensor arguments will be represented by ``[]``.
  551. Arguments will be listed in the order they are received by the backend op.
  552. Please note that this order may not match the order in which those arguments were passed
  553. on the Python side. Also note that shape recording may increase the overhead of nvtx range creation.
  554. Default: ``False``
  555. Example:
  556. >>> # xdoctest: +SKIP("undefined variables")
  557. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER)
  558. >>> with torch.cuda.profiler.profile():
  559. ... model(x) # Warmup CUDA memory allocator and profiler
  560. ... with torch.autograd.profiler.emit_nvtx():
  561. ... model(x)
  562. **Forward-backward correlation**
  563. When viewing a profile created using :class:`emit_nvtx` in the Nvidia Visual Profiler,
  564. correlating each backward-pass op with the corresponding forward-pass op can be difficult.
  565. To ease this task, :class:`emit_nvtx` appends sequence number information to the ranges it
  566. generates.
  567. During the forward pass, each function range is decorated with ``seq=<N>``. ``seq`` is a running
  568. counter, incremented each time a new backward Function object is created and stashed for backward.
  569. Thus, the ``seq=<N>`` annotation associated with each forward function range tells you that
  570. if a backward Function object is created by this forward function,
  571. the backward object will receive sequence number N.
  572. During the backward pass, the top-level range wrapping each C++ backward Function's
  573. ``apply()`` call is decorated with ``stashed seq=<M>``. ``M`` is the sequence number that
  574. the backward object was created with. By comparing ``stashed seq`` numbers in backward with ``seq``
  575. numbers in forward, you can track down which forward op created each backward Function.
  576. Any functions executed during the backward pass are also decorated with ``seq=<N>``. During
  577. default backward (with ``create_graph=False``) this information is irrelevant, and in fact,
  578. ``N`` may simply be 0 for all such functions. Only the top-level ranges associated with
  579. backward Function objects' ``apply()`` methods are useful, as a way to correlate these Function
  580. objects with the earlier forward pass.
  581. **Double-backward**
  582. If, on the other hand, a backward pass with ``create_graph=True`` is underway (in other words,
  583. if you are setting up for a double-backward), each function's execution during backward
  584. is given a nonzero, useful ``seq=<N>``. Those functions may themselves create Function objects
  585. to be executed later during double-backward, just as the original functions in the forward pass did.
  586. The relationship between backward and double-backward is conceptually the same as the relationship
  587. between forward and backward: The functions still emit current-sequence-number-tagged ranges,
  588. the Function objects they create still stash those sequence numbers, and during the eventual
  589. double-backward, the Function objects' ``apply()`` ranges are still tagged with ``stashed seq``
  590. numbers, which can be compared to `seq` numbers from the backward pass.
  591. .. warning:
  592. The sequence number is thread-local, and some forward functions don't create an associated
  593. backward Function object (instead delegating that to sub-functions further down the call chain).
  594. For these reasons, the correspondence of stashed sequence numbers in
  595. backward Function ``apply()`` ranges with `seq` numbers in forward-pass ranges is
  596. not guaranteed to be 1 to 1. The sequence numbers alone may not be enough to fully
  597. disambiguate which forward function created which
  598. backward Function object. You may need to make a judgment based on analytic knowledge of what
  599. the expected correspondence should be.
  600. """
  601. def __init__(self, enabled=True, record_shapes=False):
  602. self.enabled = enabled
  603. self.entered = False
  604. self.record_shapes = record_shapes
  605. def __enter__(self):
  606. if not self.enabled:
  607. return
  608. if self.entered:
  609. raise RuntimeError("NVTX annotation context manager is not reentrant")
  610. self.entered = True
  611. torch.cuda.synchronize()
  612. _enable_profiler(
  613. ProfilerConfig(
  614. ProfilerState.NVTX,
  615. self.record_shapes,
  616. False,
  617. False,
  618. False,
  619. False,
  620. _ExperimentalConfig()),
  621. set()
  622. )
  623. return self
  624. def __exit__(self, exc_type, exc_val, exc_tb):
  625. if not self.enabled:
  626. return
  627. torch.cuda.synchronize()
  628. _disable_profiler()
  629. return False
  630. def load_nvprof(path):
  631. """Opens an nvprof trace file and parses autograd annotations.
  632. Args:
  633. path (str): path to nvprof trace
  634. """
  635. return EventList(parse_nvprof_trace(path))
  636. class EnforceUnique:
  637. """Raises an error if a key is seen more than once."""
  638. def __init__(self):
  639. self.seen = set()
  640. def see(self, *key):
  641. if key in self.seen:
  642. raise RuntimeError('duplicate key: ' + str(key))
  643. self.seen.add(key)
  644. def parse_nvprof_trace(path):
  645. import sqlite3
  646. conn = sqlite3.connect(path)
  647. conn.row_factory = sqlite3.Row
  648. # Parse strings table
  649. strings = {}
  650. for r in conn.execute("SELECT _id_ as id, value FROM StringTable"):
  651. strings[r["id"]] = torch._C._demangle(r["value"])
  652. # First, find all functions and create FunctionEvents for them
  653. marker_query = """
  654. SELECT
  655. start.id AS marker_id, start.name, start.timestamp AS start_time, end.timestamp AS end_time
  656. FROM
  657. CUPTI_ACTIVITY_KIND_MARKER AS start INNER JOIN CUPTI_ACTIVITY_KIND_MARKER AS end
  658. ON start.id = end.id
  659. WHERE
  660. start.name != 0 AND end.name = 0
  661. """
  662. functions = []
  663. functions_map = {}
  664. unique = EnforceUnique()
  665. for row in conn.execute(marker_query):
  666. unique.see(row['marker_id'])
  667. evt = FunctionEvent(id=row['marker_id'],
  668. node_id=0, # missing a node_id when calling FunctionEvent. This is just to ensure
  669. # that pytorch doesn't crash when creating a FunctionEvent() object
  670. name=strings[row['name']],
  671. start_us=row['start_time'],
  672. end_us=row['end_time'],
  673. thread=0) # TODO: find in sqlite database
  674. functions.append(evt)
  675. functions_map[evt.id] = evt
  676. # Now, correlate all kernels with FunctionEvents
  677. kernel_query = """
  678. SELECT
  679. start.id AS marker_id, start.name, start.timestamp, end.timestamp,
  680. runtime._id_ AS runtime_id, runtime.cbid, runtime.start AS runtime_start, runtime.end AS runtime_end,
  681. kernel.start AS kernel_start, kernel.end AS kernel_end, kernel.name AS kernel_name
  682. FROM
  683. CUPTI_ACTIVITY_KIND_MARKER AS start
  684. INNER JOIN CUPTI_ACTIVITY_KIND_MARKER AS end
  685. ON start.id = end.id
  686. INNER JOIN CUPTI_ACTIVITY_KIND_RUNTIME as runtime
  687. ON (start.timestamp < runtime.start AND runtime.end < end.timestamp)
  688. INNER JOIN CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL AS kernel
  689. ON kernel.correlationId = runtime.correlationId
  690. """
  691. unique = EnforceUnique()
  692. for row in conn.execute(kernel_query):
  693. unique.see(row['marker_id'], row['runtime_id'])
  694. # 211 is cudaKernelLaunch for cuda >= 9.2
  695. assert (row['cbid'] == 211)
  696. evt = functions_map[row['marker_id']]
  697. evt.append_kernel(row['kernel_name'],
  698. 0,
  699. row['kernel_end'] - row['kernel_start'])
  700. functions.sort(key=lambda evt: evt.time_range.start)
  701. return functions
  702. class KinetoStepTracker:
  703. """Provides an abstraction for incrementing the step count globally.
  704. Previously, we only had one place to mark that a step() has occurred
  705. in the program via pytorch profiler step(). We will now add step hooks
  706. in the Optimizer class https://github.com/pytorch/pytorch/issues/88446
  707. - This could mean programs that already call profiler.step() every
  708. iteration can end up double incrementing step count.
  709. - If a model uses multiple optimizers we can also have double or more
  710. counting of the step.
  711. We fix this by adding a layer of abstraction before calling step()
  712. to the kineto library. The idea is to maintain steps per requester in a dict:
  713. ```
  714. {
  715. "ProfilerStep": 100, # triggered by profiler step() call
  716. "Optimizer1Step": 100, # Optimizer 1 or 2 are just examples, could be SGD, Adam etc
  717. "Optimizer2Step": 100,
  718. }
  719. ```
  720. To figure out the global step count just take the max of dict values (100).
  721. If one of the count increments the max will go up.
  722. ```
  723. {
  724. "ProfilerStep": 100,
  725. "Optimizer1Step": 101, # Optimizer1 got incremented first say
  726. "Optimizer2Step": 100,
  727. }
  728. ```
  729. Then global step count is 101
  730. We only call the kineto step() function when global count increments.
  731. NOTE: Please do not use the KinetoStepTracker in modules beside the Optimizer
  732. for now. The result could be incorrect increments of the step count.
  733. """
  734. _current_step = -1
  735. _step_dict: Dict[str, int] = defaultdict(int)
  736. @classmethod
  737. def init_step_count(cls, requester: str):
  738. cls._step_dict[requester] = cls._current_step
  739. @classmethod
  740. def erase_step_count(cls, requester: str) -> bool:
  741. return cls._step_dict.pop(requester, None) is not None
  742. @classmethod
  743. def increment_step(cls, requester: str) -> int:
  744. """Increments the step count for the requester.
  745. Additionally if the max over all step counts has incremented then
  746. trigger the _kineto_step()
  747. returns global step count
  748. """
  749. if requester not in cls._step_dict:
  750. cls.init_step_count(requester)
  751. cls._step_dict[requester] += 1
  752. new_step = max(cls._step_dict.values())
  753. if new_step > cls._current_step:
  754. delta = new_step - cls._current_step
  755. if delta > 1:
  756. warn("Profiler step count has increased more than 1 - "
  757. f"current_step = {cls._current_step} step dict = {cls._step_dict}")
  758. for _ in range(0, delta):
  759. _kineto_step()
  760. cls._current_step = new_step
  761. return cls._current_step
  762. @classmethod
  763. def current_step(cls) -> int:
  764. return cls._current_step