gen_lazy_tensor.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606
  1. import argparse
  2. import os
  3. import pathlib
  4. import re
  5. from collections import Counter, namedtuple
  6. from typing import (
  7. Any,
  8. Callable,
  9. Dict,
  10. Iterable,
  11. Iterator,
  12. List,
  13. Optional,
  14. Sequence,
  15. Tuple,
  16. Type,
  17. Union,
  18. )
  19. import yaml
  20. import torchgen.dest as dest
  21. from torchgen.api.lazy import setValueT
  22. from torchgen.api.types import BaseCppType
  23. from torchgen.dest.lazy_ir import GenLazyIR, GenLazyNativeFuncDefinition, GenTSLazyIR
  24. from torchgen.gen import get_grouped_native_functions, parse_native_yaml
  25. from torchgen.model import NativeFunction, NativeFunctionsGroup, OperatorName
  26. from torchgen.selective_build.selector import SelectiveBuilder
  27. from torchgen.utils import concatMap, FileManager, NamespaceHelper, YamlLoader
  28. from .gen_backend_stubs import (
  29. error_on_missing_kernels,
  30. gen_dispatcher_registrations,
  31. gen_dispatchkey_nativefunc_headers,
  32. parse_backend_yaml,
  33. )
  34. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  35. #
  36. # Lazy Tensor Codegen
  37. #
  38. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  39. # Overview
  40. # ~~~~~~~~
  41. #
  42. # This codegen script builds on existing data models and helpers used
  43. # by all ATen backends, and adds new functionality specific to lazy
  44. # tensor backends.
  45. #
  46. # Inputs:
  47. # - <backend>_native_functions.yaml: controls which operators are
  48. # supported by the backend.
  49. #
  50. # Outputs:
  51. # (for all backends)
  52. # <DispatchKey>Ir.h defines Lazy IR classes to be constructed during tracing
  53. # - opt-in: also generate 'lowering' methods for the TorchScript backend only
  54. # <DispatchKey>NativeFunctions.cpp defines implementations of native functions which perform lazy tracing
  55. # - opt-in: 'full_codegen' section of backend yaml; 'supported' section omits these implementations
  56. # <DispatchKey>NativeFunctions.h declares implementations of native functions for both 'supported' and 'full_codegen'
  57. # ops
  58. #
  59. # Register<DispatchKey>.cpp registers all op implementations with the dispatcher
  60. # RegisterAutograd<DispatchKey>.cpp registers all autograd implementations with the dispatcher
  61. #
  62. # Validation Helpers:
  63. # - Shape Inference: errs if any ops in backend yaml require shape inference not provided by meta kernels or
  64. # implementations in torch/csrc/lazy/core/shape_inference.*
  65. # - native function impls: errs if any 'supported' ops do not have an implementation defined in the backend
  66. # (non-codegen) implementation file
  67. #
  68. #
  69. # About the Data Model
  70. # ~~~~~~~~~~~~~~~~~~~~
  71. #
  72. # Modeled after ATen codegen, the first step is to parse yaml and build a data model for the operators
  73. # we care about. In this case, the <backend>_native_functions yaml defines a subset of the core operators
  74. # (defined in more detail in the main native_functions.yaml), which will be supported by your backend.
  75. # Backends can list ops in two categories:
  76. # - `supported` ops require hand-implementations but still get codegenned declarations and registrations
  77. # - `full_codegen` ops get implementations (and IR classes) generated too
  78. #
  79. # Each native function is modeled as an object with a schema, and each schema has objects representing their
  80. # arguments. Much of the codegen is manipulation of the arguments and their types. For example, lazy tensor
  81. # backends need to transform 'at::Tensor' arguments into 'lazy::Value' objects, as well as replacing reference
  82. # types (stringref) with actual string objects, and this is done by manipulating the data model objects.
  83. # - see api/lazy.py for the lazy data model
  84. #
  85. # Once the data model is set up, the rest of this script processes a number of templates for output CPP file
  86. # and fills in the template values using helpers in `dest/lazy_ir.py` and `dest/lazy_ts_lowering.py`. These
  87. # helpers mostly iterate over functions and their arguments, outputting different c++ snippets.
  88. #
  89. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  90. # Parses the external backend's yaml, and adds a new BackendIndex for the backend's dispatch key.
  91. # Returns a Tuple of (backend_key, autograd_key, cpp_namespace, updated BackendIndex mapping, full_codegen)
  92. ParsedExternalYaml = namedtuple(
  93. "ParsedExternalYaml",
  94. ["backend_key", "autograd_key", "cpp_namespace", "backend_indices", "full_codegen"],
  95. )
  96. def parse_native_functions_keys(
  97. backend_yaml_path: str,
  98. grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
  99. ) -> Tuple[List[OperatorName], List[Any], List[OperatorName]]:
  100. native_functions_map: Dict[OperatorName, NativeFunction] = {
  101. f.func.name: f
  102. for f in concatMap(
  103. lambda f: [f] if isinstance(f, NativeFunction) else list(f.functions()),
  104. grouped_native_functions,
  105. )
  106. }
  107. with open(backend_yaml_path, "r") as f:
  108. yaml_values = yaml.load(f, Loader=YamlLoader)
  109. assert isinstance(yaml_values, dict)
  110. full_codegen = yaml_values.pop("full_codegen", [])
  111. non_native = yaml_values.pop("non_native", [])
  112. ir_gen = yaml_values.pop("ir_gen", [])
  113. assert isinstance(full_codegen, list)
  114. assert isinstance(non_native, list)
  115. assert isinstance(ir_gen, list)
  116. full_codegen_opnames = [OperatorName.parse(name) for name in full_codegen]
  117. ir_gen_opnames = [OperatorName.parse(name) for name in ir_gen]
  118. return full_codegen_opnames, non_native, ir_gen_opnames
  119. def validate_shape_inference_header(
  120. shape_inference_hdr: str, expected_shape_infr_decls: List[str]
  121. ) -> None:
  122. try:
  123. with open(shape_inference_hdr, "r") as f:
  124. shape_infr_decls = f.read()
  125. shape_infr_decl_lines = set(shape_infr_decls.split("\n"))
  126. except IOError as e:
  127. raise AssertionError(
  128. f"Unable to read from the specified shape_inference_hdr file: {shape_inference_hdr}"
  129. ) from e
  130. shape_infr_regex = r"compute_shape_(\w+)"
  131. actual_shape_infr_name_counts = Counter(
  132. re.findall(shape_infr_regex, shape_infr_decls)
  133. )
  134. # TODO(whc) add a check for shape inference functions that have meta kernels implement and should be retired.
  135. missing_decls = [
  136. decl for decl in expected_shape_infr_decls if decl not in shape_infr_decl_lines
  137. ]
  138. if missing_decls:
  139. raise Exception(
  140. f"""Missing shape inference function.\n
  141. Please add declare this function in {shape_inference_hdr}:\n
  142. and implement it in the the corresponding shape_inference.cpp file.\n
  143. {os.linesep.join(missing_decls)}"""
  144. )
  145. # Some helper functions for the codegen.
  146. def get_ltc_helper_fns() -> str:
  147. return """\
  148. at::Tensor to_meta(const at::Tensor& tensor) {
  149. // undefined tensors can't be converted to the meta device, since they don't have sizes/strides
  150. if (!tensor.defined()) return tensor;
  151. auto out = at::native::empty_strided_meta_symint(tensor.sym_sizes(), tensor.sym_strides(), \
  152. /*dtype=*/c10::make_optional(tensor.scalar_type()), /*layout=*/c10::make_optional(tensor.layout()), \
  153. /*device=*/c10::make_optional(c10::Device(c10::kMeta)), /*pin_memory=*/c10::nullopt);
  154. // needs to handle wrapped numbers, so dtype promotion works properly.
  155. if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
  156. out.unsafeGetTensorImpl()->set_wrapped_number(true);
  157. }
  158. return out;
  159. }
  160. c10::optional<at::Tensor> to_meta(const c10::optional<at::Tensor>& tensor) {
  161. if (tensor.has_value()) {
  162. return to_meta(*tensor);
  163. }
  164. return c10::nullopt;
  165. }
  166. std::vector<at::Tensor> to_meta(at::ITensorListRef t_list) {
  167. std::vector<at::Tensor> outs;
  168. outs.reserve(t_list.size());
  169. for (const auto& tensor : t_list) {
  170. outs.push_back(to_meta(tensor));
  171. }
  172. return outs;
  173. }
  174. """
  175. class default_args:
  176. node_base: str = "Node"
  177. node_base_hdr: Optional[str] = None
  178. shape_inference_hdr: str = "torch/csrc/lazy/core/shape_inference.h"
  179. tensor_class: str = "torch::lazy::LazyTensor"
  180. tensor_class_hdr: str = "torch/csrc/lazy/core/tensor.h"
  181. lazy_ir_generator: Type[GenLazyIR] = GenLazyIR
  182. native_func_definition_generator: Type[
  183. GenLazyNativeFuncDefinition
  184. ] = GenLazyNativeFuncDefinition
  185. backend_name: str = "TorchScript"
  186. def main() -> None:
  187. parser = argparse.ArgumentParser(description="Generate Lazy Tensor backend files")
  188. parser.add_argument(
  189. "-s",
  190. "--source-yaml",
  191. "--source_yaml",
  192. help="path to source yaml file containing operator external definitions",
  193. )
  194. parser.add_argument("-o", "--output-dir", "--output_dir", help="output directory")
  195. parser.add_argument(
  196. "--dry-run", "--dry_run", type=bool, default=False, help="output directory"
  197. )
  198. parser.add_argument(
  199. "--impl-path",
  200. "--impl_path",
  201. type=str,
  202. default=None,
  203. help="path to the source C++ file containing kernel definitions",
  204. )
  205. parser.add_argument(
  206. "--gen-ts-lowerings",
  207. "--gen_ts_lowerings",
  208. action="store_true",
  209. help="Generate TorchScript lowerings in addition to Lazy IR and NativeFunctions",
  210. )
  211. parser.add_argument(
  212. "--node-base",
  213. "--node_base",
  214. type=str,
  215. default=default_args.node_base,
  216. help="Name of backend specific custom Lazy IR Node base class",
  217. )
  218. parser.add_argument(
  219. "--node-base-hdr",
  220. "--node_base_hdr",
  221. type=str,
  222. default=default_args.node_base_hdr,
  223. help="Path to header file defining custom Lazy IR Node base class",
  224. )
  225. parser.add_argument(
  226. "--shape-inference-hdr",
  227. "--shape_inference_hdr",
  228. type=str,
  229. default=default_args.shape_inference_hdr,
  230. help="Path to header file defining custom Lazy shape inference functions",
  231. )
  232. parser.add_argument(
  233. "--tensor-class",
  234. "--tensor_class",
  235. type=str,
  236. default=default_args.tensor_class,
  237. help="Name of backend specific custom Lazy Tensor class",
  238. )
  239. parser.add_argument(
  240. "--tensor-class-hdr",
  241. "--tensor_class_hdr",
  242. type=str,
  243. default=default_args.tensor_class_hdr,
  244. help="Path to header file defining custom Lazy Tensor class",
  245. )
  246. parser.add_argument(
  247. "--backend-name",
  248. "--backend_name",
  249. type=str,
  250. default=default_args.backend_name,
  251. help="Name of the backend to generate",
  252. )
  253. options = parser.parse_args()
  254. # Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py
  255. torch_root = pathlib.Path(__file__).parent.parent.parent.absolute()
  256. aten_path = str(torch_root / "aten" / "src" / "ATen")
  257. lazy_ir_generator: Type[GenLazyIR] = default_args.lazy_ir_generator
  258. if options.gen_ts_lowerings:
  259. lazy_ir_generator = GenTSLazyIR
  260. native_func_definition_generator: Type[
  261. GenLazyNativeFuncDefinition
  262. ] = default_args.native_func_definition_generator
  263. run_gen_lazy_tensor(
  264. aten_path,
  265. options.source_yaml,
  266. options.output_dir,
  267. options.dry_run,
  268. options.impl_path,
  269. options.node_base,
  270. options.node_base_hdr,
  271. options.tensor_class,
  272. options.tensor_class_hdr,
  273. options.shape_inference_hdr,
  274. lazy_ir_generator,
  275. native_func_definition_generator,
  276. options.backend_name,
  277. )
  278. def run_gen_lazy_tensor(
  279. aten_path: str,
  280. source_yaml: str,
  281. output_dir: str,
  282. dry_run: bool,
  283. impl_path: Optional[str],
  284. node_base: str = default_args.node_base,
  285. node_base_hdr: Optional[str] = default_args.node_base_hdr,
  286. tensor_class: str = default_args.tensor_class,
  287. tensor_class_hdr: str = default_args.tensor_class_hdr,
  288. shape_inference_hdr: str = default_args.shape_inference_hdr,
  289. lazy_ir_generator: Type[GenLazyIR] = default_args.lazy_ir_generator,
  290. native_func_definition_generator: Type[
  291. GenLazyNativeFuncDefinition
  292. ] = default_args.native_func_definition_generator,
  293. # build_in_tree is true for TS backend and affects include paths
  294. build_in_tree: bool = False,
  295. # per_operator_headers changes whether ATen/Functions.h or individual operator headers are used
  296. # it must match how ATen was built
  297. per_operator_headers: bool = False,
  298. backend_name: str = default_args.backend_name,
  299. gen_forced_fallback_code: bool = False,
  300. use_lazy_shape: bool = True,
  301. # the following arguments are temporary customization points for xla backend migration.
  302. # do not rely on them otherwise, they should be removed once migration is complete
  303. backend_namespace: str = "torch::lazy",
  304. get_tensorlist: str = "GetTensorList",
  305. get_tensor_or_wrap_number: str = "GetLtcTensorOrCreateForWrappedNumber",
  306. try_get_tensor: str = "TryGetLtcTensor",
  307. metrics_counter: str = 'TORCH_LAZY_FN_COUNTER("lazy::")',
  308. create_tensor: str = "LazyTensor::Create",
  309. create_from_first_tensor: bool = False,
  310. create_aten_from_ltc_tensor: str = "torch::lazy::CreateAtenFromLtcTensor",
  311. tuple_aten_from_ltc_tensors: str = "torch::lazy::TupleAtenFromLtcTensors",
  312. lazy_value_class: str = "torch::lazy::Value",
  313. lazy_tensor_ptr: str = "LazyTensorPtr",
  314. get_device_fn: str = "torch::lazy::GetBackendDevice",
  315. ) -> None:
  316. lv_tokens = lazy_value_class.split("::")
  317. lv_class = lv_tokens[-1]
  318. lv_ns = "::".join(lv_tokens[:-1])
  319. setValueT(BaseCppType(lv_ns, lv_class))
  320. template_dir = os.path.join(aten_path, "templates")
  321. def make_file_manager(install_dir: str) -> FileManager:
  322. return FileManager(
  323. install_dir=install_dir, template_dir=template_dir, dry_run=dry_run
  324. )
  325. fm = make_file_manager(output_dir)
  326. native_yaml_path = os.path.join(aten_path, "native/native_functions.yaml")
  327. tags_yaml_path = os.path.join(aten_path, "native/tags.yaml")
  328. parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path)
  329. native_functions, backend_indices = (
  330. parsed_yaml.native_functions,
  331. parsed_yaml.backend_indices,
  332. )
  333. grouped_native_functions = get_grouped_native_functions(native_functions)
  334. def sort_native_function(f: Union[NativeFunctionsGroup, NativeFunction]) -> str:
  335. """
  336. We sort the native function because of the note in concat_map_codegen.
  337. TODO(alanwaketan): Remove this sorting hack once all ops are grouped properly.
  338. """
  339. func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
  340. return str(func.name.name)
  341. grouped_native_functions = sorted(
  342. grouped_native_functions, key=sort_native_function
  343. )
  344. parsed_backend_yaml = parse_backend_yaml(
  345. source_yaml, grouped_native_functions, backend_indices
  346. )
  347. backend_key = parsed_backend_yaml.backend_key
  348. autograd_key = parsed_backend_yaml.autograd_key
  349. cpp_namespace = parsed_backend_yaml.cpp_namespace
  350. backend_indices = parsed_backend_yaml.backend_indices
  351. # the following 3 keys are all processed differently
  352. # for full_codegen, we generate IR, kernels, etc
  353. # for ir_gen, we generate only IR
  354. # non_native is used to register kernels not declared in
  355. # native_functions.yaml
  356. full_codegen, non_native, ir_gen = parse_native_functions_keys(
  357. source_yaml, grouped_native_functions
  358. )
  359. def concat_map_codegen(
  360. func: Callable[[NativeFunction], Sequence[str]],
  361. xs: Iterable[Union[NativeFunctionsGroup, NativeFunction]],
  362. ops_list: List[OperatorName] = full_codegen,
  363. ) -> Iterator[str]:
  364. """
  365. We code-gen for the functional variant, which is all we need for IR classes/lowerings/shape inferences, but we
  366. only code-gen additional entries for the inplace variant for the native functions.
  367. """
  368. for x in xs:
  369. fs = list(x.functions()) if isinstance(x, NativeFunctionsGroup) else [x]
  370. for f in fs:
  371. if f.func.name in ops_list:
  372. for r in func(f):
  373. yield r
  374. selector = SelectiveBuilder.get_nop_selector()
  375. assert backend_key is not None
  376. class_name = backend_indices[backend_key].native_function_class_name()
  377. if impl_path is not None:
  378. error_on_missing_kernels(
  379. native_functions,
  380. backend_indices,
  381. backend_key,
  382. autograd_key,
  383. class_name,
  384. impl_path,
  385. full_codegen,
  386. )
  387. """ Validate Shape Inference Definitions
  388. Generated lazy native functions all perform shape inference, by first using a meta:: kernel
  389. if available for that op, and otherwise using a 'compute_shape_{op}' function instead. The generator
  390. knows the call signature for compute_shape_{op} becuase it matches the nativefunction (and meta::) signature,
  391. so it just has to check whether the op is structured and generate a call for one or the other. It's up to the dev
  392. to supply the missing compute_shape_{op} function, but the codegen at least warns you about this and provides
  393. the expected signature which can be copy-pasted into shape_inference.h.
  394. compute_shape_{op} functions are handwritten and should be replaced over time as ops get ported
  395. to structured kernels.
  396. See torch/csrc/lazy/core/shape_inference.cpp #READ THIS! for more information.
  397. """
  398. if shape_inference_hdr is not None:
  399. expected_shape_infr_decls = list(
  400. concat_map_codegen(
  401. dest.GenLazyShapeInferenceDefinition(
  402. backend_indices[backend_key], tensor_class
  403. ),
  404. grouped_native_functions,
  405. )
  406. )
  407. validate_shape_inference_header(shape_inference_hdr, expected_shape_infr_decls)
  408. assert class_name is not None
  409. # Generate nativefunction declarations
  410. # Note, eager registrations is set to False for the lazy TS backend as another LTC backend
  411. # may want to register their own lazy kernels instead of registering the TS ones.
  412. # The registration will lazily happen when init_ts_backend is called.
  413. gen_dispatchkey_nativefunc_headers(
  414. fm,
  415. class_name,
  416. cpp_namespace,
  417. backend_indices,
  418. grouped_native_functions,
  419. backend_key,
  420. autograd_key,
  421. backend_name,
  422. )
  423. # Generate Dispatcher registrations which hook up the nativefunctions
  424. for dispatch_key in (
  425. [backend_key] if autograd_key is None else [backend_key, autograd_key]
  426. ):
  427. gen_dispatcher_registrations(
  428. fm,
  429. output_dir,
  430. class_name,
  431. backend_indices,
  432. grouped_native_functions,
  433. backend_key,
  434. dispatch_key,
  435. selector,
  436. build_in_tree=build_in_tree,
  437. per_operator_headers=per_operator_headers,
  438. backend_name=backend_name,
  439. eager_registration=False,
  440. )
  441. # Generate native function impls that build IR nodes
  442. ns_helper = NamespaceHelper(cpp_namespace)
  443. fm.write_with_template(
  444. f"{backend_key}NativeFunctions.cpp",
  445. "DispatchKeyNativeFunctions.cpp",
  446. lambda: {
  447. "includes": [
  448. f"#include <{path}>"
  449. for path in [
  450. tensor_class_hdr,
  451. shape_inference_hdr,
  452. "ATen/Functions.h",
  453. "ATen/native/TensorConversions.h",
  454. "ATen/NativeFunctions.h",
  455. "ATen/CompositeExplicitAutogradNonFunctionalFunctions.h",
  456. "ATen/MetaFunctions.h",
  457. "ATen/Operators.h",
  458. "ATen/native/CPUFallback.h",
  459. "torch/csrc/lazy/core/ir_builder.h",
  460. "torch/csrc/lazy/core/lazy_graph_executor.h",
  461. "torch/csrc/lazy/core/metrics.h",
  462. "torch/csrc/lazy/core/shape.h",
  463. f"{output_dir}/{backend_key}NativeFunctions.h",
  464. f"{output_dir}/LazyIr.h",
  465. ]
  466. + (
  467. ["torch/csrc/lazy/ts_backend/ts_eager_fallback.h"]
  468. if gen_forced_fallback_code
  469. else []
  470. )
  471. ],
  472. "helper_fns": get_ltc_helper_fns(),
  473. "native_functions_include": "",
  474. "namespace_prologue": ns_helper.prologue,
  475. "namespace_epilogue": ns_helper.epilogue,
  476. "native_function_definitions": list(
  477. concat_map_codegen(
  478. native_func_definition_generator(
  479. f"{backend_key}NativeFunctions",
  480. backend_indices[backend_key],
  481. tensor_class,
  482. gen_forced_fallback_code,
  483. backend_namespace,
  484. get_tensorlist,
  485. get_tensor_or_wrap_number,
  486. try_get_tensor,
  487. metrics_counter,
  488. create_tensor,
  489. create_from_first_tensor,
  490. create_aten_from_ltc_tensor,
  491. tuple_aten_from_ltc_tensors,
  492. lazy_tensor_ptr,
  493. get_device_fn,
  494. ),
  495. grouped_native_functions,
  496. )
  497. ),
  498. },
  499. )
  500. # Generate IR node classes
  501. lazy_ir_obj = lazy_ir_generator(
  502. backend_indices[backend_key], backend_name, node_base, use_lazy_shape
  503. )
  504. fm.write_with_template(
  505. "LazyIr.h",
  506. "LazyIr.h",
  507. lambda: {
  508. "lazy_ir_sysinc": [
  509. f"#include <{path}>"
  510. for path in [
  511. "ATen/core/Formatting.h",
  512. "c10/core/ScalarType.h",
  513. "c10/util/Optional.h",
  514. "torch/csrc/lazy/core/hash.h",
  515. "torch/csrc/lazy/core/ir.h",
  516. "torch/csrc/lazy/core/shape.h",
  517. "vector",
  518. ]
  519. ],
  520. "lazy_ir_inc": [f'#include "{node_base_hdr}"']
  521. if node_base_hdr is not None
  522. else [],
  523. "ir_declarations": list(
  524. concat_map_codegen(
  525. lazy_ir_obj, grouped_native_functions, full_codegen + ir_gen
  526. )
  527. ),
  528. "namespace_prologue": ns_helper.prologue,
  529. "namespace_epilogue": ns_helper.epilogue,
  530. },
  531. )
  532. # Generate Non Native IR Node classes
  533. fm.write_with_template(
  534. "LazyNonNativeIr.h",
  535. "LazyNonNativeIr.h",
  536. lambda: {
  537. "lazy_non_native_ir_inc": [
  538. f"#include <{path}>"
  539. for path in [
  540. "torch/csrc/lazy/core/ir.h",
  541. "torch/csrc/lazy/core/ir_builder.h",
  542. "torch/csrc/lazy/core/internal_ops/ltc_ops.h",
  543. "torch/csrc/lazy/core/shape_inference.h",
  544. ]
  545. + ([node_base_hdr] if node_base_hdr else [])
  546. if path
  547. ],
  548. "non_native_ir_nodes": dest.generate_non_native_lazy_ir_nodes(
  549. non_native, lazy_ir_obj
  550. ),
  551. "namespace_prologue": ns_helper.prologue,
  552. "namespace_epilogue": ns_helper.epilogue,
  553. },
  554. )
  555. if __name__ == "__main__":
  556. main()