import logging import operator from collections import defaultdict from typing import Set import torch from torch.fx import GraphModule from torch.fx.passes.backends.cudagraphs import partition_cudagraphs from torch.multiprocessing.reductions import StorageWeakRef from torch.nn import Module from torch.utils._pytree import tree_map from .common import aot_autograd from .registry import register_backend log = logging.getLogger(__name__) def cloner(t): if isinstance(t, torch.Tensor): return t.clone() else: return t class CudaGraphModule(Module): gm: GraphModule mutated_inputs: Set[int] def __init__(self, gm, mutated_inputs): super().__init__() self.gm = gm self.mutated_inputs = mutated_inputs warmed_up = False # these are all None or all filled graph = None static_inputs = None static_outputs = None # NB: we override __call__ as we don't need any nn.Module machinery # and to reduce overhead def __call__(self, *args): # TODO: once we've recorded here, we'd like to replace the __call__ # implementation with compiled bytecode that copies into static, replays # the cuda graph, then copies out. First condition is the hotpath, # needs optimizing if self.graph is not None: assert len(args) == len(self.static_inputs) for dst, src in zip(self.static_inputs, args): dst.copy_(src) self.graph.replay() for i in self.mutated_inputs: args[i].copy_(self.static_inputs[i]) return tree_map(cloner, self.static_outputs) elif self.warmed_up: # record self.static_inputs = [x.clone() for x in args] self.graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self.graph): self.static_outputs = self.gm(*self.static_inputs) # NB: recording doesn't actually run the operations, so # now we immediately replay the graph to serve up the result self.graph.replay() for i in self.mutated_inputs: args[i].copy_(self.static_inputs[i]) return tree_map(cloner, self.static_outputs) else: # warmup stream = torch.cuda.Stream() stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(stream): r = self.gm(*args) torch.cuda.current_stream().wait_stream(stream) self.warmed_up = True return r # Interpreter versions of these passes can be found at # https://gist.github.com/ezyang/df2d746cac3b2c7d55c181e37c57ef23 def find_input_mutations(g): def meta_fk(meta): return meta["val"] if "val" in meta else meta["fake_result"] inputs = defaultdict(set) input_idx = 0 mutated_inputs = set() for n in g.nodes: if n.op == "placeholder": inputs[StorageWeakRef(meta_fk(n.meta)._typed_storage())].add(input_idx) input_idx += 1 elif n.op == "call_function": if n.target is operator.getitem: continue schema = n.target._schema for i, arg in enumerate(schema.arguments): if i < len(n.args): argument = n.args[i] else: if arg.name not in n.kwargs: continue argument = n.kwargs[arg.name] mut_arg = False if arg.alias_info: if arg.alias_info.is_write: mut_arg = True if mut_arg: # TODO: not correct for args that contain tensors in a struct # like list mutated_inputs |= inputs[ StorageWeakRef(meta_fk(argument.meta)._typed_storage()) ] # TODO: error on unrecognized nodes return mutated_inputs # Mutates input graph def apply_cuda_graphs(gm): for n in gm.graph.nodes: if n.op == "call_module": assert not n.kwargs submod = gm.get_submodule(n.target) gm.delete_submodule(n.target) mutated_inputs = find_input_mutations(submod.graph) gm.add_submodule(n.target, CudaGraphModule(submod, mutated_inputs)) # NB: we didn't actually change the graph, no need for recompile def cudagraphs(model, inputs): model = partition_cudagraphs(model, inputs) apply_cuda_graphs(model) return model aot_cudagraphs = aot_autograd(fw_compiler=cudagraphs, bw_compiler=cudagraphs) # aot_cudagraphs only applies CUDA graphs to the graph. It is also helpful # for debugging and can serve as a perf baseline. # TODO(jansel): rename to just "cudagraphs"? register_backend(name="cudagraphs", compiler_fn=aot_cudagraphs) def cudagraphs_inner(model, inputs, copy_outputs=True): """This isn't registered as a backend, but is used in some benchmarks""" assert isinstance(inputs, (list, tuple)) static_inputs = [torch.zeros_like(x) for x in inputs] # warmup torch.cuda.synchronize() stream = torch.cuda.Stream() stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(stream): model(*inputs) stream.synchronize() torch.cuda.current_stream().wait_stream(stream) torch.cuda.synchronize() # record graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): static_outputs = model(*static_inputs) if not isinstance(static_outputs, (list, tuple)): static_outputs = (static_outputs,) def run(*new_inputs): assert len(static_inputs) == len(new_inputs) for dst, src in zip(static_inputs, new_inputs): dst.copy_(src) graph.replay() if copy_outputs: return [x.clone() for x in static_outputs] else: return static_outputs return run