123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794 |
- import collections
- import contextlib
- import dataclasses
- import functools
- import hashlib
- from itertools import count
- from typing import Any, Dict, List
- from torch._dynamo.utils import dynamo_timed
- from .. import codecache, config, ir
- from ..codecache import cpp_compile_command, get_code_path
- from ..utils import cache_on_self, has_triton, sympy_dot, sympy_product
- from ..virtualized import V
- from .common import CodeGen, DeferredLine, IndentedBuffer, Kernel, PythonPrinter
- pexpr = PythonPrinter().doprint
- def buffer_reuse_key(node: ir.Buffer):
- size = node.get_size()
- stride = node.get_stride()
- last_element = sympy_dot([s - 1 for s in size], stride)
- return (
- node.get_device(),
- node.get_dtype(),
- V.graph.sizevars.simplify(sympy_product(size)),
- # Detect gaps in tensor storage caused by strides
- V.graph.sizevars.size_hint(last_element),
- )
- def make_buffer_reuse(old, new, del_func, declare, ending, as_strided):
- assert old.get_dtype() == new.get_dtype()
- del_line = ""
- if old.get_name() not in V.graph.get_output_names():
- del_line = del_func(old.get_name())
- if old.get_size() == new.get_size() and old.get_stride() == new.get_stride():
- return f"{declare}{new.get_name()} = {old.get_name()}{del_line}{ending}"
- return (
- f"{declare}{new.get_name()} = {as_strided}({old.get_name()}, "
- f"{V.graph.sizevars.codegen_shape_tuple(new.get_size())}, "
- f"{V.graph.sizevars.codegen_shape_tuple(new.get_stride())}){del_line}{ending}"
- )
- def make_buffer_allocation(buffer):
- device = buffer.get_device()
- dtype = buffer.get_dtype()
- shape = tuple(buffer.get_size())
- stride = tuple(buffer.get_stride())
- return (
- f"{buffer.get_name()} = empty_strided("
- f"{V.graph.sizevars.codegen_shape_tuple(shape)}, "
- f"{V.graph.sizevars.codegen_shape_tuple(stride)}, "
- f"device='{device.type}', dtype={dtype})"
- )
- def make_cpp_buffer_allocation(buffer):
- from .cpp import DTYPE_TO_ATEN
- # TODO: map layout and device here
- dtype = buffer.get_dtype()
- shape = tuple(buffer.get_size())
- stride = tuple(buffer.get_stride())
- return (
- f"auto {buffer.get_name()} = at::empty_strided("
- f"{V.graph.sizevars.codegen_shape_tuple(shape)}, "
- f"{V.graph.sizevars.codegen_shape_tuple(stride)}, "
- f"{DTYPE_TO_ATEN[dtype]}); "
- )
- class MemoryPlanningState:
- def __init__(self):
- super().__init__()
- self.reuse_pool: Dict[
- Any, List["FreeIfNotReusedLine"]
- ] = collections.defaultdict(list)
- def __contains__(self, key):
- return bool(self.reuse_pool.get(key, None))
- def pop(self, key) -> "FreeIfNotReusedLine":
- item = self.reuse_pool[key].pop()
- assert not item.is_reused
- return item
- def push(self, key, item: "FreeIfNotReusedLine"):
- assert not item.is_reused
- self.reuse_pool[key].append(item)
- @dataclasses.dataclass
- class EnterCudaDeviceContextManagerLine:
- device_idx: int
- def codegen(self, code: IndentedBuffer):
- # Note _DeviceGuard has less overhead than device, but only accepts
- # integers
- code.writeline(f"with torch.cuda._DeviceGuard({self.device_idx}):")
- class ExitCudaDeviceContextManagerLine:
- pass
- class MemoryPlanningLine:
- def plan(self, state: MemoryPlanningState) -> "MemoryPlanningLine":
- """First pass to find reuse"""
- return self
- def codegen(self, code: IndentedBuffer):
- """Second pass to output code"""
- pass
- @dataclasses.dataclass
- class AllocateLine(MemoryPlanningLine):
- node: ir.Buffer
- def plan(self, state: MemoryPlanningState):
- if self.node.get_name() in V.graph.removed_buffers:
- return NullLine()
- # try to reuse a recently freed buffer
- key = buffer_reuse_key(self.node)
- if key in state:
- free_line = state.pop(key)
- free_line.is_reused = True
- return ReuseLine(free_line.node, self.node)
- return self
- def codegen(self, code: IndentedBuffer):
- assert self.node.get_name() not in V.graph.removed_buffers
- code.writeline(make_buffer_allocation(self.node))
- @dataclasses.dataclass
- class CppAllocateLine(AllocateLine):
- def plan(self, state: MemoryPlanningState):
- if self.node.get_name() in V.graph.removed_buffers:
- return NullLine()
- # try to reuse a recently freed buffer
- key = buffer_reuse_key(self.node)
- if key in state:
- free_line = state.pop(key)
- free_line.is_reused = True
- return CppReuseLine(free_line.node, self.node)
- return self
- def codegen(self, code: IndentedBuffer):
- assert self.node.get_name() not in V.graph.removed_buffers
- code.writeline(make_cpp_buffer_allocation(self.node))
- @dataclasses.dataclass
- class FreeIfNotReusedLine(MemoryPlanningLine):
- node: ir.Buffer
- is_reused: bool = False
- def plan(self, state: MemoryPlanningState):
- assert not self.is_reused
- if self.node.get_name() in V.graph.removed_buffers:
- return NullLine()
- state.push(buffer_reuse_key(self.node), self)
- return self
- def codegen(self, code: IndentedBuffer):
- assert self.node.get_name() not in V.graph.removed_buffers
- if not self.is_reused:
- code.writeline(f"del {self.node.get_name()}")
- @dataclasses.dataclass
- class CppFreeIfNotReusedLine(FreeIfNotReusedLine):
- node: ir.Buffer
- is_reused: bool = False
- def codegen(self, code: IndentedBuffer):
- assert (self.node.get_name()) not in V.graph.removed_buffers
- if not self.is_reused:
- code.writeline(f"{self.node.get_name()}.reset();")
- @dataclasses.dataclass
- class ReuseLine(MemoryPlanningLine):
- node: ir.Buffer
- reused_as: ir.Buffer
- def plan(self, state: MemoryPlanningState):
- assert self.node.get_name() not in V.graph.removed_buffers
- assert self.reused_as.get_name() not in V.graph.removed_buffers
- return self
- def codegen(self, code: IndentedBuffer):
- assert self.node.get_name() not in V.graph.removed_buffers
- assert self.reused_as.get_name() not in V.graph.removed_buffers
- code.writeline(
- make_buffer_reuse(
- self.node,
- self.reused_as,
- del_func=lambda name: f"; del {name}",
- declare="",
- ending="",
- as_strided="as_strided",
- )
- + " # reuse"
- )
- @dataclasses.dataclass
- class CppReuseLine(ReuseLine):
- node: ir.Buffer
- reused_as: ir.Buffer
- def codegen(self, code: IndentedBuffer):
- assert self.node.get_name() not in V.graph.removed_buffers
- assert self.reused_as.get_name() not in V.graph.removed_buffers
- code.writeline(
- make_buffer_reuse(
- self.node,
- self.reused_as,
- del_func=lambda name: f"; {name}.reset()",
- declare="auto ",
- ending=";",
- as_strided="at::as_strided",
- )
- + " // reuse"
- )
- @dataclasses.dataclass
- class FreeLine(MemoryPlanningLine):
- node: ir.Buffer
- def plan(self, state: MemoryPlanningState):
- if self.node.get_name() in V.graph.removed_buffers:
- return NullLine()
- return self
- def codegen(self, code: IndentedBuffer):
- assert self.node.get_name() not in V.graph.removed_buffers
- code.writeline(f"del {self.node.get_name()}")
- class NullLine(MemoryPlanningLine):
- pass
- class WrapperCodeGen(CodeGen):
- """
- The outer wrapper that calls the kernels.
- """
- def __init__(self):
- super().__init__()
- self._names_iter = count()
- self.header = IndentedBuffer()
- self.prefix = IndentedBuffer()
- self.wrapper_call = IndentedBuffer()
- self.kernels = {}
- self.lines = []
- self.header.splice(
- f"""
- from ctypes import c_void_p, c_long
- import torch
- import math
- import random
- from torch import empty_strided, as_strided, device
- from {codecache.__name__} import AsyncCompile
- from torch._inductor.select_algorithm import extern_kernels
- aten = torch.ops.aten
- assert_size_stride = torch._C._dynamo.guards.assert_size_stride
- async_compile = AsyncCompile()
- """
- )
- if has_triton():
- self.header.splice(
- """
- import triton
- import triton.language as tl
- from torch._inductor.triton_ops.autotune import grid
- from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
- """
- )
- self.write_prefix()
- for name, value in V.graph.constants.items():
- # include a hash so our code cache gives different constants different files
- hashed = hashlib.sha256(repr(value).encode("utf-8")).hexdigest()
- self.header.writeline(f"{name} = None # {hashed}")
- self.allocated = set()
- self.freed = set()
- # maps from reusing buffer to reused buffer
- self.reuses = dict()
- self.write_get_cuda_stream = functools.lru_cache(None)(
- self.write_get_cuda_stream
- )
- @functools.lru_cache(None)
- def add_import_once(line):
- self.header.writeline(line)
- self.add_import_once = add_import_once
- self._metas = {}
- def add_meta_once(self, meta):
- meta = repr(meta)
- if meta not in self._metas:
- var = f"meta{len(self._metas)}"
- self._metas[meta] = var
- self.header.writeline(f"{var} = {meta}")
- return self._metas[meta]
- @cache_on_self
- def get_output_refs(self):
- return [x.codegen_reference() for x in V.graph.graph_outputs]
- def write_prefix(self):
- self.prefix.splice(
- """
- async_compile.wait(globals())
- del async_compile
- def call(args):
- """
- )
- with self.prefix.indent():
- if config.triton.debug_sync_graph:
- self.prefix.writeline("torch.cuda.synchronize()")
- inp_len = len(V.graph.graph_inputs.keys())
- if inp_len != 0:
- lhs = f"{', '.join(V.graph.graph_inputs.keys())}{'' if inp_len != 1 else ','}"
- self.prefix.writeline(f"{lhs} = args")
- self.prefix.writeline("args.clear()")
- for name in V.graph.randomness_seeds:
- self.prefix.writeline(
- f"torch.randint(2**31, size=(), dtype=torch.int64, out={name})"
- )
- V.graph.sizevars.codegen(self.prefix, V.graph.graph_inputs)
- def append_precomputed_sizes_to_prefix(self):
- with self.prefix.indent():
- V.graph.sizevars.codegen_precomputed_sizes(self.prefix)
- def write_get_cuda_stream(self, index):
- name = f"stream{index}"
- self.writeline(f"{name} = get_cuda_stream({index})")
- return name
- def next_kernel_suffix(self):
- return f"{next(self._names_iter)}"
- def write_allocate_line(self, buffer):
- self.writeline(AllocateLine(buffer))
- def get_deferred_line(self, name, layout):
- return DeferredLine(
- name, f"{name} = {layout.view.codegen_reference()} # alias"
- )
- def codegen_allocation(self, buffer):
- name = buffer.get_name()
- if name in V.graph.removed_buffers or name in self.allocated:
- return
- self.allocated.add(name)
- if isinstance(
- buffer,
- (ir.ExternKernelAlloc, ir.MultiOutput),
- ):
- return
- layout = buffer.get_layout()
- if isinstance(layout, ir.MutationLayout):
- return
- if isinstance(layout, ir.AliasedLayout):
- assert isinstance(layout.view, ir.ReinterpretView)
- if not layout.maybe_guard_aligned():
- V.graph.unaligned_buffers.add(name)
- self.codegen_allocation(layout.view.data)
- allocation = self.get_deferred_line(name, layout)
- self.writeline(allocation)
- return
- self.write_allocate_line(buffer)
- def write_del_line(self, name):
- self.writeline(f"del {name}")
- def write_free_if_not_reused_line(self, buffer):
- self.writeline(FreeIfNotReusedLine(buffer))
- def codegen_free(self, buffer):
- name = buffer.get_name()
- # can be freed but not reused
- if isinstance(buffer, ir.InputBuffer):
- self.write_del_line(name)
- return
- if not self.can_reuse(buffer):
- return
- self.freed.add(name)
- layout = buffer.get_layout()
- if isinstance(layout, (ir.AliasedLayout, ir.MultiOutputLayout)):
- self.write_del_line(name)
- return
- self.write_free_if_not_reused_line(buffer)
- def can_reuse(self, buffer):
- name = buffer.get_name()
- if (
- name in V.graph.removed_buffers
- or name in V.graph.graph_inputs
- or name in V.graph.constants
- or name in self.freed
- ):
- return False
- return True
- def did_reuse(self, buffer, reused_buffer):
- # Check whether a given buffer was reused by a possible reuser in the wrapper codegen
- # Can be consulted from inside ir codegen, e.g. to determine whether a copy is needed
- return (
- buffer.get_name() in self.reuses
- and self.reuses[buffer.get_name()] == reused_buffer.get_name()
- )
- def write_reuse_line(self, input_buffer, output_buffer):
- self.writeline(ReuseLine(input_buffer, output_buffer))
- def codegen_inplace_reuse(self, input_buffer, output_buffer):
- assert buffer_reuse_key(input_buffer) == buffer_reuse_key(output_buffer)
- self.codegen_allocation(input_buffer)
- self.freed.add(input_buffer.get_name())
- self.allocated.add(output_buffer.get_name())
- self.reuses[output_buffer.get_name()] = input_buffer.get_name()
- self.write_reuse_line(input_buffer, output_buffer)
- def codegen_cuda_device_guard_enter(self, device_idx):
- self.lines.append(EnterCudaDeviceContextManagerLine(device_idx))
- def codegen_cuda_device_guard_exit(self):
- self.lines.append(ExitCudaDeviceContextManagerLine())
- def generate_return(self, output_refs):
- if output_refs:
- self.wrapper_call.writeline("return (" + ", ".join(output_refs) + ", )")
- else:
- self.wrapper_call.writeline("return ()")
- def generate_end(self, result):
- return
- def generate_extern_kernel_out(
- self, output_view, codegen_reference, args, kernel, cpp_kernel
- ):
- if output_view:
- args.append(f"out={output_view.codegen_reference()}")
- else:
- args.append(f"out={codegen_reference}")
- self.writeline(f"{kernel}({', '.join(args)})")
- @dynamo_timed
- def generate(self):
- result = IndentedBuffer()
- result.splice(self.header)
- out_names = V.graph.get_output_names()
- with contextlib.ExitStack() as stack:
- stack.enter_context(self.wrapper_call.indent())
- if config.profiler_mark_wrapper_call:
- self.wrapper_call.writeline(
- "from torch.profiler import record_function"
- )
- self.wrapper_call.writeline(
- "with record_function('inductor_wrapper_call'):"
- )
- stack.enter_context(self.wrapper_call.indent())
- while (
- self.lines
- and isinstance(self.lines[-1], MemoryPlanningLine)
- and self.lines[-1].node.name not in out_names
- ):
- # these lines will be pointless
- self.lines.pop()
- # codegen allocations in two passes
- planning_state = MemoryPlanningState()
- for i in range(len(self.lines)):
- if isinstance(self.lines[i], MemoryPlanningLine):
- self.lines[i] = self.lines[i].plan(planning_state)
- device_cm_stack = contextlib.ExitStack()
- for line in self.lines:
- if isinstance(line, MemoryPlanningLine):
- line.codegen(self.wrapper_call)
- elif isinstance(line, EnterCudaDeviceContextManagerLine):
- line.codegen(self.wrapper_call)
- device_cm_stack.enter_context(self.wrapper_call.indent())
- self.wrapper_call.writeline(
- f"torch.cuda.set_device({line.device_idx}) # no-op to ensure context"
- )
- elif isinstance(line, ExitCudaDeviceContextManagerLine):
- device_cm_stack.close()
- else:
- self.wrapper_call.writeline(line)
- output_refs = self.get_output_refs()
- if config.triton.debug_sync_graph:
- self.wrapper_call.writeline("torch.cuda.synchronize()")
- self.generate_return(output_refs)
- self.append_precomputed_sizes_to_prefix()
- result.splice(self.prefix)
- with result.indent():
- result.splice(self.wrapper_call)
- self.generate_end(result)
- self.add_benchmark_harness(result)
- return result.getvalue()
- def add_benchmark_harness(self, output):
- """
- Append a benchmark harness to generated code for debugging
- """
- if not config.benchmark_harness:
- return
- def add_fake_input(name, shape, stride, device, dtype):
- output.writeline(
- f"{name} = rand_strided("
- f"{V.graph.sizevars.codegen_benchmark_shape_tuple(shape)}, "
- f"{V.graph.sizevars.codegen_benchmark_shape_tuple(stride)}, "
- f"device='{device}', dtype={dtype})"
- )
- output.writelines(["", "", 'if __name__ == "__main__":'])
- with output.indent():
- output.splice(
- """
- from torch._dynamo.testing import rand_strided
- from torch._inductor.utils import print_performance
- """,
- strip=True,
- )
- for name, value in V.graph.constants.items():
- add_fake_input(
- name, value.size(), value.stride(), value.device, value.dtype
- )
- for name, value in V.graph.graph_inputs.items():
- shape = [V.graph.sizevars.size_hint(x) for x in value.get_size()]
- stride = [V.graph.sizevars.size_hint(x) for x in value.get_stride()]
- add_fake_input(
- name, shape, stride, value.get_device(), value.get_dtype()
- )
- output.writeline(
- f"print_performance(lambda: call([{', '.join(V.graph.graph_inputs.keys())}]))"
- )
- def define_kernel(self, name: str, kernel: str):
- self.header.splice(f"\n\n{name} = {kernel}")
- def load_kernel(self, name: str = None, kernel: str = None, arg_types: List = None):
- return
- def wrap_kernel_call(self, name, call_args):
- return "{}({})".format(name, ", ".join(call_args))
- def generate_kernel_call(self, name, call_args):
- self.writeline(
- self.wrap_kernel_call(name, call_args),
- )
- def call_kernel(self, name: str, kernel: Kernel):
- tmp = IndentedBuffer()
- kernel.call_kernel(self, tmp, name)
- for line in tmp.getvalue().split("\n"):
- line = line.strip()
- if line:
- self.writeline(line)
- def writeline(self, line):
- self.lines.append(line)
- class CppWrapperCodeGen(WrapperCodeGen):
- """
- The outer wrapper that calls the kernels.
- """
- call_func_id = count()
- def __init__(self):
- self._call_func_id = next(CppWrapperCodeGen.call_func_id)
- super().__init__()
- @cache_on_self
- def get_output_refs(self):
- def has_cpp_codegen_func(x):
- return hasattr(x, "cpp_wrapper_codegen_reference") and callable(
- x.cpp_wrapper_codegen_reference
- )
- return [
- x.cpp_wrapper_codegen_reference()
- if has_cpp_codegen_func(x)
- else x.codegen_reference()
- for x in V.graph.graph_outputs
- ]
- def write_prefix(self):
- self.prefix.splice(
- """
- async_compile.wait(globals())
- del async_compile
- from torch.utils.cpp_extension import load_inline
- wrapper = (
- '''
- #include <dlfcn.h>
- #include <assert.h>
- template <typename KernelFunc>
- KernelFunc load_cpp_kernel(const char* so_filename) {
- KernelFunc kernel_cpp;
- auto kernel_cpp_lib = dlopen(so_filename, RTLD_NOW);
- assert(kernel_cpp_lib != nullptr);
- *(void **) (&kernel_cpp) = dlsym(kernel_cpp_lib, "kernel");
- return kernel_cpp;
- }
- """
- )
- with self.wrapper_call.indent():
- inputs_len = len(V.graph.graph_inputs.keys())
- output_refs = self.get_output_refs()
- if output_refs:
- if len(output_refs) == 1:
- output_types = "at::Tensor"
- else:
- output_types = "std::vector<at::Tensor>"
- else:
- output_types = "void"
- inputs_types = "std::vector<at::Tensor>"
- self.wrapper_call.writeline(
- f"{output_types} call_{self._call_func_id}({inputs_types} args) {{"
- )
- if inputs_len != 0:
- inputs_keys_str = ", ".join(V.graph.graph_inputs.keys())
- self.wrapper_call.writeline(f"at::Tensor {inputs_keys_str};")
- for idx, input_key in enumerate(V.graph.graph_inputs.keys()):
- self.wrapper_call.writeline(f"{input_key} = args[{idx}];")
- for name in V.graph.randomness_seeds:
- self.wrapper_call.writeline(f"at::Tensor {name};")
- self.wrapper_call.writeline(
- f"{name} = at::randint(std::pow(2, 31), {{}}, at::ScalarType::Long);"
- )
- V.graph.sizevars.codegen(self.wrapper_call, V.graph.graph_inputs)
- def write_allocate_line(self, buffer):
- self.writeline(CppAllocateLine(buffer))
- def write_del_line(self, name):
- self.writeline(f"{name}.reset();")
- return
- def write_free_if_not_reused_line(self, buffer):
- self.writeline(CppFreeIfNotReusedLine(buffer))
- return
- def write_reuse_line(self, input_buffer, output_buffer):
- self.writeline(CppReuseLine(input_buffer, output_buffer))
- def get_deferred_line(self, name, layout):
- return DeferredLine(
- name, f"auto {name} = {layout.view.codegen_reference()}; // alias"
- )
- def get_kernel_path(self, code):
- from ..codecache import pick_vec_isa
- picked_vec_isa = pick_vec_isa()
- ext = "so"
- extra = cpp_compile_command("i", "o", vec_isa=picked_vec_isa)
- # \n is required to match with the CodeCache behavior
- # For reductions, the code string gotten from code.getvalue() will use backslash '\'
- # at the end of lines for readability purpose:
- # #pragma omp declare reduction(xxx :\
- # omp_out.value = xxx,\
- # While the code string loaded during the execution will escape the backslash '\':
- # #pragma omp declare reduction(xxx : omp_out.value = xxx,
- # Use code.getrawvalue() here to escape the backslash to
- # make sure the same code string is used during compilation and execution,
- # so that the hash value is the same.
- source_code = "\n" + code.getrawvalue()
- _, _, kernel_path = get_code_path(source_code, ext, extra)
- return kernel_path
- def load_kernel(self, name: str = None, kernel: str = None, arg_types: List = None):
- kernel_path = self.get_kernel_path(kernel)
- self.writeline(
- f'static auto {name} = load_cpp_kernel<void (*)({arg_types})>("{kernel_path}");'
- )
- def wrap_kernel_call(self, name, call_args):
- return "{}({});".format(name, ", ".join(call_args))
- def generate_return(self, output_refs):
- if output_refs:
- if len(output_refs) == 1:
- self.wrapper_call.writeline("return " + output_refs[0] + "; }''' )")
- else:
- self.wrapper_call.writeline(
- "return std::vector<at::Tensor>({"
- + ", ".join(output_refs)
- + "}); }''' )"
- )
- else:
- self.wrapper_call.writeline("return; }''' )")
- def generate_end(self, result):
- shared = codecache.get_shared()
- warning_all_flag = codecache.get_warning_all_flag()
- cpp_flags = codecache.cpp_flags()
- ipaths, lpaths, libs, macros = codecache.get_include_and_linking_paths()
- optimization_flags = codecache.optimization_flags()
- use_custom_generated_macros = codecache.use_custom_generated_macros()
- extra_cflags = f"{cpp_flags} {optimization_flags} {warning_all_flag} {macros} {use_custom_generated_macros}"
- extra_ldflags = f"{shared} {lpaths} {libs}"
- extra_include_paths = f"{ipaths}"
- # get the hash of the wrapper code to name the extension
- wrapper_call_hash = codecache.code_hash(self.wrapper_call.getvalue())
- result.splice(
- f"""
- module = load_inline(
- name='inline_extension_{wrapper_call_hash}',
- cpp_sources=[wrapper],
- functions=['call_{self._call_func_id}'],
- extra_cflags=['{extra_cflags}'],
- extra_ldflags=['{extra_ldflags}'],
- extra_include_paths=['{extra_include_paths}'])
- """
- )
- # Wrap the func to support setting result._boxed_call = True
- result.splice(
- f"""
- def _wrap_func(f):
- def g(args):
- return f(args)
- return g
- call = _wrap_func(module.call_{self._call_func_id})
- """
- )
- def generate_extern_kernel_out(
- self, output_view, codegen_reference, args, kernel, cpp_kernel
- ):
- if output_view:
- output_as_strided = f"{output_view.codegen_reference()}"
- output_name = f"{output_view.get_name()}_as_strided"
- self.writeline(f"auto {output_name} = {output_as_strided};")
- args.insert(0, output_name)
- else:
- args.insert(0, f"{codegen_reference}")
- self.writeline(f"{cpp_kernel}({', '.join(args)});")
|