compare.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. """Example of Timer and Compare APIs:
  2. $ python -m examples.compare
  3. """
  4. import pickle
  5. import sys
  6. import time
  7. import torch
  8. import torch.utils.benchmark as benchmark_utils
  9. class FauxTorch:
  10. """Emulate different versions of pytorch.
  11. In normal circumstances this would be done with multiple processes
  12. writing serialized measurements, but this simplifies that model to
  13. make the example clearer.
  14. """
  15. def __init__(self, real_torch, extra_ns_per_element):
  16. self._real_torch = real_torch
  17. self._extra_ns_per_element = extra_ns_per_element
  18. def extra_overhead(self, result):
  19. # time.sleep has a ~65 us overhead, so only fake a
  20. # per-element overhead if numel is large enough.
  21. numel = int(result.numel())
  22. if numel > 5000:
  23. time.sleep(numel * self._extra_ns_per_element * 1e-9)
  24. return result
  25. def add(self, *args, **kwargs):
  26. return self.extra_overhead(self._real_torch.add(*args, **kwargs))
  27. def mul(self, *args, **kwargs):
  28. return self.extra_overhead(self._real_torch.mul(*args, **kwargs))
  29. def cat(self, *args, **kwargs):
  30. return self.extra_overhead(self._real_torch.cat(*args, **kwargs))
  31. def matmul(self, *args, **kwargs):
  32. return self.extra_overhead(self._real_torch.matmul(*args, **kwargs))
  33. def main():
  34. tasks = [
  35. ("add", "add", "torch.add(x, y)"),
  36. ("add", "add (extra +0)", "torch.add(x, y + zero)"),
  37. ]
  38. serialized_results = []
  39. repeats = 2
  40. timers = [
  41. benchmark_utils.Timer(
  42. stmt=stmt,
  43. globals={
  44. "torch": torch if branch == "master" else FauxTorch(torch, overhead_ns),
  45. "x": torch.ones((size, 4)),
  46. "y": torch.ones((1, 4)),
  47. "zero": torch.zeros(()),
  48. },
  49. label=label,
  50. sub_label=sub_label,
  51. description=f"size: {size}",
  52. env=branch,
  53. num_threads=num_threads,
  54. )
  55. for branch, overhead_ns in [("master", None), ("my_branch", 1), ("severe_regression", 5)]
  56. for label, sub_label, stmt in tasks
  57. for size in [1, 10, 100, 1000, 10000, 50000]
  58. for num_threads in [1, 4]
  59. ]
  60. for i, timer in enumerate(timers * repeats):
  61. serialized_results.append(pickle.dumps(
  62. timer.blocked_autorange(min_run_time=0.05)
  63. ))
  64. print(f"\r{i + 1} / {len(timers) * repeats}", end="")
  65. sys.stdout.flush()
  66. print()
  67. comparison = benchmark_utils.Compare([
  68. pickle.loads(i) for i in serialized_results
  69. ])
  70. print("== Unformatted " + "=" * 80 + "\n" + "/" * 95 + "\n")
  71. comparison.print()
  72. print("== Formatted " + "=" * 80 + "\n" + "/" * 93 + "\n")
  73. comparison.trim_significant_figures()
  74. comparison.colorize()
  75. comparison.print()
  76. if __name__ == "__main__":
  77. main()