|| 
							- import builtins
 
- import copy
 
- import functools
 
- import hashlib
 
- import json
 
- import logging
 
- import operator
 
- import os.path
 
- import re
 
- import threading
 
- from typing import List
 
- import torch
 
- from torch._dynamo.utils import dynamo_timed
 
- from .. import config
 
- from ..codecache import cache_dir
 
- from ..ir import ReductionHint, TileHint
 
- from ..utils import conditional_product, has_triton
 
- from .conv_perf_model import (
 
-     early_config_prune as conv_early_config_prune,
 
-     estimate_conv_time,
 
- )
 
- log = logging.getLogger(__name__)
 
- if has_triton():
 
-     import triton
 
-     from triton import cdiv, Config, next_power_of_2
 
-     from triton.runtime.jit import get_cuda_stream, KernelInterface
 
- else:
 
-     cdiv = None
 
-     Config = object
 
-     get_cuda_stream = None
 
-     KernelInterface = object
 
-     next_power_of_2 = None
 
-     triton = None
 
- class CachingAutotuner(KernelInterface):
 
-     """
 
-     Simplified version of Triton autotuner that has no invalidation
 
-     key and caches the best config to disk to improve cold start times.
 
-     Unlike the main triton Autotuner, this version can precompile all
 
-     configs, and does not rely on the Triton JIT.
 
-     """
 
-     def __init__(self, fn, meta, configs, save_cache_hook, mutated_arg_names):
 
-         super().__init__()
 
-         self.fn = fn
 
-         self.meta = meta
 
-         self.save_cache_hook = save_cache_hook
 
-         self.mutated_arg_names = mutated_arg_names
 
-         self.configs = configs
 
-         self.launchers = []
 
-         self.lock = threading.Lock()
 
-         if os.getenv("TRITON_CACHE_DIR") is None:
 
-             os.environ["TRITON_CACHE_DIR"] = os.path.join(
 
-                 cache_dir(),
 
-                 "triton",
 
-                 str(self.meta.get("device", 0)),
 
-             )
 
-     def precompile(self, warm_cache_only_with_cc=None):
 
-         with self.lock:
 
-             if self.launchers:
 
-                 return
 
-             self.launchers = [
 
-                 self._precompile_config(c, warm_cache_only_with_cc)
 
-                 for c in self.configs
 
-             ]
 
-             self.configs = None
 
-     def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: int):
 
-         """Ahead of time compile a given autotuner config."""
 
-         compile_meta = copy.deepcopy(self.meta)
 
-         for k, v in cfg.kwargs.items():
 
-             compile_meta["constants"][self.fn.arg_names.index(k)] = v
 
-         compile_meta["num_warps"] = cfg.num_warps
 
-         compile_meta["num_stages"] = cfg.num_stages
 
-         if warm_cache_only_with_cc:
 
-             triton.compile(
 
-                 self.fn,
 
-                 warm_cache_only=True,
 
-                 cc=warm_cache_only_with_cc,
 
-                 **compile_meta,
 
-             )
 
-             return
 
-         # load binary to the correct device
 
-         with torch.cuda.device(compile_meta["device"]):
 
-             # need to initialize context
 
-             torch.cuda.synchronize(torch.cuda.current_device())
 
-             binary = triton.compile(
 
-                 self.fn,
 
-                 **compile_meta,
 
-             )
 
-         call_args = [
 
-             arg
 
-             for i, arg in enumerate(self.fn.arg_names)
 
-             if i not in self.fn.constexprs
 
-         ]
 
-         def_args = list(self.fn.arg_names)
 
-         while def_args and def_args[-1] in cfg.kwargs:
 
-             def_args.pop()
 
-         scope = {
 
-             "grid_meta": cfg.kwargs,
 
-             "bin": binary,
 
-             "torch": torch,
 
-             "set_device": torch.cuda.set_device,
 
-             "current_device": torch.cuda.current_device,
 
-         }
 
-         exec(
 
-             f"""
 
-             def launcher({', '.join(def_args)}, grid, stream):
 
-                 if callable(grid):
 
-                     grid_0, grid_1, grid_2 = grid(grid_meta)
 
-                 else:
 
-                     grid_0, grid_1, grid_2 = grid
 
-                 bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared,
 
-                             stream, bin.cu_function, None, None, None,
 
-                             {', '.join(call_args)})
 
-             """.lstrip(),
 
-             scope,
 
-         )
 
-         launcher = scope["launcher"]
 
-         launcher.config = cfg
 
-         return launcher
 
-     def bench(self, launcher, *args, grid):
 
-         """Measure the performance of a given launcher"""
 
-         stream = get_cuda_stream(torch.cuda.current_device())
 
-         def kernel_call():
 
-             if launcher.config.pre_hook is not None:
 
-                 launcher.config.pre_hook(
 
-                     {**zip(self.arg_names, args), **launcher.config.kwargs}
 
-                 )
 
-             launcher(
 
-                 *args,
 
-                 grid=grid,
 
-                 stream=stream,
 
-             )
 
-         from triton.testing import do_bench
 
-         return do_bench(kernel_call, rep=40, fast_flush=True)
 
-     @dynamo_timed
 
-     def autotune_to_one_config(self, *args, **kwargs):
 
-         """Do the actual autotuning"""
 
-         from ..compile_fx import clone_preserve_strides
 
-         # clone inplace buffers to avoid autotune contaminating them if
 
-         # the kernel does in-place stores. avoid cloning other buffers because
 
-         # it leads to increase memory use
 
-         cloned_args = []
 
-         for i, arg in enumerate(args):
 
-             if self.fn.arg_names[i] in self.mutated_arg_names:
 
-                 assert isinstance(arg, torch.Tensor)
 
-                 cloned_args.append(clone_preserve_strides(arg))
 
-             else:
 
-                 cloned_args.append(arg)
 
-         timings = {
 
-             launcher: self.bench(launcher, *cloned_args, **kwargs)
 
-             for launcher in self.launchers
 
-         }
 
-         self.launchers = [builtins.min(timings, key=timings.get)]
 
-         if self.save_cache_hook:
 
-             self.save_cache_hook(self.launchers[0].config)
 
-     def run(self, *args, grid, stream):
 
-         if len(self.launchers) != 1:
 
-             if len(self.launchers) == 0:
 
-                 self.precompile()
 
-             if len(self.launchers) > 1:
 
-                 self.autotune_to_one_config(*args, grid=grid)
 
-         (launcher,) = self.launchers
 
-         if launcher.config.pre_hook is not None:
 
-             launcher.config.pre_hook(
 
-                 {**zip(self.arg_names, args), **launcher.config.kwargs}
 
-             )
 
-         try:
 
-             result = launcher(
 
-                 *args,
 
-                 grid=grid,
 
-                 stream=stream,
 
-             )
 
-         except TypeError as e:
 
-             if re.match(r"function takes exactly \d+ arguments \(\d+ given\)", str(e)):
 
-                 raise RuntimeError(
 
-                     """Consider updating Triton with
 
- `pip install -U "git+https://github.com/openai/triton@af76c989eb4799b015f8b288ccd8421558772e56#subdirectory=python"`"""
 
-                 ) from e
 
-             else:
 
-                 raise e
 
-         return result
 
- def hash_configs(configs: List[Config]):
 
-     """
 
-     Hash used to check for changes in configurations
 
-     """
 
-     hasher = hashlib.sha256()
 
-     for cfg in configs:
 
-         hasher.update(
 
-             f"{sorted(cfg.kwargs.items())} {cfg.num_warps} {cfg.num_stages}\n".encode(
 
-                 "utf-8"
 
-             )
 
-         )
 
-     return hasher.hexdigest()
 
- def load_cached_autotuning(
 
-     cache_filename: str, configs_hash: str, configs: List[Config]
 
- ):
 
-     """
 
-     Read a cached autotuning result from disk
 
-     """
 
-     if not os.path.exists(cache_filename):
 
-         return None
 
-     with open(cache_filename, "r") as fd:
 
-         best_config = json.loads(fd.read())
 
-     if best_config.get("configs_hash") != configs_hash:
 
-         return None
 
-     matching_configs = [
 
-         cfg
 
-         for cfg in configs
 
-         if all(val == best_config.get(key) for key, val in cfg.kwargs.items())
 
-     ]
 
-     if len(matching_configs) != 1:
 
-         return None
 
-     return matching_configs[0]
 
- def cached_autotune(
 
-     configs: List[Config],
 
-     meta,
 
-     filename=None,
 
- ):
 
-     """
 
-     A copy of triton.autotune that calls our subclass.  Our subclass
 
-     has additional debugging, error handling, and on-disk caching.
 
-     """
 
-     configs = unique_configs(configs)
 
-     assert len(configs) == 1 or filename
 
-     # on disk caching logic
 
-     if filename is not None and len(configs) > 1:
 
-         cache_filename = os.path.splitext(filename)[0] + ".best_config"
 
-         configs_hash = hash_configs(configs)
 
-         best_config = load_cached_autotuning(cache_filename, configs_hash, configs)
 
-         if best_config:
 
-             configs = [best_config]
 
-         def save_cache_hook(cfg):
 
-             with open(cache_filename, "w") as fd:
 
-                 fd.write(json.dumps({**cfg.kwargs, "configs_hash": configs_hash}))
 
-     else:
 
-         save_cache_hook = None
 
-     mutated_arg_names = meta.pop("mutated_arg_names", ())
 
-     def decorator(fn):
 
-         return CachingAutotuner(
 
-             fn,
 
-             meta=meta,
 
-             configs=configs,
 
-             save_cache_hook=save_cache_hook,
 
-             mutated_arg_names=mutated_arg_names,
 
-         )
 
-     return decorator
 
- def unique_configs(configs: List[Config]):
 
-     """Remove duplicate configurations"""
 
-     seen = set()
 
-     pruned_configs = []
 
-     for cfg in configs:
 
-         key = tuple(cfg.kwargs.items())
 
-         if key not in seen:
 
-             seen.add(key)
 
-             pruned_configs.append(cfg)
 
-     return pruned_configs
 
- def triton_config(size_hints, x, y=None, z=None, num_stages=1) -> Config:
 
-     """
 
-     Construct a pointwise triton config with some adjustment heuristics
 
-     based on size_hints. Size_hints is a tuple of numels in each tile
 
-     dimension and will be rounded up to the nearest power of 2.
 
-     """
 
-     # Ideally we want to read this from some device config
 
-     maxGridSize = [2147483647, 65535, 65535]
 
-     target = conditional_product(x, y, z)
 
-     if conditional_product(*size_hints) < target:
 
-         target //= 8
 
-     # shrink sizes to size hints
 
-     x = min(x, size_hints[0])
 
-     if y:
 
-         y = min(y, size_hints[1])
 
-     if z:
 
-         z = min(z, size_hints[2])
 
-     # if we are below original block size, scale up where we can;
 
-     # or if the calculated grid size is larger than the limit, we bump up the corresponding dimension
 
-     while x < size_hints[0] and (
 
-         x * maxGridSize[0] < size_hints[0] or conditional_product(x, y, z) < target
 
-     ):
 
-         x *= 2
 
-     while (
 
-         y
 
-         and y < size_hints[1]
 
-         and (
 
-             y * maxGridSize[1] < size_hints[1] or conditional_product(x, y, z) < target
 
-         )
 
-     ):
 
-         y *= 2
 
-     while (
 
-         z
 
-         and z < size_hints[2]
 
-         and (
 
-             z * maxGridSize[2] < size_hints[2] or conditional_product(x, y, z) < target
 
-         )
 
-     ):
 
-         z *= 2
 
-     cfg = {"XBLOCK": x}
 
-     if y:
 
-         cfg["YBLOCK"] = y
 
-     if z:
 
-         cfg["ZBLOCK"] = z
 
-     num_warps = next_power_of_2(min(max(conditional_product(x, y, z) // 256, 1), 8))
 
-     return Config(cfg, num_warps=num_warps, num_stages=num_stages)
 
- def triton_config_reduction(size_hints, x, r, num_stages=2) -> Config:
 
-     """
 
-     Construct a reduction triton config with some adjustment heuristics
 
-     based on size_hints. Size_hints is a tuple of numels in each tile
 
-     dimension and will be rounded up to the nearest power of 2.
 
-     """
 
-     target = conditional_product(x, r)
 
-     if conditional_product(*size_hints) < target:
 
-         target //= 8
 
-     # shrink sizes to size hints
 
-     x = min(x, size_hints[0])
 
-     r = min(r, size_hints[1])
 
-     # if we are below original block size, scale up where we can
 
-     while x < size_hints[0] and conditional_product(x, r) < target:
 
-         x *= 2
 
-     while r < size_hints[1] and conditional_product(x, r) < target:
 
-         r *= 2
 
-     cfg = {"XBLOCK": x, "RBLOCK": r}
 
-     num_warps = next_power_of_2(min(max(conditional_product(x, r) // 128, 2), 8))
 
-     return Config(cfg, num_warps=num_warps, num_stages=num_stages)
 
- def triton_config_tiled_reduction(size_hints, x, y, r, num_stages=2):
 
-     """
 
-     Construct a tile reduction triton config with some adjustment
 
-     heuristics based on size_hints. Size_hints is a tuple of numels in
 
-     each tile dimension and will be rounded up to the nearest power of 2.
 
-     """
 
-     target = conditional_product(x, y, r)
 
-     if conditional_product(*size_hints) < target:
 
-         target //= 8
 
-     # shrink sizes to size hints
 
-     x = min(x, size_hints[0])
 
-     y = min(y, size_hints[1])
 
-     r = min(r, size_hints[2])
 
-     # if we are below original block size, scale up where we can
 
-     while x < size_hints[0] and conditional_product(x, y, r) < target:
 
-         x *= 2
 
-     while r < size_hints[2] and conditional_product(x, y, r) < target:
 
-         r *= 2
 
-     while y < size_hints[1] and conditional_product(x, y, r) < target:
 
-         y *= 2
 
-     cfg = {"XBLOCK": x, "YBLOCK": y, "RBLOCK": r}
 
-     num_warps = next_power_of_2(min(max(conditional_product(x, y, r) // 256, 1), 8))
 
-     return Config(cfg, num_warps=num_warps, num_stages=num_stages)
 
- def pointwise(size_hints, meta, tile_hint=None, filename=None):
 
-     """
 
-     Construct @triton.heuristics() based on size_hints.
 
-     """
 
-     numel = functools.reduce(operator.mul, size_hints)
 
-     bs = max(256, min(numel // 128, 1024))
 
-     if len(size_hints) == 1:
 
-         return cached_autotune([triton_config(size_hints, bs)], meta=meta)
 
-     if len(size_hints) == 2:
 
-         if (
 
-             not config.triton.autotune_pointwise or tile_hint == TileHint.SQUARE
 
-         ) and not config.max_autotune:
 
-             return cached_autotune([triton_config(size_hints, 32, 32)], meta=meta)
 
-         return cached_autotune(
 
-             [
 
-                 triton_config(size_hints, 32, 32),
 
-                 triton_config(size_hints, 64, 64),  # ~8% better for fp16
 
-                 triton_config(size_hints, 256, 16),
 
-                 triton_config(size_hints, 16, 256),
 
-                 triton_config(size_hints, bs, 1),
 
-                 triton_config(size_hints, 1, bs),
 
-             ],
 
-             meta=meta,
 
-             filename=filename,
 
-         )
 
-     if len(size_hints) == 3:
 
-         if not config.triton.autotune_pointwise:
 
-             return cached_autotune([triton_config(size_hints, 16, 16, 16)], meta=meta)
 
-         return cached_autotune(
 
-             [
 
-                 triton_config(size_hints, 16, 16, 16),
 
-                 triton_config(size_hints, 64, 8, 8),
 
-                 triton_config(size_hints, 8, 64, 8),
 
-                 triton_config(size_hints, 8, 8, 64),
 
-                 triton_config(size_hints, bs, 1, 1),
 
-                 triton_config(size_hints, 1, bs, 1),
 
-                 triton_config(size_hints, 1, 1, bs),
 
-             ],
 
-             meta=meta,
 
-             filename=filename,
 
-         )
 
-     raise NotImplementedError(f"size_hints: {size_hints}")
 
- def reduction(size_hints, reduction_hint=False, meta=None, filename=None):
 
-     """args to @triton.heuristics()"""
 
-     assert meta is not None
 
-     rnumel = size_hints[-1]
 
-     if len(size_hints) == 2:
 
-         contiguous_config = triton_config_reduction(
 
-             size_hints, 1, (rnumel if 256 <= rnumel < 2048 else 2048), num_stages=1
 
-         )
 
-         outer_config = triton_config_reduction(size_hints, 128, 8)
 
-         tiny_config = triton_config_reduction(
 
-             size_hints, 2 * (256 // rnumel) if rnumel <= 256 else 1, min(rnumel, 2048)
 
-         )
 
-         if config.max_autotune:
 
-             pass  # skip all these cases
 
-         elif reduction_hint == ReductionHint.INNER:
 
-             return cached_autotune([contiguous_config], meta=meta)
 
-         elif reduction_hint == ReductionHint.OUTER:
 
-             return cached_autotune([outer_config], meta=meta)
 
-         elif reduction_hint == ReductionHint.OUTER_TINY:
 
-             return cached_autotune([tiny_config], meta=meta)
 
-         if not config.triton.autotune_pointwise:
 
-             return cached_autotune(
 
-                 [triton_config_reduction(size_hints, 32, 128)], meta=meta
 
-             )
 
-         return cached_autotune(
 
-             [
 
-                 contiguous_config,
 
-                 outer_config,
 
-                 tiny_config,
 
-                 triton_config_reduction(size_hints, 64, 64),
 
-                 triton_config_reduction(size_hints, 8, 512),
 
-             ],
 
-             meta=meta,
 
-             filename=filename,
 
-         )
 
-     raise NotImplementedError(f"size_hints: {size_hints}")
 
- def persistent_reduction(size_hints, reduction_hint=False, meta=None, filename=None):
 
-     xnumel, rnumel = size_hints
 
-     configs = [
 
-         triton_config_reduction(size_hints, xblock, rnumel)
 
-         for xblock in (1, 8, 32, 128)
 
-         if rnumel * xblock <= 4096 and xblock <= xnumel
 
-     ]
 
-     # TODO(jansel): we should be able to improve these heuristics
 
-     if reduction_hint == ReductionHint.INNER and rnumel >= 256:
 
-         configs = configs[:1]
 
-     elif reduction_hint == ReductionHint.OUTER:
 
-         configs = configs[-1:]
 
-     elif reduction_hint == ReductionHint.OUTER_TINY:
 
-         configs = [
 
-             triton_config_reduction(
 
-                 size_hints, 2 * (256 // rnumel) if rnumel <= 256 else 1, rnumel
 
-             )
 
-         ]
 
-     return cached_autotune(
 
-         configs,
 
-         meta=meta,
 
-         filename=filename,
 
-     )
 
- def template(num_stages, num_warps, meta, filename=None):
 
-     """
 
-     Compile a triton template
 
-     """
 
-     return cached_autotune(
 
-         [triton.Config({}, num_stages=num_stages, num_warps=num_warps)], meta=meta
 
-     )
 
- def conv_heuristics():
 
-     configs = [
 
-         triton.Config(
 
-             {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, num_stages=2, num_warps=8
 
-         ),
 
-         triton.Config(
 
-             {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=2, num_warps=8
 
-         ),
 
-         triton.Config(
 
-             {"BLOCK_M": 256, "BLOCK_N": 32, "BLOCK_K": 32}, num_stages=4, num_warps=4
 
-         ),
 
-         triton.Config(
 
-             {"BLOCK_M": 256, "BLOCK_N": 32, "BLOCK_K": 64}, num_stages=4, num_warps=4
 
-         ),
 
-         triton.Config(
 
-             {"BLOCK_M": 256, "BLOCK_N": 16, "BLOCK_K": 32}, num_stages=4, num_warps=2
 
-         ),
 
-         triton.Config(
 
-             {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_stages=4, num_warps=8
 
-         ),
 
-         triton.Config(
 
-             {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=4, num_warps=4
 
-         ),
 
-         triton.Config(
 
-             {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=4, num_warps=4
 
-         ),
 
-         triton.Config(
 
-             {"BLOCK_M": 128, "BLOCK_N": 16, "BLOCK_K": 32}, num_stages=4, num_warps=4
 
-         ),
 
-         triton.Config(
 
-             {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=3, num_warps=8
 
-         ),
 
-         triton.Config(
 
-             {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=3, num_warps=8
 
-         ),
 
-         triton.Config(
 
-             {"BLOCK_M": 256, "BLOCK_N": 32, "BLOCK_K": 128}, num_stages=4, num_warps=4
 
-         ),
 
-         triton.Config(
 
-             {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128}, num_stages=4, num_warps=4
 
-         ),
 
-         triton.Config(
 
-             {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 128}, num_stages=4, num_warps=4
 
-         ),
 
-         triton.Config(
 
-             {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64}, num_stages=4, num_warps=2
 
-         ),
 
-         triton.Config(
 
-             {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=4, num_warps=2
 
-         ),
 
-         # triton.Config(
 
-         #     {"BLOCK_M": 128, "BLOCK_N": 16, "BLOCK_K": 64}, num_stages=4, num_warps=2
 
-         # ),
 
-     ]
 
-     key = [
 
-         "BATCH",
 
-         "IN_C",
 
-         "IN_H",
 
-         "IN_W",
 
-         "KERNEL_N",
 
-         "KERNEL_H",
 
-         "KERNEL_W",
 
-         "OUT_H",
 
-         "OUT_W",
 
-         # parameters of conv
 
-         "stride_h",
 
-         "stride_w",
 
-         "padding_h",
 
-         "padding_w",
 
-         "dilation_h",
 
-         "dilation_w",
 
-         "output_padding_h",
 
-         "output_padding_w",
 
-         "groups",
 
-     ]
 
-     prune_configs_by = {
 
-         "early_config_prune": conv_early_config_prune,
 
-         "perf_model": estimate_conv_time,
 
-         "top_k": 10,
 
-     }
 
-     return triton.autotune(configs, key, prune_configs_by=prune_configs_by)
 
- def grid(xnumel, ynumel=None, znumel=None):
 
-     """Helper function to compute triton grids"""
 
-     def get_grid_dim(numel, block_name, block):
 
-         if numel is None:
 
-             return 1
 
-         label = block_name[0]
 
-         if numel == 1:
 
-             assert block == 1, (
 
-                 f"TritonKernel.indexing assumes {label.lower()}numel == 1 => {block_name} == 1"
 
-                 f"({label.lower()}numel=={numel}, {block_name}={block})."
 
-             )
 
-         return cdiv(numel, block)
 
-     def grid_fn(meta):
 
-         return (
 
-             get_grid_dim(xnumel, "XBLOCK", meta.get("XBLOCK", None)),
 
-             get_grid_dim(ynumel, "YBLOCK", meta.get("YBLOCK", None)),
 
-             get_grid_dim(znumel, "ZBLOCK", meta.get("ZBLOCK", None)),
 
-         )
 
-     return grid_fn
 
 
  |