123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476 |
- from dataclasses import dataclass
- from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
- from torchgen.api import cpp
- from torchgen.api.types import Binding, CppSignature, CppSignatureGroup
- from torchgen.gen import pythonify_default
- from torchgen.model import (
- Argument,
- BaseTy,
- BaseType,
- FunctionSchema,
- ListType,
- NativeFunction,
- OptionalType,
- Return,
- Type,
- Variant,
- )
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- #
- # Data Models
- #
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- #
- # [Notes] python binding codegen
- #
- # The Python binding codegen produces code that takes the input list of
- # PyObjects, finds the matching ATen C++ function using PythonArgParser,
- # converts the PyObjects into C++ types and calls the ATen C++ function:
- #
- # +--------+ parsing +------------------------+ binding +-----------------------+
- # | PyObjs | ---------> | PythonArgParser Output | ---------> | Cpp Function Dispatch |
- # +--------+ +------------------------+ +-----------------------+
- #
- # The following examples demonstrate the data models the Python binding
- # codegen needs to deal with and the tasks it needs to accomplish. It
- # helps understand the purpose of the new data types we introduced below.
- #
- # - Function Schema (source of truth)
- #
- # aten::empty.names(int[] size, *, Dimname[]? names,
- # ScalarType? dtype=None, Layout? layout=None,
- # Device? device=None, bool? pin_memory=None,
- # MemoryFormat? memory_format=None) -> Tensor
- #
- # - Python Signature
- #
- # It's used to generate input schema string for PythonArgParser.
- # Note: TensorOptions fields are reordered and the additional
- # 'requires_grad' field is added:
- #
- # empty(IntArrayRef size, *, DimnameList? names,
- # MemoryFormat? memory_format=None, ScalarType dtype=None,
- # Layout layout=torch.strided, Device device=None,
- # bool pin_memory=False, bool requires_grad=False)
- #
- # - C++ Signature
- #
- # It's used to generate C++ lambda formals & dispatch call.
- # Note: the scattered TensorOptions fields are packed into 'options'.
- #
- # auto dispatch_empty =
- # [](IntArrayRef size, c10::optional<DimnameList> names,
- # const TensorOptions & options,
- # c10::optional<MemoryFormat> memory_format) -> Tensor {
- # pybind11::gil_scoped_release no_gil;
- # return torch::empty(size, names, options, memory_format);
- # };
- #
- # - Binding between Python Arguments and C++ Arguments
- #
- # Given a set of Python Arguments in scope, we need produce the
- # binding expressions that translate the Python API into C++ API:
- #
- # Python Args Cpp Args Binding Exprs
- # -----------------------------------------------------------------
- # 0: size size '_r.intlist(0)'
- # 1: names names 'names' [special init]
- # 2: memory_format -------+
- # 3: dtype -----+-|--> options 'options' [special packing]
- # 4: layout / |
- # 5: device / +--> memory_format '_r.memoryformatOptional(2)'
- # 6: pin_memory /
- # 7: requires_grad -+
- #
- # So the full dispatch expression would look like:
- #
- # dispatch_empty(_r.intlist(0), names, options,
- # _r.memoryformatOptional(2))
- #
- # Where does 'names' come from? It involves special local init:
- #
- # auto __names = _r.toDimnameListOptional(1);
- # c10::optional<DimnameList> names =
- # __names ? c10::make_optional(DimnameList(__names.value()))
- # : c10::nullopt;
- #
- # Where does 'options' come from? It involves special local init
- # for TensorOptions. Note that Python side has the additional
- # 'requires_grad' field:
- #
- # const auto options = TensorOptions()
- # .dtype(_r.scalartype(3))
- # .device(_r.device(5))
- # .layout(_r.layoutOptional(4))
- # .requires_grad(_r.toBool(7))
- # .pinned_memory(_r.toBool(6));
- #
- # In some other cases one Python Argument can map to multiple C++
- # Arguments. For example:
- #
- # aten::max.names_dim(Tensor self, Dimname dim, bool keepdim=False)
- # -> (Tensor values, Tensor indices)
- #
- # Python Args Cpp Args Binding Exprs
- # ---------------------------------------------------------------------
- # +----> max 'out[0]'
- # /-----> max_values 'out[1]
- # 0: input / self '_r.tensor(0)'
- # 1: dim / dim '_r.dimname(1)'
- # 2: keepdim / keepdim '_r.toBool(2)'
- # 3: out -----+ [local init] out '_r.tensorlist_n<2>(3)'
- #
- # As demonstrated above, the binding can involve reordering,
- # packing, unpacking and special local inits.
- #
- #
- # Let's look at a concrete example:
- #
- # static PythonArgParser parser({
- # "abs(Tensor input, *, Tensor out=None)",
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- # ^
- # +--- Python Schema, represented by PythonSignature and PythonArgument
- #
- # }, /*traceable=*/true);
- #
- # ParsedArgs<2> parsed_args;
- # auto _r = parser.parse(nullptr, args, kwargs, parsed_args);
- #
- # ...
- #
- # if (_r.isNone(1)) {
- # ~~~~~~~~~~~~ <--- Scattered PythonArgParser output (arg name = 'out')
- # represented by PythonArgParserOutputExpr
- #
- # // aten::abs(Tensor self) -> Tensor
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- # ^
- # +--- NativeFunction schema, base version
- #
- # auto dispatch_abs = [](const Tensor & self) -> Tensor {
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- # ^
- # +--- dispatch_lambda_args / dispatch_lambda_return_str
- # generated from NativeFunction / CppSignature
- # (deprecated PythonSignature is special)
- # arguments are represented by DispatchLambdaArgument
- #
- # pybind11::gil_scoped_release no_gil;
- # return self.abs();
- # ~~~~~~~~~~~ <--- cpp_dispatch_target / cpp_dispatch_exprs
- # generated from NativeFunction / CppSignature
- # };
- # return wrap(dispatch_abs(_r.tensor(0)));
- # ~~~~~~~~~~~~~
- # ^
- # +--- dispatch_lambda_exprs
- # binding PythonArgParserOutputExpr (python args)
- # and DispatchLambdaArgument (c++ args)
- #
- # } else {
- # // aten::abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- # ^
- # +--- NativeFunction schema, out-variant
- #
- # auto dispatch_abs_out = [](Tensor out, const Tensor & self) -> Tensor {
- # pybind11::gil_scoped_release no_gil;
- # return at::abs_out(out, self);
- # };
- # return wrap(dispatch_abs_out(_r.tensor(1), _r.tensor(0)));
- # }
- #
- #
- # [Notes] python interface codegen
- # The python dataclasses below are used used to generate both python binding code
- # and pyi type hint signatures.
- # In theory these two should look very similar, but there are number of differences
- # in how pyi signatures vs. python_arg_parser signatures are generated.
- # These differences have been encapsulated in signature_str() vs. signature_str_pyi()
- # to display the full signatures, and argument_str() vs argument_str_pyi() to display arguments.
- # For examples, only pyi signatures include return types.
- @dataclass(frozen=True)
- class PythonReturns:
- returns: Tuple[Return, ...]
- @dataclass(frozen=True)
- class PythonArgument:
- name: str
- type: Type
- default: Optional[str]
- # Used to generate the default init expr for some PythonArgParser outputs, e.g.:
- #
- # _r.layoutWithDefault(3, layout_from_backend(self.options().backend())))
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- # ^
- # +--- default_init str
- default_init: Optional[str]
- # Compute argument formal for python argument parsing.
- # Needs to be consistent with torch/csrc/utils/python_arg_parser.h.
- def argument_str(self, *, method: bool = False, symint: bool = True) -> str:
- type_str = (
- argument_type_str(self.type, symint=symint)
- .replace("const ", "")
- .replace(" &", "")
- )
- name = self.name
- # s/self/input/ outside method bindings
- # [old codegen] TODO: remove this? doesn't rename in codegen, it's just
- # for the parse string
- if name == "self" and type_str in ["Tensor", "Number"] and not method:
- name = "input"
- # add default
- if self.default is not None:
- default = {
- "nullptr": "None",
- "c10::nullopt": "None",
- "{}": "None",
- }.get(self.default, self.default)
- return f"{type_str} {name}={default}"
- else:
- return f"{type_str} {name}"
- def argument_str_pyi(
- self, *, method: bool = False, deprecated: bool = False
- ) -> str:
- type_str = argument_type_str_pyi(self.type)
- name = self.name
- # s/self/input/ outside method bindings
- # [old codegen] TODO: remove this? doesn't rename in codegen, it's just
- # for the parse string
- if name == "self" and type_str == "Tensor" and not method and not deprecated:
- name = "input"
- if name == "from": # from is a Python keyword...
- name += "_"
- # pyi merges the _out and functional variants into the same signature, with an optional out arg
- if name == "out" and type_str == "Tensor" and not deprecated:
- type_str = "Optional[" + type_str + "]"
- # pyi deprecated signatures don't get defaults for their out arg
- treat_as_no_default = (
- deprecated
- and isinstance(self, PythonOutArgument)
- and self.default == "None"
- )
- # add default
- if self.default is not None and not treat_as_no_default:
- if (
- isinstance(self.type, ListType)
- and self.type.elem == BaseType(BaseTy.int)
- and self.default.startswith("{")
- and self.default.endswith("}")
- ):
- default = "(" + self.default[1:-1] + ")"
- else:
- default = {
- "nullptr": "None",
- "c10::nullopt": "None",
- "{}": "None",
- "MemoryFormat::Contiguous": "contiguous_format",
- "QScheme::PER_TENSOR_AFFINE": "per_tensor_affine",
- }.get(self.default, self.default)
- return f"{name}: {type_str}={default}"
- else:
- return f"{name}: {type_str}"
- @dataclass(frozen=True)
- class PythonOutArgument(PythonArgument):
- # In Python signature multiple output fields are packed into one 'out' argument.
- # When binding to C++, it's first binded to a local 'out' variable:
- # 'auto out = _r.tensorlist_n<2>(2);',
- # then binded to scattered C++ output arguments as 'out[0]', 'out[1]', and etc.
- # TODO: maybe don't need keep scattered out fields for python signature?
- outputs: Tuple[PythonArgument, ...]
- @staticmethod
- def from_outputs(
- outputs: Tuple[PythonArgument, ...]
- ) -> Optional["PythonOutArgument"]:
- if not outputs:
- return None
- size = len(outputs)
- if size == 1:
- return PythonOutArgument(
- name=outputs[0].name,
- type=outputs[0].type,
- default="None",
- default_init=None,
- outputs=outputs,
- )
- elif size > 1:
- if any(map(lambda a: not a.type.is_tensor_like(), outputs)):
- raise RuntimeError(f"Unsupported output type: {outputs}")
- return PythonOutArgument(
- name="out",
- # TODO: shouldn't this be OptionalType[ListType[...]], since it defaults to None?
- type=ListType(BaseType(BaseTy.Tensor), size),
- default="None",
- default_init=None,
- outputs=outputs,
- )
- raise AssertionError(r"Unexpected PythonOutArgument size")
- @dataclass(frozen=True)
- class PythonSignature:
- # Base operator name, without inplace/outplace suffix.
- name: str
- # Positional arguments.
- # TODO: create a dedicated SelfArgument type for 'self'?
- input_args: Tuple[PythonArgument, ...]
- # Keyword arguments excluding the 'out' argument and scattered kwargs belonging
- # to TensorOptions (dtype, layout, device, pin_memory, requires_grad, etc).
- input_kwargs: Tuple[PythonArgument, ...]
- output_args: Optional[PythonOutArgument]
- # Return types, which are only used by pyi
- returns: PythonReturns
- # These are scattered kwargs arguments belonging to TensorOptions.
- # When binding to C++, they are packed into a TensorOptions object 'options'.
- # It's possible that the C++ signature doesn't take TensorOptions object (e.g.
- # for out variant), in which case they will be used as scattered fields without
- # being packed into 'options'.
- # TODO: maybe create a PythonTensorOptionsArgument?
- tensor_options_args: Tuple[PythonArgument, ...]
- # method or function signature?
- method: bool
- @property
- def deprecated(self) -> bool:
- return False
- def arguments(
- self, *, skip_outputs: bool = False, skip_tensor_options: bool = False
- ) -> Tuple[Union[PythonArgument, PythonOutArgument], ...]:
- result: List[Union[PythonArgument, PythonOutArgument]] = []
- result.extend(self.input_args)
- result.extend(self.input_kwargs)
- if self.output_args is not None and not skip_outputs:
- result.append(self.output_args)
- if not skip_tensor_options:
- result.extend(self.tensor_options_args)
- return tuple(result)
- def arguments_count(self) -> int:
- return len(self.arguments())
- def output_idx(self) -> int:
- return len(self.input_args) + len(self.input_kwargs)
- # [old codegen] Compute the Python function signature for argument parsing,
- # as specified in torch/csrc/utils/python_arg_parser.h. WARNING:
- # this is NOT the same type signature as specified by PEP 484
- # as understood by mypy; our format was independently developed
- # and has some quirks to make it more suitable specifically
- # for error parsing.
- #
- # For a translation to mypy-valid type signatures, see
- # signature_str_pyi().
- def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str:
- args = self.arguments(skip_outputs=skip_outputs)
- schema_formals: List[str] = list(
- map(lambda a: a.argument_str(method=self.method, symint=symint), args)
- )
- positional_argc = len(self.input_args)
- if len(schema_formals) > positional_argc:
- schema_formals.insert(positional_argc, "*")
- return f'{self.name}({", ".join(schema_formals)})'
- def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
- args = self.arguments(skip_outputs=skip_outputs)
- schema_formals: List[str] = list(
- map(lambda a: a.argument_str_pyi(method=self.method), args)
- )
- positional_argc = len(self.input_args)
- if len(schema_formals) > positional_argc:
- schema_formals.insert(positional_argc, "*")
- # only pyi signatures include returns
- returns_str = returns_str_pyi(self)
- # pyi also includes self (with no typing/defaults) for methods
- if self.method:
- schema_formals.insert(0, "self")
- return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
- def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> Optional[str]:
- # only pyi uses vararg signatures
- args = self.arguments(skip_outputs=skip_outputs)
- schema_formals: List[str] = list(
- map(lambda a: a.argument_str_pyi(method=self.method), args)
- )
- # vararg only applies to pyi signatures. vararg variants are not generated for all signatures
- num_args = self.arguments_count()
- num_positionalargs = len(self.input_args)
- have_vararg_version = False
- if num_args > 0:
- vararg_type = args[0].type
- if (
- isinstance(vararg_type, ListType)
- and str(vararg_type.elem) in ["int", "SymInt"]
- and num_positionalargs == 1
- ):
- have_vararg_version = True
- if not have_vararg_version:
- return None
- # Below are the major changes in vararg vs. regular pyi signatures
- # vararg signatures also omit the asterix
- schema_formals[0] = "*" + args[0].name + ": _int"
- returns_str = returns_str_pyi(self)
- # pyi also includes self (with no typing/defaults) for methods
- if self.method:
- schema_formals.insert(0, "self")
- return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
- # The deprecated python signature involves some special logic, so create a
- # dedicated data model to store these extra properties.
- @dataclass(frozen=True)
- class PythonSignatureDeprecated(PythonSignature):
- # Schema for the deprecated function
- deprecated_schema: FunctionSchema
- # The deprecated signature might miss some arguments that the corresponding
- # C++ signature expects. We need store the constant default values to pass in.
- # For example:
- # [deprecate signature]: addmm(Scalar beta, Tensor self, Tensor mat1, Tensor mat2)
- # [func schema]: aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
- # [func call]: self.addmm(mat1, mat2, beta, 1)
- # We store ['self', 'mat1', 'mat2', 'beta', '1'] in this case.
- deprecated_args_exprs: Tuple[str, ...]
- @property
- def deprecated(self) -> bool:
- return True
- def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str:
- return (
- PythonSignature.signature_str(
- self, skip_outputs=skip_outputs, symint=symint
- )
- + "|deprecated"
- )
- def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
- args = self.arguments(skip_outputs=skip_outputs)
- schema_formals: List[str] = list(
- map(lambda a: a.argument_str_pyi(method=self.method, deprecated=True), args)
- )
- positional_argc = len(self.input_args)
- if len(schema_formals) > positional_argc:
- schema_formals.insert(positional_argc, "*")
- returns_str = returns_str_pyi(self)
- return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
- def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> Optional[str]:
- # the codegen doesn't include vararg variants for deprecated signatures
- return None
- # This struct is used to hold the PythonSignature and its corresponding
- # NativeFunction BEFORE grouping base and out-variant functions.
- # Why not store NativeFunction in PythonSignature or construct PythonSignature
- # from NativeFunction? Because they are not 1-1 mapped.
- # One native function could have both deprecated and non-deprecated python
- # signatures - NativeFunction doesn't contain information to construct the
- # deprecated python signature.
- # One python signature is used to handle both the base and the out-variant
- # function - see 'PythonSignatureGroup'.
- @dataclass(frozen=True)
- class PythonSignatureNativeFunctionPair:
- signature: PythonSignature
- function: NativeFunction
- # We merge pairs of functions with signatures that are equivalent mod
- # output arguments, and use a single entry in the python_arg_parser sig
- # list for both (output arguments become optional).
- @dataclass(frozen=True)
- class PythonSignatureGroup:
- # The signature used for Python argument parsing. The outplace signature
- # is preferred if exists, because it can be used to parse inputs for both
- # the out-place variant and the base version (with output omitted).
- signature: PythonSignature
- # The regular ATen declaration (e.g. conv2d)
- base: NativeFunction
- # The out variant (e.g. conv2d_out)
- outplace: Optional[NativeFunction]
- @classmethod
- def from_pairs(
- cls,
- functional: PythonSignatureNativeFunctionPair,
- out: Optional[PythonSignatureNativeFunctionPair],
- ) -> "PythonSignatureGroup":
- if out is None:
- return PythonSignatureGroup(
- signature=functional.signature,
- base=functional.function,
- outplace=None,
- )
- # prefer the signature with optional out=... arguments because it's the
- # superset that can be used to parse input for both base and outplace.
- signature_kwargs = out.signature.__dict__.copy()
- # Out overloads in C++ don't have TensorOptions arguments,
- # so take these from the functional variant
- signature_kwargs[
- "tensor_options_args"
- ] = functional.signature.tensor_options_args
- return PythonSignatureGroup(
- signature=type(out.signature)(**signature_kwargs),
- base=functional.function,
- outplace=out.function,
- )
- # C++ function dispatch is wrapped in a lambda function. The lambda function
- # has almost the same signature as the C++ function, only with some small
- # variants - see details below.
- # This data model is used to represent arguments of the lambda function
- # signature.
- @dataclass(frozen=True)
- class DispatchLambdaArgument:
- name: str
- type_str: str
- is_out_arg: bool
- # To pass PyObjects arguments to C++ function (via the lambda wrapper),
- # we need first convert PyObjects into simple C++ objects. This work
- # is done by PythonArgParser.
- # This data model is used to represent the output of PythonArgParser.
- # It has 1-1 mapping with PythonArgument in PythonSignature.
- @dataclass(frozen=True)
- class PythonArgParserOutputExpr:
- # argument name
- name: str
- # RHS expression to reference PythonArgParser output.
- expr: str
- # In some special cases we need create different expr, e.g.:
- # '_r.isNone(1)' instead of '_r.tensor(1)'.
- index: int
- # The python argument it maps to.
- argument: PythonArgument
- @property
- def is_none_expr(self) -> str:
- return f"_r.isNone({self.index})"
- # To pass PythonArgParser output to the lambda wrapper, we need bind
- # PythonArgParserOutputExpr to DispatchLambdaArgument.
- # They are not always 1-1 mapped, e.g. scattered TensorOptions fields
- # need be packed into a TensorOptions object, which is the argument
- # that the lambda function wrapper takes.
- @dataclass(frozen=True)
- class DispatchLambdaArgumentExprs:
- # The exprs that provide the binding for lambda arguments, e.g.:
- #
- # 'self' -> '_r.tensor(0)'
- # 'min' -> 'out[0]' / 'min_indices' -> 'out[1]'
- # 'options' -> 'options'
- #
- # It has 1-1 mapping with DispatchLambdaArgument.
- exprs: Sequence[str]
- # Special local inits, which might introduce new variables that
- # the 'exprs' above reference, e.g.:
- #
- # 'auto out = _r.tensorlist_n<2>(2);'
- #
- inits: Sequence[str]
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- #
- # Helper Functions
- #
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- def _cpp_signature(f: NativeFunction, *, method: bool = False) -> CppSignature:
- return CppSignatureGroup.from_native_function(f, method=method).signature
- def has_tensor_options(f: NativeFunction) -> bool:
- return f.func.arguments.tensor_options is not None
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- #
- # Python Signature
- #
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- # 'simple_type' was introduced by the old codegen, which is slightly
- # different from the python schema type, e.g.: doesn't have '?' suffix
- # for optional Tensor/TensorList; doesn't have '[size]' suffix for list type.
- def argument_type_str(
- t: Type, *, simple_type: bool = False, symint: bool = True
- ) -> str:
- if isinstance(t, BaseType):
- if t.name == BaseTy.Tensor:
- return "Tensor"
- elif t.name == BaseTy.int:
- return "int64_t"
- elif t.name == BaseTy.float:
- return "double"
- elif t.name == BaseTy.str:
- return "c10::string_view"
- elif t.name in [
- BaseTy.bool,
- BaseTy.QScheme,
- BaseTy.Scalar,
- BaseTy.ScalarType,
- BaseTy.Generator,
- BaseTy.Storage,
- BaseTy.Layout,
- BaseTy.Device,
- BaseTy.MemoryFormat,
- BaseTy.Dimname,
- BaseTy.Stream,
- BaseTy.ConstQuantizerPtr,
- BaseTy.SymInt,
- ]:
- # These python schema type names line up with their function schema names
- return t.name.name
- elif isinstance(t, OptionalType):
- if str(t.elem) == "Tensor":
- # Is it desired to keep '?' for simple_type with new style dispatcher?
- return "Tensor?"
- elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint)
- return f"{elem}?"
- elif isinstance(t, ListType):
- size = t.size if not simple_type else None
- if str(t.elem) == "bool":
- assert t.size is not None
- return f"::std::array<bool,{t.size}>"
- elif str(t.elem) == "int":
- return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef"
- elif str(t.elem) == "SymInt":
- if symint:
- return (
- f"SymIntArrayRef[{size}]" if size is not None else "SymIntArrayRef"
- )
- else:
- return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef"
- elif str(t.elem) == "Tensor":
- return f"TensorList[{size}]" if size is not None else "TensorList"
- elif str(t.elem) == "Scalar":
- return f"ScalarList[{size}]" if size is not None else "ScalarList"
- elif str(t.elem) == "Tensor?":
- if simple_type:
- return "c10::List<c10::optional<Tensor>>"
- else:
- return "const c10::List<c10::optional<Tensor>> &"
- elif str(t.elem) == "Dimname":
- return f"DimnameList[{size}]" if size is not None else "DimnameList"
- elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint)
- return f"ArrayRef<{elem}>"
- raise RuntimeError(f"unrecognized type {repr(t)}")
- def argument_type_size(t: Type) -> Optional[int]:
- l = t.is_list_like()
- if l is not None and str(l.elem) != "bool":
- return l.size
- else:
- return None
- def argument(a: Argument) -> PythonArgument:
- return PythonArgument(
- name=a.name,
- type=a.type,
- # TODO: directly translate a.default to python default
- default=str(
- pythonify_default(cpp.default_expr(a.default, a.type, symint=False))
- )
- if a.default is not None
- else None,
- default_init=None,
- )
- # Generates a PythonSignature that can be used for either .pyi or PythonArgParser codegen
- def signature(
- f: NativeFunction, *, method: bool = False, pyi: bool = False
- ) -> PythonSignature:
- return signature_from_schema(
- f.func, category_override=f.category_override, method=method, pyi=pyi
- )
- def signature_from_schema(
- func: FunctionSchema,
- *,
- category_override: Optional[str],
- method: bool = False,
- pyi: bool = False,
- ) -> PythonSignature:
- args: List[Argument] = []
- args.extend(func.arguments.pre_self_positional)
- # Skip SelfArgument if this is method.
- if not method and func.arguments.self_arg is not None:
- args.append(func.arguments.self_arg.argument)
- args.extend(func.arguments.post_self_positional)
- args.extend(func.arguments.pre_tensor_options_kwarg_only)
- # Skip TensorOptionsArguments. Python side TensorOptions
- # arguments are created based on different rules - see below.
- args.extend(func.arguments.post_tensor_options_kwarg_only)
- args.extend(func.arguments.out)
- input_arg_set = {a.name for a in func.arguments.flat_positional}
- kwarg_only_set = {a.name for a in func.arguments.flat_kwarg_only}
- out_arg_set = {a.name for a in func.arguments.out}
- input_args = tuple(map(argument, filter(lambda a: a.name in input_arg_set, args)))
- input_kwargs = tuple(
- map(argument, filter(lambda a: a.name in kwarg_only_set, args))
- )
- outputs = tuple(map(argument, filter(lambda a: a.name in out_arg_set, args)))
- # Reintroduce the scattered fields of TensorOptions for Python.
- # Compared to the cpp counterpart, the python arguments have new property
- # (default_init) and a new argument 'requires_grad', which require some
- # special handlings.
- # [old codegen] TODO: because these aren't guaranteed to be 100% faithful
- # to the original versions in the yaml, this recreation is a potential
- # source of drift between eager and JIT. Pull this logic out to a shared place.
- has_tensor_input_arg = any(
- a.type.is_tensor_like() for a in func.arguments.flat_non_out
- )
- if any(a.name == "requires_grad" for a in func.schema_order_arguments()):
- raise ValueError(
- "argument named requires_grad is reserved, should not explicitly add it in the schema"
- )
- # [old codegen] this probably won't work if one of the returns is not a tensor,
- # but it will produce a compile-time error that is obvious.
- has_tensor_return = any(r.type.is_tensor_like() for r in func.returns)
- name: str = cpp.name(func)
- is_factory_function = category_override == "factory" or (
- has_tensor_return and not has_tensor_input_arg
- )
- is_like_or_new_function = (
- category_override in ("new", "like")
- or name.startswith("new_")
- or name.endswith("_like")
- )
- tensor_options_args: List[PythonArgument] = []
- if is_factory_function or is_like_or_new_function:
- def topt_default_init(name: str) -> Optional[str]:
- topt_args = func.arguments.tensor_options
- if topt_args is None:
- return None
- a = getattr(topt_args, name)
- if a.default is None or a.default == "None":
- return None
- return cpp.default_expr(a.default, a.type, symint=False)
- tensor_options_args.append(
- PythonArgument(
- name="dtype",
- type=OptionalType(BaseType(BaseTy.ScalarType)),
- default="None",
- default_init=(
- "self.scalar_type()"
- if is_like_or_new_function
- else topt_default_init("dtype")
- ),
- )
- )
- tensor_options_args.append(
- PythonArgument(
- name="layout",
- type=OptionalType(BaseType(BaseTy.Layout)),
- default="None",
- default_init=(
- "self.layout()"
- if is_like_or_new_function
- else topt_default_init("layout")
- ),
- )
- )
- tensor_options_args.append(
- PythonArgument(
- name="device",
- type=OptionalType(BaseType(BaseTy.Device)),
- default="None",
- default_init=(
- "self.device()"
- if is_like_or_new_function
- else (
- topt_default_init("device")
- or "torch::tensors::get_default_device()"
- )
- ),
- )
- )
- tensor_options_args.append(
- PythonArgument(
- name="pin_memory",
- type=OptionalType(BaseType(BaseTy.bool)),
- default="False",
- default_init=None,
- )
- )
- tensor_options_args.append(
- PythonArgument(
- name="requires_grad",
- type=OptionalType(BaseType(BaseTy.bool)),
- default="False",
- default_init=None,
- )
- )
- returns = PythonReturns(returns=func.returns)
- return PythonSignature(
- name=str(func.name.name),
- input_args=input_args,
- input_kwargs=input_kwargs,
- output_args=PythonOutArgument.from_outputs(outputs),
- tensor_options_args=tuple(tensor_options_args),
- returns=returns,
- method=method,
- )
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- #
- # Python Interface
- #
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- def namedtuple_fieldnames(returns: Tuple[Return, ...]) -> List[str]:
- if len(returns) <= 1 or all(map(lambda r: r.name is None, returns)):
- return []
- else:
- if any(map(lambda r: r.name is None, returns)):
- # When building on Windows, `PyStructSequence_UnnamedField` could not be
- # resolved by the linker for some reason, which cause error in building:
- #
- # python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol
- # PyStructSequence_UnnamedField
- #
- # Thus, at this point in time, we do not support unnamed
- # fields in namedtuple; you must either name all fields,
- # or none of them.
- raise ValueError("Unnamed field is not supported by codegen")
- return list(map(lambda r: str(r.name), returns))
- def argument_type_str_pyi(t: Type) -> str:
- add_optional = False
- if isinstance(t, OptionalType):
- t = t.elem
- add_optional = True
- if isinstance(t, BaseType):
- if t.name == BaseTy.int:
- ret = "_int"
- if t.name == BaseTy.SymInt:
- ret = "Union[_int, SymInt]"
- elif t.name == BaseTy.float:
- ret = "_float"
- elif t.name == BaseTy.str:
- ret = "str"
- elif t.name == BaseTy.Scalar:
- ret = "Number"
- elif t.name == BaseTy.ScalarType:
- ret = "_dtype"
- elif t.name == BaseTy.bool:
- ret = "_bool"
- elif t.name == BaseTy.QScheme:
- ret = "_qscheme"
- elif t.name == BaseTy.Layout:
- ret = "_layout"
- elif t.name == BaseTy.Device:
- ret = "Union[_device, str, None]"
- elif t.name == BaseTy.MemoryFormat:
- ret = "memory_format"
- elif t.name == BaseTy.Dimname:
- ret = "Union[str, ellipsis, None]"
- elif t.name in [BaseTy.Tensor, BaseTy.Generator, BaseTy.Storage, BaseTy.Stream]:
- # These python schema type names line up with their function schema names
- ret = t.name.name
- elif isinstance(t, ListType):
- if str(t.elem) == "int":
- ret = "Union[_int, _size]" if t.size is not None else "_size"
- elif t.is_tensor_like():
- # TODO: this doesn't seem right...
- # Tensor?[] currently translates to Optional[Union[Tuple[Tensor, ...], List[Tensor]]]
- # It should probably translate to Union[Tuple[Optional[Tensor], ...], List[Optional[Tensor]]]
- if isinstance(t.elem, OptionalType):
- add_optional = True
- ret = (
- "Union[Tensor, Tuple[Tensor, ...], List[Tensor]]"
- if t.size is not None
- else "Union[Tuple[Tensor, ...], List[Tensor]]"
- )
- elif str(t.elem) == "float":
- ret = "Sequence[_float]"
- else:
- elem = argument_type_str_pyi(t.elem)
- ret = f"Sequence[{elem}]"
- else:
- raise RuntimeError(f"unrecognized type {repr(t)}")
- if add_optional:
- ret = "Optional[" + ret + "]"
- return ret
- def return_type_str_pyi(t: Type) -> str:
- # Where arguments are open to accepting Union, return types should return
- # concrete types
- if isinstance(t, OptionalType):
- inner = return_type_str_pyi(t.elem)
- return f"Optional[{inner}]"
- if isinstance(t, BaseType):
- if t.name == BaseTy.Device:
- return "_device"
- elif t.name == BaseTy.Dimname:
- ret = "Optional[str]"
- else:
- return argument_type_str_pyi(t)
- if isinstance(t, ListType):
- inner = return_type_str_pyi(t.elem)
- return f"List[{inner}]"
- return argument_type_str_pyi(t)
- def returns_named_tuple_pyi(signature: PythonSignature) -> Optional[Tuple[str, str]]:
- python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns]
- namedtuple_name = signature.name
- field_names = namedtuple_fieldnames(signature.returns.returns)
- if field_names:
- tuple_args = [
- f'("{name}", {typ})' for name, typ in zip(field_names, python_returns)
- ]
- namedtuple_def = f'NamedTuple("{namedtuple_name}", [{", ".join(tuple_args)}])'
- return namedtuple_name, namedtuple_def
- return None
- def returns_str_pyi(signature: PythonSignature) -> str:
- field_names = namedtuple_fieldnames(signature.returns.returns)
- if field_names:
- return f"torch.return_types.{signature.name}"
- python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns]
- if len(python_returns) > 1:
- return "Tuple[" + ", ".join(python_returns) + "]"
- if len(python_returns) == 1:
- return python_returns[0]
- return "None"
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- #
- # C++ Function Dispatch
- #
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- # This section provides APIs to generate the code that does C++ function
- # dispatch. The C++ function call is wrapped by a lambda function.
- # For example:
- #
- # // aten::selu_(Tensor(a!) self) -> Tensor(a!)
- # auto dispatch_selu_ = [](Tensor self) -> Tensor {
- # pybind11::gil_scoped_release no_gil;
- # return at::selu_(self);
- # };
- #
- # The lambda function's signature follows the C++ signature in common
- # cases, e.g.:
- #
- # // aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
- # [](const Tensor & self, const Tensor & other, Scalar alpha) -> Tensor
- #
- # For out variant the 'out' argument's type is changed from 'Tensor &'
- # to 'Tensor'. It's because when calling the lambda it passes in the
- # PythonArgParser output '_r.tensor(3)', which is stack allocated object
- # and needs to pass by value. Also see comments in 'dispatch_lambda_return_str()'.
- #
- # // aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
- # [](Tensor out, const Tensor & self, const Tensor & other, Scalar alpha) -> Tensor
- #
- # For multi-output case it can keep using reference type because the
- # PythonArgParser output has been unpacked to local variables, e.g.:
- #
- # // aten::max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *,
- # // Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices)
- # [](Tensor & max, Tensor & max_values, const Tensor & self, Dimname dim, bool keepdim) -> std::tuple<Tensor,Tensor>
- #
- # For deprecated python signature, it should follow deprecated python arg order.
- # TODO: This is to keep same byte-for-byte result as the old codegen - maybe unnecessary?
- def dispatch_lambda_args(
- ps: PythonSignature, f: NativeFunction, symint: bool = True
- ) -> Tuple[DispatchLambdaArgument, ...]:
- if isinstance(ps, PythonSignatureDeprecated):
- schema = ps.deprecated_schema
- else:
- schema = f.func
- # Start with cpp arguments - dispatch lambda signature always include 'self'
- cpp_args = cpp.arguments(
- arguments=schema.arguments,
- faithful=False,
- symint=symint,
- method=False,
- cpp_no_default_args=f.cpp_no_default_args,
- )
- out_args: Set[str] = {a.name for a in schema.arguments.out}
- # Convert from cpp argument to lambda argument
- def dispatch_lambda_arg(cpp_arg: Binding) -> DispatchLambdaArgument:
- type_str = cpp_arg.type
- is_out_arg = cpp_arg.name in out_args
- if ps.method and cpp_arg.name == "self":
- # For method's 'self', we can use 'const Tensor &' and simply ignore mutability!
- type_str = "const at::Tensor &"
- else:
- # For other cases we need prevent dangling refs to temps (unless it's
- # unpacked scattered output)
- # The reason is explained in the comments above and in 'dispatch_lambda_return_str()'.
- # TODO: avoid this special handling?
- ensure_temp_safe = len(out_args) <= 1 or not is_out_arg
- if ensure_temp_safe:
- type_str = {
- "at::Tensor &": "at::Tensor",
- }.get(type_str, type_str)
- return DispatchLambdaArgument(
- name=cpp_arg.name,
- type_str=type_str,
- is_out_arg=is_out_arg,
- )
- return tuple(map(dispatch_lambda_arg, cpp_args))
- # [old codegen] XXX: if you got here because of an assertion failure, it doesn't mean
- # it's enough to just extend the list here. Before you do this, make sure
- # to add an appropriate wrap() overload in torch/csrc/autograd/utils/wrap_outputs.h.
- SUPPORTED_RETURN_TYPES = {
- "at::Tensor",
- "::std::tuple<at::Tensor,at::Tensor>",
- "::std::tuple<at::Tensor,at::Tensor,at::Tensor>",
- "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
- "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
- "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
- "::std::tuple<at::Tensor,at::Tensor,at::Tensor,int64_t>",
- "::std::tuple<at::Tensor,at::Tensor,double,int64_t>",
- "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,int64_t>",
- "::std::tuple<at::Tensor,at::Tensor,double,at::Tensor,int64_t>",
- "::std::tuple<double,int64_t>",
- "::std::tuple<at::Tensor,::std::vector<at::Tensor>>",
- "::std::vector<at::Tensor>",
- # Needed for flash attention forw/backward
- "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,int64_t,int64_t,int64_t,int64_t,at::Tensor>",
- "at::Scalar",
- "bool",
- "int64_t",
- "void*",
- "void",
- "at::QScheme",
- "double",
- "at::IntArrayRef",
- "at::ScalarType",
- }
- def dispatch_lambda_return_str(f: NativeFunction) -> str:
- # [old codegen] Remove type annotation (e.g. 'Tensor' rather than 'Tensor &')
- # because the dispatch lambdas take mutable arguments *by value*, not
- # by reference. If you then return a reference to such an argument, you
- # will now have a pointer to a dangling stack entry. Not good.
- #
- # You want:
- #
- # auto dispatch_selu_ = [](Tensor self) -> Tensor { ...; return at::selu_(self); };
- # ^^^^^^
- #
- # *not*
- #
- # auto dispatch_selu_ = [](Tensor self) -> Tensor& { ...; return at::selu_(self); };
- # ^^^^^^^
- #
- # (NB: We can't make dispatch_selu_ take Tensor&, because the enclosing
- # codegen looks like dispatch_selu_(_r.tensor(0)), and you can't take a
- # mutable reference to temporary. Maybe we could assign it to a
- # variable itself.)
- returns_without_annotation = tuple(
- map(lambda r: Return(r.name, r.type, None), f.func.returns)
- )
- return_str = cpp.returns_type(returns_without_annotation, symint=True).cpp_type()
- if return_str not in SUPPORTED_RETURN_TYPES:
- raise RuntimeError(f"{f.func.name} returns unsupported type {return_str}")
- return return_str
- def cpp_dispatch_target(f: NativeFunction) -> str:
- symint = f.func.has_symint()
- name = cpp.name(f.func, symint_overload=symint)
- if Variant.method in f.variants:
- return f"self.{name}"
- if Variant.function in f.variants:
- if has_tensor_options(f) or f.func.name.name.base.endswith("_like"):
- namespace = "torch"
- else:
- namespace = "at"
- return f"{namespace}::{name}"
- raise RuntimeError(f"could not dispatch, neither function nor method: {f.func}")
- def cpp_dispatch_exprs(
- f: NativeFunction,
- *,
- python_signature: Optional[PythonSignature] = None,
- ) -> Tuple[str, ...]:
- cpp_args: Sequence[Binding] = _cpp_signature(f, method=False).arguments()
- exprs: Tuple[str, ...] = tuple()
- if not isinstance(python_signature, PythonSignatureDeprecated):
- # By default the exprs are consistent with the C++ signature.
- exprs = tuple(map(lambda a: a.name, cpp_args))
- else:
- # For deprecated python signature we may need fill in some constants.
- exprs = tuple(
- filter(
- lambda n: n != "out" or f.func.is_out_fn(),
- python_signature.deprecated_args_exprs,
- )
- )
- if Variant.method in f.variants:
- exprs = tuple(filter("self".__ne__, exprs))
- return exprs
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- #
- # Python / C++ Args Binding
- #
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- # We explicitly enumerate the PythonArgParser unpacking methods for all
- # supported types. This might be more verbose than necessary, partially
- # because of the irregularity of unpacking method naming, partially
- # because we want to mimic the old codegen behavior - to reject
- # unexpected and/or unsupported cases which the old codegen rejects.
- # For certain cases it is intentionally more restrictive than necessary,
- # e.g.: it doesn't accepts doublelist with definite size.
- def arg_parser_unpack_method(
- t: Type, default: Optional[str], default_init: Optional[str], *, symint: bool = True
- ) -> str:
- has_default_init = default_init is not None
- if has_default_init and str(t) not in (
- "ScalarType?",
- "ScalarType",
- "Device",
- "Device?",
- "Layout",
- "Layout?",
- "bool",
- "bool?",
- ):
- raise RuntimeError(f"type '{t}' does not supported unpacking with default")
- if isinstance(t, BaseType):
- if t.name in [
- BaseTy.Tensor,
- BaseTy.Stream,
- BaseTy.Storage,
- BaseTy.Scalar,
- BaseTy.Dimname,
- ]:
- # These unpack methods line up with their schema names
- return t.name.name.lower()
- elif t.name == BaseTy.ScalarType:
- return "scalartypeWithDefault" if has_default_init else "scalartype"
- elif t.name == BaseTy.Device:
- return "deviceWithDefault" if has_default_init else "device"
- elif t.name == BaseTy.int:
- return "toInt64"
- elif t.name == BaseTy.SymInt:
- if symint:
- return "toSymInt"
- else:
- return "toInt64"
- elif t.name == BaseTy.bool:
- return "toBoolWithDefault" if has_default_init else "toBool"
- elif t.name == BaseTy.float:
- return "toDouble"
- elif t.name == BaseTy.str:
- return "stringView"
- elif t.name == BaseTy.Layout:
- return "layoutWithDefault" if has_default_init else "layout"
- elif t.name == BaseTy.MemoryFormat:
- return "memoryformat"
- elif isinstance(t, OptionalType):
- if str(t.elem) == "Tensor":
- return "optionalTensor"
- elif str(t.elem) == "Generator":
- return "generator"
- elif str(t.elem) == "Dimname[]":
- return "toDimnameListOptional"
- elif not has_default_init and default in (None, "None", "c10::nullopt"):
- # If default is None: append 'Optional' to elem's unpacking method
- return (
- arg_parser_unpack_method(t.elem, None, None, symint=symint) + "Optional"
- )
- else:
- # Otherwise, load as underlying type with default
- return arg_parser_unpack_method(
- t.elem, default, default_init, symint=symint
- )
- elif isinstance(t, ListType):
- if str(t.elem) == "Tensor":
- # accept and use definite size
- if t.size is not None:
- return f"tensorlist_n<{t.size}>"
- else:
- return "tensorlist"
- elif str(t.elem) == "Tensor?":
- return "list_of_optional_tensors"
- elif str(t.elem) == "Dimname":
- # accept definite size
- return "dimnamelist"
- elif str(t.elem) == "int":
- # accept definite size
- return "intlist"
- elif str(t) == "float[]":
- return "doublelist"
- elif str(t.elem) == "SymInt":
- # accept definite size
- if symint:
- return "symintlist"
- else:
- return "intlist"
- elif str(t) == "Scalar[]":
- return "scalarlist"
- raise RuntimeError(f"type '{t}' is not supported by PythonArgParser")
- # Return RHS expression for python argument using PythonArgParser output.
- # e.g. for arg name 'foo', arg type 'bool', arg_index = 2, returns '_r.toBool(2)'
- def arg_parser_output_expr(
- arg_index: int, a: PythonArgument, *, symint: bool = True
- ) -> PythonArgParserOutputExpr:
- has_default = a.default_init is not None
- unpack_method = arg_parser_unpack_method(
- t=a.type, default=a.default, default_init=a.default_init, symint=symint
- )
- default = f", {a.default_init}" if has_default else ""
- expr = f"_r.{unpack_method}({arg_index}{default})"
- return PythonArgParserOutputExpr(
- name=a.name,
- expr=expr,
- index=arg_index,
- argument=a,
- )
- # Returns a map with key = arg_name and value = PythonArgParserOutputExpr.
- def arg_parser_output_exprs(
- ps: PythonSignature, f: NativeFunction, *, symint: bool = True
- ) -> Dict[str, PythonArgParserOutputExpr]:
- return {
- e.name: e
- for i, a in enumerate(ps.arguments())
- for e in (arg_parser_output_expr(i, a, symint=symint),)
- }
- # argument name to type for scattered tensor options fields
- TENSOR_OPTIONS_FIELDS = {
- "dtype": "ScalarType?",
- "device": "Device?",
- "layout": "Layout?",
- "pin_memory": "bool?",
- "requires_grad": "bool?",
- }
- # bind arg parser outputs (python args) with dispatch lambda arguments (c++ args).
- def dispatch_lambda_exprs(
- ps: PythonSignature, f: NativeFunction, *, symint: bool = True
- ) -> DispatchLambdaArgumentExprs:
- # This method is to bind 'arg_parser_outputs' and 'lambda_args' by producing
- # 'inits' and 'lambda_args_exprs' for each lambda argument using arg parser
- # outputs.
- arg_parser_outputs = arg_parser_output_exprs(ps, f, symint=symint)
- lambda_args = dispatch_lambda_args(ps, f, symint=symint)
- inits: List[str] = []
- lambda_args_exprs: Dict[str, str] = {}
- has_toptions = has_tensor_options(f)
- # 1. special inits/unpacking to provide binding exprs for lambda arguments.
- for a in ps.arguments(skip_tensor_options=True):
- name = a.name
- arg_parser_expr = arg_parser_outputs[a.name].expr
- if has_toptions and name == "self":
- # TODO: why this needs to be special case?
- inits.extend(
- [
- f"auto self = {arg_parser_expr};",
- ]
- )
- lambda_args_exprs[name] = name
- elif (
- isinstance(a, PythonOutArgument)
- and len(a.outputs) > 1
- and f.func.is_out_fn()
- ):
- inits.extend(
- [
- f"auto out = {arg_parser_expr};",
- ]
- )
- for i, out_arg in enumerate(a.outputs):
- lambda_args_exprs[out_arg.name] = f"out[{i}]"
- elif str(a.type) == "Dimname[]?":
- # [old codegen]
- # TODO: make this part of something more general, or get rid of it.
- # optional<ArrayRef<T>> are special. The PythonArgParser returns an
- # optional<vector<T>>, which cannot be implicitly converted to
- # optional<ArrayRef<T>>. One needs to unwrap the optional and rewrap.
- inits.extend(
- [
- f"auto __{name} = {arg_parser_expr};",
- f"c10::optional<DimnameList> {name} = __{name} ? c10::make_optional(DimnameList(__{name}.value())) : c10::nullopt;", # noqa: B950
- ]
- )
- lambda_args_exprs[name] = name
- else:
- # default case - directly using PythonArgParser output expr
- lambda_args_exprs[name] = arg_parser_expr
- # method's self is passed directly to python binding, rather than parsed
- if ps.method:
- lambda_args_exprs["self"] = "self"
- # 2. special packing/checking for TensorOptions.
- tensor_options_args_names = list(map(lambda a: a.name, ps.tensor_options_args))
- if has_toptions:
- if f.func.is_out_fn():
- raise RuntimeError(f"{f.func}: tensor options with output arg")
- for a in ps.tensor_options_args:
- if a.name not in TENSOR_OPTIONS_FIELDS:
- raise RuntimeError(
- f"{f.func}: unrecognized tensor options field '{a.name}' in python binding arguments"
- )
- if str(a.type) != TENSOR_OPTIONS_FIELDS.get(a.name):
- raise RuntimeError(
- f"{f.func}: unrecognized type '{str(a.type)}' for tensor options field '{a.name}'"
- )
- if not all(
- map(lambda a: a in tensor_options_args_names, TENSOR_OPTIONS_FIELDS.keys())
- ):
- raise RuntimeError(
- f"{f.func}: incomplete tensor options args: {tensor_options_args_names}"
- )
- inits.append(
- f"""\
- const auto options = TensorOptions()
- .dtype({arg_parser_outputs['dtype'].expr})
- .device({arg_parser_outputs['device'].expr})
- .layout({arg_parser_outputs['layout'].expr})
- .requires_grad({arg_parser_outputs['requires_grad'].expr})
- .pinned_memory({arg_parser_outputs['pin_memory'].expr});
- torch::utils::maybe_initialize_cuda(options);
- """
- )
- lambda_args_exprs["options"] = "options"
- # 3. special case - access scattered TensorOptions fields without packing
- # TODO: maybe move to the generator side as it's not related to binding.
- if not has_toptions and tensor_options_args_names:
- if "dtype" in tensor_options_args_names:
- # we're an output-arg variant, check these args against output tensor
- if not f.func.is_out_fn():
- raise RuntimeError(
- f"{f.func}: dtype in tensor_options_args without output arg"
- )
- if not all(
- map(lambda a: a in tensor_options_args_names, ("layout", "device"))
- ):
- raise RuntimeError(
- f"{f.func}: incomplete tensor options for output check"
- )
- inits.append(
- f"""\
- check_out_type_matches({arg_parser_outputs['out'].expr}, {arg_parser_outputs['dtype'].expr},
- {arg_parser_outputs['dtype'].is_none_expr}, {arg_parser_outputs['layout'].expr},
- {arg_parser_outputs['device'].expr}, {arg_parser_outputs['device'].is_none_expr});
- """
- )
- # we'll set requires_grad on outgoing tensor
- if "requires_grad" not in tensor_options_args_names:
- raise RuntimeError(
- f'{f.func}: expected "requires_grad" in tensor_options_args absent, but found [{tensor_options_args_names}]'
- )
- return DispatchLambdaArgumentExprs(
- exprs=tuple(map(lambda a: lambda_args_exprs[a.name], lambda_args)),
- inits=inits,
- )
|