wrapper.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794
  1. import collections
  2. import contextlib
  3. import dataclasses
  4. import functools
  5. import hashlib
  6. from itertools import count
  7. from typing import Any, Dict, List
  8. from torch._dynamo.utils import dynamo_timed
  9. from .. import codecache, config, ir
  10. from ..codecache import cpp_compile_command, get_code_path
  11. from ..utils import cache_on_self, has_triton, sympy_dot, sympy_product
  12. from ..virtualized import V
  13. from .common import CodeGen, DeferredLine, IndentedBuffer, Kernel, PythonPrinter
  14. pexpr = PythonPrinter().doprint
  15. def buffer_reuse_key(node: ir.Buffer):
  16. size = node.get_size()
  17. stride = node.get_stride()
  18. last_element = sympy_dot([s - 1 for s in size], stride)
  19. return (
  20. node.get_device(),
  21. node.get_dtype(),
  22. V.graph.sizevars.simplify(sympy_product(size)),
  23. # Detect gaps in tensor storage caused by strides
  24. V.graph.sizevars.size_hint(last_element),
  25. )
  26. def make_buffer_reuse(old, new, del_func, declare, ending, as_strided):
  27. assert old.get_dtype() == new.get_dtype()
  28. del_line = ""
  29. if old.get_name() not in V.graph.get_output_names():
  30. del_line = del_func(old.get_name())
  31. if old.get_size() == new.get_size() and old.get_stride() == new.get_stride():
  32. return f"{declare}{new.get_name()} = {old.get_name()}{del_line}{ending}"
  33. return (
  34. f"{declare}{new.get_name()} = {as_strided}({old.get_name()}, "
  35. f"{V.graph.sizevars.codegen_shape_tuple(new.get_size())}, "
  36. f"{V.graph.sizevars.codegen_shape_tuple(new.get_stride())}){del_line}{ending}"
  37. )
  38. def make_buffer_allocation(buffer):
  39. device = buffer.get_device()
  40. dtype = buffer.get_dtype()
  41. shape = tuple(buffer.get_size())
  42. stride = tuple(buffer.get_stride())
  43. return (
  44. f"{buffer.get_name()} = empty_strided("
  45. f"{V.graph.sizevars.codegen_shape_tuple(shape)}, "
  46. f"{V.graph.sizevars.codegen_shape_tuple(stride)}, "
  47. f"device='{device.type}', dtype={dtype})"
  48. )
  49. def make_cpp_buffer_allocation(buffer):
  50. from .cpp import DTYPE_TO_ATEN
  51. # TODO: map layout and device here
  52. dtype = buffer.get_dtype()
  53. shape = tuple(buffer.get_size())
  54. stride = tuple(buffer.get_stride())
  55. return (
  56. f"auto {buffer.get_name()} = at::empty_strided("
  57. f"{V.graph.sizevars.codegen_shape_tuple(shape)}, "
  58. f"{V.graph.sizevars.codegen_shape_tuple(stride)}, "
  59. f"{DTYPE_TO_ATEN[dtype]}); "
  60. )
  61. class MemoryPlanningState:
  62. def __init__(self):
  63. super().__init__()
  64. self.reuse_pool: Dict[
  65. Any, List["FreeIfNotReusedLine"]
  66. ] = collections.defaultdict(list)
  67. def __contains__(self, key):
  68. return bool(self.reuse_pool.get(key, None))
  69. def pop(self, key) -> "FreeIfNotReusedLine":
  70. item = self.reuse_pool[key].pop()
  71. assert not item.is_reused
  72. return item
  73. def push(self, key, item: "FreeIfNotReusedLine"):
  74. assert not item.is_reused
  75. self.reuse_pool[key].append(item)
  76. @dataclasses.dataclass
  77. class EnterCudaDeviceContextManagerLine:
  78. device_idx: int
  79. def codegen(self, code: IndentedBuffer):
  80. # Note _DeviceGuard has less overhead than device, but only accepts
  81. # integers
  82. code.writeline(f"with torch.cuda._DeviceGuard({self.device_idx}):")
  83. class ExitCudaDeviceContextManagerLine:
  84. pass
  85. class MemoryPlanningLine:
  86. def plan(self, state: MemoryPlanningState) -> "MemoryPlanningLine":
  87. """First pass to find reuse"""
  88. return self
  89. def codegen(self, code: IndentedBuffer):
  90. """Second pass to output code"""
  91. pass
  92. @dataclasses.dataclass
  93. class AllocateLine(MemoryPlanningLine):
  94. node: ir.Buffer
  95. def plan(self, state: MemoryPlanningState):
  96. if self.node.get_name() in V.graph.removed_buffers:
  97. return NullLine()
  98. # try to reuse a recently freed buffer
  99. key = buffer_reuse_key(self.node)
  100. if key in state:
  101. free_line = state.pop(key)
  102. free_line.is_reused = True
  103. return ReuseLine(free_line.node, self.node)
  104. return self
  105. def codegen(self, code: IndentedBuffer):
  106. assert self.node.get_name() not in V.graph.removed_buffers
  107. code.writeline(make_buffer_allocation(self.node))
  108. @dataclasses.dataclass
  109. class CppAllocateLine(AllocateLine):
  110. def plan(self, state: MemoryPlanningState):
  111. if self.node.get_name() in V.graph.removed_buffers:
  112. return NullLine()
  113. # try to reuse a recently freed buffer
  114. key = buffer_reuse_key(self.node)
  115. if key in state:
  116. free_line = state.pop(key)
  117. free_line.is_reused = True
  118. return CppReuseLine(free_line.node, self.node)
  119. return self
  120. def codegen(self, code: IndentedBuffer):
  121. assert self.node.get_name() not in V.graph.removed_buffers
  122. code.writeline(make_cpp_buffer_allocation(self.node))
  123. @dataclasses.dataclass
  124. class FreeIfNotReusedLine(MemoryPlanningLine):
  125. node: ir.Buffer
  126. is_reused: bool = False
  127. def plan(self, state: MemoryPlanningState):
  128. assert not self.is_reused
  129. if self.node.get_name() in V.graph.removed_buffers:
  130. return NullLine()
  131. state.push(buffer_reuse_key(self.node), self)
  132. return self
  133. def codegen(self, code: IndentedBuffer):
  134. assert self.node.get_name() not in V.graph.removed_buffers
  135. if not self.is_reused:
  136. code.writeline(f"del {self.node.get_name()}")
  137. @dataclasses.dataclass
  138. class CppFreeIfNotReusedLine(FreeIfNotReusedLine):
  139. node: ir.Buffer
  140. is_reused: bool = False
  141. def codegen(self, code: IndentedBuffer):
  142. assert (self.node.get_name()) not in V.graph.removed_buffers
  143. if not self.is_reused:
  144. code.writeline(f"{self.node.get_name()}.reset();")
  145. @dataclasses.dataclass
  146. class ReuseLine(MemoryPlanningLine):
  147. node: ir.Buffer
  148. reused_as: ir.Buffer
  149. def plan(self, state: MemoryPlanningState):
  150. assert self.node.get_name() not in V.graph.removed_buffers
  151. assert self.reused_as.get_name() not in V.graph.removed_buffers
  152. return self
  153. def codegen(self, code: IndentedBuffer):
  154. assert self.node.get_name() not in V.graph.removed_buffers
  155. assert self.reused_as.get_name() not in V.graph.removed_buffers
  156. code.writeline(
  157. make_buffer_reuse(
  158. self.node,
  159. self.reused_as,
  160. del_func=lambda name: f"; del {name}",
  161. declare="",
  162. ending="",
  163. as_strided="as_strided",
  164. )
  165. + " # reuse"
  166. )
  167. @dataclasses.dataclass
  168. class CppReuseLine(ReuseLine):
  169. node: ir.Buffer
  170. reused_as: ir.Buffer
  171. def codegen(self, code: IndentedBuffer):
  172. assert self.node.get_name() not in V.graph.removed_buffers
  173. assert self.reused_as.get_name() not in V.graph.removed_buffers
  174. code.writeline(
  175. make_buffer_reuse(
  176. self.node,
  177. self.reused_as,
  178. del_func=lambda name: f"; {name}.reset()",
  179. declare="auto ",
  180. ending=";",
  181. as_strided="at::as_strided",
  182. )
  183. + " // reuse"
  184. )
  185. @dataclasses.dataclass
  186. class FreeLine(MemoryPlanningLine):
  187. node: ir.Buffer
  188. def plan(self, state: MemoryPlanningState):
  189. if self.node.get_name() in V.graph.removed_buffers:
  190. return NullLine()
  191. return self
  192. def codegen(self, code: IndentedBuffer):
  193. assert self.node.get_name() not in V.graph.removed_buffers
  194. code.writeline(f"del {self.node.get_name()}")
  195. class NullLine(MemoryPlanningLine):
  196. pass
  197. class WrapperCodeGen(CodeGen):
  198. """
  199. The outer wrapper that calls the kernels.
  200. """
  201. def __init__(self):
  202. super().__init__()
  203. self._names_iter = count()
  204. self.header = IndentedBuffer()
  205. self.prefix = IndentedBuffer()
  206. self.wrapper_call = IndentedBuffer()
  207. self.kernels = {}
  208. self.lines = []
  209. self.header.splice(
  210. f"""
  211. from ctypes import c_void_p, c_long
  212. import torch
  213. import math
  214. import random
  215. from torch import empty_strided, as_strided, device
  216. from {codecache.__name__} import AsyncCompile
  217. from torch._inductor.select_algorithm import extern_kernels
  218. aten = torch.ops.aten
  219. assert_size_stride = torch._C._dynamo.guards.assert_size_stride
  220. async_compile = AsyncCompile()
  221. """
  222. )
  223. if has_triton():
  224. self.header.splice(
  225. """
  226. import triton
  227. import triton.language as tl
  228. from torch._inductor.triton_ops.autotune import grid
  229. from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
  230. """
  231. )
  232. self.write_prefix()
  233. for name, value in V.graph.constants.items():
  234. # include a hash so our code cache gives different constants different files
  235. hashed = hashlib.sha256(repr(value).encode("utf-8")).hexdigest()
  236. self.header.writeline(f"{name} = None # {hashed}")
  237. self.allocated = set()
  238. self.freed = set()
  239. # maps from reusing buffer to reused buffer
  240. self.reuses = dict()
  241. self.write_get_cuda_stream = functools.lru_cache(None)(
  242. self.write_get_cuda_stream
  243. )
  244. @functools.lru_cache(None)
  245. def add_import_once(line):
  246. self.header.writeline(line)
  247. self.add_import_once = add_import_once
  248. self._metas = {}
  249. def add_meta_once(self, meta):
  250. meta = repr(meta)
  251. if meta not in self._metas:
  252. var = f"meta{len(self._metas)}"
  253. self._metas[meta] = var
  254. self.header.writeline(f"{var} = {meta}")
  255. return self._metas[meta]
  256. @cache_on_self
  257. def get_output_refs(self):
  258. return [x.codegen_reference() for x in V.graph.graph_outputs]
  259. def write_prefix(self):
  260. self.prefix.splice(
  261. """
  262. async_compile.wait(globals())
  263. del async_compile
  264. def call(args):
  265. """
  266. )
  267. with self.prefix.indent():
  268. if config.triton.debug_sync_graph:
  269. self.prefix.writeline("torch.cuda.synchronize()")
  270. inp_len = len(V.graph.graph_inputs.keys())
  271. if inp_len != 0:
  272. lhs = f"{', '.join(V.graph.graph_inputs.keys())}{'' if inp_len != 1 else ','}"
  273. self.prefix.writeline(f"{lhs} = args")
  274. self.prefix.writeline("args.clear()")
  275. for name in V.graph.randomness_seeds:
  276. self.prefix.writeline(
  277. f"torch.randint(2**31, size=(), dtype=torch.int64, out={name})"
  278. )
  279. V.graph.sizevars.codegen(self.prefix, V.graph.graph_inputs)
  280. def append_precomputed_sizes_to_prefix(self):
  281. with self.prefix.indent():
  282. V.graph.sizevars.codegen_precomputed_sizes(self.prefix)
  283. def write_get_cuda_stream(self, index):
  284. name = f"stream{index}"
  285. self.writeline(f"{name} = get_cuda_stream({index})")
  286. return name
  287. def next_kernel_suffix(self):
  288. return f"{next(self._names_iter)}"
  289. def write_allocate_line(self, buffer):
  290. self.writeline(AllocateLine(buffer))
  291. def get_deferred_line(self, name, layout):
  292. return DeferredLine(
  293. name, f"{name} = {layout.view.codegen_reference()} # alias"
  294. )
  295. def codegen_allocation(self, buffer):
  296. name = buffer.get_name()
  297. if name in V.graph.removed_buffers or name in self.allocated:
  298. return
  299. self.allocated.add(name)
  300. if isinstance(
  301. buffer,
  302. (ir.ExternKernelAlloc, ir.MultiOutput),
  303. ):
  304. return
  305. layout = buffer.get_layout()
  306. if isinstance(layout, ir.MutationLayout):
  307. return
  308. if isinstance(layout, ir.AliasedLayout):
  309. assert isinstance(layout.view, ir.ReinterpretView)
  310. if not layout.maybe_guard_aligned():
  311. V.graph.unaligned_buffers.add(name)
  312. self.codegen_allocation(layout.view.data)
  313. allocation = self.get_deferred_line(name, layout)
  314. self.writeline(allocation)
  315. return
  316. self.write_allocate_line(buffer)
  317. def write_del_line(self, name):
  318. self.writeline(f"del {name}")
  319. def write_free_if_not_reused_line(self, buffer):
  320. self.writeline(FreeIfNotReusedLine(buffer))
  321. def codegen_free(self, buffer):
  322. name = buffer.get_name()
  323. # can be freed but not reused
  324. if isinstance(buffer, ir.InputBuffer):
  325. self.write_del_line(name)
  326. return
  327. if not self.can_reuse(buffer):
  328. return
  329. self.freed.add(name)
  330. layout = buffer.get_layout()
  331. if isinstance(layout, (ir.AliasedLayout, ir.MultiOutputLayout)):
  332. self.write_del_line(name)
  333. return
  334. self.write_free_if_not_reused_line(buffer)
  335. def can_reuse(self, buffer):
  336. name = buffer.get_name()
  337. if (
  338. name in V.graph.removed_buffers
  339. or name in V.graph.graph_inputs
  340. or name in V.graph.constants
  341. or name in self.freed
  342. ):
  343. return False
  344. return True
  345. def did_reuse(self, buffer, reused_buffer):
  346. # Check whether a given buffer was reused by a possible reuser in the wrapper codegen
  347. # Can be consulted from inside ir codegen, e.g. to determine whether a copy is needed
  348. return (
  349. buffer.get_name() in self.reuses
  350. and self.reuses[buffer.get_name()] == reused_buffer.get_name()
  351. )
  352. def write_reuse_line(self, input_buffer, output_buffer):
  353. self.writeline(ReuseLine(input_buffer, output_buffer))
  354. def codegen_inplace_reuse(self, input_buffer, output_buffer):
  355. assert buffer_reuse_key(input_buffer) == buffer_reuse_key(output_buffer)
  356. self.codegen_allocation(input_buffer)
  357. self.freed.add(input_buffer.get_name())
  358. self.allocated.add(output_buffer.get_name())
  359. self.reuses[output_buffer.get_name()] = input_buffer.get_name()
  360. self.write_reuse_line(input_buffer, output_buffer)
  361. def codegen_cuda_device_guard_enter(self, device_idx):
  362. self.lines.append(EnterCudaDeviceContextManagerLine(device_idx))
  363. def codegen_cuda_device_guard_exit(self):
  364. self.lines.append(ExitCudaDeviceContextManagerLine())
  365. def generate_return(self, output_refs):
  366. if output_refs:
  367. self.wrapper_call.writeline("return (" + ", ".join(output_refs) + ", )")
  368. else:
  369. self.wrapper_call.writeline("return ()")
  370. def generate_end(self, result):
  371. return
  372. def generate_extern_kernel_out(
  373. self, output_view, codegen_reference, args, kernel, cpp_kernel
  374. ):
  375. if output_view:
  376. args.append(f"out={output_view.codegen_reference()}")
  377. else:
  378. args.append(f"out={codegen_reference}")
  379. self.writeline(f"{kernel}({', '.join(args)})")
  380. @dynamo_timed
  381. def generate(self):
  382. result = IndentedBuffer()
  383. result.splice(self.header)
  384. out_names = V.graph.get_output_names()
  385. with contextlib.ExitStack() as stack:
  386. stack.enter_context(self.wrapper_call.indent())
  387. if config.profiler_mark_wrapper_call:
  388. self.wrapper_call.writeline(
  389. "from torch.profiler import record_function"
  390. )
  391. self.wrapper_call.writeline(
  392. "with record_function('inductor_wrapper_call'):"
  393. )
  394. stack.enter_context(self.wrapper_call.indent())
  395. while (
  396. self.lines
  397. and isinstance(self.lines[-1], MemoryPlanningLine)
  398. and self.lines[-1].node.name not in out_names
  399. ):
  400. # these lines will be pointless
  401. self.lines.pop()
  402. # codegen allocations in two passes
  403. planning_state = MemoryPlanningState()
  404. for i in range(len(self.lines)):
  405. if isinstance(self.lines[i], MemoryPlanningLine):
  406. self.lines[i] = self.lines[i].plan(planning_state)
  407. device_cm_stack = contextlib.ExitStack()
  408. for line in self.lines:
  409. if isinstance(line, MemoryPlanningLine):
  410. line.codegen(self.wrapper_call)
  411. elif isinstance(line, EnterCudaDeviceContextManagerLine):
  412. line.codegen(self.wrapper_call)
  413. device_cm_stack.enter_context(self.wrapper_call.indent())
  414. self.wrapper_call.writeline(
  415. f"torch.cuda.set_device({line.device_idx}) # no-op to ensure context"
  416. )
  417. elif isinstance(line, ExitCudaDeviceContextManagerLine):
  418. device_cm_stack.close()
  419. else:
  420. self.wrapper_call.writeline(line)
  421. output_refs = self.get_output_refs()
  422. if config.triton.debug_sync_graph:
  423. self.wrapper_call.writeline("torch.cuda.synchronize()")
  424. self.generate_return(output_refs)
  425. self.append_precomputed_sizes_to_prefix()
  426. result.splice(self.prefix)
  427. with result.indent():
  428. result.splice(self.wrapper_call)
  429. self.generate_end(result)
  430. self.add_benchmark_harness(result)
  431. return result.getvalue()
  432. def add_benchmark_harness(self, output):
  433. """
  434. Append a benchmark harness to generated code for debugging
  435. """
  436. if not config.benchmark_harness:
  437. return
  438. def add_fake_input(name, shape, stride, device, dtype):
  439. output.writeline(
  440. f"{name} = rand_strided("
  441. f"{V.graph.sizevars.codegen_benchmark_shape_tuple(shape)}, "
  442. f"{V.graph.sizevars.codegen_benchmark_shape_tuple(stride)}, "
  443. f"device='{device}', dtype={dtype})"
  444. )
  445. output.writelines(["", "", 'if __name__ == "__main__":'])
  446. with output.indent():
  447. output.splice(
  448. """
  449. from torch._dynamo.testing import rand_strided
  450. from torch._inductor.utils import print_performance
  451. """,
  452. strip=True,
  453. )
  454. for name, value in V.graph.constants.items():
  455. add_fake_input(
  456. name, value.size(), value.stride(), value.device, value.dtype
  457. )
  458. for name, value in V.graph.graph_inputs.items():
  459. shape = [V.graph.sizevars.size_hint(x) for x in value.get_size()]
  460. stride = [V.graph.sizevars.size_hint(x) for x in value.get_stride()]
  461. add_fake_input(
  462. name, shape, stride, value.get_device(), value.get_dtype()
  463. )
  464. output.writeline(
  465. f"print_performance(lambda: call([{', '.join(V.graph.graph_inputs.keys())}]))"
  466. )
  467. def define_kernel(self, name: str, kernel: str):
  468. self.header.splice(f"\n\n{name} = {kernel}")
  469. def load_kernel(self, name: str = None, kernel: str = None, arg_types: List = None):
  470. return
  471. def wrap_kernel_call(self, name, call_args):
  472. return "{}({})".format(name, ", ".join(call_args))
  473. def generate_kernel_call(self, name, call_args):
  474. self.writeline(
  475. self.wrap_kernel_call(name, call_args),
  476. )
  477. def call_kernel(self, name: str, kernel: Kernel):
  478. tmp = IndentedBuffer()
  479. kernel.call_kernel(self, tmp, name)
  480. for line in tmp.getvalue().split("\n"):
  481. line = line.strip()
  482. if line:
  483. self.writeline(line)
  484. def writeline(self, line):
  485. self.lines.append(line)
  486. class CppWrapperCodeGen(WrapperCodeGen):
  487. """
  488. The outer wrapper that calls the kernels.
  489. """
  490. call_func_id = count()
  491. def __init__(self):
  492. self._call_func_id = next(CppWrapperCodeGen.call_func_id)
  493. super().__init__()
  494. @cache_on_self
  495. def get_output_refs(self):
  496. def has_cpp_codegen_func(x):
  497. return hasattr(x, "cpp_wrapper_codegen_reference") and callable(
  498. x.cpp_wrapper_codegen_reference
  499. )
  500. return [
  501. x.cpp_wrapper_codegen_reference()
  502. if has_cpp_codegen_func(x)
  503. else x.codegen_reference()
  504. for x in V.graph.graph_outputs
  505. ]
  506. def write_prefix(self):
  507. self.prefix.splice(
  508. """
  509. async_compile.wait(globals())
  510. del async_compile
  511. from torch.utils.cpp_extension import load_inline
  512. wrapper = (
  513. '''
  514. #include <dlfcn.h>
  515. #include <assert.h>
  516. template <typename KernelFunc>
  517. KernelFunc load_cpp_kernel(const char* so_filename) {
  518. KernelFunc kernel_cpp;
  519. auto kernel_cpp_lib = dlopen(so_filename, RTLD_NOW);
  520. assert(kernel_cpp_lib != nullptr);
  521. *(void **) (&kernel_cpp) = dlsym(kernel_cpp_lib, "kernel");
  522. return kernel_cpp;
  523. }
  524. """
  525. )
  526. with self.wrapper_call.indent():
  527. inputs_len = len(V.graph.graph_inputs.keys())
  528. output_refs = self.get_output_refs()
  529. if output_refs:
  530. if len(output_refs) == 1:
  531. output_types = "at::Tensor"
  532. else:
  533. output_types = "std::vector<at::Tensor>"
  534. else:
  535. output_types = "void"
  536. inputs_types = "std::vector<at::Tensor>"
  537. self.wrapper_call.writeline(
  538. f"{output_types} call_{self._call_func_id}({inputs_types} args) {{"
  539. )
  540. if inputs_len != 0:
  541. inputs_keys_str = ", ".join(V.graph.graph_inputs.keys())
  542. self.wrapper_call.writeline(f"at::Tensor {inputs_keys_str};")
  543. for idx, input_key in enumerate(V.graph.graph_inputs.keys()):
  544. self.wrapper_call.writeline(f"{input_key} = args[{idx}];")
  545. for name in V.graph.randomness_seeds:
  546. self.wrapper_call.writeline(f"at::Tensor {name};")
  547. self.wrapper_call.writeline(
  548. f"{name} = at::randint(std::pow(2, 31), {{}}, at::ScalarType::Long);"
  549. )
  550. V.graph.sizevars.codegen(self.wrapper_call, V.graph.graph_inputs)
  551. def write_allocate_line(self, buffer):
  552. self.writeline(CppAllocateLine(buffer))
  553. def write_del_line(self, name):
  554. self.writeline(f"{name}.reset();")
  555. return
  556. def write_free_if_not_reused_line(self, buffer):
  557. self.writeline(CppFreeIfNotReusedLine(buffer))
  558. return
  559. def write_reuse_line(self, input_buffer, output_buffer):
  560. self.writeline(CppReuseLine(input_buffer, output_buffer))
  561. def get_deferred_line(self, name, layout):
  562. return DeferredLine(
  563. name, f"auto {name} = {layout.view.codegen_reference()}; // alias"
  564. )
  565. def get_kernel_path(self, code):
  566. from ..codecache import pick_vec_isa
  567. picked_vec_isa = pick_vec_isa()
  568. ext = "so"
  569. extra = cpp_compile_command("i", "o", vec_isa=picked_vec_isa)
  570. # \n is required to match with the CodeCache behavior
  571. # For reductions, the code string gotten from code.getvalue() will use backslash '\'
  572. # at the end of lines for readability purpose:
  573. # #pragma omp declare reduction(xxx :\
  574. # omp_out.value = xxx,\
  575. # While the code string loaded during the execution will escape the backslash '\':
  576. # #pragma omp declare reduction(xxx : omp_out.value = xxx,
  577. # Use code.getrawvalue() here to escape the backslash to
  578. # make sure the same code string is used during compilation and execution,
  579. # so that the hash value is the same.
  580. source_code = "\n" + code.getrawvalue()
  581. _, _, kernel_path = get_code_path(source_code, ext, extra)
  582. return kernel_path
  583. def load_kernel(self, name: str = None, kernel: str = None, arg_types: List = None):
  584. kernel_path = self.get_kernel_path(kernel)
  585. self.writeline(
  586. f'static auto {name} = load_cpp_kernel<void (*)({arg_types})>("{kernel_path}");'
  587. )
  588. def wrap_kernel_call(self, name, call_args):
  589. return "{}({});".format(name, ", ".join(call_args))
  590. def generate_return(self, output_refs):
  591. if output_refs:
  592. if len(output_refs) == 1:
  593. self.wrapper_call.writeline("return " + output_refs[0] + "; }''' )")
  594. else:
  595. self.wrapper_call.writeline(
  596. "return std::vector<at::Tensor>({"
  597. + ", ".join(output_refs)
  598. + "}); }''' )"
  599. )
  600. else:
  601. self.wrapper_call.writeline("return; }''' )")
  602. def generate_end(self, result):
  603. shared = codecache.get_shared()
  604. warning_all_flag = codecache.get_warning_all_flag()
  605. cpp_flags = codecache.cpp_flags()
  606. ipaths, lpaths, libs, macros = codecache.get_include_and_linking_paths()
  607. optimization_flags = codecache.optimization_flags()
  608. use_custom_generated_macros = codecache.use_custom_generated_macros()
  609. extra_cflags = f"{cpp_flags} {optimization_flags} {warning_all_flag} {macros} {use_custom_generated_macros}"
  610. extra_ldflags = f"{shared} {lpaths} {libs}"
  611. extra_include_paths = f"{ipaths}"
  612. # get the hash of the wrapper code to name the extension
  613. wrapper_call_hash = codecache.code_hash(self.wrapper_call.getvalue())
  614. result.splice(
  615. f"""
  616. module = load_inline(
  617. name='inline_extension_{wrapper_call_hash}',
  618. cpp_sources=[wrapper],
  619. functions=['call_{self._call_func_id}'],
  620. extra_cflags=['{extra_cflags}'],
  621. extra_ldflags=['{extra_ldflags}'],
  622. extra_include_paths=['{extra_include_paths}'])
  623. """
  624. )
  625. # Wrap the func to support setting result._boxed_call = True
  626. result.splice(
  627. f"""
  628. def _wrap_func(f):
  629. def g(args):
  630. return f(args)
  631. return g
  632. call = _wrap_func(module.call_{self._call_func_id})
  633. """
  634. )
  635. def generate_extern_kernel_out(
  636. self, output_view, codegen_reference, args, kernel, cpp_kernel
  637. ):
  638. if output_view:
  639. output_as_strided = f"{output_view.codegen_reference()}"
  640. output_name = f"{output_view.get_name()}_as_strided"
  641. self.writeline(f"auto {output_name} = {output_as_strided};")
  642. args.insert(0, output_name)
  643. else:
  644. args.insert(0, f"{codegen_reference}")
  645. self.writeline(f"{cpp_kernel}({', '.join(args)});")