native_functions.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. from typing import List, Optional, Union
  2. import torchgen.api.meta as meta
  3. import torchgen.api.structured as structured
  4. from torchgen.api.types import kernel_signature
  5. from torchgen.context import with_native_function_and_index
  6. from torchgen.model import BackendIndex, NativeFunction, NativeFunctionsGroup
  7. from torchgen.utils import mapMaybe
  8. @with_native_function_and_index
  9. def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> Optional[str]:
  10. sig = kernel_signature(f, backend_index)
  11. metadata = backend_index.get_kernel(f)
  12. if metadata is None:
  13. return None
  14. if "legacy::" in metadata.kernel:
  15. return None
  16. else:
  17. prefix = "static" if backend_index.external else "TORCH_API"
  18. return f"{prefix} {sig.decl(name=metadata.kernel)};"
  19. @with_native_function_and_index
  20. def gen_structured(g: NativeFunctionsGroup, backend_index: BackendIndex) -> List[str]:
  21. meta_name = meta.name(g)
  22. out_args = structured.impl_arguments(g)
  23. metadata = backend_index.get_kernel(g)
  24. if metadata is None:
  25. return []
  26. prefix = "" if backend_index.external else "TORCH_API "
  27. return [
  28. f"""\
  29. struct {prefix}structured_{metadata.kernel} : public at::meta::structured_{meta_name} {{
  30. void impl({', '.join(a.decl() for a in out_args)});
  31. }};
  32. """
  33. ]
  34. # Generates NativeFunctions.h, a list of forward declarations of all
  35. # actual kernel definitions we keep in aten/src/ATen/native/
  36. @with_native_function_and_index
  37. def compute_native_function_declaration(
  38. g: Union[NativeFunctionsGroup, NativeFunction], backend_index: BackendIndex
  39. ) -> List[str]:
  40. metadata = backend_index.get_kernel(g)
  41. if isinstance(g, NativeFunctionsGroup):
  42. if metadata is not None and metadata.structured:
  43. if backend_index.external:
  44. # Structured hasn't been tested with external backends yet.
  45. raise AssertionError(
  46. "Structured external backend functions are not implemented yet."
  47. )
  48. else:
  49. return gen_structured(g, backend_index)
  50. else:
  51. return list(
  52. mapMaybe(lambda f: gen_unstructured(f, backend_index), g.functions())
  53. )
  54. else:
  55. x = gen_unstructured(g, backend_index)
  56. return [] if x is None else [x]