profiler.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. import dataclasses
  2. import os
  3. from typing import Any, List
  4. import torch
  5. from . import config
  6. from .utils import print_once
  7. @dataclasses.dataclass
  8. class ProfileMetrics:
  9. microseconds: float = 0.0
  10. operators: int = 0
  11. fusions: int = 0
  12. graphs: int = 0
  13. def __iadd__(self, other: "ProfileMetrics"):
  14. self.microseconds += other.microseconds
  15. self.operators += other.operators
  16. self.fusions += other.fusions
  17. return self
  18. def __add__(self, other: "ProfileMetrics"):
  19. assert isinstance(other, ProfileMetrics)
  20. return ProfileMetrics(
  21. self.microseconds + other.microseconds,
  22. self.operators + other.operators,
  23. self.fusions + other.fusions,
  24. )
  25. def __truediv__(self, other):
  26. if isinstance(other, int):
  27. other = ProfileMetrics(other, other, other)
  28. return ProfileMetrics(
  29. self.microseconds / max(1, other.microseconds),
  30. self.operators / max(1, other.operators),
  31. self.fusions / max(1, other.fusions),
  32. )
  33. def __str__(self):
  34. return f"{self.operators:4.0%} ops {self.microseconds:4.0%} time"
  35. def tocsv(self):
  36. return [self.operators, self.microseconds]
  37. class ProfileResult:
  38. def __init__(self, captured, total, unique_graphs):
  39. self.captured: ProfileMetrics = captured or ProfileMetrics()
  40. self.total: ProfileMetrics = total or ProfileMetrics()
  41. self.unique_graphs: int = unique_graphs
  42. def __iadd__(self, other: ProfileMetrics):
  43. self.captured += other.captured
  44. self.total += other.total
  45. self.unique_graphs += other.unique_graphs
  46. return self
  47. def percent(self):
  48. return self.captured / self.total
  49. def __str__(self):
  50. return (
  51. f"{self.unique_graphs:2} graphs {self.captured.graphs:2} graph calls "
  52. f"{self.captured.operators:4}/{self.total.operators:4} = "
  53. + str(self.percent())
  54. )
  55. def tocsv(self):
  56. return [
  57. self.unique_graphs,
  58. self.captured.graphs,
  59. self.captured.operators,
  60. self.total.operators,
  61. ] + self.percent().tocsv()
  62. def should_print_missing():
  63. return os.environ.get("TORCHDYNAMO_PRINT_MISSING") == "1"
  64. def print_missing(stack):
  65. if any("/torch/autograd/profiler.py" in x for x in stack):
  66. return
  67. stack = [
  68. x for x in stack if ("<built-in" not in x and "site-packages/torch/" not in x)
  69. ]
  70. print_once("MISSING", " >> ".join(stack[-3:]))
  71. class Profiler:
  72. unique_graphs = 0
  73. def __init__(self):
  74. self.prof = torch.profiler.profile(
  75. activities=[torch.profiler.ProfilerActivity.CPU],
  76. with_stack=should_print_missing(),
  77. )
  78. def results(self):
  79. captured_regions = 0
  80. captured_ops = 0
  81. captured_microseconds = 0
  82. total_ops = 0
  83. total_microseconds = 0
  84. last_op_end_time = -1
  85. captured_region_end_time = -1
  86. events = sorted(self.prof.events(), key=lambda x: x.time_range.start)
  87. for e in events:
  88. if e.name == "TORCHDYNAMO":
  89. captured_region_end_time = e.time_range.end
  90. captured_regions += 1
  91. # ignore `handle = torch.zeros(1)` in record_function.__init__()
  92. total_ops -= 1
  93. elif e.time_range.start >= last_op_end_time:
  94. last_op_end_time = e.time_range.end
  95. if e.time_range.end <= captured_region_end_time:
  96. captured_ops += 1
  97. captured_microseconds += e.time_range.elapsed_us()
  98. elif should_print_missing():
  99. print_missing(e.stack)
  100. total_ops += 1
  101. total_microseconds += e.time_range.elapsed_us()
  102. else:
  103. pass # ops recursively called from other ops (ignored)
  104. unique_graphs = Profiler.unique_graphs
  105. Profiler.unique_graphs = 0
  106. return ProfileResult(
  107. captured=ProfileMetrics(
  108. microseconds=captured_microseconds,
  109. operators=captured_ops,
  110. fusions=captured_ops - captured_regions,
  111. graphs=captured_regions,
  112. ),
  113. total=ProfileMetrics(
  114. microseconds=total_microseconds,
  115. operators=total_ops,
  116. fusions=total_ops - 1,
  117. ),
  118. unique_graphs=unique_graphs,
  119. )
  120. def shapes_of(it):
  121. if it:
  122. return [tuple(getattr(x, "shape", [])) for x in it]
  123. def fx_insert_profiling(gm: torch.fx.GraphModule, example_inputs: List[Any]):
  124. input_shapes = shapes_of(example_inputs)
  125. output_shapes = None
  126. def debug_print(extra):
  127. gm.graph.print_tabular()
  128. return f"shape mismatch in={input_shapes} out={output_shapes} got={extra}"
  129. def _wrapped(*args):
  130. nonlocal output_shapes
  131. with torch.profiler.record_function("TORCHDYNAMO"):
  132. assert (
  133. shapes_of(args) == input_shapes or config.dynamic_shapes
  134. ), debug_print(shapes_of(args))
  135. result = gm.forward(*args)
  136. if output_shapes is None:
  137. output_shapes = shapes_of(result)
  138. else:
  139. assert (
  140. shapes_of(result) == output_shapes or config.dynamic_shapes
  141. ), debug_print(shapes_of(result))
  142. return result
  143. Profiler.unique_graphs += 1
  144. return _wrapped