1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798 |
- """Example of Timer and Compare APIs:
- $ python -m examples.compare
- """
- import pickle
- import sys
- import time
- import torch
- import torch.utils.benchmark as benchmark_utils
- class FauxTorch:
- """Emulate different versions of pytorch.
- In normal circumstances this would be done with multiple processes
- writing serialized measurements, but this simplifies that model to
- make the example clearer.
- """
- def __init__(self, real_torch, extra_ns_per_element):
- self._real_torch = real_torch
- self._extra_ns_per_element = extra_ns_per_element
- def extra_overhead(self, result):
- # time.sleep has a ~65 us overhead, so only fake a
- # per-element overhead if numel is large enough.
- numel = int(result.numel())
- if numel > 5000:
- time.sleep(numel * self._extra_ns_per_element * 1e-9)
- return result
- def add(self, *args, **kwargs):
- return self.extra_overhead(self._real_torch.add(*args, **kwargs))
- def mul(self, *args, **kwargs):
- return self.extra_overhead(self._real_torch.mul(*args, **kwargs))
- def cat(self, *args, **kwargs):
- return self.extra_overhead(self._real_torch.cat(*args, **kwargs))
- def matmul(self, *args, **kwargs):
- return self.extra_overhead(self._real_torch.matmul(*args, **kwargs))
- def main():
- tasks = [
- ("add", "add", "torch.add(x, y)"),
- ("add", "add (extra +0)", "torch.add(x, y + zero)"),
- ]
- serialized_results = []
- repeats = 2
- timers = [
- benchmark_utils.Timer(
- stmt=stmt,
- globals={
- "torch": torch if branch == "master" else FauxTorch(torch, overhead_ns),
- "x": torch.ones((size, 4)),
- "y": torch.ones((1, 4)),
- "zero": torch.zeros(()),
- },
- label=label,
- sub_label=sub_label,
- description=f"size: {size}",
- env=branch,
- num_threads=num_threads,
- )
- for branch, overhead_ns in [("master", None), ("my_branch", 1), ("severe_regression", 5)]
- for label, sub_label, stmt in tasks
- for size in [1, 10, 100, 1000, 10000, 50000]
- for num_threads in [1, 4]
- ]
- for i, timer in enumerate(timers * repeats):
- serialized_results.append(pickle.dumps(
- timer.blocked_autorange(min_run_time=0.05)
- ))
- print(f"\r{i + 1} / {len(timers) * repeats}", end="")
- sys.stdout.flush()
- print()
- comparison = benchmark_utils.Compare([
- pickle.loads(i) for i in serialized_results
- ])
- print("== Unformatted " + "=" * 80 + "\n" + "/" * 95 + "\n")
- comparison.print()
- print("== Formatted " + "=" * 80 + "\n" + "/" * 93 + "\n")
- comparison.trim_significant_figures()
- comparison.colorize()
- comparison.print()
- if __name__ == "__main__":
- main()
|