123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779 |
- import argparse
- import os
- import pathlib
- from collections import defaultdict
- from dataclasses import dataclass
- from typing import Callable, Dict, List, Optional, Sequence, TextIO, Tuple, Union
- import yaml
- # Parse native_functions.yaml into a sequence of NativeFunctions and Backend Indices.
- from torchgen import dest
- from torchgen.api import cpp as aten_cpp
- from torchgen.api.types import CppSignature, CppSignatureGroup, CType, NamedCType
- from torchgen.context import method_with_native_function, with_native_function_and_index
- from torchgen.executorch.api import et_cpp
- from torchgen.executorch.api.custom_ops import (
- ComputeNativeFunctionStub,
- gen_custom_ops_registration,
- )
- from torchgen.executorch.api.types import ExecutorchCppSignature
- from torchgen.executorch.api.unboxing import Unboxing
- from torchgen.gen import (
- get_custom_build_selector,
- get_native_function_declarations,
- get_native_function_schema_registrations,
- LineLoader,
- parse_native_yaml,
- ParsedYaml,
- )
- from torchgen.model import (
- BackendIndex,
- BackendMetadata,
- DispatchKey,
- is_cuda_dispatch_key,
- Location,
- NativeFunction,
- NativeFunctionsGroup,
- OperatorName,
- Variant,
- )
- from torchgen.selective_build.selector import SelectiveBuilder
- from torchgen.utils import (
- context,
- FileManager,
- make_file_manager,
- mapMaybe,
- NamespaceHelper,
- )
- def static_dispatch(
- sig: Union[CppSignature, ExecutorchCppSignature],
- f: NativeFunction,
- backend_indices: List[BackendIndex],
- ) -> str:
- """
- For a given `NativeFunction`, find out the corresponding native function and dispatch to it. If zero or more than one
- native function exists, error out. A simplified version of register_dispatch_key.py
- Arguments:
- sig: A CppSignature for this native function we want to use.
- f: NativeFunction to generate static dispatch.
- backend_indices: All available backends.
- Return:
- C++ code to call backend-specific functions, e.g., "return at::native::add(self, other, scale);"
- """
- if len(backend_indices) == 0 or f.manual_kernel_registration:
- return ""
- backends = [b for b in backend_indices if b.has_kernel(f)]
- static_block = None
- if len(backends) == 1:
- backend_metadata = backends[0].get_kernel(f)
- if backend_metadata:
- args = ", ".join(a.name for a in sig.arguments())
- # Here we are assuming there's no difference between CppSignature and NativeSignature for Executorch.
- static_block = f"return ::{backend_metadata.cpp_namespace}::{backend_metadata.kernel}({args});"
- else:
- static_block = f"""
- ET_ASSERT_UNREACHABLE_MSG("The number of native function(s) binding to {f.func.name} is {len(backends)}.");
- """
- return f"""
- // {f.namespace}::{f.func}
- TORCH_API inline {sig.decl()} {{
- {static_block}
- }}
- """
- # Generates Functions.h, which provides the functional public C++ API,
- # and the scaffolding to call into the dispatcher from these functions.
- @dataclass(frozen=True)
- class ComputeFunction:
- static_dispatch_backend_indices: List[BackendIndex]
- selector: SelectiveBuilder
- use_aten_lib: bool
- is_custom_op: Callable[[NativeFunction], bool]
- @method_with_native_function
- def __call__(self, f: NativeFunction) -> Optional[str]:
- if not self.selector.is_root_operator(f"{f.namespace}::{f.func.name}"):
- return None
- if Variant.function not in f.variants:
- return None
- sig: Union[CppSignature, ExecutorchCppSignature] = (
- CppSignatureGroup.from_native_function(
- f, method=False, fallback_binding=f.manual_cpp_binding
- ).most_faithful_signature()
- if self.use_aten_lib
- else ExecutorchCppSignature.from_native_function(f)
- )
- if self.use_aten_lib and not self.is_custom_op(f):
- comma = ", "
- return f"""
- // {f.namespace}::{f.func}
- TORCH_API inline {sig.decl()} {{
- return at::{sig.name()}({comma.join(e.name for e in sig.arguments())});
- }}
- """
- else:
- return static_dispatch(
- sig,
- f,
- backend_indices=self.static_dispatch_backend_indices,
- )
- # Generates RegisterCodegenUnboxedKernels.cpp.
- @dataclass(frozen=True)
- class ComputeCodegenUnboxedKernels:
- selector: SelectiveBuilder
- use_aten_lib: bool
- @method_with_native_function
- def __call__(self, f: NativeFunction) -> str:
- if not self.selector.is_root_operator(f"{f.namespace}::{f.func.name}"):
- return ""
- sig: Union[CppSignature, ExecutorchCppSignature]
- argument_type_gen: Callable[..., NamedCType]
- return_type_gen: Callable[..., CType]
- if self.use_aten_lib:
- sig = CppSignatureGroup.from_native_function(
- f, method=False, fallback_binding=f.manual_cpp_binding
- ).most_faithful_signature()
- argument_type_gen = aten_cpp.argumenttype_type
- return_type_gen = aten_cpp.returns_type
- else:
- sig = ExecutorchCppSignature.from_native_function(f)
- argument_type_gen = et_cpp.argumenttype_type
- return_type_gen = et_cpp.returns_type
- # parse arguments into C++ code
- binding_list, code_list = Unboxing(
- argument_type_gen=argument_type_gen
- ).convert_arguments(sig.arguments())
- # for each C++ argument, generate the conversion code
- code_connector = "\n\t"
- arg_connector = ", "
- args_str = f"{arg_connector.join(e.name for e in binding_list)}"
- if len(f.func.returns) == 0:
- if len(f.func.arguments.out) == 0:
- raise Exception(
- f"Can't handle native function {f.func} with no returns and no out yet."
- )
- out = f.func.arguments.out[0]
- return_assignment = f"""stack[{len(binding_list)}] = &{out.name};"""
- ret_prefix = ""
- else:
- if len(f.func.arguments.out) == 0:
- return_assignment = (
- f"""*stack[{len(binding_list)}] = EValue(result_);"""
- )
- ret_prefix = return_type_gen(f.func.returns).cpp_type() + " result_ = "
- else:
- return_assignment = ""
- ret_prefix = ""
- return f"""
- Operator(
- "{f.namespace}::{f.func.name}",
- [](EValue** stack) {{
- {code_connector.join(code_list)}
- EXECUTORCH_SCOPE_PROF("native_call_{f.func.name}");
- {ret_prefix}torch::executor::{f.namespace}::{sig.name()}({args_str});
- {return_assignment}
- }}
- ),
- """
- def gen_unboxing(
- *,
- native_functions: Sequence[NativeFunction],
- cpu_fm: FileManager,
- selector: SelectiveBuilder,
- use_aten_lib: bool,
- ) -> None:
- def key_func(fn: Union[NativeFunction, NativeFunctionsGroup]) -> str:
- return fn.root_name
- cpu_fm.write_sharded(
- "RegisterCodegenUnboxedKernels.cpp",
- native_functions,
- key_fn=key_func,
- env_callable=lambda fn: {
- "unboxed_ops": [ComputeCodegenUnboxedKernels(selector, use_aten_lib)(fn)],
- },
- num_shards=1,
- sharded_keys={"unboxed_ops"},
- )
- @with_native_function_and_index
- def compute_native_function_declaration(
- g: Union[NativeFunctionsGroup, NativeFunction], backend_index: BackendIndex
- ) -> List[str]:
- assert isinstance(g, NativeFunction)
- sig = ExecutorchCppSignature.from_native_function(f=g)
- metadata = backend_index.get_kernel(g)
- if metadata is None:
- return []
- prefix = "static" if backend_index.external else "TORCH_API"
- return [f"{prefix} {sig.decl(name=metadata.kernel)};"]
- def gen_functions_declarations(
- *,
- native_functions: Sequence[NativeFunction],
- static_dispatch_idx: List[BackendIndex],
- selector: SelectiveBuilder,
- use_aten_lib: bool,
- custom_ops_native_functions: Optional[Sequence[NativeFunction]] = None,
- ) -> str:
- """
- Generates namespace separated C++ function API inline declaration/definitions.
- Native functions are grouped by namespaces and the generated code is wrapped inside
- namespace blocks.
- E.g., for `custom_1::foo.out` in yaml file we will generate a C++ API as a symbol
- in `torch::executor::custom_1::foo_out`. This way we avoid symbol conflict when
- the other `custom_2::foo.out` is available.
- """
- ns_grouped_functions = defaultdict(list)
- for native_function in native_functions:
- ns_grouped_functions[native_function.namespace].append(native_function)
- functions_declarations = ""
- newline = "\n"
- for namespace in ns_grouped_functions:
- ns_helper = NamespaceHelper(
- namespace_str=namespace,
- entity_name="",
- max_level=3,
- )
- declarations = list(
- mapMaybe(
- ComputeFunction(
- static_dispatch_backend_indices=static_dispatch_idx,
- selector=selector,
- use_aten_lib=use_aten_lib,
- is_custom_op=lambda f: custom_ops_native_functions is not None
- and f in custom_ops_native_functions,
- ),
- ns_grouped_functions[namespace],
- )
- )
- functions_declarations += f"""
- {ns_helper.prologue}
- {newline.join(declarations)}
- {ns_helper.epilogue}
- """
- return functions_declarations
- def gen_headers(
- *,
- native_functions: Sequence[NativeFunction],
- custom_ops_native_functions: Sequence[NativeFunction],
- static_dispatch_idx: List[BackendIndex],
- selector: SelectiveBuilder,
- backend_indices: Dict[DispatchKey, BackendIndex],
- cpu_fm: FileManager,
- use_aten_lib: bool,
- ) -> None:
- aten_headers = ["#include <ATen/Functions.h>"]
- if custom_ops_native_functions:
- cpu_fm.write_with_template(
- "CustomOpsNativeFunctions.h",
- "NativeFunctions.h",
- lambda: {
- "nativeFunctions_declarations": get_native_function_declarations(
- grouped_native_functions=custom_ops_native_functions,
- backend_indices=backend_indices,
- native_function_decl_gen=dest.compute_native_function_declaration,
- ),
- },
- )
- aten_headers.append('#include "CustomOpsNativeFunctions.h"')
- cpu_fm.write(
- "Functions.h",
- lambda: {
- "static_dispatch_extra_headers": aten_headers
- if use_aten_lib
- else ['#include "NativeFunctions.h"'],
- "Functions_declarations": gen_functions_declarations(
- native_functions=native_functions,
- static_dispatch_idx=static_dispatch_idx,
- selector=selector,
- use_aten_lib=use_aten_lib,
- custom_ops_native_functions=custom_ops_native_functions,
- ),
- },
- )
- cpu_fm.write(
- "NativeFunctions.h",
- lambda: {
- "nativeFunctions_declarations": get_native_function_declarations(
- grouped_native_functions=native_functions,
- backend_indices=backend_indices,
- native_function_decl_gen=dest.compute_native_function_declaration
- if use_aten_lib
- else compute_native_function_declaration,
- ),
- },
- )
- def gen_custom_ops(
- *,
- native_functions: Sequence[NativeFunction],
- selector: SelectiveBuilder,
- backend_indices: Dict[DispatchKey, BackendIndex],
- cpu_fm: FileManager,
- rocm: bool,
- ) -> None:
- dispatch_key = DispatchKey.CPU
- backend_index = backend_indices[dispatch_key]
- (
- anonymous_definition,
- static_init_dispatch_registrations,
- ) = gen_custom_ops_registration(
- native_functions=native_functions,
- selector=selector,
- backend_index=backend_index,
- rocm=rocm,
- )
- cpu_fm.write_with_template(
- f"Register{dispatch_key}CustomOps.cpp",
- "RegisterDispatchKeyCustomOps.cpp",
- lambda: {
- "ops_headers": '#include "CustomOpsNativeFunctions.h"',
- "DispatchKey": dispatch_key,
- "dispatch_namespace": dispatch_key.lower(),
- "dispatch_namespaced_definitions": "",
- "dispatch_anonymous_definitions": anonymous_definition,
- "static_init_dispatch_registrations": static_init_dispatch_registrations,
- },
- )
- cpu_fm.write_with_template(
- f"Register{dispatch_key}Stub.cpp",
- "RegisterDispatchKeyCustomOps.cpp",
- lambda: {
- "ops_headers": "",
- "DispatchKey": dispatch_key,
- "dispatch_namespace": dispatch_key.lower(),
- "dispatch_namespaced_definitions": "",
- "dispatch_anonymous_definitions": list(
- mapMaybe(ComputeNativeFunctionStub(), native_functions)
- ),
- "static_init_dispatch_registrations": static_init_dispatch_registrations,
- },
- )
- (
- aten_schema_registrations,
- schema_registrations,
- ) = get_native_function_schema_registrations(
- native_functions=native_functions,
- schema_selector=selector,
- )
- cpu_fm.write(
- "RegisterSchema.cpp",
- lambda: {
- "schema_registrations": schema_registrations,
- "aten_schema_registrations": aten_schema_registrations,
- },
- )
- def translate_native_yaml(
- tags_yaml_path: str,
- aten_yaml_path: str,
- native_yaml_path: Optional[str],
- use_aten_lib: bool,
- out_file: TextIO,
- ) -> None:
- """Translates Executorch DSL dialect to use the same syntax as
- native_functions.yaml. The major difference is that Executorch DSL dialect
- supports "op" key, where it refers to the operator name in native_functions.yaml.
- For example, a functions.yaml may have the following entry:
- - op: add.out
- ...
- It needs to be translated to the following:
- - func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
- ...
- We go in aten_yaml_path and find the operator schema for "add.out" and add it
- to the original functions.yaml. We also add required field "variants", where for
- Executorch it will always be "function".
- For ATen mode we don't have to do the translation because native_yaml_path is
- the same as native_functions.yaml.
- Args:
- tags_yaml_path: Path to a tags.yaml file to satisfy codegen parsing.
- It is not optional.
- aten_yaml_path: Path to ATen operator yaml file native_functions.yaml.
- native_yaml_path: Path to a functions.yaml file to parse.
- If the path does not exist in the filesystem, it is treated as an
- empty file. If `custom_ops_yaml_path` exists, the contents of that
- file are appended to the yaml input to be parsed.
- use_aten_lib: We use this flag to determine if we want to generate native
- functions. In ATen mode we should generate out= variants.
- out_file: The IO object that we are writing into.
- Returns:
- None
- """
- if use_aten_lib:
- with open(aten_yaml_path, "r") as aten_yaml:
- out_file.writelines(aten_yaml.readlines())
- return
- aten_parsed_yaml = parse_native_yaml(
- aten_yaml_path,
- tags_yaml_path,
- None,
- skip_native_fns_gen=False,
- )
- aten_native_functions = aten_parsed_yaml.native_functions
- schema_dict = {
- f"{f.namespace}::{f.func.name}": str(f.func) for f in aten_native_functions
- }
- if (
- not native_yaml_path
- or not os.path.exists(native_yaml_path)
- or os.stat(native_yaml_path).st_size == 0
- ):
- return
- with open(native_yaml_path, "r") as native_yaml:
- native_es = yaml.load(native_yaml, Loader=LineLoader)
- if not native_es:
- return
- for e in native_es:
- assert isinstance(e.get("__line__"), int), e
- loc = Location(native_yaml_path, e.pop("__line__"))
- with context(lambda: f"in {loc}:\n "):
- if "variants" not in e:
- e["variants"] = "function"
- if "func" in e:
- continue
- assert isinstance(e.get("op"), str), e
- opname = e.pop("op")
- if "::" not in opname:
- opname = "aten::" + opname
- assert opname in schema_dict
- e["func"] = schema_dict.get(opname)
- yaml.dump(native_es, out_file, width=1000)
- def convert_backend_indices(
- bs: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]]
- ) -> Dict[DispatchKey, BackendIndex]:
- indices: Dict[DispatchKey, BackendIndex] = defaultdict(
- lambda: BackendIndex(
- dispatch_key=DispatchKey.Undefined,
- use_out_as_primary=True,
- external=False,
- device_guard=False,
- index={},
- )
- )
- for k, v in bs.items():
- indices[k] = BackendIndex(
- dispatch_key=k,
- use_out_as_primary=True,
- external=False,
- # Only cuda-like devices in tree require device guards
- device_guard=is_cuda_dispatch_key(k),
- index=v,
- )
- return indices
- def parse_yaml(
- path: Optional[str],
- tags_yaml_path: str,
- function_filter: Callable[[NativeFunction], bool],
- skip_native_fns_gen: bool = False,
- ) -> Tuple[
- List[NativeFunction], Dict[DispatchKey, Dict[OperatorName, BackendMetadata]]
- ]:
- if path and os.path.exists(path) and os.stat(path).st_size > 0:
- parsed_yaml = parse_native_yaml(
- path,
- tags_yaml_path,
- None,
- skip_native_fns_gen=skip_native_fns_gen,
- )
- native_functions = list(filter(function_filter, parsed_yaml.native_functions))
- op_names = [f.func.name for f in native_functions]
- def map_index(
- m: Dict[OperatorName, BackendMetadata]
- ) -> Dict[OperatorName, BackendMetadata]:
- return {op: m[op] for op in m if op in op_names}
- backend_indices = {
- k: map_index(b.index) for (k, b) in parsed_yaml.backend_indices.items()
- }
- return native_functions, backend_indices
- else:
- return [], {}
- def parse_yaml_files(
- tags_yaml_path: str,
- aten_yaml_path: str,
- native_yaml_path: Optional[str],
- custom_ops_yaml_path: Optional[str],
- selector: SelectiveBuilder,
- use_aten_lib: bool,
- ) -> Tuple[ParsedYaml, Optional[ParsedYaml]]:
- """Parses functions.yaml and custom_ops.yaml files.
- Args:
- tags_yaml_path: Path to a tags.yaml file to satisfy codegen parsing.
- It is not optional.
- aten_yaml_path: Path to ATen operator yaml file native_functions.yaml.
- native_yaml_path: Path to a functions.yaml file to parse.
- If the path does not exist in the filesystem, it is treated as an
- empty file. If `custom_ops_yaml_path` exists, the contents of that
- file are appended to the yaml input to be parsed.
- custom_ops_yaml_path: Path to a custom_ops.yaml file to parse. If
- the path does not exist in the filesystem, it is ignored.
- selector: For selective build.
- use_aten_lib: We use this flag to determine if we want to generate native
- functions. In ATen mode we should generate out= variants.
- Returns:
- A tuple with two elements:
- [0]: The parsed results of concatenating the contents of
- `native_yaml_path` and `custom_ops_yaml_path`.
- [1]: The parsed results of the contents of `custom_ops_yaml_path`, if
- present. If not present, None.
- """
- import tempfile
- # only include selected ops, this is because we want to avoid
- def function_filter(f: NativeFunction) -> bool:
- return selector.is_native_function_selected(f)
- with tempfile.TemporaryDirectory() as tmpdirname:
- translated_yaml_path = os.path.join(tmpdirname, "translated.yaml")
- with open(translated_yaml_path, "w") as translated:
- translate_native_yaml(
- tags_yaml_path,
- aten_yaml_path,
- native_yaml_path,
- use_aten_lib,
- translated,
- )
- translated_functions, translated_backend_indices = parse_yaml(
- translated_yaml_path, tags_yaml_path, function_filter, not use_aten_lib
- )
- custom_ops_functions, custom_ops_backend_indices = parse_yaml(
- custom_ops_yaml_path, tags_yaml_path, function_filter, True
- )
- combined_functions = translated_functions + custom_ops_functions
- combined_backend_indices: Dict[
- DispatchKey, Dict[OperatorName, BackendMetadata]
- ] = defaultdict(dict)
- combined_backend_indices.update(translated_backend_indices)
- for dk in custom_ops_backend_indices:
- if dk not in combined_backend_indices:
- combined_backend_indices.update({dk: custom_ops_backend_indices[dk]})
- else:
- combined_backend_indices[dk] = {
- **combined_backend_indices[dk],
- **custom_ops_backend_indices[dk],
- }
- combined_yaml = ParsedYaml(
- combined_functions, convert_backend_indices(combined_backend_indices)
- )
- custom_ops_parsed_yaml = ParsedYaml(
- custom_ops_functions, convert_backend_indices(custom_ops_backend_indices)
- )
- return combined_yaml, custom_ops_parsed_yaml
- def main() -> None:
- parser = argparse.ArgumentParser(description="Generate operator source files")
- # Although we don't refer to --source-path directly, make_file_manager()
- # expects it to point to a directory that contains a templates/ subdirectory
- # containing the file templates.
- parser.add_argument(
- "-s",
- "--source-path",
- help="path to source directory for kernel templates",
- )
- parser.add_argument(
- "--functions-yaml-path",
- "--functions_yaml_path",
- help="path to the functions.yaml file to use. Optional, but at least "
- "one of --functions-yaml-path and --custom-ops-yaml-path must be "
- "specified.",
- )
- parser.add_argument(
- "--custom-ops-yaml-path",
- "--custom_ops_yaml_path",
- help="path to the custom_ops.yaml file to use. Optional, but at least "
- "one of --functions-yaml-path and --custom-ops-yaml-path must be "
- "specified.",
- )
- parser.add_argument(
- "--aten-yaml-path",
- "--aten_yaml_path",
- help="path to native_functions.yaml file.",
- )
- # Note that make_file_manager() also looks at --install-dir.
- parser.add_argument(
- "-d",
- "--install-dir",
- "--install_dir",
- help="output directory",
- default="build/generated",
- )
- parser.add_argument(
- "-o",
- "--output-dependencies",
- help="output a list of dependencies into the given file and exit",
- )
- # Although we don't refer to --dry-run directly, make_file_manager() looks
- # for it.
- parser.add_argument(
- "--dry-run",
- action="store_true",
- help="run without writing any files (still updates outputs)",
- )
- parser.add_argument(
- "--static-dispatch-backend",
- "--static_dispatch_backend",
- nargs="*",
- help="generate static dispatch code for the specific backend (if set)",
- )
- parser.add_argument(
- "--op-registration-whitelist",
- "--op_registration_whitelist",
- nargs="*",
- help="filter op registrations by the whitelist (if set); "
- "each item is `namespace`::`operator name` without overload name; "
- "e.g.: aten::empty aten::conv2d ...",
- )
- parser.add_argument(
- "--op-selection-yaml-path",
- "--op_selection_yaml_path",
- help="Provide a path to the operator selection (for custom build) YAML "
- "that contains the information about the set of selected operators "
- "and their categories (training, ...). Each operator is either a "
- "full operator name with overload or just a bare operator name. "
- "The operator names also contain the namespace prefix (e.g. aten::)",
- )
- parser.add_argument(
- "--tags-path",
- help="Path to tags.yaml. Required by yaml parsing in codegen system.",
- )
- parser.add_argument(
- "--rocm",
- action="store_true",
- help="reinterpret CUDA as ROCm/HIP and adjust filepaths accordingly",
- )
- parser.add_argument(
- "--use-aten-lib",
- "--use_aten_lib",
- action="store_true",
- help="a boolean flag to indicate whether we use ATen kernels or not, in the future this flag will be per "
- "operator",
- )
- parser.add_argument(
- "--generate",
- type=str,
- nargs="*",
- choices=["headers", "sources"],
- default=["headers", "sources"],
- help="Generate only a subset of files",
- )
- options = parser.parse_args()
- assert options.tags_path, "tags.yaml is required by codegen yaml parsing."
- selector = get_custom_build_selector(
- options.op_registration_whitelist,
- options.op_selection_yaml_path,
- )
- parsed_yaml, custom_ops_parsed_yaml = parse_yaml_files(
- aten_yaml_path=options.aten_yaml_path,
- tags_yaml_path=options.tags_path,
- native_yaml_path=options.functions_yaml_path,
- custom_ops_yaml_path=options.custom_ops_yaml_path,
- selector=selector,
- use_aten_lib=options.use_aten_lib,
- )
- native_functions, backend_indices = (
- parsed_yaml.native_functions,
- parsed_yaml.backend_indices,
- )
- custom_ops_native_functions = (
- custom_ops_parsed_yaml.native_functions if custom_ops_parsed_yaml else []
- )
- cpu_fm = make_file_manager(options=options)
- static_dispatch_idx: List[BackendIndex] = [backend_indices[DispatchKey.CPU]]
- if "headers" in options.generate:
- gen_headers(
- native_functions=native_functions,
- custom_ops_native_functions=custom_ops_native_functions,
- static_dispatch_idx=static_dispatch_idx,
- selector=selector,
- backend_indices=backend_indices,
- cpu_fm=cpu_fm,
- use_aten_lib=options.use_aten_lib,
- )
- if "sources" in options.generate:
- gen_unboxing(
- native_functions=native_functions,
- cpu_fm=cpu_fm,
- selector=selector,
- use_aten_lib=options.use_aten_lib,
- )
- if custom_ops_native_functions:
- gen_custom_ops(
- native_functions=custom_ops_native_functions,
- selector=selector,
- backend_indices=backend_indices,
- cpu_fm=cpu_fm,
- rocm=options.rocm,
- )
- if options.output_dependencies:
- depfile_path = pathlib.Path(options.output_dependencies).resolve()
- depfile_name = depfile_path.name
- depfile_stem = depfile_path.stem
- for fm, prefix in [
- (cpu_fm, ""),
- ]:
- varname = prefix + depfile_stem
- path = depfile_path.parent / (prefix + depfile_name)
- fm.write_outputs(varname, str(path))
- if __name__ == "__main__":
- main()
|