log_extract.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. from contextlib import contextmanager
  2. from typing import Any, List, Tuple, cast
  3. import random
  4. import torch
  5. import time
  6. from torch.utils.benchmark import Timer
  7. def extract_ir(filename: str) -> List[str]:
  8. BEGIN = "<GRAPH_EXPORT>"
  9. END = "</GRAPH_EXPORT>"
  10. pfx = None
  11. current = ""
  12. graphs = []
  13. with open(filename, "r") as f:
  14. split_strs = f.read().split(BEGIN)
  15. for i, split_str in enumerate(split_strs):
  16. if i == 0:
  17. continue
  18. end_loc = split_str.find(END)
  19. if end_loc == -1:
  20. continue
  21. s = split_str[:end_loc]
  22. pfx = split_strs[i - 1].splitlines()[-1]
  23. lines = [x[len(pfx):] for x in s.splitlines(keepends=True)]
  24. graphs.append(''.join(lines))
  25. return graphs
  26. def make_tensor_from_type(inp_type: torch._C.TensorType):
  27. size = inp_type.sizes()
  28. stride = inp_type.strides()
  29. device = inp_type.device()
  30. dtype = inp_type.dtype()
  31. assert size is not None
  32. assert stride is not None
  33. assert device is not None
  34. assert dtype is not None
  35. return torch.empty_strided(size=size, stride=stride, device=device, dtype=dtype)
  36. def load_graph_and_inputs(ir: str) -> Tuple[Any, List[Any]]:
  37. graph = torch._C.parse_ir(ir, parse_tensor_constants=True)
  38. graph.makeMultiOutputIntoTuple()
  39. inputs = []
  40. for inp in graph.inputs():
  41. if isinstance(inp.type(), torch._C.FloatType):
  42. inputs.append(random.uniform(.1, 100))
  43. elif isinstance(inp.type(), torch._C.IntType):
  44. inputs.append(random.randint(1, 100))
  45. elif isinstance(inp.type(), torch._C.TensorType):
  46. tensorType = cast(torch._C.TensorType, inp.type())
  47. inputs.append(make_tensor_from_type(tensorType))
  48. elif isinstance(inp.type(), torch._C.BoolType):
  49. inputs.append(random.randint(0, 1) == 1)
  50. else:
  51. raise NotImplementedError(f"A default value is not implemented for type {inp.type()}")
  52. func = torch._C._create_function_from_graph("forward", graph)
  53. torch._C._jit_pass_erase_shape_information(func.graph)
  54. return (func, inputs)
  55. def time_cuda(fn, inputs, test_runs):
  56. t = Timer(stmt="fn(*inputs)", globals={"fn": fn, "inputs" : inputs})
  57. times = t.blocked_autorange()
  58. return times.median * 1000 # time in ms
  59. def time_cpu(fn, inputs, test_runs):
  60. s = time.perf_counter()
  61. for _ in range(test_runs):
  62. fn(*inputs)
  63. e = time.perf_counter()
  64. return (e - s) / test_runs * 1000 # time in ms
  65. def run_test(ir, inputs, *, warmup_runs=10, test_runs=20) -> float:
  66. graph, _ = load_graph_and_inputs(ir)
  67. for _ in range(warmup_runs):
  68. graph(*inputs)
  69. is_cpu = None
  70. for input in inputs:
  71. if isinstance(input, torch.Tensor):
  72. is_cpu = input.device.type == "cpu"
  73. break
  74. assert is_cpu is not None
  75. out = time_cpu(graph, inputs, test_runs) if is_cpu else time_cuda(graph, inputs, test_runs)
  76. return out
  77. @contextmanager
  78. def no_fuser(*args, **kwargs):
  79. old_optimize = torch._C._get_graph_executor_optimize(False)
  80. try:
  81. yield
  82. finally:
  83. torch._C._get_graph_executor_optimize(old_optimize)
  84. def run_baseline_no_fusion(ir, inputs) -> float:
  85. with no_fuser():
  86. return run_test(ir, inputs)
  87. def run_nnc(ir, inputs, dynamic) -> float:
  88. try:
  89. strat = [("DYNAMIC", 10)] if dynamic else [("STATIC", 10)]
  90. old_strat = torch.jit.set_fusion_strategy(strat)
  91. with torch.jit.fuser("fuser1"):
  92. return run_test(ir, inputs)
  93. finally:
  94. torch.jit.set_fusion_strategy(old_strat)
  95. def run_nvfuser(ir, inputs) -> float:
  96. with torch.jit.fuser("fuser2"):
  97. return run_test(ir, inputs)