import collections import contextlib import dataclasses import functools import itertools import logging import math import operator from typing import Dict, List, Set import sympy import torch from ..._dynamo import config as dynamo_config from .. import config, ir, scheduler from ..ir import ReductionHint from ..optimize_indexing import indexing_dtype_strength_reduction from ..utils import ( get_fused_kernel_name, instance_descriptor, sympy_product, sympy_subs, sympy_symbol, ) from ..virtualized import ops, V from .common import ( CSEVariable, DeferredLine, free_symbol_startswith, IndentedBuffer, index_prevent_reordering, Kernel, OpOverrides, PythonPrinter, SizeArg, TensorArg, ) log = logging.getLogger(__name__) def signature_of(arg): from triton.runtime.jit import JITFunction if isinstance(arg, TensorArg): tye = JITFunction._type_of(arg.dtype) if V.graph.is_unspec_arg(arg.buffer): # had unwrapped 0d tensor as scalar new_tye = tye.lstrip("*") if new_tye in ["fp16", "bf16"]: return "fp32" else: return new_tye else: return tye if isinstance(arg, SizeArg): return JITFunction._key_of(V.graph.sizevars.size_hint(arg.expr)) raise NotImplementedError(f"unhandled {type(arg)}: {arg}") def config_of(args): from ..compile_fx import ALIGNMENT def is_aligned(x): if isinstance(x, TensorArg): return x.buffer not in V.graph.unaligned_buffers if isinstance(x, SizeArg): return V.graph.sizevars.maybe_guard_multiple_of(x.expr, ALIGNMENT) raise NotImplementedError(f"unhandled {type(x)}: {x}") divisible_by_16 = [i for i, arg in enumerate(args) if is_aligned(arg)] return instance_descriptor(tuple(divisible_by_16), ()) class TritonPrinter(PythonPrinter): def _print_floor(self, expr): assert len(expr.args) == 1 return f"tl.libdevice.floor({self.paren(self._print(expr.args[0]))})" texpr = TritonPrinter().doprint pexpr = PythonPrinter().doprint def triton_compute_type(dtype): triton_type_name = str(dtype).split(".")[-1] if triton_type_name == "bool": triton_type_name = "int1" if triton_type_name in ("float16", "bfloat16"): # float16 math is done in float32 inside the kernel triton_type_name = "float32" return f"tl.{triton_type_name}" def triton_constant(value): if value == float("inf"): return 'float("inf")' elif value == float("-inf"): return 'float("-inf")' elif math.isnan(value): return 'float("nan")' return repr(value) class TritonCSEVariable(CSEVariable): def __init__(self, name): super().__init__(name) # We'll use this to track which masks the variable needs when used for indirect indexing self.mask_vars: Set[str] = set() def update_on_args(self, name, args, kwargs): # When making a variable that is going to be used in indirect indexing # if a where clause is used it should mean that the result is always a # valid index, so you shouldn't include any of the dependent variables # in the resulting load mask if name == "where": return for arg in args: if isinstance(arg, TritonCSEVariable): self.mask_vars.update(arg.mask_vars) class TritonOverrides(OpOverrides): """Map element-wise ops to Triton""" @staticmethod def to_dtype(x, dtype: torch.dtype): if dtype == torch.bool: return f"({x} != 0)" elif dtype == torch.uint8: # to work around llvm uint conversion semantics # that produces 0's for negative values return f"{x}.to(tl.int8).to(tl.uint8)" return f"{x}.to({triton_compute_type(dtype)})" @staticmethod def constant(value, dtype): type_ = torch._prims_common.dtype_to_type(dtype) return triton_constant(type_(value)) @staticmethod def abs(x): return f"tl.abs({x})" @staticmethod def libdevice_abs(x): return f"tl.libdevice.abs({x})" @staticmethod def exp(x): return f"tl.exp({x})" @staticmethod def libdevice_exp(x): return f"tl.libdevice.exp({x})" @staticmethod def exp2(x): return f"tl.libdevice.exp2({x})" @staticmethod def expm1(x): return f"tl.libdevice.expm1({x})" @staticmethod def sqrt(x): return f"tl.sqrt({x})" @staticmethod def libdevice_sqrt(x): return f"tl.libdevice.sqrt({x})" @staticmethod def relu(x): return ops.maximum("0", x) @staticmethod def minimum(a, b): return f"tl.where({a} != {a}, {a}, tl.where({a} < {b}, {a}, {b}))" @staticmethod def maximum(a, b): return f"tl.where({a} != {a}, {a}, tl.where({a} > {b}, {a}, {b}))" @staticmethod def where(a, b, c): return f"tl.where({a}, {b}, {c})" @staticmethod def cos(x): return f"tl.cos({x})" @staticmethod def libdevice_cos(x): return f"tl.libdevice.cos({x})" @staticmethod def sin(x): return f"tl.sin({x})" @staticmethod def libdevice_sin(x): return f"tl.libdevice.sin({x})" @staticmethod def index_expr(expr, dtype): return V.kernel.indexing(expr)[0] @staticmethod def masked(mask, body, other): with V.kernel.mask_loads(mask) as new_mask: result = body() return ops.where(new_mask, result, triton_constant(other)) @staticmethod def lgamma(x): return f"tl.libdevice.lgamma({x})" @staticmethod def erf(x): return f"tl.libdevice.erf({x})" @staticmethod def cosh(x): return f"tl.libdevice.cosh({x})" @staticmethod def sinh(x): return f"tl.libdevice.sinh({x})" @staticmethod def acos(x): return f"tl.libdevice.acos({x})" @staticmethod def acosh(x): return f"tl.libdevice.acosh({x})" @staticmethod def asin(x): return f"tl.libdevice.asin({x})" @staticmethod def asinh(x): return f"tl.libdevice.asinh({x})" @staticmethod def atan2(x, y): return f"tl.libdevice.atan2({x}, {y})" @staticmethod def atan(x): return f"tl.libdevice.atan({x})" @staticmethod def atanh(x): return f"tl.libdevice.atanh({x})" @staticmethod def copysign(x, y): return f"tl.libdevice.copysign({x}, {y})" @staticmethod def erfc(x): return f"tl.libdevice.erfc({x})" @staticmethod def hypot(x, y): return f"tl.libdevice.hypot({x}, {y})" @staticmethod def log10(x): return f"tl.libdevice.log10({x})" @staticmethod def nextafter(x, y): return f"tl.libdevice.nextafter({x}, {y})" @staticmethod def logical_and(a, b): return f"{a} & {b}" @staticmethod def logical_or(a, b): return f"{a} | {b}" @staticmethod def rand(seed, offset, _): # _ here to keep the contract identical to CPU rand op return f"tl.rand({seed}, {offset})" @staticmethod def randn(seed, offset, _): # _ here to keep the contract identical to CPU randn op return f"tl.randn({seed}, {offset})" @staticmethod def rsqrt(x): return f"tl.libdevice.rsqrt({x})" @staticmethod def log1p(x): return f"tl.libdevice.log1p({x})" @staticmethod def tan(x): return f"tl.libdevice.tan({x})" @staticmethod def tanh(x): return f"tl.libdevice.tanh({x})" @staticmethod def sigmoid(x): return f"tl.sigmoid({x})" @staticmethod def libdevice_sigmoid(x): return f"1/(1 + tl.libdevice.exp(-({x})))" @staticmethod def signbit(x): # XX: This is wrong for the value -0.0 in floating point return f"tl.libdevice.signbit({x}) if ({x}).dtype is tl.float32 else {x} < 0" @staticmethod def fmod(a, b): return f"tl.libdevice.fmod({a}, {b})" @staticmethod def pow(a, b): return f"tl.libdevice.pow({a}, {b})" @staticmethod def log(x): return f"tl.log({x})" @staticmethod def libdevice_log(x): return f"tl.libdevice.log({x})" @staticmethod def isinf(x): return f"tl.libdevice.isinf({x})" @staticmethod def isnan(x): return f"tl.libdevice.isnan({x})" @staticmethod def round(x): return f"tl.libdevice.nearbyint({x})" @staticmethod def floor(x): return f"tl.libdevice.floor({x})" @staticmethod def floordiv(a, b): # See the comment in lowering.div_mode. a and b are integer type. # Similar to div_floor_kernel_cuda in pytorch core. # Notice that // in triton behaves as truncdiv instead of floordiv quot = f"{a} // {b}" rem = f"{a} % {b}" return f"tl.where(({a} < 0) != ({b} < 0), tl.where({rem} != 0, {quot} - 1, {quot}), {quot})" @staticmethod def trunc(x): return f"tl.libdevice.trunc({x})" @staticmethod def truncdiv(a, b): # See the comment in lowering.div_mode. a and b are integer type. # Notice that // in triton behaves as truncdiv instead of floordiv return f"{a} // {b}" @staticmethod def ceil(x): return f"tl.libdevice.ceil({x})" @dataclasses.dataclass class IterationRanges: """ Each range tree represents multiple sets of iteration indexing in a single tiled dimension in the output kernel. If you have two loops ranges one (4, 3, 2) and another (4, 6), then the range tree will be: 4 (i0) 3 (i1) 6 (i3) 2 (i2) Where i0 is shared between both loops, but then the split into different indexing vars. All loop ranges must iterate over the same number of elements. """ def __init__( self, name: str, var_list: List[sympy.Symbol], var_ranges: Dict[sympy.Symbol, sympy.Expr], numel: sympy.Expr, prefix: str, *, kernel: "Kernel", divisor=sympy.Integer(1), length=sympy.Integer(1), ): super().__init__() self.name = name self.var_list = var_list self.var_ranges = var_ranges self.numel = numel self.prefix = prefix self.divisor = divisor self.length = length self.kernel = kernel def is_loop(self): return self.prefix == "r" and not self.kernel.persistent_reduction class IterationRangesRoot(IterationRanges): def __init__( self, name: str, numel: sympy.Expr, prefix: str, index: int, kernel: "Kernel", pid_cache=None, ): if pid_cache is None: pid_cache = {} super().__init__( name=name, var_list=[], var_ranges={}, numel=numel, prefix=prefix, kernel=kernel, ) self.index = index # Store all the nodes in one flat list self.nodes: Dict[sympy.Expr, IterationRangesEntry] = {} # This is for re-ordering program ID in triton mm template # pid_cache["tl.program_id(0)"] = pid_m self.pid_cache: Dict[str, str] = pid_cache def cache_clear(self): for node in self.nodes.values(): node.cache_clear() def lookup(self, divisor, length): """ Lookup a given RangeTreeEntry, creating it if needed """ if V.graph.sizevars.maybe_guard_equals(divisor * length, self.numel): expr = ir.FloorDiv(sympy_symbol(f"{self.prefix}index"), divisor) else: expr = ir.ModularIndexing( sympy_symbol(f"{self.prefix}index"), divisor, length ) if expr not in self.nodes: node = IterationRangesEntry( f"{self.prefix}{next(V.kernel.iter_vars_count)}", divisor, length, expr, self, ) V.kernel.range_tree_nodes[node.symbol()] = node self.var_list.append(node.symbol()) self.var_ranges[node.symbol()] = length self.nodes[expr] = node return self.nodes[expr] def construct_entries(self, lengths: List[sympy.Expr]): divisor = sympy.Integer(1) itervars = [] for length in reversed(lengths): itervars.append(self.lookup(divisor, length)) divisor = divisor * length return list(reversed(itervars)) def construct(self, lengths: List[sympy.Expr]): return [e.symbol() for e in self.construct_entries(lengths)] def vars_and_sizes(self, index: sympy.Expr): """Figure out vars from this tree used in index""" nodes = [V.kernel.range_tree_nodes.get(s) for s in index.free_symbols] nodes = [n for n in nodes if n and n.prefix == self.prefix] nodes.sort(key=lambda x: V.graph.sizevars.size_hint(x.divisor)) divisor = sympy.Integer(1) index_vars = [] sizes = [] def add(node): nonlocal divisor index_vars.append(node.symbol()) sizes.append(node.length) divisor = divisor * node.length for node in nodes: if not V.graph.sizevars.maybe_guard_equals(node.divisor, divisor): # fill in unused index var add(self.lookup(divisor, ir.FloorDiv(node.divisor, divisor))) divisor = node.divisor add(node) if not V.graph.sizevars.maybe_guard_equals(self.numel, divisor): # fill in unused index var add(self.lookup(divisor, ir.FloorDiv(self.numel, divisor))) return list(reversed(index_vars)), list(reversed(sizes)) def ranges_code(self): size = self.kernel.indexing_size_str(self.index, self.prefix) return f"tl.arange(0, {self.prefix.upper()}BLOCK){size}" def pid_cache_lookup(self, key): if key in self.pid_cache: return self.pid_cache[key] return key def codegen_header(self, code): x = self.prefix if self.is_loop(): code.writeline(f"{self.name} = {x}offset + {x}base") elif x == "r" and self.kernel.persistent_reduction: # no need to "roffset = " code.writeline( f"{self.name} = {self.ranges_code()}", ) else: pid = self.pid_cache_lookup(f"tl.program_id({self.index})") code.writelines( [ f"{x}offset = {pid} * {x.upper()}BLOCK", f"{self.name} = {x}offset + {self.ranges_code()}", ] ) code.writeline(f"{x}mask = {self.name} < {x}numel") class IterationRangesEntry(IterationRanges): def __init__( self, name: str, divisor: sympy.Expr, length: sympy.Expr, expr: sympy.Expr, parent: IterationRanges, ): super().__init__( name=name, numel=parent.numel / length, var_list=parent.var_list, var_ranges=parent.var_ranges, prefix=parent.prefix, divisor=divisor, length=length, kernel=parent.kernel, ) self.parent = parent self.codegen = functools.lru_cache(None)(self._codegen) self.expr = expr def set_name(self, name): self.codegen = lambda: name self.codegen.cache_clear = lambda: None self.name = name def cache_clear(self): self.codegen.cache_clear() def writeline(self, line): if self.is_loop(): V.kernel.indexing_code.writeline(line) else: # lift non-reduction stores outside loop V.kernel.body.writeline(line) def _codegen(self): self.writeline(f"{self.name} = " + texpr(V.kernel.rename_indexing(self.expr))) return self.name def precomputed_args(self): # for dynamic shapes, find parts of indexing expressions that have to be precomputed precomputed_args = [] if isinstance(self.expr, sympy.Symbol): return precomputed_args assert isinstance(self.expr, (ir.FloorDiv, ir.ModularIndexing)), type(self.expr) for arg in self.expr.args[1:]: if not isinstance(arg, (sympy.Integer, sympy.Symbol)): symbols = arg.free_symbols if len(symbols) > 0 and all(s.name.startswith("s") for s in symbols): precomputed_args.append(arg) return precomputed_args def symbol(self): return sympy_symbol(self.name) def __hash__(self): return hash(self.name) def __eq__(self, other): return self.name == other.name class TritonKernel(Kernel): overrides = TritonOverrides sexpr = pexpr def __init__( self, *groups, mutations=None, pid_cache=None, reduction_hint=ReductionHint.DEFAULT, ): if pid_cache is None: pid_cache = {} super().__init__() self.numels = [V.graph.sizevars.simplify(s) for s in groups] self.mutations = mutations self.range_trees = [] self.range_tree_nodes = {} self.iter_vars_count = itertools.count() self.inside_reduction = self.numels[-1] != 1 self._load_mask = None self.body = IndentedBuffer() self.indexing_code = IndentedBuffer() self.suffix = IndentedBuffer() self.outside_loop_vars = set() self.reduction_hint = reduction_hint self.persistent_reduction = self.should_use_persistent_reduction() self.initialize_range_tree(pid_cache) # define this in a closure to make cache local to object @functools.lru_cache(None) def simplify_indexing(index: sympy.Expr): index = V.graph.sizevars.simplify_with_ranges(index, self.var_ranges()) for tree in self.range_trees: index = self.combine_contiguous_dims(index, tree) return index self.simplify_indexing = simplify_indexing def should_use_persistent_reduction(self): """ Heuristic to set self.persistent_reduction and add guards if needed. """ if not (self.inside_reduction and config.triton.persistent_reductions): return False threshold = { ReductionHint.INNER: 1024, }.get(self.reduction_hint, 64) hint = V.graph.sizevars.size_hint(self.numels[-1]) if hint > threshold: return False from triton import next_power_of_2 # will need to recompile if we cross a larger power of 2 boundary V.graph.sizevars.guard_leq(self.numels[-1], next_power_of_2(hint)) return True def initialize_range_tree(self, pid_cache): names = ["xindex", "yindex", "zindex"][: len(self.numels) - 1] + ["rindex"] for i in range(len(self.numels)): self.range_trees.append( IterationRangesRoot( names[i], self.numels[i], names[i][0], i, self, pid_cache ) ) for tree in self.range_trees: # reduction indexing goes inside a loop if not tree.is_loop(): tree.codegen_header(self.body) if self.inside_reduction and self.range_trees[-1].is_loop(): # workaround for this issue: # https://gist.github.com/jansel/6527126f781559095c5531f98a4235a7 self.body.writeline(f"rbase = {self.range_trees[-1].ranges_code()}") def disable_reduction(self): @contextlib.contextmanager def ctx(): if self.numels[-1] == 1: assert not self.inside_reduction yield return if not self.persistent_reduction: # calling codegen_body() will flush all the pending buffers # and write out a reduction loop self.codegen_body() self.inside_reduction = False yield if not self.persistent_reduction: # flush out any code before opening the next loop self.codegen_body() self.inside_reduction = True return ctx() def set_ranges(self, *lengths): assert len(lengths) == len(self.range_trees) return [ ranges.construct(length) for length, ranges in zip(lengths, self.range_trees) ] @staticmethod def _split_iteration_ranges( groups: List[sympy.Expr], lengths: List[List[sympy.Expr]] ): sv = V.graph.sizevars new_ranges = [[] for _ in groups] remaining = [sv.simplify(g) for g in groups] var_count = itertools.count() def add_range(i, expr): expr = sv.simplify(expr) if not sv.maybe_guard_multiple_of(remaining[i], expr): raise CantSplit() # guard on the last item out sv.maybe_guard_equals(remaining[i], expr) remaining[i] = ir.FloorDiv(remaining[i], expr) new_ranges[i].append(expr) return next(var_count) def make_combined(size, idx1, idx2): def getter(flat_vars): return size * flat_vars[idx1] + flat_vars[idx2] return getter return_getters_groups = [] current_group = 0 for length_group in lengths: return_getters = [] for size in length_group: if sv.maybe_guard_equals(size, 1): return_getters.append(lambda _: sympy.Integer(0)) continue while ( current_group < len(remaining) and sv.size_hint(remaining[current_group]) == 1 ): # scroll to next group with remaining elements current_group += 1 if sv.size_hint(size) > sv.size_hint(remaining[current_group]): # need to break size in two if not sv.maybe_guard_multiple_of(size, remaining[current_group]): raise CantSplit() size1 = remaining[current_group] size2 = ir.FloorDiv(size, remaining[current_group]) return_getters.append( make_combined( size2, add_range(current_group, size1), add_range(current_group + 1, size2), ) ) else: return_getters.append( operator.itemgetter(add_range(current_group, size)) ) return_getters_groups.append(return_getters) assert all( V.graph.sizevars.size_hint(s) == 1 for s in remaining ), f"failed to set ranges {remaining} {lengths}" return new_ranges, return_getters_groups @classmethod def is_compatible(cls, groups: List[sympy.Expr], lengths: List[List[sympy.Expr]]): try: cls._split_iteration_ranges(groups, lengths) return True except CantSplit: return False def split_and_set_ranges(self, lengths: List[List[sympy.Expr]]): """ We may want to fuse `for i0 in s0*s1` into a tiled kernel with groups (s0, s1). To do this we need to split up the iteration space of i0 into something like: for i1 in s0: for i2 in s1: i0 = i1*s1 + i2 .... This function matches and resplits lengths to the groups of this kernel to enable tiled + non-tiled fusions. """ groups = [rt.numel for rt in self.range_trees] if not self.inside_reduction: groups[-1] = sympy.Integer(1) if len(lengths) == len(self.range_trees) and all( V.graph.sizevars.simplify(sympy_product(x) - g) == 0 for x, g in zip(lengths, groups) ): return self.set_ranges(*lengths) new_ranges, return_getters_groups = self._split_iteration_ranges( groups, lengths ) itervars = list(itertools.chain(*self.set_ranges(*new_ranges))) return [[fn(itervars) for fn in fns] for fns in return_getters_groups] def is_indirect_indexing(self, index: sympy.Expr): # tmpX means indirect indexing return free_symbol_startswith(index, "tmp") def combine_contiguous_dims(self, index: sympy.Expr, tree: IterationRangesRoot): """ More aggressive simplification to merge contiguous dims """ if isinstance(index, (sympy.Integer, sympy.Symbol)): return index index_vars, sizes = tree.vars_and_sizes(index) if len(sizes) <= 1: return index new_sizes, reindex, prune = V.graph.sizevars._simplify_loops( index_vars, sizes, index_prevent_reordering([index], index_vars, sizes) ) if new_sizes == sizes: return index new_index_vars = tree.construct(new_sizes) new_index = sympy_subs(index, dict(zip(index_vars, reindex(new_index_vars)))) return new_index def indexing( self, index: sympy.Expr, *, copy_shape=None, dense_indexing=False, override_mask=None, ): """ Compute the index and mask to pass to tl.load() or tl.store() """ index = self.simplify_indexing(index) index_vars = index.free_symbols index_str = texpr(self.rename_indexing(self.codegen_indexing(index))) mask_vars: Set[str] = set() for var in index_vars: if override_mask: pass elif var.name.startswith("tmp"): # indirect indexing cse_var = self.cse.varname_map[var.name] mask_vars.update(cse_var.mask_vars) elif var.name.startswith("s"): pass else: # var is one of xN, yN or rN assert var.name[0] in "xyr", var.name mask_vars.add(f"{var.name[0]}mask") need_dense = ( config.triton.dense_indexing or dense_indexing or self._load_mask is not None ) and index != 0 have_dense = True have_loop_vars = False dense_mask_vars = set() for tree in self.range_trees: if tree.prefix == "r" and not self.inside_reduction: continue if index_vars.intersection(tree.var_list): have_loop_vars = True have_dense = False dense_mask_vars.add(f"{tree.prefix}mask") if (need_dense and not have_dense) or isinstance(index, sympy.Integer): if copy_shape: index_str = f"{index_str} + tl.zeros({copy_shape}.shape, tl.int32)" else: index_str = f"{index_str} + tl.zeros({self.dense_size_str()}, tl.int32)" if isinstance(index, sympy.Integer): return index_str, set(), "None" else: mask_vars = dense_mask_vars elif not have_loop_vars and copy_shape: mask_vars = dense_mask_vars index_str = f"{index_str} + tl.zeros({copy_shape}.shape, tl.int32)" if override_mask: mask_vars = {override_mask} if self._load_mask: mask_vars.add(self._load_mask) self.filter_masks(mask_vars) mask_str = " & ".join(sorted(map(str, mask_vars))) if mask_vars else "None" return index_str, mask_vars, mask_str def filter_masks(self, mask_vars): for tree in self.range_trees: # Masks are superfluous if we only have one element if V.graph.sizevars.maybe_guard_equals(tree.numel, 1): mask_vars.discard(f"{tree.prefix}mask") def var_ranges(self): return dict( itertools.chain.from_iterable( tree.var_ranges.items() for tree in self.range_trees ) ) def codegen_indexing(self, expr: sympy.Expr): expr = V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges()) for sym in sorted(expr.free_symbols, key=str): if sym in self.range_tree_nodes: # if indexing expression is complicated, we precompute it on the host side # and send the result as a kernel argument replacements = {} for ps in self.range_tree_nodes[sym].precomputed_args(): replacements[ps] = V.graph.sizevars.lookup_precomputed_size(ps) if len(replacements) > 0: self.range_tree_nodes[sym].expr = sympy_subs( self.range_tree_nodes[sym].expr, replacements ) self.range_tree_nodes[sym].codegen() return expr @contextlib.contextmanager def mask_loads(self, mask): """Context manager to add an additional mask to tl.load/store""" prior = self._load_mask if prior: mask = self.cse.generate(self.compute, f"{mask} & {prior}") self._load_mask = mask with self.swap_buffers(self.compute, self.compute): # TODO(jansel): do we need a reshape here? yield mask self._load_mask = prior def load(self, name: str, index: sympy.Expr): var = self.args.input(name) indirect_indexing = self.is_indirect_indexing(index) original_index = index index, mask_vars, mask = self.indexing(index) if "rmask" in mask and not self.persistent_reduction: # This eviction policy heuristic is untested. # ptillet suggested we should try only doing this for # the first N-1 loops and not for the final loop. ep = ", eviction_policy='evict_last'" else: ep = "" # "other" below is a workaround for https://github.com/openai/triton/issues/737 # for bool, even though it's likely subject to the same bug, setting `other` leads # to LLVM errors so we are skipping it for now if ("tmp" in mask or "rmask" in mask) and V.graph.get_dtype(name) != torch.bool: other = ", other=0" else: other = "" append_broadcast = None if V.graph.is_unspec_arg(name): line = var else: if isinstance(original_index, sympy.Integer): dense_size = self.dense_size_str() line = f"tl.load({var} + ({original_index}))" append_broadcast = dense_size else: line = f"tl.load({var} + ({index}), {mask}{ep}{other})" if V.graph.get_dtype(name) in (torch.float16, torch.bfloat16): line += ".to(tl.float32)" if ( self.inside_reduction and not self.persistent_reduction and "rmask" not in mask and "tmp" not in mask and not indirect_indexing ): # can lift a common load outside of reduction loop # One exception is when this is an indirect_load. result_var = self.cse.generate( self.body, line, append_broadcast=append_broadcast ) else: result_var = self.cse.generate( self.loads, line, append_broadcast=append_broadcast ) result_var.mask_vars = mask_vars if not self.inside_reduction or "rmask" not in mask: self.outside_loop_vars.add(result_var) return result_var def store(self, name, index, value, mode=None): var = self.args.output(name) index, mask_vars, mask = self.indexing(index, dense_indexing=True) if mode is None: line = f"tl.store({var} + ({index}), {value}, {mask})" elif mode == "atomic_add": line = f"tl.atomic_add({var} + ({index}), {value}, {mask})" else: raise NotImplementedError(f"store mode={mode}") self.stores.writeline(name, line) if not self.inside_reduction: self.outside_loop_vars.add(value) def reduction(self, name, dtype, src_dtype, reduction_type, index, value): assert self.inside_reduction default = triton_constant(ir.Reduction.default_value(reduction_type, src_dtype)) masks = {f"{tree.prefix}mask" for tree in self.range_trees} self.filter_masks(masks) masks = sorted(masks) if self._load_mask: masks.append(self._load_mask) sizes = [":" for _ in self.range_trees] sizes[-1] = "None" reduction_range_prefix = self.range_trees[-1].prefix reduction_sizes = ["None" for _ in self.range_trees] reduction_sizes[-1] = ":" if reduction_type == "any": reduction_type = "max" dim = len(self.range_trees) - 1 result_var = self.cse.newvar() result_var.mask_vars = {var for var in masks if var[0] != "r"} if self.persistent_reduction: cond = " & ".join(masks) masked_value = self.cse.generate( self.compute, f"tl.where({cond}, {value}, {default})" ) result_var = self.cse.generate( self.compute, f"tl.{reduction_type}({masked_value}, {dim})[{', '.join(sizes)}]", ) elif (src_dtype, reduction_type, value) not in self.cse.reduction_cache: self.cse.reduction_cache[(src_dtype, reduction_type, value)] = result_var accumulator = f"_{result_var}" default_value = f" + {default}" if default != 0 else "" self.body.writeline( f"{accumulator} = tl.zeros({self.dense_size_str()}, {triton_compute_type(src_dtype)}){default_value}" ) accumulator_index = None if reduction_type in {"argmax", "argmin"}: accumulator_index = f"_{result_var}_index" self.body.writeline( f"{accumulator_index} = tl.zeros({self.dense_size_str()}, tl.int64)" ) updated = value if reduction_type in {"min", "argmin"}: masks.append(f"({accumulator} > {value})") elif reduction_type in {"max", "argmax"}: masks.append(f"({accumulator} < {value})") elif reduction_type == "sum": updated = f"{accumulator} + {value}" else: raise NotImplementedError(f"reduction_type {reduction_type}") cond = " & ".join(masks) if accumulator_index: # argmax or argmin self.compute.writeline( f"{accumulator_index} = tl.where({cond}, {reduction_range_prefix}index, {accumulator_index})", ) self.compute.writeline( f"{accumulator} = tl.where({cond}, {updated}, {accumulator})" ) if accumulator_index: # argmax, argmin self.suffix.writelines( [ f"{accumulator_index}_reduce = " f"tl.{reduction_type}({accumulator}, {dim})[{', '.join(sizes)}].to(tl.int32)", f"{accumulator_index}_mask = tl.arange(0, {reduction_range_prefix.upper()}BLOCK)" f"[{', '.join(reduction_sizes)}] == {accumulator_index}_reduce", f"{result_var} = tl.sum(" f"tl.where({accumulator_index}_mask, {accumulator_index}, 0), {dim})[{', '.join(sizes)}]", ] ) else: self.suffix.writeline( f"{result_var} = tl.{reduction_type}({accumulator}, {dim})[{', '.join(sizes)}]" ) else: var_name = self.cse.reduction_cache[(src_dtype, reduction_type, value)] self.suffix.writeline(f"{result_var} = {var_name}") result_var.mask_vars = var_name.mask_vars self.inside_reduction = False index, mask_vars, mask = self.indexing(index) assert "rmask" not in index self.inside_reduction = True self.outside_loop_vars.add(result_var) self.cse.store_cache[name] = result_var if name not in V.graph.removed_buffers: var = self.args.output(name) self.suffix.writeline( DeferredLine(name, f"tl.store({var} + {index}, {result_var}, {mask})") ) def codegen_body(self): """ Concat output code from index_code, loads, compute, stores, suffix into self.body. For pointwise kernels, this is called just once at the end. For reduction kernels, this generates a loop over the reduction axis. """ if not ( self.indexing_code or self.loads or self.stores or self.compute or self.suffix ): return if self.inside_reduction and not self.persistent_reduction: self.body.writeline("for roffset in range(0, rnumel, RBLOCK):") with self.body.indent(): # last range tree is always reduction self.range_trees[-1].codegen_header(self.body) self.body.splice(self.indexing_code) self.body.splice(self.loads) self.body.splice(self.compute) self.body.splice(self.stores) # invalidate any caches that came from inside the reduction loop self.cse.invalidate(self.outside_loop_vars) self.range_trees[-1].cache_clear() else: self.body.splice(self.indexing_code) self.body.splice(self.loads) self.body.splice(self.compute) self.body.splice(self.stores) self.body.splice(self.suffix) self.indexing_code.clear() self.loads.clear() self.compute.clear() self.stores.clear() self.suffix.clear() def codegen_kernel(self, name=None): from triton import next_power_of_2 code = IndentedBuffer() size_hints = [ next_power_of_2(V.graph.sizevars.size_hint(numel)) for numel in self.numels ] if self.persistent_reduction: assert self.inside_reduction heuristics = "persistent_reduction" elif self.inside_reduction: heuristics = "reduction" else: size_hints.pop() heuristics = "pointwise" if name is None: code.splice( f""" import triton import triton.language as tl from torch._inductor.ir import ReductionHint from torch._inductor.ir import TileHint from torch._inductor.triton_ops.autotune import {heuristics} from torch._inductor.utils import instance_descriptor """ ) argdefs, _, signature = self.args.python_argdefs() # maps actual expression to SizeArg if its in sizevars replacements for i, arg in enumerate(signature): if ( isinstance(arg, SizeArg) and arg.expr in V.graph.sizevars.inv_precomputed_replacements ): signature[i] = SizeArg( arg.name, V.graph.sizevars.inv_precomputed_replacements[arg.expr] ) mutated_args = set() for mutation in self.mutations: if mutation in self.args.input_buffers: mutated_args.add(self.args.input_buffers[mutation]) if mutation in self.args.inplace_buffers: mutated_args.add(self.args.inplace_buffers[mutation].inner_name) if mutation in self.args.output_buffers: mutated_args.add(self.args.output_buffers[mutation]) mutated_args = sorted(mutated_args) triton_meta = { "signature": dict(enumerate(map(signature_of, signature))), "device": V.graph.scheduler.current_device.index, "constants": {}, "mutated_arg_names": mutated_args, } for tree in self.range_trees: if tree.prefix != "r" or self.inside_reduction: sizearg = SizeArg(f"{tree.prefix}numel", tree.numel) signature.append(sizearg) triton_meta["signature"][len(argdefs)] = signature_of(sizearg) argdefs.append(f"{tree.prefix}numel") # constexpr version causes issues, see # https://github.com/pytorch/torchdynamo/pull/1362 # triton_meta["constants"][len(argdefs)] = V.graph.sizevars.size_hint( # tree.numel # ) # argdefs.append(f"{tree.prefix}numel: tl.constexpr") triton_meta["configs"] = [config_of(signature)] for tree in self.range_trees: if tree.prefix != "r" or self.inside_reduction: argdefs.append(f"{tree.prefix.upper()}BLOCK : tl.constexpr") if self.inside_reduction: reduction_hint = self.reduction_hint heuristics_line = f""" @{heuristics}( size_hints={size_hints!r}, reduction_hint={reduction_hint}, filename=__file__, meta={triton_meta!r} ) @triton.jit """ else: tile_hint = "" if len(size_hints) == 2: if len(signature) == 4: # input, output and 2 args tile_hint = "tile_hint=TileHint.SQUARE," else: tile_hint = "tile_hint=TileHint.DEFAULT," heuristics_line = f""" @{heuristics}(size_hints={size_hints!r}, {tile_hint}filename=__file__, meta={triton_meta!r}) @triton.jit """ code.splice(heuristics_line) code.writeline(f"def {name or 'KERNEL_NAME'}({', '.join(argdefs)}):") self.codegen_body() with code.indent(): if not dynamo_config.dynamic_shapes: self.codegen_static_numels(code) for old, new in self.args.aliases(): code.writeline(f"{old} = {new}") code.splice(self.body) if name is not None: return code.getvalue() wrapper = IndentedBuffer() wrapper.writeline("async_compile.triton('''") wrapper.splice(code.getvalue(), strip=True) wrapper.writeline("''')") return wrapper.getvalue() def codegen_template_wrapper(self, src_code): wrapper = IndentedBuffer() wrapper.writeline("async_compile.triton('''") wrapper.splice(src_code, strip=True) wrapper.writeline("''')") return wrapper.getvalue() def codegen_static_numels(self, code): """ We get a small speedup from hard coding numels if they are static. """ for tree in self.range_trees: if tree.prefix != "r" or self.inside_reduction: if isinstance(V.graph.sizevars.simplify(tree.numel), sympy.Integer): code.writeline( f"{tree.prefix}numel = {V.graph.sizevars.size_hint(tree.numel)}" ) elif not dynamo_config.dynamic_shapes: code.writeline( f"{tree.prefix}numel = {V.graph.sizevars.size_hint(tree.numel)} # dynamic_shapes=False" ) def indexing_size_str(self, i=None, x=None): sizes = ["None"] * (len(self.range_trees) - int(self.numels[-1] == 1)) if i is not None: sizes[i] = ":" return f"[{', '.join(sizes)}]" def dense_size_str(self): sizes = [] for tree in self.range_trees: if tree.prefix != "r" or self.inside_reduction: sizes.append(f"{tree.prefix.upper()}BLOCK") elif tree.prefix == "r" and tree.numel != 1: sizes.append("1") return f"[{', '.join(sizes)}]" def call_kernel(self, code, name: str): _, call_args, _ = self.args.python_argdefs() # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar for i in range(len(call_args)): if V.graph.is_unspec_arg(call_args[i]): call_args[i] = call_args[i] + ".item()" grid = [] # TODO(jansel): if there are constants, we shouldn't bother passing them as args for tree in self.range_trees: if isinstance(tree.numel, (sympy.Integer, sympy.Symbol)): expr = pexpr(tree.numel) else: expr = f"{name}_{tree.prefix}numel" code.writeline(f"{expr} = {pexpr(tree.numel)}") if tree.prefix != "r" or self.inside_reduction: call_args.append(expr) if tree.prefix != "r": grid.append(expr) call_args = ", ".join(call_args) stream_name = code.write_get_cuda_stream(V.graph.scheduler.current_device.index) code.writeline( f"{name}.run({call_args}, grid=grid({', '.join(grid)}), stream={stream_name})" ) def create_cse_var(self, *args, **kwargs): return TritonCSEVariable(*args, **kwargs) class TritonScheduling: def __init__(self, scheduler): self.scheduler = scheduler def group_fn(self, sizes): return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes) def can_fuse(self, node1, node2): """ Hook called by Scheduler to determine if the Triton backend can fuse node1 and node2. These nodes might already be FusedSchedulerNodes. """ _, (numel1, rnumel1) = node1.group _, (numel2, rnumel2) = node2.group if node1.is_reduction() and node2.is_reduction(): return numel1 == numel2 and rnumel1 == rnumel2 if not node1.is_reduction() and not node2.is_reduction(): if not (numel1 == numel2 and rnumel1 == rnumel2): return False if node1.is_template(): return True # skip checks for compatible tiling # check for a bad combined tiling tiling1 = self.select_tiling(node1.get_nodes(), numel1, rnumel1) tiling2 = self.select_tiling(node2.get_nodes(), numel1, rnumel1) tiling3 = self.select_tiling( node1.get_nodes() + node2.get_nodes(), numel1, rnumel1 ) if config.triton.tiling_prevents_pointwise_fusion: if len(tiling1) > 2: if len(tiling2) > 2: return tiling1 == tiling2 == tiling3 else: return tiling1 == tiling3 elif len(tiling2) > 2: return tiling2 == tiling3 return True if not node1.is_reduction() and node2.is_reduction(): assert rnumel1 == 1 and rnumel2 != 1 if numel1 == numel2 * rnumel2: if not all( TritonKernel.is_compatible((numel2, rnumel2), n.get_ranges()) for n in node1.get_nodes() ): return False if ( config.triton.tiling_prevents_reduction_fusion and not node1.is_template() ): return self.select_tiling(node1.get_nodes(), numel1) in ( (numel1, 1), (numel2, rnumel2, 1), ) return True return numel1 == numel2 assert node1.is_reduction() and not node2.is_reduction() # swap args to hit the case above return self.can_fuse_horizontal(node2, node1) can_fuse_vertical = can_fuse can_fuse_horizontal = can_fuse def codegen_nodes(self, nodes): """ Given a set of pre-fused nodes, generate a Triton kernel. """ _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group node_schedule = [] current_loop_writes = set() is_current_reductions = set() done = set() def fits_in_main_body(n): _, (node_numel, node_rnumel) = n.group return (node_numel == numel and node_rnumel == rnumel) or ( node_numel == numel * rnumel and node_rnumel == 1 ) def fits_outside_reduction(n): _, (node_numel, node_rnumel) = n.group return node_numel == numel and node_rnumel == 1 and rnumel != 1 @contextlib.contextmanager def end_current_reduction_loop(): if current_loop_writes: # flush out any other runnable nodes to reduce number of loops for other_node in nodes[index + 1 :]: if ( node not in done and fits_in_main_body(other_node) and not ( current_loop_writes & other_node.recursive_predecessors ) ): done.add(node) current_loop_writes.add(node.get_name()) is_current_reductions.add(node.is_reduction()) node_schedule.append(node) if node_schedule and node_schedule[-1] is EnableReduction: node_schedule.pop() else: node_schedule.append(DisableReduction) yield node_schedule.append(EnableReduction) current_loop_writes.clear() is_current_reductions.clear() for index, node in enumerate(nodes): if node in done: continue done.add(node) def requires_closing_previous_reduction(node, node_schedule): if rnumel == 1: return False if not current_loop_writes & node.recursive_predecessors: return False assert node_schedule and not isinstance( node_schedule[-1], (EnableReduction, DisableReduction) ) return True in is_current_reductions if fits_in_main_body(node): if requires_closing_previous_reduction(node, node_schedule): with end_current_reduction_loop(): pass # need to start a new reduction loop current_loop_writes.add(node.get_name()) is_current_reductions.add(node.is_reduction()) node_schedule.append(node) elif fits_outside_reduction(node): with end_current_reduction_loop(): node_schedule.append(node) else: raise NotImplementedError( f"unexpected group: ({numel}, {rnumel}) != {node.group[1]}" ) if dynamo_config.output_code: log.info("schedule: %s", node_schedule) return self.codegen_node_schedule(node_schedule, numel, rnumel) @staticmethod def reduction_hint(node): assert node.is_reduction() if all( dep.is_contiguous() for dep in itertools.chain(node.read_writes.reads, node.read_writes.writes) ): return ReductionHint.INNER else: return node.node.data.reduction_hint def codegen_node_schedule(self, node_schedule, numel, reduction_numel): tiled_groups = self.select_tiling(node_schedule, numel, reduction_numel) reductions = list( filter( lambda n: n not in (EnableReduction, DisableReduction) and n.is_reduction(), node_schedule, ) ) if len(reductions) > 0: hints = [self.reduction_hint(n) for n in reductions] if hints.count(hints[0]) == len(hints): reduction_hint_val = hints[0] else: reduction_hint_val = ReductionHint.DEFAULT else: reduction_hint_val = ReductionHint.DEFAULT mutations = set() for node in node_schedule: if hasattr(node, "get_mutations"): mutations.update(node.get_mutations()) with TritonKernel( *tiled_groups, reduction_hint=reduction_hint_val, mutations=mutations ) as kernel: stack = contextlib.ExitStack() for node in node_schedule: if node not in (EnableReduction, DisableReduction): node.mark_run() for node in node_schedule: if node is DisableReduction: stack.enter_context(kernel.disable_reduction()) elif node is EnableReduction: stack.close() else: # TODO - mostly works but needs a couple fixes if not dynamo_config.dynamic_shapes: # TODO - use split ranges ? indexing_dtype_strength_reduction(node._body) index_vars = kernel.split_and_set_ranges(node.get_ranges()) node.codegen(index_vars) src_code = kernel.codegen_kernel() kernel_name = self.define_kernel(src_code, node_schedule) kernel.call_kernel(V.graph.wrapper_code, kernel_name) self.scheduler.free_buffers() def define_kernel(self, src_code, node_schedule): wrapper = V.graph.wrapper_code if src_code in wrapper.kernels: kernel_name = wrapper.kernels[src_code] else: fused_name = ( get_fused_kernel_name(node_schedule) if config.triton.descriptive_kernel_names else "" ) kernel_name = "_".join(["triton", fused_name, wrapper.next_kernel_suffix()]) wrapper.kernels[src_code] = kernel_name subs_name = kernel_name if config.triton.ordered_kernel_names else "triton_" src_code = src_code.replace("KERNEL_NAME", subs_name) # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does # not use BracesBuffer, so we have no good indicator of a C++ buffer atm. src_code = src_code.replace("#pragma CMT", "#") wrapper.define_kernel(kernel_name, src_code) return kernel_name def codegen_template(self, template_node, epilogue_nodes): """ Codegen a triton template """ _, (numel, rnumel) = template_node.group assert rnumel == 1 kernel, render = template_node.node.make_kernel_render(template_node.node) with kernel: for node in [template_node, *epilogue_nodes]: node.mark_run() render() # warmup run to get the args right for node in epilogue_nodes: node.codegen(kernel.split_and_set_ranges(node.get_ranges())) src_code = kernel.codegen_template_wrapper(render()) kernel_name = self.define_kernel(src_code, [template_node, *epilogue_nodes]) kernel.call_kernel(V.graph.wrapper_code, kernel_name) self.scheduler.free_buffers() def codegen_sync(self): V.graph.wrapper_code.writeline("torch.cuda.synchronize()") @staticmethod @functools.lru_cache(32) def candidate_tilings(node): ranges, reduction_ranges = node.get_ranges() if len(ranges) <= 1: return () rw = node.pointwise_read_writes() assert len(rw.range_vars) == len(ranges) deps = [ dep for dep in itertools.chain(rw.reads, rw.writes) if dep.name not in V.graph.removed_buffers ] write_names = {dep.name for dep in rw.writes} tilings = [] for dep in deps: strides = V.graph.sizevars.stride_hints(dep.index, rw.range_vars) assert len(strides) == len(ranges) try: split = strides.index(1) + 1 if split == len(ranges): continue if all(s == 0 for s in strides[split:]): # if this is a broadcasted tensor and all dimensions after split are broadcast, # this is not a real split continue except ValueError: continue tiled_groups = ( V.graph.sizevars.simplify(sympy_product(ranges[:split])), V.graph.sizevars.simplify(sympy_product(ranges[split:])), ) # score by number of elements score = V.graph.sizevars.size_hint( sympy_product( size for size, stride in zip(ranges, strides) if stride != 0 ) ) if dep.name in write_names: # ngimel said contiguous writes is more important than reads score *= 2 if CandidateTiling.is_good_size(tiled_groups[0]): score *= 2 if CandidateTiling.is_good_size(tiled_groups[1]): score *= 2 if ( V.graph.sizevars.size_hint( score - sympy_product(itertools.chain(ranges, reduction_ranges)) ) >= 0 ): tilings.append(CandidateTiling(tiled_groups, score, dep.name)) return tilings @classmethod def select_tiling(cls, node_schedule, numel, reduction_numel=sympy.Integer(1)): """ Heuristics to decide how to tile kernels. Currently, we tile based on stride-1 dimensions. Returns: `(tile1, tile2, reduction_numel)` s.t. `tile1 * tile2 == numel` """ if reduction_numel != 1 or config.triton.max_tiles <= 1: # TODO(jansel): should we tile reductions? return (numel, reduction_numel) seen_names = set() candidate_tiles = collections.Counter() for node in EnableReduction.filter(node_schedule): for tiling in cls.candidate_tilings(node): if tiling.name in seen_names: continue seen_names.add(tiling.name) candidate_tiles[tiling.tiling] += tiling.score ranked_tilings = [tiling for tiling, score in candidate_tiles.most_common()] if config.triton.max_tiles >= 3: # Add one 3D tiling choice for i in range(1, len(ranked_tilings)): a0, a1 = ranked_tilings[0] b0, b1 = ranked_tilings[i] if V.graph.sizevars.size_hint(a1 - b1) == 0: continue if V.graph.sizevars.size_hint(a1 - b1) < 0: # swap so a0 is bigger a0, a1 = ranked_tilings[i] b0, b1 = ranked_tilings[0] assert V.graph.sizevars.size_hint(a1 - b1) > 0 if V.graph.sizevars.maybe_guard_multiple_of(a1, b1): tiling = (a0, ir.FloorDiv(a1, b1), b1) ranked_tilings = [tiling] + ranked_tilings break # only 1 choice for now for tiled_groups in ranked_tilings: new_groups = (*tiled_groups, reduction_numel) if all( TritonKernel.is_compatible(new_groups, node.get_ranges()) for node in node_schedule if isinstance(node, scheduler.SchedulerNode) ): return new_groups return (numel, reduction_numel) def flush(self): pass @dataclasses.dataclass class CandidateTiling: tiling: List[sympy.Expr] score: int # higher is better name: str = None @staticmethod def is_good_size(s): """Somewhat arbitrary heuristic used to boost scores for some sizes""" s = V.graph.sizevars.size_hint(s) return s >= 32 and (s % 32 == 0) class DisableReduction: """ Marker to invoke `kernel.disable_reduction()`. This closes a reduction loop and allows for pointwise ops to occur on the output of a reduction. """ class EnableReduction: """ Marker to end a DisableReduction block. """ @staticmethod def filter(node_schedule): """ Get the nodes from node_schedule skipping those in a DisableReduction block. """ disabled = False for node in node_schedule: if node in (EnableReduction, DisableReduction): # Don't tile stuff outside the main reduction loop disabled = node is DisableReduction elif disabled: pass else: yield node class CantSplit(Exception): pass