import collections import contextlib import dataclasses import functools import hashlib from itertools import count from typing import Any, Dict, List from torch._dynamo.utils import dynamo_timed from .. import codecache, config, ir from ..codecache import cpp_compile_command, get_code_path from ..utils import cache_on_self, has_triton, sympy_dot, sympy_product from ..virtualized import V from .common import CodeGen, DeferredLine, IndentedBuffer, Kernel, PythonPrinter pexpr = PythonPrinter().doprint def buffer_reuse_key(node: ir.Buffer): size = node.get_size() stride = node.get_stride() last_element = sympy_dot([s - 1 for s in size], stride) return ( node.get_device(), node.get_dtype(), V.graph.sizevars.simplify(sympy_product(size)), # Detect gaps in tensor storage caused by strides V.graph.sizevars.size_hint(last_element), ) def make_buffer_reuse(old, new, del_func, declare, ending, as_strided): assert old.get_dtype() == new.get_dtype() del_line = "" if old.get_name() not in V.graph.get_output_names(): del_line = del_func(old.get_name()) if old.get_size() == new.get_size() and old.get_stride() == new.get_stride(): return f"{declare}{new.get_name()} = {old.get_name()}{del_line}{ending}" return ( f"{declare}{new.get_name()} = {as_strided}({old.get_name()}, " f"{V.graph.sizevars.codegen_shape_tuple(new.get_size())}, " f"{V.graph.sizevars.codegen_shape_tuple(new.get_stride())}){del_line}{ending}" ) def make_buffer_allocation(buffer): device = buffer.get_device() dtype = buffer.get_dtype() shape = tuple(buffer.get_size()) stride = tuple(buffer.get_stride()) return ( f"{buffer.get_name()} = empty_strided(" f"{V.graph.sizevars.codegen_shape_tuple(shape)}, " f"{V.graph.sizevars.codegen_shape_tuple(stride)}, " f"device='{device.type}', dtype={dtype})" ) def make_cpp_buffer_allocation(buffer): from .cpp import DTYPE_TO_ATEN # TODO: map layout and device here dtype = buffer.get_dtype() shape = tuple(buffer.get_size()) stride = tuple(buffer.get_stride()) return ( f"auto {buffer.get_name()} = at::empty_strided(" f"{V.graph.sizevars.codegen_shape_tuple(shape)}, " f"{V.graph.sizevars.codegen_shape_tuple(stride)}, " f"{DTYPE_TO_ATEN[dtype]}); " ) class MemoryPlanningState: def __init__(self): super().__init__() self.reuse_pool: Dict[ Any, List["FreeIfNotReusedLine"] ] = collections.defaultdict(list) def __contains__(self, key): return bool(self.reuse_pool.get(key, None)) def pop(self, key) -> "FreeIfNotReusedLine": item = self.reuse_pool[key].pop() assert not item.is_reused return item def push(self, key, item: "FreeIfNotReusedLine"): assert not item.is_reused self.reuse_pool[key].append(item) @dataclasses.dataclass class EnterCudaDeviceContextManagerLine: device_idx: int def codegen(self, code: IndentedBuffer): # Note _DeviceGuard has less overhead than device, but only accepts # integers code.writeline(f"with torch.cuda._DeviceGuard({self.device_idx}):") class ExitCudaDeviceContextManagerLine: pass class MemoryPlanningLine: def plan(self, state: MemoryPlanningState) -> "MemoryPlanningLine": """First pass to find reuse""" return self def codegen(self, code: IndentedBuffer): """Second pass to output code""" pass @dataclasses.dataclass class AllocateLine(MemoryPlanningLine): node: ir.Buffer def plan(self, state: MemoryPlanningState): if self.node.get_name() in V.graph.removed_buffers: return NullLine() # try to reuse a recently freed buffer key = buffer_reuse_key(self.node) if key in state: free_line = state.pop(key) free_line.is_reused = True return ReuseLine(free_line.node, self.node) return self def codegen(self, code: IndentedBuffer): assert self.node.get_name() not in V.graph.removed_buffers code.writeline(make_buffer_allocation(self.node)) @dataclasses.dataclass class CppAllocateLine(AllocateLine): def plan(self, state: MemoryPlanningState): if self.node.get_name() in V.graph.removed_buffers: return NullLine() # try to reuse a recently freed buffer key = buffer_reuse_key(self.node) if key in state: free_line = state.pop(key) free_line.is_reused = True return CppReuseLine(free_line.node, self.node) return self def codegen(self, code: IndentedBuffer): assert self.node.get_name() not in V.graph.removed_buffers code.writeline(make_cpp_buffer_allocation(self.node)) @dataclasses.dataclass class FreeIfNotReusedLine(MemoryPlanningLine): node: ir.Buffer is_reused: bool = False def plan(self, state: MemoryPlanningState): assert not self.is_reused if self.node.get_name() in V.graph.removed_buffers: return NullLine() state.push(buffer_reuse_key(self.node), self) return self def codegen(self, code: IndentedBuffer): assert self.node.get_name() not in V.graph.removed_buffers if not self.is_reused: code.writeline(f"del {self.node.get_name()}") @dataclasses.dataclass class CppFreeIfNotReusedLine(FreeIfNotReusedLine): node: ir.Buffer is_reused: bool = False def codegen(self, code: IndentedBuffer): assert (self.node.get_name()) not in V.graph.removed_buffers if not self.is_reused: code.writeline(f"{self.node.get_name()}.reset();") @dataclasses.dataclass class ReuseLine(MemoryPlanningLine): node: ir.Buffer reused_as: ir.Buffer def plan(self, state: MemoryPlanningState): assert self.node.get_name() not in V.graph.removed_buffers assert self.reused_as.get_name() not in V.graph.removed_buffers return self def codegen(self, code: IndentedBuffer): assert self.node.get_name() not in V.graph.removed_buffers assert self.reused_as.get_name() not in V.graph.removed_buffers code.writeline( make_buffer_reuse( self.node, self.reused_as, del_func=lambda name: f"; del {name}", declare="", ending="", as_strided="as_strided", ) + " # reuse" ) @dataclasses.dataclass class CppReuseLine(ReuseLine): node: ir.Buffer reused_as: ir.Buffer def codegen(self, code: IndentedBuffer): assert self.node.get_name() not in V.graph.removed_buffers assert self.reused_as.get_name() not in V.graph.removed_buffers code.writeline( make_buffer_reuse( self.node, self.reused_as, del_func=lambda name: f"; {name}.reset()", declare="auto ", ending=";", as_strided="at::as_strided", ) + " // reuse" ) @dataclasses.dataclass class FreeLine(MemoryPlanningLine): node: ir.Buffer def plan(self, state: MemoryPlanningState): if self.node.get_name() in V.graph.removed_buffers: return NullLine() return self def codegen(self, code: IndentedBuffer): assert self.node.get_name() not in V.graph.removed_buffers code.writeline(f"del {self.node.get_name()}") class NullLine(MemoryPlanningLine): pass class WrapperCodeGen(CodeGen): """ The outer wrapper that calls the kernels. """ def __init__(self): super().__init__() self._names_iter = count() self.header = IndentedBuffer() self.prefix = IndentedBuffer() self.wrapper_call = IndentedBuffer() self.kernels = {} self.lines = [] self.header.splice( f""" from ctypes import c_void_p, c_long import torch import math import random from torch import empty_strided, as_strided, device from {codecache.__name__} import AsyncCompile from torch._inductor.select_algorithm import extern_kernels aten = torch.ops.aten assert_size_stride = torch._C._dynamo.guards.assert_size_stride async_compile = AsyncCompile() """ ) if has_triton(): self.header.splice( """ import triton import triton.language as tl from torch._inductor.triton_ops.autotune import grid from torch._C import _cuda_getCurrentRawStream as get_cuda_stream """ ) self.write_prefix() for name, value in V.graph.constants.items(): # include a hash so our code cache gives different constants different files hashed = hashlib.sha256(repr(value).encode("utf-8")).hexdigest() self.header.writeline(f"{name} = None # {hashed}") self.allocated = set() self.freed = set() # maps from reusing buffer to reused buffer self.reuses = dict() self.write_get_cuda_stream = functools.lru_cache(None)( self.write_get_cuda_stream ) @functools.lru_cache(None) def add_import_once(line): self.header.writeline(line) self.add_import_once = add_import_once self._metas = {} def add_meta_once(self, meta): meta = repr(meta) if meta not in self._metas: var = f"meta{len(self._metas)}" self._metas[meta] = var self.header.writeline(f"{var} = {meta}") return self._metas[meta] @cache_on_self def get_output_refs(self): return [x.codegen_reference() for x in V.graph.graph_outputs] def write_prefix(self): self.prefix.splice( """ async_compile.wait(globals()) del async_compile def call(args): """ ) with self.prefix.indent(): if config.triton.debug_sync_graph: self.prefix.writeline("torch.cuda.synchronize()") inp_len = len(V.graph.graph_inputs.keys()) if inp_len != 0: lhs = f"{', '.join(V.graph.graph_inputs.keys())}{'' if inp_len != 1 else ','}" self.prefix.writeline(f"{lhs} = args") self.prefix.writeline("args.clear()") for name in V.graph.randomness_seeds: self.prefix.writeline( f"torch.randint(2**31, size=(), dtype=torch.int64, out={name})" ) V.graph.sizevars.codegen(self.prefix, V.graph.graph_inputs) def append_precomputed_sizes_to_prefix(self): with self.prefix.indent(): V.graph.sizevars.codegen_precomputed_sizes(self.prefix) def write_get_cuda_stream(self, index): name = f"stream{index}" self.writeline(f"{name} = get_cuda_stream({index})") return name def next_kernel_suffix(self): return f"{next(self._names_iter)}" def write_allocate_line(self, buffer): self.writeline(AllocateLine(buffer)) def get_deferred_line(self, name, layout): return DeferredLine( name, f"{name} = {layout.view.codegen_reference()} # alias" ) def codegen_allocation(self, buffer): name = buffer.get_name() if name in V.graph.removed_buffers or name in self.allocated: return self.allocated.add(name) if isinstance( buffer, (ir.ExternKernelAlloc, ir.MultiOutput), ): return layout = buffer.get_layout() if isinstance(layout, ir.MutationLayout): return if isinstance(layout, ir.AliasedLayout): assert isinstance(layout.view, ir.ReinterpretView) if not layout.maybe_guard_aligned(): V.graph.unaligned_buffers.add(name) self.codegen_allocation(layout.view.data) allocation = self.get_deferred_line(name, layout) self.writeline(allocation) return self.write_allocate_line(buffer) def write_del_line(self, name): self.writeline(f"del {name}") def write_free_if_not_reused_line(self, buffer): self.writeline(FreeIfNotReusedLine(buffer)) def codegen_free(self, buffer): name = buffer.get_name() # can be freed but not reused if isinstance(buffer, ir.InputBuffer): self.write_del_line(name) return if not self.can_reuse(buffer): return self.freed.add(name) layout = buffer.get_layout() if isinstance(layout, (ir.AliasedLayout, ir.MultiOutputLayout)): self.write_del_line(name) return self.write_free_if_not_reused_line(buffer) def can_reuse(self, buffer): name = buffer.get_name() if ( name in V.graph.removed_buffers or name in V.graph.graph_inputs or name in V.graph.constants or name in self.freed ): return False return True def did_reuse(self, buffer, reused_buffer): # Check whether a given buffer was reused by a possible reuser in the wrapper codegen # Can be consulted from inside ir codegen, e.g. to determine whether a copy is needed return ( buffer.get_name() in self.reuses and self.reuses[buffer.get_name()] == reused_buffer.get_name() ) def write_reuse_line(self, input_buffer, output_buffer): self.writeline(ReuseLine(input_buffer, output_buffer)) def codegen_inplace_reuse(self, input_buffer, output_buffer): assert buffer_reuse_key(input_buffer) == buffer_reuse_key(output_buffer) self.codegen_allocation(input_buffer) self.freed.add(input_buffer.get_name()) self.allocated.add(output_buffer.get_name()) self.reuses[output_buffer.get_name()] = input_buffer.get_name() self.write_reuse_line(input_buffer, output_buffer) def codegen_cuda_device_guard_enter(self, device_idx): self.lines.append(EnterCudaDeviceContextManagerLine(device_idx)) def codegen_cuda_device_guard_exit(self): self.lines.append(ExitCudaDeviceContextManagerLine()) def generate_return(self, output_refs): if output_refs: self.wrapper_call.writeline("return (" + ", ".join(output_refs) + ", )") else: self.wrapper_call.writeline("return ()") def generate_end(self, result): return def generate_extern_kernel_out( self, output_view, codegen_reference, args, kernel, cpp_kernel ): if output_view: args.append(f"out={output_view.codegen_reference()}") else: args.append(f"out={codegen_reference}") self.writeline(f"{kernel}({', '.join(args)})") @dynamo_timed def generate(self): result = IndentedBuffer() result.splice(self.header) out_names = V.graph.get_output_names() with contextlib.ExitStack() as stack: stack.enter_context(self.wrapper_call.indent()) if config.profiler_mark_wrapper_call: self.wrapper_call.writeline( "from torch.profiler import record_function" ) self.wrapper_call.writeline( "with record_function('inductor_wrapper_call'):" ) stack.enter_context(self.wrapper_call.indent()) while ( self.lines and isinstance(self.lines[-1], MemoryPlanningLine) and self.lines[-1].node.name not in out_names ): # these lines will be pointless self.lines.pop() # codegen allocations in two passes planning_state = MemoryPlanningState() for i in range(len(self.lines)): if isinstance(self.lines[i], MemoryPlanningLine): self.lines[i] = self.lines[i].plan(planning_state) device_cm_stack = contextlib.ExitStack() for line in self.lines: if isinstance(line, MemoryPlanningLine): line.codegen(self.wrapper_call) elif isinstance(line, EnterCudaDeviceContextManagerLine): line.codegen(self.wrapper_call) device_cm_stack.enter_context(self.wrapper_call.indent()) self.wrapper_call.writeline( f"torch.cuda.set_device({line.device_idx}) # no-op to ensure context" ) elif isinstance(line, ExitCudaDeviceContextManagerLine): device_cm_stack.close() else: self.wrapper_call.writeline(line) output_refs = self.get_output_refs() if config.triton.debug_sync_graph: self.wrapper_call.writeline("torch.cuda.synchronize()") self.generate_return(output_refs) self.append_precomputed_sizes_to_prefix() result.splice(self.prefix) with result.indent(): result.splice(self.wrapper_call) self.generate_end(result) self.add_benchmark_harness(result) return result.getvalue() def add_benchmark_harness(self, output): """ Append a benchmark harness to generated code for debugging """ if not config.benchmark_harness: return def add_fake_input(name, shape, stride, device, dtype): output.writeline( f"{name} = rand_strided(" f"{V.graph.sizevars.codegen_benchmark_shape_tuple(shape)}, " f"{V.graph.sizevars.codegen_benchmark_shape_tuple(stride)}, " f"device='{device}', dtype={dtype})" ) output.writelines(["", "", 'if __name__ == "__main__":']) with output.indent(): output.splice( """ from torch._dynamo.testing import rand_strided from torch._inductor.utils import print_performance """, strip=True, ) for name, value in V.graph.constants.items(): add_fake_input( name, value.size(), value.stride(), value.device, value.dtype ) for name, value in V.graph.graph_inputs.items(): shape = [V.graph.sizevars.size_hint(x) for x in value.get_size()] stride = [V.graph.sizevars.size_hint(x) for x in value.get_stride()] add_fake_input( name, shape, stride, value.get_device(), value.get_dtype() ) output.writeline( f"print_performance(lambda: call([{', '.join(V.graph.graph_inputs.keys())}]))" ) def define_kernel(self, name: str, kernel: str): self.header.splice(f"\n\n{name} = {kernel}") def load_kernel(self, name: str = None, kernel: str = None, arg_types: List = None): return def wrap_kernel_call(self, name, call_args): return "{}({})".format(name, ", ".join(call_args)) def generate_kernel_call(self, name, call_args): self.writeline( self.wrap_kernel_call(name, call_args), ) def call_kernel(self, name: str, kernel: Kernel): tmp = IndentedBuffer() kernel.call_kernel(self, tmp, name) for line in tmp.getvalue().split("\n"): line = line.strip() if line: self.writeline(line) def writeline(self, line): self.lines.append(line) class CppWrapperCodeGen(WrapperCodeGen): """ The outer wrapper that calls the kernels. """ call_func_id = count() def __init__(self): self._call_func_id = next(CppWrapperCodeGen.call_func_id) super().__init__() @cache_on_self def get_output_refs(self): def has_cpp_codegen_func(x): return hasattr(x, "cpp_wrapper_codegen_reference") and callable( x.cpp_wrapper_codegen_reference ) return [ x.cpp_wrapper_codegen_reference() if has_cpp_codegen_func(x) else x.codegen_reference() for x in V.graph.graph_outputs ] def write_prefix(self): self.prefix.splice( """ async_compile.wait(globals()) del async_compile from torch.utils.cpp_extension import load_inline wrapper = ( ''' #include #include template KernelFunc load_cpp_kernel(const char* so_filename) { KernelFunc kernel_cpp; auto kernel_cpp_lib = dlopen(so_filename, RTLD_NOW); assert(kernel_cpp_lib != nullptr); *(void **) (&kernel_cpp) = dlsym(kernel_cpp_lib, "kernel"); return kernel_cpp; } """ ) with self.wrapper_call.indent(): inputs_len = len(V.graph.graph_inputs.keys()) output_refs = self.get_output_refs() if output_refs: if len(output_refs) == 1: output_types = "at::Tensor" else: output_types = "std::vector" else: output_types = "void" inputs_types = "std::vector" self.wrapper_call.writeline( f"{output_types} call_{self._call_func_id}({inputs_types} args) {{" ) if inputs_len != 0: inputs_keys_str = ", ".join(V.graph.graph_inputs.keys()) self.wrapper_call.writeline(f"at::Tensor {inputs_keys_str};") for idx, input_key in enumerate(V.graph.graph_inputs.keys()): self.wrapper_call.writeline(f"{input_key} = args[{idx}];") for name in V.graph.randomness_seeds: self.wrapper_call.writeline(f"at::Tensor {name};") self.wrapper_call.writeline( f"{name} = at::randint(std::pow(2, 31), {{}}, at::ScalarType::Long);" ) V.graph.sizevars.codegen(self.wrapper_call, V.graph.graph_inputs) def write_allocate_line(self, buffer): self.writeline(CppAllocateLine(buffer)) def write_del_line(self, name): self.writeline(f"{name}.reset();") return def write_free_if_not_reused_line(self, buffer): self.writeline(CppFreeIfNotReusedLine(buffer)) return def write_reuse_line(self, input_buffer, output_buffer): self.writeline(CppReuseLine(input_buffer, output_buffer)) def get_deferred_line(self, name, layout): return DeferredLine( name, f"auto {name} = {layout.view.codegen_reference()}; // alias" ) def get_kernel_path(self, code): from ..codecache import pick_vec_isa picked_vec_isa = pick_vec_isa() ext = "so" extra = cpp_compile_command("i", "o", vec_isa=picked_vec_isa) # \n is required to match with the CodeCache behavior # For reductions, the code string gotten from code.getvalue() will use backslash '\' # at the end of lines for readability purpose: # #pragma omp declare reduction(xxx :\ # omp_out.value = xxx,\ # While the code string loaded during the execution will escape the backslash '\': # #pragma omp declare reduction(xxx : omp_out.value = xxx, # Use code.getrawvalue() here to escape the backslash to # make sure the same code string is used during compilation and execution, # so that the hash value is the same. source_code = "\n" + code.getrawvalue() _, _, kernel_path = get_code_path(source_code, ext, extra) return kernel_path def load_kernel(self, name: str = None, kernel: str = None, arg_types: List = None): kernel_path = self.get_kernel_path(kernel) self.writeline( f'static auto {name} = load_cpp_kernel("{kernel_path}");' ) def wrap_kernel_call(self, name, call_args): return "{}({});".format(name, ", ".join(call_args)) def generate_return(self, output_refs): if output_refs: if len(output_refs) == 1: self.wrapper_call.writeline("return " + output_refs[0] + "; }''' )") else: self.wrapper_call.writeline( "return std::vector({" + ", ".join(output_refs) + "}); }''' )" ) else: self.wrapper_call.writeline("return; }''' )") def generate_end(self, result): shared = codecache.get_shared() warning_all_flag = codecache.get_warning_all_flag() cpp_flags = codecache.cpp_flags() ipaths, lpaths, libs, macros = codecache.get_include_and_linking_paths() optimization_flags = codecache.optimization_flags() use_custom_generated_macros = codecache.use_custom_generated_macros() extra_cflags = f"{cpp_flags} {optimization_flags} {warning_all_flag} {macros} {use_custom_generated_macros}" extra_ldflags = f"{shared} {lpaths} {libs}" extra_include_paths = f"{ipaths}" # get the hash of the wrapper code to name the extension wrapper_call_hash = codecache.code_hash(self.wrapper_call.getvalue()) result.splice( f""" module = load_inline( name='inline_extension_{wrapper_call_hash}', cpp_sources=[wrapper], functions=['call_{self._call_func_id}'], extra_cflags=['{extra_cflags}'], extra_ldflags=['{extra_ldflags}'], extra_include_paths=['{extra_include_paths}']) """ ) # Wrap the func to support setting result._boxed_call = True result.splice( f""" def _wrap_func(f): def g(args): return f(args) return g call = _wrap_func(module.call_{self._call_func_id}) """ ) def generate_extern_kernel_out( self, output_view, codegen_reference, args, kernel, cpp_kernel ): if output_view: output_as_strided = f"{output_view.codegen_reference()}" output_name = f"{output_view.get_name()}_as_strided" self.writeline(f"auto {output_name} = {output_as_strided};") args.insert(0, output_name) else: args.insert(0, f"{codegen_reference}") self.writeline(f"{cpp_kernel}({', '.join(args)});")