123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265 |
- import textwrap
- from dataclasses import dataclass
- from typing import List, Optional, Sequence, Tuple
- from torchgen.api.translate import translate
- from torchgen.api.types import DispatcherSignature
- from torchgen.context import method_with_native_function
- from torchgen.model import (
- Argument,
- BaseTy,
- BaseType,
- FunctionSchema,
- ListType,
- NativeFunction,
- OptionalType,
- Return,
- SchemaKind,
- Type,
- )
- from torchgen.utils import mapMaybe
- def is_tensor(typ: Type) -> bool:
- return isinstance(typ, BaseType) and typ.name == BaseTy.Tensor
- def is_optional_tensor(typ: Type) -> bool:
- return isinstance(typ, OptionalType) and is_tensor(typ.elem)
- def is_tensor_list(typ: Type) -> bool:
- return isinstance(typ, ListType) and is_tensor(typ.elem)
- def unwrap_tensor(name: str, cur_level_var: str) -> List[str]:
- result = f"""\
- Tensor {name}_value;
- optional<int64_t> {name}_bdim;
- std::tie({name}_value, {name}_bdim) = unwrapTensorAtLevel({name}, {cur_level_var});"""
- return textwrap.dedent(result).split("\n")
- def unwrap_optional_tensor(name: str, cur_level_var: str) -> List[str]:
- result = f"""\
- optional<Tensor> {name}_value;
- optional<int64_t> {name}_bdim;
- if ({name}) {{
- std::tie({name}_value, {name}_bdim) = unwrapTensorAtLevel({name}.value(), {cur_level_var});
- }}"""
- return textwrap.dedent(result).split("\n")
- def gen_unwraps(
- flat_arguments: Sequence[Argument], cur_level_var: str
- ) -> Tuple[str, List[str]]:
- arg_names = [a.name for a in flat_arguments]
- arg_types = [a.type for a in flat_arguments]
- tensors = [name for typ, name in zip(arg_types, arg_names) if is_tensor(typ)]
- optional_tensors = [
- name for typ, name in zip(arg_types, arg_names) if is_optional_tensor(typ)
- ]
- unwraps = []
- for tensor in tensors:
- unwraps += unwrap_tensor(tensor, cur_level_var)
- for opt_tensor in optional_tensors:
- unwraps += unwrap_optional_tensor(opt_tensor, cur_level_var)
- unwrap_code = "\n".join(unwraps)
- unwrapped_arg_list = []
- for arg in arg_names:
- if arg in tensors or arg in optional_tensors:
- unwrapped_arg_list += [f"{arg}_value", f"{arg}_bdim"]
- else:
- unwrapped_arg_list.append(arg)
- return unwrap_code, unwrapped_arg_list
- def gen_case_where_all_bdims_are_none(
- outer_sig: DispatcherSignature, schema: FunctionSchema, cur_level_var: str
- ) -> str:
- conditions = []
- flat_args = schema.arguments.flat_all
- for arg in flat_args:
- if not arg.type.is_tensor_like():
- continue
- conditions.append(f"!isBatchedAtLevel({arg.name}, {cur_level_var})")
- sig = DispatcherSignature.from_schema(schema)
- translated_args = ", ".join(
- e.expr for e in translate(outer_sig.arguments(), sig.arguments())
- )
- return f"""\
- if ({' && '.join(conditions)}) {{
- return at::_ops::{sig.func.name.unambiguous_name()}::call({translated_args});
- }}"""
- def gen_returns(
- returns: Tuple[Return, ...], cur_level_var: str, results_var: str
- ) -> str:
- idx = 0
- wrapped_returns = []
- for ret in returns:
- if is_tensor(ret.type):
- wrapped_returns.append(
- f"makeBatched(std::get<{idx}>({results_var}), std::get<{idx + 1}>({results_var}), {cur_level_var})"
- )
- idx += 2
- elif is_tensor_list(ret.type):
- wrapped_returns.append(
- f"makeBatchedVector(std::get<{idx}>({results_var}), std::get<{idx+1}>({results_var}), {cur_level_var})"
- )
- idx += 2
- else:
- wrapped_returns.append(f"std::get<{idx}>({results_var})")
- idx += 1
- if len(wrapped_returns) == 1:
- result = f"return {wrapped_returns[0]};"
- else:
- result = f'return std::make_tuple({", ".join(wrapped_returns)});'
- return result
- def accepts_at_least_one_tensor_input(schema: FunctionSchema) -> bool:
- return any(a.type.is_tensor_like() for a in schema.arguments.flat_all)
- def is_mutated_arg(argument: Argument) -> bool:
- return argument.annotation is not None and argument.annotation.is_write
- def gen_vmap_inplace_plumbing(native_function: NativeFunction) -> Optional[str]:
- # Assumptions:
- # - only one argument is being modified in-place
- # - the argument that is being modified in-place is the first argument
- # - all returns are either Tensor, tuple of Tensor, or TensorList
- schema = native_function.func
- sig = DispatcherSignature.from_schema(schema)
- returns = schema.returns
- # Check assumptions. If these are invalid we return None
- # and punt the work to handle them to the future.
- assert schema.kind() == SchemaKind.inplace
- if not is_mutated_arg(schema.arguments.flat_all[0]):
- return None
- if not len([arg for arg in schema.arguments.flat_all if is_mutated_arg(arg)]) == 1:
- return None
- # Only support cases where all returns are Tensors or vector<Tensor>
- if len(returns) == 0:
- return None
- if not all(is_tensor(ret.type) or is_tensor_list(ret.type) for ret in returns):
- return None
- if not accepts_at_least_one_tensor_input(schema):
- return None
- cur_level_var = "cur_level"
- unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
- bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var)
- return f"""\
- template <typename batch_rule_t, batch_rule_t batch_rule>
- {sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{
- c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
- auto maybe_layer = maybeCurrentDynamicLayer();
- vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing");
- int64_t {cur_level_var} = maybe_layer->layerId();
- {textwrap.indent(bdims_all_none_case, " ")}
- {textwrap.indent(unwraps, " ")}
- batch_rule({', '.join(unwrapped_arg_list)});
- return {schema.arguments.flat_all[0].name};
- }}"""
- def gen_vmap_plumbing_no_returns(native_function: NativeFunction) -> str:
- schema = native_function.func
- sig = DispatcherSignature.from_schema(schema)
- cur_level_var = "cur_level"
- unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
- bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var)
- return f"""\
- template <typename batch_rule_t, batch_rule_t batch_rule>
- {sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{
- c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
- auto maybe_layer = maybeCurrentDynamicLayer();
- vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns");
- int64_t {cur_level_var} = maybe_layer->layerId();
- {textwrap.indent(bdims_all_none_case, " ")}
- {textwrap.indent(unwraps, " ")}
- batch_rule({', '.join(unwrapped_arg_list)});
- }}"""
- def gen_vmap_plumbing(native_function: NativeFunction) -> Optional[str]:
- schema = native_function.func
- sig = DispatcherSignature.from_schema(schema)
- returns = schema.returns
- # Only support cases where all returns are Tensors or vector<Tensor>
- if not accepts_at_least_one_tensor_input(schema):
- return None
- if len(returns) == 0:
- return gen_vmap_plumbing_no_returns(native_function)
- if not all(ret.type.is_tensor_like() for ret in returns):
- return None
- # in-place views need special handling
- if "inplace_view" in native_function.tags:
- return None
- if schema.kind() == SchemaKind.inplace:
- return gen_vmap_inplace_plumbing(native_function)
- # Don't support these (mutable, out, scratch)
- if schema.kind() != SchemaKind.functional:
- return None
- results_var = "results"
- cur_level_var = "cur_level"
- unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
- bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var)
- wrapped_returns = gen_returns(returns, cur_level_var, results_var)
- return f"""\
- template <typename batch_rule_t, batch_rule_t batch_rule>
- {sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{
- c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
- auto maybe_layer = maybeCurrentDynamicLayer();
- vmap_check_escaped(maybe_layer, "gen_vmap_plumbing");
- int64_t {cur_level_var} = maybe_layer->layerId();
- {textwrap.indent(bdims_all_none_case, " ")}
- {textwrap.indent(unwraps, " ")}
- auto {results_var} = batch_rule({', '.join(unwrapped_arg_list)});
- {wrapped_returns}
- }}"""
- @dataclass(frozen=True)
- class ComputeBatchRulePlumbing:
- @method_with_native_function
- def __call__(self, f: NativeFunction) -> Optional[str]:
- opname = str(f.func.name)
- result = gen_vmap_plumbing(f)
- return result
- def gen_all_vmap_plumbing(native_functions: Sequence[NativeFunction]) -> str:
- body = "\n".join(list(mapMaybe(ComputeBatchRulePlumbing(), native_functions)))
- return f"""
- #pragma once
- #include <ATen/Operators.h>
- #include <ATen/functorch/PlumbingHelper.h>
- namespace at {{ namespace functorch {{
- {body}
- }}}} // namespace at::functorch
- """
|