123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460 |
- from typing import List, Optional, Sequence, Set, Union
- from torchgen import local
- from torchgen.api.types import (
- ArgName,
- ArrayCType,
- ArrayRefCType,
- BaseCType,
- BaseTypeToCppMapping,
- Binding,
- boolT,
- ConstRefCType,
- CType,
- dimnameListT,
- intArrayRefT,
- iTensorListRefT,
- ListCType,
- longT,
- MutRefCType,
- NamedCType,
- OptionalCType,
- optionalIntArrayRefT,
- optionalSymIntArrayRefT,
- scalarT,
- SpecialArgName,
- symIntArrayRefT,
- SymIntT,
- tensorListT,
- tensorOptionsT,
- tensorT,
- TupleCType,
- VectorCType,
- voidT,
- )
- from torchgen.model import (
- Argument,
- Arguments,
- BaseTy,
- BaseType,
- FunctionSchema,
- ListType,
- NativeFunction,
- OptionalType,
- Return,
- SelfArgument,
- TensorOptionsArguments,
- Type,
- )
- from torchgen.utils import assert_never
- # This file describes the translation of JIT schema to the public C++
- # API, which is what people use when they call functions like at::add.
- #
- # Prominent characteristics of the C++ API:
- #
- # - dtype, layout, device and pin_memory are collected into
- # a single C++ type TensorOptions (the native functions API
- # also has this, but tensor options is really most relevant
- # for the C++ API; it makes calling kwarg factory functions
- # pleasant)
- #
- # - defaulting lives here (in fact, the dispatcher is completely
- # oblivious of defaults!)
- #
- # BTW: policy on name collisions: we try not to have types with
- # collisions, but functions are fair game to collide
- def name(
- func: FunctionSchema,
- *,
- faithful_name_for_out_overloads: bool = False,
- symint_overload: bool = False,
- ) -> str:
- name = str(func.name.name)
- if symint_overload:
- name += "_symint"
- if func.is_out_fn():
- if faithful_name_for_out_overloads:
- name += "_outf"
- else:
- name += "_out"
- return name
- # Translation of "value types" in JIT schema to C++ API type. Value
- # types look the same no matter if they are argument types or return
- # types. Returns None if the type in question is not a value type.
- def valuetype_type(
- t: Type,
- *,
- binds: ArgName,
- remove_non_owning_ref_types: bool = False,
- symint: bool = False,
- ) -> Optional[NamedCType]:
- if isinstance(t, BaseType):
- if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar:
- return None
- elif str(t) == "SymInt":
- if symint:
- return NamedCType(binds, BaseCType(SymIntT))
- else:
- return NamedCType(binds, BaseCType(longT))
- if remove_non_owning_ref_types:
- if t.name == BaseTy.str:
- raise AssertionError(
- "string ref->value conversion: not implemented yet"
- )
- # All other BaseType currently map directly to BaseCppTypes.
- return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name]))
- elif isinstance(t, OptionalType):
- elem = valuetype_type(t.elem, binds=binds, symint=symint)
- if elem is None:
- return None
- return NamedCType(binds, OptionalCType(elem.type))
- elif isinstance(t, ListType):
- if str(t.elem) == "bool":
- assert t.size is not None
- return NamedCType(binds, ArrayCType(BaseCType(boolT), t.size))
- else:
- return None
- else:
- raise AssertionError(f"unrecognized type {repr(t)}")
- # Translation of types occuring in JIT arguments to a C++ argument type.
- # If remove_non_owning_ref_types is set, we'll guarantee that the outputed CType is not a non-owning reference type.
- # For example, we'll return std::vector<int> instead of IntArrayRef.
- # See Note [translation from C++ reference to value types]
- def argumenttype_type(
- t: Type,
- *,
- mutable: bool,
- binds: ArgName,
- remove_non_owning_ref_types: bool = False,
- symint: bool = False,
- ) -> NamedCType:
- # If it's a value type, do the value type translation
- r = valuetype_type(
- t,
- binds=binds,
- symint=symint,
- remove_non_owning_ref_types=remove_non_owning_ref_types,
- )
- if r is not None:
- return r
- if isinstance(t, BaseType):
- if t.name == BaseTy.Tensor:
- if mutable and not local.use_const_ref_for_mutable_tensors():
- return NamedCType(binds, MutRefCType(BaseCType(tensorT)))
- else:
- return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
- elif t.name == BaseTy.Scalar:
- return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
- else:
- raise AssertionError(f"base type should have been value type {t}")
- elif isinstance(t, OptionalType):
- if str(t.elem) == "Tensor":
- if mutable and not local.use_const_ref_for_mutable_tensors():
- return NamedCType(
- binds, MutRefCType(BaseCType(tensorT))
- ) # TODO: fix this discrepancy
- else:
- return NamedCType(
- binds, ConstRefCType(OptionalCType(BaseCType(tensorT)))
- )
- elif str(t.elem) == "Scalar":
- return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
- elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int":
- return NamedCType(binds, BaseCType(optionalIntArrayRefT))
- elif isinstance(t.elem, ListType) and str(t.elem.elem) == "SymInt":
- if symint:
- return NamedCType(binds, BaseCType(optionalSymIntArrayRefT))
- else:
- return NamedCType(binds, BaseCType(optionalIntArrayRefT))
- elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint)
- return NamedCType(binds, OptionalCType(elem.type))
- elif isinstance(t, ListType):
- # TODO: remove these special cases, ArrayRef fallthrough works fine
- if str(t.elem) == "int":
- if remove_non_owning_ref_types:
- return NamedCType(binds, VectorCType(BaseCType(longT)))
- else:
- return NamedCType(binds, BaseCType(intArrayRefT))
- if str(t.elem) == "SymInt":
- if remove_non_owning_ref_types:
- if symint:
- return NamedCType(binds, VectorCType(BaseCType(SymIntT)))
- else:
- return NamedCType(binds, VectorCType(BaseCType(longT)))
- else:
- if symint:
- return NamedCType(binds, BaseCType(symIntArrayRefT))
- else:
- return NamedCType(binds, BaseCType(intArrayRefT))
- if str(t.elem) == "Tensor":
- if local.use_ilistref_for_tensor_lists():
- return NamedCType(binds, ConstRefCType(BaseCType(iTensorListRefT)))
- else:
- return NamedCType(binds, BaseCType(tensorListT))
- elif str(t.elem) == "Scalar":
- return NamedCType(binds, ArrayRefCType(BaseCType(scalarT)))
- elif str(t.elem) == "Dimname":
- return NamedCType(binds, BaseCType(dimnameListT))
- elif str(t.elem) == "Tensor?":
- return NamedCType(
- binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT))))
- )
- elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint)
- return NamedCType(binds, ArrayRefCType(elem.type))
- else:
- raise AssertionError(f"unrecognized type {repr(t)}")
- # Translate a JIT argument into its C++ type
- def argument_type(a: Argument, *, binds: ArgName, symint: bool = False) -> NamedCType:
- return argumenttype_type(a.type, mutable=a.is_write, symint=symint, binds=binds)
- # Translation of a (non-multi) return type from JIT to C++
- # N.B: returntype_type returns a CType, not a NamedCType.
- # This is mostly because of the mismatch between return types and return names.
- # e.g. a function with a return type of 'void' has 0 return names,
- # and a function with a return type of 'std::tuple' has >1 return name.
- def returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType:
- # placeholder is ignored
- r = valuetype_type(t, binds="__placeholder__", symint=symint)
- if r is not None:
- return r.type
- if isinstance(t, BaseType):
- if t.name == BaseTy.Tensor:
- if mutable:
- if local.use_const_ref_for_mutable_tensors():
- return ConstRefCType(BaseCType(tensorT))
- else:
- return MutRefCType(BaseCType(tensorT))
- else:
- # Note [Tensor Copy Returns]
- # Currently, we use "Argument.is_write" to determine
- # whether or not Tensor return types should be copies or references.
- # If that ever changes, take a look at other locations of this note!
- return BaseCType(tensorT)
- elif t.name == BaseTy.Scalar:
- return BaseCType(scalarT)
- elif isinstance(t, ListType):
- assert (
- not mutable
- ), "Native functions should never return a mutable tensor list. They should return void."
- elem = returntype_type(t.elem, mutable=False, symint=symint)
- assert t.size is None, f"fixed size list returns not supported: {t}"
- return VectorCType(elem)
- raise AssertionError(f"unrecognized return type {t}")
- # Translation of a single return to its C++ type
- def return_type(r: Return, *, symint: bool = False) -> CType:
- return returntype_type(r.type, mutable=r.is_write, symint=symint)
- # Translation of a full (possibly multi) return from JIT to its C++ type
- def returns_type(rs: Sequence[Return], *, symint: bool = False) -> CType:
- if len(rs) == 0:
- return BaseCType(voidT)
- elif len(rs) == 1:
- return return_type(rs[0], symint=symint)
- else:
- return TupleCType([return_type(r, symint=symint) for r in rs])
- def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]:
- returns: List[str] = []
- for i, r in enumerate(f.func.returns):
- # If we have an inplace function, the return argument is
- # implicitly named self.
- # TODO: Consider incorporating this into the data model
- if f.func.name.name.inplace:
- assert i == 0, "illegal inplace function with multiple returns"
- name = "self"
- # If we are out function, the name is the name of the
- # corresponding output function (r.name will get recorded
- # in field_name later.)
- elif f.func.is_out_fn():
- name = f.func.arguments.out[i].name
- # If the return argument is explicitly named...
- elif r.name:
- name_conflict = any(
- r.name == a.name for a in f.func.schema_order_arguments()
- )
- if name_conflict and not f.func.is_out_fn():
- name = f"{r.name}_return"
- else:
- name = r.name
- # If there is no explicit name and no fallback name was passed in, we just name the output result,
- # unless it's a multi-return, in which case it's result0,
- # result1, etc (zero-indexed)
- else:
- name = fallback_name if len(f.func.returns) == 1 else f"{fallback_name}{i}"
- returns.append(name)
- return returns
- JIT_TO_CPP_DEFAULT = {
- "False": "false",
- "True": "true",
- "None": "c10::nullopt", # UGH this one is type directed
- "Mean": "at::Reduction::Mean",
- "[]": "{}",
- "contiguous_format": "MemoryFormat::Contiguous",
- "long": "at::kLong",
- }
- # Convert a JIT default into C++ expression representing the default
- def default_expr(d: str, t: Type, *, symint: bool) -> str:
- if d == "None" and str(t) == "Tensor?":
- return "{}"
- if isinstance(t, BaseType) and t.name is BaseTy.str:
- # Schema allows single quotes but C++ needs double
- if len(d) >= 2 and d[0] == "'" and d[-1] == "'":
- s = ""
- i = 1
- while i + 1 < len(d):
- if d[i] != "\\":
- if d[i] == '"':
- s += '\\"'
- else:
- s += d[i]
- i += 1
- else:
- if d[i + 1] == "'":
- s += "'"
- else:
- s += d[i : i + 2]
- i += 2
- return f'"{s}"'
- if isinstance(t, OptionalType):
- if d == "None":
- return "c10::nullopt"
- return default_expr(d, t.elem, symint=symint)
- if isinstance(t, ListType):
- if d.startswith("[") and d.endswith("]"):
- return "{" + d[1:-1] + "}"
- elif symint and d.isdigit() and str(t.elem) == "SymInt":
- return f"c10::SymInt({d})"
- elif t.size is None:
- # NOTE: Sized lists can have scalar defaults
- raise ValueError(f"Expected a list default '[...]' but found: '{d}'")
- return JIT_TO_CPP_DEFAULT.get(d, d)
- # Convert an argument into its C++ API form
- def argument(
- a: Union[Argument, TensorOptionsArguments, SelfArgument],
- *,
- cpp_no_default_args: Set[str],
- method: bool,
- faithful: bool,
- symint: bool = False,
- has_tensor_options: bool,
- ) -> List[Binding]:
- def sub_argument(
- a: Union[Argument, TensorOptionsArguments, SelfArgument]
- ) -> List[Binding]:
- return argument(
- a,
- cpp_no_default_args=cpp_no_default_args,
- method=method,
- faithful=faithful,
- symint=symint,
- has_tensor_options=has_tensor_options,
- )
- if isinstance(a, Argument):
- binds: ArgName
- if a.name == "memory_format" and has_tensor_options:
- binds = SpecialArgName.possibly_redundant_memory_format
- else:
- binds = a.name
- default: Optional[str] = None
- if a.name not in cpp_no_default_args and a.default is not None:
- default = default_expr(a.default, a.type, symint=symint)
- return [
- Binding(
- nctype=argument_type(a, binds=binds, symint=symint),
- name=a.name,
- default=default,
- argument=a,
- )
- ]
- elif isinstance(a, TensorOptionsArguments):
- if faithful:
- return (
- sub_argument(a.dtype)
- + sub_argument(a.layout)
- + sub_argument(a.device)
- + sub_argument(a.pin_memory)
- )
- else:
- default = None
- # Enforced by NativeFunction.__post_init__
- assert "options" not in cpp_no_default_args
- if all(x.default == "None" for x in a.all()):
- default = "{}"
- elif a.dtype.default == "long":
- default = "at::kLong" # TODO: this is wrong
- return [
- Binding(
- nctype=NamedCType("options", BaseCType(tensorOptionsT)),
- name="options",
- default=default,
- argument=a,
- )
- ]
- elif isinstance(a, SelfArgument):
- if method:
- # Caller is responsible for installing implicit this in context!
- return []
- else:
- return sub_argument(a.argument)
- else:
- assert_never(a)
- def arguments(
- arguments: Arguments,
- *,
- faithful: bool,
- symint: bool = False,
- method: bool,
- cpp_no_default_args: Set[str],
- ) -> List[Binding]:
- args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
- if faithful:
- args.extend(arguments.non_out)
- args.extend(arguments.out)
- else:
- args.extend(arguments.out)
- args.extend(arguments.non_out)
- return [
- r.no_default() if faithful else r
- for a in args
- for r in argument(
- a,
- faithful=faithful,
- symint=symint,
- method=method,
- has_tensor_options=arguments.tensor_options is not None,
- cpp_no_default_args=cpp_no_default_args,
- )
- ]
|