123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685 |
- import builtins
- import functools
- import inspect
- import itertools
- import logging
- import sys
- import textwrap
- from io import StringIO
- from typing import Any, List
- from unittest.mock import patch
- import sympy
- import torch
- from torch._dynamo.testing import rand_strided
- from torch._dynamo.utils import counters, identity
- from . import ir
- from .codecache import code_hash, DiskCache, PyCodeCache
- from .codegen.common import IndentedBuffer
- from .codegen.triton import config_of, signature_of, texpr, TritonKernel, TritonPrinter
- from .utils import do_bench, sympy_dot, sympy_product
- from .virtualized import V
- log = logging.getLogger(__name__)
- # correctness checks struggle with fp16/tf32
- VERIFY = False # dict(atol=1, rtol=0.05)
- PRINT_AUTOTUNE = True
- class KernelNamespace:
- pass
- # these objects are imported from the generated wrapper code
- template_kernels = KernelNamespace()
- extern_kernels = KernelNamespace()
- class TritonTemplateKernel(TritonKernel):
- def __init__(
- self,
- kernel_name,
- input_nodes,
- output_node,
- defines,
- num_stages,
- num_warps,
- grid_fn,
- meta,
- call_sizes,
- use_jit=True,
- prefix_args=0,
- suffix_args=0,
- epilogue_fn=identity,
- ):
- super().__init__(sympy_product(output_node.get_size()), sympy.Integer(1))
- self.input_nodes = input_nodes
- self.output_node = output_node
- self.named_input_nodes = {}
- self.defines = defines
- self.kernel_name = kernel_name
- self.template_mask = None
- self.use_jit = use_jit
- self.num_stages = num_stages
- self.num_warps = num_warps
- self.grid_fn = grid_fn
- self.meta = meta
- self.call_sizes = call_sizes
- # for templates with fixed epilogues
- self.prefix_args = prefix_args
- self.suffix_args = suffix_args
- self.epilogue_fn = epilogue_fn
- def jit_line(self):
- if self.use_jit:
- return "@triton.jit"
- argdefs, _, signature = self.args.python_argdefs()
- triton_meta = {
- "signature": dict(enumerate(map(signature_of, signature))),
- "device": V.graph.scheduler.current_device.index,
- "constants": {},
- }
- triton_meta["configs"] = [config_of(signature)]
- return (
- f"@template(num_stages={self.num_stages}, num_warps={self.num_warps}, meta={triton_meta!r})\n"
- + "@triton.jit"
- )
- def def_kernel(self, *argnames):
- """
- Hook called from template code to generate function def and
- needed args.
- """
- assert all(isinstance(x, str) for x in argnames)
- renames = IndentedBuffer(initial_indent=1)
- named_args = self.input_nodes[
- self.prefix_args : len(self.input_nodes) - self.suffix_args
- ]
- assert len(argnames) == len(named_args), (
- len(argnames),
- len(named_args),
- self.prefix_args,
- len(self.input_nodes),
- )
- for input_node in self.input_nodes[: self.prefix_args]:
- # get args in correct order
- self.args.input(input_node.get_name())
- for name, input_node in zip(argnames, named_args):
- arg_name = f"arg_{name}"
- self.named_input_nodes[name] = input_node
- self.args.input_buffers[input_node.get_name()] = arg_name
- if input_node.get_layout().offset == 0:
- renames.writeline(f"{name} = {arg_name}")
- else:
- offset = texpr(self.rename_indexing(input_node.get_layout().offset))
- renames.writeline(f"{name} = {arg_name} + {offset}")
- for input_node in self.input_nodes[len(self.input_nodes) - self.suffix_args :]:
- # get args in correct order
- self.args.input(input_node.get_name())
- arg_defs, *_ = self.args.python_argdefs()
- return "\n".join(
- [
- "import triton.language as tl",
- "import triton",
- "from torch._inductor.triton_ops.autotune import template",
- "from torch._inductor.utils import instance_descriptor",
- "",
- self.jit_line(),
- f"def {self.kernel_name}({', '.join(arg_defs)}):",
- self.defines,
- renames.getvalue(),
- ]
- )
- def size(self, name: str, index: int):
- """
- Hook called from template code to get the size of an arg.
- Will add needed args to pass it in if it is dynamic.
- """
- assert isinstance(name, str)
- assert isinstance(index, int)
- val = self.named_input_nodes[name].get_size()[index]
- return texpr(self.rename_indexing(val))
- def stride(self, name, index):
- """
- Hook called from template code to get the stride of an arg.
- Will add needed args to pass it in if it is dynamic.
- """
- assert isinstance(name, str)
- assert isinstance(index, int)
- val = self.named_input_nodes[name].get_stride()[index]
- return texpr(self.rename_indexing(val))
- def store_output(self, indices, val, mask):
- """
- Hook called from template code to store the final output
- (if the buffer hasn't been optimized away), then append any
- epilogue fusions.
- """
- assert isinstance(indices, (list, tuple))
- assert isinstance(val, str)
- assert isinstance(mask, str)
- if self.template_mask is None:
- indices = list(map(TritonPrinter.paren, indices))
- index_symbols = [sympy.Symbol(x) for x in indices]
- lengths = [
- V.graph.sizevars.simplify(s) for s in self.output_node.get_size()
- ]
- assert len(indices) == len(lengths)
- # glue to make generated code use same indexing from template
- for name, range_tree_entry in zip(
- indices, self.range_trees[0].construct_entries(lengths)
- ):
- range_tree_entry.set_name(name)
- contiguous_index = sympy_dot(
- ir.FlexibleLayout.contiguous_strides(lengths), index_symbols
- )
- self.body.writeline("xindex = " + texpr(contiguous_index))
- self.range_trees[0].lookup(
- sympy.Integer(1), sympy_product(lengths)
- ).set_name("xindex")
- self.template_mask = mask
- self.template_indices = indices
- output_index = self.output_node.get_layout().make_indexer()(index_symbols)
- if output_index == contiguous_index:
- output_index = sympy.Symbol("xindex")
- epilogue_args = [val]
- for input_node in itertools.chain(
- self.input_nodes[: self.prefix_args],
- self.input_nodes[len(self.input_nodes) - self.suffix_args :],
- ):
- input_node.freeze_layout()
- epilogue_args.append(input_node.make_loader()(index_symbols))
- V.ops.store(
- self.output_node.get_name(),
- output_index,
- self.epilogue_fn(*epilogue_args),
- )
- assert self.template_mask == mask
- self.codegen_body()
- return textwrap.indent(self.body.getvalue(), " ").strip()
- def make_load(self, name, indices, mask):
- """
- Optional helper called from template code to generate the code
- needed to load from an tensor.
- """
- assert isinstance(indices, (list, tuple))
- assert isinstance(name, str)
- assert isinstance(mask, str)
- stride = self.named_input_nodes[name].get_stride()
- indices = list(map(TritonPrinter.paren, indices))
- assert len(indices) == len(stride)
- index = " + ".join(
- f"{texpr(self.rename_indexing(s))} * {i}" for s, i in zip(stride, indices)
- )
- return f"tl.load({name} + ({index}), {mask})"
- def template_env(self):
- """
- Generate the namespace visible in the template.
- """
- return {
- fn.__name__: fn
- for fn in [
- self.def_kernel,
- self.size,
- self.stride,
- self.store_output,
- self.make_load,
- ]
- }
- def indexing(
- self,
- index: sympy.Expr,
- *,
- copy_shape=None,
- dense_indexing=False,
- ):
- """
- Override the default indexing to use our custom mask and force
- dense indexing.
- """
- result, *mask = super().indexing(
- index,
- dense_indexing=False,
- copy_shape=copy_shape,
- override_mask=self.template_mask,
- )
- result += f" + tl.zeros({self.template_mask}.shape, tl.int32)"
- return (result, *mask)
- def initialize_range_tree(self, pid_cache):
- super().initialize_range_tree(pid_cache)
- # ignore default codegen
- self.body.clear()
- self.indexing_code.clear()
- def call_kernel(self, code, name: str):
- _, call_args, _ = self.args.python_argdefs()
- for i in range(len(call_args)):
- if V.graph.is_unspec_arg(call_args[i]):
- call_args[i] = call_args[i] + ".item()"
- call_args = ", ".join(call_args)
- stream_name = code.write_get_cuda_stream(V.graph.scheduler.current_device.index)
- V.graph.wrapper_code.add_import_once(f"import {self.grid_fn.__module__}")
- meta = V.graph.wrapper_code.add_meta_once(self.meta)
- grid_call = [texpr(V.graph.sizevars.simplify(s)) for s in self.call_sizes] + [
- meta
- ]
- grid_call = (
- f"{self.grid_fn.__module__}.{self.grid_fn.__name__}({', '.join(grid_call)})"
- )
- code.writeline(
- f"{name}.run({call_args}, grid={grid_call}, stream={stream_name})"
- )
- @functools.lru_cache(None)
- def _jinja2_env():
- try:
- import jinja2
- return jinja2.Environment(
- undefined=jinja2.StrictUndefined,
- )
- except ImportError:
- return None
- class TritonTemplate:
- index_counter = itertools.count()
- all_templates = dict()
- @staticmethod
- def _template_from_string(source):
- env = _jinja2_env()
- if env is not None:
- return env.from_string(source)
- return None
- def __init__(self, name: str, grid: Any, source: str, debug=False):
- super().__init__()
- self.name = name
- self.grid = grid
- self.template = self._template_from_string(source)
- assert name not in self.all_templates, "duplicate template name"
- self.all_templates[name] = self
- self.debug = debug
- def generate(
- self,
- input_nodes,
- layout,
- num_stages,
- num_warps,
- prefix_args=0,
- suffix_args=0,
- epilogue_fn=identity,
- **kwargs,
- ):
- assert self.template, "requires jinja2"
- defines = StringIO()
- for name, val in kwargs.items():
- defines.write(f" {name} : tl.constexpr = {val}\n")
- defines = defines.getvalue()
- fake_out = ir.Buffer("buf_out", layout)
- kernel_name = f"triton_{self.name}"
- kernel_options = dict(
- input_nodes=input_nodes,
- defines=defines,
- num_stages=num_stages,
- num_warps=num_warps,
- grid_fn=self.grid,
- meta=kwargs,
- call_sizes=layout.size,
- prefix_args=prefix_args,
- suffix_args=suffix_args,
- epilogue_fn=epilogue_fn,
- )
- with patch.object(
- V.graph, "get_dtype", self.fake_get_dtype(fake_out)
- ), TritonTemplateKernel(
- kernel_name=kernel_name,
- output_node=fake_out,
- use_jit=True,
- **kernel_options,
- ) as kernel:
- # need to do call render twice to get all the needed args right
- self.template.render(
- **kernel.template_env(),
- **kwargs,
- )
- code = self.template.render(
- **kernel.template_env(),
- **kwargs,
- )
- if self.debug:
- print("Generated Code:\n", code)
- mod = PyCodeCache.load(code)
- run = getattr(mod, kernel_name).run
- _, call_args, _ = kernel.args.python_argdefs()
- expected_args = [x.get_name() for x in input_nodes] + [fake_out.get_name()]
- assert list(call_args) == expected_args, (call_args, expected_args)
- extra_args = V.graph.sizevars.size_hints(
- map(sympy.expand, call_args[len(expected_args) :])
- )
- assert not extra_args, "TODO: dynamic shapes"
- def call(*args, out):
- return run(
- *args,
- out,
- *extra_args,
- grid=self.grid(*out.size(), kwargs),
- num_stages=num_stages,
- num_warps=num_warps,
- )
- call.key = mod.key
- call.__file__ = mod.__file__
- kernel_hash_name = f"triton_{self.name}_{next(self.index_counter)}"
- setattr(template_kernels, kernel_hash_name, call)
- def make_kernel_render(out_node):
- kernel = TritonTemplateKernel(
- kernel_name="KERNEL_NAME",
- output_node=out_node,
- use_jit=False,
- **kernel_options,
- )
- render = functools.partial(
- self.template.render,
- **kernel.template_env(),
- **kwargs,
- )
- return kernel, render
- return TritonTemplateCaller(
- kernel_hash_name, input_nodes, layout, make_kernel_render
- )
- @staticmethod
- def fake_get_dtype(fake_out):
- _get_dtype_real = V.graph.get_dtype
- def get_dtype(name):
- if name == fake_out.get_name():
- return fake_out.get_dtype()
- return _get_dtype_real(name)
- return get_dtype
- class ExternKernelChoice:
- def __init__(self, kernel, cpp_kernel=None, *, name=None):
- super().__init__()
- name = name or kernel.__name__
- assert callable(kernel)
- assert not hasattr(extern_kernels, name), "duplicate extern kernel"
- self.name = name
- self.cpp_kernel = cpp_kernel
- setattr(extern_kernels, name, kernel)
- def to_callable(self):
- return getattr(extern_kernels, self.name)
- def call_name(self):
- return f"extern_kernels.{self.name}"
- @functools.lru_cache(None)
- def hash_key(self):
- fn = self.to_callable()
- parts = [
- self.name,
- getattr(fn, "__name__", ""),
- getattr(fn, "__module__", ""),
- ]
- try:
- parts.append(inspect.getsource(fn))
- except Exception:
- pass
- return code_hash("-".join(parts))
- def bind(self, input_nodes, layout, **kwargs):
- return ExternKernelCaller(self, input_nodes, layout, kwargs)
- class ChoiceCaller:
- def __init__(self, name, input_nodes, layout):
- super().__init__()
- self.name = name
- self.layout = layout
- self.input_nodes = input_nodes
- class TritonTemplateCaller(ChoiceCaller):
- def __init__(self, name, input_nodes, layout, make_kernel_render):
- super().__init__(name, input_nodes, layout)
- self.make_kernel_render = make_kernel_render
- def __str__(self):
- return f"TritonTemplateCaller({self.to_callable().__file__})"
- def call_name(self):
- return f"template_kernels.{self.name}"
- def to_callable(self):
- return getattr(template_kernels, self.name)
- def hash_key(self):
- return self.to_callable().key
- def output_node(self):
- return ir.TensorBox.create(
- ir.TemplateBuffer(
- layout=self.layout,
- inputs=self.input_nodes,
- make_kernel_render=self.make_kernel_render,
- )
- )
- class ExternKernelCaller(ChoiceCaller):
- def __init__(self, choice: ExternKernelChoice, input_nodes, layout, kwargs=None):
- super().__init__(choice.name, input_nodes, layout)
- self.choice = choice
- self.kwargs = kwargs or {}
- def to_callable(self):
- fn = self.choice.to_callable()
- if self.kwargs:
- return functools.partial(fn, **self.kwargs)
- else:
- return fn
- def hash_key(self):
- return "/".join(
- [
- self.choice.hash_key(),
- repr(self.kwargs),
- ]
- )
- def output_node(self):
- return ir.TensorBox.create(
- ir.ExternKernelOut(
- layout=self.layout,
- inputs=self.input_nodes,
- kernel=self.choice.call_name(),
- cpp_kernel=self.choice.cpp_kernel,
- kwargs=self.kwargs,
- )
- )
- class AlgorithmSelectorCache(DiskCache):
- def __call__(self, choices: List[ChoiceCaller], input_nodes, layout):
- if len(choices) == 1:
- return choices[0].output_node()
- def autotune():
- benchmark_fn = self.make_benchmark_fn(choices, input_nodes, layout)
- timings = {}
- for choice in choices:
- try:
- timings[choice] = benchmark_fn(
- choice.to_callable(), isinstance(choice, ExternKernelCaller)
- )
- except RuntimeError as e:
- if "invalid argument" in str(e):
- msg = textwrap.dedent(
- f"""
- {e}
- From choice {choices.index(choice)}: {choice}
- This may mean this GPU is too small for max_autotune mode.
- """
- ).strip()
- if VERIFY:
- raise RuntimeError(msg)
- else:
- log.warning(msg)
- else:
- raise
- except AssertionError as e:
- raise AssertionError(
- f"Incorrect result from choice {choices.index(choice)} {choice}\n\n{e}"
- )
- self.log_results(choices[0].name, input_nodes, timings)
- best_choice = builtins.min(timings, key=timings.__getitem__)
- return choices.index(best_choice)
- counters["inductor"]["select_algorithm_autotune"] += 1
- key = [x.hash_key() for x in choices] + [self.key_of(x) for x in input_nodes]
- return choices[self.lookup(key, autotune)].output_node()
- @classmethod
- def make_benchmark_fn(
- cls,
- choices,
- input_nodes,
- layout,
- ):
- example_inputs = [cls.benchmark_example_value(x) for x in input_nodes]
- example_inputs_extern = list(example_inputs)
- for i in range(len(example_inputs)):
- if input_nodes[i].get_layout().offset != 0:
- offset = V.graph.sizevars.size_hint(input_nodes[i].get_layout().offset)
- data = example_inputs_extern[i]
- example_inputs_extern[i] = torch.as_strided(
- data, data.size(), data.stride(), offset
- )
- out = cls.benchmark_example_value(layout)
- out_extern = torch.as_strided(
- out, out.size(), out.stride(), V.graph.sizevars.size_hint(layout.offset)
- )
- if VERIFY:
- choices[0].to_callable()(*example_inputs_extern, out=out_extern)
- expected = out_extern.clone()
- def benchmark(algo, is_extern):
- out.zero_()
- if is_extern:
- result = do_bench(lambda: algo(*example_inputs_extern, out=out_extern))
- else:
- result = do_bench(lambda: algo(*example_inputs, out=out))
- if VERIFY:
- torch.testing.assert_close(out_extern, expected, **VERIFY)
- torch.cuda.synchronize() # shake out any CUDA errors
- return result
- return benchmark
- @staticmethod
- def log_results(name, input_nodes, timings):
- if not PRINT_AUTOTUNE:
- return
- sizes = ", ".join(
- [
- "x".join(map(str, V.graph.sizevars.size_hints(n.get_size())))
- for n in input_nodes
- ]
- )
- top_k = sorted(timings, key=timings.__getitem__)[:10]
- best = top_k[0]
- best_time = timings[best][0]
- sys.stderr.write(f"AUTOTUNE {name}({sizes})\n")
- for choice in top_k:
- result = timings[choice]
- sys.stderr.write(
- f" {choice.name} {result[0]:.4f}s {best_time/result[0]:.1%}\n"
- )
- @staticmethod
- def benchmark_example_value(node):
- """
- Convert an ir.Buffer into a concrete torch.Tensor we can use for
- benchmarking.
- """
- if isinstance(node, ir.Layout):
- node = ir.Buffer("fake", node)
- return rand_strided(
- V.graph.sizevars.size_hints(node.get_size()),
- V.graph.sizevars.size_hints(node.get_stride()),
- device=node.get_device(),
- dtype=node.get_dtype(),
- extra_size=V.graph.sizevars.size_hint(node.get_layout().offset),
- )
- @staticmethod
- def key_of(node):
- """
- Extract the pieces of an ir.Buffer that we should invalidate cached
- autotuning results on.
- """
- sizevars = V.graph.sizevars
- return (
- node.get_device().type,
- str(node.get_dtype()),
- *sizevars.size_hints(node.get_size()),
- *sizevars.size_hints(node.get_stride()),
- sizevars.size_hint(node.get_layout().offset),
- )
- autotune_select_algorithm = AlgorithmSelectorCache(__name__)
- def realize_inputs(*args):
- if len(args) == 1:
- return ir.ExternKernel.require_stride1(ir.ExternKernel.realize_input(args[0]))
- return [realize_inputs(x) for x in args]
- # ensure lowering is imported so that `extern_kernels.*` is populated
- from . import lowering # noqa: F401
|