123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630 |
- import builtins
- import copy
- import functools
- import hashlib
- import json
- import logging
- import operator
- import os.path
- import re
- import threading
- from typing import List
- import torch
- from torch._dynamo.utils import dynamo_timed
- from .. import config
- from ..codecache import cache_dir
- from ..ir import ReductionHint, TileHint
- from ..utils import conditional_product, has_triton
- from .conv_perf_model import (
- early_config_prune as conv_early_config_prune,
- estimate_conv_time,
- )
- log = logging.getLogger(__name__)
- if has_triton():
- import triton
- from triton import cdiv, Config, next_power_of_2
- from triton.runtime.jit import get_cuda_stream, KernelInterface
- else:
- cdiv = None
- Config = object
- get_cuda_stream = None
- KernelInterface = object
- next_power_of_2 = None
- triton = None
- class CachingAutotuner(KernelInterface):
- """
- Simplified version of Triton autotuner that has no invalidation
- key and caches the best config to disk to improve cold start times.
- Unlike the main triton Autotuner, this version can precompile all
- configs, and does not rely on the Triton JIT.
- """
- def __init__(self, fn, meta, configs, save_cache_hook, mutated_arg_names):
- super().__init__()
- self.fn = fn
- self.meta = meta
- self.save_cache_hook = save_cache_hook
- self.mutated_arg_names = mutated_arg_names
- self.configs = configs
- self.launchers = []
- self.lock = threading.Lock()
- if os.getenv("TRITON_CACHE_DIR") is None:
- os.environ["TRITON_CACHE_DIR"] = os.path.join(
- cache_dir(),
- "triton",
- str(self.meta.get("device", 0)),
- )
- def precompile(self, warm_cache_only_with_cc=None):
- with self.lock:
- if self.launchers:
- return
- self.launchers = [
- self._precompile_config(c, warm_cache_only_with_cc)
- for c in self.configs
- ]
- self.configs = None
- def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: int):
- """Ahead of time compile a given autotuner config."""
- compile_meta = copy.deepcopy(self.meta)
- for k, v in cfg.kwargs.items():
- compile_meta["constants"][self.fn.arg_names.index(k)] = v
- compile_meta["num_warps"] = cfg.num_warps
- compile_meta["num_stages"] = cfg.num_stages
- if warm_cache_only_with_cc:
- triton.compile(
- self.fn,
- warm_cache_only=True,
- cc=warm_cache_only_with_cc,
- **compile_meta,
- )
- return
- # load binary to the correct device
- with torch.cuda.device(compile_meta["device"]):
- # need to initialize context
- torch.cuda.synchronize(torch.cuda.current_device())
- binary = triton.compile(
- self.fn,
- **compile_meta,
- )
- call_args = [
- arg
- for i, arg in enumerate(self.fn.arg_names)
- if i not in self.fn.constexprs
- ]
- def_args = list(self.fn.arg_names)
- while def_args and def_args[-1] in cfg.kwargs:
- def_args.pop()
- scope = {
- "grid_meta": cfg.kwargs,
- "bin": binary,
- "torch": torch,
- "set_device": torch.cuda.set_device,
- "current_device": torch.cuda.current_device,
- }
- exec(
- f"""
- def launcher({', '.join(def_args)}, grid, stream):
- if callable(grid):
- grid_0, grid_1, grid_2 = grid(grid_meta)
- else:
- grid_0, grid_1, grid_2 = grid
- bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared,
- stream, bin.cu_function, None, None, None,
- {', '.join(call_args)})
- """.lstrip(),
- scope,
- )
- launcher = scope["launcher"]
- launcher.config = cfg
- return launcher
- def bench(self, launcher, *args, grid):
- """Measure the performance of a given launcher"""
- stream = get_cuda_stream(torch.cuda.current_device())
- def kernel_call():
- if launcher.config.pre_hook is not None:
- launcher.config.pre_hook(
- {**zip(self.arg_names, args), **launcher.config.kwargs}
- )
- launcher(
- *args,
- grid=grid,
- stream=stream,
- )
- from triton.testing import do_bench
- return do_bench(kernel_call, rep=40, fast_flush=True)
- @dynamo_timed
- def autotune_to_one_config(self, *args, **kwargs):
- """Do the actual autotuning"""
- from ..compile_fx import clone_preserve_strides
- # clone inplace buffers to avoid autotune contaminating them if
- # the kernel does in-place stores. avoid cloning other buffers because
- # it leads to increase memory use
- cloned_args = []
- for i, arg in enumerate(args):
- if self.fn.arg_names[i] in self.mutated_arg_names:
- assert isinstance(arg, torch.Tensor)
- cloned_args.append(clone_preserve_strides(arg))
- else:
- cloned_args.append(arg)
- timings = {
- launcher: self.bench(launcher, *cloned_args, **kwargs)
- for launcher in self.launchers
- }
- self.launchers = [builtins.min(timings, key=timings.get)]
- if self.save_cache_hook:
- self.save_cache_hook(self.launchers[0].config)
- def run(self, *args, grid, stream):
- if len(self.launchers) != 1:
- if len(self.launchers) == 0:
- self.precompile()
- if len(self.launchers) > 1:
- self.autotune_to_one_config(*args, grid=grid)
- (launcher,) = self.launchers
- if launcher.config.pre_hook is not None:
- launcher.config.pre_hook(
- {**zip(self.arg_names, args), **launcher.config.kwargs}
- )
- try:
- result = launcher(
- *args,
- grid=grid,
- stream=stream,
- )
- except TypeError as e:
- if re.match(r"function takes exactly \d+ arguments \(\d+ given\)", str(e)):
- raise RuntimeError(
- """Consider updating Triton with
- `pip install -U "git+https://github.com/openai/triton@af76c989eb4799b015f8b288ccd8421558772e56#subdirectory=python"`"""
- ) from e
- else:
- raise e
- return result
- def hash_configs(configs: List[Config]):
- """
- Hash used to check for changes in configurations
- """
- hasher = hashlib.sha256()
- for cfg in configs:
- hasher.update(
- f"{sorted(cfg.kwargs.items())} {cfg.num_warps} {cfg.num_stages}\n".encode(
- "utf-8"
- )
- )
- return hasher.hexdigest()
- def load_cached_autotuning(
- cache_filename: str, configs_hash: str, configs: List[Config]
- ):
- """
- Read a cached autotuning result from disk
- """
- if not os.path.exists(cache_filename):
- return None
- with open(cache_filename, "r") as fd:
- best_config = json.loads(fd.read())
- if best_config.get("configs_hash") != configs_hash:
- return None
- matching_configs = [
- cfg
- for cfg in configs
- if all(val == best_config.get(key) for key, val in cfg.kwargs.items())
- ]
- if len(matching_configs) != 1:
- return None
- return matching_configs[0]
- def cached_autotune(
- configs: List[Config],
- meta,
- filename=None,
- ):
- """
- A copy of triton.autotune that calls our subclass. Our subclass
- has additional debugging, error handling, and on-disk caching.
- """
- configs = unique_configs(configs)
- assert len(configs) == 1 or filename
- # on disk caching logic
- if filename is not None and len(configs) > 1:
- cache_filename = os.path.splitext(filename)[0] + ".best_config"
- configs_hash = hash_configs(configs)
- best_config = load_cached_autotuning(cache_filename, configs_hash, configs)
- if best_config:
- configs = [best_config]
- def save_cache_hook(cfg):
- with open(cache_filename, "w") as fd:
- fd.write(json.dumps({**cfg.kwargs, "configs_hash": configs_hash}))
- else:
- save_cache_hook = None
- mutated_arg_names = meta.pop("mutated_arg_names", ())
- def decorator(fn):
- return CachingAutotuner(
- fn,
- meta=meta,
- configs=configs,
- save_cache_hook=save_cache_hook,
- mutated_arg_names=mutated_arg_names,
- )
- return decorator
- def unique_configs(configs: List[Config]):
- """Remove duplicate configurations"""
- seen = set()
- pruned_configs = []
- for cfg in configs:
- key = tuple(cfg.kwargs.items())
- if key not in seen:
- seen.add(key)
- pruned_configs.append(cfg)
- return pruned_configs
- def triton_config(size_hints, x, y=None, z=None, num_stages=1) -> Config:
- """
- Construct a pointwise triton config with some adjustment heuristics
- based on size_hints. Size_hints is a tuple of numels in each tile
- dimension and will be rounded up to the nearest power of 2.
- """
- # Ideally we want to read this from some device config
- maxGridSize = [2147483647, 65535, 65535]
- target = conditional_product(x, y, z)
- if conditional_product(*size_hints) < target:
- target //= 8
- # shrink sizes to size hints
- x = min(x, size_hints[0])
- if y:
- y = min(y, size_hints[1])
- if z:
- z = min(z, size_hints[2])
- # if we are below original block size, scale up where we can;
- # or if the calculated grid size is larger than the limit, we bump up the corresponding dimension
- while x < size_hints[0] and (
- x * maxGridSize[0] < size_hints[0] or conditional_product(x, y, z) < target
- ):
- x *= 2
- while (
- y
- and y < size_hints[1]
- and (
- y * maxGridSize[1] < size_hints[1] or conditional_product(x, y, z) < target
- )
- ):
- y *= 2
- while (
- z
- and z < size_hints[2]
- and (
- z * maxGridSize[2] < size_hints[2] or conditional_product(x, y, z) < target
- )
- ):
- z *= 2
- cfg = {"XBLOCK": x}
- if y:
- cfg["YBLOCK"] = y
- if z:
- cfg["ZBLOCK"] = z
- num_warps = next_power_of_2(min(max(conditional_product(x, y, z) // 256, 1), 8))
- return Config(cfg, num_warps=num_warps, num_stages=num_stages)
- def triton_config_reduction(size_hints, x, r, num_stages=2) -> Config:
- """
- Construct a reduction triton config with some adjustment heuristics
- based on size_hints. Size_hints is a tuple of numels in each tile
- dimension and will be rounded up to the nearest power of 2.
- """
- target = conditional_product(x, r)
- if conditional_product(*size_hints) < target:
- target //= 8
- # shrink sizes to size hints
- x = min(x, size_hints[0])
- r = min(r, size_hints[1])
- # if we are below original block size, scale up where we can
- while x < size_hints[0] and conditional_product(x, r) < target:
- x *= 2
- while r < size_hints[1] and conditional_product(x, r) < target:
- r *= 2
- cfg = {"XBLOCK": x, "RBLOCK": r}
- num_warps = next_power_of_2(min(max(conditional_product(x, r) // 128, 2), 8))
- return Config(cfg, num_warps=num_warps, num_stages=num_stages)
- def triton_config_tiled_reduction(size_hints, x, y, r, num_stages=2):
- """
- Construct a tile reduction triton config with some adjustment
- heuristics based on size_hints. Size_hints is a tuple of numels in
- each tile dimension and will be rounded up to the nearest power of 2.
- """
- target = conditional_product(x, y, r)
- if conditional_product(*size_hints) < target:
- target //= 8
- # shrink sizes to size hints
- x = min(x, size_hints[0])
- y = min(y, size_hints[1])
- r = min(r, size_hints[2])
- # if we are below original block size, scale up where we can
- while x < size_hints[0] and conditional_product(x, y, r) < target:
- x *= 2
- while r < size_hints[2] and conditional_product(x, y, r) < target:
- r *= 2
- while y < size_hints[1] and conditional_product(x, y, r) < target:
- y *= 2
- cfg = {"XBLOCK": x, "YBLOCK": y, "RBLOCK": r}
- num_warps = next_power_of_2(min(max(conditional_product(x, y, r) // 256, 1), 8))
- return Config(cfg, num_warps=num_warps, num_stages=num_stages)
- def pointwise(size_hints, meta, tile_hint=None, filename=None):
- """
- Construct @triton.heuristics() based on size_hints.
- """
- numel = functools.reduce(operator.mul, size_hints)
- bs = max(256, min(numel // 128, 1024))
- if len(size_hints) == 1:
- return cached_autotune([triton_config(size_hints, bs)], meta=meta)
- if len(size_hints) == 2:
- if (
- not config.triton.autotune_pointwise or tile_hint == TileHint.SQUARE
- ) and not config.max_autotune:
- return cached_autotune([triton_config(size_hints, 32, 32)], meta=meta)
- return cached_autotune(
- [
- triton_config(size_hints, 32, 32),
- triton_config(size_hints, 64, 64), # ~8% better for fp16
- triton_config(size_hints, 256, 16),
- triton_config(size_hints, 16, 256),
- triton_config(size_hints, bs, 1),
- triton_config(size_hints, 1, bs),
- ],
- meta=meta,
- filename=filename,
- )
- if len(size_hints) == 3:
- if not config.triton.autotune_pointwise:
- return cached_autotune([triton_config(size_hints, 16, 16, 16)], meta=meta)
- return cached_autotune(
- [
- triton_config(size_hints, 16, 16, 16),
- triton_config(size_hints, 64, 8, 8),
- triton_config(size_hints, 8, 64, 8),
- triton_config(size_hints, 8, 8, 64),
- triton_config(size_hints, bs, 1, 1),
- triton_config(size_hints, 1, bs, 1),
- triton_config(size_hints, 1, 1, bs),
- ],
- meta=meta,
- filename=filename,
- )
- raise NotImplementedError(f"size_hints: {size_hints}")
- def reduction(size_hints, reduction_hint=False, meta=None, filename=None):
- """args to @triton.heuristics()"""
- assert meta is not None
- rnumel = size_hints[-1]
- if len(size_hints) == 2:
- contiguous_config = triton_config_reduction(
- size_hints, 1, (rnumel if 256 <= rnumel < 2048 else 2048), num_stages=1
- )
- outer_config = triton_config_reduction(size_hints, 128, 8)
- tiny_config = triton_config_reduction(
- size_hints, 2 * (256 // rnumel) if rnumel <= 256 else 1, min(rnumel, 2048)
- )
- if config.max_autotune:
- pass # skip all these cases
- elif reduction_hint == ReductionHint.INNER:
- return cached_autotune([contiguous_config], meta=meta)
- elif reduction_hint == ReductionHint.OUTER:
- return cached_autotune([outer_config], meta=meta)
- elif reduction_hint == ReductionHint.OUTER_TINY:
- return cached_autotune([tiny_config], meta=meta)
- if not config.triton.autotune_pointwise:
- return cached_autotune(
- [triton_config_reduction(size_hints, 32, 128)], meta=meta
- )
- return cached_autotune(
- [
- contiguous_config,
- outer_config,
- tiny_config,
- triton_config_reduction(size_hints, 64, 64),
- triton_config_reduction(size_hints, 8, 512),
- ],
- meta=meta,
- filename=filename,
- )
- raise NotImplementedError(f"size_hints: {size_hints}")
- def persistent_reduction(size_hints, reduction_hint=False, meta=None, filename=None):
- xnumel, rnumel = size_hints
- configs = [
- triton_config_reduction(size_hints, xblock, rnumel)
- for xblock in (1, 8, 32, 128)
- if rnumel * xblock <= 4096 and xblock <= xnumel
- ]
- # TODO(jansel): we should be able to improve these heuristics
- if reduction_hint == ReductionHint.INNER and rnumel >= 256:
- configs = configs[:1]
- elif reduction_hint == ReductionHint.OUTER:
- configs = configs[-1:]
- elif reduction_hint == ReductionHint.OUTER_TINY:
- configs = [
- triton_config_reduction(
- size_hints, 2 * (256 // rnumel) if rnumel <= 256 else 1, rnumel
- )
- ]
- return cached_autotune(
- configs,
- meta=meta,
- filename=filename,
- )
- def template(num_stages, num_warps, meta, filename=None):
- """
- Compile a triton template
- """
- return cached_autotune(
- [triton.Config({}, num_stages=num_stages, num_warps=num_warps)], meta=meta
- )
- def conv_heuristics():
- configs = [
- triton.Config(
- {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, num_stages=2, num_warps=8
- ),
- triton.Config(
- {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=2, num_warps=8
- ),
- triton.Config(
- {"BLOCK_M": 256, "BLOCK_N": 32, "BLOCK_K": 32}, num_stages=4, num_warps=4
- ),
- triton.Config(
- {"BLOCK_M": 256, "BLOCK_N": 32, "BLOCK_K": 64}, num_stages=4, num_warps=4
- ),
- triton.Config(
- {"BLOCK_M": 256, "BLOCK_N": 16, "BLOCK_K": 32}, num_stages=4, num_warps=2
- ),
- triton.Config(
- {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_stages=4, num_warps=8
- ),
- triton.Config(
- {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=4, num_warps=4
- ),
- triton.Config(
- {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=4, num_warps=4
- ),
- triton.Config(
- {"BLOCK_M": 128, "BLOCK_N": 16, "BLOCK_K": 32}, num_stages=4, num_warps=4
- ),
- triton.Config(
- {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=3, num_warps=8
- ),
- triton.Config(
- {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=3, num_warps=8
- ),
- triton.Config(
- {"BLOCK_M": 256, "BLOCK_N": 32, "BLOCK_K": 128}, num_stages=4, num_warps=4
- ),
- triton.Config(
- {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=4, num_warps=4
- ),
- triton.Config(
- {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=4, num_warps=4
- ),
- triton.Config(
- {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64}, num_stages=4, num_warps=2
- ),
- triton.Config(
- {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=4, num_warps=2
- ),
- # triton.Config(
- # {"BLOCK_M": 128, "BLOCK_N": 16, "BLOCK_K": 64}, num_stages=4, num_warps=2
- # ),
- ]
- key = [
- "BATCH",
- "IN_C",
- "IN_H",
- "IN_W",
- "KERNEL_N",
- "KERNEL_H",
- "KERNEL_W",
- "OUT_H",
- "OUT_W",
- # parameters of conv
- "stride_h",
- "stride_w",
- "padding_h",
- "padding_w",
- "dilation_h",
- "dilation_w",
- "output_padding_h",
- "output_padding_w",
- "groups",
- ]
- prune_configs_by = {
- "early_config_prune": conv_early_config_prune,
- "perf_model": estimate_conv_time,
- "top_k": 10,
- }
- return triton.autotune(configs, key, prune_configs_by=prune_configs_by)
- def grid(xnumel, ynumel=None, znumel=None):
- """Helper function to compute triton grids"""
- def get_grid_dim(numel, block_name, block):
- if numel is None:
- return 1
- label = block_name[0]
- if numel == 1:
- assert block == 1, (
- f"TritonKernel.indexing assumes {label.lower()}numel == 1 => {block_name} == 1"
- f"({label.lower()}numel=={numel}, {block_name}={block})."
- )
- return cdiv(numel, block)
- def grid_fn(meta):
- return (
- get_grid_dim(xnumel, "XBLOCK", meta.get("XBLOCK", None)),
- get_grid_dim(ynumel, "YBLOCK", meta.get("YBLOCK", None)),
- get_grid_dim(znumel, "ZBLOCK", meta.get("ZBLOCK", None)),
- )
- return grid_fn
|