12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760 |
- import collections
- import contextlib
- import dataclasses
- import functools
- import itertools
- import logging
- import math
- import operator
- from typing import Dict, List, Set
- import sympy
- import torch
- from ..._dynamo import config as dynamo_config
- from .. import config, ir, scheduler
- from ..ir import ReductionHint
- from ..optimize_indexing import indexing_dtype_strength_reduction
- from ..utils import (
- get_fused_kernel_name,
- instance_descriptor,
- sympy_product,
- sympy_subs,
- sympy_symbol,
- )
- from ..virtualized import ops, V
- from .common import (
- CSEVariable,
- DeferredLine,
- free_symbol_startswith,
- IndentedBuffer,
- index_prevent_reordering,
- Kernel,
- OpOverrides,
- PythonPrinter,
- SizeArg,
- TensorArg,
- )
- log = logging.getLogger(__name__)
- def signature_of(arg):
- from triton.runtime.jit import JITFunction
- if isinstance(arg, TensorArg):
- tye = JITFunction._type_of(arg.dtype)
- if V.graph.is_unspec_arg(arg.buffer):
- # had unwrapped 0d tensor as scalar
- new_tye = tye.lstrip("*")
- if new_tye in ["fp16", "bf16"]:
- return "fp32"
- else:
- return new_tye
- else:
- return tye
- if isinstance(arg, SizeArg):
- return JITFunction._key_of(V.graph.sizevars.size_hint(arg.expr))
- raise NotImplementedError(f"unhandled {type(arg)}: {arg}")
- def config_of(args):
- from ..compile_fx import ALIGNMENT
- def is_aligned(x):
- if isinstance(x, TensorArg):
- return x.buffer not in V.graph.unaligned_buffers
- if isinstance(x, SizeArg):
- return V.graph.sizevars.maybe_guard_multiple_of(x.expr, ALIGNMENT)
- raise NotImplementedError(f"unhandled {type(x)}: {x}")
- divisible_by_16 = [i for i, arg in enumerate(args) if is_aligned(arg)]
- return instance_descriptor(tuple(divisible_by_16), ())
- class TritonPrinter(PythonPrinter):
- def _print_floor(self, expr):
- assert len(expr.args) == 1
- return f"tl.libdevice.floor({self.paren(self._print(expr.args[0]))})"
- texpr = TritonPrinter().doprint
- pexpr = PythonPrinter().doprint
- def triton_compute_type(dtype):
- triton_type_name = str(dtype).split(".")[-1]
- if triton_type_name == "bool":
- triton_type_name = "int1"
- if triton_type_name in ("float16", "bfloat16"):
- # float16 math is done in float32 inside the kernel
- triton_type_name = "float32"
- return f"tl.{triton_type_name}"
- def triton_constant(value):
- if value == float("inf"):
- return 'float("inf")'
- elif value == float("-inf"):
- return 'float("-inf")'
- elif math.isnan(value):
- return 'float("nan")'
- return repr(value)
- class TritonCSEVariable(CSEVariable):
- def __init__(self, name):
- super().__init__(name)
- # We'll use this to track which masks the variable needs when used for indirect indexing
- self.mask_vars: Set[str] = set()
- def update_on_args(self, name, args, kwargs):
- # When making a variable that is going to be used in indirect indexing
- # if a where clause is used it should mean that the result is always a
- # valid index, so you shouldn't include any of the dependent variables
- # in the resulting load mask
- if name == "where":
- return
- for arg in args:
- if isinstance(arg, TritonCSEVariable):
- self.mask_vars.update(arg.mask_vars)
- class TritonOverrides(OpOverrides):
- """Map element-wise ops to Triton"""
- @staticmethod
- def to_dtype(x, dtype: torch.dtype):
- if dtype == torch.bool:
- return f"({x} != 0)"
- elif dtype == torch.uint8:
- # to work around llvm uint conversion semantics
- # that produces 0's for negative values
- return f"{x}.to(tl.int8).to(tl.uint8)"
- return f"{x}.to({triton_compute_type(dtype)})"
- @staticmethod
- def constant(value, dtype):
- type_ = torch._prims_common.dtype_to_type(dtype)
- return triton_constant(type_(value))
- @staticmethod
- def abs(x):
- return f"tl.abs({x})"
- @staticmethod
- def libdevice_abs(x):
- return f"tl.libdevice.abs({x})"
- @staticmethod
- def exp(x):
- return f"tl.exp({x})"
- @staticmethod
- def libdevice_exp(x):
- return f"tl.libdevice.exp({x})"
- @staticmethod
- def exp2(x):
- return f"tl.libdevice.exp2({x})"
- @staticmethod
- def expm1(x):
- return f"tl.libdevice.expm1({x})"
- @staticmethod
- def sqrt(x):
- return f"tl.sqrt({x})"
- @staticmethod
- def libdevice_sqrt(x):
- return f"tl.libdevice.sqrt({x})"
- @staticmethod
- def relu(x):
- return ops.maximum("0", x)
- @staticmethod
- def minimum(a, b):
- return f"tl.where({a} != {a}, {a}, tl.where({a} < {b}, {a}, {b}))"
- @staticmethod
- def maximum(a, b):
- return f"tl.where({a} != {a}, {a}, tl.where({a} > {b}, {a}, {b}))"
- @staticmethod
- def where(a, b, c):
- return f"tl.where({a}, {b}, {c})"
- @staticmethod
- def cos(x):
- return f"tl.cos({x})"
- @staticmethod
- def libdevice_cos(x):
- return f"tl.libdevice.cos({x})"
- @staticmethod
- def sin(x):
- return f"tl.sin({x})"
- @staticmethod
- def libdevice_sin(x):
- return f"tl.libdevice.sin({x})"
- @staticmethod
- def index_expr(expr, dtype):
- return V.kernel.indexing(expr)[0]
- @staticmethod
- def masked(mask, body, other):
- with V.kernel.mask_loads(mask) as new_mask:
- result = body()
- return ops.where(new_mask, result, triton_constant(other))
- @staticmethod
- def lgamma(x):
- return f"tl.libdevice.lgamma({x})"
- @staticmethod
- def erf(x):
- return f"tl.libdevice.erf({x})"
- @staticmethod
- def cosh(x):
- return f"tl.libdevice.cosh({x})"
- @staticmethod
- def sinh(x):
- return f"tl.libdevice.sinh({x})"
- @staticmethod
- def acos(x):
- return f"tl.libdevice.acos({x})"
- @staticmethod
- def acosh(x):
- return f"tl.libdevice.acosh({x})"
- @staticmethod
- def asin(x):
- return f"tl.libdevice.asin({x})"
- @staticmethod
- def asinh(x):
- return f"tl.libdevice.asinh({x})"
- @staticmethod
- def atan2(x, y):
- return f"tl.libdevice.atan2({x}, {y})"
- @staticmethod
- def atan(x):
- return f"tl.libdevice.atan({x})"
- @staticmethod
- def atanh(x):
- return f"tl.libdevice.atanh({x})"
- @staticmethod
- def copysign(x, y):
- return f"tl.libdevice.copysign({x}, {y})"
- @staticmethod
- def erfc(x):
- return f"tl.libdevice.erfc({x})"
- @staticmethod
- def hypot(x, y):
- return f"tl.libdevice.hypot({x}, {y})"
- @staticmethod
- def log10(x):
- return f"tl.libdevice.log10({x})"
- @staticmethod
- def nextafter(x, y):
- return f"tl.libdevice.nextafter({x}, {y})"
- @staticmethod
- def logical_and(a, b):
- return f"{a} & {b}"
- @staticmethod
- def logical_or(a, b):
- return f"{a} | {b}"
- @staticmethod
- def rand(seed, offset, _): # _ here to keep the contract identical to CPU rand op
- return f"tl.rand({seed}, {offset})"
- @staticmethod
- def randn(seed, offset, _): # _ here to keep the contract identical to CPU randn op
- return f"tl.randn({seed}, {offset})"
- @staticmethod
- def rsqrt(x):
- return f"tl.libdevice.rsqrt({x})"
- @staticmethod
- def log1p(x):
- return f"tl.libdevice.log1p({x})"
- @staticmethod
- def tan(x):
- return f"tl.libdevice.tan({x})"
- @staticmethod
- def tanh(x):
- return f"tl.libdevice.tanh({x})"
- @staticmethod
- def sigmoid(x):
- return f"tl.sigmoid({x})"
- @staticmethod
- def libdevice_sigmoid(x):
- return f"1/(1 + tl.libdevice.exp(-({x})))"
- @staticmethod
- def signbit(x):
- # XX: This is wrong for the value -0.0 in floating point
- return f"tl.libdevice.signbit({x}) if ({x}).dtype is tl.float32 else {x} < 0"
- @staticmethod
- def fmod(a, b):
- return f"tl.libdevice.fmod({a}, {b})"
- @staticmethod
- def pow(a, b):
- return f"tl.libdevice.pow({a}, {b})"
- @staticmethod
- def log(x):
- return f"tl.log({x})"
- @staticmethod
- def libdevice_log(x):
- return f"tl.libdevice.log({x})"
- @staticmethod
- def isinf(x):
- return f"tl.libdevice.isinf({x})"
- @staticmethod
- def isnan(x):
- return f"tl.libdevice.isnan({x})"
- @staticmethod
- def round(x):
- return f"tl.libdevice.nearbyint({x})"
- @staticmethod
- def floor(x):
- return f"tl.libdevice.floor({x})"
- @staticmethod
- def floordiv(a, b):
- # See the comment in lowering.div_mode. a and b are integer type.
- # Similar to div_floor_kernel_cuda in pytorch core.
- # Notice that // in triton behaves as truncdiv instead of floordiv
- quot = f"{a} // {b}"
- rem = f"{a} % {b}"
- return f"tl.where(({a} < 0) != ({b} < 0), tl.where({rem} != 0, {quot} - 1, {quot}), {quot})"
- @staticmethod
- def trunc(x):
- return f"tl.libdevice.trunc({x})"
- @staticmethod
- def truncdiv(a, b):
- # See the comment in lowering.div_mode. a and b are integer type.
- # Notice that // in triton behaves as truncdiv instead of floordiv
- return f"{a} // {b}"
- @staticmethod
- def ceil(x):
- return f"tl.libdevice.ceil({x})"
- @dataclasses.dataclass
- class IterationRanges:
- """
- Each range tree represents multiple sets of iteration indexing
- in a single tiled dimension in the output kernel.
- If you have two loops ranges one (4, 3, 2) and another (4, 6),
- then the range tree will be:
- 4 (i0)
- 3 (i1) 6 (i3)
- 2 (i2)
- Where i0 is shared between both loops, but then the split into
- different indexing vars. All loop ranges must iterate over
- the same number of elements.
- """
- def __init__(
- self,
- name: str,
- var_list: List[sympy.Symbol],
- var_ranges: Dict[sympy.Symbol, sympy.Expr],
- numel: sympy.Expr,
- prefix: str,
- *,
- kernel: "Kernel",
- divisor=sympy.Integer(1),
- length=sympy.Integer(1),
- ):
- super().__init__()
- self.name = name
- self.var_list = var_list
- self.var_ranges = var_ranges
- self.numel = numel
- self.prefix = prefix
- self.divisor = divisor
- self.length = length
- self.kernel = kernel
- def is_loop(self):
- return self.prefix == "r" and not self.kernel.persistent_reduction
- class IterationRangesRoot(IterationRanges):
- def __init__(
- self,
- name: str,
- numel: sympy.Expr,
- prefix: str,
- index: int,
- kernel: "Kernel",
- pid_cache=None,
- ):
- if pid_cache is None:
- pid_cache = {}
- super().__init__(
- name=name,
- var_list=[],
- var_ranges={},
- numel=numel,
- prefix=prefix,
- kernel=kernel,
- )
- self.index = index
- # Store all the nodes in one flat list
- self.nodes: Dict[sympy.Expr, IterationRangesEntry] = {}
- # This is for re-ordering program ID in triton mm template
- # pid_cache["tl.program_id(0)"] = pid_m
- self.pid_cache: Dict[str, str] = pid_cache
- def cache_clear(self):
- for node in self.nodes.values():
- node.cache_clear()
- def lookup(self, divisor, length):
- """
- Lookup a given RangeTreeEntry, creating it if needed
- """
- if V.graph.sizevars.maybe_guard_equals(divisor * length, self.numel):
- expr = ir.FloorDiv(sympy_symbol(f"{self.prefix}index"), divisor)
- else:
- expr = ir.ModularIndexing(
- sympy_symbol(f"{self.prefix}index"), divisor, length
- )
- if expr not in self.nodes:
- node = IterationRangesEntry(
- f"{self.prefix}{next(V.kernel.iter_vars_count)}",
- divisor,
- length,
- expr,
- self,
- )
- V.kernel.range_tree_nodes[node.symbol()] = node
- self.var_list.append(node.symbol())
- self.var_ranges[node.symbol()] = length
- self.nodes[expr] = node
- return self.nodes[expr]
- def construct_entries(self, lengths: List[sympy.Expr]):
- divisor = sympy.Integer(1)
- itervars = []
- for length in reversed(lengths):
- itervars.append(self.lookup(divisor, length))
- divisor = divisor * length
- return list(reversed(itervars))
- def construct(self, lengths: List[sympy.Expr]):
- return [e.symbol() for e in self.construct_entries(lengths)]
- def vars_and_sizes(self, index: sympy.Expr):
- """Figure out vars from this tree used in index"""
- nodes = [V.kernel.range_tree_nodes.get(s) for s in index.free_symbols]
- nodes = [n for n in nodes if n and n.prefix == self.prefix]
- nodes.sort(key=lambda x: V.graph.sizevars.size_hint(x.divisor))
- divisor = sympy.Integer(1)
- index_vars = []
- sizes = []
- def add(node):
- nonlocal divisor
- index_vars.append(node.symbol())
- sizes.append(node.length)
- divisor = divisor * node.length
- for node in nodes:
- if not V.graph.sizevars.maybe_guard_equals(node.divisor, divisor):
- # fill in unused index var
- add(self.lookup(divisor, ir.FloorDiv(node.divisor, divisor)))
- divisor = node.divisor
- add(node)
- if not V.graph.sizevars.maybe_guard_equals(self.numel, divisor):
- # fill in unused index var
- add(self.lookup(divisor, ir.FloorDiv(self.numel, divisor)))
- return list(reversed(index_vars)), list(reversed(sizes))
- def ranges_code(self):
- size = self.kernel.indexing_size_str(self.index, self.prefix)
- return f"tl.arange(0, {self.prefix.upper()}BLOCK){size}"
- def pid_cache_lookup(self, key):
- if key in self.pid_cache:
- return self.pid_cache[key]
- return key
- def codegen_header(self, code):
- x = self.prefix
- if self.is_loop():
- code.writeline(f"{self.name} = {x}offset + {x}base")
- elif x == "r" and self.kernel.persistent_reduction:
- # no need to "roffset = "
- code.writeline(
- f"{self.name} = {self.ranges_code()}",
- )
- else:
- pid = self.pid_cache_lookup(f"tl.program_id({self.index})")
- code.writelines(
- [
- f"{x}offset = {pid} * {x.upper()}BLOCK",
- f"{self.name} = {x}offset + {self.ranges_code()}",
- ]
- )
- code.writeline(f"{x}mask = {self.name} < {x}numel")
- class IterationRangesEntry(IterationRanges):
- def __init__(
- self,
- name: str,
- divisor: sympy.Expr,
- length: sympy.Expr,
- expr: sympy.Expr,
- parent: IterationRanges,
- ):
- super().__init__(
- name=name,
- numel=parent.numel / length,
- var_list=parent.var_list,
- var_ranges=parent.var_ranges,
- prefix=parent.prefix,
- divisor=divisor,
- length=length,
- kernel=parent.kernel,
- )
- self.parent = parent
- self.codegen = functools.lru_cache(None)(self._codegen)
- self.expr = expr
- def set_name(self, name):
- self.codegen = lambda: name
- self.codegen.cache_clear = lambda: None
- self.name = name
- def cache_clear(self):
- self.codegen.cache_clear()
- def writeline(self, line):
- if self.is_loop():
- V.kernel.indexing_code.writeline(line)
- else:
- # lift non-reduction stores outside loop
- V.kernel.body.writeline(line)
- def _codegen(self):
- self.writeline(f"{self.name} = " + texpr(V.kernel.rename_indexing(self.expr)))
- return self.name
- def precomputed_args(self):
- # for dynamic shapes, find parts of indexing expressions that have to be precomputed
- precomputed_args = []
- if isinstance(self.expr, sympy.Symbol):
- return precomputed_args
- assert isinstance(self.expr, (ir.FloorDiv, ir.ModularIndexing)), type(self.expr)
- for arg in self.expr.args[1:]:
- if not isinstance(arg, (sympy.Integer, sympy.Symbol)):
- symbols = arg.free_symbols
- if len(symbols) > 0 and all(s.name.startswith("s") for s in symbols):
- precomputed_args.append(arg)
- return precomputed_args
- def symbol(self):
- return sympy_symbol(self.name)
- def __hash__(self):
- return hash(self.name)
- def __eq__(self, other):
- return self.name == other.name
- class TritonKernel(Kernel):
- overrides = TritonOverrides
- sexpr = pexpr
- def __init__(
- self,
- *groups,
- mutations=None,
- pid_cache=None,
- reduction_hint=ReductionHint.DEFAULT,
- ):
- if pid_cache is None:
- pid_cache = {}
- super().__init__()
- self.numels = [V.graph.sizevars.simplify(s) for s in groups]
- self.mutations = mutations
- self.range_trees = []
- self.range_tree_nodes = {}
- self.iter_vars_count = itertools.count()
- self.inside_reduction = self.numels[-1] != 1
- self._load_mask = None
- self.body = IndentedBuffer()
- self.indexing_code = IndentedBuffer()
- self.suffix = IndentedBuffer()
- self.outside_loop_vars = set()
- self.reduction_hint = reduction_hint
- self.persistent_reduction = self.should_use_persistent_reduction()
- self.initialize_range_tree(pid_cache)
- # define this in a closure to make cache local to object
- @functools.lru_cache(None)
- def simplify_indexing(index: sympy.Expr):
- index = V.graph.sizevars.simplify_with_ranges(index, self.var_ranges())
- for tree in self.range_trees:
- index = self.combine_contiguous_dims(index, tree)
- return index
- self.simplify_indexing = simplify_indexing
- def should_use_persistent_reduction(self):
- """
- Heuristic to set self.persistent_reduction and add guards
- if needed.
- """
- if not (self.inside_reduction and config.triton.persistent_reductions):
- return False
- threshold = {
- ReductionHint.INNER: 1024,
- }.get(self.reduction_hint, 64)
- hint = V.graph.sizevars.size_hint(self.numels[-1])
- if hint > threshold:
- return False
- from triton import next_power_of_2
- # will need to recompile if we cross a larger power of 2 boundary
- V.graph.sizevars.guard_leq(self.numels[-1], next_power_of_2(hint))
- return True
- def initialize_range_tree(self, pid_cache):
- names = ["xindex", "yindex", "zindex"][: len(self.numels) - 1] + ["rindex"]
- for i in range(len(self.numels)):
- self.range_trees.append(
- IterationRangesRoot(
- names[i], self.numels[i], names[i][0], i, self, pid_cache
- )
- )
- for tree in self.range_trees:
- # reduction indexing goes inside a loop
- if not tree.is_loop():
- tree.codegen_header(self.body)
- if self.inside_reduction and self.range_trees[-1].is_loop():
- # workaround for this issue:
- # https://gist.github.com/jansel/6527126f781559095c5531f98a4235a7
- self.body.writeline(f"rbase = {self.range_trees[-1].ranges_code()}")
- def disable_reduction(self):
- @contextlib.contextmanager
- def ctx():
- if self.numels[-1] == 1:
- assert not self.inside_reduction
- yield
- return
- if not self.persistent_reduction:
- # calling codegen_body() will flush all the pending buffers
- # and write out a reduction loop
- self.codegen_body()
- self.inside_reduction = False
- yield
- if not self.persistent_reduction:
- # flush out any code before opening the next loop
- self.codegen_body()
- self.inside_reduction = True
- return ctx()
- def set_ranges(self, *lengths):
- assert len(lengths) == len(self.range_trees)
- return [
- ranges.construct(length)
- for length, ranges in zip(lengths, self.range_trees)
- ]
- @staticmethod
- def _split_iteration_ranges(
- groups: List[sympy.Expr], lengths: List[List[sympy.Expr]]
- ):
- sv = V.graph.sizevars
- new_ranges = [[] for _ in groups]
- remaining = [sv.simplify(g) for g in groups]
- var_count = itertools.count()
- def add_range(i, expr):
- expr = sv.simplify(expr)
- if not sv.maybe_guard_multiple_of(remaining[i], expr):
- raise CantSplit()
- # guard on the last item out
- sv.maybe_guard_equals(remaining[i], expr)
- remaining[i] = ir.FloorDiv(remaining[i], expr)
- new_ranges[i].append(expr)
- return next(var_count)
- def make_combined(size, idx1, idx2):
- def getter(flat_vars):
- return size * flat_vars[idx1] + flat_vars[idx2]
- return getter
- return_getters_groups = []
- current_group = 0
- for length_group in lengths:
- return_getters = []
- for size in length_group:
- if sv.maybe_guard_equals(size, 1):
- return_getters.append(lambda _: sympy.Integer(0))
- continue
- while (
- current_group < len(remaining)
- and sv.size_hint(remaining[current_group]) == 1
- ):
- # scroll to next group with remaining elements
- current_group += 1
- if sv.size_hint(size) > sv.size_hint(remaining[current_group]):
- # need to break size in two
- if not sv.maybe_guard_multiple_of(size, remaining[current_group]):
- raise CantSplit()
- size1 = remaining[current_group]
- size2 = ir.FloorDiv(size, remaining[current_group])
- return_getters.append(
- make_combined(
- size2,
- add_range(current_group, size1),
- add_range(current_group + 1, size2),
- )
- )
- else:
- return_getters.append(
- operator.itemgetter(add_range(current_group, size))
- )
- return_getters_groups.append(return_getters)
- assert all(
- V.graph.sizevars.size_hint(s) == 1 for s in remaining
- ), f"failed to set ranges {remaining} {lengths}"
- return new_ranges, return_getters_groups
- @classmethod
- def is_compatible(cls, groups: List[sympy.Expr], lengths: List[List[sympy.Expr]]):
- try:
- cls._split_iteration_ranges(groups, lengths)
- return True
- except CantSplit:
- return False
- def split_and_set_ranges(self, lengths: List[List[sympy.Expr]]):
- """
- We may want to fuse `for i0 in s0*s1` into a tiled kernel with groups (s0, s1).
- To do this we need to split up the iteration space of i0 into something like:
- for i1 in s0:
- for i2 in s1:
- i0 = i1*s1 + i2
- ....
- This function matches and resplits lengths to the groups of
- this kernel to enable tiled + non-tiled fusions.
- """
- groups = [rt.numel for rt in self.range_trees]
- if not self.inside_reduction:
- groups[-1] = sympy.Integer(1)
- if len(lengths) == len(self.range_trees) and all(
- V.graph.sizevars.simplify(sympy_product(x) - g) == 0
- for x, g in zip(lengths, groups)
- ):
- return self.set_ranges(*lengths)
- new_ranges, return_getters_groups = self._split_iteration_ranges(
- groups, lengths
- )
- itervars = list(itertools.chain(*self.set_ranges(*new_ranges)))
- return [[fn(itervars) for fn in fns] for fns in return_getters_groups]
- def is_indirect_indexing(self, index: sympy.Expr):
- # tmpX means indirect indexing
- return free_symbol_startswith(index, "tmp")
- def combine_contiguous_dims(self, index: sympy.Expr, tree: IterationRangesRoot):
- """
- More aggressive simplification to merge contiguous dims
- """
- if isinstance(index, (sympy.Integer, sympy.Symbol)):
- return index
- index_vars, sizes = tree.vars_and_sizes(index)
- if len(sizes) <= 1:
- return index
- new_sizes, reindex, prune = V.graph.sizevars._simplify_loops(
- index_vars, sizes, index_prevent_reordering([index], index_vars, sizes)
- )
- if new_sizes == sizes:
- return index
- new_index_vars = tree.construct(new_sizes)
- new_index = sympy_subs(index, dict(zip(index_vars, reindex(new_index_vars))))
- return new_index
- def indexing(
- self,
- index: sympy.Expr,
- *,
- copy_shape=None,
- dense_indexing=False,
- override_mask=None,
- ):
- """
- Compute the index and mask to pass to tl.load() or tl.store()
- """
- index = self.simplify_indexing(index)
- index_vars = index.free_symbols
- index_str = texpr(self.rename_indexing(self.codegen_indexing(index)))
- mask_vars: Set[str] = set()
- for var in index_vars:
- if override_mask:
- pass
- elif var.name.startswith("tmp"):
- # indirect indexing
- cse_var = self.cse.varname_map[var.name]
- mask_vars.update(cse_var.mask_vars)
- elif var.name.startswith("s"):
- pass
- else:
- # var is one of xN, yN or rN
- assert var.name[0] in "xyr", var.name
- mask_vars.add(f"{var.name[0]}mask")
- need_dense = (
- config.triton.dense_indexing
- or dense_indexing
- or self._load_mask is not None
- ) and index != 0
- have_dense = True
- have_loop_vars = False
- dense_mask_vars = set()
- for tree in self.range_trees:
- if tree.prefix == "r" and not self.inside_reduction:
- continue
- if index_vars.intersection(tree.var_list):
- have_loop_vars = True
- have_dense = False
- dense_mask_vars.add(f"{tree.prefix}mask")
- if (need_dense and not have_dense) or isinstance(index, sympy.Integer):
- if copy_shape:
- index_str = f"{index_str} + tl.zeros({copy_shape}.shape, tl.int32)"
- else:
- index_str = f"{index_str} + tl.zeros({self.dense_size_str()}, tl.int32)"
- if isinstance(index, sympy.Integer):
- return index_str, set(), "None"
- else:
- mask_vars = dense_mask_vars
- elif not have_loop_vars and copy_shape:
- mask_vars = dense_mask_vars
- index_str = f"{index_str} + tl.zeros({copy_shape}.shape, tl.int32)"
- if override_mask:
- mask_vars = {override_mask}
- if self._load_mask:
- mask_vars.add(self._load_mask)
- self.filter_masks(mask_vars)
- mask_str = " & ".join(sorted(map(str, mask_vars))) if mask_vars else "None"
- return index_str, mask_vars, mask_str
- def filter_masks(self, mask_vars):
- for tree in self.range_trees:
- # Masks are superfluous if we only have one element
- if V.graph.sizevars.maybe_guard_equals(tree.numel, 1):
- mask_vars.discard(f"{tree.prefix}mask")
- def var_ranges(self):
- return dict(
- itertools.chain.from_iterable(
- tree.var_ranges.items() for tree in self.range_trees
- )
- )
- def codegen_indexing(self, expr: sympy.Expr):
- expr = V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges())
- for sym in sorted(expr.free_symbols, key=str):
- if sym in self.range_tree_nodes:
- # if indexing expression is complicated, we precompute it on the host side
- # and send the result as a kernel argument
- replacements = {}
- for ps in self.range_tree_nodes[sym].precomputed_args():
- replacements[ps] = V.graph.sizevars.lookup_precomputed_size(ps)
- if len(replacements) > 0:
- self.range_tree_nodes[sym].expr = sympy_subs(
- self.range_tree_nodes[sym].expr, replacements
- )
- self.range_tree_nodes[sym].codegen()
- return expr
- @contextlib.contextmanager
- def mask_loads(self, mask):
- """Context manager to add an additional mask to tl.load/store"""
- prior = self._load_mask
- if prior:
- mask = self.cse.generate(self.compute, f"{mask} & {prior}")
- self._load_mask = mask
- with self.swap_buffers(self.compute, self.compute):
- # TODO(jansel): do we need a reshape here?
- yield mask
- self._load_mask = prior
- def load(self, name: str, index: sympy.Expr):
- var = self.args.input(name)
- indirect_indexing = self.is_indirect_indexing(index)
- original_index = index
- index, mask_vars, mask = self.indexing(index)
- if "rmask" in mask and not self.persistent_reduction:
- # This eviction policy heuristic is untested.
- # ptillet suggested we should try only doing this for
- # the first N-1 loops and not for the final loop.
- ep = ", eviction_policy='evict_last'"
- else:
- ep = ""
- # "other" below is a workaround for https://github.com/openai/triton/issues/737
- # for bool, even though it's likely subject to the same bug, setting `other` leads
- # to LLVM errors so we are skipping it for now
- if ("tmp" in mask or "rmask" in mask) and V.graph.get_dtype(name) != torch.bool:
- other = ", other=0"
- else:
- other = ""
- append_broadcast = None
- if V.graph.is_unspec_arg(name):
- line = var
- else:
- if isinstance(original_index, sympy.Integer):
- dense_size = self.dense_size_str()
- line = f"tl.load({var} + ({original_index}))"
- append_broadcast = dense_size
- else:
- line = f"tl.load({var} + ({index}), {mask}{ep}{other})"
- if V.graph.get_dtype(name) in (torch.float16, torch.bfloat16):
- line += ".to(tl.float32)"
- if (
- self.inside_reduction
- and not self.persistent_reduction
- and "rmask" not in mask
- and "tmp" not in mask
- and not indirect_indexing
- ):
- # can lift a common load outside of reduction loop
- # One exception is when this is an indirect_load.
- result_var = self.cse.generate(
- self.body, line, append_broadcast=append_broadcast
- )
- else:
- result_var = self.cse.generate(
- self.loads, line, append_broadcast=append_broadcast
- )
- result_var.mask_vars = mask_vars
- if not self.inside_reduction or "rmask" not in mask:
- self.outside_loop_vars.add(result_var)
- return result_var
- def store(self, name, index, value, mode=None):
- var = self.args.output(name)
- index, mask_vars, mask = self.indexing(index, dense_indexing=True)
- if mode is None:
- line = f"tl.store({var} + ({index}), {value}, {mask})"
- elif mode == "atomic_add":
- line = f"tl.atomic_add({var} + ({index}), {value}, {mask})"
- else:
- raise NotImplementedError(f"store mode={mode}")
- self.stores.writeline(name, line)
- if not self.inside_reduction:
- self.outside_loop_vars.add(value)
- def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
- assert self.inside_reduction
- default = triton_constant(ir.Reduction.default_value(reduction_type, src_dtype))
- masks = {f"{tree.prefix}mask" for tree in self.range_trees}
- self.filter_masks(masks)
- masks = sorted(masks)
- if self._load_mask:
- masks.append(self._load_mask)
- sizes = [":" for _ in self.range_trees]
- sizes[-1] = "None"
- reduction_range_prefix = self.range_trees[-1].prefix
- reduction_sizes = ["None" for _ in self.range_trees]
- reduction_sizes[-1] = ":"
- if reduction_type == "any":
- reduction_type = "max"
- dim = len(self.range_trees) - 1
- result_var = self.cse.newvar()
- result_var.mask_vars = {var for var in masks if var[0] != "r"}
- if self.persistent_reduction:
- cond = " & ".join(masks)
- masked_value = self.cse.generate(
- self.compute, f"tl.where({cond}, {value}, {default})"
- )
- result_var = self.cse.generate(
- self.compute,
- f"tl.{reduction_type}({masked_value}, {dim})[{', '.join(sizes)}]",
- )
- elif (src_dtype, reduction_type, value) not in self.cse.reduction_cache:
- self.cse.reduction_cache[(src_dtype, reduction_type, value)] = result_var
- accumulator = f"_{result_var}"
- default_value = f" + {default}" if default != 0 else ""
- self.body.writeline(
- f"{accumulator} = tl.zeros({self.dense_size_str()}, {triton_compute_type(src_dtype)}){default_value}"
- )
- accumulator_index = None
- if reduction_type in {"argmax", "argmin"}:
- accumulator_index = f"_{result_var}_index"
- self.body.writeline(
- f"{accumulator_index} = tl.zeros({self.dense_size_str()}, tl.int64)"
- )
- updated = value
- if reduction_type in {"min", "argmin"}:
- masks.append(f"({accumulator} > {value})")
- elif reduction_type in {"max", "argmax"}:
- masks.append(f"({accumulator} < {value})")
- elif reduction_type == "sum":
- updated = f"{accumulator} + {value}"
- else:
- raise NotImplementedError(f"reduction_type {reduction_type}")
- cond = " & ".join(masks)
- if accumulator_index:
- # argmax or argmin
- self.compute.writeline(
- f"{accumulator_index} = tl.where({cond}, {reduction_range_prefix}index, {accumulator_index})",
- )
- self.compute.writeline(
- f"{accumulator} = tl.where({cond}, {updated}, {accumulator})"
- )
- if accumulator_index:
- # argmax, argmin
- self.suffix.writelines(
- [
- f"{accumulator_index}_reduce = "
- f"tl.{reduction_type}({accumulator}, {dim})[{', '.join(sizes)}].to(tl.int32)",
- f"{accumulator_index}_mask = tl.arange(0, {reduction_range_prefix.upper()}BLOCK)"
- f"[{', '.join(reduction_sizes)}] == {accumulator_index}_reduce",
- f"{result_var} = tl.sum("
- f"tl.where({accumulator_index}_mask, {accumulator_index}, 0), {dim})[{', '.join(sizes)}]",
- ]
- )
- else:
- self.suffix.writeline(
- f"{result_var} = tl.{reduction_type}({accumulator}, {dim})[{', '.join(sizes)}]"
- )
- else:
- var_name = self.cse.reduction_cache[(src_dtype, reduction_type, value)]
- self.suffix.writeline(f"{result_var} = {var_name}")
- result_var.mask_vars = var_name.mask_vars
- self.inside_reduction = False
- index, mask_vars, mask = self.indexing(index)
- assert "rmask" not in index
- self.inside_reduction = True
- self.outside_loop_vars.add(result_var)
- self.cse.store_cache[name] = result_var
- if name not in V.graph.removed_buffers:
- var = self.args.output(name)
- self.suffix.writeline(
- DeferredLine(name, f"tl.store({var} + {index}, {result_var}, {mask})")
- )
- def codegen_body(self):
- """
- Concat output code from index_code, loads, compute, stores,
- suffix into self.body.
- For pointwise kernels, this is called just once at the end.
- For reduction kernels, this generates a loop over the reduction
- axis.
- """
- if not (
- self.indexing_code
- or self.loads
- or self.stores
- or self.compute
- or self.suffix
- ):
- return
- if self.inside_reduction and not self.persistent_reduction:
- self.body.writeline("for roffset in range(0, rnumel, RBLOCK):")
- with self.body.indent():
- # last range tree is always reduction
- self.range_trees[-1].codegen_header(self.body)
- self.body.splice(self.indexing_code)
- self.body.splice(self.loads)
- self.body.splice(self.compute)
- self.body.splice(self.stores)
- # invalidate any caches that came from inside the reduction loop
- self.cse.invalidate(self.outside_loop_vars)
- self.range_trees[-1].cache_clear()
- else:
- self.body.splice(self.indexing_code)
- self.body.splice(self.loads)
- self.body.splice(self.compute)
- self.body.splice(self.stores)
- self.body.splice(self.suffix)
- self.indexing_code.clear()
- self.loads.clear()
- self.compute.clear()
- self.stores.clear()
- self.suffix.clear()
- def codegen_kernel(self, name=None):
- from triton import next_power_of_2
- code = IndentedBuffer()
- size_hints = [
- next_power_of_2(V.graph.sizevars.size_hint(numel)) for numel in self.numels
- ]
- if self.persistent_reduction:
- assert self.inside_reduction
- heuristics = "persistent_reduction"
- elif self.inside_reduction:
- heuristics = "reduction"
- else:
- size_hints.pop()
- heuristics = "pointwise"
- if name is None:
- code.splice(
- f"""
- import triton
- import triton.language as tl
- from torch._inductor.ir import ReductionHint
- from torch._inductor.ir import TileHint
- from torch._inductor.triton_ops.autotune import {heuristics}
- from torch._inductor.utils import instance_descriptor
- """
- )
- argdefs, _, signature = self.args.python_argdefs()
- # maps actual expression to SizeArg if its in sizevars replacements
- for i, arg in enumerate(signature):
- if (
- isinstance(arg, SizeArg)
- and arg.expr in V.graph.sizevars.inv_precomputed_replacements
- ):
- signature[i] = SizeArg(
- arg.name, V.graph.sizevars.inv_precomputed_replacements[arg.expr]
- )
- mutated_args = set()
- for mutation in self.mutations:
- if mutation in self.args.input_buffers:
- mutated_args.add(self.args.input_buffers[mutation])
- if mutation in self.args.inplace_buffers:
- mutated_args.add(self.args.inplace_buffers[mutation].inner_name)
- if mutation in self.args.output_buffers:
- mutated_args.add(self.args.output_buffers[mutation])
- mutated_args = sorted(mutated_args)
- triton_meta = {
- "signature": dict(enumerate(map(signature_of, signature))),
- "device": V.graph.scheduler.current_device.index,
- "constants": {},
- "mutated_arg_names": mutated_args,
- }
- for tree in self.range_trees:
- if tree.prefix != "r" or self.inside_reduction:
- sizearg = SizeArg(f"{tree.prefix}numel", tree.numel)
- signature.append(sizearg)
- triton_meta["signature"][len(argdefs)] = signature_of(sizearg)
- argdefs.append(f"{tree.prefix}numel")
- # constexpr version causes issues, see
- # https://github.com/pytorch/torchdynamo/pull/1362
- # triton_meta["constants"][len(argdefs)] = V.graph.sizevars.size_hint(
- # tree.numel
- # )
- # argdefs.append(f"{tree.prefix}numel: tl.constexpr")
- triton_meta["configs"] = [config_of(signature)]
- for tree in self.range_trees:
- if tree.prefix != "r" or self.inside_reduction:
- argdefs.append(f"{tree.prefix.upper()}BLOCK : tl.constexpr")
- if self.inside_reduction:
- reduction_hint = self.reduction_hint
- heuristics_line = f"""
- @{heuristics}(
- size_hints={size_hints!r},
- reduction_hint={reduction_hint},
- filename=__file__,
- meta={triton_meta!r}
- )
- @triton.jit
- """
- else:
- tile_hint = ""
- if len(size_hints) == 2:
- if len(signature) == 4: # input, output and 2 args
- tile_hint = "tile_hint=TileHint.SQUARE,"
- else:
- tile_hint = "tile_hint=TileHint.DEFAULT,"
- heuristics_line = f"""
- @{heuristics}(size_hints={size_hints!r}, {tile_hint}filename=__file__, meta={triton_meta!r})
- @triton.jit
- """
- code.splice(heuristics_line)
- code.writeline(f"def {name or 'KERNEL_NAME'}({', '.join(argdefs)}):")
- self.codegen_body()
- with code.indent():
- if not dynamo_config.dynamic_shapes:
- self.codegen_static_numels(code)
- for old, new in self.args.aliases():
- code.writeline(f"{old} = {new}")
- code.splice(self.body)
- if name is not None:
- return code.getvalue()
- wrapper = IndentedBuffer()
- wrapper.writeline("async_compile.triton('''")
- wrapper.splice(code.getvalue(), strip=True)
- wrapper.writeline("''')")
- return wrapper.getvalue()
- def codegen_template_wrapper(self, src_code):
- wrapper = IndentedBuffer()
- wrapper.writeline("async_compile.triton('''")
- wrapper.splice(src_code, strip=True)
- wrapper.writeline("''')")
- return wrapper.getvalue()
- def codegen_static_numels(self, code):
- """
- We get a small speedup from hard coding numels if they are static.
- """
- for tree in self.range_trees:
- if tree.prefix != "r" or self.inside_reduction:
- if isinstance(V.graph.sizevars.simplify(tree.numel), sympy.Integer):
- code.writeline(
- f"{tree.prefix}numel = {V.graph.sizevars.size_hint(tree.numel)}"
- )
- elif not dynamo_config.dynamic_shapes:
- code.writeline(
- f"{tree.prefix}numel = {V.graph.sizevars.size_hint(tree.numel)} # dynamic_shapes=False"
- )
- def indexing_size_str(self, i=None, x=None):
- sizes = ["None"] * (len(self.range_trees) - int(self.numels[-1] == 1))
- if i is not None:
- sizes[i] = ":"
- return f"[{', '.join(sizes)}]"
- def dense_size_str(self):
- sizes = []
- for tree in self.range_trees:
- if tree.prefix != "r" or self.inside_reduction:
- sizes.append(f"{tree.prefix.upper()}BLOCK")
- elif tree.prefix == "r" and tree.numel != 1:
- sizes.append("1")
- return f"[{', '.join(sizes)}]"
- def call_kernel(self, code, name: str):
- _, call_args, _ = self.args.python_argdefs()
- # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
- for i in range(len(call_args)):
- if V.graph.is_unspec_arg(call_args[i]):
- call_args[i] = call_args[i] + ".item()"
- grid = []
- # TODO(jansel): if there are constants, we shouldn't bother passing them as args
- for tree in self.range_trees:
- if isinstance(tree.numel, (sympy.Integer, sympy.Symbol)):
- expr = pexpr(tree.numel)
- else:
- expr = f"{name}_{tree.prefix}numel"
- code.writeline(f"{expr} = {pexpr(tree.numel)}")
- if tree.prefix != "r" or self.inside_reduction:
- call_args.append(expr)
- if tree.prefix != "r":
- grid.append(expr)
- call_args = ", ".join(call_args)
- stream_name = code.write_get_cuda_stream(V.graph.scheduler.current_device.index)
- code.writeline(
- f"{name}.run({call_args}, grid=grid({', '.join(grid)}), stream={stream_name})"
- )
- def create_cse_var(self, *args, **kwargs):
- return TritonCSEVariable(*args, **kwargs)
- class TritonScheduling:
- def __init__(self, scheduler):
- self.scheduler = scheduler
- def group_fn(self, sizes):
- return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes)
- def can_fuse(self, node1, node2):
- """
- Hook called by Scheduler to determine if the Triton backend
- can fuse node1 and node2. These nodes might already be
- FusedSchedulerNodes.
- """
- _, (numel1, rnumel1) = node1.group
- _, (numel2, rnumel2) = node2.group
- if node1.is_reduction() and node2.is_reduction():
- return numel1 == numel2 and rnumel1 == rnumel2
- if not node1.is_reduction() and not node2.is_reduction():
- if not (numel1 == numel2 and rnumel1 == rnumel2):
- return False
- if node1.is_template():
- return True # skip checks for compatible tiling
- # check for a bad combined tiling
- tiling1 = self.select_tiling(node1.get_nodes(), numel1, rnumel1)
- tiling2 = self.select_tiling(node2.get_nodes(), numel1, rnumel1)
- tiling3 = self.select_tiling(
- node1.get_nodes() + node2.get_nodes(), numel1, rnumel1
- )
- if config.triton.tiling_prevents_pointwise_fusion:
- if len(tiling1) > 2:
- if len(tiling2) > 2:
- return tiling1 == tiling2 == tiling3
- else:
- return tiling1 == tiling3
- elif len(tiling2) > 2:
- return tiling2 == tiling3
- return True
- if not node1.is_reduction() and node2.is_reduction():
- assert rnumel1 == 1 and rnumel2 != 1
- if numel1 == numel2 * rnumel2:
- if not all(
- TritonKernel.is_compatible((numel2, rnumel2), n.get_ranges())
- for n in node1.get_nodes()
- ):
- return False
- if (
- config.triton.tiling_prevents_reduction_fusion
- and not node1.is_template()
- ):
- return self.select_tiling(node1.get_nodes(), numel1) in (
- (numel1, 1),
- (numel2, rnumel2, 1),
- )
- return True
- return numel1 == numel2
- assert node1.is_reduction() and not node2.is_reduction()
- # swap args to hit the case above
- return self.can_fuse_horizontal(node2, node1)
- can_fuse_vertical = can_fuse
- can_fuse_horizontal = can_fuse
- def codegen_nodes(self, nodes):
- """
- Given a set of pre-fused nodes, generate a Triton kernel.
- """
- _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group
- node_schedule = []
- current_loop_writes = set()
- is_current_reductions = set()
- done = set()
- def fits_in_main_body(n):
- _, (node_numel, node_rnumel) = n.group
- return (node_numel == numel and node_rnumel == rnumel) or (
- node_numel == numel * rnumel and node_rnumel == 1
- )
- def fits_outside_reduction(n):
- _, (node_numel, node_rnumel) = n.group
- return node_numel == numel and node_rnumel == 1 and rnumel != 1
- @contextlib.contextmanager
- def end_current_reduction_loop():
- if current_loop_writes:
- # flush out any other runnable nodes to reduce number of loops
- for other_node in nodes[index + 1 :]:
- if (
- node not in done
- and fits_in_main_body(other_node)
- and not (
- current_loop_writes & other_node.recursive_predecessors
- )
- ):
- done.add(node)
- current_loop_writes.add(node.get_name())
- is_current_reductions.add(node.is_reduction())
- node_schedule.append(node)
- if node_schedule and node_schedule[-1] is EnableReduction:
- node_schedule.pop()
- else:
- node_schedule.append(DisableReduction)
- yield
- node_schedule.append(EnableReduction)
- current_loop_writes.clear()
- is_current_reductions.clear()
- for index, node in enumerate(nodes):
- if node in done:
- continue
- done.add(node)
- def requires_closing_previous_reduction(node, node_schedule):
- if rnumel == 1:
- return False
- if not current_loop_writes & node.recursive_predecessors:
- return False
- assert node_schedule and not isinstance(
- node_schedule[-1], (EnableReduction, DisableReduction)
- )
- return True in is_current_reductions
- if fits_in_main_body(node):
- if requires_closing_previous_reduction(node, node_schedule):
- with end_current_reduction_loop():
- pass # need to start a new reduction loop
- current_loop_writes.add(node.get_name())
- is_current_reductions.add(node.is_reduction())
- node_schedule.append(node)
- elif fits_outside_reduction(node):
- with end_current_reduction_loop():
- node_schedule.append(node)
- else:
- raise NotImplementedError(
- f"unexpected group: ({numel}, {rnumel}) != {node.group[1]}"
- )
- if dynamo_config.output_code:
- log.info("schedule: %s", node_schedule)
- return self.codegen_node_schedule(node_schedule, numel, rnumel)
- @staticmethod
- def reduction_hint(node):
- assert node.is_reduction()
- if all(
- dep.is_contiguous()
- for dep in itertools.chain(node.read_writes.reads, node.read_writes.writes)
- ):
- return ReductionHint.INNER
- else:
- return node.node.data.reduction_hint
- def codegen_node_schedule(self, node_schedule, numel, reduction_numel):
- tiled_groups = self.select_tiling(node_schedule, numel, reduction_numel)
- reductions = list(
- filter(
- lambda n: n not in (EnableReduction, DisableReduction)
- and n.is_reduction(),
- node_schedule,
- )
- )
- if len(reductions) > 0:
- hints = [self.reduction_hint(n) for n in reductions]
- if hints.count(hints[0]) == len(hints):
- reduction_hint_val = hints[0]
- else:
- reduction_hint_val = ReductionHint.DEFAULT
- else:
- reduction_hint_val = ReductionHint.DEFAULT
- mutations = set()
- for node in node_schedule:
- if hasattr(node, "get_mutations"):
- mutations.update(node.get_mutations())
- with TritonKernel(
- *tiled_groups, reduction_hint=reduction_hint_val, mutations=mutations
- ) as kernel:
- stack = contextlib.ExitStack()
- for node in node_schedule:
- if node not in (EnableReduction, DisableReduction):
- node.mark_run()
- for node in node_schedule:
- if node is DisableReduction:
- stack.enter_context(kernel.disable_reduction())
- elif node is EnableReduction:
- stack.close()
- else:
- # TODO - mostly works but needs a couple fixes
- if not dynamo_config.dynamic_shapes:
- # TODO - use split ranges ?
- indexing_dtype_strength_reduction(node._body)
- index_vars = kernel.split_and_set_ranges(node.get_ranges())
- node.codegen(index_vars)
- src_code = kernel.codegen_kernel()
- kernel_name = self.define_kernel(src_code, node_schedule)
- kernel.call_kernel(V.graph.wrapper_code, kernel_name)
- self.scheduler.free_buffers()
- def define_kernel(self, src_code, node_schedule):
- wrapper = V.graph.wrapper_code
- if src_code in wrapper.kernels:
- kernel_name = wrapper.kernels[src_code]
- else:
- fused_name = (
- get_fused_kernel_name(node_schedule)
- if config.triton.descriptive_kernel_names
- else ""
- )
- kernel_name = "_".join(["triton", fused_name, wrapper.next_kernel_suffix()])
- wrapper.kernels[src_code] = kernel_name
- subs_name = kernel_name if config.triton.ordered_kernel_names else "triton_"
- src_code = src_code.replace("KERNEL_NAME", subs_name)
- # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does
- # not use BracesBuffer, so we have no good indicator of a C++ buffer atm.
- src_code = src_code.replace("#pragma CMT", "#")
- wrapper.define_kernel(kernel_name, src_code)
- return kernel_name
- def codegen_template(self, template_node, epilogue_nodes):
- """
- Codegen a triton template
- """
- _, (numel, rnumel) = template_node.group
- assert rnumel == 1
- kernel, render = template_node.node.make_kernel_render(template_node.node)
- with kernel:
- for node in [template_node, *epilogue_nodes]:
- node.mark_run()
- render() # warmup run to get the args right
- for node in epilogue_nodes:
- node.codegen(kernel.split_and_set_ranges(node.get_ranges()))
- src_code = kernel.codegen_template_wrapper(render())
- kernel_name = self.define_kernel(src_code, [template_node, *epilogue_nodes])
- kernel.call_kernel(V.graph.wrapper_code, kernel_name)
- self.scheduler.free_buffers()
- def codegen_sync(self):
- V.graph.wrapper_code.writeline("torch.cuda.synchronize()")
- @staticmethod
- @functools.lru_cache(32)
- def candidate_tilings(node):
- ranges, reduction_ranges = node.get_ranges()
- if len(ranges) <= 1:
- return ()
- rw = node.pointwise_read_writes()
- assert len(rw.range_vars) == len(ranges)
- deps = [
- dep
- for dep in itertools.chain(rw.reads, rw.writes)
- if dep.name not in V.graph.removed_buffers
- ]
- write_names = {dep.name for dep in rw.writes}
- tilings = []
- for dep in deps:
- strides = V.graph.sizevars.stride_hints(dep.index, rw.range_vars)
- assert len(strides) == len(ranges)
- try:
- split = strides.index(1) + 1
- if split == len(ranges):
- continue
- if all(s == 0 for s in strides[split:]):
- # if this is a broadcasted tensor and all dimensions after split are broadcast,
- # this is not a real split
- continue
- except ValueError:
- continue
- tiled_groups = (
- V.graph.sizevars.simplify(sympy_product(ranges[:split])),
- V.graph.sizevars.simplify(sympy_product(ranges[split:])),
- )
- # score by number of elements
- score = V.graph.sizevars.size_hint(
- sympy_product(
- size for size, stride in zip(ranges, strides) if stride != 0
- )
- )
- if dep.name in write_names:
- # ngimel said contiguous writes is more important than reads
- score *= 2
- if CandidateTiling.is_good_size(tiled_groups[0]):
- score *= 2
- if CandidateTiling.is_good_size(tiled_groups[1]):
- score *= 2
- if (
- V.graph.sizevars.size_hint(
- score - sympy_product(itertools.chain(ranges, reduction_ranges))
- )
- >= 0
- ):
- tilings.append(CandidateTiling(tiled_groups, score, dep.name))
- return tilings
- @classmethod
- def select_tiling(cls, node_schedule, numel, reduction_numel=sympy.Integer(1)):
- """
- Heuristics to decide how to tile kernels.
- Currently, we tile based on stride-1 dimensions.
- Returns:
- `(tile1, tile2, reduction_numel)` s.t. `tile1 * tile2 == numel`
- """
- if reduction_numel != 1 or config.triton.max_tiles <= 1:
- # TODO(jansel): should we tile reductions?
- return (numel, reduction_numel)
- seen_names = set()
- candidate_tiles = collections.Counter()
- for node in EnableReduction.filter(node_schedule):
- for tiling in cls.candidate_tilings(node):
- if tiling.name in seen_names:
- continue
- seen_names.add(tiling.name)
- candidate_tiles[tiling.tiling] += tiling.score
- ranked_tilings = [tiling for tiling, score in candidate_tiles.most_common()]
- if config.triton.max_tiles >= 3:
- # Add one 3D tiling choice
- for i in range(1, len(ranked_tilings)):
- a0, a1 = ranked_tilings[0]
- b0, b1 = ranked_tilings[i]
- if V.graph.sizevars.size_hint(a1 - b1) == 0:
- continue
- if V.graph.sizevars.size_hint(a1 - b1) < 0:
- # swap so a0 is bigger
- a0, a1 = ranked_tilings[i]
- b0, b1 = ranked_tilings[0]
- assert V.graph.sizevars.size_hint(a1 - b1) > 0
- if V.graph.sizevars.maybe_guard_multiple_of(a1, b1):
- tiling = (a0, ir.FloorDiv(a1, b1), b1)
- ranked_tilings = [tiling] + ranked_tilings
- break # only 1 choice for now
- for tiled_groups in ranked_tilings:
- new_groups = (*tiled_groups, reduction_numel)
- if all(
- TritonKernel.is_compatible(new_groups, node.get_ranges())
- for node in node_schedule
- if isinstance(node, scheduler.SchedulerNode)
- ):
- return new_groups
- return (numel, reduction_numel)
- def flush(self):
- pass
- @dataclasses.dataclass
- class CandidateTiling:
- tiling: List[sympy.Expr]
- score: int # higher is better
- name: str = None
- @staticmethod
- def is_good_size(s):
- """Somewhat arbitrary heuristic used to boost scores for some sizes"""
- s = V.graph.sizevars.size_hint(s)
- return s >= 32 and (s % 32 == 0)
- class DisableReduction:
- """
- Marker to invoke `kernel.disable_reduction()`. This closes a
- reduction loop and allows for pointwise ops to occur on the output
- of a reduction.
- """
- class EnableReduction:
- """
- Marker to end a DisableReduction block.
- """
- @staticmethod
- def filter(node_schedule):
- """
- Get the nodes from node_schedule skipping those in a
- DisableReduction block.
- """
- disabled = False
- for node in node_schedule:
- if node in (EnableReduction, DisableReduction):
- # Don't tile stuff outside the main reduction loop
- disabled = node is DisableReduction
- elif disabled:
- pass
- else:
- yield node
- class CantSplit(Exception):
- pass
|