123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126 |
- from collections import defaultdict
- from dataclasses import dataclass
- from typing import Dict, List, Optional, Sequence, Tuple
- from torchgen import dest
- # disable import sorting to avoid circular dependency.
- from torchgen.api.types import DispatcherSignature # isort:skip
- from torchgen.context import method_with_native_function
- from torchgen.model import BackendIndex, DispatchKey, NativeFunction, Variant
- from torchgen.selective_build.selector import SelectiveBuilder
- from torchgen.utils import concatMap, Target
- # Generates RegisterKernelStub.cpp, which provides placeholder kernels for custom operators. This will be used at
- # model authoring side.
- @dataclass(frozen=True)
- class ComputeNativeFunctionStub:
- @method_with_native_function
- def __call__(self, f: NativeFunction) -> Optional[str]:
- if Variant.function not in f.variants:
- return None
- sig = DispatcherSignature.from_schema(
- f.func, prefix=f"wrapper_CPU_{f.func.name.overload_name}_", symint=False
- )
- assert sig is not None
- if len(f.func.returns) == 0:
- ret_name = ""
- elif len(f.func.returns) == 1:
- if f.func.arguments.out:
- ret_name = f.func.arguments.out[0].name
- else:
- ret_name = next(
- (
- a.name
- for a in f.func.arguments.flat_non_out
- if a.type == f.func.returns[0].type
- ),
- "",
- )
- if not ret_name:
- raise Exception(f"Can't handle this return type {f.func}")
- else:
- assert len(f.func.arguments.out) == len(f.func.returns), (
- "Out variant number of returns need to match the number of out arguments."
- f" Got outs {str(f.func.arguments.out)} but returns {str(f.func.returns)}"
- )
- # returns a tuple of out arguments
- tensor_type = "at::Tensor &"
- comma = ", "
- ret_name = f"""::std::tuple<{comma.join([tensor_type] * len(f.func.returns))}>(
- {comma.join([r.name for r in f.func.arguments.out])}
- )"""
- ret_str = f"return {ret_name};" if len(f.func.returns) > 0 else ""
- return f"""
- {sig.defn()} {{
- {ret_str}
- }}
- """
- def gen_custom_ops_registration(
- *,
- native_functions: Sequence[NativeFunction],
- selector: SelectiveBuilder,
- backend_index: BackendIndex,
- rocm: bool,
- ) -> Tuple[str, str]:
- """
- Generate custom ops registration code for dest.RegisterDispatchKey.
- :param native_functions: a sequence of `NativeFunction`
- :param selector: for selective build.
- :param backend_index: kernels for all the ops.
- :param rocm: bool for dest.RegisterDispatchKey.
- :return: generated C++ code to register custom operators into PyTorch
- """
- dispatch_key = DispatchKey.CPU
- static_init_dispatch_registrations = ""
- ns_grouped_native_functions: Dict[str, List[NativeFunction]] = defaultdict(list)
- for native_function in native_functions:
- ns_grouped_native_functions[native_function.namespace].append(native_function)
- for namespace, functions in ns_grouped_native_functions.items():
- if len(functions) == 0:
- continue
- dispatch_registrations_body = "\n".join(
- list(
- concatMap(
- dest.RegisterDispatchKey(
- backend_index,
- Target.REGISTRATION,
- selector,
- rocm=rocm,
- symint=False,
- class_method_name=None,
- skip_dispatcher_op_registration=False,
- ),
- functions,
- )
- )
- )
- static_init_dispatch_registrations += f"""
- TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
- {dispatch_registrations_body}
- }};"""
- anonymous_definition = "\n".join(
- list(
- concatMap(
- dest.RegisterDispatchKey(
- backend_index,
- Target.ANONYMOUS_DEFINITION,
- selector,
- rocm=rocm,
- symint=False,
- class_method_name=None,
- skip_dispatcher_op_registration=False,
- ),
- native_functions,
- )
- )
- )
- return anonymous_definition, static_init_dispatch_registrations
|