autotune.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630
  1. import builtins
  2. import copy
  3. import functools
  4. import hashlib
  5. import json
  6. import logging
  7. import operator
  8. import os.path
  9. import re
  10. import threading
  11. from typing import List
  12. import torch
  13. from torch._dynamo.utils import dynamo_timed
  14. from .. import config
  15. from ..codecache import cache_dir
  16. from ..ir import ReductionHint, TileHint
  17. from ..utils import conditional_product, has_triton
  18. from .conv_perf_model import (
  19. early_config_prune as conv_early_config_prune,
  20. estimate_conv_time,
  21. )
  22. log = logging.getLogger(__name__)
  23. if has_triton():
  24. import triton
  25. from triton import cdiv, Config, next_power_of_2
  26. from triton.runtime.jit import get_cuda_stream, KernelInterface
  27. else:
  28. cdiv = None
  29. Config = object
  30. get_cuda_stream = None
  31. KernelInterface = object
  32. next_power_of_2 = None
  33. triton = None
  34. class CachingAutotuner(KernelInterface):
  35. """
  36. Simplified version of Triton autotuner that has no invalidation
  37. key and caches the best config to disk to improve cold start times.
  38. Unlike the main triton Autotuner, this version can precompile all
  39. configs, and does not rely on the Triton JIT.
  40. """
  41. def __init__(self, fn, meta, configs, save_cache_hook, mutated_arg_names):
  42. super().__init__()
  43. self.fn = fn
  44. self.meta = meta
  45. self.save_cache_hook = save_cache_hook
  46. self.mutated_arg_names = mutated_arg_names
  47. self.configs = configs
  48. self.launchers = []
  49. self.lock = threading.Lock()
  50. if os.getenv("TRITON_CACHE_DIR") is None:
  51. os.environ["TRITON_CACHE_DIR"] = os.path.join(
  52. cache_dir(),
  53. "triton",
  54. str(self.meta.get("device", 0)),
  55. )
  56. def precompile(self, warm_cache_only_with_cc=None):
  57. with self.lock:
  58. if self.launchers:
  59. return
  60. self.launchers = [
  61. self._precompile_config(c, warm_cache_only_with_cc)
  62. for c in self.configs
  63. ]
  64. self.configs = None
  65. def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: int):
  66. """Ahead of time compile a given autotuner config."""
  67. compile_meta = copy.deepcopy(self.meta)
  68. for k, v in cfg.kwargs.items():
  69. compile_meta["constants"][self.fn.arg_names.index(k)] = v
  70. compile_meta["num_warps"] = cfg.num_warps
  71. compile_meta["num_stages"] = cfg.num_stages
  72. if warm_cache_only_with_cc:
  73. triton.compile(
  74. self.fn,
  75. warm_cache_only=True,
  76. cc=warm_cache_only_with_cc,
  77. **compile_meta,
  78. )
  79. return
  80. # load binary to the correct device
  81. with torch.cuda.device(compile_meta["device"]):
  82. # need to initialize context
  83. torch.cuda.synchronize(torch.cuda.current_device())
  84. binary = triton.compile(
  85. self.fn,
  86. **compile_meta,
  87. )
  88. call_args = [
  89. arg
  90. for i, arg in enumerate(self.fn.arg_names)
  91. if i not in self.fn.constexprs
  92. ]
  93. def_args = list(self.fn.arg_names)
  94. while def_args and def_args[-1] in cfg.kwargs:
  95. def_args.pop()
  96. scope = {
  97. "grid_meta": cfg.kwargs,
  98. "bin": binary,
  99. "torch": torch,
  100. "set_device": torch.cuda.set_device,
  101. "current_device": torch.cuda.current_device,
  102. }
  103. exec(
  104. f"""
  105. def launcher({', '.join(def_args)}, grid, stream):
  106. if callable(grid):
  107. grid_0, grid_1, grid_2 = grid(grid_meta)
  108. else:
  109. grid_0, grid_1, grid_2 = grid
  110. bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared,
  111. stream, bin.cu_function, None, None, None,
  112. {', '.join(call_args)})
  113. """.lstrip(),
  114. scope,
  115. )
  116. launcher = scope["launcher"]
  117. launcher.config = cfg
  118. return launcher
  119. def bench(self, launcher, *args, grid):
  120. """Measure the performance of a given launcher"""
  121. stream = get_cuda_stream(torch.cuda.current_device())
  122. def kernel_call():
  123. if launcher.config.pre_hook is not None:
  124. launcher.config.pre_hook(
  125. {**zip(self.arg_names, args), **launcher.config.kwargs}
  126. )
  127. launcher(
  128. *args,
  129. grid=grid,
  130. stream=stream,
  131. )
  132. from triton.testing import do_bench
  133. return do_bench(kernel_call, rep=40, fast_flush=True)
  134. @dynamo_timed
  135. def autotune_to_one_config(self, *args, **kwargs):
  136. """Do the actual autotuning"""
  137. from ..compile_fx import clone_preserve_strides
  138. # clone inplace buffers to avoid autotune contaminating them if
  139. # the kernel does in-place stores. avoid cloning other buffers because
  140. # it leads to increase memory use
  141. cloned_args = []
  142. for i, arg in enumerate(args):
  143. if self.fn.arg_names[i] in self.mutated_arg_names:
  144. assert isinstance(arg, torch.Tensor)
  145. cloned_args.append(clone_preserve_strides(arg))
  146. else:
  147. cloned_args.append(arg)
  148. timings = {
  149. launcher: self.bench(launcher, *cloned_args, **kwargs)
  150. for launcher in self.launchers
  151. }
  152. self.launchers = [builtins.min(timings, key=timings.get)]
  153. if self.save_cache_hook:
  154. self.save_cache_hook(self.launchers[0].config)
  155. def run(self, *args, grid, stream):
  156. if len(self.launchers) != 1:
  157. if len(self.launchers) == 0:
  158. self.precompile()
  159. if len(self.launchers) > 1:
  160. self.autotune_to_one_config(*args, grid=grid)
  161. (launcher,) = self.launchers
  162. if launcher.config.pre_hook is not None:
  163. launcher.config.pre_hook(
  164. {**zip(self.arg_names, args), **launcher.config.kwargs}
  165. )
  166. try:
  167. result = launcher(
  168. *args,
  169. grid=grid,
  170. stream=stream,
  171. )
  172. except TypeError as e:
  173. if re.match(r"function takes exactly \d+ arguments \(\d+ given\)", str(e)):
  174. raise RuntimeError(
  175. """Consider updating Triton with
  176. `pip install -U "git+https://github.com/openai/triton@af76c989eb4799b015f8b288ccd8421558772e56#subdirectory=python"`"""
  177. ) from e
  178. else:
  179. raise e
  180. return result
  181. def hash_configs(configs: List[Config]):
  182. """
  183. Hash used to check for changes in configurations
  184. """
  185. hasher = hashlib.sha256()
  186. for cfg in configs:
  187. hasher.update(
  188. f"{sorted(cfg.kwargs.items())} {cfg.num_warps} {cfg.num_stages}\n".encode(
  189. "utf-8"
  190. )
  191. )
  192. return hasher.hexdigest()
  193. def load_cached_autotuning(
  194. cache_filename: str, configs_hash: str, configs: List[Config]
  195. ):
  196. """
  197. Read a cached autotuning result from disk
  198. """
  199. if not os.path.exists(cache_filename):
  200. return None
  201. with open(cache_filename, "r") as fd:
  202. best_config = json.loads(fd.read())
  203. if best_config.get("configs_hash") != configs_hash:
  204. return None
  205. matching_configs = [
  206. cfg
  207. for cfg in configs
  208. if all(val == best_config.get(key) for key, val in cfg.kwargs.items())
  209. ]
  210. if len(matching_configs) != 1:
  211. return None
  212. return matching_configs[0]
  213. def cached_autotune(
  214. configs: List[Config],
  215. meta,
  216. filename=None,
  217. ):
  218. """
  219. A copy of triton.autotune that calls our subclass. Our subclass
  220. has additional debugging, error handling, and on-disk caching.
  221. """
  222. configs = unique_configs(configs)
  223. assert len(configs) == 1 or filename
  224. # on disk caching logic
  225. if filename is not None and len(configs) > 1:
  226. cache_filename = os.path.splitext(filename)[0] + ".best_config"
  227. configs_hash = hash_configs(configs)
  228. best_config = load_cached_autotuning(cache_filename, configs_hash, configs)
  229. if best_config:
  230. configs = [best_config]
  231. def save_cache_hook(cfg):
  232. with open(cache_filename, "w") as fd:
  233. fd.write(json.dumps({**cfg.kwargs, "configs_hash": configs_hash}))
  234. else:
  235. save_cache_hook = None
  236. mutated_arg_names = meta.pop("mutated_arg_names", ())
  237. def decorator(fn):
  238. return CachingAutotuner(
  239. fn,
  240. meta=meta,
  241. configs=configs,
  242. save_cache_hook=save_cache_hook,
  243. mutated_arg_names=mutated_arg_names,
  244. )
  245. return decorator
  246. def unique_configs(configs: List[Config]):
  247. """Remove duplicate configurations"""
  248. seen = set()
  249. pruned_configs = []
  250. for cfg in configs:
  251. key = tuple(cfg.kwargs.items())
  252. if key not in seen:
  253. seen.add(key)
  254. pruned_configs.append(cfg)
  255. return pruned_configs
  256. def triton_config(size_hints, x, y=None, z=None, num_stages=1) -> Config:
  257. """
  258. Construct a pointwise triton config with some adjustment heuristics
  259. based on size_hints. Size_hints is a tuple of numels in each tile
  260. dimension and will be rounded up to the nearest power of 2.
  261. """
  262. # Ideally we want to read this from some device config
  263. maxGridSize = [2147483647, 65535, 65535]
  264. target = conditional_product(x, y, z)
  265. if conditional_product(*size_hints) < target:
  266. target //= 8
  267. # shrink sizes to size hints
  268. x = min(x, size_hints[0])
  269. if y:
  270. y = min(y, size_hints[1])
  271. if z:
  272. z = min(z, size_hints[2])
  273. # if we are below original block size, scale up where we can;
  274. # or if the calculated grid size is larger than the limit, we bump up the corresponding dimension
  275. while x < size_hints[0] and (
  276. x * maxGridSize[0] < size_hints[0] or conditional_product(x, y, z) < target
  277. ):
  278. x *= 2
  279. while (
  280. y
  281. and y < size_hints[1]
  282. and (
  283. y * maxGridSize[1] < size_hints[1] or conditional_product(x, y, z) < target
  284. )
  285. ):
  286. y *= 2
  287. while (
  288. z
  289. and z < size_hints[2]
  290. and (
  291. z * maxGridSize[2] < size_hints[2] or conditional_product(x, y, z) < target
  292. )
  293. ):
  294. z *= 2
  295. cfg = {"XBLOCK": x}
  296. if y:
  297. cfg["YBLOCK"] = y
  298. if z:
  299. cfg["ZBLOCK"] = z
  300. num_warps = next_power_of_2(min(max(conditional_product(x, y, z) // 256, 1), 8))
  301. return Config(cfg, num_warps=num_warps, num_stages=num_stages)
  302. def triton_config_reduction(size_hints, x, r, num_stages=2) -> Config:
  303. """
  304. Construct a reduction triton config with some adjustment heuristics
  305. based on size_hints. Size_hints is a tuple of numels in each tile
  306. dimension and will be rounded up to the nearest power of 2.
  307. """
  308. target = conditional_product(x, r)
  309. if conditional_product(*size_hints) < target:
  310. target //= 8
  311. # shrink sizes to size hints
  312. x = min(x, size_hints[0])
  313. r = min(r, size_hints[1])
  314. # if we are below original block size, scale up where we can
  315. while x < size_hints[0] and conditional_product(x, r) < target:
  316. x *= 2
  317. while r < size_hints[1] and conditional_product(x, r) < target:
  318. r *= 2
  319. cfg = {"XBLOCK": x, "RBLOCK": r}
  320. num_warps = next_power_of_2(min(max(conditional_product(x, r) // 128, 2), 8))
  321. return Config(cfg, num_warps=num_warps, num_stages=num_stages)
  322. def triton_config_tiled_reduction(size_hints, x, y, r, num_stages=2):
  323. """
  324. Construct a tile reduction triton config with some adjustment
  325. heuristics based on size_hints. Size_hints is a tuple of numels in
  326. each tile dimension and will be rounded up to the nearest power of 2.
  327. """
  328. target = conditional_product(x, y, r)
  329. if conditional_product(*size_hints) < target:
  330. target //= 8
  331. # shrink sizes to size hints
  332. x = min(x, size_hints[0])
  333. y = min(y, size_hints[1])
  334. r = min(r, size_hints[2])
  335. # if we are below original block size, scale up where we can
  336. while x < size_hints[0] and conditional_product(x, y, r) < target:
  337. x *= 2
  338. while r < size_hints[2] and conditional_product(x, y, r) < target:
  339. r *= 2
  340. while y < size_hints[1] and conditional_product(x, y, r) < target:
  341. y *= 2
  342. cfg = {"XBLOCK": x, "YBLOCK": y, "RBLOCK": r}
  343. num_warps = next_power_of_2(min(max(conditional_product(x, y, r) // 256, 1), 8))
  344. return Config(cfg, num_warps=num_warps, num_stages=num_stages)
  345. def pointwise(size_hints, meta, tile_hint=None, filename=None):
  346. """
  347. Construct @triton.heuristics() based on size_hints.
  348. """
  349. numel = functools.reduce(operator.mul, size_hints)
  350. bs = max(256, min(numel // 128, 1024))
  351. if len(size_hints) == 1:
  352. return cached_autotune([triton_config(size_hints, bs)], meta=meta)
  353. if len(size_hints) == 2:
  354. if (
  355. not config.triton.autotune_pointwise or tile_hint == TileHint.SQUARE
  356. ) and not config.max_autotune:
  357. return cached_autotune([triton_config(size_hints, 32, 32)], meta=meta)
  358. return cached_autotune(
  359. [
  360. triton_config(size_hints, 32, 32),
  361. triton_config(size_hints, 64, 64), # ~8% better for fp16
  362. triton_config(size_hints, 256, 16),
  363. triton_config(size_hints, 16, 256),
  364. triton_config(size_hints, bs, 1),
  365. triton_config(size_hints, 1, bs),
  366. ],
  367. meta=meta,
  368. filename=filename,
  369. )
  370. if len(size_hints) == 3:
  371. if not config.triton.autotune_pointwise:
  372. return cached_autotune([triton_config(size_hints, 16, 16, 16)], meta=meta)
  373. return cached_autotune(
  374. [
  375. triton_config(size_hints, 16, 16, 16),
  376. triton_config(size_hints, 64, 8, 8),
  377. triton_config(size_hints, 8, 64, 8),
  378. triton_config(size_hints, 8, 8, 64),
  379. triton_config(size_hints, bs, 1, 1),
  380. triton_config(size_hints, 1, bs, 1),
  381. triton_config(size_hints, 1, 1, bs),
  382. ],
  383. meta=meta,
  384. filename=filename,
  385. )
  386. raise NotImplementedError(f"size_hints: {size_hints}")
  387. def reduction(size_hints, reduction_hint=False, meta=None, filename=None):
  388. """args to @triton.heuristics()"""
  389. assert meta is not None
  390. rnumel = size_hints[-1]
  391. if len(size_hints) == 2:
  392. contiguous_config = triton_config_reduction(
  393. size_hints, 1, (rnumel if 256 <= rnumel < 2048 else 2048), num_stages=1
  394. )
  395. outer_config = triton_config_reduction(size_hints, 128, 8)
  396. tiny_config = triton_config_reduction(
  397. size_hints, 2 * (256 // rnumel) if rnumel <= 256 else 1, min(rnumel, 2048)
  398. )
  399. if config.max_autotune:
  400. pass # skip all these cases
  401. elif reduction_hint == ReductionHint.INNER:
  402. return cached_autotune([contiguous_config], meta=meta)
  403. elif reduction_hint == ReductionHint.OUTER:
  404. return cached_autotune([outer_config], meta=meta)
  405. elif reduction_hint == ReductionHint.OUTER_TINY:
  406. return cached_autotune([tiny_config], meta=meta)
  407. if not config.triton.autotune_pointwise:
  408. return cached_autotune(
  409. [triton_config_reduction(size_hints, 32, 128)], meta=meta
  410. )
  411. return cached_autotune(
  412. [
  413. contiguous_config,
  414. outer_config,
  415. tiny_config,
  416. triton_config_reduction(size_hints, 64, 64),
  417. triton_config_reduction(size_hints, 8, 512),
  418. ],
  419. meta=meta,
  420. filename=filename,
  421. )
  422. raise NotImplementedError(f"size_hints: {size_hints}")
  423. def persistent_reduction(size_hints, reduction_hint=False, meta=None, filename=None):
  424. xnumel, rnumel = size_hints
  425. configs = [
  426. triton_config_reduction(size_hints, xblock, rnumel)
  427. for xblock in (1, 8, 32, 128)
  428. if rnumel * xblock <= 4096 and xblock <= xnumel
  429. ]
  430. # TODO(jansel): we should be able to improve these heuristics
  431. if reduction_hint == ReductionHint.INNER and rnumel >= 256:
  432. configs = configs[:1]
  433. elif reduction_hint == ReductionHint.OUTER:
  434. configs = configs[-1:]
  435. elif reduction_hint == ReductionHint.OUTER_TINY:
  436. configs = [
  437. triton_config_reduction(
  438. size_hints, 2 * (256 // rnumel) if rnumel <= 256 else 1, rnumel
  439. )
  440. ]
  441. return cached_autotune(
  442. configs,
  443. meta=meta,
  444. filename=filename,
  445. )
  446. def template(num_stages, num_warps, meta, filename=None):
  447. """
  448. Compile a triton template
  449. """
  450. return cached_autotune(
  451. [triton.Config({}, num_stages=num_stages, num_warps=num_warps)], meta=meta
  452. )
  453. def conv_heuristics():
  454. configs = [
  455. triton.Config(
  456. {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, num_stages=2, num_warps=8
  457. ),
  458. triton.Config(
  459. {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=2, num_warps=8
  460. ),
  461. triton.Config(
  462. {"BLOCK_M": 256, "BLOCK_N": 32, "BLOCK_K": 32}, num_stages=4, num_warps=4
  463. ),
  464. triton.Config(
  465. {"BLOCK_M": 256, "BLOCK_N": 32, "BLOCK_K": 64}, num_stages=4, num_warps=4
  466. ),
  467. triton.Config(
  468. {"BLOCK_M": 256, "BLOCK_N": 16, "BLOCK_K": 32}, num_stages=4, num_warps=2
  469. ),
  470. triton.Config(
  471. {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_stages=4, num_warps=8
  472. ),
  473. triton.Config(
  474. {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=4, num_warps=4
  475. ),
  476. triton.Config(
  477. {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=4, num_warps=4
  478. ),
  479. triton.Config(
  480. {"BLOCK_M": 128, "BLOCK_N": 16, "BLOCK_K": 32}, num_stages=4, num_warps=4
  481. ),
  482. triton.Config(
  483. {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=3, num_warps=8
  484. ),
  485. triton.Config(
  486. {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=3, num_warps=8
  487. ),
  488. triton.Config(
  489. {"BLOCK_M": 256, "BLOCK_N": 32, "BLOCK_K": 128}, num_stages=4, num_warps=4
  490. ),
  491. triton.Config(
  492. {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=4, num_warps=4
  493. ),
  494. triton.Config(
  495. {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=4, num_warps=4
  496. ),
  497. triton.Config(
  498. {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64}, num_stages=4, num_warps=2
  499. ),
  500. triton.Config(
  501. {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=4, num_warps=2
  502. ),
  503. # triton.Config(
  504. # {"BLOCK_M": 128, "BLOCK_N": 16, "BLOCK_K": 64}, num_stages=4, num_warps=2
  505. # ),
  506. ]
  507. key = [
  508. "BATCH",
  509. "IN_C",
  510. "IN_H",
  511. "IN_W",
  512. "KERNEL_N",
  513. "KERNEL_H",
  514. "KERNEL_W",
  515. "OUT_H",
  516. "OUT_W",
  517. # parameters of conv
  518. "stride_h",
  519. "stride_w",
  520. "padding_h",
  521. "padding_w",
  522. "dilation_h",
  523. "dilation_w",
  524. "output_padding_h",
  525. "output_padding_w",
  526. "groups",
  527. ]
  528. prune_configs_by = {
  529. "early_config_prune": conv_early_config_prune,
  530. "perf_model": estimate_conv_time,
  531. "top_k": 10,
  532. }
  533. return triton.autotune(configs, key, prune_configs_by=prune_configs_by)
  534. def grid(xnumel, ynumel=None, znumel=None):
  535. """Helper function to compute triton grids"""
  536. def get_grid_dim(numel, block_name, block):
  537. if numel is None:
  538. return 1
  539. label = block_name[0]
  540. if numel == 1:
  541. assert block == 1, (
  542. f"TritonKernel.indexing assumes {label.lower()}numel == 1 => {block_name} == 1"
  543. f"({label.lower()}numel=={numel}, {block_name}={block})."
  544. )
  545. return cdiv(numel, block)
  546. def grid_fn(meta):
  547. return (
  548. get_grid_dim(xnumel, "XBLOCK", meta.get("XBLOCK", None)),
  549. get_grid_dim(ynumel, "YBLOCK", meta.get("YBLOCK", None)),
  550. get_grid_dim(znumel, "ZBLOCK", meta.get("ZBLOCK", None)),
  551. )
  552. return grid_fn