import operator from copy import deepcopy from dataclasses import dataclass from functools import lru_cache from types import MappingProxyType from warnings import warn import torch import torch.overrides from torch._prims_common import ( _torch_dtype_to_nvfuser_dtype_map, getnvFuserDtype, Number, number_type, ) from torch.fx import GraphModule from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten if torch.cuda.is_available(): from nvfuser._C import ( # type: ignore[import] DataType, Fusion, FusionDefinition, Tensor, ) else: DataType = None import os @lru_cache(None) def get_nvprim_dump_nvtx(): return os.getenv("PYTORCH_NVFUSER_DUMP_NVTX") DEFAULT_NVFUSER_PYTHON_CONFIG = MappingProxyType( { "use_python_fusion_cache": True, "allow_single_op_fusion": False, } ) # nvFuserTensorTemplate and nvFuserScalarTemplate are helper objects # for cached construction of the nvFuser's Fusion # TODO: change what is stored in the cache for nvFuser's Tensor objects # https://github.com/pytorch/pytorch/issues/80551 @dataclass(frozen=True) class nvFuserTensorTemplate: symbolic_shape: tuple contiguity: tuple dtype: DataType is_cpu: bool @dataclass(frozen=True) class nvFuserScalarTemplate: dtype: DataType @lru_cache(maxsize=2048) def compute_symbolic_shape(shape): """Computes the symbolic shape of a tensor. nvFuser specializes on size-1 dimensions as broadcasted dimensions. -1 is used to represent any size.""" return tuple(1 if s == 1 else -1 for s in shape) @lru_cache(maxsize=2048) def compute_contiguity(shape, strides): """Computes the contiguity information to simplify internal indexing. Contiguous dimensions are represented by True, strided dimensions are represented by False. """ from nvfuser._C import compute_contiguity return compute_contiguity(shape, strides) def to_nvfuser_template_args(args): def to_nvfuser(arg): if isinstance(arg, torch.Tensor): return nvFuserTensorTemplate( compute_symbolic_shape(arg.size()), compute_contiguity(arg.size(), arg.stride()), getnvFuserDtype(arg.dtype), arg.is_cpu, # type: ignore[attr-defined] ) elif isinstance(arg, Number): return nvFuserScalarTemplate(getnvFuserDtype(number_type(arg))) else: return arg return tree_map(to_nvfuser, args) def _any_get_attr_used(call_function_nodes): return any( filter( # bug in mypy https://github.com/python/mypy/issues/12682 lambda n: any( # type: ignore[arg-type] a.op == "get_attr" for a in n.args if isinstance(a, torch.fx.Node) # type: ignore[attr-defined] ), call_function_nodes, ) ) # MyPy bug: https://github.com/python/mypy/issues/5107 @lru_cache(maxsize=1024) # type: ignore[arg-type] def make_nvfuser_fusion(gm: GraphModule, *nv_args_templates): if not torch.cuda.is_available(): raise RuntimeError( "Attempting to use nvFuser trace executor but CUDA is not available!" ) # Everything in the graph must support nvfuser for node in gm.graph.nodes: if node.op == "call_function" and node.target == operator.getitem: continue if ( node.op == "call_function" and getattr(node.target, "impl_nvfuser", None) is None ): raise ValueError( "All call_function nodes in the graph must support nvfuser. " f"Node {node} with target {node.target} does not support nvfuser" ) graph_input_nodes = list(filter(lambda n: n.op == "placeholder", gm.graph.nodes)) call_function_nodes = list( filter(lambda n: n.op == "call_function", gm.graph.nodes) ) assert len(graph_input_nodes) == len( nv_args_templates ), "Number of placeholder nodes in the graph must match number of args" assert len(nv_args_templates) > 0, "There must be at least one argument" assert ( len(call_function_nodes) > 0 ), "Graph must contain at least one call_function node" assert not _any_get_attr_used( call_function_nodes ), "Constant tensors that are saved in the graph and used as arguments are not supported yet" # Checking output dtypes output_node = next(filter(lambda n: n.op == "output", gm.graph.nodes)) orig_flat_out, _ = tree_flatten(output_node.args[0]) fusion = Fusion() with FusionDefinition(fusion) as fd: def _to_nvfuser_constant(arg): if isinstance(arg, Number): return fd.define_constant(arg) else: return arg class FusionInterpreter(torch.fx.Interpreter): def run_node(self, node): # Squeeze requires original shape of args[0] if node.target in [ torch.ops.nvprims.squeeze, torch.ops.nvprims.squeeze.default, ]: original_shape = list(node.args[0].meta["tensor_meta"].shape) assert len(node.args) == 2 args, kwargs = self.fetch_args_kwargs_from_env(node) args = [args[0], original_shape, args[1]] return self.call_function(node.target, args, node.kwargs) if node.target in [ torch.ops.nvprims.native_batch_norm, torch.ops.nvprims.native_batch_norm.default, ]: args, kwargs = self.fetch_args_kwargs_from_env(node) assert len(args) == 8 training = args[5] args6_end = tuple(map(_to_nvfuser_constant, args[6:])) args = args[:5] + (training,) + args6_end return node.target.impl_nvfuser(fd, *args, **kwargs) return super().run_node(node) def call_function(self, target, args, kwargs): # This handles tuple unpacking if target == operator.getitem: assert isinstance(args[0], tuple) return target(*args, **kwargs) args = tuple(map(_to_nvfuser_constant, args)) target = target.impl_nvfuser args = (fd,) + args return target(*args, **kwargs) def output(self, target, args, kwargs): flat_out, unflatten_spec = tree_flatten(args[0]) for o, orig_o in zip(flat_out, orig_flat_out): # casting outputs to the original data type # ensures outputs produced by fusion would always agree with original GraphModule out_dtype = _torch_dtype_to_nvfuser_dtype_map.get(orig_o.meta["tensor_meta"].dtype) # type: ignore[union-attr] assert isinstance( o, Tensor ), "output from codegen has to be tensor type" fd.add_output(fd.ops.cast(o, dtype=out_dtype)) return args[0] def templates_to_nvfuser_inputs(arg): if isinstance(arg, nvFuserTensorTemplate): x = fd.define_tensor( arg.symbolic_shape, arg.contiguity, arg.dtype, arg.is_cpu ) return x elif isinstance(arg, nvFuserScalarTemplate): x = fd.define_scalar(arg.dtype) return x else: return arg # Transforms graph to call nvfuser lowerings nv_args = tuple(map(templates_to_nvfuser_inputs, nv_args_templates)) out = FusionInterpreter(gm).run(*nv_args) flat_out, unflatten_spec = tree_flatten(out) return fusion, unflatten_spec def nvfuser_execute(gm: GraphModule, *args, executor_parameters=None): executor_parameters = executor_parameters or DEFAULT_NVFUSER_PYTHON_CONFIG flat_args, _ = tree_flatten(args) # check for cuda only fusion if any(isinstance(arg, torch.Tensor) and arg.is_cuda for arg in flat_args) and all( # type: ignore[attr-defined] ( not isinstance(arg, torch.Tensor) or (arg.is_cpu and arg.ndim == 0) # type: ignore[attr-defined] or arg.is_cuda # type: ignore[attr-defined] ) for arg in flat_args ): # Construction of the fusion is expensive and cached based on the GraphModule # and symbolic nvFuser args. nv_template_args = to_nvfuser_template_args(flat_args) use_cache = executor_parameters.get( "use_python_fusion_cache", DEFAULT_NVFUSER_PYTHON_CONFIG["use_python_fusion_cache"], ) if use_cache: fusion, unflatten_spec = make_nvfuser_fusion(gm, *nv_template_args) # type: ignore[misc] else: fusion, unflatten_spec = make_nvfuser_fusion.__wrapped__(gm, *nv_template_args) # type: ignore[misc] # Inputs to fusion.execute correspond to the same template/symbolic inputs # marked with `define_tensor/scalar` concrete_fusion_inputs = tuple( arg for arg in flat_args if isinstance(arg, (torch.Tensor, Number)) ) if get_nvprim_dump_nvtx(): torch.cuda.nvtx.range_push( "fusion: {0}, graph: {1}".format( fusion.id(), str( [ { "op": n.op, "name": n.name, "args": n.args, "kwargs": n.kwargs, } for n in gm.graph.nodes ] ), ) ) result = tree_unflatten( fusion.execute(concrete_fusion_inputs), # type: ignore[has-type] unflatten_spec, # type: ignore[has-type] ) if get_nvprim_dump_nvtx(): torch.cuda.nvtx.range_pop() return result else: warn( "nvfuser_executor is executed with non-cuda args, fallback to aten executor" ) return gm.forward(*args) class NvfuserPrimOperatorSupport(torch.fx.passes.operator_support.OperatorSupport): def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: # special case to stop lowering to nvprim when converting to an unsupported type if ( node.op == "call_function" and node.target == torch.ops.nvprims.convert_element_type.default ): return ( _torch_dtype_to_nvfuser_dtype_map.get(node.args[1]) is not None and _torch_dtype_to_nvfuser_dtype_map.get( node.args[0].meta["tensor_meta"].dtype # type: ignore[union-attr] ) is not None ) return node.op == "call_function" and ( getattr(node.target, "impl_nvfuser", None) is not None or node.target == operator.getitem ) class PartitionedInterpreter(torch.fx.Interpreter): def call_module(self, target, args, kwargs): assert isinstance(target, str) assert len(kwargs) == 0 submod = self.fetch_attr(target) # CapabilityBasedPartitioner hardcodes the name of the subgraphs with supported_ops as "fused_" + subgraph id if target.startswith("fused_"): return nvfuser_execute(submod, *args) else: return super().call_module(target, args, kwargs) class NvfuserGraphModule(torch.nn.Module): def __init__(self, gm, use_python_fusion_cache): super().__init__() self.gm = gm self.executor_parameters = {"use_python_fusion_cache": use_python_fusion_cache} def __call__(self, *args): return nvfuser_execute( self.gm, *args, executor_parameters=self.executor_parameters ) # A set of operators that are supported by nvFuser # but should not form a fusion group solely on their own _non_compute_ops = [ "torch.ops." + str(getattr(torch.ops.nvprims, prim).default) for prim in dir(torch.ops.nvprims) if isinstance(getattr(torch.ops.nvprims, prim), torch._ops.OpOverloadPacket) and getattr(torch.ops.nvprims, prim).return_type == torch._prims_common.RETURN_TYPE.VIEW ] _allowed_single_node_partition_ops = [ "torch.ops.nvprims.native_batch_norm.default", "torch.ops.nvprims.var_mean.default", "torch.ops.nvprims.var_mean.main", ] def _remove_empty_like_fill(gm: GraphModule): # Remove empty_like + fill nodes that prevent lowering to nvprims # This is a workaround for nonoptimal traces of C++ code `(1 - tensor)` # https://github.com/pytorch/pytorch/issues/86612 def pattern(scalar, tensor): # pattern for C++ trace of `scalar - tensor`. We are looking for the # pattern of aten and nvprims.sub specifically because we want to remove # the empty_like + fill nodes after lowering of AOT Autograd trace to # nvprims In the future, nvFuser might support fill, and empty_like and # this workaround can be removed. empty_like = torch.ops.aten.empty_like.default( tensor, memory_format=torch.preserve_format ) fill = torch.ops.aten.fill.Scalar(empty_like, scalar) sub = torch.ops.nvprims.sub.default(fill, tensor) return sub def replacement(scalar, tensor): return torch.ops.nvprims.sub.default(scalar, tensor) torch.fx.replace_pattern(gm, pattern, replacement) return gm # MyPy bug: https://github.com/python/mypy/issues/5107 @lru_cache(maxsize=1024) # type: ignore[arg-type] def maybe_partition_graph( gm: GraphModule, allow_single_op_fusion: bool, use_python_fusion_cache: bool ): gm = _remove_empty_like_fill(gm) supported_ops = NvfuserPrimOperatorSupport() call_function_nodes = list( filter(lambda n: n.op == "call_function", gm.graph.nodes) ) # the graph is partitioned only if at least one node is not supported by nvFuser any_unsupported = any( not supported_ops.is_node_supported(None, node) for node in call_function_nodes ) any_unsupported |= len(call_function_nodes) == 0 # When there are constant tensors in the graph, we can't partition it # because deepcopy fails. Here we just return the original graph to be # executed by eager mode # https://github.com/pytorch/pytorch/issues/84415 if ( _any_get_attr_used(call_function_nodes) or len(list(filter(lambda n: n.op == "placeholder", gm.graph.nodes))) == 0 ): return gm, True if any_unsupported: # CapabilityBasedPartitioner modifies the graph in-place so we need to make a copy of the graph gm = deepcopy(gm) partitioner = CapabilityBasedPartitioner( gm, supported_ops, allows_single_node_partition=allow_single_op_fusion, non_compute_ops=_non_compute_ops, allowed_single_node_partition_ops=_allowed_single_node_partition_ops, ) partitions = partitioner.propose_partitions() partitioner.remove_bookend_non_compute_ops(partitions) if len(partitions) == 0: warn( "No partition found for the graph. " + "This is likely because the graph is not supported by nvFuser. " + "Please use the eager ATen mode to execute the graph.", category=RuntimeWarning, ) partitioned_graph = partitioner.fuse_partitions(partitions) # Replacing graph's fused submodules with a wrapper module with # __call__() method that calls nvfuser_execute. # This avoids the need to call the interpreter on the graph for node in partitioned_graph.graph.nodes: # TODO: use a better way to identify fused submodule if node.op == "call_module" and "fused_" in node.name: nvfuser_submodule = getattr(partitioned_graph, node.name) partitioned_graph.delete_submodule(node.target) gm.add_submodule( node.target, NvfuserGraphModule(nvfuser_submodule, use_python_fusion_cache), ) # Go through the graph and replace all the nodes that were converted to # nvprims but won't be sent to nvFuser with a call to PyTorch's eager # mode. This is necessary because torch.ops.* have higher overhead than # calling the eager mode directly. for node in partitioned_graph.graph.nodes: if node.op == "call_function" and str(node.target).startswith("nvprims."): if getattr(node.target, "impl_aten", None) is not None: node.target = node.target.impl_aten partitioned_graph.graph.eliminate_dead_code() partitioned_graph.recompile() return partitioned_graph, any_unsupported else: return gm, any_unsupported class NVTXInterpreter(torch.fx.Interpreter): def run_node(self, n): torch.cuda.nvtx.range_push( "name: {0}, args: {1}, op: {2}, kwargs: {3}".format( n.name, n.args, n.op, n.kwargs ) ) result = super().run_node(n) torch.cuda.nvtx.range_pop() return result def nvfuser_execute_partitioned(gm: GraphModule, *args, executor_parameters=None): executor_parameters = executor_parameters or DEFAULT_NVFUSER_PYTHON_CONFIG # maybe_partition_graph function is cached so we can't use non-hashable arguments allow_single_op_fusion = executor_parameters.get( "allow_single_op_fusion", DEFAULT_NVFUSER_PYTHON_CONFIG["allow_single_op_fusion"], ) use_python_fusion_cache = executor_parameters.get( "use_python_fusion_cache", DEFAULT_NVFUSER_PYTHON_CONFIG["use_python_fusion_cache"], ) # When possible it's better to use nvfuser_execute directly # because it avoids GraphModule's overhead gm, is_partitioned = maybe_partition_graph( gm, allow_single_op_fusion=allow_single_op_fusion, use_python_fusion_cache=use_python_fusion_cache, ) if is_partitioned: if get_nvprim_dump_nvtx(): return NVTXInterpreter(gm).run(*args) else: return gm(*args) else: return nvfuser_execute(gm, *args, executor_parameters=executor_parameters)