profiler_legacy.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. import torch
  2. import torch.cuda
  3. from torch.autograd.profiler_util import (
  4. EventList, FunctionEvent, MEMORY_EVENT_NAME,
  5. _filter_name, _filter_stack_entry, _rewrite_name
  6. )
  7. from torch.autograd import (
  8. DeviceType, ProfilerConfig, ProfilerState,
  9. _disable_profiler_legacy, _enable_profiler_legacy,
  10. )
  11. import itertools
  12. from warnings import warn
  13. __all__ = ["profile"]
  14. class profile:
  15. """DEPRECATED: use torch.profiler instead"""
  16. def __init__(
  17. self,
  18. enabled=True,
  19. *,
  20. use_cuda=False,
  21. record_shapes=False,
  22. with_flops=False,
  23. profile_memory=False,
  24. with_stack=False,
  25. with_modules=False):
  26. self.enabled: bool = enabled
  27. if not self.enabled:
  28. return
  29. self.use_cuda = use_cuda
  30. self.function_events = None
  31. self.entered = False
  32. self.record_shapes = record_shapes
  33. self.with_flops = with_flops
  34. self.record_shapes |= self.with_flops
  35. self.profile_memory = profile_memory
  36. self.with_stack = with_stack
  37. self.with_modules = with_modules
  38. if self.use_cuda and not torch.cuda.is_available():
  39. warn("CUDA is not available, disabling CUDA profiling")
  40. self.use_cuda = False
  41. if self.use_cuda:
  42. self.profiler_kind = ProfilerState.CUDA
  43. else:
  44. self.profiler_kind = ProfilerState.CPU
  45. def config(self):
  46. return ProfilerConfig(
  47. self.profiler_kind,
  48. self.record_shapes,
  49. self.profile_memory,
  50. self.with_stack,
  51. self.with_flops,
  52. self.with_modules,
  53. # avoid exposing _ExperimentalConfig this in legacy public API
  54. torch._C._profiler._ExperimentalConfig(),
  55. )
  56. def __enter__(self):
  57. if not self.enabled:
  58. return
  59. if self.entered:
  60. raise RuntimeError("Profiler context manager is not reentrant")
  61. self.entered = True
  62. self._start_trace()
  63. return self
  64. def _start_trace(self):
  65. _enable_profiler_legacy(self.config())
  66. def __exit__(self, exc_type, exc_val, exc_tb):
  67. if not self.enabled:
  68. return
  69. if self.use_cuda:
  70. torch.cuda.synchronize()
  71. records = _disable_profiler_legacy()
  72. parsed_results = _parse_legacy_records(records)
  73. self.function_events = EventList(
  74. parsed_results,
  75. use_cuda=self.use_cuda,
  76. profile_memory=self.profile_memory,
  77. with_flops=self.with_flops)
  78. self.function_events._build_tree()
  79. return False
  80. def __repr__(self):
  81. if self.function_events is None:
  82. return '<unfinished profiler_legacy.profile>'
  83. return repr(self.function_events)
  84. def __str__(self):
  85. if self.function_events is None:
  86. return '<unfinished profile.profiler_legacy.profile>'
  87. return str(self.function_events)
  88. def _check_finish(self):
  89. if self.function_events is None:
  90. raise RuntimeError("Profiler didn't finish running")
  91. def table(
  92. self,
  93. sort_by=None,
  94. row_limit=100,
  95. max_src_column_width=75,
  96. max_name_column_width=55,
  97. max_shapes_column_width=80,
  98. header=None,
  99. top_level_events_only=False
  100. ):
  101. self._check_finish()
  102. assert self.function_events is not None
  103. return self.function_events.table(
  104. sort_by=sort_by,
  105. row_limit=row_limit,
  106. max_src_column_width=max_src_column_width,
  107. max_name_column_width=max_name_column_width,
  108. max_shapes_column_width=max_shapes_column_width,
  109. header=header,
  110. top_level_events_only=top_level_events_only
  111. )
  112. table.__doc__ = EventList.table.__doc__
  113. def export_chrome_trace(self, path):
  114. self._check_finish()
  115. assert self.function_events is not None
  116. return self.function_events.export_chrome_trace(path)
  117. export_chrome_trace.__doc__ = EventList.export_chrome_trace.__doc__
  118. def export_stacks(self, path: str, metric: str = "self_cpu_time_total"):
  119. self._check_finish()
  120. assert self.function_events is not None, "Expected profiling results"
  121. assert self.with_stack, "export_stacks() requires with_stack=True"
  122. return self.function_events.export_stacks(path, metric)
  123. def key_averages(self, group_by_input_shape=False, group_by_stack_n=0):
  124. self._check_finish()
  125. assert self.function_events is not None, "Expected profiling results"
  126. return self.function_events.key_averages(group_by_input_shape, group_by_stack_n)
  127. key_averages.__doc__ = EventList.key_averages.__doc__
  128. def total_average(self):
  129. self._check_finish()
  130. assert self.function_events is not None, "Expected profiling results"
  131. return self.function_events.total_average()
  132. total_average.__doc__ = EventList.total_average.__doc__
  133. @property
  134. def self_cpu_time_total(self):
  135. """ Returns total time spent on CPU obtained as a sum of
  136. all self times across all the events.
  137. """
  138. self._check_finish()
  139. assert self.function_events is not None
  140. return self.function_events.self_cpu_time_total
  141. def _parse_legacy_records(thread_records):
  142. def _get_record_key(record):
  143. """
  144. Returns a tuple to be used by _parse_legacy_records for correlating start and
  145. end records.
  146. """
  147. return (record.handle(), record.node_id())
  148. next_id = 0
  149. start_record = None
  150. functions = []
  151. record_stack = []
  152. # '__start_profile' is not guaranteed to be first, so we must find it here
  153. for record in itertools.chain(*thread_records):
  154. name = record.name()
  155. if start_record is None and name == '__start_profile':
  156. start_record = record
  157. assert start_record is not None and not start_record.is_remote()
  158. for thread_record_list in thread_records:
  159. # accumulated memory allocations per handle
  160. cpu_memory_allocs = {}
  161. cuda_memory_allocs = {}
  162. # ranges per handle
  163. range_starts = {}
  164. filtered_handles = set()
  165. prev_record = None
  166. for record in thread_record_list:
  167. record_key = _get_record_key(record)
  168. if (_filter_name(record.name()) or
  169. record_key in filtered_handles):
  170. filtered_handles.add(record_key)
  171. continue
  172. if record.kind() == 'push':
  173. # workaround to reduce double logging from operator
  174. # wrappers and redispatch
  175. if prev_record is not None:
  176. duplicate = (
  177. prev_record.name() == record.name()
  178. and prev_record.kind() == record.kind()
  179. and prev_record.node_id() == record.node_id()
  180. )
  181. if duplicate:
  182. filtered_handles.add(record_key)
  183. continue
  184. range_starts[record_key] = record
  185. cpu_memory_allocs[record_key] = 0
  186. cuda_memory_allocs[record_key] = 0
  187. elif record.kind() == 'pop':
  188. assert (
  189. record_key in range_starts
  190. ), """Expected record with key {} to exist in range_starts.
  191. This means that the pop event did not have a corresponding push.""".format(
  192. record_key
  193. )
  194. start = range_starts[record_key]
  195. cpu_memory_usage = cpu_memory_allocs[record_key]
  196. cuda_memory_usage = cuda_memory_allocs[record_key]
  197. is_async = start.is_async() or (
  198. start.thread_id() != record.thread_id()
  199. )
  200. is_remote_event = record.is_remote()
  201. start_flops = start.flops()
  202. fe = FunctionEvent(
  203. id=record.handle(),
  204. node_id=record.node_id(),
  205. name=_rewrite_name(name=start.name(), with_wildcard=True),
  206. trace_name=_rewrite_name(name=start.name(), with_wildcard=False),
  207. thread=start.thread_id(),
  208. start_us=start_record.cpu_elapsed_us(start),
  209. end_us=start_record.cpu_elapsed_us(record),
  210. fwd_thread=start.fwd_thread_id(),
  211. input_shapes=start.shapes(),
  212. stack=[entry for entry in start.stack() if _filter_stack_entry(entry)],
  213. scope=start.scope(),
  214. cpu_memory_usage=cpu_memory_usage,
  215. cuda_memory_usage=cuda_memory_usage,
  216. is_async=is_async,
  217. is_remote=is_remote_event,
  218. sequence_nr=start.sequence_nr(),
  219. device_type=DeviceType.CPU,
  220. is_legacy=True,
  221. flops=start_flops,
  222. )
  223. # note: async events have only cpu total time
  224. if not is_async and start.has_cuda():
  225. duration = start.cuda_elapsed_us(record)
  226. if duration > 0:
  227. fe.append_kernel(
  228. start.name(),
  229. start.device(),
  230. duration)
  231. functions.append(fe)
  232. del range_starts[record_key]
  233. del cpu_memory_allocs[record_key]
  234. del cuda_memory_allocs[record_key]
  235. elif record.kind() == 'memory_alloc':
  236. num_open_handles_cpu = len(cpu_memory_allocs)
  237. num_open_handles_cuda = len(cuda_memory_allocs)
  238. assert num_open_handles_cpu == num_open_handles_cuda
  239. for handle in cpu_memory_allocs.keys():
  240. cpu_memory_allocs[handle] += record.cpu_memory_usage()
  241. for handle in cuda_memory_allocs.keys():
  242. cuda_memory_allocs[handle] += record.cuda_memory_usage()
  243. if num_open_handles_cpu == 0:
  244. # output event as a top-level memory event
  245. fe = FunctionEvent(
  246. id=0,
  247. name=MEMORY_EVENT_NAME,
  248. trace_name=None,
  249. thread=0,
  250. start_us=0,
  251. end_us=0,
  252. stack=[],
  253. cpu_memory_usage=record.cpu_memory_usage(),
  254. cuda_memory_usage=record.cuda_memory_usage(),
  255. is_legacy=True,
  256. )
  257. functions.append(fe)
  258. prev_record = record
  259. # Sort functions by start time then by end time ascending.
  260. # This ensures that--in the case of nested events which
  261. # have the same start time (which may happen due to the
  262. # granularity of the given clock tick)--we always show
  263. # the outermost nested call first. This adds stability
  264. # in how FunctionEvents appear
  265. functions.sort(key=lambda evt: [evt.time_range.start, -evt.time_range.end])
  266. return functions