12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364 |
- from typing import List, Optional, Union
- import torchgen.api.meta as meta
- import torchgen.api.structured as structured
- from torchgen.api.types import kernel_signature
- from torchgen.context import with_native_function_and_index
- from torchgen.model import BackendIndex, NativeFunction, NativeFunctionsGroup
- from torchgen.utils import mapMaybe
- @with_native_function_and_index
- def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> Optional[str]:
- sig = kernel_signature(f, backend_index)
- metadata = backend_index.get_kernel(f)
- if metadata is None:
- return None
- if "legacy::" in metadata.kernel:
- return None
- else:
- prefix = "static" if backend_index.external else "TORCH_API"
- return f"{prefix} {sig.decl(name=metadata.kernel)};"
- @with_native_function_and_index
- def gen_structured(g: NativeFunctionsGroup, backend_index: BackendIndex) -> List[str]:
- meta_name = meta.name(g)
- out_args = structured.impl_arguments(g)
- metadata = backend_index.get_kernel(g)
- if metadata is None:
- return []
- prefix = "" if backend_index.external else "TORCH_API "
- return [
- f"""\
- struct {prefix}structured_{metadata.kernel} : public at::meta::structured_{meta_name} {{
- void impl({', '.join(a.decl() for a in out_args)});
- }};
- """
- ]
- # Generates NativeFunctions.h, a list of forward declarations of all
- # actual kernel definitions we keep in aten/src/ATen/native/
- @with_native_function_and_index
- def compute_native_function_declaration(
- g: Union[NativeFunctionsGroup, NativeFunction], backend_index: BackendIndex
- ) -> List[str]:
- metadata = backend_index.get_kernel(g)
- if isinstance(g, NativeFunctionsGroup):
- if metadata is not None and metadata.structured:
- if backend_index.external:
- # Structured hasn't been tested with external backends yet.
- raise AssertionError(
- "Structured external backend functions are not implemented yet."
- )
- else:
- return gen_structured(g, backend_index)
- else:
- return list(
- mapMaybe(lambda f: gen_unstructured(f, backend_index), g.functions())
- )
- else:
- x = gen_unstructured(g, backend_index)
- return [] if x is None else [x]
|