custom_ops.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. from collections import defaultdict
  2. from dataclasses import dataclass
  3. from typing import Dict, List, Optional, Sequence, Tuple
  4. from torchgen import dest
  5. # disable import sorting to avoid circular dependency.
  6. from torchgen.api.types import DispatcherSignature # isort:skip
  7. from torchgen.context import method_with_native_function
  8. from torchgen.model import BackendIndex, DispatchKey, NativeFunction, Variant
  9. from torchgen.selective_build.selector import SelectiveBuilder
  10. from torchgen.utils import concatMap, Target
  11. # Generates RegisterKernelStub.cpp, which provides placeholder kernels for custom operators. This will be used at
  12. # model authoring side.
  13. @dataclass(frozen=True)
  14. class ComputeNativeFunctionStub:
  15. @method_with_native_function
  16. def __call__(self, f: NativeFunction) -> Optional[str]:
  17. if Variant.function not in f.variants:
  18. return None
  19. sig = DispatcherSignature.from_schema(
  20. f.func, prefix=f"wrapper_CPU_{f.func.name.overload_name}_", symint=False
  21. )
  22. assert sig is not None
  23. if len(f.func.returns) == 0:
  24. ret_name = ""
  25. elif len(f.func.returns) == 1:
  26. if f.func.arguments.out:
  27. ret_name = f.func.arguments.out[0].name
  28. else:
  29. ret_name = next(
  30. (
  31. a.name
  32. for a in f.func.arguments.flat_non_out
  33. if a.type == f.func.returns[0].type
  34. ),
  35. "",
  36. )
  37. if not ret_name:
  38. raise Exception(f"Can't handle this return type {f.func}")
  39. else:
  40. assert len(f.func.arguments.out) == len(f.func.returns), (
  41. "Out variant number of returns need to match the number of out arguments."
  42. f" Got outs {str(f.func.arguments.out)} but returns {str(f.func.returns)}"
  43. )
  44. # returns a tuple of out arguments
  45. tensor_type = "at::Tensor &"
  46. comma = ", "
  47. ret_name = f"""::std::tuple<{comma.join([tensor_type] * len(f.func.returns))}>(
  48. {comma.join([r.name for r in f.func.arguments.out])}
  49. )"""
  50. ret_str = f"return {ret_name};" if len(f.func.returns) > 0 else ""
  51. return f"""
  52. {sig.defn()} {{
  53. {ret_str}
  54. }}
  55. """
  56. def gen_custom_ops_registration(
  57. *,
  58. native_functions: Sequence[NativeFunction],
  59. selector: SelectiveBuilder,
  60. backend_index: BackendIndex,
  61. rocm: bool,
  62. ) -> Tuple[str, str]:
  63. """
  64. Generate custom ops registration code for dest.RegisterDispatchKey.
  65. :param native_functions: a sequence of `NativeFunction`
  66. :param selector: for selective build.
  67. :param backend_index: kernels for all the ops.
  68. :param rocm: bool for dest.RegisterDispatchKey.
  69. :return: generated C++ code to register custom operators into PyTorch
  70. """
  71. dispatch_key = DispatchKey.CPU
  72. static_init_dispatch_registrations = ""
  73. ns_grouped_native_functions: Dict[str, List[NativeFunction]] = defaultdict(list)
  74. for native_function in native_functions:
  75. ns_grouped_native_functions[native_function.namespace].append(native_function)
  76. for namespace, functions in ns_grouped_native_functions.items():
  77. if len(functions) == 0:
  78. continue
  79. dispatch_registrations_body = "\n".join(
  80. list(
  81. concatMap(
  82. dest.RegisterDispatchKey(
  83. backend_index,
  84. Target.REGISTRATION,
  85. selector,
  86. rocm=rocm,
  87. symint=False,
  88. class_method_name=None,
  89. skip_dispatcher_op_registration=False,
  90. ),
  91. functions,
  92. )
  93. )
  94. )
  95. static_init_dispatch_registrations += f"""
  96. TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
  97. {dispatch_registrations_body}
  98. }};"""
  99. anonymous_definition = "\n".join(
  100. list(
  101. concatMap(
  102. dest.RegisterDispatchKey(
  103. backend_index,
  104. Target.ANONYMOUS_DEFINITION,
  105. selector,
  106. rocm=rocm,
  107. symint=False,
  108. class_method_name=None,
  109. skip_dispatcher_op_registration=False,
  110. ),
  111. native_functions,
  112. )
  113. )
  114. )
  115. return anonymous_definition, static_init_dispatch_registrations