gen_static_runtime_ops.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. import argparse
  2. import itertools
  3. import os
  4. from typing import Sequence, TypeVar, Union
  5. from libfb.py.log import set_simple_logging # type: ignore[import]
  6. from torchgen import gen
  7. from torchgen.context import native_function_manager
  8. from torchgen.model import DispatchKey, NativeFunctionsGroup, NativeFunctionsViewGroup
  9. from torchgen.static_runtime import config, generator
  10. # Given a list of `grouped_native_functions` sorted by their op names, return a list of
  11. # lists each of which groups ops that share the base name. For example, `mean` and
  12. # `mean.dim` are grouped together by this function.
  13. NativeGroupT = TypeVar(
  14. "NativeGroupT",
  15. bound=Union[NativeFunctionsGroup, NativeFunctionsViewGroup],
  16. )
  17. def group_functions_by_op_name(
  18. grouped_native_functions: Sequence[NativeGroupT],
  19. ) -> Sequence[Sequence[NativeGroupT]]:
  20. if not grouped_native_functions:
  21. return []
  22. groups = []
  23. def is_supported(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bool:
  24. with native_function_manager(g):
  25. return generator.is_supported(g)
  26. eligible_ops = (g for g in grouped_native_functions if is_supported(g))
  27. groups = [
  28. list(group)
  29. for k, group in (
  30. itertools.groupby(
  31. eligible_ops,
  32. key=lambda g: config.func_name_base_str(g),
  33. )
  34. )
  35. ]
  36. return groups
  37. def clang_format(cpp_file_path: str) -> None:
  38. import subprocess
  39. subprocess.run(["clang-format", "-i", cpp_file_path])
  40. def write_cpp(cpp_ops: Sequence[str], file_path: str) -> None:
  41. code = "\n".join(cpp_ops)
  42. generated = f"""// @lint-ignore-every CLANGTIDY HOWTOEVEN
  43. // AUTO-GENERATED FROM: torchgen/static_runtime/gen_static_runtime_ops.py
  44. #include <torch/csrc/jit/runtime/static/ops.h>
  45. #include <ATen/CPUFunctions.h>
  46. #include <ATen/InferSize.h>
  47. #include <ATen/NativeFunctions.h>
  48. #include <ATen/Parallel.h>
  49. #include <ATen/ScalarOps.h>
  50. #include <ATen/TensorUtils.h>
  51. #include <ATen/cpu/vec/functional.h>
  52. #include <ATen/cpu/vec/vec.h>
  53. #include <ATen/native/EmbeddingBag.h>
  54. #include <ATen/native/Fill.h>
  55. #include <ATen/native/IndexingUtils.h>
  56. #include <ATen/native/NonSymbolicBC.h>
  57. #include <ATen/native/Resize.h>
  58. #include <ATen/native/SharedReduceOps.h>
  59. #include <ATen/native/TensorAdvancedIndexing.h>
  60. #include <ATen/native/cpu/SerialStackImpl.h>
  61. #include <ATen/native/layer_norm.h>
  62. #include <ATen/native/quantized/cpu/fbgemm_utils.h>
  63. #include <ATen/native/quantized/cpu/qembeddingbag.h>
  64. #include <ATen/native/quantized/cpu/qembeddingbag_prepack.h>
  65. #include <ATen/quantized/QTensorImpl.h>
  66. #include <ATen/quantized/Quantizer.h>
  67. #include <c10/core/ScalarType.h>
  68. #include <c10/core/WrapDimMinimal.h>
  69. #include <c10/util/irange.h>
  70. #include <torch/csrc/jit/ir/ir.h>
  71. #include <torch/csrc/jit/runtime/static/impl.h>
  72. #include <torch/csrc/jit/runtime/static/te_wrapper.h>
  73. #include <torch/csrc/jit/runtime/vararg_functions.h>
  74. #include <torch/csrc/jit/tensorexpr/ir.h>
  75. #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
  76. #include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
  77. #include <torch/csrc/jit/tensorexpr/loopnest.h>
  78. namespace torch {{
  79. namespace jit {{
  80. {code}
  81. }} // namespace jit
  82. }} // namespace torch
  83. """
  84. with open(file_path, "w") as f:
  85. f.write(generated)
  86. clang_format(file_path)
  87. def write_test_cpp(cpp_ops: Sequence[str], file_path: str) -> None:
  88. code = "\n".join(cpp_ops)
  89. generated = f"""// @lint-ignore-every CLANGTIDY HOWTOEVEN
  90. // AUTO-GENERATED FROM: torchgen/static_runtime/gen_static_runtime_ops.py
  91. #include <gtest/gtest.h>
  92. #include <torch/csrc/jit/runtime/static/impl.h>
  93. #include <torch/torch.h>
  94. #include "test_utils.h"
  95. using namespace caffe2;
  96. using namespace torch;
  97. using namespace torch::jit;
  98. using namespace torch::jit::test;
  99. using c10::IValue;
  100. {code}
  101. """
  102. with open(file_path, "w") as f:
  103. f.write(generated)
  104. clang_format(file_path)
  105. def main() -> None:
  106. parser = argparse.ArgumentParser(description="Generate ATen source files")
  107. parser.add_argument(
  108. "-s",
  109. "--source-path",
  110. help="path to source directory for ATen",
  111. default="caffe2/aten/src/ATen",
  112. )
  113. parser.add_argument(
  114. "-p",
  115. "--generated-ops-cpp-path",
  116. help="path to directory to generate op dispatcher .cpp file",
  117. default="caffe2/torch/csrc/jit/runtime/static/generated_ops.cpp",
  118. )
  119. parser.add_argument(
  120. "-t",
  121. "--generated-ops-test-cpp-path",
  122. help="path to directory to generate op dispatcher .cpp file",
  123. default="caffe2/benchmarks/static_runtime/test_generated_ops.cc",
  124. )
  125. options = parser.parse_args()
  126. native_yaml_path = os.path.join(options.source_path, "native/native_functions.yaml")
  127. tags_yaml_path = os.path.join(options.source_path, "native/tags.yaml")
  128. parsed_yaml = gen.parse_native_yaml(native_yaml_path, tags_yaml_path)
  129. native_functions, backend_indices = (
  130. parsed_yaml.native_functions,
  131. parsed_yaml.backend_indices,
  132. )
  133. op_generator = generator.GenOpDispatcher()
  134. test_case_generator = generator.GenOpTestCase()
  135. native_functions_groups = [
  136. g
  137. for g in gen.get_grouped_native_functions(native_functions)
  138. if isinstance(g, NativeFunctionsGroup)
  139. ]
  140. supported_functions_groups = group_functions_by_op_name(native_functions_groups)
  141. out_variant_op_result = [
  142. op_generator.out_variant(groups, backend_indices[DispatchKey.CPU])
  143. for groups in supported_functions_groups
  144. ]
  145. out_variant_test_result = [
  146. test_case_generator.out_variant(groups) for groups in supported_functions_groups
  147. ]
  148. native_functions_view_groups = [
  149. g
  150. for g in gen.get_grouped_by_view_native_functions(native_functions)
  151. if isinstance(g, NativeFunctionsViewGroup)
  152. ]
  153. supported_functions_view_groups = group_functions_by_op_name(
  154. native_functions_view_groups
  155. )
  156. view_op_result = [
  157. op_generator.view(groups, backend_indices[DispatchKey.CPU])
  158. for groups in supported_functions_view_groups
  159. ]
  160. view_test_result = [
  161. test_case_generator.view(groups) for groups in supported_functions_view_groups
  162. ]
  163. op_result = out_variant_op_result + ["\n\n"] + view_op_result
  164. test_result = out_variant_test_result + ["\n\n"] + view_test_result
  165. write_cpp(op_result, options.generated_ops_cpp_path)
  166. write_test_cpp(test_result, options.generated_ops_test_cpp_path)
  167. print(
  168. "\ntotal grouped native ops: %d"
  169. % len(gen.get_grouped_native_functions(native_functions))
  170. )
  171. print("grouped native ops with out variant: %d" % len(native_functions_groups))
  172. supported_functions_num = sum(
  173. [len(groups) for groups in supported_functions_groups]
  174. )
  175. print("generated functions groups with out variant: %d" % supported_functions_num)
  176. print("\nview grouped native ops: %d" % len(native_functions_view_groups))
  177. supported_view_functions_num = sum(
  178. [len(groups) for groups in supported_functions_view_groups]
  179. )
  180. print("generated functions view groups: %d" % supported_view_functions_num)
  181. print(
  182. "\noverall generated : %d"
  183. % (supported_functions_num + supported_view_functions_num)
  184. )
  185. if __name__ == "__main__":
  186. set_simple_logging(escape_newlines=False)
  187. main()