123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479 |
- import dataclasses
- import functools
- import itertools
- import logging
- import sys
- import warnings
- from typing import Any, Dict, List, Optional
- import functorch
- from functorch.compile import min_cut_rematerialization_partition
- import torch._dynamo.config as dynamo_config
- import torch.fx
- from torch._dynamo import logging as dynamo_logging, utils as dynamo_utils
- from torch._dynamo.utils import fake_mode_from_tensors
- from torch._functorch.aot_autograd import make_boxed_func
- from torch._subclasses.fake_tensor import FakeTensor
- from .._dynamo.backends.common import aot_autograd
- from . import config, metrics, overrides, pattern_matcher
- from .debug import DebugContext
- from .decomposition import select_decomp_table
- from .graph import GraphLowering
- from .mkldnn import convert_outplace_to_inplace
- from .utils import developer_warning, get_dtype_size, has_incompatible_cudagraph_ops
- from .virtualized import V
- log = logging.getLogger(__name__)
- ALIGNMENT = 16
- @dataclasses.dataclass
- class BoxedBool:
- value: bool
- def __bool__(self):
- return self.value
- @staticmethod
- def disable(obj):
- if isinstance(obj, BoxedBool):
- obj.value = False
- return obj
- return False
- # copy_ fails when trying to write to tensors with memory overlap,
- # for expanded dimensions (a dimension which used to have size 1 -> ?)
- # we can select one element from that dimension and write to it
- # to achieve writing to all values of that dimension of the input tensor
- def get_expanded_dims(t):
- return [i for i in range(t.ndim) if t.stride(i) == 0 and t.size(i) != 1]
- def index_expanded_dims(t, expanded_dims):
- for expanded_dim in expanded_dims:
- t = torch.ops.aten.slice(t, expanded_dim, 0, 1)
- return t
- def complex_memory_overlap(t):
- # if torch._debug_has_internal_overlap thinks this tensor potentially has
- # memory overlap internally, let's dig deeper to find out whether it's true.
- if torch._debug_has_internal_overlap(t) != 0:
- strides = t.stride()
- sizes = t.shape
- indices = list(range(len(strides)))
- indices = [x for _, x in sorted(zip(strides, indices))]
- for i in range(len(strides)):
- prev_stride = 1 if i == 0 else strides[indices[i - 1]]
- prev_size = 1 if i == 0 else sizes[indices[i - 1]]
- if strides[indices[i]] < prev_stride * prev_size:
- return True
- return False
- @functools.lru_cache(None)
- def _step_logger():
- return dynamo_logging.get_step_logger(log)
- @functools.lru_cache(None)
- def _warn_tf32_disabled():
- if (
- torch.cuda.is_available()
- and not torch.backends.cuda.matmul.allow_tf32
- and torch.cuda.get_device_capability() >= (8, 0)
- ):
- warnings.warn(
- "TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. "
- "Consider setting `torch.set_float32_matmul_precision('high')` for better performance."
- )
- def is_tf32_warning_applicable(gm: torch.fx.GraphModule):
- aten = torch.ops.aten
- tf32_ops = {
- aten.mm.default,
- aten.addmm.default,
- aten.bmm.default,
- aten.baddbmm.default,
- }
- for node in gm.graph.nodes:
- if (
- node.op == "call_function"
- and node.target in tf32_ops
- and isinstance(node.meta.get("val", None), torch.Tensor)
- and node.meta["val"].dtype == torch.float32
- and node.meta["val"].device.type == "cuda"
- ):
- return True
- return False
- @DebugContext.wrap
- def count_bytes_inner(gm, example_inputs, num_fixed=0, **kwargs):
- shape_env = _shape_env_from_inputs(example_inputs)
- graph = GraphLowering(gm, shape_env=shape_env, num_static_inputs=num_fixed)
- with V.set_graph_handler(graph):
- graph.run(*example_inputs)
- num_bytes, nodes_num_elem = graph.count_bytes()
- metrics.num_bytes_accessed += num_bytes
- metrics.nodes_num_elem += nodes_num_elem
- return make_boxed_func(gm.forward)
- @DebugContext.wrap
- @torch.utils._python_dispatch._disable_current_modes()
- def compile_fx_inner(
- gm: torch.fx.GraphModule,
- example_inputs: List[torch.Tensor],
- cudagraphs=None,
- num_fixed=0,
- is_backward=False,
- graph_id=None,
- ):
- if is_tf32_warning_applicable(gm):
- _warn_tf32_disabled()
- if dynamo_utils.count_calls(gm.graph) == 0:
- return make_boxed_func(gm.forward)
- # lift the maximum depth of the Python interpreter stack
- # to adapt large/deep models
- sys.setrecursionlimit(max(sys.getrecursionlimit(), 2000))
- _step_logger()(
- logging.INFO,
- "torchinductor compiling "
- f"{'BACKWARDS' if is_backward else 'FORWARDS'} "
- f"graph {graph_id}",
- )
- V.debug.fx_graph(gm, example_inputs)
- if cudagraphs is None:
- cudagraphs = config.triton.cudagraphs
- shape_env = _shape_env_from_inputs(example_inputs)
- fake_mode = fake_mode_from_tensors(
- example_inputs
- ) or torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
- with V.set_fake_mode(fake_mode):
- pattern_matcher.fx_passes(gm)
- V.debug.fx_graph_transformed(gm, example_inputs)
- graph = GraphLowering(
- gm,
- shape_env=shape_env,
- num_static_inputs=num_fixed,
- graph_id=graph_id,
- )
- with V.set_graph_handler(graph):
- graph.run(*example_inputs)
- compiled_fn = graph.compile_to_fn()
- if cudagraphs:
- complex_memory_overlap_inputs = any(
- complex_memory_overlap(t) for t in example_inputs
- )
- if (
- set(graph.device_types) == {"cuda"}
- and not graph.mutated_inputs
- and not has_incompatible_cudagraph_ops(gm)
- and not complex_memory_overlap_inputs
- ):
- compiled_fn = cudagraphify(
- compiled_fn, example_inputs, static_input_idxs=range(num_fixed)
- )
- else:
- BoxedBool.disable(cudagraphs)
- if len(set(graph.device_types)) > 1:
- developer_warning("skipping cudagraphs due to multiple devices")
- elif set(graph.device_types) == {"cuda"}:
- if graph.mutated_inputs:
- developer_warning("skipping cudagraphs due to input mutation")
- elif complex_memory_overlap_inputs:
- developer_warning(
- "skipping cudagraphs due to complex input striding"
- )
- result = align_inputs(compiled_fn, example_inputs, range(num_fixed))
- _step_logger()(
- logging.INFO,
- "torchinductor done compiling "
- f"{'BACKWARDS' if is_backward else 'FORWARDS'} "
- f"graph {graph_id}",
- )
- # aot autograd needs to know to pass in inputs as a list
- result._boxed_call = True
- return result
- def clone_preserve_strides(x):
- needed_size = (
- sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1
- )
- buffer = torch.as_strided(x, (needed_size,), (1,)).clone()
- return torch.as_strided(buffer, x.size(), x.stride())
- def align_inputs(model, inputs, static_input_idxs=()):
- def is_aligned(storage_offset, dtype):
- return (storage_offset * get_dtype_size(dtype)) % ALIGNMENT == 0
- check_inputs = [
- i
- for i in range(len(inputs))
- if (
- i not in static_input_idxs
- or not is_aligned(inputs[i].storage_offset(), inputs[i].dtype)
- )
- and inputs[i].device.type == "cuda"
- ]
- if len(check_inputs) == 0:
- return model
- def run(new_inputs):
- for i in check_inputs:
- if new_inputs[i].data_ptr() % ALIGNMENT:
- new_inputs[i] = clone_preserve_strides(new_inputs[i])
- return model(new_inputs)
- return run
- @dynamo_utils.dynamo_timed
- def cudagraphify(model, inputs, static_input_idxs=()):
- # if using fake tensors, defer cudagraphs until we get real inputs at runtime
- if not any(isinstance(inp, FakeTensor) for inp in inputs):
- return cudagraphify_impl(model, inputs, static_input_idxs)
- compiled_fn = None
- def run(new_inputs):
- nonlocal compiled_fn
- if compiled_fn is None:
- with dynamo_utils.preserve_rng_state():
- compiled_fn = cudagraphify_impl(model, new_inputs, static_input_idxs)
- return compiled_fn(new_inputs)
- return run
- def remove_unaligned_input_idxs(inputs, static_input_idxs):
- """
- We require all inputs to be aligned, so introduce a copy for any
- that aren't.
- """
- aligned_static_input_idxs = {
- idx for idx in static_input_idxs if (inputs[idx].data_ptr() % ALIGNMENT) == 0
- }
- if len(aligned_static_input_idxs) != len(static_input_idxs):
- return aligned_static_input_idxs
- return static_input_idxs
- def cudagraphify_impl(model, inputs, static_input_idxs=()):
- """
- Assumes inputs[static_input_idxs[i]] are always the same memory address
- """
- static_input_idxs = remove_unaligned_input_idxs(inputs, static_input_idxs)
- def static_input(x):
- """
- Copy and input while preserving strides
- """
- # TODO(jansel): figure out why this version doesn't work:
- # return torch.empty_strided(x.size(), x.stride(), dtype=x.dtype, device=x.device)
- needed_size = (
- sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1
- )
- buffer = torch.zeros(needed_size, dtype=x.dtype, device=x.device)
- return torch.as_strided(buffer, x.size(), x.stride())
- assert isinstance(inputs, (list, tuple))
- static_inputs = [
- static_input(x) if idx not in static_input_idxs else x.detach()
- for idx, x in enumerate(inputs)
- ]
- inps_expanded_dims = [
- get_expanded_dims(x) if idx not in static_input_idxs else []
- for idx, x in enumerate(inputs)
- ]
- # warmup
- torch.cuda.synchronize()
- stream = torch.cuda.Stream()
- stream.wait_stream(torch.cuda.current_stream())
- # copy static_inputs because it will be cleared in model
- with torch.cuda.stream(stream):
- model(list(static_inputs))
- stream.synchronize()
- torch.cuda.current_stream().wait_stream(stream)
- torch.cuda.synchronize()
- # record
- graph = torch.cuda.CUDAGraph()
- with torch.cuda.graph(graph, stream=stream):
- static_outputs = model(list(static_inputs))
- if not isinstance(static_outputs, (list, tuple)):
- static_outputs = (static_outputs,)
- if config.size_asserts:
- def run(new_inputs):
- assert len(static_inputs) == len(new_inputs)
- for idx, (dst, src, expanded_dims) in enumerate(
- zip(static_inputs, new_inputs, inps_expanded_dims)
- ):
- if idx in static_input_idxs:
- assert dst.data_ptr() == src.data_ptr()
- else:
- # TODO - could make one single op of multiple slices
- # and avoid dispatch.
- # Could also pre-index the `dst` tensors
- dst = index_expanded_dims(dst, expanded_dims)
- src = index_expanded_dims(src, expanded_dims)
- dst.copy_(src)
- new_inputs.clear()
- graph.replay()
- return static_outputs
- else:
- copy_indices = [
- idx for idx in range(len(static_inputs)) if idx not in static_input_idxs
- ]
- def run(new_inputs):
- for idx in copy_indices:
- src = index_expanded_dims(static_inputs[idx], inps_expanded_dims[idx])
- dst = index_expanded_dims(new_inputs[idx], inps_expanded_dims[idx])
- dst.copy_(src)
- new_inputs.clear()
- graph.replay()
- return static_outputs
- return run
- def count_tangents(fx_g: torch.fx.GraphModule):
- """
- Infers which inputs are static for a backwards graph
- """
- def is_not_gradout(x):
- return "tangents" not in x.name
- arg_count = 0
- static_arg_idxs = []
- for n in fx_g.graph.nodes:
- if n.op == "placeholder":
- if is_not_gradout(n):
- static_arg_idxs.append(arg_count)
- arg_count += 1
- assert static_arg_idxs == list(range(len(static_arg_idxs)))
- return len(static_arg_idxs)
- _graph_counter = itertools.count(0)
- def compile_fx(
- model_: torch.fx.GraphModule,
- example_inputs_: List[torch.Tensor],
- inner_compile=compile_fx_inner,
- config_patches: Optional[Dict[str, Any]] = None,
- ):
- """Main entrypoint to a compile given FX graph"""
- if config_patches:
- with config.patch(config_patches):
- return compile_fx(
- model_,
- example_inputs_,
- # need extra layer of patching as backwards is compiled out of scope
- inner_compile=config.patch(config_patches)(inner_compile),
- )
- assert not config._raise_error_for_testing
- functorch.compile.config.use_functionalize = True
- functorch.compile.config.use_fake_tensor = True
- with overrides.patch_functions():
- model_ = overrides.replace_fx(model_)
- model_ = overrides.fuse_fx(model_, example_inputs_)
- num_example_inputs = len(example_inputs_)
- cudagraphs = BoxedBool(
- config.triton.cudagraphs and not dynamo_config.dynamic_shapes
- )
- graph_id = next(_graph_counter)
- @dynamo_utils.dynamo_timed
- def fw_compiler(model: torch.fx.GraphModule, example_inputs):
- fixed = len(example_inputs) - num_example_inputs
- # Why convert outplace op to inplace? Inductor can support inplace operations well and for custom
- # inplace ops which are lowered as ExternKernel, it is beneficial to performance when the inplace
- # implementation is used if available.
- model = convert_outplace_to_inplace(model)
- return inner_compile(
- model,
- example_inputs,
- num_fixed=fixed,
- cudagraphs=cudagraphs,
- graph_id=graph_id,
- )
- @dynamo_utils.dynamo_timed
- def bw_compiler(model: torch.fx.GraphModule, example_inputs):
- fixed = count_tangents(model)
- return inner_compile(
- model,
- example_inputs,
- num_fixed=fixed,
- cudagraphs=cudagraphs,
- is_backward=True,
- graph_id=graph_id,
- )
- with overrides.patch_functions():
- # TODO: can add logging before/after the call to create_aot_dispatcher_function
- # in torch._functorch/aot_autograd.py::aot_module_simplified::aot_function_simplified::new_func
- # once torchdynamo is merged into pytorch
- return aot_autograd(
- fw_compiler=fw_compiler,
- bw_compiler=bw_compiler,
- decompositions=select_decomp_table(),
- partition_fn=functools.partial(
- min_cut_rematerialization_partition, compiler="inductor"
- ),
- keep_inference_input_mutations=True,
- )(model_, example_inputs_)
- def _shape_env_from_inputs(inputs):
- shape_env = None
- fake_mode = fake_mode_from_tensors(inputs)
- # TODO(voz): It would be nice to enable this assert, but there are lots of tests that
- # pass in real inputs for now.
- # if len(inputs) > 0:
- # assert fake_mode is not None, breakpoint()
- if fake_mode is not None:
- return fake_mode.shape_env
- # TODO(voz): Should we always have one anyway?
- return None
|