generator.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795
  1. import json
  2. import logging
  3. import math
  4. from typing import Dict, List, Optional, Sequence, Tuple, Union
  5. import torchgen.api.cpp as cpp
  6. from torchgen.context import native_function_manager
  7. from torchgen.model import (
  8. Argument,
  9. BackendIndex,
  10. BaseTy,
  11. BaseType,
  12. FunctionSchema,
  13. NativeFunctionsGroup,
  14. NativeFunctionsViewGroup,
  15. OptionalType,
  16. SelfArgument,
  17. TensorOptionsArguments,
  18. Type,
  19. )
  20. from torchgen.static_runtime import config
  21. logger: logging.Logger = logging.getLogger()
  22. def has_alias(
  23. arguments: Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]]
  24. ) -> bool:
  25. for arg in arguments:
  26. annotation = getattr(arg, "annotation", None)
  27. if not annotation:
  28. continue
  29. alias_set = getattr(annotation, "alias_set", ())
  30. if alias_set:
  31. return True
  32. return False
  33. BLOCKED_OPS = frozenset(
  34. (
  35. # non cpu ops
  36. "sparse_sampled_addmm",
  37. "hspmm",
  38. "linalg_svdvals",
  39. # sparse ops
  40. "sspaddmm",
  41. "coalesce",
  42. "_indices",
  43. "indices",
  44. "_values",
  45. "values",
  46. "crow_indices",
  47. "col_indices",
  48. # deprecated ops
  49. "floor_divide",
  50. "ger",
  51. # buggy ops
  52. "conj_physical", # P495807361
  53. "binary_cross_entropy", # P496394764
  54. "arccosh",
  55. # uncommon ops
  56. "cholesky",
  57. "lu_solve",
  58. "linalg_cholesky",
  59. "linalg_householder_product",
  60. "linalg_ldl_solve",
  61. "_compute_linear_combination",
  62. # training related ops
  63. "_make_dual",
  64. # cannot call directly
  65. "_fw_primal",
  66. # no documentation
  67. "_index_reduce",
  68. # TODO: these ones got added recently and need manual inspection
  69. "_new_zeros_with_same_feature_meta",
  70. "_conj_physical",
  71. "binary_cross_entropy_with_logits",
  72. "bincount",
  73. "conv_tbc",
  74. "copy",
  75. "_copy_from",
  76. "_copy_from_and_resize",
  77. "count_nonzero",
  78. "cudnn_affine_grid_generator",
  79. "cudnn_affine_grid_generator_backward",
  80. "cudnn_grid_sampler",
  81. "diag_embed",
  82. "embedding",
  83. "embedding_dense_backward",
  84. "_embedding_bag_dense_backward",
  85. "_embedding_bag_per_sample_weights_backward",
  86. "grid_sampler_2d",
  87. "_grid_sampler_2d_cpu_fallback",
  88. "grid_sampler_3d",
  89. "isnan",
  90. "mkldnn_linear",
  91. "median",
  92. "nanmedian",
  93. "_sparse_sparse_matmul",
  94. "batch_norm_backward_elemt",
  95. "_euclidean_dist",
  96. "pixel_shuffle",
  97. "pixel_unshuffle",
  98. "channel_shuffle",
  99. "_reshape_nested_backward",
  100. "relu",
  101. "prelu",
  102. "celu",
  103. "slice_scatter",
  104. "select_scatter",
  105. "diagonal_scatter",
  106. "sum",
  107. "_mkldnn_transpose",
  108. "_nested_tensor_from_mask",
  109. "_nested_from_padded",
  110. "_nested_tensor_size",
  111. "_nested_from_padded_and_nested_example",
  112. "_standard_gamma_grad",
  113. "_dirichlet_grad",
  114. "native_norm",
  115. "_sparse_softmax",
  116. "_sparse_softmax_backward_data",
  117. "_sparse_log_softmax",
  118. "_sparse_log_softmax_backward_data",
  119. "zero",
  120. "_sparse_addmm",
  121. "sparse_mask",
  122. "_to_dense",
  123. "_coalesce",
  124. "_coalesced",
  125. "copy_sparse_to_sparse",
  126. "to_sparse",
  127. "to_sparse_csr",
  128. "to_sparse_csc",
  129. "to_mkldnn",
  130. "quantize_per_tensor_dynamic",
  131. "quantize_per_channel",
  132. "q_per_channel_scales",
  133. "q_per_channel_zero_points",
  134. "int_repr",
  135. "_make_per_channel_quantized_tensor",
  136. "set",
  137. "lift",
  138. "lift_fresh",
  139. "lift_fresh_copy",
  140. "masked_scatter",
  141. "_masked_softmax",
  142. "_masked_softmax_backward",
  143. "put",
  144. "index_reduce",
  145. "trace",
  146. "_cholesky_solve_helper",
  147. "dist",
  148. "max",
  149. "_torch_cuda_cu_linker_symbol_op",
  150. "glu_jvp",
  151. "glu_backward_jvp",
  152. "hardswish_backward",
  153. "rrelu_with_noise_backward",
  154. "mkldnn_adaptive_avg_pool2d_backward",
  155. "_adaptive_avg_pool2d_backward",
  156. "_adaptive_avg_pool3d_backward",
  157. "isinf",
  158. "linalg_lu_solve",
  159. "linalg_vecdot",
  160. "linalg_matrix_exp",
  161. "linalg_eigvalsh",
  162. "_test_warn_in_autograd",
  163. "_test_autograd_multiple_dispatch_view",
  164. "_test_autograd_multiple_dispatch_view_copy",
  165. "_segment_reduce",
  166. "_segment_reduce_backward",
  167. "_fw_primal_copy",
  168. "_make_dual_copy",
  169. "view_as_real_copy",
  170. "view_as_complex_copy",
  171. "_conj_copy",
  172. "_neg_view_copy",
  173. "diagonal_copy",
  174. "detach_copy",
  175. "squeeze_copy",
  176. "t_copy",
  177. "unsqueeze_copy",
  178. "_indices_copy",
  179. "_values_copy",
  180. "indices_copy",
  181. "values_copy",
  182. "crow_indices_copy",
  183. "col_indices_copy",
  184. "ccol_indices",
  185. "ccol_indices_copy",
  186. "row_indices",
  187. "row_indices_copy",
  188. "unfold_copy",
  189. "alias_copy",
  190. "_triton_multi_head_attention",
  191. "special_airy_ai",
  192. "special_bessel_j0",
  193. "special_bessel_j1",
  194. "special_bessel_y0",
  195. "special_bessel_y1",
  196. "special_chebyshev_polynomial_t",
  197. "special_chebyshev_polynomial_u",
  198. "special_chebyshev_polynomial_v",
  199. "special_chebyshev_polynomial_w",
  200. "special_hermite_polynomial_h",
  201. "special_hermite_polynomial_he",
  202. "special_laguerre_polynomial_l",
  203. "special_legendre_polynomial_p",
  204. "special_modified_bessel_i0",
  205. "special_modified_bessel_i1",
  206. "special_modified_bessel_k0",
  207. "special_modified_bessel_k1",
  208. "special_scaled_modified_bessel_k0",
  209. "special_scaled_modified_bessel_k1",
  210. "special_shifted_chebyshev_polynomial_t",
  211. "special_shifted_chebyshev_polynomial_u",
  212. "special_shifted_chebyshev_polynomial_v",
  213. "special_shifted_chebyshev_polynomial_w",
  214. "special_spherical_bessel_j0",
  215. "_foobar",
  216. "_nested_tensor_strides",
  217. )
  218. )
  219. def is_supported(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bool:
  220. base_op_name = ""
  221. func = None
  222. if isinstance(g, NativeFunctionsViewGroup):
  223. base_op_name = g.view.root_name
  224. func = g.view.func
  225. else:
  226. base_op_name = g.out.func.name.name.base
  227. func = g.out.func
  228. if config.is_hand_written(g):
  229. logger.info(f"HAND WRITTEN: {base_op_name}")
  230. return False
  231. if base_op_name in BLOCKED_OPS:
  232. logger.info(f"BLOCKED: {base_op_name}")
  233. return False
  234. for arg in func.schema_order_arguments():
  235. maybe_method = ivalue_type_conversion_method(arg.type)
  236. if not maybe_method:
  237. # Type converting is unsupported yet.
  238. logger.info(f"NOT SUPPORTED TYPE CONVERTING: {str(func)}")
  239. return False
  240. if isinstance(g, NativeFunctionsViewGroup):
  241. # TODO: stop doing type tests by converting to C++ and then testing
  242. # the string, just test the dang thing directly
  243. if "at::Tensor" != cpp.returns_type(func.returns, symint=False).cpp_type():
  244. # Returns a non-Tensor value.
  245. logger.info(f"NON-TENSOR RET TYPE: {str(func)}")
  246. return False
  247. return True
  248. # For out variant ops, we need to check the arguments of its functional func.
  249. for arg in g.functional.func.schema_order_arguments():
  250. maybe_method = ivalue_type_conversion_method(arg.type)
  251. if not maybe_method:
  252. # Type converting is unsupported yet.
  253. logger.info(f"NOT SUPPORTED TYPE CONVERTING: {str(g.functional.func)}")
  254. return False
  255. if not g.structured:
  256. # In case of unstructured op, we check if it has out variant implementation.
  257. # The out variant implementation satisfies the minimum requirement that it has the output tensor as the last
  258. # parameter.
  259. if (
  260. not hasattr(g, "out")
  261. or not str(func).endswith("Tensor(a!) out) -> Tensor(a!)")
  262. or not str(func.name).endswith(".out")
  263. ):
  264. return False
  265. # TODO: stop type testing by converting to C++
  266. if "at::Tensor &" != cpp.returns_type(func.returns, symint=False).cpp_type():
  267. logger.info(f"NON_TENSOR RET TYPE: {str(func)}")
  268. return False
  269. if has_alias(func.arguments.non_out):
  270. # This op may create an alias of inputs.
  271. logger.info(f"INPUTS ALIAS: {base_op_name}")
  272. return False
  273. return True
  274. def ivalue_type_conversion_method(
  275. arg_type: Union[BaseType, OptionalType, Type]
  276. ) -> Optional[Tuple[bool, str]]:
  277. """
  278. Return the method call expression of `c10::ivalue' to convert its contained value to
  279. the expected value of `arg_type` type. For example, for `arg_type` == BaseTy.Tensor,
  280. this function returns ".toTensor()", so that it can be appended to the ivalue's
  281. variable name to get the value of the expected type.
  282. """
  283. type_conversion_methods = {
  284. BaseTy.Tensor: ((True, "toTensor()"), (False, "toOptional<at::Tensor>()")),
  285. BaseTy.int: ((False, "toInt()"), (False, "toOptional<int64_t>()")),
  286. BaseTy.bool: ((False, "toBool()"), (False, "toOptional<bool>()")),
  287. BaseTy.Scalar: ((False, "toScalar()"), (False, "toOptional<at::Scalar>()")),
  288. BaseTy.ScalarType: (
  289. (False, "toScalarType()"),
  290. (False, "toOptional<at::ScalarType>()"),
  291. ),
  292. BaseTy.str: (
  293. (False, "toStringView()"),
  294. (False, "toOptional<c10::string_view>()"),
  295. ),
  296. }
  297. base_ty_object = None
  298. if isinstance(arg_type, BaseType):
  299. base_ty_object = arg_type.name
  300. elif isinstance(arg_type, OptionalType):
  301. if not isinstance(arg_type.elem, BaseType):
  302. # ListType is currently unsupported.
  303. return None
  304. base_ty_object = arg_type.elem.name
  305. else:
  306. return None
  307. if base_ty_object not in type_conversion_methods:
  308. return None
  309. methods = type_conversion_methods[base_ty_object]
  310. if isinstance(arg_type, BaseType):
  311. return methods[0]
  312. return methods[1]
  313. should_use_int_tensor_ops_ = frozenset(
  314. (
  315. "bitwise_not",
  316. "bitwise_and",
  317. "bitwise_or",
  318. "bitwise_xor",
  319. "bitwise_left_shift",
  320. "bitwise_right_shift",
  321. "gcd",
  322. "lcm",
  323. "scatter",
  324. "gather",
  325. "_convert_indices_from_coo_to_csr",
  326. "_convert_indices_from_csr_to_coo",
  327. )
  328. )
  329. should_use_complex_tensor_ops_ = frozenset(("view_as_real", "imag", "_conj"))
  330. def should_use_int_tensor(op_name: str) -> bool:
  331. return op_name in should_use_int_tensor_ops_
  332. def should_use_complex_tensor(op_name: str) -> bool:
  333. return op_name in should_use_complex_tensor_ops_
  334. test_tensor_dim_ops_1_ = frozenset(
  335. (
  336. "addmv",
  337. "index_add",
  338. "_convert_indices_from_coo_to_csr",
  339. "_convert_indices_from_csr_to_coo",
  340. "nll_loss_backward",
  341. "dot",
  342. "vdot",
  343. "outer",
  344. "ger",
  345. )
  346. )
  347. test_tensor_dim_ops_2_ = frozenset(
  348. ("addmm", "mm", "nuclear_norm", "diag", "_addmm_activation", "matrix_H", "t")
  349. )
  350. def test_tensor_dim(op_name: str) -> int:
  351. if op_name in test_tensor_dim_ops_1_:
  352. return 1
  353. if op_name in test_tensor_dim_ops_2_:
  354. return 2
  355. return 3
  356. test_tensor_shapes_string = '{"view_as_complex": "{2, 2}"}'
  357. test_tensor_shape_json: Dict[str, str] = json.loads(test_tensor_shapes_string)
  358. def test_tensor_shape(op_name: str) -> str:
  359. if op_name in test_tensor_shape_json:
  360. return test_tensor_shape_json[op_name]
  361. else:
  362. return ""
  363. def test_value_expression(
  364. arg_type: Union[BaseType, OptionalType, Type], index: int, op_name: str
  365. ) -> str:
  366. tensor_size_ex = test_tensor_shape(op_name)
  367. if tensor_size_ex == "":
  368. num_tensors = 16 if index == 0 else 64
  369. num_dim = test_tensor_dim(op_name)
  370. size_per_dim = math.ceil(num_tensors / float(num_dim))
  371. size_per_dim += size_per_dim % 2
  372. tensor_size_ex = "{%s}" % (",".join([f"{size_per_dim}"] * num_dim))
  373. if should_use_int_tensor(op_name):
  374. tensor_expression = f"at::randint(1, 100, {tensor_size_ex}, at::kInt)"
  375. elif should_use_complex_tensor(op_name):
  376. tensor_expression = f"at::randn({tensor_size_ex}, at::kComplexFloat)"
  377. else:
  378. tensor_expression = f"at::rand({tensor_size_ex})"
  379. value_expressions = {
  380. BaseTy.Tensor: tensor_expression,
  381. BaseTy.int: "1",
  382. BaseTy.bool: "false",
  383. BaseTy.Scalar: "2",
  384. BaseTy.ScalarType: "at::ScalarType::Float",
  385. BaseTy.str: '"floor"',
  386. }
  387. base_ty_object = None
  388. if isinstance(arg_type, BaseType):
  389. base_ty_object = arg_type.name
  390. else:
  391. assert isinstance(arg_type, OptionalType) and isinstance(
  392. arg_type.elem, BaseType
  393. )
  394. base_ty_object = arg_type.elem.name
  395. assert base_ty_object in value_expressions, "not expected type"
  396. value_expression = value_expressions[base_ty_object]
  397. return value_expression
  398. def generate_test_value_definitions(schema: FunctionSchema, index: int) -> str:
  399. assert not schema.is_out_fn()
  400. schema_name = schema.name.name.base
  401. arg_map = {}
  402. for arg in schema.schema_order_arguments():
  403. test_value_exp = test_value_expression(arg.type, index, schema_name)
  404. arg_map[arg.name] = test_value_exp
  405. config.override_test_values(arg_map, schema_name, index)
  406. arg_populations = []
  407. for arg_name, arg_value in arg_map.items():
  408. arg_populations.append(f"auto {arg_name}{index} = {arg_value}")
  409. return ";\n ".join(arg_populations) + ";"
  410. def generate_test_value_names(schema: FunctionSchema, index: int) -> str:
  411. assert not schema.is_out_fn()
  412. return ",".join(f"{arg.name}{index}" for arg in schema.schema_order_arguments())
  413. generate_test_ir_arguments_base_ty_to_type_str_ = {
  414. BaseTy.Tensor: "Tensor",
  415. BaseTy.int: "int",
  416. BaseTy.float: "float",
  417. BaseTy.str: "str",
  418. BaseTy.Scalar: "int",
  419. BaseTy.ScalarType: "int",
  420. BaseTy.bool: "bool",
  421. }
  422. def generate_test_ir_arguments(
  423. schema: FunctionSchema,
  424. ) -> List[Tuple[str, Optional[str]]]:
  425. def ir_argument(arg: Argument) -> Tuple[str, Optional[str]]:
  426. t = arg.type
  427. add_optional = False
  428. if isinstance(t, OptionalType):
  429. t = t.elem
  430. add_optional = True
  431. assert isinstance(t, BaseType)
  432. type_str = None
  433. if t.name in generate_test_ir_arguments_base_ty_to_type_str_:
  434. type_str = generate_test_ir_arguments_base_ty_to_type_str_[t.name]
  435. if type_str and add_optional:
  436. type_str = f"{type_str}?"
  437. return ("%" + arg.name, type_str)
  438. return [ir_argument(arg) for arg in schema.schema_order_arguments()]
  439. def generate_arg_extraction(schema: FunctionSchema) -> str:
  440. arg_populations = []
  441. for i, arg in enumerate(schema.schema_order_arguments()):
  442. maybe_method = ivalue_type_conversion_method(arg.type)
  443. assert maybe_method
  444. is_reference, type_conversion_method = maybe_method
  445. reference = "&" if is_reference else ""
  446. arg_populations.append(
  447. f"const auto{reference} {arg.name} = p_node->Input({i}).{type_conversion_method}"
  448. )
  449. return ";\n ".join(arg_populations) + ";"
  450. def get_kernel_name(g: NativeFunctionsGroup, backend_index: BackendIndex) -> str:
  451. kernel = backend_index.get_kernel(g.functional)
  452. if g.structured or kernel is None:
  453. return cpp.name(g.functional.func)
  454. return kernel.kernel
  455. def get_out_kernel_name(g: NativeFunctionsGroup, backend_index: BackendIndex) -> str:
  456. kernel = backend_index.get_kernel(g.out)
  457. if g.structured or kernel is None:
  458. return cpp.name(g.out.func)
  459. return kernel.kernel
  460. def generate_non_out_variant_call(
  461. g: NativeFunctionsGroup, backend_index: BackendIndex
  462. ) -> str:
  463. schema = g.functional.func
  464. assert not schema.is_out_fn()
  465. kernel_name = get_kernel_name(g, backend_index)
  466. arg_names = (arg.name for arg in schema.schema_order_arguments())
  467. namespace_name = "cpu" if g.structured else "native"
  468. return f'at::{namespace_name}::{kernel_name}({",".join(arg_names)})'
  469. def generate_call_to_view_ops(
  470. g: NativeFunctionsViewGroup, backend_index: BackendIndex
  471. ) -> str:
  472. schema = g.view.func
  473. kernel_name = cpp.name(schema)
  474. kernel = backend_index.get_kernel(g.view)
  475. if kernel:
  476. kernel_name = kernel.kernel
  477. arg_names = (arg.name for arg in schema.schema_order_arguments())
  478. namespace_name = "native"
  479. return f'at::{namespace_name}::{kernel_name}({",".join(arg_names)})'
  480. def generate_out_variant_call(
  481. g: NativeFunctionsGroup, backend_index: BackendIndex
  482. ) -> str:
  483. schema = g.out.func
  484. assert schema.is_out_fn()
  485. arg_names = []
  486. kernel_name = get_out_kernel_name(g, backend_index)
  487. if g.structured:
  488. # structured op starts with the output tensor argument.
  489. arg_names = [out_arg.name for out_arg in schema.arguments.out]
  490. else:
  491. arg_names = []
  492. for arg in schema.arguments.non_out:
  493. if isinstance(arg, SelfArgument):
  494. arg_names.append(arg.argument.name)
  495. else:
  496. assert isinstance(arg, Argument)
  497. arg_names.append(arg.name)
  498. if not g.structured:
  499. assert len(schema.arguments.out) == 1
  500. arg_names.append(schema.arguments.out[0].name)
  501. cpp_arg_names = ",".join(arg_names)
  502. namespace_name = "cpu" if g.structured else "native"
  503. return f"at::{namespace_name}::{kernel_name}({cpp_arg_names})"
  504. no_memory_resize_ops = frozenset(
  505. (
  506. "isin.Scalar_Tensor",
  507. "index_add",
  508. "dot",
  509. "vdot",
  510. "nuclear_norm",
  511. "histc",
  512. "l1_loss",
  513. "multi_margin_loss",
  514. "multilabel_margin_loss",
  515. "nll_loss",
  516. "nll_loss2d",
  517. "prod",
  518. )
  519. )
  520. def should_check_resize(schema: FunctionSchema) -> bool:
  521. schema_str = str(schema)
  522. type_variant_op_name = schema_str[: schema_str.find("(")]
  523. return type_variant_op_name not in no_memory_resize_ops
  524. def op_name_from_group(g: NativeFunctionsGroup) -> str:
  525. return g.functional.func.name.name.base
  526. class GenOpDispatcher:
  527. def out_variant(
  528. self, groups: Sequence[NativeFunctionsGroup], backend_index: BackendIndex
  529. ) -> str:
  530. if not groups:
  531. return ""
  532. generated_type_variants = []
  533. for g in groups:
  534. with native_function_manager(g):
  535. assert is_supported(g)
  536. assert isinstance(g, NativeFunctionsGroup)
  537. generated_type_variant = self.out_variant_op_generator(g, backend_index)
  538. generated_type_variants.append(generated_type_variant)
  539. op_name = op_name_from_group(groups[0])
  540. body = "\n".join(generated_type_variants)
  541. generated = f"""
  542. REGISTER_OPERATOR_FUNCTOR(
  543. aten::{op_name},
  544. aten_{op_name},
  545. [](Node* n) -> SROperator {{
  546. {body}
  547. LogAndDumpSchema(n);
  548. return nullptr;
  549. }});
  550. """
  551. return generated
  552. def view(
  553. self, groups: Sequence[NativeFunctionsViewGroup], backend_index: BackendIndex
  554. ) -> str:
  555. if not groups:
  556. return ""
  557. generated_type_variants = []
  558. for g in groups:
  559. with native_function_manager(g):
  560. assert is_supported(g)
  561. assert isinstance(g, NativeFunctionsViewGroup)
  562. generated_type_variant = self.view_op_generator(g, backend_index)
  563. generated_type_variants.append(generated_type_variant)
  564. op_name = config.func_name_base_str(groups[0])
  565. body = "\n".join(generated_type_variants)
  566. generated = f"""
  567. REGISTER_NATIVE_OPERATOR_FUNCTOR(
  568. aten::{op_name},
  569. aten_{op_name},
  570. [](Node* n) -> SROperator {{
  571. {body}
  572. LogAndDumpSchema(n);
  573. return nullptr;
  574. }});
  575. """
  576. return generated
  577. def out_variant_op_generator(
  578. self, g: NativeFunctionsGroup, backend_index: BackendIndex
  579. ) -> str:
  580. functional = g.functional
  581. schema = str(functional.func)
  582. populated_argument = generate_arg_extraction(g.functional.func)
  583. functional_variant_call = generate_non_out_variant_call(g, backend_index)
  584. assert len(g.out.func.arguments.out) == 1
  585. out_variable_name = str(g.out.func.arguments.out[0].name)
  586. out_variant_call = generate_out_variant_call(g, backend_index)
  587. generated = f"""
  588. if (n->matches(torch::schema("aten::{schema}"))) {{
  589. return [](ProcessedNode* p_node) {{
  590. {populated_argument}
  591. if (p_node->Output(0).isNone()) {{
  592. p_node->Output(0) = {functional_variant_call};
  593. return;
  594. }}
  595. auto& {out_variable_name} = p_node->Output(0).toTensor();
  596. fastResizeToZero({out_variable_name});
  597. {out_variant_call};
  598. }};
  599. }}"""
  600. return generated
  601. def view_op_generator(
  602. self, g: NativeFunctionsViewGroup, backend_index: BackendIndex
  603. ) -> str:
  604. schema = str(g.view.func)
  605. populated_argument = generate_arg_extraction(g.view.func)
  606. functional_variant_call = generate_call_to_view_ops(g, backend_index)
  607. generated = f"""
  608. if (n->matches(torch::schema("aten::{schema}"))) {{
  609. return [](ProcessedNode* p_node) {{
  610. {populated_argument}
  611. p_node->Output(0) = {functional_variant_call};
  612. }};
  613. }}"""
  614. return generated
  615. class GenOpTestCase:
  616. def out_variant(self, groups: Sequence[NativeFunctionsGroup]) -> str:
  617. if not groups:
  618. return ""
  619. generated_type_variants = []
  620. for g in groups:
  621. with native_function_manager(g):
  622. assert is_supported(g)
  623. assert isinstance(g, NativeFunctionsGroup)
  624. generated_type_variant = self.out_variant_op_test_case_generator(g)
  625. generated_type_variants.append(generated_type_variant)
  626. return "\n".join(generated_type_variants)
  627. def view(self, groups: Sequence[NativeFunctionsViewGroup]) -> str:
  628. if not groups:
  629. return ""
  630. generated_type_variants = []
  631. for g in groups:
  632. with native_function_manager(g):
  633. assert is_supported(g)
  634. assert isinstance(g, NativeFunctionsViewGroup)
  635. generated_type_variant = self.view_op_test_case_generator(g)
  636. generated_type_variants.append(generated_type_variant)
  637. return "\n".join(generated_type_variants)
  638. def out_variant_op_test_case_generator(self, g: NativeFunctionsGroup) -> str:
  639. schema = g.functional.func
  640. schema_str = str(schema)
  641. assert schema_str.find("(") > 0
  642. type_variant_op_name = schema_str[: schema_str.find("(")].replace(".", "_")
  643. op_name = op_name_from_group(g)
  644. assert type_variant_op_name.startswith(op_name)
  645. arg_types = generate_test_ir_arguments(schema)
  646. arg_declarations = ", ".join(
  647. (
  648. arg_name if arg_type is None else f"{arg_name}: {arg_type}"
  649. for arg_name, arg_type in arg_types
  650. )
  651. )
  652. arg_names = ", ".join((arg_name for arg_name, _ in arg_types))
  653. assert (
  654. len(schema.returns) == 1
  655. and isinstance(schema.returns[0].type, BaseType)
  656. and schema.returns[0].type.name is BaseTy.Tensor
  657. )
  658. test_value_definitions = generate_test_value_definitions(schema, 0)
  659. test_value_names = generate_test_value_names(schema, 0)
  660. test_value_definitions2 = generate_test_value_definitions(schema, 1)
  661. test_value_names2 = generate_test_value_names(schema, 1)
  662. check_resize = "true" if should_check_resize(schema) else "false"
  663. generated = f"""
  664. TEST(StaticRuntime, autogen_{type_variant_op_name}) {{
  665. const std::string script = R"IR(
  666. graph({arg_declarations}):
  667. %bias: None = prim::Constant()
  668. %ret = aten::{op_name}({arg_names})
  669. %cloned = aten::clone(%ret, %bias)
  670. return (%cloned)
  671. )IR";
  672. {test_value_definitions}
  673. std::vector<IValue> args{{{test_value_names}}};
  674. testStaticRuntime(script, args, {{}}, /*use_allclose=*/false, /*use_equalnan=*/false, /*check_resize=*/{check_resize});
  675. {test_value_definitions2}
  676. std::vector<IValue> args2{{{test_value_names2}}};
  677. testStaticRuntime(script, args, args2, /*use_allclose=*/false, /*use_equalnan=*/false, /*check_resize=*/{check_resize});
  678. }}
  679. """
  680. return generated
  681. def view_op_test_case_generator(self, g: NativeFunctionsViewGroup) -> str:
  682. schema = g.view.func
  683. schema_str = str(schema)
  684. assert schema_str.find("(") > 0
  685. type_variant_op_name = schema_str[: schema_str.find("(")].replace(".", "_")
  686. op_name = g.view.root_name
  687. assert type_variant_op_name.startswith(op_name)
  688. arg_types = generate_test_ir_arguments(schema)
  689. arg_declarations = ", ".join(
  690. (
  691. arg_name if arg_type is None else f"{arg_name}: {arg_type}"
  692. for arg_name, arg_type in arg_types
  693. )
  694. )
  695. arg_names = ", ".join((arg_name for arg_name, _ in arg_types))
  696. assert (
  697. len(schema.returns) == 1
  698. and isinstance(schema.returns[0].type, BaseType)
  699. and schema.returns[0].type.name is BaseTy.Tensor
  700. )
  701. test_value_definitions = generate_test_value_definitions(schema, 0)
  702. test_value_names = generate_test_value_names(schema, 0)
  703. generated = f"""
  704. TEST(StaticRuntime, autogen_{type_variant_op_name}) {{
  705. const std::string script = R"IR(
  706. graph({arg_declarations}):
  707. %bias: None = prim::Constant()
  708. %ret = aten::{op_name}({arg_names})
  709. %cloned = aten::clone(%ret, %bias)
  710. return (%cloned)
  711. )IR";
  712. {test_value_definitions}
  713. std::vector<IValue> args{{{test_value_names}}};
  714. testStaticRuntime(script, args);
  715. }}
  716. """
  717. return generated