triton.py 60 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760
  1. import collections
  2. import contextlib
  3. import dataclasses
  4. import functools
  5. import itertools
  6. import logging
  7. import math
  8. import operator
  9. from typing import Dict, List, Set
  10. import sympy
  11. import torch
  12. from ..._dynamo import config as dynamo_config
  13. from .. import config, ir, scheduler
  14. from ..ir import ReductionHint
  15. from ..optimize_indexing import indexing_dtype_strength_reduction
  16. from ..utils import (
  17. get_fused_kernel_name,
  18. instance_descriptor,
  19. sympy_product,
  20. sympy_subs,
  21. sympy_symbol,
  22. )
  23. from ..virtualized import ops, V
  24. from .common import (
  25. CSEVariable,
  26. DeferredLine,
  27. free_symbol_startswith,
  28. IndentedBuffer,
  29. index_prevent_reordering,
  30. Kernel,
  31. OpOverrides,
  32. PythonPrinter,
  33. SizeArg,
  34. TensorArg,
  35. )
  36. log = logging.getLogger(__name__)
  37. def signature_of(arg):
  38. from triton.runtime.jit import JITFunction
  39. if isinstance(arg, TensorArg):
  40. tye = JITFunction._type_of(arg.dtype)
  41. if V.graph.is_unspec_arg(arg.buffer):
  42. # had unwrapped 0d tensor as scalar
  43. new_tye = tye.lstrip("*")
  44. if new_tye in ["fp16", "bf16"]:
  45. return "fp32"
  46. else:
  47. return new_tye
  48. else:
  49. return tye
  50. if isinstance(arg, SizeArg):
  51. return JITFunction._key_of(V.graph.sizevars.size_hint(arg.expr))
  52. raise NotImplementedError(f"unhandled {type(arg)}: {arg}")
  53. def config_of(args):
  54. from ..compile_fx import ALIGNMENT
  55. def is_aligned(x):
  56. if isinstance(x, TensorArg):
  57. return x.buffer not in V.graph.unaligned_buffers
  58. if isinstance(x, SizeArg):
  59. return V.graph.sizevars.maybe_guard_multiple_of(x.expr, ALIGNMENT)
  60. raise NotImplementedError(f"unhandled {type(x)}: {x}")
  61. divisible_by_16 = [i for i, arg in enumerate(args) if is_aligned(arg)]
  62. return instance_descriptor(tuple(divisible_by_16), ())
  63. class TritonPrinter(PythonPrinter):
  64. def _print_floor(self, expr):
  65. assert len(expr.args) == 1
  66. return f"tl.libdevice.floor({self.paren(self._print(expr.args[0]))})"
  67. texpr = TritonPrinter().doprint
  68. pexpr = PythonPrinter().doprint
  69. def triton_compute_type(dtype):
  70. triton_type_name = str(dtype).split(".")[-1]
  71. if triton_type_name == "bool":
  72. triton_type_name = "int1"
  73. if triton_type_name in ("float16", "bfloat16"):
  74. # float16 math is done in float32 inside the kernel
  75. triton_type_name = "float32"
  76. return f"tl.{triton_type_name}"
  77. def triton_constant(value):
  78. if value == float("inf"):
  79. return 'float("inf")'
  80. elif value == float("-inf"):
  81. return 'float("-inf")'
  82. elif math.isnan(value):
  83. return 'float("nan")'
  84. return repr(value)
  85. class TritonCSEVariable(CSEVariable):
  86. def __init__(self, name):
  87. super().__init__(name)
  88. # We'll use this to track which masks the variable needs when used for indirect indexing
  89. self.mask_vars: Set[str] = set()
  90. def update_on_args(self, name, args, kwargs):
  91. # When making a variable that is going to be used in indirect indexing
  92. # if a where clause is used it should mean that the result is always a
  93. # valid index, so you shouldn't include any of the dependent variables
  94. # in the resulting load mask
  95. if name == "where":
  96. return
  97. for arg in args:
  98. if isinstance(arg, TritonCSEVariable):
  99. self.mask_vars.update(arg.mask_vars)
  100. class TritonOverrides(OpOverrides):
  101. """Map element-wise ops to Triton"""
  102. @staticmethod
  103. def to_dtype(x, dtype: torch.dtype):
  104. if dtype == torch.bool:
  105. return f"({x} != 0)"
  106. elif dtype == torch.uint8:
  107. # to work around llvm uint conversion semantics
  108. # that produces 0's for negative values
  109. return f"{x}.to(tl.int8).to(tl.uint8)"
  110. return f"{x}.to({triton_compute_type(dtype)})"
  111. @staticmethod
  112. def constant(value, dtype):
  113. type_ = torch._prims_common.dtype_to_type(dtype)
  114. return triton_constant(type_(value))
  115. @staticmethod
  116. def abs(x):
  117. return f"tl.abs({x})"
  118. @staticmethod
  119. def libdevice_abs(x):
  120. return f"tl.libdevice.abs({x})"
  121. @staticmethod
  122. def exp(x):
  123. return f"tl.exp({x})"
  124. @staticmethod
  125. def libdevice_exp(x):
  126. return f"tl.libdevice.exp({x})"
  127. @staticmethod
  128. def exp2(x):
  129. return f"tl.libdevice.exp2({x})"
  130. @staticmethod
  131. def expm1(x):
  132. return f"tl.libdevice.expm1({x})"
  133. @staticmethod
  134. def sqrt(x):
  135. return f"tl.sqrt({x})"
  136. @staticmethod
  137. def libdevice_sqrt(x):
  138. return f"tl.libdevice.sqrt({x})"
  139. @staticmethod
  140. def relu(x):
  141. return ops.maximum("0", x)
  142. @staticmethod
  143. def minimum(a, b):
  144. return f"tl.where({a} != {a}, {a}, tl.where({a} < {b}, {a}, {b}))"
  145. @staticmethod
  146. def maximum(a, b):
  147. return f"tl.where({a} != {a}, {a}, tl.where({a} > {b}, {a}, {b}))"
  148. @staticmethod
  149. def where(a, b, c):
  150. return f"tl.where({a}, {b}, {c})"
  151. @staticmethod
  152. def cos(x):
  153. return f"tl.cos({x})"
  154. @staticmethod
  155. def libdevice_cos(x):
  156. return f"tl.libdevice.cos({x})"
  157. @staticmethod
  158. def sin(x):
  159. return f"tl.sin({x})"
  160. @staticmethod
  161. def libdevice_sin(x):
  162. return f"tl.libdevice.sin({x})"
  163. @staticmethod
  164. def index_expr(expr, dtype):
  165. return V.kernel.indexing(expr)[0]
  166. @staticmethod
  167. def masked(mask, body, other):
  168. with V.kernel.mask_loads(mask) as new_mask:
  169. result = body()
  170. return ops.where(new_mask, result, triton_constant(other))
  171. @staticmethod
  172. def lgamma(x):
  173. return f"tl.libdevice.lgamma({x})"
  174. @staticmethod
  175. def erf(x):
  176. return f"tl.libdevice.erf({x})"
  177. @staticmethod
  178. def cosh(x):
  179. return f"tl.libdevice.cosh({x})"
  180. @staticmethod
  181. def sinh(x):
  182. return f"tl.libdevice.sinh({x})"
  183. @staticmethod
  184. def acos(x):
  185. return f"tl.libdevice.acos({x})"
  186. @staticmethod
  187. def acosh(x):
  188. return f"tl.libdevice.acosh({x})"
  189. @staticmethod
  190. def asin(x):
  191. return f"tl.libdevice.asin({x})"
  192. @staticmethod
  193. def asinh(x):
  194. return f"tl.libdevice.asinh({x})"
  195. @staticmethod
  196. def atan2(x, y):
  197. return f"tl.libdevice.atan2({x}, {y})"
  198. @staticmethod
  199. def atan(x):
  200. return f"tl.libdevice.atan({x})"
  201. @staticmethod
  202. def atanh(x):
  203. return f"tl.libdevice.atanh({x})"
  204. @staticmethod
  205. def copysign(x, y):
  206. return f"tl.libdevice.copysign({x}, {y})"
  207. @staticmethod
  208. def erfc(x):
  209. return f"tl.libdevice.erfc({x})"
  210. @staticmethod
  211. def hypot(x, y):
  212. return f"tl.libdevice.hypot({x}, {y})"
  213. @staticmethod
  214. def log10(x):
  215. return f"tl.libdevice.log10({x})"
  216. @staticmethod
  217. def nextafter(x, y):
  218. return f"tl.libdevice.nextafter({x}, {y})"
  219. @staticmethod
  220. def logical_and(a, b):
  221. return f"{a} & {b}"
  222. @staticmethod
  223. def logical_or(a, b):
  224. return f"{a} | {b}"
  225. @staticmethod
  226. def rand(seed, offset, _): # _ here to keep the contract identical to CPU rand op
  227. return f"tl.rand({seed}, {offset})"
  228. @staticmethod
  229. def randn(seed, offset, _): # _ here to keep the contract identical to CPU randn op
  230. return f"tl.randn({seed}, {offset})"
  231. @staticmethod
  232. def rsqrt(x):
  233. return f"tl.libdevice.rsqrt({x})"
  234. @staticmethod
  235. def log1p(x):
  236. return f"tl.libdevice.log1p({x})"
  237. @staticmethod
  238. def tan(x):
  239. return f"tl.libdevice.tan({x})"
  240. @staticmethod
  241. def tanh(x):
  242. return f"tl.libdevice.tanh({x})"
  243. @staticmethod
  244. def sigmoid(x):
  245. return f"tl.sigmoid({x})"
  246. @staticmethod
  247. def libdevice_sigmoid(x):
  248. return f"1/(1 + tl.libdevice.exp(-({x})))"
  249. @staticmethod
  250. def signbit(x):
  251. # XX: This is wrong for the value -0.0 in floating point
  252. return f"tl.libdevice.signbit({x}) if ({x}).dtype is tl.float32 else {x} < 0"
  253. @staticmethod
  254. def fmod(a, b):
  255. return f"tl.libdevice.fmod({a}, {b})"
  256. @staticmethod
  257. def pow(a, b):
  258. return f"tl.libdevice.pow({a}, {b})"
  259. @staticmethod
  260. def log(x):
  261. return f"tl.log({x})"
  262. @staticmethod
  263. def libdevice_log(x):
  264. return f"tl.libdevice.log({x})"
  265. @staticmethod
  266. def isinf(x):
  267. return f"tl.libdevice.isinf({x})"
  268. @staticmethod
  269. def isnan(x):
  270. return f"tl.libdevice.isnan({x})"
  271. @staticmethod
  272. def round(x):
  273. return f"tl.libdevice.nearbyint({x})"
  274. @staticmethod
  275. def floor(x):
  276. return f"tl.libdevice.floor({x})"
  277. @staticmethod
  278. def floordiv(a, b):
  279. # See the comment in lowering.div_mode. a and b are integer type.
  280. # Similar to div_floor_kernel_cuda in pytorch core.
  281. # Notice that // in triton behaves as truncdiv instead of floordiv
  282. quot = f"{a} // {b}"
  283. rem = f"{a} % {b}"
  284. return f"tl.where(({a} < 0) != ({b} < 0), tl.where({rem} != 0, {quot} - 1, {quot}), {quot})"
  285. @staticmethod
  286. def trunc(x):
  287. return f"tl.libdevice.trunc({x})"
  288. @staticmethod
  289. def truncdiv(a, b):
  290. # See the comment in lowering.div_mode. a and b are integer type.
  291. # Notice that // in triton behaves as truncdiv instead of floordiv
  292. return f"{a} // {b}"
  293. @staticmethod
  294. def ceil(x):
  295. return f"tl.libdevice.ceil({x})"
  296. @dataclasses.dataclass
  297. class IterationRanges:
  298. """
  299. Each range tree represents multiple sets of iteration indexing
  300. in a single tiled dimension in the output kernel.
  301. If you have two loops ranges one (4, 3, 2) and another (4, 6),
  302. then the range tree will be:
  303. 4 (i0)
  304. 3 (i1) 6 (i3)
  305. 2 (i2)
  306. Where i0 is shared between both loops, but then the split into
  307. different indexing vars. All loop ranges must iterate over
  308. the same number of elements.
  309. """
  310. def __init__(
  311. self,
  312. name: str,
  313. var_list: List[sympy.Symbol],
  314. var_ranges: Dict[sympy.Symbol, sympy.Expr],
  315. numel: sympy.Expr,
  316. prefix: str,
  317. *,
  318. kernel: "Kernel",
  319. divisor=sympy.Integer(1),
  320. length=sympy.Integer(1),
  321. ):
  322. super().__init__()
  323. self.name = name
  324. self.var_list = var_list
  325. self.var_ranges = var_ranges
  326. self.numel = numel
  327. self.prefix = prefix
  328. self.divisor = divisor
  329. self.length = length
  330. self.kernel = kernel
  331. def is_loop(self):
  332. return self.prefix == "r" and not self.kernel.persistent_reduction
  333. class IterationRangesRoot(IterationRanges):
  334. def __init__(
  335. self,
  336. name: str,
  337. numel: sympy.Expr,
  338. prefix: str,
  339. index: int,
  340. kernel: "Kernel",
  341. pid_cache=None,
  342. ):
  343. if pid_cache is None:
  344. pid_cache = {}
  345. super().__init__(
  346. name=name,
  347. var_list=[],
  348. var_ranges={},
  349. numel=numel,
  350. prefix=prefix,
  351. kernel=kernel,
  352. )
  353. self.index = index
  354. # Store all the nodes in one flat list
  355. self.nodes: Dict[sympy.Expr, IterationRangesEntry] = {}
  356. # This is for re-ordering program ID in triton mm template
  357. # pid_cache["tl.program_id(0)"] = pid_m
  358. self.pid_cache: Dict[str, str] = pid_cache
  359. def cache_clear(self):
  360. for node in self.nodes.values():
  361. node.cache_clear()
  362. def lookup(self, divisor, length):
  363. """
  364. Lookup a given RangeTreeEntry, creating it if needed
  365. """
  366. if V.graph.sizevars.maybe_guard_equals(divisor * length, self.numel):
  367. expr = ir.FloorDiv(sympy_symbol(f"{self.prefix}index"), divisor)
  368. else:
  369. expr = ir.ModularIndexing(
  370. sympy_symbol(f"{self.prefix}index"), divisor, length
  371. )
  372. if expr not in self.nodes:
  373. node = IterationRangesEntry(
  374. f"{self.prefix}{next(V.kernel.iter_vars_count)}",
  375. divisor,
  376. length,
  377. expr,
  378. self,
  379. )
  380. V.kernel.range_tree_nodes[node.symbol()] = node
  381. self.var_list.append(node.symbol())
  382. self.var_ranges[node.symbol()] = length
  383. self.nodes[expr] = node
  384. return self.nodes[expr]
  385. def construct_entries(self, lengths: List[sympy.Expr]):
  386. divisor = sympy.Integer(1)
  387. itervars = []
  388. for length in reversed(lengths):
  389. itervars.append(self.lookup(divisor, length))
  390. divisor = divisor * length
  391. return list(reversed(itervars))
  392. def construct(self, lengths: List[sympy.Expr]):
  393. return [e.symbol() for e in self.construct_entries(lengths)]
  394. def vars_and_sizes(self, index: sympy.Expr):
  395. """Figure out vars from this tree used in index"""
  396. nodes = [V.kernel.range_tree_nodes.get(s) for s in index.free_symbols]
  397. nodes = [n for n in nodes if n and n.prefix == self.prefix]
  398. nodes.sort(key=lambda x: V.graph.sizevars.size_hint(x.divisor))
  399. divisor = sympy.Integer(1)
  400. index_vars = []
  401. sizes = []
  402. def add(node):
  403. nonlocal divisor
  404. index_vars.append(node.symbol())
  405. sizes.append(node.length)
  406. divisor = divisor * node.length
  407. for node in nodes:
  408. if not V.graph.sizevars.maybe_guard_equals(node.divisor, divisor):
  409. # fill in unused index var
  410. add(self.lookup(divisor, ir.FloorDiv(node.divisor, divisor)))
  411. divisor = node.divisor
  412. add(node)
  413. if not V.graph.sizevars.maybe_guard_equals(self.numel, divisor):
  414. # fill in unused index var
  415. add(self.lookup(divisor, ir.FloorDiv(self.numel, divisor)))
  416. return list(reversed(index_vars)), list(reversed(sizes))
  417. def ranges_code(self):
  418. size = self.kernel.indexing_size_str(self.index, self.prefix)
  419. return f"tl.arange(0, {self.prefix.upper()}BLOCK){size}"
  420. def pid_cache_lookup(self, key):
  421. if key in self.pid_cache:
  422. return self.pid_cache[key]
  423. return key
  424. def codegen_header(self, code):
  425. x = self.prefix
  426. if self.is_loop():
  427. code.writeline(f"{self.name} = {x}offset + {x}base")
  428. elif x == "r" and self.kernel.persistent_reduction:
  429. # no need to "roffset = "
  430. code.writeline(
  431. f"{self.name} = {self.ranges_code()}",
  432. )
  433. else:
  434. pid = self.pid_cache_lookup(f"tl.program_id({self.index})")
  435. code.writelines(
  436. [
  437. f"{x}offset = {pid} * {x.upper()}BLOCK",
  438. f"{self.name} = {x}offset + {self.ranges_code()}",
  439. ]
  440. )
  441. code.writeline(f"{x}mask = {self.name} < {x}numel")
  442. class IterationRangesEntry(IterationRanges):
  443. def __init__(
  444. self,
  445. name: str,
  446. divisor: sympy.Expr,
  447. length: sympy.Expr,
  448. expr: sympy.Expr,
  449. parent: IterationRanges,
  450. ):
  451. super().__init__(
  452. name=name,
  453. numel=parent.numel / length,
  454. var_list=parent.var_list,
  455. var_ranges=parent.var_ranges,
  456. prefix=parent.prefix,
  457. divisor=divisor,
  458. length=length,
  459. kernel=parent.kernel,
  460. )
  461. self.parent = parent
  462. self.codegen = functools.lru_cache(None)(self._codegen)
  463. self.expr = expr
  464. def set_name(self, name):
  465. self.codegen = lambda: name
  466. self.codegen.cache_clear = lambda: None
  467. self.name = name
  468. def cache_clear(self):
  469. self.codegen.cache_clear()
  470. def writeline(self, line):
  471. if self.is_loop():
  472. V.kernel.indexing_code.writeline(line)
  473. else:
  474. # lift non-reduction stores outside loop
  475. V.kernel.body.writeline(line)
  476. def _codegen(self):
  477. self.writeline(f"{self.name} = " + texpr(V.kernel.rename_indexing(self.expr)))
  478. return self.name
  479. def precomputed_args(self):
  480. # for dynamic shapes, find parts of indexing expressions that have to be precomputed
  481. precomputed_args = []
  482. if isinstance(self.expr, sympy.Symbol):
  483. return precomputed_args
  484. assert isinstance(self.expr, (ir.FloorDiv, ir.ModularIndexing)), type(self.expr)
  485. for arg in self.expr.args[1:]:
  486. if not isinstance(arg, (sympy.Integer, sympy.Symbol)):
  487. symbols = arg.free_symbols
  488. if len(symbols) > 0 and all(s.name.startswith("s") for s in symbols):
  489. precomputed_args.append(arg)
  490. return precomputed_args
  491. def symbol(self):
  492. return sympy_symbol(self.name)
  493. def __hash__(self):
  494. return hash(self.name)
  495. def __eq__(self, other):
  496. return self.name == other.name
  497. class TritonKernel(Kernel):
  498. overrides = TritonOverrides
  499. sexpr = pexpr
  500. def __init__(
  501. self,
  502. *groups,
  503. mutations=None,
  504. pid_cache=None,
  505. reduction_hint=ReductionHint.DEFAULT,
  506. ):
  507. if pid_cache is None:
  508. pid_cache = {}
  509. super().__init__()
  510. self.numels = [V.graph.sizevars.simplify(s) for s in groups]
  511. self.mutations = mutations
  512. self.range_trees = []
  513. self.range_tree_nodes = {}
  514. self.iter_vars_count = itertools.count()
  515. self.inside_reduction = self.numels[-1] != 1
  516. self._load_mask = None
  517. self.body = IndentedBuffer()
  518. self.indexing_code = IndentedBuffer()
  519. self.suffix = IndentedBuffer()
  520. self.outside_loop_vars = set()
  521. self.reduction_hint = reduction_hint
  522. self.persistent_reduction = self.should_use_persistent_reduction()
  523. self.initialize_range_tree(pid_cache)
  524. # define this in a closure to make cache local to object
  525. @functools.lru_cache(None)
  526. def simplify_indexing(index: sympy.Expr):
  527. index = V.graph.sizevars.simplify_with_ranges(index, self.var_ranges())
  528. for tree in self.range_trees:
  529. index = self.combine_contiguous_dims(index, tree)
  530. return index
  531. self.simplify_indexing = simplify_indexing
  532. def should_use_persistent_reduction(self):
  533. """
  534. Heuristic to set self.persistent_reduction and add guards
  535. if needed.
  536. """
  537. if not (self.inside_reduction and config.triton.persistent_reductions):
  538. return False
  539. threshold = {
  540. ReductionHint.INNER: 1024,
  541. }.get(self.reduction_hint, 64)
  542. hint = V.graph.sizevars.size_hint(self.numels[-1])
  543. if hint > threshold:
  544. return False
  545. from triton import next_power_of_2
  546. # will need to recompile if we cross a larger power of 2 boundary
  547. V.graph.sizevars.guard_leq(self.numels[-1], next_power_of_2(hint))
  548. return True
  549. def initialize_range_tree(self, pid_cache):
  550. names = ["xindex", "yindex", "zindex"][: len(self.numels) - 1] + ["rindex"]
  551. for i in range(len(self.numels)):
  552. self.range_trees.append(
  553. IterationRangesRoot(
  554. names[i], self.numels[i], names[i][0], i, self, pid_cache
  555. )
  556. )
  557. for tree in self.range_trees:
  558. # reduction indexing goes inside a loop
  559. if not tree.is_loop():
  560. tree.codegen_header(self.body)
  561. if self.inside_reduction and self.range_trees[-1].is_loop():
  562. # workaround for this issue:
  563. # https://gist.github.com/jansel/6527126f781559095c5531f98a4235a7
  564. self.body.writeline(f"rbase = {self.range_trees[-1].ranges_code()}")
  565. def disable_reduction(self):
  566. @contextlib.contextmanager
  567. def ctx():
  568. if self.numels[-1] == 1:
  569. assert not self.inside_reduction
  570. yield
  571. return
  572. if not self.persistent_reduction:
  573. # calling codegen_body() will flush all the pending buffers
  574. # and write out a reduction loop
  575. self.codegen_body()
  576. self.inside_reduction = False
  577. yield
  578. if not self.persistent_reduction:
  579. # flush out any code before opening the next loop
  580. self.codegen_body()
  581. self.inside_reduction = True
  582. return ctx()
  583. def set_ranges(self, *lengths):
  584. assert len(lengths) == len(self.range_trees)
  585. return [
  586. ranges.construct(length)
  587. for length, ranges in zip(lengths, self.range_trees)
  588. ]
  589. @staticmethod
  590. def _split_iteration_ranges(
  591. groups: List[sympy.Expr], lengths: List[List[sympy.Expr]]
  592. ):
  593. sv = V.graph.sizevars
  594. new_ranges = [[] for _ in groups]
  595. remaining = [sv.simplify(g) for g in groups]
  596. var_count = itertools.count()
  597. def add_range(i, expr):
  598. expr = sv.simplify(expr)
  599. if not sv.maybe_guard_multiple_of(remaining[i], expr):
  600. raise CantSplit()
  601. # guard on the last item out
  602. sv.maybe_guard_equals(remaining[i], expr)
  603. remaining[i] = ir.FloorDiv(remaining[i], expr)
  604. new_ranges[i].append(expr)
  605. return next(var_count)
  606. def make_combined(size, idx1, idx2):
  607. def getter(flat_vars):
  608. return size * flat_vars[idx1] + flat_vars[idx2]
  609. return getter
  610. return_getters_groups = []
  611. current_group = 0
  612. for length_group in lengths:
  613. return_getters = []
  614. for size in length_group:
  615. if sv.maybe_guard_equals(size, 1):
  616. return_getters.append(lambda _: sympy.Integer(0))
  617. continue
  618. while (
  619. current_group < len(remaining)
  620. and sv.size_hint(remaining[current_group]) == 1
  621. ):
  622. # scroll to next group with remaining elements
  623. current_group += 1
  624. if sv.size_hint(size) > sv.size_hint(remaining[current_group]):
  625. # need to break size in two
  626. if not sv.maybe_guard_multiple_of(size, remaining[current_group]):
  627. raise CantSplit()
  628. size1 = remaining[current_group]
  629. size2 = ir.FloorDiv(size, remaining[current_group])
  630. return_getters.append(
  631. make_combined(
  632. size2,
  633. add_range(current_group, size1),
  634. add_range(current_group + 1, size2),
  635. )
  636. )
  637. else:
  638. return_getters.append(
  639. operator.itemgetter(add_range(current_group, size))
  640. )
  641. return_getters_groups.append(return_getters)
  642. assert all(
  643. V.graph.sizevars.size_hint(s) == 1 for s in remaining
  644. ), f"failed to set ranges {remaining} {lengths}"
  645. return new_ranges, return_getters_groups
  646. @classmethod
  647. def is_compatible(cls, groups: List[sympy.Expr], lengths: List[List[sympy.Expr]]):
  648. try:
  649. cls._split_iteration_ranges(groups, lengths)
  650. return True
  651. except CantSplit:
  652. return False
  653. def split_and_set_ranges(self, lengths: List[List[sympy.Expr]]):
  654. """
  655. We may want to fuse `for i0 in s0*s1` into a tiled kernel with groups (s0, s1).
  656. To do this we need to split up the iteration space of i0 into something like:
  657. for i1 in s0:
  658. for i2 in s1:
  659. i0 = i1*s1 + i2
  660. ....
  661. This function matches and resplits lengths to the groups of
  662. this kernel to enable tiled + non-tiled fusions.
  663. """
  664. groups = [rt.numel for rt in self.range_trees]
  665. if not self.inside_reduction:
  666. groups[-1] = sympy.Integer(1)
  667. if len(lengths) == len(self.range_trees) and all(
  668. V.graph.sizevars.simplify(sympy_product(x) - g) == 0
  669. for x, g in zip(lengths, groups)
  670. ):
  671. return self.set_ranges(*lengths)
  672. new_ranges, return_getters_groups = self._split_iteration_ranges(
  673. groups, lengths
  674. )
  675. itervars = list(itertools.chain(*self.set_ranges(*new_ranges)))
  676. return [[fn(itervars) for fn in fns] for fns in return_getters_groups]
  677. def is_indirect_indexing(self, index: sympy.Expr):
  678. # tmpX means indirect indexing
  679. return free_symbol_startswith(index, "tmp")
  680. def combine_contiguous_dims(self, index: sympy.Expr, tree: IterationRangesRoot):
  681. """
  682. More aggressive simplification to merge contiguous dims
  683. """
  684. if isinstance(index, (sympy.Integer, sympy.Symbol)):
  685. return index
  686. index_vars, sizes = tree.vars_and_sizes(index)
  687. if len(sizes) <= 1:
  688. return index
  689. new_sizes, reindex, prune = V.graph.sizevars._simplify_loops(
  690. index_vars, sizes, index_prevent_reordering([index], index_vars, sizes)
  691. )
  692. if new_sizes == sizes:
  693. return index
  694. new_index_vars = tree.construct(new_sizes)
  695. new_index = sympy_subs(index, dict(zip(index_vars, reindex(new_index_vars))))
  696. return new_index
  697. def indexing(
  698. self,
  699. index: sympy.Expr,
  700. *,
  701. copy_shape=None,
  702. dense_indexing=False,
  703. override_mask=None,
  704. ):
  705. """
  706. Compute the index and mask to pass to tl.load() or tl.store()
  707. """
  708. index = self.simplify_indexing(index)
  709. index_vars = index.free_symbols
  710. index_str = texpr(self.rename_indexing(self.codegen_indexing(index)))
  711. mask_vars: Set[str] = set()
  712. for var in index_vars:
  713. if override_mask:
  714. pass
  715. elif var.name.startswith("tmp"):
  716. # indirect indexing
  717. cse_var = self.cse.varname_map[var.name]
  718. mask_vars.update(cse_var.mask_vars)
  719. elif var.name.startswith("s"):
  720. pass
  721. else:
  722. # var is one of xN, yN or rN
  723. assert var.name[0] in "xyr", var.name
  724. mask_vars.add(f"{var.name[0]}mask")
  725. need_dense = (
  726. config.triton.dense_indexing
  727. or dense_indexing
  728. or self._load_mask is not None
  729. ) and index != 0
  730. have_dense = True
  731. have_loop_vars = False
  732. dense_mask_vars = set()
  733. for tree in self.range_trees:
  734. if tree.prefix == "r" and not self.inside_reduction:
  735. continue
  736. if index_vars.intersection(tree.var_list):
  737. have_loop_vars = True
  738. have_dense = False
  739. dense_mask_vars.add(f"{tree.prefix}mask")
  740. if (need_dense and not have_dense) or isinstance(index, sympy.Integer):
  741. if copy_shape:
  742. index_str = f"{index_str} + tl.zeros({copy_shape}.shape, tl.int32)"
  743. else:
  744. index_str = f"{index_str} + tl.zeros({self.dense_size_str()}, tl.int32)"
  745. if isinstance(index, sympy.Integer):
  746. return index_str, set(), "None"
  747. else:
  748. mask_vars = dense_mask_vars
  749. elif not have_loop_vars and copy_shape:
  750. mask_vars = dense_mask_vars
  751. index_str = f"{index_str} + tl.zeros({copy_shape}.shape, tl.int32)"
  752. if override_mask:
  753. mask_vars = {override_mask}
  754. if self._load_mask:
  755. mask_vars.add(self._load_mask)
  756. self.filter_masks(mask_vars)
  757. mask_str = " & ".join(sorted(map(str, mask_vars))) if mask_vars else "None"
  758. return index_str, mask_vars, mask_str
  759. def filter_masks(self, mask_vars):
  760. for tree in self.range_trees:
  761. # Masks are superfluous if we only have one element
  762. if V.graph.sizevars.maybe_guard_equals(tree.numel, 1):
  763. mask_vars.discard(f"{tree.prefix}mask")
  764. def var_ranges(self):
  765. return dict(
  766. itertools.chain.from_iterable(
  767. tree.var_ranges.items() for tree in self.range_trees
  768. )
  769. )
  770. def codegen_indexing(self, expr: sympy.Expr):
  771. expr = V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges())
  772. for sym in sorted(expr.free_symbols, key=str):
  773. if sym in self.range_tree_nodes:
  774. # if indexing expression is complicated, we precompute it on the host side
  775. # and send the result as a kernel argument
  776. replacements = {}
  777. for ps in self.range_tree_nodes[sym].precomputed_args():
  778. replacements[ps] = V.graph.sizevars.lookup_precomputed_size(ps)
  779. if len(replacements) > 0:
  780. self.range_tree_nodes[sym].expr = sympy_subs(
  781. self.range_tree_nodes[sym].expr, replacements
  782. )
  783. self.range_tree_nodes[sym].codegen()
  784. return expr
  785. @contextlib.contextmanager
  786. def mask_loads(self, mask):
  787. """Context manager to add an additional mask to tl.load/store"""
  788. prior = self._load_mask
  789. if prior:
  790. mask = self.cse.generate(self.compute, f"{mask} & {prior}")
  791. self._load_mask = mask
  792. with self.swap_buffers(self.compute, self.compute):
  793. # TODO(jansel): do we need a reshape here?
  794. yield mask
  795. self._load_mask = prior
  796. def load(self, name: str, index: sympy.Expr):
  797. var = self.args.input(name)
  798. indirect_indexing = self.is_indirect_indexing(index)
  799. original_index = index
  800. index, mask_vars, mask = self.indexing(index)
  801. if "rmask" in mask and not self.persistent_reduction:
  802. # This eviction policy heuristic is untested.
  803. # ptillet suggested we should try only doing this for
  804. # the first N-1 loops and not for the final loop.
  805. ep = ", eviction_policy='evict_last'"
  806. else:
  807. ep = ""
  808. # "other" below is a workaround for https://github.com/openai/triton/issues/737
  809. # for bool, even though it's likely subject to the same bug, setting `other` leads
  810. # to LLVM errors so we are skipping it for now
  811. if ("tmp" in mask or "rmask" in mask) and V.graph.get_dtype(name) != torch.bool:
  812. other = ", other=0"
  813. else:
  814. other = ""
  815. append_broadcast = None
  816. if V.graph.is_unspec_arg(name):
  817. line = var
  818. else:
  819. if isinstance(original_index, sympy.Integer):
  820. dense_size = self.dense_size_str()
  821. line = f"tl.load({var} + ({original_index}))"
  822. append_broadcast = dense_size
  823. else:
  824. line = f"tl.load({var} + ({index}), {mask}{ep}{other})"
  825. if V.graph.get_dtype(name) in (torch.float16, torch.bfloat16):
  826. line += ".to(tl.float32)"
  827. if (
  828. self.inside_reduction
  829. and not self.persistent_reduction
  830. and "rmask" not in mask
  831. and "tmp" not in mask
  832. and not indirect_indexing
  833. ):
  834. # can lift a common load outside of reduction loop
  835. # One exception is when this is an indirect_load.
  836. result_var = self.cse.generate(
  837. self.body, line, append_broadcast=append_broadcast
  838. )
  839. else:
  840. result_var = self.cse.generate(
  841. self.loads, line, append_broadcast=append_broadcast
  842. )
  843. result_var.mask_vars = mask_vars
  844. if not self.inside_reduction or "rmask" not in mask:
  845. self.outside_loop_vars.add(result_var)
  846. return result_var
  847. def store(self, name, index, value, mode=None):
  848. var = self.args.output(name)
  849. index, mask_vars, mask = self.indexing(index, dense_indexing=True)
  850. if mode is None:
  851. line = f"tl.store({var} + ({index}), {value}, {mask})"
  852. elif mode == "atomic_add":
  853. line = f"tl.atomic_add({var} + ({index}), {value}, {mask})"
  854. else:
  855. raise NotImplementedError(f"store mode={mode}")
  856. self.stores.writeline(name, line)
  857. if not self.inside_reduction:
  858. self.outside_loop_vars.add(value)
  859. def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
  860. assert self.inside_reduction
  861. default = triton_constant(ir.Reduction.default_value(reduction_type, src_dtype))
  862. masks = {f"{tree.prefix}mask" for tree in self.range_trees}
  863. self.filter_masks(masks)
  864. masks = sorted(masks)
  865. if self._load_mask:
  866. masks.append(self._load_mask)
  867. sizes = [":" for _ in self.range_trees]
  868. sizes[-1] = "None"
  869. reduction_range_prefix = self.range_trees[-1].prefix
  870. reduction_sizes = ["None" for _ in self.range_trees]
  871. reduction_sizes[-1] = ":"
  872. if reduction_type == "any":
  873. reduction_type = "max"
  874. dim = len(self.range_trees) - 1
  875. result_var = self.cse.newvar()
  876. result_var.mask_vars = {var for var in masks if var[0] != "r"}
  877. if self.persistent_reduction:
  878. cond = " & ".join(masks)
  879. masked_value = self.cse.generate(
  880. self.compute, f"tl.where({cond}, {value}, {default})"
  881. )
  882. result_var = self.cse.generate(
  883. self.compute,
  884. f"tl.{reduction_type}({masked_value}, {dim})[{', '.join(sizes)}]",
  885. )
  886. elif (src_dtype, reduction_type, value) not in self.cse.reduction_cache:
  887. self.cse.reduction_cache[(src_dtype, reduction_type, value)] = result_var
  888. accumulator = f"_{result_var}"
  889. default_value = f" + {default}" if default != 0 else ""
  890. self.body.writeline(
  891. f"{accumulator} = tl.zeros({self.dense_size_str()}, {triton_compute_type(src_dtype)}){default_value}"
  892. )
  893. accumulator_index = None
  894. if reduction_type in {"argmax", "argmin"}:
  895. accumulator_index = f"_{result_var}_index"
  896. self.body.writeline(
  897. f"{accumulator_index} = tl.zeros({self.dense_size_str()}, tl.int64)"
  898. )
  899. updated = value
  900. if reduction_type in {"min", "argmin"}:
  901. masks.append(f"({accumulator} > {value})")
  902. elif reduction_type in {"max", "argmax"}:
  903. masks.append(f"({accumulator} < {value})")
  904. elif reduction_type == "sum":
  905. updated = f"{accumulator} + {value}"
  906. else:
  907. raise NotImplementedError(f"reduction_type {reduction_type}")
  908. cond = " & ".join(masks)
  909. if accumulator_index:
  910. # argmax or argmin
  911. self.compute.writeline(
  912. f"{accumulator_index} = tl.where({cond}, {reduction_range_prefix}index, {accumulator_index})",
  913. )
  914. self.compute.writeline(
  915. f"{accumulator} = tl.where({cond}, {updated}, {accumulator})"
  916. )
  917. if accumulator_index:
  918. # argmax, argmin
  919. self.suffix.writelines(
  920. [
  921. f"{accumulator_index}_reduce = "
  922. f"tl.{reduction_type}({accumulator}, {dim})[{', '.join(sizes)}].to(tl.int32)",
  923. f"{accumulator_index}_mask = tl.arange(0, {reduction_range_prefix.upper()}BLOCK)"
  924. f"[{', '.join(reduction_sizes)}] == {accumulator_index}_reduce",
  925. f"{result_var} = tl.sum("
  926. f"tl.where({accumulator_index}_mask, {accumulator_index}, 0), {dim})[{', '.join(sizes)}]",
  927. ]
  928. )
  929. else:
  930. self.suffix.writeline(
  931. f"{result_var} = tl.{reduction_type}({accumulator}, {dim})[{', '.join(sizes)}]"
  932. )
  933. else:
  934. var_name = self.cse.reduction_cache[(src_dtype, reduction_type, value)]
  935. self.suffix.writeline(f"{result_var} = {var_name}")
  936. result_var.mask_vars = var_name.mask_vars
  937. self.inside_reduction = False
  938. index, mask_vars, mask = self.indexing(index)
  939. assert "rmask" not in index
  940. self.inside_reduction = True
  941. self.outside_loop_vars.add(result_var)
  942. self.cse.store_cache[name] = result_var
  943. if name not in V.graph.removed_buffers:
  944. var = self.args.output(name)
  945. self.suffix.writeline(
  946. DeferredLine(name, f"tl.store({var} + {index}, {result_var}, {mask})")
  947. )
  948. def codegen_body(self):
  949. """
  950. Concat output code from index_code, loads, compute, stores,
  951. suffix into self.body.
  952. For pointwise kernels, this is called just once at the end.
  953. For reduction kernels, this generates a loop over the reduction
  954. axis.
  955. """
  956. if not (
  957. self.indexing_code
  958. or self.loads
  959. or self.stores
  960. or self.compute
  961. or self.suffix
  962. ):
  963. return
  964. if self.inside_reduction and not self.persistent_reduction:
  965. self.body.writeline("for roffset in range(0, rnumel, RBLOCK):")
  966. with self.body.indent():
  967. # last range tree is always reduction
  968. self.range_trees[-1].codegen_header(self.body)
  969. self.body.splice(self.indexing_code)
  970. self.body.splice(self.loads)
  971. self.body.splice(self.compute)
  972. self.body.splice(self.stores)
  973. # invalidate any caches that came from inside the reduction loop
  974. self.cse.invalidate(self.outside_loop_vars)
  975. self.range_trees[-1].cache_clear()
  976. else:
  977. self.body.splice(self.indexing_code)
  978. self.body.splice(self.loads)
  979. self.body.splice(self.compute)
  980. self.body.splice(self.stores)
  981. self.body.splice(self.suffix)
  982. self.indexing_code.clear()
  983. self.loads.clear()
  984. self.compute.clear()
  985. self.stores.clear()
  986. self.suffix.clear()
  987. def codegen_kernel(self, name=None):
  988. from triton import next_power_of_2
  989. code = IndentedBuffer()
  990. size_hints = [
  991. next_power_of_2(V.graph.sizevars.size_hint(numel)) for numel in self.numels
  992. ]
  993. if self.persistent_reduction:
  994. assert self.inside_reduction
  995. heuristics = "persistent_reduction"
  996. elif self.inside_reduction:
  997. heuristics = "reduction"
  998. else:
  999. size_hints.pop()
  1000. heuristics = "pointwise"
  1001. if name is None:
  1002. code.splice(
  1003. f"""
  1004. import triton
  1005. import triton.language as tl
  1006. from torch._inductor.ir import ReductionHint
  1007. from torch._inductor.ir import TileHint
  1008. from torch._inductor.triton_ops.autotune import {heuristics}
  1009. from torch._inductor.utils import instance_descriptor
  1010. """
  1011. )
  1012. argdefs, _, signature = self.args.python_argdefs()
  1013. # maps actual expression to SizeArg if its in sizevars replacements
  1014. for i, arg in enumerate(signature):
  1015. if (
  1016. isinstance(arg, SizeArg)
  1017. and arg.expr in V.graph.sizevars.inv_precomputed_replacements
  1018. ):
  1019. signature[i] = SizeArg(
  1020. arg.name, V.graph.sizevars.inv_precomputed_replacements[arg.expr]
  1021. )
  1022. mutated_args = set()
  1023. for mutation in self.mutations:
  1024. if mutation in self.args.input_buffers:
  1025. mutated_args.add(self.args.input_buffers[mutation])
  1026. if mutation in self.args.inplace_buffers:
  1027. mutated_args.add(self.args.inplace_buffers[mutation].inner_name)
  1028. if mutation in self.args.output_buffers:
  1029. mutated_args.add(self.args.output_buffers[mutation])
  1030. mutated_args = sorted(mutated_args)
  1031. triton_meta = {
  1032. "signature": dict(enumerate(map(signature_of, signature))),
  1033. "device": V.graph.scheduler.current_device.index,
  1034. "constants": {},
  1035. "mutated_arg_names": mutated_args,
  1036. }
  1037. for tree in self.range_trees:
  1038. if tree.prefix != "r" or self.inside_reduction:
  1039. sizearg = SizeArg(f"{tree.prefix}numel", tree.numel)
  1040. signature.append(sizearg)
  1041. triton_meta["signature"][len(argdefs)] = signature_of(sizearg)
  1042. argdefs.append(f"{tree.prefix}numel")
  1043. # constexpr version causes issues, see
  1044. # https://github.com/pytorch/torchdynamo/pull/1362
  1045. # triton_meta["constants"][len(argdefs)] = V.graph.sizevars.size_hint(
  1046. # tree.numel
  1047. # )
  1048. # argdefs.append(f"{tree.prefix}numel: tl.constexpr")
  1049. triton_meta["configs"] = [config_of(signature)]
  1050. for tree in self.range_trees:
  1051. if tree.prefix != "r" or self.inside_reduction:
  1052. argdefs.append(f"{tree.prefix.upper()}BLOCK : tl.constexpr")
  1053. if self.inside_reduction:
  1054. reduction_hint = self.reduction_hint
  1055. heuristics_line = f"""
  1056. @{heuristics}(
  1057. size_hints={size_hints!r},
  1058. reduction_hint={reduction_hint},
  1059. filename=__file__,
  1060. meta={triton_meta!r}
  1061. )
  1062. @triton.jit
  1063. """
  1064. else:
  1065. tile_hint = ""
  1066. if len(size_hints) == 2:
  1067. if len(signature) == 4: # input, output and 2 args
  1068. tile_hint = "tile_hint=TileHint.SQUARE,"
  1069. else:
  1070. tile_hint = "tile_hint=TileHint.DEFAULT,"
  1071. heuristics_line = f"""
  1072. @{heuristics}(size_hints={size_hints!r}, {tile_hint}filename=__file__, meta={triton_meta!r})
  1073. @triton.jit
  1074. """
  1075. code.splice(heuristics_line)
  1076. code.writeline(f"def {name or 'KERNEL_NAME'}({', '.join(argdefs)}):")
  1077. self.codegen_body()
  1078. with code.indent():
  1079. if not dynamo_config.dynamic_shapes:
  1080. self.codegen_static_numels(code)
  1081. for old, new in self.args.aliases():
  1082. code.writeline(f"{old} = {new}")
  1083. code.splice(self.body)
  1084. if name is not None:
  1085. return code.getvalue()
  1086. wrapper = IndentedBuffer()
  1087. wrapper.writeline("async_compile.triton('''")
  1088. wrapper.splice(code.getvalue(), strip=True)
  1089. wrapper.writeline("''')")
  1090. return wrapper.getvalue()
  1091. def codegen_template_wrapper(self, src_code):
  1092. wrapper = IndentedBuffer()
  1093. wrapper.writeline("async_compile.triton('''")
  1094. wrapper.splice(src_code, strip=True)
  1095. wrapper.writeline("''')")
  1096. return wrapper.getvalue()
  1097. def codegen_static_numels(self, code):
  1098. """
  1099. We get a small speedup from hard coding numels if they are static.
  1100. """
  1101. for tree in self.range_trees:
  1102. if tree.prefix != "r" or self.inside_reduction:
  1103. if isinstance(V.graph.sizevars.simplify(tree.numel), sympy.Integer):
  1104. code.writeline(
  1105. f"{tree.prefix}numel = {V.graph.sizevars.size_hint(tree.numel)}"
  1106. )
  1107. elif not dynamo_config.dynamic_shapes:
  1108. code.writeline(
  1109. f"{tree.prefix}numel = {V.graph.sizevars.size_hint(tree.numel)} # dynamic_shapes=False"
  1110. )
  1111. def indexing_size_str(self, i=None, x=None):
  1112. sizes = ["None"] * (len(self.range_trees) - int(self.numels[-1] == 1))
  1113. if i is not None:
  1114. sizes[i] = ":"
  1115. return f"[{', '.join(sizes)}]"
  1116. def dense_size_str(self):
  1117. sizes = []
  1118. for tree in self.range_trees:
  1119. if tree.prefix != "r" or self.inside_reduction:
  1120. sizes.append(f"{tree.prefix.upper()}BLOCK")
  1121. elif tree.prefix == "r" and tree.numel != 1:
  1122. sizes.append("1")
  1123. return f"[{', '.join(sizes)}]"
  1124. def call_kernel(self, code, name: str):
  1125. _, call_args, _ = self.args.python_argdefs()
  1126. # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
  1127. for i in range(len(call_args)):
  1128. if V.graph.is_unspec_arg(call_args[i]):
  1129. call_args[i] = call_args[i] + ".item()"
  1130. grid = []
  1131. # TODO(jansel): if there are constants, we shouldn't bother passing them as args
  1132. for tree in self.range_trees:
  1133. if isinstance(tree.numel, (sympy.Integer, sympy.Symbol)):
  1134. expr = pexpr(tree.numel)
  1135. else:
  1136. expr = f"{name}_{tree.prefix}numel"
  1137. code.writeline(f"{expr} = {pexpr(tree.numel)}")
  1138. if tree.prefix != "r" or self.inside_reduction:
  1139. call_args.append(expr)
  1140. if tree.prefix != "r":
  1141. grid.append(expr)
  1142. call_args = ", ".join(call_args)
  1143. stream_name = code.write_get_cuda_stream(V.graph.scheduler.current_device.index)
  1144. code.writeline(
  1145. f"{name}.run({call_args}, grid=grid({', '.join(grid)}), stream={stream_name})"
  1146. )
  1147. def create_cse_var(self, *args, **kwargs):
  1148. return TritonCSEVariable(*args, **kwargs)
  1149. class TritonScheduling:
  1150. def __init__(self, scheduler):
  1151. self.scheduler = scheduler
  1152. def group_fn(self, sizes):
  1153. return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes)
  1154. def can_fuse(self, node1, node2):
  1155. """
  1156. Hook called by Scheduler to determine if the Triton backend
  1157. can fuse node1 and node2. These nodes might already be
  1158. FusedSchedulerNodes.
  1159. """
  1160. _, (numel1, rnumel1) = node1.group
  1161. _, (numel2, rnumel2) = node2.group
  1162. if node1.is_reduction() and node2.is_reduction():
  1163. return numel1 == numel2 and rnumel1 == rnumel2
  1164. if not node1.is_reduction() and not node2.is_reduction():
  1165. if not (numel1 == numel2 and rnumel1 == rnumel2):
  1166. return False
  1167. if node1.is_template():
  1168. return True # skip checks for compatible tiling
  1169. # check for a bad combined tiling
  1170. tiling1 = self.select_tiling(node1.get_nodes(), numel1, rnumel1)
  1171. tiling2 = self.select_tiling(node2.get_nodes(), numel1, rnumel1)
  1172. tiling3 = self.select_tiling(
  1173. node1.get_nodes() + node2.get_nodes(), numel1, rnumel1
  1174. )
  1175. if config.triton.tiling_prevents_pointwise_fusion:
  1176. if len(tiling1) > 2:
  1177. if len(tiling2) > 2:
  1178. return tiling1 == tiling2 == tiling3
  1179. else:
  1180. return tiling1 == tiling3
  1181. elif len(tiling2) > 2:
  1182. return tiling2 == tiling3
  1183. return True
  1184. if not node1.is_reduction() and node2.is_reduction():
  1185. assert rnumel1 == 1 and rnumel2 != 1
  1186. if numel1 == numel2 * rnumel2:
  1187. if not all(
  1188. TritonKernel.is_compatible((numel2, rnumel2), n.get_ranges())
  1189. for n in node1.get_nodes()
  1190. ):
  1191. return False
  1192. if (
  1193. config.triton.tiling_prevents_reduction_fusion
  1194. and not node1.is_template()
  1195. ):
  1196. return self.select_tiling(node1.get_nodes(), numel1) in (
  1197. (numel1, 1),
  1198. (numel2, rnumel2, 1),
  1199. )
  1200. return True
  1201. return numel1 == numel2
  1202. assert node1.is_reduction() and not node2.is_reduction()
  1203. # swap args to hit the case above
  1204. return self.can_fuse_horizontal(node2, node1)
  1205. can_fuse_vertical = can_fuse
  1206. can_fuse_horizontal = can_fuse
  1207. def codegen_nodes(self, nodes):
  1208. """
  1209. Given a set of pre-fused nodes, generate a Triton kernel.
  1210. """
  1211. _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group
  1212. node_schedule = []
  1213. current_loop_writes = set()
  1214. is_current_reductions = set()
  1215. done = set()
  1216. def fits_in_main_body(n):
  1217. _, (node_numel, node_rnumel) = n.group
  1218. return (node_numel == numel and node_rnumel == rnumel) or (
  1219. node_numel == numel * rnumel and node_rnumel == 1
  1220. )
  1221. def fits_outside_reduction(n):
  1222. _, (node_numel, node_rnumel) = n.group
  1223. return node_numel == numel and node_rnumel == 1 and rnumel != 1
  1224. @contextlib.contextmanager
  1225. def end_current_reduction_loop():
  1226. if current_loop_writes:
  1227. # flush out any other runnable nodes to reduce number of loops
  1228. for other_node in nodes[index + 1 :]:
  1229. if (
  1230. node not in done
  1231. and fits_in_main_body(other_node)
  1232. and not (
  1233. current_loop_writes & other_node.recursive_predecessors
  1234. )
  1235. ):
  1236. done.add(node)
  1237. current_loop_writes.add(node.get_name())
  1238. is_current_reductions.add(node.is_reduction())
  1239. node_schedule.append(node)
  1240. if node_schedule and node_schedule[-1] is EnableReduction:
  1241. node_schedule.pop()
  1242. else:
  1243. node_schedule.append(DisableReduction)
  1244. yield
  1245. node_schedule.append(EnableReduction)
  1246. current_loop_writes.clear()
  1247. is_current_reductions.clear()
  1248. for index, node in enumerate(nodes):
  1249. if node in done:
  1250. continue
  1251. done.add(node)
  1252. def requires_closing_previous_reduction(node, node_schedule):
  1253. if rnumel == 1:
  1254. return False
  1255. if not current_loop_writes & node.recursive_predecessors:
  1256. return False
  1257. assert node_schedule and not isinstance(
  1258. node_schedule[-1], (EnableReduction, DisableReduction)
  1259. )
  1260. return True in is_current_reductions
  1261. if fits_in_main_body(node):
  1262. if requires_closing_previous_reduction(node, node_schedule):
  1263. with end_current_reduction_loop():
  1264. pass # need to start a new reduction loop
  1265. current_loop_writes.add(node.get_name())
  1266. is_current_reductions.add(node.is_reduction())
  1267. node_schedule.append(node)
  1268. elif fits_outside_reduction(node):
  1269. with end_current_reduction_loop():
  1270. node_schedule.append(node)
  1271. else:
  1272. raise NotImplementedError(
  1273. f"unexpected group: ({numel}, {rnumel}) != {node.group[1]}"
  1274. )
  1275. if dynamo_config.output_code:
  1276. log.info("schedule: %s", node_schedule)
  1277. return self.codegen_node_schedule(node_schedule, numel, rnumel)
  1278. @staticmethod
  1279. def reduction_hint(node):
  1280. assert node.is_reduction()
  1281. if all(
  1282. dep.is_contiguous()
  1283. for dep in itertools.chain(node.read_writes.reads, node.read_writes.writes)
  1284. ):
  1285. return ReductionHint.INNER
  1286. else:
  1287. return node.node.data.reduction_hint
  1288. def codegen_node_schedule(self, node_schedule, numel, reduction_numel):
  1289. tiled_groups = self.select_tiling(node_schedule, numel, reduction_numel)
  1290. reductions = list(
  1291. filter(
  1292. lambda n: n not in (EnableReduction, DisableReduction)
  1293. and n.is_reduction(),
  1294. node_schedule,
  1295. )
  1296. )
  1297. if len(reductions) > 0:
  1298. hints = [self.reduction_hint(n) for n in reductions]
  1299. if hints.count(hints[0]) == len(hints):
  1300. reduction_hint_val = hints[0]
  1301. else:
  1302. reduction_hint_val = ReductionHint.DEFAULT
  1303. else:
  1304. reduction_hint_val = ReductionHint.DEFAULT
  1305. mutations = set()
  1306. for node in node_schedule:
  1307. if hasattr(node, "get_mutations"):
  1308. mutations.update(node.get_mutations())
  1309. with TritonKernel(
  1310. *tiled_groups, reduction_hint=reduction_hint_val, mutations=mutations
  1311. ) as kernel:
  1312. stack = contextlib.ExitStack()
  1313. for node in node_schedule:
  1314. if node not in (EnableReduction, DisableReduction):
  1315. node.mark_run()
  1316. for node in node_schedule:
  1317. if node is DisableReduction:
  1318. stack.enter_context(kernel.disable_reduction())
  1319. elif node is EnableReduction:
  1320. stack.close()
  1321. else:
  1322. # TODO - mostly works but needs a couple fixes
  1323. if not dynamo_config.dynamic_shapes:
  1324. # TODO - use split ranges ?
  1325. indexing_dtype_strength_reduction(node._body)
  1326. index_vars = kernel.split_and_set_ranges(node.get_ranges())
  1327. node.codegen(index_vars)
  1328. src_code = kernel.codegen_kernel()
  1329. kernel_name = self.define_kernel(src_code, node_schedule)
  1330. kernel.call_kernel(V.graph.wrapper_code, kernel_name)
  1331. self.scheduler.free_buffers()
  1332. def define_kernel(self, src_code, node_schedule):
  1333. wrapper = V.graph.wrapper_code
  1334. if src_code in wrapper.kernels:
  1335. kernel_name = wrapper.kernels[src_code]
  1336. else:
  1337. fused_name = (
  1338. get_fused_kernel_name(node_schedule)
  1339. if config.triton.descriptive_kernel_names
  1340. else ""
  1341. )
  1342. kernel_name = "_".join(["triton", fused_name, wrapper.next_kernel_suffix()])
  1343. wrapper.kernels[src_code] = kernel_name
  1344. subs_name = kernel_name if config.triton.ordered_kernel_names else "triton_"
  1345. src_code = src_code.replace("KERNEL_NAME", subs_name)
  1346. # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does
  1347. # not use BracesBuffer, so we have no good indicator of a C++ buffer atm.
  1348. src_code = src_code.replace("#pragma CMT", "#")
  1349. wrapper.define_kernel(kernel_name, src_code)
  1350. return kernel_name
  1351. def codegen_template(self, template_node, epilogue_nodes):
  1352. """
  1353. Codegen a triton template
  1354. """
  1355. _, (numel, rnumel) = template_node.group
  1356. assert rnumel == 1
  1357. kernel, render = template_node.node.make_kernel_render(template_node.node)
  1358. with kernel:
  1359. for node in [template_node, *epilogue_nodes]:
  1360. node.mark_run()
  1361. render() # warmup run to get the args right
  1362. for node in epilogue_nodes:
  1363. node.codegen(kernel.split_and_set_ranges(node.get_ranges()))
  1364. src_code = kernel.codegen_template_wrapper(render())
  1365. kernel_name = self.define_kernel(src_code, [template_node, *epilogue_nodes])
  1366. kernel.call_kernel(V.graph.wrapper_code, kernel_name)
  1367. self.scheduler.free_buffers()
  1368. def codegen_sync(self):
  1369. V.graph.wrapper_code.writeline("torch.cuda.synchronize()")
  1370. @staticmethod
  1371. @functools.lru_cache(32)
  1372. def candidate_tilings(node):
  1373. ranges, reduction_ranges = node.get_ranges()
  1374. if len(ranges) <= 1:
  1375. return ()
  1376. rw = node.pointwise_read_writes()
  1377. assert len(rw.range_vars) == len(ranges)
  1378. deps = [
  1379. dep
  1380. for dep in itertools.chain(rw.reads, rw.writes)
  1381. if dep.name not in V.graph.removed_buffers
  1382. ]
  1383. write_names = {dep.name for dep in rw.writes}
  1384. tilings = []
  1385. for dep in deps:
  1386. strides = V.graph.sizevars.stride_hints(dep.index, rw.range_vars)
  1387. assert len(strides) == len(ranges)
  1388. try:
  1389. split = strides.index(1) + 1
  1390. if split == len(ranges):
  1391. continue
  1392. if all(s == 0 for s in strides[split:]):
  1393. # if this is a broadcasted tensor and all dimensions after split are broadcast,
  1394. # this is not a real split
  1395. continue
  1396. except ValueError:
  1397. continue
  1398. tiled_groups = (
  1399. V.graph.sizevars.simplify(sympy_product(ranges[:split])),
  1400. V.graph.sizevars.simplify(sympy_product(ranges[split:])),
  1401. )
  1402. # score by number of elements
  1403. score = V.graph.sizevars.size_hint(
  1404. sympy_product(
  1405. size for size, stride in zip(ranges, strides) if stride != 0
  1406. )
  1407. )
  1408. if dep.name in write_names:
  1409. # ngimel said contiguous writes is more important than reads
  1410. score *= 2
  1411. if CandidateTiling.is_good_size(tiled_groups[0]):
  1412. score *= 2
  1413. if CandidateTiling.is_good_size(tiled_groups[1]):
  1414. score *= 2
  1415. if (
  1416. V.graph.sizevars.size_hint(
  1417. score - sympy_product(itertools.chain(ranges, reduction_ranges))
  1418. )
  1419. >= 0
  1420. ):
  1421. tilings.append(CandidateTiling(tiled_groups, score, dep.name))
  1422. return tilings
  1423. @classmethod
  1424. def select_tiling(cls, node_schedule, numel, reduction_numel=sympy.Integer(1)):
  1425. """
  1426. Heuristics to decide how to tile kernels.
  1427. Currently, we tile based on stride-1 dimensions.
  1428. Returns:
  1429. `(tile1, tile2, reduction_numel)` s.t. `tile1 * tile2 == numel`
  1430. """
  1431. if reduction_numel != 1 or config.triton.max_tiles <= 1:
  1432. # TODO(jansel): should we tile reductions?
  1433. return (numel, reduction_numel)
  1434. seen_names = set()
  1435. candidate_tiles = collections.Counter()
  1436. for node in EnableReduction.filter(node_schedule):
  1437. for tiling in cls.candidate_tilings(node):
  1438. if tiling.name in seen_names:
  1439. continue
  1440. seen_names.add(tiling.name)
  1441. candidate_tiles[tiling.tiling] += tiling.score
  1442. ranked_tilings = [tiling for tiling, score in candidate_tiles.most_common()]
  1443. if config.triton.max_tiles >= 3:
  1444. # Add one 3D tiling choice
  1445. for i in range(1, len(ranked_tilings)):
  1446. a0, a1 = ranked_tilings[0]
  1447. b0, b1 = ranked_tilings[i]
  1448. if V.graph.sizevars.size_hint(a1 - b1) == 0:
  1449. continue
  1450. if V.graph.sizevars.size_hint(a1 - b1) < 0:
  1451. # swap so a0 is bigger
  1452. a0, a1 = ranked_tilings[i]
  1453. b0, b1 = ranked_tilings[0]
  1454. assert V.graph.sizevars.size_hint(a1 - b1) > 0
  1455. if V.graph.sizevars.maybe_guard_multiple_of(a1, b1):
  1456. tiling = (a0, ir.FloorDiv(a1, b1), b1)
  1457. ranked_tilings = [tiling] + ranked_tilings
  1458. break # only 1 choice for now
  1459. for tiled_groups in ranked_tilings:
  1460. new_groups = (*tiled_groups, reduction_numel)
  1461. if all(
  1462. TritonKernel.is_compatible(new_groups, node.get_ranges())
  1463. for node in node_schedule
  1464. if isinstance(node, scheduler.SchedulerNode)
  1465. ):
  1466. return new_groups
  1467. return (numel, reduction_numel)
  1468. def flush(self):
  1469. pass
  1470. @dataclasses.dataclass
  1471. class CandidateTiling:
  1472. tiling: List[sympy.Expr]
  1473. score: int # higher is better
  1474. name: str = None
  1475. @staticmethod
  1476. def is_good_size(s):
  1477. """Somewhat arbitrary heuristic used to boost scores for some sizes"""
  1478. s = V.graph.sizevars.size_hint(s)
  1479. return s >= 32 and (s % 32 == 0)
  1480. class DisableReduction:
  1481. """
  1482. Marker to invoke `kernel.disable_reduction()`. This closes a
  1483. reduction loop and allows for pointwise ops to occur on the output
  1484. of a reduction.
  1485. """
  1486. class EnableReduction:
  1487. """
  1488. Marker to end a DisableReduction block.
  1489. """
  1490. @staticmethod
  1491. def filter(node_schedule):
  1492. """
  1493. Get the nodes from node_schedule skipping those in a
  1494. DisableReduction block.
  1495. """
  1496. disabled = False
  1497. for node in node_schedule:
  1498. if node in (EnableReduction, DisableReduction):
  1499. # Don't tile stuff outside the main reduction loop
  1500. disabled = node is DisableReduction
  1501. elif disabled:
  1502. pass
  1503. else:
  1504. yield node
  1505. class CantSplit(Exception):
  1506. pass