throughput_benchmark.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. import torch._C
  2. def format_time(time_us=None, time_ms=None, time_s=None):
  3. '''Defines how to format time'''
  4. assert sum([time_us is not None, time_ms is not None, time_s is not None]) == 1
  5. US_IN_SECOND = 1e6
  6. US_IN_MS = 1e3
  7. if time_us is None:
  8. if time_ms is not None:
  9. time_us = time_ms * US_IN_MS
  10. elif time_s is not None:
  11. time_us = time_s * US_IN_SECOND
  12. else:
  13. raise AssertionError("Shouldn't reach here :)")
  14. if time_us >= US_IN_SECOND:
  15. return '{:.3f}s'.format(time_us / US_IN_SECOND)
  16. if time_us >= US_IN_MS:
  17. return '{:.3f}ms'.format(time_us / US_IN_MS)
  18. return '{:.3f}us'.format(time_us)
  19. class ExecutionStats:
  20. def __init__(self, c_stats, benchmark_config):
  21. self._c_stats = c_stats
  22. self.benchmark_config = benchmark_config
  23. @property
  24. def latency_avg_ms(self):
  25. return self._c_stats.latency_avg_ms
  26. @property
  27. def num_iters(self):
  28. return self._c_stats.num_iters
  29. @property
  30. def iters_per_second(self):
  31. '''
  32. Returns total number of iterations per second across all calling threads
  33. '''
  34. return self.num_iters / self.total_time_seconds
  35. @property
  36. def total_time_seconds(self):
  37. return self.num_iters * (
  38. self.latency_avg_ms / 1000.0) / self.benchmark_config.num_calling_threads
  39. def __str__(self):
  40. return '\n'.join([
  41. "Average latency per example: " + format_time(time_ms=self.latency_avg_ms),
  42. "Total number of iterations: {}".format(self.num_iters),
  43. "Total number of iterations per second (across all threads): {:.2f}".format(self.iters_per_second),
  44. "Total time: " + format_time(time_s=self.total_time_seconds)
  45. ])
  46. class ThroughputBenchmark:
  47. '''
  48. This class is a wrapper around a c++ component throughput_benchmark::ThroughputBenchmark
  49. responsible for executing a PyTorch module (nn.Module or ScriptModule)
  50. under an inference server like load. It can emulate multiple calling threads
  51. to a single module provided. In the future we plan to enhance this component
  52. to support inter and intra-op parallelism as well as multiple models
  53. running in a single process.
  54. Please note that even though nn.Module is supported, it might incur an overhead
  55. from the need to hold GIL every time we execute Python code or pass around
  56. inputs as Python objects. As soon as you have a ScriptModule version of your
  57. model for inference deployment it is better to switch to using it in this
  58. benchmark.
  59. Example::
  60. >>> # xdoctest: +SKIP("undefined vars")
  61. >>> from torch.utils import ThroughputBenchmark
  62. >>> bench = ThroughputBenchmark(my_module)
  63. >>> # Pre-populate benchmark's data set with the inputs
  64. >>> for input in inputs:
  65. ... # Both args and kwargs work, same as any PyTorch Module / ScriptModule
  66. ... bench.add_input(input[0], x2=input[1])
  67. >>> # Inputs supplied above are randomly used during the execution
  68. >>> stats = bench.benchmark(
  69. ... num_calling_threads=4,
  70. ... num_warmup_iters = 100,
  71. ... num_iters = 1000,
  72. ... )
  73. >>> print("Avg latency (ms): {}".format(stats.latency_avg_ms))
  74. >>> print("Number of iterations: {}".format(stats.num_iters))
  75. '''
  76. def __init__(self, module):
  77. if isinstance(module, torch.jit.ScriptModule):
  78. self._benchmark = torch._C.ThroughputBenchmark(module._c)
  79. else:
  80. self._benchmark = torch._C.ThroughputBenchmark(module)
  81. def run_once(self, *args, **kwargs):
  82. '''
  83. Given input id (input_idx) run benchmark once and return prediction.
  84. This is useful for testing that benchmark actually runs the module you
  85. want it to run. input_idx here is an index into inputs array populated
  86. by calling add_input() method.
  87. '''
  88. return self._benchmark.run_once(*args, **kwargs)
  89. def add_input(self, *args, **kwargs):
  90. '''
  91. Store a single input to a module into the benchmark memory and keep it
  92. there. During the benchmark execution every thread is going to pick up a
  93. random input from the all the inputs ever supplied to the benchmark via
  94. this function.
  95. '''
  96. self._benchmark.add_input(*args, **kwargs)
  97. def benchmark(
  98. self,
  99. num_calling_threads=1,
  100. num_warmup_iters=10,
  101. num_iters=100,
  102. profiler_output_path=""):
  103. '''
  104. Args:
  105. num_warmup_iters (int): Warmup iters are used to make sure we run a module
  106. a few times before actually measuring things. This way we avoid cold
  107. caches and any other similar problems. This is the number of warmup
  108. iterations for each of the thread in separate
  109. num_iters (int): Number of iterations the benchmark should run with.
  110. This number is separate from the warmup iterations. Also the number is
  111. shared across all the threads. Once the num_iters iterations across all
  112. the threads is reached, we will stop execution. Though total number of
  113. iterations might be slightly larger. Which is reported as
  114. stats.num_iters where stats is the result of this function
  115. profiler_output_path (str): Location to save Autograd Profiler trace.
  116. If not empty, Autograd Profiler will be enabled for the main benchmark
  117. execution (but not the warmup phase). The full trace will be saved
  118. into the file path provided by this argument
  119. This function returns BenchmarkExecutionStats object which is defined via pybind11.
  120. It currently has two fields:
  121. - num_iters - number of actual iterations the benchmark have made
  122. - avg_latency_ms - average time it took to infer on one input example in milliseconds
  123. '''
  124. config = torch._C.BenchmarkConfig()
  125. config.num_calling_threads = num_calling_threads
  126. config.num_warmup_iters = num_warmup_iters
  127. config.num_iters = num_iters
  128. config.profiler_output_path = profiler_output_path
  129. c_stats = self._benchmark.benchmark(config)
  130. return ExecutionStats(c_stats, config)