123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180 |
- import logging
- import operator
- from collections import defaultdict
- from typing import Set
- import torch
- from torch.fx import GraphModule
- from torch.fx.passes.backends.cudagraphs import partition_cudagraphs
- from torch.multiprocessing.reductions import StorageWeakRef
- from torch.nn import Module
- from torch.utils._pytree import tree_map
- from .common import aot_autograd
- from .registry import register_backend
- log = logging.getLogger(__name__)
- def cloner(t):
- if isinstance(t, torch.Tensor):
- return t.clone()
- else:
- return t
- class CudaGraphModule(Module):
- gm: GraphModule
- mutated_inputs: Set[int]
- def __init__(self, gm, mutated_inputs):
- super().__init__()
- self.gm = gm
- self.mutated_inputs = mutated_inputs
- warmed_up = False
- # these are all None or all filled
- graph = None
- static_inputs = None
- static_outputs = None
- # NB: we override __call__ as we don't need any nn.Module machinery
- # and to reduce overhead
- def __call__(self, *args):
- # TODO: once we've recorded here, we'd like to replace the __call__
- # implementation with compiled bytecode that copies into static, replays
- # the cuda graph, then copies out. First condition is the hotpath,
- # needs optimizing
- if self.graph is not None:
- assert len(args) == len(self.static_inputs)
- for dst, src in zip(self.static_inputs, args):
- dst.copy_(src)
- self.graph.replay()
- for i in self.mutated_inputs:
- args[i].copy_(self.static_inputs[i])
- return tree_map(cloner, self.static_outputs)
- elif self.warmed_up:
- # record
- self.static_inputs = [x.clone() for x in args]
- self.graph = torch.cuda.CUDAGraph()
- with torch.cuda.graph(self.graph):
- self.static_outputs = self.gm(*self.static_inputs)
- # NB: recording doesn't actually run the operations, so
- # now we immediately replay the graph to serve up the result
- self.graph.replay()
- for i in self.mutated_inputs:
- args[i].copy_(self.static_inputs[i])
- return tree_map(cloner, self.static_outputs)
- else:
- # warmup
- stream = torch.cuda.Stream()
- stream.wait_stream(torch.cuda.current_stream())
- with torch.cuda.stream(stream):
- r = self.gm(*args)
- torch.cuda.current_stream().wait_stream(stream)
- self.warmed_up = True
- return r
- # Interpreter versions of these passes can be found at
- # https://gist.github.com/ezyang/df2d746cac3b2c7d55c181e37c57ef23
- def find_input_mutations(g):
- def meta_fk(meta):
- return meta["val"] if "val" in meta else meta["fake_result"]
- inputs = defaultdict(set)
- input_idx = 0
- mutated_inputs = set()
- for n in g.nodes:
- if n.op == "placeholder":
- inputs[StorageWeakRef(meta_fk(n.meta)._typed_storage())].add(input_idx)
- input_idx += 1
- elif n.op == "call_function":
- if n.target is operator.getitem:
- continue
- schema = n.target._schema
- for i, arg in enumerate(schema.arguments):
- if i < len(n.args):
- argument = n.args[i]
- else:
- if arg.name not in n.kwargs:
- continue
- argument = n.kwargs[arg.name]
- mut_arg = False
- if arg.alias_info:
- if arg.alias_info.is_write:
- mut_arg = True
- if mut_arg:
- # TODO: not correct for args that contain tensors in a struct
- # like list
- mutated_inputs |= inputs[
- StorageWeakRef(meta_fk(argument.meta)._typed_storage())
- ]
- # TODO: error on unrecognized nodes
- return mutated_inputs
- # Mutates input graph
- def apply_cuda_graphs(gm):
- for n in gm.graph.nodes:
- if n.op == "call_module":
- assert not n.kwargs
- submod = gm.get_submodule(n.target)
- gm.delete_submodule(n.target)
- mutated_inputs = find_input_mutations(submod.graph)
- gm.add_submodule(n.target, CudaGraphModule(submod, mutated_inputs))
- # NB: we didn't actually change the graph, no need for recompile
- def cudagraphs(model, inputs):
- model = partition_cudagraphs(model, inputs)
- apply_cuda_graphs(model)
- return model
- aot_cudagraphs = aot_autograd(fw_compiler=cudagraphs, bw_compiler=cudagraphs)
- # aot_cudagraphs only applies CUDA graphs to the graph. It is also helpful
- # for debugging and can serve as a perf baseline.
- # TODO(jansel): rename to just "cudagraphs"?
- register_backend(name="cudagraphs", compiler_fn=aot_cudagraphs)
- def cudagraphs_inner(model, inputs, copy_outputs=True):
- """This isn't registered as a backend, but is used in some benchmarks"""
- assert isinstance(inputs, (list, tuple))
- static_inputs = [torch.zeros_like(x) for x in inputs]
- # warmup
- torch.cuda.synchronize()
- stream = torch.cuda.Stream()
- stream.wait_stream(torch.cuda.current_stream())
- with torch.cuda.stream(stream):
- model(*inputs)
- stream.synchronize()
- torch.cuda.current_stream().wait_stream(stream)
- torch.cuda.synchronize()
- # record
- graph = torch.cuda.CUDAGraph()
- with torch.cuda.graph(graph, stream=stream):
- static_outputs = model(*static_inputs)
- if not isinstance(static_outputs, (list, tuple)):
- static_outputs = (static_outputs,)
- def run(*new_inputs):
- assert len(static_inputs) == len(new_inputs)
- for dst, src in zip(static_inputs, new_inputs):
- dst.copy_(src)
- graph.replay()
- if copy_outputs:
- return [x.clone() for x in static_outputs]
- else:
- return static_outputs
- return run
|