123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795 |
- import json
- import logging
- import math
- from typing import Dict, List, Optional, Sequence, Tuple, Union
- import torchgen.api.cpp as cpp
- from torchgen.context import native_function_manager
- from torchgen.model import (
- Argument,
- BackendIndex,
- BaseTy,
- BaseType,
- FunctionSchema,
- NativeFunctionsGroup,
- NativeFunctionsViewGroup,
- OptionalType,
- SelfArgument,
- TensorOptionsArguments,
- Type,
- )
- from torchgen.static_runtime import config
- logger: logging.Logger = logging.getLogger()
- def has_alias(
- arguments: Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]]
- ) -> bool:
- for arg in arguments:
- annotation = getattr(arg, "annotation", None)
- if not annotation:
- continue
- alias_set = getattr(annotation, "alias_set", ())
- if alias_set:
- return True
- return False
- BLOCKED_OPS = frozenset(
- (
- # non cpu ops
- "sparse_sampled_addmm",
- "hspmm",
- "linalg_svdvals",
- # sparse ops
- "sspaddmm",
- "coalesce",
- "_indices",
- "indices",
- "_values",
- "values",
- "crow_indices",
- "col_indices",
- # deprecated ops
- "floor_divide",
- "ger",
- # buggy ops
- "conj_physical", # P495807361
- "binary_cross_entropy", # P496394764
- "arccosh",
- # uncommon ops
- "cholesky",
- "lu_solve",
- "linalg_cholesky",
- "linalg_householder_product",
- "linalg_ldl_solve",
- "_compute_linear_combination",
- # training related ops
- "_make_dual",
- # cannot call directly
- "_fw_primal",
- # no documentation
- "_index_reduce",
- # TODO: these ones got added recently and need manual inspection
- "_new_zeros_with_same_feature_meta",
- "_conj_physical",
- "binary_cross_entropy_with_logits",
- "bincount",
- "conv_tbc",
- "copy",
- "_copy_from",
- "_copy_from_and_resize",
- "count_nonzero",
- "cudnn_affine_grid_generator",
- "cudnn_affine_grid_generator_backward",
- "cudnn_grid_sampler",
- "diag_embed",
- "embedding",
- "embedding_dense_backward",
- "_embedding_bag_dense_backward",
- "_embedding_bag_per_sample_weights_backward",
- "grid_sampler_2d",
- "_grid_sampler_2d_cpu_fallback",
- "grid_sampler_3d",
- "isnan",
- "mkldnn_linear",
- "median",
- "nanmedian",
- "_sparse_sparse_matmul",
- "batch_norm_backward_elemt",
- "_euclidean_dist",
- "pixel_shuffle",
- "pixel_unshuffle",
- "channel_shuffle",
- "_reshape_nested_backward",
- "relu",
- "prelu",
- "celu",
- "slice_scatter",
- "select_scatter",
- "diagonal_scatter",
- "sum",
- "_mkldnn_transpose",
- "_nested_tensor_from_mask",
- "_nested_from_padded",
- "_nested_tensor_size",
- "_nested_from_padded_and_nested_example",
- "_standard_gamma_grad",
- "_dirichlet_grad",
- "native_norm",
- "_sparse_softmax",
- "_sparse_softmax_backward_data",
- "_sparse_log_softmax",
- "_sparse_log_softmax_backward_data",
- "zero",
- "_sparse_addmm",
- "sparse_mask",
- "_to_dense",
- "_coalesce",
- "_coalesced",
- "copy_sparse_to_sparse",
- "to_sparse",
- "to_sparse_csr",
- "to_sparse_csc",
- "to_mkldnn",
- "quantize_per_tensor_dynamic",
- "quantize_per_channel",
- "q_per_channel_scales",
- "q_per_channel_zero_points",
- "int_repr",
- "_make_per_channel_quantized_tensor",
- "set",
- "lift",
- "lift_fresh",
- "lift_fresh_copy",
- "masked_scatter",
- "_masked_softmax",
- "_masked_softmax_backward",
- "put",
- "index_reduce",
- "trace",
- "_cholesky_solve_helper",
- "dist",
- "max",
- "_torch_cuda_cu_linker_symbol_op",
- "glu_jvp",
- "glu_backward_jvp",
- "hardswish_backward",
- "rrelu_with_noise_backward",
- "mkldnn_adaptive_avg_pool2d_backward",
- "_adaptive_avg_pool2d_backward",
- "_adaptive_avg_pool3d_backward",
- "isinf",
- "linalg_lu_solve",
- "linalg_vecdot",
- "linalg_matrix_exp",
- "linalg_eigvalsh",
- "_test_warn_in_autograd",
- "_test_autograd_multiple_dispatch_view",
- "_test_autograd_multiple_dispatch_view_copy",
- "_segment_reduce",
- "_segment_reduce_backward",
- "_fw_primal_copy",
- "_make_dual_copy",
- "view_as_real_copy",
- "view_as_complex_copy",
- "_conj_copy",
- "_neg_view_copy",
- "diagonal_copy",
- "detach_copy",
- "squeeze_copy",
- "t_copy",
- "unsqueeze_copy",
- "_indices_copy",
- "_values_copy",
- "indices_copy",
- "values_copy",
- "crow_indices_copy",
- "col_indices_copy",
- "ccol_indices",
- "ccol_indices_copy",
- "row_indices",
- "row_indices_copy",
- "unfold_copy",
- "alias_copy",
- "_triton_multi_head_attention",
- "special_airy_ai",
- "special_bessel_j0",
- "special_bessel_j1",
- "special_bessel_y0",
- "special_bessel_y1",
- "special_chebyshev_polynomial_t",
- "special_chebyshev_polynomial_u",
- "special_chebyshev_polynomial_v",
- "special_chebyshev_polynomial_w",
- "special_hermite_polynomial_h",
- "special_hermite_polynomial_he",
- "special_laguerre_polynomial_l",
- "special_legendre_polynomial_p",
- "special_modified_bessel_i0",
- "special_modified_bessel_i1",
- "special_modified_bessel_k0",
- "special_modified_bessel_k1",
- "special_scaled_modified_bessel_k0",
- "special_scaled_modified_bessel_k1",
- "special_shifted_chebyshev_polynomial_t",
- "special_shifted_chebyshev_polynomial_u",
- "special_shifted_chebyshev_polynomial_v",
- "special_shifted_chebyshev_polynomial_w",
- "special_spherical_bessel_j0",
- "_foobar",
- "_nested_tensor_strides",
- )
- )
- def is_supported(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bool:
- base_op_name = ""
- func = None
- if isinstance(g, NativeFunctionsViewGroup):
- base_op_name = g.view.root_name
- func = g.view.func
- else:
- base_op_name = g.out.func.name.name.base
- func = g.out.func
- if config.is_hand_written(g):
- logger.info(f"HAND WRITTEN: {base_op_name}")
- return False
- if base_op_name in BLOCKED_OPS:
- logger.info(f"BLOCKED: {base_op_name}")
- return False
- for arg in func.schema_order_arguments():
- maybe_method = ivalue_type_conversion_method(arg.type)
- if not maybe_method:
- # Type converting is unsupported yet.
- logger.info(f"NOT SUPPORTED TYPE CONVERTING: {str(func)}")
- return False
- if isinstance(g, NativeFunctionsViewGroup):
- # TODO: stop doing type tests by converting to C++ and then testing
- # the string, just test the dang thing directly
- if "at::Tensor" != cpp.returns_type(func.returns, symint=False).cpp_type():
- # Returns a non-Tensor value.
- logger.info(f"NON-TENSOR RET TYPE: {str(func)}")
- return False
- return True
- # For out variant ops, we need to check the arguments of its functional func.
- for arg in g.functional.func.schema_order_arguments():
- maybe_method = ivalue_type_conversion_method(arg.type)
- if not maybe_method:
- # Type converting is unsupported yet.
- logger.info(f"NOT SUPPORTED TYPE CONVERTING: {str(g.functional.func)}")
- return False
- if not g.structured:
- # In case of unstructured op, we check if it has out variant implementation.
- # The out variant implementation satisfies the minimum requirement that it has the output tensor as the last
- # parameter.
- if (
- not hasattr(g, "out")
- or not str(func).endswith("Tensor(a!) out) -> Tensor(a!)")
- or not str(func.name).endswith(".out")
- ):
- return False
- # TODO: stop type testing by converting to C++
- if "at::Tensor &" != cpp.returns_type(func.returns, symint=False).cpp_type():
- logger.info(f"NON_TENSOR RET TYPE: {str(func)}")
- return False
- if has_alias(func.arguments.non_out):
- # This op may create an alias of inputs.
- logger.info(f"INPUTS ALIAS: {base_op_name}")
- return False
- return True
- def ivalue_type_conversion_method(
- arg_type: Union[BaseType, OptionalType, Type]
- ) -> Optional[Tuple[bool, str]]:
- """
- Return the method call expression of `c10::ivalue' to convert its contained value to
- the expected value of `arg_type` type. For example, for `arg_type` == BaseTy.Tensor,
- this function returns ".toTensor()", so that it can be appended to the ivalue's
- variable name to get the value of the expected type.
- """
- type_conversion_methods = {
- BaseTy.Tensor: ((True, "toTensor()"), (False, "toOptional<at::Tensor>()")),
- BaseTy.int: ((False, "toInt()"), (False, "toOptional<int64_t>()")),
- BaseTy.bool: ((False, "toBool()"), (False, "toOptional<bool>()")),
- BaseTy.Scalar: ((False, "toScalar()"), (False, "toOptional<at::Scalar>()")),
- BaseTy.ScalarType: (
- (False, "toScalarType()"),
- (False, "toOptional<at::ScalarType>()"),
- ),
- BaseTy.str: (
- (False, "toStringView()"),
- (False, "toOptional<c10::string_view>()"),
- ),
- }
- base_ty_object = None
- if isinstance(arg_type, BaseType):
- base_ty_object = arg_type.name
- elif isinstance(arg_type, OptionalType):
- if not isinstance(arg_type.elem, BaseType):
- # ListType is currently unsupported.
- return None
- base_ty_object = arg_type.elem.name
- else:
- return None
- if base_ty_object not in type_conversion_methods:
- return None
- methods = type_conversion_methods[base_ty_object]
- if isinstance(arg_type, BaseType):
- return methods[0]
- return methods[1]
- should_use_int_tensor_ops_ = frozenset(
- (
- "bitwise_not",
- "bitwise_and",
- "bitwise_or",
- "bitwise_xor",
- "bitwise_left_shift",
- "bitwise_right_shift",
- "gcd",
- "lcm",
- "scatter",
- "gather",
- "_convert_indices_from_coo_to_csr",
- "_convert_indices_from_csr_to_coo",
- )
- )
- should_use_complex_tensor_ops_ = frozenset(("view_as_real", "imag", "_conj"))
- def should_use_int_tensor(op_name: str) -> bool:
- return op_name in should_use_int_tensor_ops_
- def should_use_complex_tensor(op_name: str) -> bool:
- return op_name in should_use_complex_tensor_ops_
- test_tensor_dim_ops_1_ = frozenset(
- (
- "addmv",
- "index_add",
- "_convert_indices_from_coo_to_csr",
- "_convert_indices_from_csr_to_coo",
- "nll_loss_backward",
- "dot",
- "vdot",
- "outer",
- "ger",
- )
- )
- test_tensor_dim_ops_2_ = frozenset(
- ("addmm", "mm", "nuclear_norm", "diag", "_addmm_activation", "matrix_H", "t")
- )
- def test_tensor_dim(op_name: str) -> int:
- if op_name in test_tensor_dim_ops_1_:
- return 1
- if op_name in test_tensor_dim_ops_2_:
- return 2
- return 3
- test_tensor_shapes_string = '{"view_as_complex": "{2, 2}"}'
- test_tensor_shape_json: Dict[str, str] = json.loads(test_tensor_shapes_string)
- def test_tensor_shape(op_name: str) -> str:
- if op_name in test_tensor_shape_json:
- return test_tensor_shape_json[op_name]
- else:
- return ""
- def test_value_expression(
- arg_type: Union[BaseType, OptionalType, Type], index: int, op_name: str
- ) -> str:
- tensor_size_ex = test_tensor_shape(op_name)
- if tensor_size_ex == "":
- num_tensors = 16 if index == 0 else 64
- num_dim = test_tensor_dim(op_name)
- size_per_dim = math.ceil(num_tensors / float(num_dim))
- size_per_dim += size_per_dim % 2
- tensor_size_ex = "{%s}" % (",".join([f"{size_per_dim}"] * num_dim))
- if should_use_int_tensor(op_name):
- tensor_expression = f"at::randint(1, 100, {tensor_size_ex}, at::kInt)"
- elif should_use_complex_tensor(op_name):
- tensor_expression = f"at::randn({tensor_size_ex}, at::kComplexFloat)"
- else:
- tensor_expression = f"at::rand({tensor_size_ex})"
- value_expressions = {
- BaseTy.Tensor: tensor_expression,
- BaseTy.int: "1",
- BaseTy.bool: "false",
- BaseTy.Scalar: "2",
- BaseTy.ScalarType: "at::ScalarType::Float",
- BaseTy.str: '"floor"',
- }
- base_ty_object = None
- if isinstance(arg_type, BaseType):
- base_ty_object = arg_type.name
- else:
- assert isinstance(arg_type, OptionalType) and isinstance(
- arg_type.elem, BaseType
- )
- base_ty_object = arg_type.elem.name
- assert base_ty_object in value_expressions, "not expected type"
- value_expression = value_expressions[base_ty_object]
- return value_expression
- def generate_test_value_definitions(schema: FunctionSchema, index: int) -> str:
- assert not schema.is_out_fn()
- schema_name = schema.name.name.base
- arg_map = {}
- for arg in schema.schema_order_arguments():
- test_value_exp = test_value_expression(arg.type, index, schema_name)
- arg_map[arg.name] = test_value_exp
- config.override_test_values(arg_map, schema_name, index)
- arg_populations = []
- for arg_name, arg_value in arg_map.items():
- arg_populations.append(f"auto {arg_name}{index} = {arg_value}")
- return ";\n ".join(arg_populations) + ";"
- def generate_test_value_names(schema: FunctionSchema, index: int) -> str:
- assert not schema.is_out_fn()
- return ",".join(f"{arg.name}{index}" for arg in schema.schema_order_arguments())
- generate_test_ir_arguments_base_ty_to_type_str_ = {
- BaseTy.Tensor: "Tensor",
- BaseTy.int: "int",
- BaseTy.float: "float",
- BaseTy.str: "str",
- BaseTy.Scalar: "int",
- BaseTy.ScalarType: "int",
- BaseTy.bool: "bool",
- }
- def generate_test_ir_arguments(
- schema: FunctionSchema,
- ) -> List[Tuple[str, Optional[str]]]:
- def ir_argument(arg: Argument) -> Tuple[str, Optional[str]]:
- t = arg.type
- add_optional = False
- if isinstance(t, OptionalType):
- t = t.elem
- add_optional = True
- assert isinstance(t, BaseType)
- type_str = None
- if t.name in generate_test_ir_arguments_base_ty_to_type_str_:
- type_str = generate_test_ir_arguments_base_ty_to_type_str_[t.name]
- if type_str and add_optional:
- type_str = f"{type_str}?"
- return ("%" + arg.name, type_str)
- return [ir_argument(arg) for arg in schema.schema_order_arguments()]
- def generate_arg_extraction(schema: FunctionSchema) -> str:
- arg_populations = []
- for i, arg in enumerate(schema.schema_order_arguments()):
- maybe_method = ivalue_type_conversion_method(arg.type)
- assert maybe_method
- is_reference, type_conversion_method = maybe_method
- reference = "&" if is_reference else ""
- arg_populations.append(
- f"const auto{reference} {arg.name} = p_node->Input({i}).{type_conversion_method}"
- )
- return ";\n ".join(arg_populations) + ";"
- def get_kernel_name(g: NativeFunctionsGroup, backend_index: BackendIndex) -> str:
- kernel = backend_index.get_kernel(g.functional)
- if g.structured or kernel is None:
- return cpp.name(g.functional.func)
- return kernel.kernel
- def get_out_kernel_name(g: NativeFunctionsGroup, backend_index: BackendIndex) -> str:
- kernel = backend_index.get_kernel(g.out)
- if g.structured or kernel is None:
- return cpp.name(g.out.func)
- return kernel.kernel
- def generate_non_out_variant_call(
- g: NativeFunctionsGroup, backend_index: BackendIndex
- ) -> str:
- schema = g.functional.func
- assert not schema.is_out_fn()
- kernel_name = get_kernel_name(g, backend_index)
- arg_names = (arg.name for arg in schema.schema_order_arguments())
- namespace_name = "cpu" if g.structured else "native"
- return f'at::{namespace_name}::{kernel_name}({",".join(arg_names)})'
- def generate_call_to_view_ops(
- g: NativeFunctionsViewGroup, backend_index: BackendIndex
- ) -> str:
- schema = g.view.func
- kernel_name = cpp.name(schema)
- kernel = backend_index.get_kernel(g.view)
- if kernel:
- kernel_name = kernel.kernel
- arg_names = (arg.name for arg in schema.schema_order_arguments())
- namespace_name = "native"
- return f'at::{namespace_name}::{kernel_name}({",".join(arg_names)})'
- def generate_out_variant_call(
- g: NativeFunctionsGroup, backend_index: BackendIndex
- ) -> str:
- schema = g.out.func
- assert schema.is_out_fn()
- arg_names = []
- kernel_name = get_out_kernel_name(g, backend_index)
- if g.structured:
- # structured op starts with the output tensor argument.
- arg_names = [out_arg.name for out_arg in schema.arguments.out]
- else:
- arg_names = []
- for arg in schema.arguments.non_out:
- if isinstance(arg, SelfArgument):
- arg_names.append(arg.argument.name)
- else:
- assert isinstance(arg, Argument)
- arg_names.append(arg.name)
- if not g.structured:
- assert len(schema.arguments.out) == 1
- arg_names.append(schema.arguments.out[0].name)
- cpp_arg_names = ",".join(arg_names)
- namespace_name = "cpu" if g.structured else "native"
- return f"at::{namespace_name}::{kernel_name}({cpp_arg_names})"
- no_memory_resize_ops = frozenset(
- (
- "isin.Scalar_Tensor",
- "index_add",
- "dot",
- "vdot",
- "nuclear_norm",
- "histc",
- "l1_loss",
- "multi_margin_loss",
- "multilabel_margin_loss",
- "nll_loss",
- "nll_loss2d",
- "prod",
- )
- )
- def should_check_resize(schema: FunctionSchema) -> bool:
- schema_str = str(schema)
- type_variant_op_name = schema_str[: schema_str.find("(")]
- return type_variant_op_name not in no_memory_resize_ops
- def op_name_from_group(g: NativeFunctionsGroup) -> str:
- return g.functional.func.name.name.base
- class GenOpDispatcher:
- def out_variant(
- self, groups: Sequence[NativeFunctionsGroup], backend_index: BackendIndex
- ) -> str:
- if not groups:
- return ""
- generated_type_variants = []
- for g in groups:
- with native_function_manager(g):
- assert is_supported(g)
- assert isinstance(g, NativeFunctionsGroup)
- generated_type_variant = self.out_variant_op_generator(g, backend_index)
- generated_type_variants.append(generated_type_variant)
- op_name = op_name_from_group(groups[0])
- body = "\n".join(generated_type_variants)
- generated = f"""
- REGISTER_OPERATOR_FUNCTOR(
- aten::{op_name},
- aten_{op_name},
- [](Node* n) -> SROperator {{
- {body}
- LogAndDumpSchema(n);
- return nullptr;
- }});
- """
- return generated
- def view(
- self, groups: Sequence[NativeFunctionsViewGroup], backend_index: BackendIndex
- ) -> str:
- if not groups:
- return ""
- generated_type_variants = []
- for g in groups:
- with native_function_manager(g):
- assert is_supported(g)
- assert isinstance(g, NativeFunctionsViewGroup)
- generated_type_variant = self.view_op_generator(g, backend_index)
- generated_type_variants.append(generated_type_variant)
- op_name = config.func_name_base_str(groups[0])
- body = "\n".join(generated_type_variants)
- generated = f"""
- REGISTER_NATIVE_OPERATOR_FUNCTOR(
- aten::{op_name},
- aten_{op_name},
- [](Node* n) -> SROperator {{
- {body}
- LogAndDumpSchema(n);
- return nullptr;
- }});
- """
- return generated
- def out_variant_op_generator(
- self, g: NativeFunctionsGroup, backend_index: BackendIndex
- ) -> str:
- functional = g.functional
- schema = str(functional.func)
- populated_argument = generate_arg_extraction(g.functional.func)
- functional_variant_call = generate_non_out_variant_call(g, backend_index)
- assert len(g.out.func.arguments.out) == 1
- out_variable_name = str(g.out.func.arguments.out[0].name)
- out_variant_call = generate_out_variant_call(g, backend_index)
- generated = f"""
- if (n->matches(torch::schema("aten::{schema}"))) {{
- return [](ProcessedNode* p_node) {{
- {populated_argument}
- if (p_node->Output(0).isNone()) {{
- p_node->Output(0) = {functional_variant_call};
- return;
- }}
- auto& {out_variable_name} = p_node->Output(0).toTensor();
- fastResizeToZero({out_variable_name});
- {out_variant_call};
- }};
- }}"""
- return generated
- def view_op_generator(
- self, g: NativeFunctionsViewGroup, backend_index: BackendIndex
- ) -> str:
- schema = str(g.view.func)
- populated_argument = generate_arg_extraction(g.view.func)
- functional_variant_call = generate_call_to_view_ops(g, backend_index)
- generated = f"""
- if (n->matches(torch::schema("aten::{schema}"))) {{
- return [](ProcessedNode* p_node) {{
- {populated_argument}
- p_node->Output(0) = {functional_variant_call};
- }};
- }}"""
- return generated
- class GenOpTestCase:
- def out_variant(self, groups: Sequence[NativeFunctionsGroup]) -> str:
- if not groups:
- return ""
- generated_type_variants = []
- for g in groups:
- with native_function_manager(g):
- assert is_supported(g)
- assert isinstance(g, NativeFunctionsGroup)
- generated_type_variant = self.out_variant_op_test_case_generator(g)
- generated_type_variants.append(generated_type_variant)
- return "\n".join(generated_type_variants)
- def view(self, groups: Sequence[NativeFunctionsViewGroup]) -> str:
- if not groups:
- return ""
- generated_type_variants = []
- for g in groups:
- with native_function_manager(g):
- assert is_supported(g)
- assert isinstance(g, NativeFunctionsViewGroup)
- generated_type_variant = self.view_op_test_case_generator(g)
- generated_type_variants.append(generated_type_variant)
- return "\n".join(generated_type_variants)
- def out_variant_op_test_case_generator(self, g: NativeFunctionsGroup) -> str:
- schema = g.functional.func
- schema_str = str(schema)
- assert schema_str.find("(") > 0
- type_variant_op_name = schema_str[: schema_str.find("(")].replace(".", "_")
- op_name = op_name_from_group(g)
- assert type_variant_op_name.startswith(op_name)
- arg_types = generate_test_ir_arguments(schema)
- arg_declarations = ", ".join(
- (
- arg_name if arg_type is None else f"{arg_name}: {arg_type}"
- for arg_name, arg_type in arg_types
- )
- )
- arg_names = ", ".join((arg_name for arg_name, _ in arg_types))
- assert (
- len(schema.returns) == 1
- and isinstance(schema.returns[0].type, BaseType)
- and schema.returns[0].type.name is BaseTy.Tensor
- )
- test_value_definitions = generate_test_value_definitions(schema, 0)
- test_value_names = generate_test_value_names(schema, 0)
- test_value_definitions2 = generate_test_value_definitions(schema, 1)
- test_value_names2 = generate_test_value_names(schema, 1)
- check_resize = "true" if should_check_resize(schema) else "false"
- generated = f"""
- TEST(StaticRuntime, autogen_{type_variant_op_name}) {{
- const std::string script = R"IR(
- graph({arg_declarations}):
- %bias: None = prim::Constant()
- %ret = aten::{op_name}({arg_names})
- %cloned = aten::clone(%ret, %bias)
- return (%cloned)
- )IR";
- {test_value_definitions}
- std::vector<IValue> args{{{test_value_names}}};
- testStaticRuntime(script, args, {{}}, /*use_allclose=*/false, /*use_equalnan=*/false, /*check_resize=*/{check_resize});
- {test_value_definitions2}
- std::vector<IValue> args2{{{test_value_names2}}};
- testStaticRuntime(script, args, args2, /*use_allclose=*/false, /*use_equalnan=*/false, /*check_resize=*/{check_resize});
- }}
- """
- return generated
- def view_op_test_case_generator(self, g: NativeFunctionsViewGroup) -> str:
- schema = g.view.func
- schema_str = str(schema)
- assert schema_str.find("(") > 0
- type_variant_op_name = schema_str[: schema_str.find("(")].replace(".", "_")
- op_name = g.view.root_name
- assert type_variant_op_name.startswith(op_name)
- arg_types = generate_test_ir_arguments(schema)
- arg_declarations = ", ".join(
- (
- arg_name if arg_type is None else f"{arg_name}: {arg_type}"
- for arg_name, arg_type in arg_types
- )
- )
- arg_names = ", ".join((arg_name for arg_name, _ in arg_types))
- assert (
- len(schema.returns) == 1
- and isinstance(schema.returns[0].type, BaseType)
- and schema.returns[0].type.name is BaseTy.Tensor
- )
- test_value_definitions = generate_test_value_definitions(schema, 0)
- test_value_names = generate_test_value_names(schema, 0)
- generated = f"""
- TEST(StaticRuntime, autogen_{type_variant_op_name}) {{
- const std::string script = R"IR(
- graph({arg_declarations}):
- %bias: None = prim::Constant()
- %ret = aten::{op_name}({arg_names})
- %cloned = aten::clone(%ret, %bias)
- return (%cloned)
- )IR";
- {test_value_definitions}
- std::vector<IValue> args{{{test_value_names}}};
- testStaticRuntime(script, args);
- }}
- """
- return generated
|