_utils.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. from collections import deque
  2. from dataclasses import dataclass
  3. import functools
  4. import re
  5. from typing import Dict, List
  6. from torch.profiler import DeviceType
  7. from torch.autograd.profiler import profile
  8. from torch.autograd import _KinetoEvent
  9. def _traverse(tree, next_fn, children_fn=lambda x: x.children, reverse: bool = False):
  10. order = reversed if reverse else lambda x: x
  11. remaining = deque(order(tree))
  12. while remaining:
  13. curr_event = next_fn(remaining)
  14. yield curr_event
  15. for child_event in order(children_fn(curr_event)):
  16. remaining.append(child_event)
  17. traverse_dfs = functools.partial(_traverse, next_fn=lambda x: x.pop(), reverse=True)
  18. traverse_bfs = functools.partial(_traverse, next_fn=lambda x: x.popleft(), reverse=False)
  19. @dataclass
  20. class EventMetrics:
  21. duration_time_ns: int = 0
  22. self_time_ns: int = 0
  23. idle_time_ns: int = 0
  24. queue_depth: int = 0
  25. @property
  26. def fraction_idle_time(self):
  27. if self.duration_time_ns == 0:
  28. return 0.0
  29. return self.idle_time_ns / self.duration_time_ns
  30. @dataclass
  31. class Interval:
  32. start: int
  33. end: int
  34. queue_depth: int = 0
  35. class EventKey:
  36. def __init__(self, event):
  37. self.event = event
  38. def __hash__(self):
  39. return hash(self.event.id)
  40. def __eq__(self, other):
  41. return self.event.id == other.event.id
  42. def __repr__(self):
  43. return f"{self.event.name}"
  44. def intervals_overlap(self, intervals: List[Interval]):
  45. overlap_time = 0
  46. intervals = sorted(intervals, key=lambda x: x.start)
  47. if intervals:
  48. overlap_start = max(self.event.start_time_ns, intervals[0].start)
  49. overlap_end = min(self.event.end_time_ns, intervals[0].end)
  50. if overlap_start < overlap_end:
  51. overlap_time += overlap_end - overlap_start
  52. i, j = 0, 1
  53. while (j < len(intervals)):
  54. prev_interval = intervals[i]
  55. curr_interval = intervals[j]
  56. j += 1
  57. if prev_interval.end > curr_interval.start:
  58. # Completely subsumed by previous interval
  59. if prev_interval.end > curr_interval.end:
  60. j += 1
  61. continue
  62. else:
  63. curr_interval.start = prev_interval.end
  64. i = j
  65. overlap_start = max(self.event.start_time_ns, curr_interval.start)
  66. overlap_end = min(self.event.end_time_ns, curr_interval.end)
  67. if overlap_start < overlap_end:
  68. overlap_time += overlap_end - overlap_start
  69. return overlap_time
  70. class BasicEvaluation:
  71. def __init__(self, prof: profile):
  72. self.profile = prof
  73. self.metrics: Dict[EventKey, EventMetrics] = {}
  74. self.compute_self_time()
  75. self.event_keys = sorted((e for e in self.metrics.keys()),
  76. key=lambda x: x.event.start_time_ns)
  77. self.events = [e.event for e in self.event_keys]
  78. self.cuda_events: List[_KinetoEvent] = []
  79. self.queue_depth_list = self.compute_queue_depth()
  80. self.compute_idle_time()
  81. def compute_self_time(self):
  82. '''
  83. Computes event's self time(total time - time in child ops).
  84. '''
  85. assert (self.profile.kineto_results is not None)
  86. stack = deque(self.profile.kineto_results.experimental_event_tree())
  87. # standard iterating dfs
  88. while stack:
  89. curr_event = stack.pop()
  90. self_time = curr_event.duration_time_ns
  91. for child_event in curr_event.children:
  92. self_time -= child_event.duration_time_ns
  93. stack.append(child_event)
  94. assert EventKey(
  95. curr_event
  96. ) not in self.metrics, f"Duplicate id: {curr_event.id}, {curr_event.name}"
  97. self.metrics[EventKey(curr_event)] = EventMetrics(
  98. self_time_ns=self_time)
  99. self.metrics[EventKey(
  100. curr_event)].duration_time_ns = curr_event.duration_time_ns
  101. def compute_queue_depth(self):
  102. '''
  103. Computes queue_depth at each event. This will calculate the queue depth data for
  104. All the events in the tree.
  105. This will return a list of Interval of queue depth data of cuda launch and kernels.
  106. '''
  107. assert (self.profile.kineto_results is not None)
  108. cuda_event_list = self.profile.kineto_results.events()
  109. def is_cuda_launch_kernel(e):
  110. # TODO: find a better way to identify cudaLaunchKernel
  111. return e.name == "cudaLaunchKernel"
  112. def is_cuda_kernel(e):
  113. # TODO: find a better way to identify CUDA Kernel
  114. return e.device_type() == DeviceType.CUDA and "mem" not in e.name.lower()
  115. cuda_launch_events = sorted(
  116. (e for e in cuda_event_list if is_cuda_launch_kernel(e)),
  117. key=lambda x: x.start_us())
  118. cuda_kernel_events = sorted(
  119. (e for e in cuda_event_list if is_cuda_kernel(e)),
  120. key=lambda x: x.start_us())
  121. self.cuda_events = sorted(cuda_launch_events + cuda_kernel_events,
  122. key=lambda x: x.start_us())
  123. kernel_mapping: Dict[_KinetoEvent, int] = {}
  124. last_mapped_kernel = 0
  125. for cuda_launch_event in cuda_launch_events:
  126. index = index_of_first_match(
  127. cuda_kernel_events,
  128. lambda x: x.linked_correlation_id(
  129. ) == cuda_launch_event.linked_correlation_id(),
  130. start=last_mapped_kernel)
  131. kernel_mapping[cuda_launch_event] = index
  132. last_mapped_kernel = index if index is not None else last_mapped_kernel
  133. current_kernel_index = 0
  134. spawned_kernel_index = -1
  135. all_events = cuda_launch_events + cuda_kernel_events + self.events
  136. def new_old_event_comparator(event):
  137. if hasattr(event, "start_us"):
  138. return event.start_us() * 1000
  139. if hasattr(event, "start_time_ns"):
  140. return event.start_time_ns
  141. raise Exception("Unknown Event Type")
  142. queue_depth_list: List[Interval] = []
  143. all_events.sort(key=new_old_event_comparator)
  144. for event in all_events:
  145. # Find latest cuda kernel event
  146. if hasattr(event, "start_us"):
  147. start_time = event.start_us() * 1000
  148. end_time = (event.start_us() + event.duration_us()) * 1000
  149. # Find current spawned cuda kernel event
  150. if event in kernel_mapping and kernel_mapping[
  151. event] is not None:
  152. spawned_kernel_index = kernel_mapping[event]
  153. elif hasattr(event, "start_time_ns"):
  154. start_time = event.start_time_ns # type: ignore[attr-defined]
  155. end_time = event.end_time_ns # type: ignore[attr-defined]
  156. while (current_kernel_index < len(cuda_kernel_events) and
  157. (cuda_kernel_events[current_kernel_index].start_us()) * 1000
  158. <= start_time):
  159. current_kernel_index += 1
  160. current_queue_depth = spawned_kernel_index - current_kernel_index + 1
  161. current_queue_depth = max(current_queue_depth, 0)
  162. if hasattr(event, "start_us"):
  163. queue_depth_list.append(
  164. Interval(start_time, end_time, current_queue_depth))
  165. elif hasattr(event, "start_time_ns"):
  166. self.metrics[EventKey(event)].queue_depth = current_queue_depth
  167. return queue_depth_list
  168. def compute_idle_time(self):
  169. '''
  170. Computes idle time of the profile.
  171. '''
  172. # Based on queue_depth_list, we can calculate idle time for all the events
  173. idle = False
  174. idle_start = 0
  175. idle_intervals: List[Interval] = []
  176. if self.queue_depth_list and self.events:
  177. idle_intervals += [
  178. Interval(self.events[0].start_time_ns,
  179. self.queue_depth_list[0].start),
  180. Interval(self.queue_depth_list[-1].end,
  181. self.events[-1].end_time_ns)
  182. ]
  183. for data_point in self.queue_depth_list:
  184. if data_point.queue_depth == 0 and not idle:
  185. idle_start = data_point.end
  186. idle = True
  187. if data_point.queue_depth > 0 and idle:
  188. idle_intervals.append(Interval(idle_start, data_point.start))
  189. idle = False
  190. event_list = [e.event for e in self.metrics.keys()]
  191. for event in event_list:
  192. self.metrics[EventKey(event)].idle_time_ns = EventKey(
  193. event).intervals_overlap(idle_intervals)
  194. def rank_events(self, length):
  195. '''
  196. Filter and Rank the events based on some heuristics:
  197. 1) Events that are in the falling phase of the queue depth.
  198. 2) Events that have a high idle_time, self_time difference.
  199. Parameters:
  200. length: The number of events to return.
  201. '''
  202. # Find the interval when qd is falling to 0
  203. import torch
  204. queue_depth_list = list(reversed(self.queue_depth_list))
  205. qd_values = [e.queue_depth for e in queue_depth_list]
  206. bottom_threashold = 0
  207. top_threashold = 4
  208. decrease_interval = []
  209. i = 0
  210. while (i < len(qd_values)):
  211. if qd_values[i] > bottom_threashold:
  212. i += 1
  213. continue
  214. for j in range(i + 1, len(qd_values)):
  215. # Find next zero and if the max value between them exceeds
  216. # the threshold, then we have a falling interval
  217. next_minimum_idx = index_of_first_match(
  218. qd_values, lambda x: x <= bottom_threashold, start=j)
  219. peak_idx = argmax(qd_values, start=j, end=next_minimum_idx)
  220. # if is a valid peak, we add to list and continue
  221. if peak_idx is not None and qd_values[
  222. peak_idx] >= top_threashold:
  223. decrease_interval.append(
  224. Interval(queue_depth_list[peak_idx].start,
  225. queue_depth_list[i].start))
  226. i = next_minimum_idx if next_minimum_idx is not None else i
  227. break
  228. i += 1
  229. # Filter out events that are not in the decrease interval
  230. event_list = [
  231. event for event in self.metrics.keys()
  232. if event.intervals_overlap(decrease_interval)
  233. ]
  234. if event_list:
  235. self_time = torch.tensor(
  236. [self.metrics[event].self_time_ns for event in event_list],
  237. dtype=torch.float32)
  238. idle_time = torch.tensor([
  239. self.metrics[event].fraction_idle_time for event in event_list
  240. ], dtype=torch.float32)
  241. normalized_gain = (idle_time -
  242. torch.mean(idle_time)) / torch.std(idle_time)
  243. normalized_self = (self_time -
  244. torch.mean(self_time)) / torch.std(self_time)
  245. heuristic_score_list = normalized_gain + 0.6 * normalized_self
  246. # Sort events by heuristic
  247. event_list = [
  248. event
  249. for _, event in sorted(zip(heuristic_score_list, event_list),
  250. key=lambda x: x[0],
  251. reverse=True)
  252. ]
  253. event_list = event_list[:length]
  254. return event_list
  255. def get_optimizable_events(self,
  256. length: int = 1,
  257. print_enable: bool = True):
  258. event_list = self.rank_events(length)
  259. if not print_enable:
  260. return event_list
  261. output = "Optimizable events:\n" if event_list else "No events to optimize\n"
  262. output += "\n".join([
  263. f"""{'-'*80}
  264. Event: {event}
  265. Source code location: {source_code_location(event.event)}
  266. Percentage idle time: {self.metrics[event].fraction_idle_time * 100:.2f}%
  267. {'-'*80}""" for event in event_list
  268. ])
  269. if print_enable:
  270. print(output)
  271. return event_list
  272. def index_of_first_match(seq, predicate, start=0, end=None):
  273. if end is None or end >= len(seq):
  274. end = len(seq)
  275. for i in range(start, end):
  276. if predicate(seq[i]):
  277. return i
  278. return None
  279. def argmax(seq, key=lambda x: x, start=0, end=None):
  280. seq = seq[start:end]
  281. if len(seq) == 0:
  282. return None
  283. return seq.index(max(seq, key=key)) + start
  284. def source_code_location(event):
  285. while (event is not None):
  286. match = re.search(r"\.py\(.*\)", event.name)
  287. if (match is None):
  288. event = event.parent
  289. continue
  290. return event.name
  291. return "No source code location found"