select_algorithm.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685
  1. import builtins
  2. import functools
  3. import inspect
  4. import itertools
  5. import logging
  6. import sys
  7. import textwrap
  8. from io import StringIO
  9. from typing import Any, List
  10. from unittest.mock import patch
  11. import sympy
  12. import torch
  13. from torch._dynamo.testing import rand_strided
  14. from torch._dynamo.utils import counters, identity
  15. from . import ir
  16. from .codecache import code_hash, DiskCache, PyCodeCache
  17. from .codegen.common import IndentedBuffer
  18. from .codegen.triton import config_of, signature_of, texpr, TritonKernel, TritonPrinter
  19. from .utils import do_bench, sympy_dot, sympy_product
  20. from .virtualized import V
  21. log = logging.getLogger(__name__)
  22. # correctness checks struggle with fp16/tf32
  23. VERIFY = False # dict(atol=1, rtol=0.05)
  24. PRINT_AUTOTUNE = True
  25. class KernelNamespace:
  26. pass
  27. # these objects are imported from the generated wrapper code
  28. template_kernels = KernelNamespace()
  29. extern_kernels = KernelNamespace()
  30. class TritonTemplateKernel(TritonKernel):
  31. def __init__(
  32. self,
  33. kernel_name,
  34. input_nodes,
  35. output_node,
  36. defines,
  37. num_stages,
  38. num_warps,
  39. grid_fn,
  40. meta,
  41. call_sizes,
  42. use_jit=True,
  43. prefix_args=0,
  44. suffix_args=0,
  45. epilogue_fn=identity,
  46. ):
  47. super().__init__(sympy_product(output_node.get_size()), sympy.Integer(1))
  48. self.input_nodes = input_nodes
  49. self.output_node = output_node
  50. self.named_input_nodes = {}
  51. self.defines = defines
  52. self.kernel_name = kernel_name
  53. self.template_mask = None
  54. self.use_jit = use_jit
  55. self.num_stages = num_stages
  56. self.num_warps = num_warps
  57. self.grid_fn = grid_fn
  58. self.meta = meta
  59. self.call_sizes = call_sizes
  60. # for templates with fixed epilogues
  61. self.prefix_args = prefix_args
  62. self.suffix_args = suffix_args
  63. self.epilogue_fn = epilogue_fn
  64. def jit_line(self):
  65. if self.use_jit:
  66. return "@triton.jit"
  67. argdefs, _, signature = self.args.python_argdefs()
  68. triton_meta = {
  69. "signature": dict(enumerate(map(signature_of, signature))),
  70. "device": V.graph.scheduler.current_device.index,
  71. "constants": {},
  72. }
  73. triton_meta["configs"] = [config_of(signature)]
  74. return (
  75. f"@template(num_stages={self.num_stages}, num_warps={self.num_warps}, meta={triton_meta!r})\n"
  76. + "@triton.jit"
  77. )
  78. def def_kernel(self, *argnames):
  79. """
  80. Hook called from template code to generate function def and
  81. needed args.
  82. """
  83. assert all(isinstance(x, str) for x in argnames)
  84. renames = IndentedBuffer(initial_indent=1)
  85. named_args = self.input_nodes[
  86. self.prefix_args : len(self.input_nodes) - self.suffix_args
  87. ]
  88. assert len(argnames) == len(named_args), (
  89. len(argnames),
  90. len(named_args),
  91. self.prefix_args,
  92. len(self.input_nodes),
  93. )
  94. for input_node in self.input_nodes[: self.prefix_args]:
  95. # get args in correct order
  96. self.args.input(input_node.get_name())
  97. for name, input_node in zip(argnames, named_args):
  98. arg_name = f"arg_{name}"
  99. self.named_input_nodes[name] = input_node
  100. self.args.input_buffers[input_node.get_name()] = arg_name
  101. if input_node.get_layout().offset == 0:
  102. renames.writeline(f"{name} = {arg_name}")
  103. else:
  104. offset = texpr(self.rename_indexing(input_node.get_layout().offset))
  105. renames.writeline(f"{name} = {arg_name} + {offset}")
  106. for input_node in self.input_nodes[len(self.input_nodes) - self.suffix_args :]:
  107. # get args in correct order
  108. self.args.input(input_node.get_name())
  109. arg_defs, *_ = self.args.python_argdefs()
  110. return "\n".join(
  111. [
  112. "import triton.language as tl",
  113. "import triton",
  114. "from torch._inductor.triton_ops.autotune import template",
  115. "from torch._inductor.utils import instance_descriptor",
  116. "",
  117. self.jit_line(),
  118. f"def {self.kernel_name}({', '.join(arg_defs)}):",
  119. self.defines,
  120. renames.getvalue(),
  121. ]
  122. )
  123. def size(self, name: str, index: int):
  124. """
  125. Hook called from template code to get the size of an arg.
  126. Will add needed args to pass it in if it is dynamic.
  127. """
  128. assert isinstance(name, str)
  129. assert isinstance(index, int)
  130. val = self.named_input_nodes[name].get_size()[index]
  131. return texpr(self.rename_indexing(val))
  132. def stride(self, name, index):
  133. """
  134. Hook called from template code to get the stride of an arg.
  135. Will add needed args to pass it in if it is dynamic.
  136. """
  137. assert isinstance(name, str)
  138. assert isinstance(index, int)
  139. val = self.named_input_nodes[name].get_stride()[index]
  140. return texpr(self.rename_indexing(val))
  141. def store_output(self, indices, val, mask):
  142. """
  143. Hook called from template code to store the final output
  144. (if the buffer hasn't been optimized away), then append any
  145. epilogue fusions.
  146. """
  147. assert isinstance(indices, (list, tuple))
  148. assert isinstance(val, str)
  149. assert isinstance(mask, str)
  150. if self.template_mask is None:
  151. indices = list(map(TritonPrinter.paren, indices))
  152. index_symbols = [sympy.Symbol(x) for x in indices]
  153. lengths = [
  154. V.graph.sizevars.simplify(s) for s in self.output_node.get_size()
  155. ]
  156. assert len(indices) == len(lengths)
  157. # glue to make generated code use same indexing from template
  158. for name, range_tree_entry in zip(
  159. indices, self.range_trees[0].construct_entries(lengths)
  160. ):
  161. range_tree_entry.set_name(name)
  162. contiguous_index = sympy_dot(
  163. ir.FlexibleLayout.contiguous_strides(lengths), index_symbols
  164. )
  165. self.body.writeline("xindex = " + texpr(contiguous_index))
  166. self.range_trees[0].lookup(
  167. sympy.Integer(1), sympy_product(lengths)
  168. ).set_name("xindex")
  169. self.template_mask = mask
  170. self.template_indices = indices
  171. output_index = self.output_node.get_layout().make_indexer()(index_symbols)
  172. if output_index == contiguous_index:
  173. output_index = sympy.Symbol("xindex")
  174. epilogue_args = [val]
  175. for input_node in itertools.chain(
  176. self.input_nodes[: self.prefix_args],
  177. self.input_nodes[len(self.input_nodes) - self.suffix_args :],
  178. ):
  179. input_node.freeze_layout()
  180. epilogue_args.append(input_node.make_loader()(index_symbols))
  181. V.ops.store(
  182. self.output_node.get_name(),
  183. output_index,
  184. self.epilogue_fn(*epilogue_args),
  185. )
  186. assert self.template_mask == mask
  187. self.codegen_body()
  188. return textwrap.indent(self.body.getvalue(), " ").strip()
  189. def make_load(self, name, indices, mask):
  190. """
  191. Optional helper called from template code to generate the code
  192. needed to load from an tensor.
  193. """
  194. assert isinstance(indices, (list, tuple))
  195. assert isinstance(name, str)
  196. assert isinstance(mask, str)
  197. stride = self.named_input_nodes[name].get_stride()
  198. indices = list(map(TritonPrinter.paren, indices))
  199. assert len(indices) == len(stride)
  200. index = " + ".join(
  201. f"{texpr(self.rename_indexing(s))} * {i}" for s, i in zip(stride, indices)
  202. )
  203. return f"tl.load({name} + ({index}), {mask})"
  204. def template_env(self):
  205. """
  206. Generate the namespace visible in the template.
  207. """
  208. return {
  209. fn.__name__: fn
  210. for fn in [
  211. self.def_kernel,
  212. self.size,
  213. self.stride,
  214. self.store_output,
  215. self.make_load,
  216. ]
  217. }
  218. def indexing(
  219. self,
  220. index: sympy.Expr,
  221. *,
  222. copy_shape=None,
  223. dense_indexing=False,
  224. ):
  225. """
  226. Override the default indexing to use our custom mask and force
  227. dense indexing.
  228. """
  229. result, *mask = super().indexing(
  230. index,
  231. dense_indexing=False,
  232. copy_shape=copy_shape,
  233. override_mask=self.template_mask,
  234. )
  235. result += f" + tl.zeros({self.template_mask}.shape, tl.int32)"
  236. return (result, *mask)
  237. def initialize_range_tree(self, pid_cache):
  238. super().initialize_range_tree(pid_cache)
  239. # ignore default codegen
  240. self.body.clear()
  241. self.indexing_code.clear()
  242. def call_kernel(self, code, name: str):
  243. _, call_args, _ = self.args.python_argdefs()
  244. for i in range(len(call_args)):
  245. if V.graph.is_unspec_arg(call_args[i]):
  246. call_args[i] = call_args[i] + ".item()"
  247. call_args = ", ".join(call_args)
  248. stream_name = code.write_get_cuda_stream(V.graph.scheduler.current_device.index)
  249. V.graph.wrapper_code.add_import_once(f"import {self.grid_fn.__module__}")
  250. meta = V.graph.wrapper_code.add_meta_once(self.meta)
  251. grid_call = [texpr(V.graph.sizevars.simplify(s)) for s in self.call_sizes] + [
  252. meta
  253. ]
  254. grid_call = (
  255. f"{self.grid_fn.__module__}.{self.grid_fn.__name__}({', '.join(grid_call)})"
  256. )
  257. code.writeline(
  258. f"{name}.run({call_args}, grid={grid_call}, stream={stream_name})"
  259. )
  260. @functools.lru_cache(None)
  261. def _jinja2_env():
  262. try:
  263. import jinja2
  264. return jinja2.Environment(
  265. undefined=jinja2.StrictUndefined,
  266. )
  267. except ImportError:
  268. return None
  269. class TritonTemplate:
  270. index_counter = itertools.count()
  271. all_templates = dict()
  272. @staticmethod
  273. def _template_from_string(source):
  274. env = _jinja2_env()
  275. if env is not None:
  276. return env.from_string(source)
  277. return None
  278. def __init__(self, name: str, grid: Any, source: str, debug=False):
  279. super().__init__()
  280. self.name = name
  281. self.grid = grid
  282. self.template = self._template_from_string(source)
  283. assert name not in self.all_templates, "duplicate template name"
  284. self.all_templates[name] = self
  285. self.debug = debug
  286. def generate(
  287. self,
  288. input_nodes,
  289. layout,
  290. num_stages,
  291. num_warps,
  292. prefix_args=0,
  293. suffix_args=0,
  294. epilogue_fn=identity,
  295. **kwargs,
  296. ):
  297. assert self.template, "requires jinja2"
  298. defines = StringIO()
  299. for name, val in kwargs.items():
  300. defines.write(f" {name} : tl.constexpr = {val}\n")
  301. defines = defines.getvalue()
  302. fake_out = ir.Buffer("buf_out", layout)
  303. kernel_name = f"triton_{self.name}"
  304. kernel_options = dict(
  305. input_nodes=input_nodes,
  306. defines=defines,
  307. num_stages=num_stages,
  308. num_warps=num_warps,
  309. grid_fn=self.grid,
  310. meta=kwargs,
  311. call_sizes=layout.size,
  312. prefix_args=prefix_args,
  313. suffix_args=suffix_args,
  314. epilogue_fn=epilogue_fn,
  315. )
  316. with patch.object(
  317. V.graph, "get_dtype", self.fake_get_dtype(fake_out)
  318. ), TritonTemplateKernel(
  319. kernel_name=kernel_name,
  320. output_node=fake_out,
  321. use_jit=True,
  322. **kernel_options,
  323. ) as kernel:
  324. # need to do call render twice to get all the needed args right
  325. self.template.render(
  326. **kernel.template_env(),
  327. **kwargs,
  328. )
  329. code = self.template.render(
  330. **kernel.template_env(),
  331. **kwargs,
  332. )
  333. if self.debug:
  334. print("Generated Code:\n", code)
  335. mod = PyCodeCache.load(code)
  336. run = getattr(mod, kernel_name).run
  337. _, call_args, _ = kernel.args.python_argdefs()
  338. expected_args = [x.get_name() for x in input_nodes] + [fake_out.get_name()]
  339. assert list(call_args) == expected_args, (call_args, expected_args)
  340. extra_args = V.graph.sizevars.size_hints(
  341. map(sympy.expand, call_args[len(expected_args) :])
  342. )
  343. assert not extra_args, "TODO: dynamic shapes"
  344. def call(*args, out):
  345. return run(
  346. *args,
  347. out,
  348. *extra_args,
  349. grid=self.grid(*out.size(), kwargs),
  350. num_stages=num_stages,
  351. num_warps=num_warps,
  352. )
  353. call.key = mod.key
  354. call.__file__ = mod.__file__
  355. kernel_hash_name = f"triton_{self.name}_{next(self.index_counter)}"
  356. setattr(template_kernels, kernel_hash_name, call)
  357. def make_kernel_render(out_node):
  358. kernel = TritonTemplateKernel(
  359. kernel_name="KERNEL_NAME",
  360. output_node=out_node,
  361. use_jit=False,
  362. **kernel_options,
  363. )
  364. render = functools.partial(
  365. self.template.render,
  366. **kernel.template_env(),
  367. **kwargs,
  368. )
  369. return kernel, render
  370. return TritonTemplateCaller(
  371. kernel_hash_name, input_nodes, layout, make_kernel_render
  372. )
  373. @staticmethod
  374. def fake_get_dtype(fake_out):
  375. _get_dtype_real = V.graph.get_dtype
  376. def get_dtype(name):
  377. if name == fake_out.get_name():
  378. return fake_out.get_dtype()
  379. return _get_dtype_real(name)
  380. return get_dtype
  381. class ExternKernelChoice:
  382. def __init__(self, kernel, cpp_kernel=None, *, name=None):
  383. super().__init__()
  384. name = name or kernel.__name__
  385. assert callable(kernel)
  386. assert not hasattr(extern_kernels, name), "duplicate extern kernel"
  387. self.name = name
  388. self.cpp_kernel = cpp_kernel
  389. setattr(extern_kernels, name, kernel)
  390. def to_callable(self):
  391. return getattr(extern_kernels, self.name)
  392. def call_name(self):
  393. return f"extern_kernels.{self.name}"
  394. @functools.lru_cache(None)
  395. def hash_key(self):
  396. fn = self.to_callable()
  397. parts = [
  398. self.name,
  399. getattr(fn, "__name__", ""),
  400. getattr(fn, "__module__", ""),
  401. ]
  402. try:
  403. parts.append(inspect.getsource(fn))
  404. except Exception:
  405. pass
  406. return code_hash("-".join(parts))
  407. def bind(self, input_nodes, layout, **kwargs):
  408. return ExternKernelCaller(self, input_nodes, layout, kwargs)
  409. class ChoiceCaller:
  410. def __init__(self, name, input_nodes, layout):
  411. super().__init__()
  412. self.name = name
  413. self.layout = layout
  414. self.input_nodes = input_nodes
  415. class TritonTemplateCaller(ChoiceCaller):
  416. def __init__(self, name, input_nodes, layout, make_kernel_render):
  417. super().__init__(name, input_nodes, layout)
  418. self.make_kernel_render = make_kernel_render
  419. def __str__(self):
  420. return f"TritonTemplateCaller({self.to_callable().__file__})"
  421. def call_name(self):
  422. return f"template_kernels.{self.name}"
  423. def to_callable(self):
  424. return getattr(template_kernels, self.name)
  425. def hash_key(self):
  426. return self.to_callable().key
  427. def output_node(self):
  428. return ir.TensorBox.create(
  429. ir.TemplateBuffer(
  430. layout=self.layout,
  431. inputs=self.input_nodes,
  432. make_kernel_render=self.make_kernel_render,
  433. )
  434. )
  435. class ExternKernelCaller(ChoiceCaller):
  436. def __init__(self, choice: ExternKernelChoice, input_nodes, layout, kwargs=None):
  437. super().__init__(choice.name, input_nodes, layout)
  438. self.choice = choice
  439. self.kwargs = kwargs or {}
  440. def to_callable(self):
  441. fn = self.choice.to_callable()
  442. if self.kwargs:
  443. return functools.partial(fn, **self.kwargs)
  444. else:
  445. return fn
  446. def hash_key(self):
  447. return "/".join(
  448. [
  449. self.choice.hash_key(),
  450. repr(self.kwargs),
  451. ]
  452. )
  453. def output_node(self):
  454. return ir.TensorBox.create(
  455. ir.ExternKernelOut(
  456. layout=self.layout,
  457. inputs=self.input_nodes,
  458. kernel=self.choice.call_name(),
  459. cpp_kernel=self.choice.cpp_kernel,
  460. kwargs=self.kwargs,
  461. )
  462. )
  463. class AlgorithmSelectorCache(DiskCache):
  464. def __call__(self, choices: List[ChoiceCaller], input_nodes, layout):
  465. if len(choices) == 1:
  466. return choices[0].output_node()
  467. def autotune():
  468. benchmark_fn = self.make_benchmark_fn(choices, input_nodes, layout)
  469. timings = {}
  470. for choice in choices:
  471. try:
  472. timings[choice] = benchmark_fn(
  473. choice.to_callable(), isinstance(choice, ExternKernelCaller)
  474. )
  475. except RuntimeError as e:
  476. if "invalid argument" in str(e):
  477. msg = textwrap.dedent(
  478. f"""
  479. {e}
  480. From choice {choices.index(choice)}: {choice}
  481. This may mean this GPU is too small for max_autotune mode.
  482. """
  483. ).strip()
  484. if VERIFY:
  485. raise RuntimeError(msg)
  486. else:
  487. log.warning(msg)
  488. else:
  489. raise
  490. except AssertionError as e:
  491. raise AssertionError(
  492. f"Incorrect result from choice {choices.index(choice)} {choice}\n\n{e}"
  493. )
  494. self.log_results(choices[0].name, input_nodes, timings)
  495. best_choice = builtins.min(timings, key=timings.__getitem__)
  496. return choices.index(best_choice)
  497. counters["inductor"]["select_algorithm_autotune"] += 1
  498. key = [x.hash_key() for x in choices] + [self.key_of(x) for x in input_nodes]
  499. return choices[self.lookup(key, autotune)].output_node()
  500. @classmethod
  501. def make_benchmark_fn(
  502. cls,
  503. choices,
  504. input_nodes,
  505. layout,
  506. ):
  507. example_inputs = [cls.benchmark_example_value(x) for x in input_nodes]
  508. example_inputs_extern = list(example_inputs)
  509. for i in range(len(example_inputs)):
  510. if input_nodes[i].get_layout().offset != 0:
  511. offset = V.graph.sizevars.size_hint(input_nodes[i].get_layout().offset)
  512. data = example_inputs_extern[i]
  513. example_inputs_extern[i] = torch.as_strided(
  514. data, data.size(), data.stride(), offset
  515. )
  516. out = cls.benchmark_example_value(layout)
  517. out_extern = torch.as_strided(
  518. out, out.size(), out.stride(), V.graph.sizevars.size_hint(layout.offset)
  519. )
  520. if VERIFY:
  521. choices[0].to_callable()(*example_inputs_extern, out=out_extern)
  522. expected = out_extern.clone()
  523. def benchmark(algo, is_extern):
  524. out.zero_()
  525. if is_extern:
  526. result = do_bench(lambda: algo(*example_inputs_extern, out=out_extern))
  527. else:
  528. result = do_bench(lambda: algo(*example_inputs, out=out))
  529. if VERIFY:
  530. torch.testing.assert_close(out_extern, expected, **VERIFY)
  531. torch.cuda.synchronize() # shake out any CUDA errors
  532. return result
  533. return benchmark
  534. @staticmethod
  535. def log_results(name, input_nodes, timings):
  536. if not PRINT_AUTOTUNE:
  537. return
  538. sizes = ", ".join(
  539. [
  540. "x".join(map(str, V.graph.sizevars.size_hints(n.get_size())))
  541. for n in input_nodes
  542. ]
  543. )
  544. top_k = sorted(timings, key=timings.__getitem__)[:10]
  545. best = top_k[0]
  546. best_time = timings[best][0]
  547. sys.stderr.write(f"AUTOTUNE {name}({sizes})\n")
  548. for choice in top_k:
  549. result = timings[choice]
  550. sys.stderr.write(
  551. f" {choice.name} {result[0]:.4f}s {best_time/result[0]:.1%}\n"
  552. )
  553. @staticmethod
  554. def benchmark_example_value(node):
  555. """
  556. Convert an ir.Buffer into a concrete torch.Tensor we can use for
  557. benchmarking.
  558. """
  559. if isinstance(node, ir.Layout):
  560. node = ir.Buffer("fake", node)
  561. return rand_strided(
  562. V.graph.sizevars.size_hints(node.get_size()),
  563. V.graph.sizevars.size_hints(node.get_stride()),
  564. device=node.get_device(),
  565. dtype=node.get_dtype(),
  566. extra_size=V.graph.sizevars.size_hint(node.get_layout().offset),
  567. )
  568. @staticmethod
  569. def key_of(node):
  570. """
  571. Extract the pieces of an ir.Buffer that we should invalidate cached
  572. autotuning results on.
  573. """
  574. sizevars = V.graph.sizevars
  575. return (
  576. node.get_device().type,
  577. str(node.get_dtype()),
  578. *sizevars.size_hints(node.get_size()),
  579. *sizevars.size_hints(node.get_stride()),
  580. sizevars.size_hint(node.get_layout().offset),
  581. )
  582. autotune_select_algorithm = AlgorithmSelectorCache(__name__)
  583. def realize_inputs(*args):
  584. if len(args) == 1:
  585. return ir.ExternKernel.require_stride1(ir.ExternKernel.realize_input(args[0]))
  586. return [realize_inputs(x) for x in args]
  587. # ensure lowering is imported so that `extern_kernels.*` is populated
  588. from . import lowering # noqa: F401