import builtins import copy import functools import hashlib import json import logging import operator import os.path import re import threading from typing import List import torch from torch._dynamo.utils import dynamo_timed from .. import config from ..codecache import cache_dir from ..ir import ReductionHint, TileHint from ..utils import conditional_product, has_triton from .conv_perf_model import ( early_config_prune as conv_early_config_prune, estimate_conv_time, ) log = logging.getLogger(__name__) if has_triton(): import triton from triton import cdiv, Config, next_power_of_2 from triton.runtime.jit import get_cuda_stream, KernelInterface else: cdiv = None Config = object get_cuda_stream = None KernelInterface = object next_power_of_2 = None triton = None class CachingAutotuner(KernelInterface): """ Simplified version of Triton autotuner that has no invalidation key and caches the best config to disk to improve cold start times. Unlike the main triton Autotuner, this version can precompile all configs, and does not rely on the Triton JIT. """ def __init__(self, fn, meta, configs, save_cache_hook, mutated_arg_names): super().__init__() self.fn = fn self.meta = meta self.save_cache_hook = save_cache_hook self.mutated_arg_names = mutated_arg_names self.configs = configs self.launchers = [] self.lock = threading.Lock() if os.getenv("TRITON_CACHE_DIR") is None: os.environ["TRITON_CACHE_DIR"] = os.path.join( cache_dir(), "triton", str(self.meta.get("device", 0)), ) def precompile(self, warm_cache_only_with_cc=None): with self.lock: if self.launchers: return self.launchers = [ self._precompile_config(c, warm_cache_only_with_cc) for c in self.configs ] self.configs = None def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: int): """Ahead of time compile a given autotuner config.""" compile_meta = copy.deepcopy(self.meta) for k, v in cfg.kwargs.items(): compile_meta["constants"][self.fn.arg_names.index(k)] = v compile_meta["num_warps"] = cfg.num_warps compile_meta["num_stages"] = cfg.num_stages if warm_cache_only_with_cc: triton.compile( self.fn, warm_cache_only=True, cc=warm_cache_only_with_cc, **compile_meta, ) return # load binary to the correct device with torch.cuda.device(compile_meta["device"]): # need to initialize context torch.cuda.synchronize(torch.cuda.current_device()) binary = triton.compile( self.fn, **compile_meta, ) call_args = [ arg for i, arg in enumerate(self.fn.arg_names) if i not in self.fn.constexprs ] def_args = list(self.fn.arg_names) while def_args and def_args[-1] in cfg.kwargs: def_args.pop() scope = { "grid_meta": cfg.kwargs, "bin": binary, "torch": torch, "set_device": torch.cuda.set_device, "current_device": torch.cuda.current_device, } exec( f""" def launcher({', '.join(def_args)}, grid, stream): if callable(grid): grid_0, grid_1, grid_2 = grid(grid_meta) else: grid_0, grid_1, grid_2 = grid bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, None, None, None, {', '.join(call_args)}) """.lstrip(), scope, ) launcher = scope["launcher"] launcher.config = cfg return launcher def bench(self, launcher, *args, grid): """Measure the performance of a given launcher""" stream = get_cuda_stream(torch.cuda.current_device()) def kernel_call(): if launcher.config.pre_hook is not None: launcher.config.pre_hook( {**zip(self.arg_names, args), **launcher.config.kwargs} ) launcher( *args, grid=grid, stream=stream, ) from triton.testing import do_bench return do_bench(kernel_call, rep=40, fast_flush=True) @dynamo_timed def autotune_to_one_config(self, *args, **kwargs): """Do the actual autotuning""" from ..compile_fx import clone_preserve_strides # clone inplace buffers to avoid autotune contaminating them if # the kernel does in-place stores. avoid cloning other buffers because # it leads to increase memory use cloned_args = [] for i, arg in enumerate(args): if self.fn.arg_names[i] in self.mutated_arg_names: assert isinstance(arg, torch.Tensor) cloned_args.append(clone_preserve_strides(arg)) else: cloned_args.append(arg) timings = { launcher: self.bench(launcher, *cloned_args, **kwargs) for launcher in self.launchers } self.launchers = [builtins.min(timings, key=timings.get)] if self.save_cache_hook: self.save_cache_hook(self.launchers[0].config) def run(self, *args, grid, stream): if len(self.launchers) != 1: if len(self.launchers) == 0: self.precompile() if len(self.launchers) > 1: self.autotune_to_one_config(*args, grid=grid) (launcher,) = self.launchers if launcher.config.pre_hook is not None: launcher.config.pre_hook( {**zip(self.arg_names, args), **launcher.config.kwargs} ) try: result = launcher( *args, grid=grid, stream=stream, ) except TypeError as e: if re.match(r"function takes exactly \d+ arguments \(\d+ given\)", str(e)): raise RuntimeError( """Consider updating Triton with `pip install -U "git+https://github.com/openai/triton@af76c989eb4799b015f8b288ccd8421558772e56#subdirectory=python"`""" ) from e else: raise e return result def hash_configs(configs: List[Config]): """ Hash used to check for changes in configurations """ hasher = hashlib.sha256() for cfg in configs: hasher.update( f"{sorted(cfg.kwargs.items())} {cfg.num_warps} {cfg.num_stages}\n".encode( "utf-8" ) ) return hasher.hexdigest() def load_cached_autotuning( cache_filename: str, configs_hash: str, configs: List[Config] ): """ Read a cached autotuning result from disk """ if not os.path.exists(cache_filename): return None with open(cache_filename, "r") as fd: best_config = json.loads(fd.read()) if best_config.get("configs_hash") != configs_hash: return None matching_configs = [ cfg for cfg in configs if all(val == best_config.get(key) for key, val in cfg.kwargs.items()) ] if len(matching_configs) != 1: return None return matching_configs[0] def cached_autotune( configs: List[Config], meta, filename=None, ): """ A copy of triton.autotune that calls our subclass. Our subclass has additional debugging, error handling, and on-disk caching. """ configs = unique_configs(configs) assert len(configs) == 1 or filename # on disk caching logic if filename is not None and len(configs) > 1: cache_filename = os.path.splitext(filename)[0] + ".best_config" configs_hash = hash_configs(configs) best_config = load_cached_autotuning(cache_filename, configs_hash, configs) if best_config: configs = [best_config] def save_cache_hook(cfg): with open(cache_filename, "w") as fd: fd.write(json.dumps({**cfg.kwargs, "configs_hash": configs_hash})) else: save_cache_hook = None mutated_arg_names = meta.pop("mutated_arg_names", ()) def decorator(fn): return CachingAutotuner( fn, meta=meta, configs=configs, save_cache_hook=save_cache_hook, mutated_arg_names=mutated_arg_names, ) return decorator def unique_configs(configs: List[Config]): """Remove duplicate configurations""" seen = set() pruned_configs = [] for cfg in configs: key = tuple(cfg.kwargs.items()) if key not in seen: seen.add(key) pruned_configs.append(cfg) return pruned_configs def triton_config(size_hints, x, y=None, z=None, num_stages=1) -> Config: """ Construct a pointwise triton config with some adjustment heuristics based on size_hints. Size_hints is a tuple of numels in each tile dimension and will be rounded up to the nearest power of 2. """ # Ideally we want to read this from some device config maxGridSize = [2147483647, 65535, 65535] target = conditional_product(x, y, z) if conditional_product(*size_hints) < target: target //= 8 # shrink sizes to size hints x = min(x, size_hints[0]) if y: y = min(y, size_hints[1]) if z: z = min(z, size_hints[2]) # if we are below original block size, scale up where we can; # or if the calculated grid size is larger than the limit, we bump up the corresponding dimension while x < size_hints[0] and ( x * maxGridSize[0] < size_hints[0] or conditional_product(x, y, z) < target ): x *= 2 while ( y and y < size_hints[1] and ( y * maxGridSize[1] < size_hints[1] or conditional_product(x, y, z) < target ) ): y *= 2 while ( z and z < size_hints[2] and ( z * maxGridSize[2] < size_hints[2] or conditional_product(x, y, z) < target ) ): z *= 2 cfg = {"XBLOCK": x} if y: cfg["YBLOCK"] = y if z: cfg["ZBLOCK"] = z num_warps = next_power_of_2(min(max(conditional_product(x, y, z) // 256, 1), 8)) return Config(cfg, num_warps=num_warps, num_stages=num_stages) def triton_config_reduction(size_hints, x, r, num_stages=2) -> Config: """ Construct a reduction triton config with some adjustment heuristics based on size_hints. Size_hints is a tuple of numels in each tile dimension and will be rounded up to the nearest power of 2. """ target = conditional_product(x, r) if conditional_product(*size_hints) < target: target //= 8 # shrink sizes to size hints x = min(x, size_hints[0]) r = min(r, size_hints[1]) # if we are below original block size, scale up where we can while x < size_hints[0] and conditional_product(x, r) < target: x *= 2 while r < size_hints[1] and conditional_product(x, r) < target: r *= 2 cfg = {"XBLOCK": x, "RBLOCK": r} num_warps = next_power_of_2(min(max(conditional_product(x, r) // 128, 2), 8)) return Config(cfg, num_warps=num_warps, num_stages=num_stages) def triton_config_tiled_reduction(size_hints, x, y, r, num_stages=2): """ Construct a tile reduction triton config with some adjustment heuristics based on size_hints. Size_hints is a tuple of numels in each tile dimension and will be rounded up to the nearest power of 2. """ target = conditional_product(x, y, r) if conditional_product(*size_hints) < target: target //= 8 # shrink sizes to size hints x = min(x, size_hints[0]) y = min(y, size_hints[1]) r = min(r, size_hints[2]) # if we are below original block size, scale up where we can while x < size_hints[0] and conditional_product(x, y, r) < target: x *= 2 while r < size_hints[2] and conditional_product(x, y, r) < target: r *= 2 while y < size_hints[1] and conditional_product(x, y, r) < target: y *= 2 cfg = {"XBLOCK": x, "YBLOCK": y, "RBLOCK": r} num_warps = next_power_of_2(min(max(conditional_product(x, y, r) // 256, 1), 8)) return Config(cfg, num_warps=num_warps, num_stages=num_stages) def pointwise(size_hints, meta, tile_hint=None, filename=None): """ Construct @triton.heuristics() based on size_hints. """ numel = functools.reduce(operator.mul, size_hints) bs = max(256, min(numel // 128, 1024)) if len(size_hints) == 1: return cached_autotune([triton_config(size_hints, bs)], meta=meta) if len(size_hints) == 2: if ( not config.triton.autotune_pointwise or tile_hint == TileHint.SQUARE ) and not config.max_autotune: return cached_autotune([triton_config(size_hints, 32, 32)], meta=meta) return cached_autotune( [ triton_config(size_hints, 32, 32), triton_config(size_hints, 64, 64), # ~8% better for fp16 triton_config(size_hints, 256, 16), triton_config(size_hints, 16, 256), triton_config(size_hints, bs, 1), triton_config(size_hints, 1, bs), ], meta=meta, filename=filename, ) if len(size_hints) == 3: if not config.triton.autotune_pointwise: return cached_autotune([triton_config(size_hints, 16, 16, 16)], meta=meta) return cached_autotune( [ triton_config(size_hints, 16, 16, 16), triton_config(size_hints, 64, 8, 8), triton_config(size_hints, 8, 64, 8), triton_config(size_hints, 8, 8, 64), triton_config(size_hints, bs, 1, 1), triton_config(size_hints, 1, bs, 1), triton_config(size_hints, 1, 1, bs), ], meta=meta, filename=filename, ) raise NotImplementedError(f"size_hints: {size_hints}") def reduction(size_hints, reduction_hint=False, meta=None, filename=None): """args to @triton.heuristics()""" assert meta is not None rnumel = size_hints[-1] if len(size_hints) == 2: contiguous_config = triton_config_reduction( size_hints, 1, (rnumel if 256 <= rnumel < 2048 else 2048), num_stages=1 ) outer_config = triton_config_reduction(size_hints, 128, 8) tiny_config = triton_config_reduction( size_hints, 2 * (256 // rnumel) if rnumel <= 256 else 1, min(rnumel, 2048) ) if config.max_autotune: pass # skip all these cases elif reduction_hint == ReductionHint.INNER: return cached_autotune([contiguous_config], meta=meta) elif reduction_hint == ReductionHint.OUTER: return cached_autotune([outer_config], meta=meta) elif reduction_hint == ReductionHint.OUTER_TINY: return cached_autotune([tiny_config], meta=meta) if not config.triton.autotune_pointwise: return cached_autotune( [triton_config_reduction(size_hints, 32, 128)], meta=meta ) return cached_autotune( [ contiguous_config, outer_config, tiny_config, triton_config_reduction(size_hints, 64, 64), triton_config_reduction(size_hints, 8, 512), ], meta=meta, filename=filename, ) raise NotImplementedError(f"size_hints: {size_hints}") def persistent_reduction(size_hints, reduction_hint=False, meta=None, filename=None): xnumel, rnumel = size_hints configs = [ triton_config_reduction(size_hints, xblock, rnumel) for xblock in (1, 8, 32, 128) if rnumel * xblock <= 4096 and xblock <= xnumel ] # TODO(jansel): we should be able to improve these heuristics if reduction_hint == ReductionHint.INNER and rnumel >= 256: configs = configs[:1] elif reduction_hint == ReductionHint.OUTER: configs = configs[-1:] elif reduction_hint == ReductionHint.OUTER_TINY: configs = [ triton_config_reduction( size_hints, 2 * (256 // rnumel) if rnumel <= 256 else 1, rnumel ) ] return cached_autotune( configs, meta=meta, filename=filename, ) def template(num_stages, num_warps, meta, filename=None): """ Compile a triton template """ return cached_autotune( [triton.Config({}, num_stages=num_stages, num_warps=num_warps)], meta=meta ) def conv_heuristics(): configs = [ triton.Config( {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, num_stages=2, num_warps=8 ), triton.Config( {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=2, num_warps=8 ), triton.Config( {"BLOCK_M": 256, "BLOCK_N": 32, "BLOCK_K": 32}, num_stages=4, num_warps=4 ), triton.Config( {"BLOCK_M": 256, "BLOCK_N": 32, "BLOCK_K": 64}, num_stages=4, num_warps=4 ), triton.Config( {"BLOCK_M": 256, "BLOCK_N": 16, "BLOCK_K": 32}, num_stages=4, num_warps=2 ), triton.Config( {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_stages=4, num_warps=8 ), triton.Config( {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=4, num_warps=4 ), triton.Config( {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=4, num_warps=4 ), triton.Config( {"BLOCK_M": 128, "BLOCK_N": 16, "BLOCK_K": 32}, num_stages=4, num_warps=4 ), triton.Config( {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=3, num_warps=8 ), triton.Config( {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=3, num_warps=8 ), triton.Config( {"BLOCK_M": 256, "BLOCK_N": 32, "BLOCK_K": 128}, num_stages=4, num_warps=4 ), triton.Config( {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=4, num_warps=4 ), triton.Config( {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=4, num_warps=4 ), triton.Config( {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64}, num_stages=4, num_warps=2 ), triton.Config( {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=4, num_warps=2 ), # triton.Config( # {"BLOCK_M": 128, "BLOCK_N": 16, "BLOCK_K": 64}, num_stages=4, num_warps=2 # ), ] key = [ "BATCH", "IN_C", "IN_H", "IN_W", "KERNEL_N", "KERNEL_H", "KERNEL_W", "OUT_H", "OUT_W", # parameters of conv "stride_h", "stride_w", "padding_h", "padding_w", "dilation_h", "dilation_w", "output_padding_h", "output_padding_w", "groups", ] prune_configs_by = { "early_config_prune": conv_early_config_prune, "perf_model": estimate_conv_time, "top_k": 10, } return triton.autotune(configs, key, prune_configs_by=prune_configs_by) def grid(xnumel, ynumel=None, znumel=None): """Helper function to compute triton grids""" def get_grid_dim(numel, block_name, block): if numel is None: return 1 label = block_name[0] if numel == 1: assert block == 1, ( f"TritonKernel.indexing assumes {label.lower()}numel == 1 => {block_name} == 1" f"({label.lower()}numel=={numel}, {block_name}={block})." ) return cdiv(numel, block) def grid_fn(meta): return ( get_grid_dim(xnumel, "XBLOCK", meta.get("XBLOCK", None)), get_grid_dim(ynumel, "YBLOCK", meta.get("YBLOCK", None)), get_grid_dim(znumel, "ZBLOCK", meta.get("ZBLOCK", None)), ) return grid_fn