gen_executorch.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779
  1. import argparse
  2. import os
  3. import pathlib
  4. from collections import defaultdict
  5. from dataclasses import dataclass
  6. from typing import Callable, Dict, List, Optional, Sequence, TextIO, Tuple, Union
  7. import yaml
  8. # Parse native_functions.yaml into a sequence of NativeFunctions and Backend Indices.
  9. from torchgen import dest
  10. from torchgen.api import cpp as aten_cpp
  11. from torchgen.api.types import CppSignature, CppSignatureGroup, CType, NamedCType
  12. from torchgen.context import method_with_native_function, with_native_function_and_index
  13. from torchgen.executorch.api import et_cpp
  14. from torchgen.executorch.api.custom_ops import (
  15. ComputeNativeFunctionStub,
  16. gen_custom_ops_registration,
  17. )
  18. from torchgen.executorch.api.types import ExecutorchCppSignature
  19. from torchgen.executorch.api.unboxing import Unboxing
  20. from torchgen.gen import (
  21. get_custom_build_selector,
  22. get_native_function_declarations,
  23. get_native_function_schema_registrations,
  24. LineLoader,
  25. parse_native_yaml,
  26. ParsedYaml,
  27. )
  28. from torchgen.model import (
  29. BackendIndex,
  30. BackendMetadata,
  31. DispatchKey,
  32. is_cuda_dispatch_key,
  33. Location,
  34. NativeFunction,
  35. NativeFunctionsGroup,
  36. OperatorName,
  37. Variant,
  38. )
  39. from torchgen.selective_build.selector import SelectiveBuilder
  40. from torchgen.utils import (
  41. context,
  42. FileManager,
  43. make_file_manager,
  44. mapMaybe,
  45. NamespaceHelper,
  46. )
  47. def static_dispatch(
  48. sig: Union[CppSignature, ExecutorchCppSignature],
  49. f: NativeFunction,
  50. backend_indices: List[BackendIndex],
  51. ) -> str:
  52. """
  53. For a given `NativeFunction`, find out the corresponding native function and dispatch to it. If zero or more than one
  54. native function exists, error out. A simplified version of register_dispatch_key.py
  55. Arguments:
  56. sig: A CppSignature for this native function we want to use.
  57. f: NativeFunction to generate static dispatch.
  58. backend_indices: All available backends.
  59. Return:
  60. C++ code to call backend-specific functions, e.g., "return at::native::add(self, other, scale);"
  61. """
  62. if len(backend_indices) == 0 or f.manual_kernel_registration:
  63. return ""
  64. backends = [b for b in backend_indices if b.has_kernel(f)]
  65. static_block = None
  66. if len(backends) == 1:
  67. backend_metadata = backends[0].get_kernel(f)
  68. if backend_metadata:
  69. args = ", ".join(a.name for a in sig.arguments())
  70. # Here we are assuming there's no difference between CppSignature and NativeSignature for Executorch.
  71. static_block = f"return ::{backend_metadata.cpp_namespace}::{backend_metadata.kernel}({args});"
  72. else:
  73. static_block = f"""
  74. ET_ASSERT_UNREACHABLE_MSG("The number of native function(s) binding to {f.func.name} is {len(backends)}.");
  75. """
  76. return f"""
  77. // {f.namespace}::{f.func}
  78. TORCH_API inline {sig.decl()} {{
  79. {static_block}
  80. }}
  81. """
  82. # Generates Functions.h, which provides the functional public C++ API,
  83. # and the scaffolding to call into the dispatcher from these functions.
  84. @dataclass(frozen=True)
  85. class ComputeFunction:
  86. static_dispatch_backend_indices: List[BackendIndex]
  87. selector: SelectiveBuilder
  88. use_aten_lib: bool
  89. is_custom_op: Callable[[NativeFunction], bool]
  90. @method_with_native_function
  91. def __call__(self, f: NativeFunction) -> Optional[str]:
  92. if not self.selector.is_root_operator(f"{f.namespace}::{f.func.name}"):
  93. return None
  94. if Variant.function not in f.variants:
  95. return None
  96. sig: Union[CppSignature, ExecutorchCppSignature] = (
  97. CppSignatureGroup.from_native_function(
  98. f, method=False, fallback_binding=f.manual_cpp_binding
  99. ).most_faithful_signature()
  100. if self.use_aten_lib
  101. else ExecutorchCppSignature.from_native_function(f)
  102. )
  103. if self.use_aten_lib and not self.is_custom_op(f):
  104. comma = ", "
  105. return f"""
  106. // {f.namespace}::{f.func}
  107. TORCH_API inline {sig.decl()} {{
  108. return at::{sig.name()}({comma.join(e.name for e in sig.arguments())});
  109. }}
  110. """
  111. else:
  112. return static_dispatch(
  113. sig,
  114. f,
  115. backend_indices=self.static_dispatch_backend_indices,
  116. )
  117. # Generates RegisterCodegenUnboxedKernels.cpp.
  118. @dataclass(frozen=True)
  119. class ComputeCodegenUnboxedKernels:
  120. selector: SelectiveBuilder
  121. use_aten_lib: bool
  122. @method_with_native_function
  123. def __call__(self, f: NativeFunction) -> str:
  124. if not self.selector.is_root_operator(f"{f.namespace}::{f.func.name}"):
  125. return ""
  126. sig: Union[CppSignature, ExecutorchCppSignature]
  127. argument_type_gen: Callable[..., NamedCType]
  128. return_type_gen: Callable[..., CType]
  129. if self.use_aten_lib:
  130. sig = CppSignatureGroup.from_native_function(
  131. f, method=False, fallback_binding=f.manual_cpp_binding
  132. ).most_faithful_signature()
  133. argument_type_gen = aten_cpp.argumenttype_type
  134. return_type_gen = aten_cpp.returns_type
  135. else:
  136. sig = ExecutorchCppSignature.from_native_function(f)
  137. argument_type_gen = et_cpp.argumenttype_type
  138. return_type_gen = et_cpp.returns_type
  139. # parse arguments into C++ code
  140. binding_list, code_list = Unboxing(
  141. argument_type_gen=argument_type_gen
  142. ).convert_arguments(sig.arguments())
  143. # for each C++ argument, generate the conversion code
  144. code_connector = "\n\t"
  145. arg_connector = ", "
  146. args_str = f"{arg_connector.join(e.name for e in binding_list)}"
  147. if len(f.func.returns) == 0:
  148. if len(f.func.arguments.out) == 0:
  149. raise Exception(
  150. f"Can't handle native function {f.func} with no returns and no out yet."
  151. )
  152. out = f.func.arguments.out[0]
  153. return_assignment = f"""stack[{len(binding_list)}] = &{out.name};"""
  154. ret_prefix = ""
  155. else:
  156. if len(f.func.arguments.out) == 0:
  157. return_assignment = (
  158. f"""*stack[{len(binding_list)}] = EValue(result_);"""
  159. )
  160. ret_prefix = return_type_gen(f.func.returns).cpp_type() + " result_ = "
  161. else:
  162. return_assignment = ""
  163. ret_prefix = ""
  164. return f"""
  165. Operator(
  166. "{f.namespace}::{f.func.name}",
  167. [](EValue** stack) {{
  168. {code_connector.join(code_list)}
  169. EXECUTORCH_SCOPE_PROF("native_call_{f.func.name}");
  170. {ret_prefix}torch::executor::{f.namespace}::{sig.name()}({args_str});
  171. {return_assignment}
  172. }}
  173. ),
  174. """
  175. def gen_unboxing(
  176. *,
  177. native_functions: Sequence[NativeFunction],
  178. cpu_fm: FileManager,
  179. selector: SelectiveBuilder,
  180. use_aten_lib: bool,
  181. ) -> None:
  182. def key_func(fn: Union[NativeFunction, NativeFunctionsGroup]) -> str:
  183. return fn.root_name
  184. cpu_fm.write_sharded(
  185. "RegisterCodegenUnboxedKernels.cpp",
  186. native_functions,
  187. key_fn=key_func,
  188. env_callable=lambda fn: {
  189. "unboxed_ops": [ComputeCodegenUnboxedKernels(selector, use_aten_lib)(fn)],
  190. },
  191. num_shards=1,
  192. sharded_keys={"unboxed_ops"},
  193. )
  194. @with_native_function_and_index
  195. def compute_native_function_declaration(
  196. g: Union[NativeFunctionsGroup, NativeFunction], backend_index: BackendIndex
  197. ) -> List[str]:
  198. assert isinstance(g, NativeFunction)
  199. sig = ExecutorchCppSignature.from_native_function(f=g)
  200. metadata = backend_index.get_kernel(g)
  201. if metadata is None:
  202. return []
  203. prefix = "static" if backend_index.external else "TORCH_API"
  204. return [f"{prefix} {sig.decl(name=metadata.kernel)};"]
  205. def gen_functions_declarations(
  206. *,
  207. native_functions: Sequence[NativeFunction],
  208. static_dispatch_idx: List[BackendIndex],
  209. selector: SelectiveBuilder,
  210. use_aten_lib: bool,
  211. custom_ops_native_functions: Optional[Sequence[NativeFunction]] = None,
  212. ) -> str:
  213. """
  214. Generates namespace separated C++ function API inline declaration/definitions.
  215. Native functions are grouped by namespaces and the generated code is wrapped inside
  216. namespace blocks.
  217. E.g., for `custom_1::foo.out` in yaml file we will generate a C++ API as a symbol
  218. in `torch::executor::custom_1::foo_out`. This way we avoid symbol conflict when
  219. the other `custom_2::foo.out` is available.
  220. """
  221. ns_grouped_functions = defaultdict(list)
  222. for native_function in native_functions:
  223. ns_grouped_functions[native_function.namespace].append(native_function)
  224. functions_declarations = ""
  225. newline = "\n"
  226. for namespace in ns_grouped_functions:
  227. ns_helper = NamespaceHelper(
  228. namespace_str=namespace,
  229. entity_name="",
  230. max_level=3,
  231. )
  232. declarations = list(
  233. mapMaybe(
  234. ComputeFunction(
  235. static_dispatch_backend_indices=static_dispatch_idx,
  236. selector=selector,
  237. use_aten_lib=use_aten_lib,
  238. is_custom_op=lambda f: custom_ops_native_functions is not None
  239. and f in custom_ops_native_functions,
  240. ),
  241. ns_grouped_functions[namespace],
  242. )
  243. )
  244. functions_declarations += f"""
  245. {ns_helper.prologue}
  246. {newline.join(declarations)}
  247. {ns_helper.epilogue}
  248. """
  249. return functions_declarations
  250. def gen_headers(
  251. *,
  252. native_functions: Sequence[NativeFunction],
  253. custom_ops_native_functions: Sequence[NativeFunction],
  254. static_dispatch_idx: List[BackendIndex],
  255. selector: SelectiveBuilder,
  256. backend_indices: Dict[DispatchKey, BackendIndex],
  257. cpu_fm: FileManager,
  258. use_aten_lib: bool,
  259. ) -> None:
  260. aten_headers = ["#include <ATen/Functions.h>"]
  261. if custom_ops_native_functions:
  262. cpu_fm.write_with_template(
  263. "CustomOpsNativeFunctions.h",
  264. "NativeFunctions.h",
  265. lambda: {
  266. "nativeFunctions_declarations": get_native_function_declarations(
  267. grouped_native_functions=custom_ops_native_functions,
  268. backend_indices=backend_indices,
  269. native_function_decl_gen=dest.compute_native_function_declaration,
  270. ),
  271. },
  272. )
  273. aten_headers.append('#include "CustomOpsNativeFunctions.h"')
  274. cpu_fm.write(
  275. "Functions.h",
  276. lambda: {
  277. "static_dispatch_extra_headers": aten_headers
  278. if use_aten_lib
  279. else ['#include "NativeFunctions.h"'],
  280. "Functions_declarations": gen_functions_declarations(
  281. native_functions=native_functions,
  282. static_dispatch_idx=static_dispatch_idx,
  283. selector=selector,
  284. use_aten_lib=use_aten_lib,
  285. custom_ops_native_functions=custom_ops_native_functions,
  286. ),
  287. },
  288. )
  289. cpu_fm.write(
  290. "NativeFunctions.h",
  291. lambda: {
  292. "nativeFunctions_declarations": get_native_function_declarations(
  293. grouped_native_functions=native_functions,
  294. backend_indices=backend_indices,
  295. native_function_decl_gen=dest.compute_native_function_declaration
  296. if use_aten_lib
  297. else compute_native_function_declaration,
  298. ),
  299. },
  300. )
  301. def gen_custom_ops(
  302. *,
  303. native_functions: Sequence[NativeFunction],
  304. selector: SelectiveBuilder,
  305. backend_indices: Dict[DispatchKey, BackendIndex],
  306. cpu_fm: FileManager,
  307. rocm: bool,
  308. ) -> None:
  309. dispatch_key = DispatchKey.CPU
  310. backend_index = backend_indices[dispatch_key]
  311. (
  312. anonymous_definition,
  313. static_init_dispatch_registrations,
  314. ) = gen_custom_ops_registration(
  315. native_functions=native_functions,
  316. selector=selector,
  317. backend_index=backend_index,
  318. rocm=rocm,
  319. )
  320. cpu_fm.write_with_template(
  321. f"Register{dispatch_key}CustomOps.cpp",
  322. "RegisterDispatchKeyCustomOps.cpp",
  323. lambda: {
  324. "ops_headers": '#include "CustomOpsNativeFunctions.h"',
  325. "DispatchKey": dispatch_key,
  326. "dispatch_namespace": dispatch_key.lower(),
  327. "dispatch_namespaced_definitions": "",
  328. "dispatch_anonymous_definitions": anonymous_definition,
  329. "static_init_dispatch_registrations": static_init_dispatch_registrations,
  330. },
  331. )
  332. cpu_fm.write_with_template(
  333. f"Register{dispatch_key}Stub.cpp",
  334. "RegisterDispatchKeyCustomOps.cpp",
  335. lambda: {
  336. "ops_headers": "",
  337. "DispatchKey": dispatch_key,
  338. "dispatch_namespace": dispatch_key.lower(),
  339. "dispatch_namespaced_definitions": "",
  340. "dispatch_anonymous_definitions": list(
  341. mapMaybe(ComputeNativeFunctionStub(), native_functions)
  342. ),
  343. "static_init_dispatch_registrations": static_init_dispatch_registrations,
  344. },
  345. )
  346. (
  347. aten_schema_registrations,
  348. schema_registrations,
  349. ) = get_native_function_schema_registrations(
  350. native_functions=native_functions,
  351. schema_selector=selector,
  352. )
  353. cpu_fm.write(
  354. "RegisterSchema.cpp",
  355. lambda: {
  356. "schema_registrations": schema_registrations,
  357. "aten_schema_registrations": aten_schema_registrations,
  358. },
  359. )
  360. def translate_native_yaml(
  361. tags_yaml_path: str,
  362. aten_yaml_path: str,
  363. native_yaml_path: Optional[str],
  364. use_aten_lib: bool,
  365. out_file: TextIO,
  366. ) -> None:
  367. """Translates Executorch DSL dialect to use the same syntax as
  368. native_functions.yaml. The major difference is that Executorch DSL dialect
  369. supports "op" key, where it refers to the operator name in native_functions.yaml.
  370. For example, a functions.yaml may have the following entry:
  371. - op: add.out
  372. ...
  373. It needs to be translated to the following:
  374. - func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
  375. ...
  376. We go in aten_yaml_path and find the operator schema for "add.out" and add it
  377. to the original functions.yaml. We also add required field "variants", where for
  378. Executorch it will always be "function".
  379. For ATen mode we don't have to do the translation because native_yaml_path is
  380. the same as native_functions.yaml.
  381. Args:
  382. tags_yaml_path: Path to a tags.yaml file to satisfy codegen parsing.
  383. It is not optional.
  384. aten_yaml_path: Path to ATen operator yaml file native_functions.yaml.
  385. native_yaml_path: Path to a functions.yaml file to parse.
  386. If the path does not exist in the filesystem, it is treated as an
  387. empty file. If `custom_ops_yaml_path` exists, the contents of that
  388. file are appended to the yaml input to be parsed.
  389. use_aten_lib: We use this flag to determine if we want to generate native
  390. functions. In ATen mode we should generate out= variants.
  391. out_file: The IO object that we are writing into.
  392. Returns:
  393. None
  394. """
  395. if use_aten_lib:
  396. with open(aten_yaml_path, "r") as aten_yaml:
  397. out_file.writelines(aten_yaml.readlines())
  398. return
  399. aten_parsed_yaml = parse_native_yaml(
  400. aten_yaml_path,
  401. tags_yaml_path,
  402. None,
  403. skip_native_fns_gen=False,
  404. )
  405. aten_native_functions = aten_parsed_yaml.native_functions
  406. schema_dict = {
  407. f"{f.namespace}::{f.func.name}": str(f.func) for f in aten_native_functions
  408. }
  409. if (
  410. not native_yaml_path
  411. or not os.path.exists(native_yaml_path)
  412. or os.stat(native_yaml_path).st_size == 0
  413. ):
  414. return
  415. with open(native_yaml_path, "r") as native_yaml:
  416. native_es = yaml.load(native_yaml, Loader=LineLoader)
  417. if not native_es:
  418. return
  419. for e in native_es:
  420. assert isinstance(e.get("__line__"), int), e
  421. loc = Location(native_yaml_path, e.pop("__line__"))
  422. with context(lambda: f"in {loc}:\n "):
  423. if "variants" not in e:
  424. e["variants"] = "function"
  425. if "func" in e:
  426. continue
  427. assert isinstance(e.get("op"), str), e
  428. opname = e.pop("op")
  429. if "::" not in opname:
  430. opname = "aten::" + opname
  431. assert opname in schema_dict
  432. e["func"] = schema_dict.get(opname)
  433. yaml.dump(native_es, out_file, width=1000)
  434. def convert_backend_indices(
  435. bs: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]]
  436. ) -> Dict[DispatchKey, BackendIndex]:
  437. indices: Dict[DispatchKey, BackendIndex] = defaultdict(
  438. lambda: BackendIndex(
  439. dispatch_key=DispatchKey.Undefined,
  440. use_out_as_primary=True,
  441. external=False,
  442. device_guard=False,
  443. index={},
  444. )
  445. )
  446. for k, v in bs.items():
  447. indices[k] = BackendIndex(
  448. dispatch_key=k,
  449. use_out_as_primary=True,
  450. external=False,
  451. # Only cuda-like devices in tree require device guards
  452. device_guard=is_cuda_dispatch_key(k),
  453. index=v,
  454. )
  455. return indices
  456. def parse_yaml(
  457. path: Optional[str],
  458. tags_yaml_path: str,
  459. function_filter: Callable[[NativeFunction], bool],
  460. skip_native_fns_gen: bool = False,
  461. ) -> Tuple[
  462. List[NativeFunction], Dict[DispatchKey, Dict[OperatorName, BackendMetadata]]
  463. ]:
  464. if path and os.path.exists(path) and os.stat(path).st_size > 0:
  465. parsed_yaml = parse_native_yaml(
  466. path,
  467. tags_yaml_path,
  468. None,
  469. skip_native_fns_gen=skip_native_fns_gen,
  470. )
  471. native_functions = list(filter(function_filter, parsed_yaml.native_functions))
  472. op_names = [f.func.name for f in native_functions]
  473. def map_index(
  474. m: Dict[OperatorName, BackendMetadata]
  475. ) -> Dict[OperatorName, BackendMetadata]:
  476. return {op: m[op] for op in m if op in op_names}
  477. backend_indices = {
  478. k: map_index(b.index) for (k, b) in parsed_yaml.backend_indices.items()
  479. }
  480. return native_functions, backend_indices
  481. else:
  482. return [], {}
  483. def parse_yaml_files(
  484. tags_yaml_path: str,
  485. aten_yaml_path: str,
  486. native_yaml_path: Optional[str],
  487. custom_ops_yaml_path: Optional[str],
  488. selector: SelectiveBuilder,
  489. use_aten_lib: bool,
  490. ) -> Tuple[ParsedYaml, Optional[ParsedYaml]]:
  491. """Parses functions.yaml and custom_ops.yaml files.
  492. Args:
  493. tags_yaml_path: Path to a tags.yaml file to satisfy codegen parsing.
  494. It is not optional.
  495. aten_yaml_path: Path to ATen operator yaml file native_functions.yaml.
  496. native_yaml_path: Path to a functions.yaml file to parse.
  497. If the path does not exist in the filesystem, it is treated as an
  498. empty file. If `custom_ops_yaml_path` exists, the contents of that
  499. file are appended to the yaml input to be parsed.
  500. custom_ops_yaml_path: Path to a custom_ops.yaml file to parse. If
  501. the path does not exist in the filesystem, it is ignored.
  502. selector: For selective build.
  503. use_aten_lib: We use this flag to determine if we want to generate native
  504. functions. In ATen mode we should generate out= variants.
  505. Returns:
  506. A tuple with two elements:
  507. [0]: The parsed results of concatenating the contents of
  508. `native_yaml_path` and `custom_ops_yaml_path`.
  509. [1]: The parsed results of the contents of `custom_ops_yaml_path`, if
  510. present. If not present, None.
  511. """
  512. import tempfile
  513. # only include selected ops, this is because we want to avoid
  514. def function_filter(f: NativeFunction) -> bool:
  515. return selector.is_native_function_selected(f)
  516. with tempfile.TemporaryDirectory() as tmpdirname:
  517. translated_yaml_path = os.path.join(tmpdirname, "translated.yaml")
  518. with open(translated_yaml_path, "w") as translated:
  519. translate_native_yaml(
  520. tags_yaml_path,
  521. aten_yaml_path,
  522. native_yaml_path,
  523. use_aten_lib,
  524. translated,
  525. )
  526. translated_functions, translated_backend_indices = parse_yaml(
  527. translated_yaml_path, tags_yaml_path, function_filter, not use_aten_lib
  528. )
  529. custom_ops_functions, custom_ops_backend_indices = parse_yaml(
  530. custom_ops_yaml_path, tags_yaml_path, function_filter, True
  531. )
  532. combined_functions = translated_functions + custom_ops_functions
  533. combined_backend_indices: Dict[
  534. DispatchKey, Dict[OperatorName, BackendMetadata]
  535. ] = defaultdict(dict)
  536. combined_backend_indices.update(translated_backend_indices)
  537. for dk in custom_ops_backend_indices:
  538. if dk not in combined_backend_indices:
  539. combined_backend_indices.update({dk: custom_ops_backend_indices[dk]})
  540. else:
  541. combined_backend_indices[dk] = {
  542. **combined_backend_indices[dk],
  543. **custom_ops_backend_indices[dk],
  544. }
  545. combined_yaml = ParsedYaml(
  546. combined_functions, convert_backend_indices(combined_backend_indices)
  547. )
  548. custom_ops_parsed_yaml = ParsedYaml(
  549. custom_ops_functions, convert_backend_indices(custom_ops_backend_indices)
  550. )
  551. return combined_yaml, custom_ops_parsed_yaml
  552. def main() -> None:
  553. parser = argparse.ArgumentParser(description="Generate operator source files")
  554. # Although we don't refer to --source-path directly, make_file_manager()
  555. # expects it to point to a directory that contains a templates/ subdirectory
  556. # containing the file templates.
  557. parser.add_argument(
  558. "-s",
  559. "--source-path",
  560. help="path to source directory for kernel templates",
  561. )
  562. parser.add_argument(
  563. "--functions-yaml-path",
  564. "--functions_yaml_path",
  565. help="path to the functions.yaml file to use. Optional, but at least "
  566. "one of --functions-yaml-path and --custom-ops-yaml-path must be "
  567. "specified.",
  568. )
  569. parser.add_argument(
  570. "--custom-ops-yaml-path",
  571. "--custom_ops_yaml_path",
  572. help="path to the custom_ops.yaml file to use. Optional, but at least "
  573. "one of --functions-yaml-path and --custom-ops-yaml-path must be "
  574. "specified.",
  575. )
  576. parser.add_argument(
  577. "--aten-yaml-path",
  578. "--aten_yaml_path",
  579. help="path to native_functions.yaml file.",
  580. )
  581. # Note that make_file_manager() also looks at --install-dir.
  582. parser.add_argument(
  583. "-d",
  584. "--install-dir",
  585. "--install_dir",
  586. help="output directory",
  587. default="build/generated",
  588. )
  589. parser.add_argument(
  590. "-o",
  591. "--output-dependencies",
  592. help="output a list of dependencies into the given file and exit",
  593. )
  594. # Although we don't refer to --dry-run directly, make_file_manager() looks
  595. # for it.
  596. parser.add_argument(
  597. "--dry-run",
  598. action="store_true",
  599. help="run without writing any files (still updates outputs)",
  600. )
  601. parser.add_argument(
  602. "--static-dispatch-backend",
  603. "--static_dispatch_backend",
  604. nargs="*",
  605. help="generate static dispatch code for the specific backend (if set)",
  606. )
  607. parser.add_argument(
  608. "--op-registration-whitelist",
  609. "--op_registration_whitelist",
  610. nargs="*",
  611. help="filter op registrations by the whitelist (if set); "
  612. "each item is `namespace`::`operator name` without overload name; "
  613. "e.g.: aten::empty aten::conv2d ...",
  614. )
  615. parser.add_argument(
  616. "--op-selection-yaml-path",
  617. "--op_selection_yaml_path",
  618. help="Provide a path to the operator selection (for custom build) YAML "
  619. "that contains the information about the set of selected operators "
  620. "and their categories (training, ...). Each operator is either a "
  621. "full operator name with overload or just a bare operator name. "
  622. "The operator names also contain the namespace prefix (e.g. aten::)",
  623. )
  624. parser.add_argument(
  625. "--tags-path",
  626. help="Path to tags.yaml. Required by yaml parsing in codegen system.",
  627. )
  628. parser.add_argument(
  629. "--rocm",
  630. action="store_true",
  631. help="reinterpret CUDA as ROCm/HIP and adjust filepaths accordingly",
  632. )
  633. parser.add_argument(
  634. "--use-aten-lib",
  635. "--use_aten_lib",
  636. action="store_true",
  637. help="a boolean flag to indicate whether we use ATen kernels or not, in the future this flag will be per "
  638. "operator",
  639. )
  640. parser.add_argument(
  641. "--generate",
  642. type=str,
  643. nargs="*",
  644. choices=["headers", "sources"],
  645. default=["headers", "sources"],
  646. help="Generate only a subset of files",
  647. )
  648. options = parser.parse_args()
  649. assert options.tags_path, "tags.yaml is required by codegen yaml parsing."
  650. selector = get_custom_build_selector(
  651. options.op_registration_whitelist,
  652. options.op_selection_yaml_path,
  653. )
  654. parsed_yaml, custom_ops_parsed_yaml = parse_yaml_files(
  655. aten_yaml_path=options.aten_yaml_path,
  656. tags_yaml_path=options.tags_path,
  657. native_yaml_path=options.functions_yaml_path,
  658. custom_ops_yaml_path=options.custom_ops_yaml_path,
  659. selector=selector,
  660. use_aten_lib=options.use_aten_lib,
  661. )
  662. native_functions, backend_indices = (
  663. parsed_yaml.native_functions,
  664. parsed_yaml.backend_indices,
  665. )
  666. custom_ops_native_functions = (
  667. custom_ops_parsed_yaml.native_functions if custom_ops_parsed_yaml else []
  668. )
  669. cpu_fm = make_file_manager(options=options)
  670. static_dispatch_idx: List[BackendIndex] = [backend_indices[DispatchKey.CPU]]
  671. if "headers" in options.generate:
  672. gen_headers(
  673. native_functions=native_functions,
  674. custom_ops_native_functions=custom_ops_native_functions,
  675. static_dispatch_idx=static_dispatch_idx,
  676. selector=selector,
  677. backend_indices=backend_indices,
  678. cpu_fm=cpu_fm,
  679. use_aten_lib=options.use_aten_lib,
  680. )
  681. if "sources" in options.generate:
  682. gen_unboxing(
  683. native_functions=native_functions,
  684. cpu_fm=cpu_fm,
  685. selector=selector,
  686. use_aten_lib=options.use_aten_lib,
  687. )
  688. if custom_ops_native_functions:
  689. gen_custom_ops(
  690. native_functions=custom_ops_native_functions,
  691. selector=selector,
  692. backend_indices=backend_indices,
  693. cpu_fm=cpu_fm,
  694. rocm=options.rocm,
  695. )
  696. if options.output_dependencies:
  697. depfile_path = pathlib.Path(options.output_dependencies).resolve()
  698. depfile_name = depfile_path.name
  699. depfile_stem = depfile_path.stem
  700. for fm, prefix in [
  701. (cpu_fm, ""),
  702. ]:
  703. varname = prefix + depfile_stem
  704. path = depfile_path.parent / (prefix + depfile_name)
  705. fm.write_outputs(varname, str(path))
  706. if __name__ == "__main__":
  707. main()