123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113 |
- from contextlib import contextmanager
- from typing import Any, List, Tuple, cast
- import random
- import torch
- import time
- from torch.utils.benchmark import Timer
- def extract_ir(filename: str) -> List[str]:
- BEGIN = "<GRAPH_EXPORT>"
- END = "</GRAPH_EXPORT>"
- pfx = None
- current = ""
- graphs = []
- with open(filename, "r") as f:
- split_strs = f.read().split(BEGIN)
- for i, split_str in enumerate(split_strs):
- if i == 0:
- continue
- end_loc = split_str.find(END)
- if end_loc == -1:
- continue
- s = split_str[:end_loc]
- pfx = split_strs[i - 1].splitlines()[-1]
- lines = [x[len(pfx):] for x in s.splitlines(keepends=True)]
- graphs.append(''.join(lines))
- return graphs
- def make_tensor_from_type(inp_type: torch._C.TensorType):
- size = inp_type.sizes()
- stride = inp_type.strides()
- device = inp_type.device()
- dtype = inp_type.dtype()
- assert size is not None
- assert stride is not None
- assert device is not None
- assert dtype is not None
- return torch.empty_strided(size=size, stride=stride, device=device, dtype=dtype)
- def load_graph_and_inputs(ir: str) -> Tuple[Any, List[Any]]:
- graph = torch._C.parse_ir(ir, parse_tensor_constants=True)
- graph.makeMultiOutputIntoTuple()
- inputs = []
- for inp in graph.inputs():
- if isinstance(inp.type(), torch._C.FloatType):
- inputs.append(random.uniform(.1, 100))
- elif isinstance(inp.type(), torch._C.IntType):
- inputs.append(random.randint(1, 100))
- elif isinstance(inp.type(), torch._C.TensorType):
- tensorType = cast(torch._C.TensorType, inp.type())
- inputs.append(make_tensor_from_type(tensorType))
- elif isinstance(inp.type(), torch._C.BoolType):
- inputs.append(random.randint(0, 1) == 1)
- else:
- raise NotImplementedError(f"A default value is not implemented for type {inp.type()}")
- func = torch._C._create_function_from_graph("forward", graph)
- torch._C._jit_pass_erase_shape_information(func.graph)
- return (func, inputs)
- def time_cuda(fn, inputs, test_runs):
- t = Timer(stmt="fn(*inputs)", globals={"fn": fn, "inputs" : inputs})
- times = t.blocked_autorange()
- return times.median * 1000 # time in ms
- def time_cpu(fn, inputs, test_runs):
- s = time.perf_counter()
- for _ in range(test_runs):
- fn(*inputs)
- e = time.perf_counter()
- return (e - s) / test_runs * 1000 # time in ms
- def run_test(ir, inputs, *, warmup_runs=10, test_runs=20) -> float:
- graph, _ = load_graph_and_inputs(ir)
- for _ in range(warmup_runs):
- graph(*inputs)
- is_cpu = None
- for input in inputs:
- if isinstance(input, torch.Tensor):
- is_cpu = input.device.type == "cpu"
- break
- assert is_cpu is not None
- out = time_cpu(graph, inputs, test_runs) if is_cpu else time_cuda(graph, inputs, test_runs)
- return out
- @contextmanager
- def no_fuser(*args, **kwargs):
- old_optimize = torch._C._get_graph_executor_optimize(False)
- try:
- yield
- finally:
- torch._C._get_graph_executor_optimize(old_optimize)
- def run_baseline_no_fusion(ir, inputs) -> float:
- with no_fuser():
- return run_test(ir, inputs)
- def run_nnc(ir, inputs, dynamic) -> float:
- try:
- strat = [("DYNAMIC", 10)] if dynamic else [("STATIC", 10)]
- old_strat = torch.jit.set_fusion_strategy(strat)
- with torch.jit.fuser("fuser1"):
- return run_test(ir, inputs)
- finally:
- torch.jit.set_fusion_strategy(old_strat)
- def run_nvfuser(ir, inputs) -> float:
- with torch.jit.fuser("fuser2"):
- return run_test(ir, inputs)
|