common.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661
  1. import contextlib
  2. import itertools
  3. import logging
  4. import re
  5. import typing
  6. from collections import namedtuple
  7. from itertools import chain
  8. import sympy
  9. from sympy.printing.printer import Printer
  10. from .. import metrics
  11. from ..utils import (
  12. DeferredLineBase,
  13. free_symbol_startswith,
  14. IndentedBuffer,
  15. sympy_dot,
  16. sympy_subs,
  17. sympy_symbol,
  18. unique,
  19. )
  20. from ..virtualized import ops, V
  21. log = logging.getLogger(__name__)
  22. TensorArg = namedtuple("TensorArg", ["name", "buffer", "dtype"])
  23. SizeArg = namedtuple("SizeArg", ["name", "expr"])
  24. def index_prevent_reordering(index: typing.List[sympy.Expr], index_vars, sizes):
  25. from ..ir import FlexibleLayout
  26. # added contiguous index prevents reordering
  27. return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))]
  28. class ExprPrinter(Printer):
  29. @staticmethod
  30. def paren(string):
  31. if (
  32. isinstance(string, CSEVariable)
  33. or re.match(r"^[a-z0-9_.]+$", string, re.I)
  34. or re.match(r"^\([^)]*\)$", string, re.I)
  35. or string == ""
  36. ):
  37. return string
  38. return f"({string})"
  39. def _print_Pow(self, expr):
  40. # Pow() confuses triton
  41. base, exp = expr.args
  42. base = self._print(base)
  43. assert exp.is_integer
  44. exp = int(exp)
  45. if exp > 0:
  46. return "*".join([self.paren(base)] * exp)
  47. elif exp < 0:
  48. return "1/" + self.paren("*".join([self.paren(base)] * abs(exp)))
  49. else: # exp == 0
  50. return "1"
  51. def _print_Mul(self, expr):
  52. return "*".join(map(self.paren, map(self._print, expr.args)))
  53. def _print_Add(self, expr):
  54. return " + ".join(map(self.paren, map(self._print, expr.args)))
  55. def _print_Mod(self, expr):
  56. return " % ".join(map(self.paren, map(self._print, expr.args)))
  57. def _print_CleanDiv(self, expr):
  58. return self._print_FloorDiv(expr)
  59. class PythonPrinter(ExprPrinter):
  60. def _print_ModularIndexing(self, expr):
  61. x, div, mod = expr.args
  62. x = self.paren(self.doprint(x))
  63. div = self.paren(self.doprint(div))
  64. mod = self.paren(self.doprint(mod))
  65. if div != "1":
  66. x = f"({x} // {div})"
  67. return f"{x} % {mod}"
  68. def _print_FloorDiv(self, expr):
  69. x, div = expr.args
  70. x = self.paren(self.doprint(x))
  71. div = self.paren(self.doprint(div))
  72. return f"({x} // {div})"
  73. def _print_floor(self, expr):
  74. assert len(expr.args) == 1
  75. return f"math.floor({self.paren(self._print(expr.args[0]))})"
  76. class OpOverrides:
  77. def __init__(self, parent):
  78. super().__init__()
  79. self._parent = parent
  80. def __getattr__(self, item):
  81. return getattr(self._parent, item)
  82. @staticmethod
  83. def identity(value):
  84. # used to trigger cse
  85. return value
  86. @staticmethod
  87. def constant(value, dtype):
  88. return repr(value)
  89. @staticmethod
  90. def reciprocal(x):
  91. return ops.div("1", x)
  92. @staticmethod
  93. def square(x):
  94. return ops.mul(x, x)
  95. @staticmethod
  96. def sign(x):
  97. left = ops.where(ops.lt("0", x), "1", "0")
  98. right = ops.where(ops.lt(x, "0"), "1", "0")
  99. return ops.sub(left, right)
  100. @staticmethod
  101. def bitwise_not(x):
  102. return f"~{ExprPrinter.paren(x)}"
  103. @staticmethod
  104. def logical_not(a):
  105. return f"{ExprPrinter.paren(a)} == 0"
  106. @staticmethod
  107. def bitwise_and(x, y):
  108. return f"{ExprPrinter.paren(x)} & {ExprPrinter.paren(y)}"
  109. @staticmethod
  110. def bitwise_or(x, y):
  111. return f"{ExprPrinter.paren(x)} | {ExprPrinter.paren(y)}"
  112. @staticmethod
  113. def bitwise_xor(x, y):
  114. return f"{ExprPrinter.paren(x)} ^ {ExprPrinter.paren(y)}"
  115. @staticmethod
  116. def bitwise_left_shift(x, y):
  117. return f"{ExprPrinter.paren(x)} << {ExprPrinter.paren(y)}"
  118. # TODO(fdrocha): this is currently not being used anywhere,
  119. # pending on moving triton pin past 972b761
  120. @staticmethod
  121. def bitwise_right_shift(x, y):
  122. return f"{ExprPrinter.paren(x)} >> {ExprPrinter.paren(y)}"
  123. @staticmethod
  124. def remainder(a, b):
  125. r = ops.mod(a, b)
  126. return ops.where(f"(({r} != 0) & (({r} < 0) != ({b} < 0)))", ops.add(r, b), r)
  127. class DeferredLine(DeferredLineBase):
  128. """A line that can be 'unwritten' by adding name to V.graph.removed_buffers"""
  129. def __init__(self, name, line):
  130. super().__init__(line)
  131. self.name = name
  132. def __call__(self):
  133. if (
  134. self.name not in V.graph.removed_buffers
  135. and self.name not in V.graph.inplaced_to_remove
  136. ):
  137. return self.line
  138. return None
  139. def _new_line(self, line):
  140. return DeferredLine(self.name, line)
  141. class DeferredIndentedBuffer(IndentedBuffer):
  142. def __init__(self, initial_indent=0):
  143. super().__init__(initial_indent)
  144. def writeline(self, name, line):
  145. if name is None:
  146. return super().writeline(line)
  147. assert "buf" in name
  148. return super().writeline(DeferredLine(name, line))
  149. def writelines(self, name, lines):
  150. for line in lines:
  151. self.writeline(name, line)
  152. class BracesBuffer(IndentedBuffer):
  153. def indent(self, offset=1):
  154. @contextlib.contextmanager
  155. def ctx():
  156. for _ in range(offset):
  157. self.writeline("{")
  158. self._indent += 1
  159. for _ in range(-offset):
  160. self._indent -= 1
  161. self.writeline("}")
  162. yield
  163. for _ in range(-offset):
  164. self.writeline("{")
  165. self._indent += 1
  166. for _ in range(offset):
  167. self._indent -= 1
  168. self.writeline("}")
  169. return ctx()
  170. class InplacedBuffer(typing.NamedTuple):
  171. inner_name: str
  172. other_names: typing.List[str]
  173. class KernelArgs:
  174. @staticmethod
  175. def _lookup(prefix, odict, name):
  176. assert isinstance(name, (str, sympy.Symbol))
  177. if name not in odict:
  178. odict[name] = f"{prefix}{len(odict)}"
  179. return odict[name]
  180. def __init__(self, sizevars=None):
  181. self.input_buffers = dict()
  182. self.output_buffers = dict()
  183. self.inplace_buffers = dict()
  184. self.sizevars = sizevars or dict()
  185. def __repr__(self):
  186. return "KernelArgs({})".format(
  187. ", ".join(
  188. map(
  189. repr,
  190. [
  191. self.input_buffers,
  192. self.output_buffers,
  193. self.inplace_buffers,
  194. self.sizevars,
  195. ],
  196. )
  197. )
  198. )
  199. def input(self, name):
  200. if V.graph.scheduler:
  201. name = V.graph.scheduler.mutation_real_name.get(name, name)
  202. assert name not in V.graph.removed_buffers, name
  203. if name in self.output_buffers:
  204. return self.output_buffers[name]
  205. if name in self.inplace_buffers:
  206. return self.inplace_buffers[name].inner_name
  207. if name.startswith("seed"):
  208. return self._lookup("seed", self.input_buffers, name)
  209. return self._lookup("in_ptr", self.input_buffers, name)
  210. def output(self, name):
  211. if V.graph.scheduler:
  212. name = V.graph.scheduler.mutation_real_name.get(name, name)
  213. assert name not in V.graph.removed_buffers, name
  214. if name in self.inplace_buffers:
  215. return self.inplace_buffers[name].inner_name
  216. return self._lookup("out_ptr", self.output_buffers, name)
  217. def make_inplace(self, input_name, output_name):
  218. assert output_name not in self.inplace_buffers
  219. if input_name in self.inplace_buffers:
  220. buf = self.inplace_buffers[input_name]
  221. buf.other_names.append(output_name)
  222. self.inplace_buffers[output_name] = buf
  223. else:
  224. buf = InplacedBuffer(
  225. f"in_out_ptr{len(unique(self.inplace_buffers.values()))}",
  226. [input_name, output_name],
  227. )
  228. self.inplace_buffers[input_name] = buf
  229. self.inplace_buffers[output_name] = buf
  230. def size(self, name):
  231. if str(name) == "seed":
  232. self.sizevars["seed"] = "seed"
  233. return "seed"
  234. return self._lookup("ks", self.sizevars, name)
  235. def call_names(self):
  236. return chain(
  237. self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys()
  238. )
  239. def wrap_ptr_arg(self, buf, dtype):
  240. return f"c_void_p({buf}.data_ptr())"
  241. def wrap_size_arg(self, size):
  242. return f"c_long({size})"
  243. def cpp_argdefs(self):
  244. from .cpp import DTYPE_TO_CPP, INDEX_TYPE
  245. # TODO(jansel): replace this with data from scheduler
  246. buffer_types = {x.get_name(): x.get_dtype() for x in V.graph.buffers}
  247. buffer_types.update(
  248. {name: val.get_dtype() for name, val in V.graph.graph_inputs.items()}
  249. )
  250. buffer_types.update(
  251. {name: val.dtype for name, val in V.graph.constants.items()}
  252. )
  253. call_args = []
  254. arg_defs = []
  255. arg_types = []
  256. for inplaced in unique(self.inplace_buffers.values()):
  257. outer = inplaced.other_names[-1]
  258. inner = inplaced.inner_name
  259. dtype = buffer_types[outer]
  260. cpp_dtype = DTYPE_TO_CPP[dtype]
  261. arg_defs.append(f"{cpp_dtype}* __restrict__ {inner}")
  262. call_args.append(self.wrap_ptr_arg(outer, dtype))
  263. arg_types.append(f"{cpp_dtype}*")
  264. for outer, inner in self.input_buffers.items():
  265. if outer in self.inplace_buffers:
  266. continue
  267. dtype = buffer_types[outer]
  268. cpp_dtype = DTYPE_TO_CPP[dtype]
  269. arg_defs.append(f"const {cpp_dtype}* __restrict__ {inner}")
  270. call_args.append(self.wrap_ptr_arg(outer, dtype))
  271. arg_types.append(f"const {cpp_dtype}*")
  272. for outer, inner in self.output_buffers.items():
  273. if outer in self.inplace_buffers or inner == "REMOVED":
  274. continue
  275. dtype = buffer_types[outer]
  276. cpp_dtype = DTYPE_TO_CPP[dtype]
  277. arg_defs.append(f"{cpp_dtype}* __restrict__ {inner}")
  278. call_args.append(self.wrap_ptr_arg(outer, dtype))
  279. arg_types.append(f"{cpp_dtype}*")
  280. for outer, inner in self.sizevars.items():
  281. arg_defs.append(f"const {INDEX_TYPE} {inner}")
  282. call_args.append(self.wrap_size_arg(outer))
  283. arg_types.append(f"const {INDEX_TYPE}")
  284. return arg_defs, call_args, arg_types
  285. def python_argdefs(self):
  286. arg_defs = []
  287. call_args = []
  288. precompile_args = []
  289. for inplaced in unique(self.inplace_buffers.values()):
  290. arg_defs.append(inplaced.inner_name)
  291. call_args.append(inplaced.other_names[-1])
  292. precompile_args.append(
  293. TensorArg(
  294. inplaced.inner_name,
  295. inplaced.other_names[-1],
  296. V.graph.get_dtype(inplaced.other_names[-1]),
  297. )
  298. )
  299. for outer, inner in chain(
  300. self.input_buffers.items(), self.output_buffers.items()
  301. ):
  302. if outer in self.inplace_buffers or inner == "REMOVED":
  303. continue
  304. arg_defs.append(inner)
  305. call_args.append(outer)
  306. precompile_args.append(TensorArg(inner, outer, V.graph.get_dtype(outer)))
  307. for outer, inner in self.sizevars.items():
  308. arg_defs.append(inner)
  309. call_args.append(str(outer))
  310. precompile_args.append(SizeArg(inner, outer))
  311. return arg_defs, call_args, precompile_args
  312. def aliases(self):
  313. for inplaced in unique(self.inplace_buffers.values()):
  314. for other in inplaced.other_names:
  315. if other in V.graph.inplaced_to_remove:
  316. continue
  317. if other in self.input_buffers:
  318. yield self.input_buffers[other], inplaced.inner_name
  319. if other in self.output_buffers:
  320. yield self.output_buffers[other], inplaced.inner_name
  321. def is_removed(self, name):
  322. def _is_removed(name, buffers):
  323. return name not in buffers or buffers[name] == "REMOVED"
  324. return _is_removed(name, self.output_buffers) and _is_removed(
  325. name, self.inplace_buffers
  326. )
  327. class CSEVariable:
  328. """A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis.
  329. The backends can inherit from this class and overload the "create_cse_var" Kernel to do that.
  330. The "update_on_args" method gives you a hook for annotations, see example of TritonCSEVariable in triton.py."""
  331. def __init__(self, name):
  332. self.name = name
  333. def __str__(self):
  334. return self.name
  335. def __hash__(self) -> int:
  336. return hash(self.name)
  337. def __eq__(self, other) -> bool:
  338. return type(other) == type(self) and other.name == self.name
  339. def update_on_args(self, name, args, kwargs):
  340. pass
  341. class CppWrapperKernelArgs(KernelArgs):
  342. def wrap_ptr_arg(self, buf, dtype):
  343. from .cpp import DTYPE_TO_CPP
  344. return f"({DTYPE_TO_CPP[dtype]}*)({buf}.data_ptr())"
  345. def wrap_size_arg(self, size):
  346. return f"{size}"
  347. class CSE:
  348. """Common subexpression elimination"""
  349. def __init__(
  350. self,
  351. prefix="",
  352. suffix="",
  353. name_prefix="tmp",
  354. iter_buffers=None,
  355. store_cache=None,
  356. reduction_cache=None,
  357. varname_map=None,
  358. ):
  359. self.prefix = prefix
  360. self.suffix = suffix
  361. self.cache = {}
  362. self.name_prefix = name_prefix
  363. self.store_cache = store_cache or {}
  364. self.reduction_cache = reduction_cache or {}
  365. self.iter_buffer_ids = iter_buffers or itertools.count()
  366. self.invalidated_stores = set()
  367. self.varname_map = varname_map or {}
  368. def invalidate(self, keep_vars: typing.Set[str]):
  369. for name, tmp in list(self.store_cache.items()):
  370. if tmp not in keep_vars:
  371. del self.store_cache[name]
  372. self.invalidated_stores.add(name)
  373. self.cache = {k: v for k, v in self.cache.items() if v in keep_vars}
  374. def clone(self):
  375. # Note(fdrocha): reduction_cache is not being cloned, not sure if this is intentional
  376. return CSE(
  377. prefix=self.prefix,
  378. suffix=self.suffix,
  379. name_prefix=self.name_prefix,
  380. iter_buffers=self.iter_buffer_ids,
  381. store_cache=self.store_cache,
  382. varname_map=self.varname_map,
  383. )
  384. def generate(
  385. self,
  386. buffer: IndentedBuffer,
  387. expr: typing.Union[str, CSEVariable],
  388. write=True,
  389. append_broadcast=None,
  390. ) -> CSEVariable:
  391. assert isinstance(expr, (str, CSEVariable)), type(expr)
  392. if isinstance(expr, CSEVariable):
  393. return expr
  394. cache_key = expr
  395. if append_broadcast:
  396. assert isinstance(append_broadcast, str)
  397. cache_key = expr + append_broadcast
  398. if cache_key not in self.cache:
  399. var = self.newvar()
  400. self.cache[cache_key] = var
  401. if write:
  402. if V.kernel.current_node:
  403. V.kernel.current_node.codegen_originating_info(
  404. buffer, only_once=True
  405. )
  406. if append_broadcast:
  407. var_suffix = "_load"
  408. else:
  409. var_suffix = ""
  410. buffer.writeline(
  411. f"{self.prefix}{var}{var_suffix} = {expr}{self.suffix}"
  412. )
  413. if append_broadcast:
  414. buffer.writeline(
  415. f"{self.prefix}{var} = tl.broadcast_to({var}{var_suffix}, {append_broadcast})"
  416. )
  417. return self.cache[cache_key]
  418. def newvar(self) -> CSEVariable:
  419. var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}"
  420. var = V.kernel.create_cse_var(var_name)
  421. self.varname_map[var_name] = var
  422. return var
  423. class CodeGen:
  424. def __init__(self):
  425. super().__init__()
  426. self.exit_stack = contextlib.ExitStack()
  427. def __enter__(self):
  428. self.exit_stack.__enter__()
  429. return self
  430. def __exit__(self, exc_type, exc_val, exc_tb):
  431. self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
  432. class Kernel(CodeGen):
  433. newvar_prefix = ""
  434. suffix = ""
  435. overrides = None
  436. load_format = None
  437. store_format = None
  438. def __init__(self, args=None):
  439. super().__init__()
  440. metrics.generated_kernel_count += 1
  441. self.args = args or KernelArgs()
  442. self.loads = IndentedBuffer()
  443. self.compute = IndentedBuffer()
  444. self.stores = DeferredIndentedBuffer()
  445. self.cse = CSE(self.newvar_prefix, self.suffix)
  446. self.must_keep_buffers = set()
  447. self.current_node = None
  448. self.store_buffer_names = set()
  449. @contextlib.contextmanager
  450. def set_current_node(self, node):
  451. prior = self.current_node
  452. self.current_node = node
  453. yield
  454. self.current_node = prior
  455. @contextlib.contextmanager
  456. def swap_buffers(self, lb, cb=None, sb=None):
  457. if cb is None:
  458. cb = lb
  459. loads = self.loads
  460. compute = self.compute
  461. stores = self.stores
  462. cse = self.cse
  463. self.loads = lb
  464. self.compute = cb
  465. self.stores = sb
  466. self.cse = cse.clone()
  467. yield
  468. self.loads = loads
  469. self.compute = compute
  470. self.stores = stores
  471. self.cse = cse
  472. def load(self, name: str, index: sympy.Expr):
  473. raise NotImplementedError()
  474. def indirect_load(self, name: str, index: sympy.Expr):
  475. """A load the depends on an index we have read"""
  476. prior = self.loads
  477. try:
  478. # put the load in the compute section as it might have deps
  479. self.loads = self.compute
  480. return self.load(name, index)
  481. finally:
  482. self.loads = prior
  483. def store(self, name, index, value, mode=None):
  484. raise NotImplementedError()
  485. def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
  486. raise NotImplementedError()
  487. def __enter__(self):
  488. class CSEProxy:
  489. self.name = "CSEProxy"
  490. @staticmethod
  491. def __getattr__(name):
  492. def inner(*args, **kwargs):
  493. csevar = self.cse.generate(
  494. self.compute, getattr(parent_handler, name)(*args, **kwargs)
  495. )
  496. csevar.update_on_args(name, args, kwargs)
  497. return csevar
  498. return inner
  499. @staticmethod
  500. def indirect_indexing(index_var):
  501. return sympy_symbol(str(index_var))
  502. @staticmethod
  503. def load(name: str, index: sympy.Expr):
  504. if name in self.cse.invalidated_stores:
  505. # A load from an invalidated store requires us to
  506. # keep the actual buffer around
  507. V.kernel.must_keep_buffers.add(name)
  508. if free_symbol_startswith(index, "tmp"):
  509. return self.indirect_load(name, index)
  510. store_cache = self.cse.store_cache
  511. if name in store_cache:
  512. return store_cache[name]
  513. return self.load(name, index)
  514. @staticmethod
  515. def store(name, index, value, mode=None):
  516. self.store_buffer_names.add(name)
  517. if mode is None:
  518. self.cse.store_cache[name] = value
  519. if self.current_node:
  520. for other_name in self.current_node.get_mutations():
  521. self.cse.store_cache[other_name] = value
  522. if name not in V.graph.removed_buffers:
  523. return self.store(name, index, value, mode=mode)
  524. @staticmethod
  525. def reduction(name, dtype, src_dtype, reduction_type, index, value):
  526. self.store_buffer_names.add(name)
  527. return self.reduction(
  528. name, dtype, src_dtype, reduction_type, index, value
  529. )
  530. super().__enter__()
  531. parent_handler = self.overrides(V.get_ops_handler())
  532. self.exit_stack.enter_context(V.set_ops_handler(CSEProxy()))
  533. self.exit_stack.enter_context(V.set_kernel_handler(self))
  534. return self
  535. def __exit__(self, exc_type, exc_val, exc_tb):
  536. if V.graph.scheduler:
  537. V.graph.scheduler.remove_kernel_local_buffers()
  538. super().__exit__(exc_type, exc_val, exc_tb)
  539. def rename_indexing(self, index) -> sympy.Expr:
  540. if isinstance(index, (list, tuple)):
  541. return [self.rename_indexing(x) for x in index]
  542. index = V.graph.sizevars.simplify(index)
  543. sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name)
  544. replacements = {
  545. x: self.args.size(x)
  546. for x in sorted_symbols
  547. if x.name.startswith("s") or x.name.startswith("ps")
  548. }
  549. return sympy_subs(index, replacements)
  550. def create_cse_var(self, *args, **kwargs):
  551. return CSEVariable(*args, **kwargs)