123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838 |
- import argparse
- import functools
- import json
- import os
- import pathlib
- from collections import defaultdict, namedtuple, OrderedDict
- from dataclasses import dataclass
- from typing import (
- Any,
- Callable,
- Dict,
- List,
- Optional,
- Sequence,
- Set,
- Tuple,
- TypeVar,
- Union,
- )
- import yaml
- from typing_extensions import Literal # Python 3.8+
- import torchgen.api.dispatcher as dispatcher
- import torchgen.api.meta as meta
- import torchgen.api.native as native
- import torchgen.api.structured as structured
- import torchgen.dest as dest
- from torchgen.api import cpp
- from torchgen.api.translate import translate
- from torchgen.api.types import (
- Binding,
- CppSignature,
- CppSignatureGroup,
- DispatcherSignature,
- NamedCType,
- NativeSignature,
- SpecialArgName,
- )
- from torchgen.context import (
- method_with_native_function,
- native_function_manager,
- with_native_function,
- with_native_function_and_indices,
- )
- from torchgen.gen_functionalization_type import (
- gen_functionalization_definition,
- gen_functionalization_registration,
- gen_functionalization_view_inverse_declaration,
- GenCompositeViewCopyKernel,
- )
- from torchgen.gen_vmap_plumbing import gen_all_vmap_plumbing
- from torchgen.model import (
- Argument,
- BackendIndex,
- BackendMetadata,
- BaseOperatorName,
- DEFAULT_KERNEL_NAMESPACE,
- DispatchKey,
- FRAGMENT_NAMESPACES,
- FunctionSchema,
- is_cuda_dispatch_key,
- is_generic_dispatch_key,
- is_ufunc_dispatch_key,
- Location,
- NativeFunction,
- NativeFunctionsGroup,
- NativeFunctionsViewGroup,
- OperatorName,
- OptionalType,
- SchemaKind,
- SelfArgument,
- STRUCTURED_DISPATCH_KEYS,
- TensorOptionsArguments,
- Type,
- Variant,
- ViewSchemaKind,
- )
- from torchgen.native_function_generation import (
- add_generated_native_functions,
- gen_composite_functional_kernel,
- gen_composite_out_kernel,
- pre_group_native_functions,
- )
- from torchgen.selective_build.selector import SelectiveBuilder
- from torchgen.utils import (
- assert_never,
- concatMap,
- context,
- FileManager,
- make_file_manager,
- mapMaybe,
- NamespaceHelper,
- Target,
- YamlDumper,
- YamlLoader,
- )
- T = TypeVar("T")
- # Welcome to the ATen code generator v2! The ATen code generator is
- # responsible for parsing native_functions.yaml and then generating
- # various generated files (e.g., TypeDefault.cpp) based on the operators
- # defined in this file. This means that the code generator knows how to
- # parse function schema, and then translate this into various C++ types
- # and boilerplate code.
- #
- # Some things to know about this file when you modify it:
- #
- # - This file has STRICT mypy typechecking. Typecheck it with
- # `mypy --config mypy-strict.ini` in the root source directory
- #
- # - Most of the heavy lifting lives in external modules:
- # - 'model' has the data model for native_functions.yaml. The classes
- # in those file represent what you see when you look at
- # a native_functions.yaml
- # - 'api' has conversions for how to translate JIT schema into
- # the various C++ APIs that the codegen interacts with. There
- # are in fact THREE different C++ APIs: the public C++ API,
- # the dispatcher API, and the legacy dispatcher API. See each
- # of these respective files for more information
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- #
- # HELPER FUNCTIONS
- #
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- # A custom loader for YAML to let us also keep track of line numbers
- # of each entry in the YAML file
- class LineLoader(YamlLoader):
- def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def]
- mapping = super().construct_mapping(node, deep=deep) # type: ignore[no-untyped-call]
- # Add 1 so line numbering starts at 1
- mapping["__line__"] = node.start_mark.line + 1
- return mapping
- _GLOBAL_PARSE_NATIVE_YAML_CACHE = {}
- _GLOBAL_PARSE_TAGS_YAML_CACHE = {}
- # Parse native_functions.yaml into a sequence of NativeFunctions and Backend Indices.
- ParsedYaml = namedtuple("ParsedYaml", ["native_functions", "backend_indices"])
- def parse_native_yaml_struct(
- es: object,
- valid_tags: Set[str],
- ignore_keys: Optional[Set[DispatchKey]] = None,
- path: str = "<stdin>",
- skip_native_fns_gen: bool = False,
- ) -> ParsedYaml:
- assert isinstance(es, list)
- rs: List[NativeFunction] = []
- bs: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = defaultdict(dict)
- for e in es:
- assert isinstance(e.get("__line__"), int), e
- loc = Location(path, e["__line__"])
- funcs = e.get("func")
- with context(lambda: f"in {loc}:\n {funcs}"):
- func, m = NativeFunction.from_yaml(e, loc, valid_tags, ignore_keys)
- rs.append(func)
- BackendIndex.grow_index(bs, m)
- error_check_native_functions(rs)
- # Default dict is to prevent the codegen from barfing when we have a dispatch key that has no kernels yet.
- indices: Dict[DispatchKey, BackendIndex] = defaultdict(
- lambda: BackendIndex(
- dispatch_key=DispatchKey.Undefined,
- use_out_as_primary=True,
- external=False,
- device_guard=False,
- # I'm actually not sure about this; undefined could be hit on
- # empty TensorList, hypothetically that could have sizes in it
- index={},
- )
- )
- if not skip_native_fns_gen:
- add_generated_native_functions(rs, bs)
- for k, v in bs.items():
- # All structured in-tree operators are implemented in terms of their out operator.
- indices[k] = BackendIndex(
- dispatch_key=k,
- use_out_as_primary=True,
- external=False,
- # Only cuda-like devices in tree require device guards
- device_guard=is_cuda_dispatch_key(k),
- index=v,
- )
- return ParsedYaml(rs, indices)
- def parse_tags_yaml_struct(es: object, path: str = "<stdin>") -> Set[str]:
- assert isinstance(es, list)
- rs: Set[str] = set()
- for e in es:
- assert isinstance(e.get("__line__"), int), e
- loc = Location(path, e["__line__"])
- tags = e.get("tag")
- with context(lambda: f"in {loc}:\n {tags}"):
- e_i = e.copy()
- name = e_i.pop("tag")
- desc = e_i.pop("desc", "")
- # ensure that each tag has a non-empty description
- assert desc != ""
- rs.add(name)
- return rs
- @functools.lru_cache(maxsize=None)
- def parse_tags_yaml(path: str) -> Set[str]:
- global _GLOBAL_PARSE_TAGS_YAML_CACHE
- if path not in _GLOBAL_PARSE_TAGS_YAML_CACHE:
- with open(path, "r") as f:
- es = yaml.load(f, Loader=LineLoader)
- _GLOBAL_PARSE_TAGS_YAML_CACHE[path] = parse_tags_yaml_struct(es, path=path)
- return _GLOBAL_PARSE_TAGS_YAML_CACHE[path]
- def parse_native_yaml(
- path: str,
- tags_yaml_path: str,
- ignore_keys: Optional[Set[DispatchKey]] = None,
- *,
- skip_native_fns_gen: bool = False,
- ) -> ParsedYaml:
- global _GLOBAL_PARSE_NATIVE_YAML_CACHE
- if path not in _GLOBAL_PARSE_NATIVE_YAML_CACHE:
- valid_tags = parse_tags_yaml(tags_yaml_path)
- with open(path, "r") as f:
- es = yaml.load(f, Loader=LineLoader)
- _GLOBAL_PARSE_NATIVE_YAML_CACHE[path] = parse_native_yaml_struct(
- es,
- valid_tags,
- ignore_keys,
- path=path,
- skip_native_fns_gen=skip_native_fns_gen,
- )
- return _GLOBAL_PARSE_NATIVE_YAML_CACHE[path]
- # Some assertions are already performed during parsing, but those are only within a single NativeFunction.
- # Assertions here are meant to be performed across NativeFunctions.
- def error_check_native_functions(funcs: Sequence[NativeFunction]) -> None:
- func_map: Dict[OperatorName, NativeFunction] = {}
- base_func_map: Dict[BaseOperatorName, List[NativeFunction]] = defaultdict(list)
- for f in funcs:
- func_map[f.func.name] = f
- base_func_map[f.func.name.name].append(f)
- for f in funcs:
- if f.structured_delegate is not None:
- delegate_func = func_map[f.structured_delegate]
- assert delegate_func.structured, (
- f"{f.func.name} is marked as a structured_delegate pointing to "
- f"{f.structured_delegate}, but {f.structured_delegate} is not marked as structured. "
- f"Consider adding 'structured=True' to the delegated operator"
- )
- # See Note [resize_ in Functionalization]
- # resize_() is technically an inplace view op (and therefore needs the tag),
- # but it would be overkill to add a true "view" variant of resize.
- # Instead, resize_() gets special treatment in functionalization,
- # and we have a resize() op that is non-aliasing + functional.
- if "inplace_view" in f.tags and str(f.func.name) != "resize_":
- base_name = f.func.name.name
- overload_name = f.func.name.overload_name
- assert base_name.inplace, (
- f"{f.func.name} is marked with tag: inplace_view, but it doesn't follow the naming "
- "convention for inplace ops - the codegen expects the base name to have a trailing underscore. "
- )
- out_of_place_base_name = BaseOperatorName(
- base_name.base, False, base_name.dunder_method
- )
- assert len(base_func_map[out_of_place_base_name]) > 0, (
- f"{f.func.name} is marked with tag: inplace_view. The codegen expects there to be a corresponding "
- f"out-of-place view op with the name '{base_name}' and matching schema, but it didn't find one. "
- )
- def cpp_string(s: str) -> str:
- """Convert a python string into a c++ string literal"""
- s = s.replace("\\", "\\\\")
- s = s.replace('"', '\\"')
- s = s.replace("\a", "\\a")
- s = s.replace("\b", "\\b")
- s = s.replace("\f", "\\f")
- s = s.replace("\n", "\\n")
- s = s.replace("\v", "\\v")
- s = s.replace("\t", "\\t")
- return f'"{s}"'
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- #
- # C++ CODE GENERATION
- #
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- # Most functions in this section are curried: they consist of a function
- # that takes some parameters (e.g., what is to be generated) which itself
- # returns a function that actually maps NativeFunction to the code
- # to be generated. This pattern makes it convenient to use map, concatMap
- # and similar functional combinators.
- def static_dispatch_keys(backends: List[BackendIndex]) -> List[DispatchKey]:
- if len(backends) == 0:
- return []
- else:
- return [backend.dispatch_key for backend in backends] + [
- DispatchKey.CompositeImplicitAutograd,
- DispatchKey.CompositeImplicitAutogradNestedTensor,
- DispatchKey.CompositeExplicitAutograd,
- DispatchKey.CompositeExplicitAutogradNonFunctional,
- ]
- def get_static_dispatch_backend(
- f: NativeFunction, backend_index: BackendIndex
- ) -> Optional[DispatchKey]:
- if f.structured_delegate is not None or backend_index.has_kernel(f):
- # TODO: for ops with structured_delegate it should check the dispatch table of
- # the out variant instead. For now, these structured ops all have CPU/CUDA kernels
- # so we always dispatch to the `backend`, but this could be wrong when we
- # migrate math/default_backend ops to use structured delegate.
- return backend_index.dispatch_key
- elif f.has_composite_explicit_autograd_kernel:
- return DispatchKey.CompositeExplicitAutograd
- elif f.has_composite_explicit_autograd_non_functional_kernel:
- return DispatchKey.CompositeExplicitAutogradNonFunctional
- elif f.has_composite_implicit_autograd_kernel:
- return DispatchKey.CompositeImplicitAutograd
- elif f.has_composite_implicit_autograd_nested_tensor_kernel:
- return DispatchKey.CompositeImplicitAutogradNestedTensor
- return None
- def static_dispatch_ops_header(
- f: NativeFunction, backend_index: List[BackendIndex]
- ) -> Optional[str]:
- if backend_index is None or f.manual_kernel_registration:
- return None
- output = []
- for index in backend_index:
- dispatch_key = get_static_dispatch_backend(f, index)
- if dispatch_key is not None:
- output.append(
- f"#include <ATen/ops/{f.root_name}_{dispatch_key.lower()}_dispatch.h>"
- )
- return "\n".join(output)
- def static_dispatch_extra_headers(backends: List[BackendIndex]) -> List[str]:
- return [
- f"#include <ATen/{dispatch_key}Functions.h>"
- for dispatch_key in static_dispatch_keys(backends)
- ]
- # Translates arguments of `sig` to CppSignature bindings.
- # Note that we have a special case for `memory_format` argument and this case is not covered by
- # tools.codegen.api.translate() yet as its application is limited to static dispatch.
- def translate_args(
- sig: Union[CppSignature, DispatcherSignature],
- cpp_sig: CppSignature,
- ) -> str:
- # Adds SpecialArgName.possibly_redundant_memory_format NamedCType for memory_format bindings
- def add_spl_memory_format_binding(input_bindings: List[Binding]) -> List[Binding]:
- output_bindings: List[Binding] = []
- for binding in input_bindings:
- if binding.name == "memory_format":
- spl_mem_format_binding = Binding(
- nctype=NamedCType(
- SpecialArgName.possibly_redundant_memory_format,
- binding.nctype.type,
- ),
- name=binding.name,
- default=binding.default,
- argument=binding.argument,
- )
- output_bindings.append(spl_mem_format_binding)
- else:
- output_bindings.append(binding)
- return output_bindings
- src_bindings = list(sig.arguments())
- goal_bindings = list(cpp_sig.arguments())
- # When last argument of CPP signature has SpecialArgName.possibly_redundant_memory_format NCType,
- # get memory_format bindings of dispatcher signature to have the same NCType as well
- for arg in goal_bindings:
- if arg.nctype.name == SpecialArgName.possibly_redundant_memory_format:
- src_bindings = add_spl_memory_format_binding(src_bindings)
- break
- exprs = translate(src_bindings, goal_bindings)
- return ", ".join(a.expr for a in exprs)
- def generate_static_dispatch_backend_call(
- sig: Union[CppSignature, DispatcherSignature],
- f: NativeFunction,
- backend_index: BackendIndex,
- ) -> str:
- cpp_sigs = CppSignatureGroup.from_native_function(
- f, method=False, fallback_binding=False
- )
- if sig.symint and f.func.has_symint():
- cpp_sig = cpp_sigs.symint_signature
- else:
- cpp_sig = cpp_sigs.signature
- assert cpp_sig is not None
- name = cpp_sig.name()
- exprs = translate_args(sig, cpp_sig)
- backend_metadata = backend_index.get_kernel(f)
- kernel_ns = (
- backend_metadata.cpp_namespace
- if backend_metadata and backend_metadata.cpp_namespace
- else DEFAULT_KERNEL_NAMESPACE
- )
- ns = kernel_ns.replace("::native", "")
- return f"return {ns}::{backend_index.dispatch_key.lower()}::{name}({exprs});"
- def generate_static_dispatch_fallback_call(
- sig: Union[CppSignature, DispatcherSignature],
- f: NativeFunction,
- backend_indices: List[BackendIndex],
- ) -> str:
- cpp_sigs = CppSignatureGroup.from_native_function(
- f, method=False, fallback_binding=False
- )
- if sig.symint and f.func.has_symint():
- cpp_sig = cpp_sigs.symint_signature
- else:
- cpp_sig = cpp_sigs.signature
- assert cpp_sig is not None
- name = cpp_sig.name()
- exprs = translate_args(sig, cpp_sig)
- ns = DEFAULT_KERNEL_NAMESPACE.replace("::native", "")
- if f.has_composite_explicit_autograd_kernel:
- return f"return {ns}::{DispatchKey.CompositeExplicitAutograd.lower()}::{name}({exprs});"
- elif f.has_composite_explicit_autograd_non_functional_kernel:
- return f"return {ns}::{DispatchKey.CompositeExplicitAutogradNonFunctional.lower()}::{name}({exprs});"
- elif f.has_composite_implicit_autograd_kernel:
- return f"return {ns}::{DispatchKey.CompositeImplicitAutograd.lower()}::{name}({exprs});"
- elif f.has_composite_implicit_autograd_nested_tensor_kernel:
- return f"return {ns}::{DispatchKey.CompositeImplicitAutogradNestedTensor.lower()}::{name}({exprs});"
- else:
- return f"""TORCH_CHECK(false, "Static dispatch does not support {name} for\
- {', '.join([str(index.dispatch_key)for index in backend_indices])} ");"""
- def static_dispatch(
- sig: Union[CppSignature, DispatcherSignature],
- f: NativeFunction,
- backend_indices: List[BackendIndex],
- ) -> str:
- """
- For a given `NativeFunction`, find out the corresponding backend and dispatch to it. If more than one
- backends exsit, fallback to static dispatch by determining dispatch key from inputs.
- Arguments:
- sig: A CppSignature or DispatcherSignature for this native function we want to use.
- f: NativeFunction to generate static dispatch.
- backend_indices: All available backends.
- Return:
- C++ code to call backend-specific functions, e.g., "return at::cpu::add(self, other, scale);"
- """
- if len(backend_indices) == 0 or f.manual_kernel_registration:
- return ""
- keys = [
- b
- for b in backend_indices
- if b.has_kernel(f)
- or (
- f.structured_delegate is not None
- and b.dispatch_key in STRUCTURED_DISPATCH_KEYS
- )
- ]
- if len(keys) == 1:
- return generate_static_dispatch_backend_call(sig, f, keys[0])
- elif len(keys) == 0:
- return generate_static_dispatch_fallback_call(sig, f, backend_indices)
- native_tensor_args = [
- a.name
- for a in sig.arguments()
- if isinstance(a.argument, SelfArgument)
- or isinstance(a.argument, Argument)
- and a.argument.type.is_tensor_like()
- ]
- tensor_args = ", ".join(native_tensor_args)
- tensor_opts = f.func.arguments.tensor_options
- stmts = []
- subexprs: List[str] = []
- if tensor_opts is not None:
- subexprs.append(
- "DispatchKeySet(c10::computeDispatchKey(dtype, layout, device))"
- )
- if tensor_args != "":
- subexprs.append(f"c10::detail::multi_dispatch_key_set({tensor_args})")
- stmts.append(f"""DispatchKeySet _dk_set = {' | '.join(subexprs)};""")
- stmts.append("DispatchKey _dk = c10::highestPriorityBackendTypeId(_dk_set);")
- dispatch_code = []
- for index in keys:
- dispatch_code.append(f"""case DispatchKey::{index.dispatch_key}:""")
- dispatch_code.append(
- f"""\t{generate_static_dispatch_backend_call(sig, f, index)};"""
- )
- fallback = generate_static_dispatch_fallback_call(sig, f, backend_indices)
- connector = "\n\t\t"
- return f"""
- {connector.join(stmts)}
- switch (_dk) {{
- {connector.join(dispatch_code)}
- default:
- {fallback}
- }}
- """
- # Generates RegisterSchema.cpp. Depending on the selector, either
- # all schemas are registered, or only some are (in the case of
- # selective build)
- @dataclass(frozen=True)
- class RegisterSchema:
- selector: SelectiveBuilder
- @method_with_native_function
- def __call__(self, f: NativeFunction) -> Optional[str]:
- if not self.selector.is_native_function_selected(f):
- return None
- tags = "{" + ", ".join([f"at::Tag::{tag}" for tag in f.tags]) + "}"
- return f"m.def({cpp_string(str(f.func))}, {tags});\n"
- # Generates Operators.h and Operators.cpp.
- # These provide macros that, given an operator and overload name, allow users
- # to access an "un-overloaded" function version of the operator. This
- # is useful for extension writers who want to (1) want to decltype the operator
- # and (2) don't want to worry about method-only operators.
- @dataclass(frozen=True)
- class ComputeOperators:
- target: Literal[Target.DECLARATION, Target.DEFINITION]
- static_dispatch_backend_indices: List[BackendIndex]
- @method_with_native_function
- def __call__(self, f: NativeFunction) -> str:
- sig = DispatcherSignature.from_schema(f.func)
- name = f.func.name.unambiguous_name()
- if self.target is Target.DECLARATION:
- # Note [The ATen Operators API]
- # The ATen Operators API lives in the at::_ops namespace, and contains compile-time
- # metadata about each operator + entry points into the Dispatcher.
- # The C++ function, method, and redispatch API's are all implemented as wrappers
- # into various bits of the structs defined here.
- #
- # Important characteristics about the Operators API:
- # (1) It follows the Dispatcher API.
- # This is kind of necessary to avoid overhead.
- # For example: if it followed the C++ API, then all of the faithful C++ factory functions
- # would need to wrap their arguments into TensorOptions only to unwrap them again.
- # (2) Overload names are disambiguated.
- # This is helpful for pytorch extenders who would like to decltype() an aten operator,
- # that has overloads, e.g. decltype(at::_ops::mul_Tensor::call)
- # (3) No argument defaulting is allowed.
- # This is more of an implementation detail to avoid #include cycles,
- # since TensorBody.h (which defines the Tensor class) needs to include this file.
- # (4) manual_cpp_bindings and faithful names are not included in the API.
- # This applies to stuff like __dispatch__is_complex(), and add_outf().
- # These aren't "real aten ops", they're just additional functions provided by the C++ API.
- # They're implemented as wrappers in Functions.h that call into the actual operators
- # defined here, i.e. at::_ops::is_complex::call() and at::_ops::add_out::call().
- # This means that ATEN_OP(is_complex) will not fastpath, and will go through the dispatcher.
- return f"""
- struct TORCH_API {name} {{
- using schema = {sig.type()};
- using ptr_schema = schema*;
- // See Note [static constexpr char* members for windows NVCC]
- STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::{f.func.name.name}")
- STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "{f.func.name.overload_name}")
- STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, {cpp_string(str(f.func))})
- static {sig.defn(name="call", is_redispatching_fn=False)};
- static {sig.defn(name="redispatch", is_redispatching_fn=True)};
- }};"""
- elif self.target is Target.DEFINITION:
- defns = f"""
- STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, name, "aten::{f.func.name.name}")
- STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, overload_name, "{f.func.name.overload_name}")
- STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, schema_str, {cpp_string(str(f.func))})
- // aten::{f.func}
- static C10_NOINLINE c10::TypedOperatorHandle<{name}::schema> create_{name}_typed_handle() {{
- return c10::Dispatcher::singleton()
- .findSchemaOrThrow({name}::name, {name}::overload_name)
- .typed<{name}::schema>();
- }}
- """
- for is_redispatching_fn in [False, True]:
- if is_redispatching_fn:
- dispatcher_exprs_str = ", ".join(
- ["dispatchKeySet"] + [a.name for a in sig.arguments()]
- )
- method_base = "redispatch"
- else:
- dispatcher_exprs_str = ", ".join([a.name for a in sig.arguments()])
- method_base = "call"
- dispatcher_call = method_base
- method_name = f"{name}::{method_base}"
- fn_body = f"""
- static auto op = create_{name}_typed_handle();
- return op.{dispatcher_call}({dispatcher_exprs_str});"""
- if (
- not is_redispatching_fn
- and len(self.static_dispatch_backend_indices) > 0
- ):
- # call() should go through static dispatch
- fn_body = static_dispatch(
- sig, f, backend_indices=self.static_dispatch_backend_indices
- )
- defns += f"""
- // aten::{f.func}
- {sig.defn(name=method_name, is_redispatching_fn=is_redispatching_fn)} {{
- {fn_body}
- }}
- """
- return defns
- else:
- assert_never(self.target)
- # Generates Functions.h, which provides the functional public C++ API,
- # and the scaffolding to call into the dispatcher from these functions.
- @dataclass(frozen=True)
- class ComputeFunction:
- @method_with_native_function
- def __call__(self, f: NativeFunction) -> Optional[str]:
- sig_group = CppSignatureGroup.from_native_function(
- f, method=False, fallback_binding=f.manual_cpp_binding
- )
- has_symint = f.func.has_symint()
- result = ""
- for sig in sig_group.signatures():
- # See Note [The ATen Operators API]
- target_sig = DispatcherSignature.from_schema(f.func)
- exprs = translate(sig.arguments(), target_sig.arguments())
- exprs_str = ", ".join([e.expr for e in exprs])
- if sig.symint:
- intlike_t = "c10::SymInt"
- else:
- intlike_t = "int64_t"
- if Variant.function in f.variants:
- result += f"""
- // aten::{f.func}
- inline {sig.decl()} {{
- return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
- }}"""
- # The template function can be used from template situations
- # where you want to switch between the symint or not version
- # depending on a template argument
- #
- # NB: we ALWAYS generate this even for methods. But we put it in
- # this header so it can take advantage of per-op headers
- if has_symint:
- result += f"""
- namespace symint {{
- template <typename T, typename = std::enable_if_t<std::is_same<T, {intlike_t}>::value>>
- {sig.decl(suppress_symint_suffix=True)} {{
- return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
- }}
- }}
- """
- return result
- # Generates TensorBody.h. This file provides the object-oriented (method-based)
- # public C++ API, and the scaffolding to call into the dispatcher from these functions.
- @dataclass(frozen=True)
- class ComputeTensorMethod:
- target: Literal[Target.DECLARATION, Target.DEFINITION]
- static_dispatch_backend_indices: List[BackendIndex]
- @method_with_native_function
- def __call__(self, f: NativeFunction) -> Optional[str]:
- if Variant.method not in f.variants:
- return None
- assert not f.func.is_out_fn()
- assert f.func.arguments.self_arg is not None
- sig_group = CppSignatureGroup.from_native_function(
- f, method=True, fallback_binding=f.manual_cpp_binding
- )
- if self.target is Target.DECLARATION:
- result = ""
- for sig in sig_group.signatures():
- result += f"{sig.decl()} const;\n"
- return result
- if self.target is not Target.DEFINITION:
- assert_never(self.target)
- result = ""
- for sig in sig_group.signatures():
- target_sig = DispatcherSignature.from_schema(f.func)
- exprs = translate(sig.arguments(), target_sig.arguments(), method=True)
- exprs_str = ", ".join([e.expr for e in exprs])
- result += f"""
- // aten::{f.func}
- inline {sig.defn(prefix="Tensor::")} const {{
- return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
- }}
- """
- return result
- # Generates RedispatchFunctions.h.
- # This is similar to the C++ API defined in Functions.h, but provides access
- # to the dispatcher's redispatch API.
- @dataclass(frozen=True)
- class ComputeRedispatchFunction:
- @method_with_native_function
- def __call__(self, f: NativeFunction) -> Optional[str]:
- # We unconditionally generate function variants of the redispatch API.
- # This is mainly because we can namespace functions separately, but not methods,
- sig_group = CppSignatureGroup.from_native_function(
- f, method=False, fallback_binding=f.manual_cpp_binding
- )
- result = ""
- for sig in sig_group.signatures():
- target_sig = DispatcherSignature.from_schema(f.func)
- exprs = translate(sig.arguments(), target_sig.arguments())
- exprs_str = ", ".join(["dispatchKeySet"] + [a.expr for a in exprs])
- result += f"""
- // aten::{f.func}
- inline {sig.decl(is_redispatching_fn=True)} {{
- return at::_ops::{f.func.name.unambiguous_name()}::redispatch({exprs_str});
- }}
- """
- return result
- # Generates ATenOpList.cpp, a runtime accessible list of all aten
- # operators.
- # TODO: This was historically used to help some JIT interop code
- # figure out whether or not to treat aten namespace'd operators
- # one way or another, we should reevaluate if this is actually needed.
- @with_native_function
- def compute_aten_op(f: NativeFunction) -> str:
- return f'{{"aten::{f.func.name.name}", "{f.func.name.overload_name}"}},'
- # Generates MetaFunctions.h
- def compute_meta_function_declaration(g: NativeFunctionsGroup) -> Optional[str]:
- if not g.structured:
- return None
- with native_function_manager(g.out):
- name = meta.name(g)
- args = structured.meta_arguments(g)
- args_str = ", ".join(a.decl() for a in args)
- parent_class = g.out.structured_inherits
- if parent_class is None:
- parent_class = "at::impl::MetaBase"
- meta_return = "void"
- precomputed = g.out.precomputed if g.structured else None
- if precomputed:
- # Generate the template declaration with one bool parameter for each
- # precomputed element. Each parameter is true if the corresponding (in
- # terms of position) precomputed element has been set.
- precomputed_values = [*precomputed.replace.values(), precomputed.add]
- precomputed_elements = [
- elem for replace_list in precomputed_values for elem in replace_list
- ]
- precomputed_template_parameters = [
- elem.name.upper() for elem in precomputed_elements
- ]
- precomputed_template_params_str = ", ".join(
- f"bool {param} = false" for param in precomputed_template_parameters
- )
- precompute_template_decl = f"template <{precomputed_template_params_str}>"
- # Generate a string containing declarations of all precomputed elements.
- precomputed_elements_with_cpp_types = [
- structured.argument_type(elem, binds=elem.name)
- for elem in precomputed_elements
- ]
- precomputed_elements_decl = ";\n".join(
- f"{elem.cpp_type(strip_ref=True)} {elem.name}"
- for elem in precomputed_elements_with_cpp_types
- )
- # Generate "setter" methods for each precomputed element. Each method will return
- # a new instance of precompute_out with the template parameter that corresponds to
- # the member set by the method to true (to indicate that it has been set).
- setter_methods = []
- for i, elem in enumerate(precomputed_elements):
- # Generate the signature. The return type will be the same
- # as the type of `this` but with the template parameter
- # corresponding to the element set by this method set to true.
- # The assert generated below will ensure that this template
- # parameter is false on the type of `this`.
- return_ty_templates = ", ".join(
- precomputed_template_parameters[:i]
- + ["true"]
- + precomputed_template_parameters[i + 1 :]
- )
- return_ty = f"precompute_out<{return_ty_templates}>"
- elem_cpp_ty = precomputed_elements_with_cpp_types[i].cpp_type(
- strip_ref=True
- )
- signature = f"{return_ty} set_{elem.name}({elem_cpp_ty} value)"
- # Generate an assert which checks that the
- # template parameter corresponding to the precomputed
- # element that is set by this method is false on the
- # class corresponding to the object that `this` points to.
- # This ensures that each element can be set only once.
- assert_msg = f'"{precomputed_elements[i].name} already set"'
- assert_stmt = f"static_assert({precomputed_template_parameters[i]} == false, {assert_msg});"
- # Generate the new object construction block. All state
- # except the element that this method sets is copied from the
- # object that `this` points to. The value for the element that
- # the method sets is taken from a method parameter.
- construction_stmts = []
- construction_stmts.append(f"{return_ty} ret;")
- for j, elem in enumerate(precomputed_elements):
- if i == j:
- construction_stmts.append(f"ret.{elem.name} = value;")
- else:
- construction_stmts.append(
- f"ret.{elem.name} = this->{elem.name};"
- )
- construction_stmts.append("return ret;")
- construction_block = "\n".join(construction_stmts)
- setter_methods.append(
- f"""
- {signature} {{
- {assert_stmt}
- {construction_block}
- }}
- """
- )
- setter_methods_decl = "\n".join(setter_methods)
- # Meta should return an instance of the struct containing the precomputed elements.
- meta_return_template_params = ", ".join(
- ["true"] * len(precomputed_template_parameters)
- )
- # This typedef (actually a using statement) is needed so that TORCH_META_FUNC can reuse the return
- # type (which has a variable number of template parameters).
- meta_return_typedef = f"using meta_return_ty = precompute_out <{meta_return_template_params}>;"
- meta_return = "meta_return_ty"
- precomputed_decl = f"""
- {precompute_template_decl}
- struct TORCH_API precompute_out {{
- {setter_methods_decl}
- {precomputed_elements_decl};
- }};"""
- else:
- meta_return_typedef = ""
- precomputed_decl = ""
- return f"""\
- struct TORCH_API structured_{name} : public {parent_class} {{
- {precomputed_decl}
- {meta_return_typedef}
- {meta_return} meta({args_str});
- }};
- """
- def needs_backend_select(f: NativeFunction, selector: SelectiveBuilder) -> bool:
- name = str(f.func.name.name)
- if name.endswith("_like") or name.startswith("new_"):
- return False
- if f.func.arguments.tensor_options is None:
- return False
- return selector.is_native_function_selected(f)
- # Generates RegisterBackendSelect.cpp, a series of kernels which provide
- # specialized computation of dispatch key for operator signatures which cannot
- # be easily done automatically using templating.
- @dataclass(frozen=True)
- class ComputeBackendSelect:
- target: Literal[Target.DEFINITION, Target.REGISTRATION]
- # Selector object to determine which operators to generate
- # registration code for.
- selector: SelectiveBuilder
- @method_with_native_function
- def __call__(self, f: NativeFunction) -> Optional[str]:
- if not needs_backend_select(f, self.selector):
- return None
- name = native.name(f.func)
- # BackendSelect can go to Meta, so it must preserve symints
- native_sig = NativeSignature(f.func, symint=True)
- native_tensor_args = [
- a
- for a in native_sig.arguments()
- if isinstance(a.argument, Argument) and a.argument.type.is_tensor_like()
- ]
- dispatcher_sig = DispatcherSignature.from_schema(f.func)
- sig: Union[NativeSignature, DispatcherSignature]
- sig = dispatcher_sig
- dispatcher_exprs = dispatcher_sig.exprs()
- dispatch_key = "c10::computeDispatchKey(dtype, layout, device)"
- if self.target is Target.DEFINITION:
- # I don't think there's actually a good reason to generate
- # these two cases differently
- # The first case could probably be improved though- it calls computeDispatchKeySet(),
- # which looks at TLS dispatch keys- there should not be any by the time we reach backend select.
- if native_tensor_args:
- assert f.func.arguments.has_tensor_arg()
- tensor_args = ", ".join(a.name for a in native_tensor_args)
- compute_dk = f"""\
- DispatchKeySet _dk_set = c10::DispatchKeySet({dispatch_key}) | c10::detail::multi_dispatch_key_set({tensor_args});
- DispatchKeySet _dk_mask = c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::BackendSelect);
- DispatchKeySet _dk = c10::impl::computeDispatchKeySet(_dk_set, _dk_mask);"""
- else:
- assert not f.func.arguments.has_tensor_arg()
- compute_dk = (
- f"DispatchKeySet _dk = c10::DispatchKeySet({dispatch_key});"
- )
- return f"""\
- // aten::{f.func}
- C10_ALWAYS_INLINE
- {sig.defn(name)} {{
- {compute_dk}
- return at::_ops::{f.func.name.unambiguous_name()}::redispatch(
- _dk, {', '.join(a.expr for a in dispatcher_exprs)});
- }}
- """
- elif self.target is Target.REGISTRATION:
- return f"""m.impl("aten::{f.func.name}", TORCH_FN({name}));"""
- else:
- assert_never(self.target)
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- #
- # YAML CODE GENERATION
- #
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- def format_yaml(data: object) -> str:
- # Ignore alias in Dumper
- YamlDumper.ignore_aliases = lambda self, data: True # type: ignore[assignment]
- # Support serializing OrderedDict
- def dict_representer(dumper: Any, data: Any) -> Any:
- return dumper.represent_dict(data.items())
- YamlDumper.add_representer(OrderedDict, dict_representer) # type: ignore[no-untyped-call]
- # Some yaml parsers (e.g. Haskell's) don't understand line breaks.
- # width=1e9 turns off optional line breaks and improves
- # the portability of the outputted yaml.
- return yaml.dump(data, default_flow_style=False, Dumper=YamlDumper, width=1e9) # type: ignore[no-any-return, call-overload]
- # For some reason, some defaults we write to YAML are written as native
- # YAML objects, rather than doing them uniformly as strings. This
- # function detects those cases and converts them into native Python
- # objects.
- def pythonify_default(s: str) -> object:
- if s == "true":
- return True
- elif s == "false":
- return False
- try:
- return int(s)
- except ValueError:
- try:
- return float(s)
- except ValueError:
- return s
- # What is a dynamic type? Over time, the semantic meaning of
- # dynamic type has degraded to meaninglessness (in the old days,
- # it captured dtype-ness of types, but that has gone away with
- # the removal of TH). These days, it's mostly the same thing as
- # the C++ API argument type, except that Tensor and Tensor?
- # arguments simply present as Tensor.
- #
- # TODO: Get rid of dynamic_type, after getting tools/autograd
- # to use the new codegen framework
- def dynamic_type(t: Type) -> str:
- if isinstance(t, OptionalType):
- return dynamic_type(t.elem)
- # Note we don't use t.is_tensor_like() here because it would
- # also include Tensor[]
- if str(t) == "Tensor":
- return "at::Tensor"
- # This is a legacy concept, so never report SymInt
- return cpp.argumenttype_type(
- t, mutable=False, binds="__placeholder__", symint=False
- ).cpp_type()
- def compute_method_of_yaml(variants: Set[Variant]) -> List[str]:
- # This is written out explicitly to ensure that Tensor and
- # namespace are put into the list in the right order
- method_of = ["Type"]
- if Variant.method in variants:
- method_of.append("Tensor")
- if Variant.function in variants:
- method_of.append("namespace")
- return method_of
- def compute_returns_yaml(
- f: NativeFunction,
- ) -> Tuple[List[Dict[str, str]], Dict[str, str]]:
- # Note [name and field_name]
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~
- # To understand name_to_field_name, we must first talk about this
- # schema:
- #
- # lstsq.X(Tensor self, Tensor A, *, Tensor(a!) X, Tensor(b!) qr) -> (Tensor(a!) solution, Tensor(b!) QR)
- #
- # There is something very odd about this schema: it is an out
- # variant of the function (that is to say, it will convert into
- # at::lstsq_out() in the C++ API), but the names of the output
- # return arguments don't match the keyword argument names of
- # the inputs. It TURNS OUT that in this situation, the historical
- # Declarations.yaml we want to output is this (abbreviated to
- # only show relevant fields):
- #
- # arguments:
- # ...
- # - field_name: solution
- # name: X
- # - field_name: QR
- # name: qr
- # ...
- #
- # returns:
- # - field_name: solution
- # name: X
- # - field_name: QR
- # name: qr
- #
- # The name of the return fields is stored in 'field_name', and the
- # name of the arguments is stored in 'name'. So when we process
- # arguments, we need a way to get at the corresponding return. At
- # the moment, this is most conveniently done by constructing a
- # mapping from name (the argument concept) to field_name (the
- # return concept) while processing return arguments, since we don't
- # directly maintain this correspondence in the modeling of function
- # schema itself.
- #
- # See also https://github.com/pytorch/pytorch/issues/43114
- name_to_field_name: Dict[str, str] = {}
- # Compute the returns field of the YAML entry
- names = cpp.return_names(f)
- returns = []
- for i, (r, name) in enumerate(zip(f.func.returns, names)):
- ret = {
- "dynamic_type": dynamic_type(r.type),
- "name": name,
- # legacy, report ints
- "type": cpp.return_type(r, symint=False).cpp_type(),
- }
- if r.name:
- # See Note [name and field_name]
- ret["field_name"] = r.name
- if f.func.is_out_fn():
- name_to_field_name[f.func.arguments.out[i].name] = r.name
- returns.append(ret)
- return returns, name_to_field_name
- # arguments in yaml roughly corresponds to the public C++ API
- def compute_cpp_argument_yaml(
- cpp_a: Binding,
- *,
- schema_order: bool,
- kwarg_only_set: Set[str],
- out_arg_set: Set[str],
- name_to_field_name: Dict[str, str],
- ) -> object:
- if isinstance(cpp_a.argument, TensorOptionsArguments):
- arg: Dict[str, object] = {
- "annotation": None,
- "dynamic_type": "at::TensorOptions",
- "is_nullable": False,
- "name": cpp_a.name,
- "type": cpp_a.type,
- "kwarg_only": True,
- }
- if cpp_a.default is not None:
- arg["default"] = cpp_a.default
- return arg
- elif isinstance(cpp_a.argument, SelfArgument):
- raise AssertionError()
- elif isinstance(cpp_a.argument, Argument):
- return compute_argument_yaml(
- cpp_a.argument,
- schema_order=schema_order,
- kwarg_only_set=kwarg_only_set,
- out_arg_set=out_arg_set,
- name_to_field_name=name_to_field_name,
- )
- def compute_argument_yaml(
- a: Argument,
- *,
- schema_order: bool,
- kwarg_only_set: Set[str],
- out_arg_set: Set[str],
- name_to_field_name: Dict[str, str],
- ) -> object:
- arg: Dict[str, object] = {
- "annotation": str(a.annotation) if a.annotation else None,
- "dynamic_type": dynamic_type(a.type),
- "is_nullable": a.type.is_nullable(),
- "name": a.name,
- # legacy, report ints
- "type": cpp.argument_type(a, binds="__placeholder__", symint=False).cpp_type(),
- }
- if a.default is not None:
- arg["default"] = pythonify_default(
- cpp.default_expr(a.default, a.type, symint=False)
- )
- if a.name in kwarg_only_set:
- arg["kwarg_only"] = True
- if a.name in out_arg_set:
- arg["output"] = True
- arg["allocate"] = True
- # See Note [name and field_name]
- if a.name in name_to_field_name:
- arg["field_name"] = name_to_field_name[a.name]
- # Historically, booleans don't get their size recorded, because it
- # is already built into the cpp type (e.g., std::array<bool, 4>)
- l = a.type.is_list_like()
- if l is not None and l.size is not None and str(l.elem) != "bool":
- arg["size"] = l.size
- return arg
- @with_native_function
- def compute_declaration_yaml(f: NativeFunction) -> object:
- returns, name_to_field_name = compute_returns_yaml(f)
- # These sets are used to conveniently test if an argument is a
- # kwarg-only or out argument
- kwarg_only_set = {a.name for a in f.func.arguments.flat_kwarg_only}
- out_arg_set = {a.name for a in f.func.arguments.out}
- sig_group = CppSignatureGroup.from_native_function(
- f, method=False, fallback_binding=False
- )
- cpp_args = sig_group.signature.arguments()
- arguments = [
- compute_cpp_argument_yaml(
- cpp_a,
- schema_order=False,
- kwarg_only_set=kwarg_only_set,
- out_arg_set=out_arg_set,
- name_to_field_name=name_to_field_name,
- )
- for cpp_a in cpp_args
- ]
- schema_order_jit_arguments = list(f.func.schema_order_arguments())
- schema_order_arguments = [
- compute_argument_yaml(
- a,
- schema_order=True,
- kwarg_only_set=kwarg_only_set,
- out_arg_set=out_arg_set,
- name_to_field_name=name_to_field_name,
- )
- for a in schema_order_jit_arguments
- ]
- cpp_schema_order_types = [
- # NB: method here doesn't matter
- r.type
- for a in schema_order_jit_arguments
- for r in cpp.argument(
- a,
- method=False,
- cpp_no_default_args=set(),
- faithful=False,
- symint=False,
- has_tensor_options=False,
- )
- ]
- # legacy, report ints
- cpp_returns = cpp.returns_type(f.func.returns, symint=False).cpp_type()
- schema_order_cpp_signature = f"{cpp_returns} ({', '.join(cpp_schema_order_types)})"
- is_factory_method = (
- any(isinstance(a.argument, TensorOptionsArguments) for a in cpp_args)
- and Variant.method not in f.variants
- )
- return OrderedDict(
- [
- ("name", cpp.name(f.func)),
- ("operator_name", str(f.func.name.name)),
- ("overload_name", str(f.func.name.overload_name)),
- ("manual_kernel_registration", f.manual_kernel_registration),
- (
- "category_override",
- f.category_override if f.category_override is not None else "",
- ),
- ("schema_string", f"aten::{f.func}"),
- ("arguments", arguments),
- ("schema_order_cpp_signature", schema_order_cpp_signature),
- ("schema_order_arguments", schema_order_arguments),
- ("method_of", compute_method_of_yaml(f.variants)),
- ("mode", "native"),
- ("python_module", "" if f.python_module is None else f.python_module),
- ("returns", returns),
- ("inplace", f.func.name.name.inplace),
- ("is_factory_method", is_factory_method),
- ("abstract", f.is_abstract),
- ("device_guard", f.device_guard),
- ("with_gil", False),
- ("deprecated", False),
- ("has_math_kernel", f.has_composite_implicit_autograd_kernel),
- ]
- )
- # See Note [Auto generated composite kernels]
- def has_autogenerated_composite_kernel(f: NativeFunction) -> bool:
- return (f.structured or f.structured_delegate is not None) and (
- f.func.kind() == SchemaKind.functional or f.func.kind() == SchemaKind.inplace
- )
- @with_native_function_and_indices
- def compute_registration_declarations(
- f: NativeFunction, backend_indices: Dict[DispatchKey, BackendIndex]
- ) -> str:
- name = dispatcher.name(f.func)
- returns_type = dispatcher.returns_type(
- f.func.returns
- ).cpp_type_registration_declarations()
- args = dispatcher.arguments(f.func)
- args_str = ", ".join(a.no_default().decl_registration_declarations() for a in args)
- comment_data: Dict[str, str] = {
- "schema": f"aten::{f.func}",
- # TODO: What exactly is the semantics of the 'dispatch' field?
- "dispatch": str(
- {k for k, v in backend_indices.items() if v.has_kernel(f)}
- != {DispatchKey.CompositeImplicitAutograd}
- and {k for k, v in backend_indices.items() if v.has_kernel(f)}
- != {
- DispatchKey.CompositeImplicitAutograd,
- DispatchKey.CompositeImplicitAutogradNestedTensor,
- }
- ),
- "default": str(f.has_composite_kernel or has_autogenerated_composite_kernel(f)),
- }
- return f"""{returns_type} {name}({args_str}); // {json.dumps(comment_data)}
- """
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- #
- # RUN IT ALL
- #
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- def get_custom_build_selector(
- provided_op_registration_allowlist: Optional[List[str]],
- op_selection_yaml_path: Optional[str],
- ) -> SelectiveBuilder:
- assert not (
- provided_op_registration_allowlist is not None
- and op_selection_yaml_path is not None
- ), (
- "Both provided_op_registration_allowlist and "
- + "op_selection_yaml_path can NOT be provided at the "
- + "same time."
- )
- op_registration_allowlist: Optional[Set[str]] = None
- if provided_op_registration_allowlist is not None:
- op_registration_allowlist = set(provided_op_registration_allowlist)
- if op_registration_allowlist is not None:
- selector = SelectiveBuilder.from_legacy_op_registration_allow_list(
- op_registration_allowlist,
- True,
- False,
- )
- elif op_selection_yaml_path is not None:
- selector = SelectiveBuilder.from_yaml_path(op_selection_yaml_path)
- else:
- selector = SelectiveBuilder.get_nop_selector()
- return selector
- def get_grouped_by_view_native_functions(
- native_functions: Sequence[NativeFunction],
- ) -> Sequence[Union[NativeFunction, NativeFunctionsViewGroup]]:
- def maybe_create_view_group(
- d: Dict[Union[ViewSchemaKind, SchemaKind], NativeFunction]
- ) -> List[Union[NativeFunction, NativeFunctionsViewGroup]]:
- funcs: List[Union[NativeFunction, NativeFunctionsViewGroup]] = []
- if ViewSchemaKind.aliasing in d:
- view = d.pop(ViewSchemaKind.aliasing)
- view_inplace = d.pop(ViewSchemaKind.aliasing_inplace, None)
- view_copy = d.pop(SchemaKind.functional, None)
- funcs.append(
- NativeFunctionsViewGroup(
- view=view,
- view_copy=view_copy,
- view_inplace=view_inplace,
- )
- )
- # Take the remaining functions that weren't part of the view group
- # and emit them separately
- for func in d.values():
- funcs.append(func)
- return funcs
- grouped_by_views: Dict[
- FunctionSchema, Dict[Union[SchemaKind, ViewSchemaKind], NativeFunction]
- ] = defaultdict(dict)
- for f in native_functions:
- schema = f.func.view_signature()
- view_kind: ViewSchemaKind = f.view_schema_kind
- # We need to group up ops relevant to the same "view", consisting of:
- # view op (ViewSchemaKind.aliasing)
- # view_inplace op (ViewSchemaKind.aliasing_inplace)
- # view_copy op (SchemaKind.functional)
- if view_kind == ViewSchemaKind.non_aliasing:
- kind = f.func.kind()
- assert kind not in grouped_by_views[schema]
- grouped_by_views[schema][kind] = f
- else:
- assert view_kind not in grouped_by_views[schema]
- grouped_by_views[schema][view_kind] = f
- return list(concatMap(maybe_create_view_group, grouped_by_views.values()))
- def get_grouped_native_functions(
- native_functions: Sequence[NativeFunction],
- ) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]:
- def flatten_pre_group(
- d: Dict[SchemaKind, NativeFunction]
- ) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]:
- r = NativeFunctionsGroup.from_dict(d)
- if r is None:
- # Invariant: any NativeFunctions that are code-generated
- # should have been grouped into NativeFunctionsGroup objects
- assert not any("generated" in f.tags for f in d.values())
- return list(d.values())
- else:
- return [r]
- # TODO: how come ValuesView isn't a Sequence lol
- pre_grouped_native_functions = pre_group_native_functions(native_functions)
- return list(
- concatMap(flatten_pre_group, list(pre_grouped_native_functions.values()))
- )
- # Return native function declarations grouped by their namespaces.
- def get_native_function_declarations(
- *,
- grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
- backend_indices: Dict[DispatchKey, BackendIndex],
- native_function_decl_gen: Callable[
- [Union[NativeFunctionsGroup, NativeFunction], BackendIndex], List[str]
- ] = dest.compute_native_function_declaration,
- ) -> List[str]:
- """
- Generate kernel declarations, in `NativeFunction(s).h`.
- :param grouped_native_functions: a sequence of `NativeFunction` or `NativeFunctionGroup`.
- :param backend_indices: kernel collections grouped by dispatch key.
- :param native_function_decl_gen: callable to generate kernel declaration for each `NativeFunction`.
- :return: a list of string, from the string with all declarations, grouped by namespaces, split by newline.
- """
- declarations: List[str] = []
- ns_grouped_kernels: Dict[str, List[str]] = defaultdict(list)
- newline = "\n"
- for f in grouped_native_functions:
- native_function_namespaces = set()
- dispatch_keys = set()
- for dispatch_key, backend_idx in backend_indices.items():
- backend_metadata = backend_idx.get_kernel(f)
- if backend_metadata:
- namespace = backend_metadata.cpp_namespace
- dispatch_keys.add(dispatch_key)
- native_function_namespaces.add(namespace)
- else:
- namespace = DEFAULT_KERNEL_NAMESPACE
- assert (
- len(native_function_namespaces) <= 1
- ), f"Codegen only supports one namespace per operator, got {native_function_namespaces} from {dispatch_keys}"
- ns_grouped_kernels[namespace].extend(
- native_function_decl_gen(f, backend_idx)
- )
- for namespace, kernels in ns_grouped_kernels.items():
- ns_helper = NamespaceHelper(
- namespace_str=namespace,
- entity_name="",
- max_level=3,
- )
- # Convert to a set first to remove duplicate kernel names. Backends are
- # allowed to repeat kernel names; only generate the declaration once!
- ordered_kernels = list(OrderedDict.fromkeys(kernels))
- declarations.extend(
- f"""
- {ns_helper.prologue}
- {newline.join(ordered_kernels)}
- {ns_helper.epilogue}
- """.split(
- newline
- )
- )
- return declarations
- def get_kernel_namespace(
- *, f: Union[NativeFunction, NativeFunctionsGroup], backend_idx: BackendIndex
- ) -> str:
- backend_metadata = backend_idx.get_kernel(f)
- assert not backend_metadata or "::native" in backend_metadata.cpp_namespace, (
- f"The kernel for function {f.func.name if isinstance(f, NativeFunction) else f.functional.func.name} "
- f"with dispatch key {backend_idx.dispatch_key}"
- f" has a namespace {backend_metadata.cpp_namespace} and it's not ending with '::native'."
- )
- return (
- backend_metadata.cpp_namespace if backend_metadata else DEFAULT_KERNEL_NAMESPACE
- )
- # Return native function definitions grouped by dispatch key and custom namespace.
- # Used in RegisterDispatchKey.cpp and etc.
- def get_native_function_definitions(
- *,
- fm: FileManager,
- grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
- dispatch_key: DispatchKey,
- backend_idx: BackendIndex,
- selector: SelectiveBuilder,
- rocm: bool,
- symint: bool,
- skip_dispatcher_op_registration: bool,
- gen_dispatch_helpers: bool,
- ) -> List[str]:
- definitions: List[str] = []
- ns_definitions: Dict[str, List[str]] = defaultdict(list)
- anonymous_definitions: Dict[str, List[str]] = defaultdict(list)
- registrations: Dict[str, Dict[str, List[str]]] = defaultdict(dict)
- newline = "\n"
- ns_gen = dest.RegisterDispatchKey(
- backend_idx,
- Target.NAMESPACED_DEFINITION,
- selector,
- rocm=rocm,
- symint=symint,
- class_method_name=None,
- skip_dispatcher_op_registration=skip_dispatcher_op_registration,
- )
- anonymous_gen = dest.RegisterDispatchKey(
- backend_idx,
- Target.ANONYMOUS_DEFINITION,
- selector,
- rocm=rocm,
- symint=symint,
- class_method_name=None,
- skip_dispatcher_op_registration=skip_dispatcher_op_registration,
- )
- reg_gen = dest.RegisterDispatchKey(
- backend_idx,
- Target.REGISTRATION,
- selector,
- rocm=rocm,
- symint=symint,
- class_method_name=None,
- skip_dispatcher_op_registration=skip_dispatcher_op_registration,
- )
- for f in grouped_native_functions:
- kernel_namespace = get_kernel_namespace(f=f, backend_idx=backend_idx).replace(
- "::native", ""
- )
- ns_definitions[kernel_namespace].extend(
- ns_gen(f),
- )
- anonymous_definitions[kernel_namespace].extend(
- anonymous_gen(f),
- )
- namespace = (
- f.namespace if isinstance(f, NativeFunction) else f.functional.namespace
- )
- if namespace not in registrations[kernel_namespace]:
- registrations[kernel_namespace] = defaultdict(list)
- registrations[kernel_namespace][namespace].extend(
- reg_gen(f),
- )
- for kernel_namespace in ns_definitions:
- if len(ns_definitions[kernel_namespace]) == 0:
- continue
- ns_helper = NamespaceHelper(namespace_str=kernel_namespace)
- registration_body = ""
- for namespace in registrations[kernel_namespace]:
- if not registrations[kernel_namespace][namespace]:
- continue
- registration_body += f"""
- TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
- {newline.join(registrations[kernel_namespace][namespace])}
- }};"""
- definitions.extend(
- fm.substitute_with_template(
- "RegisterDispatchDefinitions.ini",
- lambda: {
- "ns_prologue": ns_helper.prologue,
- "ns_epilogue": ns_helper.epilogue,
- "dispatch_helpers": dest.gen_registration_helpers(backend_idx)
- if gen_dispatch_helpers
- else [],
- "dispatch_anonymous_definitions": anonymous_definitions[
- kernel_namespace
- ],
- "static_init_dispatch_registrations": ""
- if skip_dispatcher_op_registration
- else registration_body,
- "deferred_dispatch_registrations": "",
- "dispatch_namespace": dispatch_key.lower(),
- "dispatch_namespaced_definitions": ns_definitions[kernel_namespace],
- },
- ).split(newline)
- )
- return definitions
- # Return native function declarations grouped by dispatch key and custom namespace.
- # Used in CPUFunctions_inl.h and etc.
- def get_namespaced_declaration(
- *,
- grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
- dispatch_key: DispatchKey,
- backend_idx: BackendIndex,
- selector: SelectiveBuilder,
- rocm: bool,
- symint: bool,
- ) -> List[str]:
- declarations: List[str] = []
- ns_grouped_kernels: Dict[str, List[str]] = defaultdict(list)
- newline = "\n"
- func = dest.RegisterDispatchKey(
- backend_idx,
- Target.NAMESPACED_DECLARATION,
- selector,
- rocm=rocm,
- class_method_name=None,
- skip_dispatcher_op_registration=False,
- symint=symint,
- )
- for f in grouped_native_functions:
- namespace = get_kernel_namespace(f=f, backend_idx=backend_idx).replace(
- "native", dispatch_key.lower()
- )
- ns_grouped_kernels[namespace].extend(
- func(f),
- )
- for namespace, kernels in ns_grouped_kernels.items():
- if len(kernels) == 0:
- continue
- ns_helper = NamespaceHelper(
- namespace_str=namespace, entity_name="", max_level=3
- )
- ordered_kernels = list(OrderedDict.fromkeys(kernels))
- declarations.extend(
- f"""
- {ns_helper.prologue}
- {newline.join(ordered_kernels)}
- {ns_helper.epilogue}
- """.split(
- newline
- )
- )
- return declarations
- # Return native function schema registration code for aten and other namespaces.
- def get_native_function_schema_registrations(
- *,
- native_functions: Sequence[NativeFunction],
- schema_selector: SelectiveBuilder,
- ) -> Tuple[List[str], str]:
- ns_native_functions: Dict[str, List[NativeFunction]] = defaultdict(list)
- for native_function in native_functions:
- ns_native_functions[native_function.namespace].append(native_function)
- schema_registrations = ""
- aten_schema_registrations = []
- custom_namespace = None
- for namespace, funcs in ns_native_functions.items():
- schema_registrations_body = list(
- mapMaybe(RegisterSchema(schema_selector), funcs)
- )
- # NB: we have to separate aten namespace registration from other namespaces,
- # because in the template we hardcoded an operator for ATen already.
- if namespace == "aten":
- aten_schema_registrations = schema_registrations_body
- else:
- custom_namespace = namespace
- tab = "\t"
- # if the namespace is predefined, we should use define a library fragment
- # instead of a new library
- torch_library_macro = (
- "TORCH_LIBRARY_FRAGMENT"
- if namespace in FRAGMENT_NAMESPACES
- else "TORCH_LIBRARY"
- )
- schema_registrations += f"""
- {torch_library_macro}({custom_namespace}, m) {{
- {tab.join(schema_registrations_body)}
- }};"""
- return (aten_schema_registrations, schema_registrations)
- def gen_aggregated_headers(
- *,
- native_functions: Sequence[NativeFunction],
- grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
- structured_native_functions: Sequence[NativeFunctionsGroup],
- static_dispatch_idx: List[BackendIndex],
- selector: SelectiveBuilder,
- backend_indices: Dict[DispatchKey, BackendIndex],
- cpu_fm: FileManager,
- cuda_fm: FileManager,
- functions_keys: Set[DispatchKey],
- dispatch_keys: Sequence[DispatchKey],
- rocm: bool,
- ) -> None:
- # Buck doesn't support dynamic output files, so we aggregate all operator
- # headers into a single file
- cpu_fm.write(
- "NativeMetaFunctions.h",
- lambda: {
- "NativeMetaFunctions_includes": [],
- "NativeMetaFunctions_declarations": list(
- mapMaybe(compute_meta_function_declaration, structured_native_functions)
- ),
- },
- )
- method_native_functions = [
- fn for fn in native_functions if Variant.method in fn.variants
- ]
- non_method_native_functions = [
- fn for fn in native_functions if fn not in method_native_functions
- ]
- cpu_fm.write(
- "MethodOperators.h",
- lambda: {
- "MethodOperators_includes": [],
- "MethodOperators_declarations": list(
- mapMaybe(
- ComputeOperators(
- Target.DECLARATION,
- static_dispatch_backend_indices=static_dispatch_idx,
- ),
- method_native_functions,
- )
- ),
- },
- )
- cpu_fm.write(
- "Operators.h",
- lambda: {
- "Operators_includes": ["#include <ATen/MethodOperators.h>"],
- "Operators_declarations": list(
- mapMaybe(
- ComputeOperators(
- Target.DECLARATION,
- static_dispatch_backend_indices=static_dispatch_idx,
- ),
- non_method_native_functions,
- )
- ),
- },
- )
- cpu_fm.write(
- "Functions.h",
- lambda: {
- "static_dispatch_extra_headers": static_dispatch_extra_headers(
- static_dispatch_idx
- ),
- "Functions_includes": ["#include <ATen/Operators.h>"],
- "Functions_declarations": list(
- mapMaybe(
- ComputeFunction(),
- native_functions,
- )
- ),
- },
- )
- declarations = get_native_function_declarations(
- grouped_native_functions=grouped_native_functions,
- backend_indices=backend_indices,
- )
- cpu_fm.write(
- "NativeFunctions.h",
- lambda: {
- "NativeFunctions_includes": ["#include <ATen/NativeMetaFunctions.h>"],
- "NativeFunctions_declarations": declarations,
- },
- )
- for dispatch_key in dispatch_keys:
- fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm
- if dispatch_key in functions_keys:
- inl_headers = f"#include <ATen/{dispatch_key}Functions_inl.h>"
- fm.write_with_template(
- f"{dispatch_key}Functions.h",
- "DispatchKeyFunctions.h",
- lambda: {
- "dispatch_key": str(dispatch_key),
- "inline_headers": inl_headers,
- },
- )
- fm.write_with_template(
- f"{dispatch_key}Functions_inl.h",
- "DispatchKeyFunctions_inl.h",
- lambda: {
- "DispatchKeyFunctions_inl_includes": [],
- "dispatch_namespace": dispatch_key.lower(),
- "dispatch_namespaced_declarations": get_namespaced_declaration(
- grouped_native_functions=grouped_native_functions,
- dispatch_key=dispatch_key,
- backend_idx=backend_indices[dispatch_key],
- selector=selector,
- rocm=rocm,
- symint=True,
- ),
- },
- )
- del fm
- def gen_per_operator_headers(
- *,
- native_functions: Sequence[NativeFunction],
- grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
- static_dispatch_idx: List[BackendIndex],
- selector: SelectiveBuilder,
- backend_indices: Dict[DispatchKey, BackendIndex],
- cpu_fm: FileManager,
- cuda_fm: FileManager,
- ops_fm: FileManager,
- functions_keys: Set[DispatchKey],
- dispatch_keys: Sequence[DispatchKey],
- rocm: bool,
- ) -> None:
- # For CMake builds, split operator declarations into separate headers in
- # the ATen/ops folder to split up header dependencies
- functions_by_root_name: Dict[str, List[NativeFunction]] = defaultdict(lambda: [])
- for fn in native_functions:
- functions_by_root_name[fn.root_name].append(fn)
- grouped_functions_by_root_name: Dict[
- str, List[Union[NativeFunction, NativeFunctionsGroup]]
- ] = defaultdict(lambda: [])
- for group in grouped_native_functions:
- name = group.root_name
- grouped_functions_by_root_name[name].append(group)
- for name, functions in functions_by_root_name.items():
- ops_fm.write_with_template(
- f"{name}_ops.h",
- "Operator.h",
- lambda: {
- "declarations": list(
- mapMaybe(
- ComputeOperators(
- Target.DECLARATION,
- static_dispatch_backend_indices=static_dispatch_idx,
- ),
- functions,
- )
- ),
- },
- )
- ops_fm.write_with_template(
- f"{name}.h",
- "Function.h",
- lambda: {
- "static_dispatch_ops_headers": list(
- mapMaybe(
- lambda fn: static_dispatch_ops_header(
- fn, backend_index=static_dispatch_idx
- ),
- functions,
- )
- ),
- "operator_includes": f"#include <ATen/ops/{name}_ops.h>",
- "function_definitions": list(
- mapMaybe(
- ComputeFunction(),
- functions,
- )
- ),
- },
- )
- grouped_functions = grouped_functions_by_root_name.get(name, [])
- structured_functions = [
- fn
- for fn in grouped_functions
- if isinstance(fn, NativeFunctionsGroup) and fn.structured
- ]
- is_structured = len(structured_functions) > 0
- if is_structured:
- ops_fm.write_with_template(
- f"{name}_meta.h",
- "NativeMetaFunction.h",
- lambda: {
- "meta_function_declarations": list(
- mapMaybe(
- compute_meta_function_declaration, structured_functions
- )
- ),
- },
- )
- declarations = get_native_function_declarations(
- grouped_native_functions=grouped_functions,
- backend_indices=backend_indices,
- native_function_decl_gen=dest.compute_native_function_declaration,
- )
- ops_fm.write_with_template(
- f"{name}_native.h",
- "NativeFunction.h",
- lambda: {
- "extra_includes": (
- f"#include <ATen/ops/{name}_meta.h>" if is_structured else []
- ),
- "native_function_declarations": declarations,
- },
- )
- for category, suffix in [
- ("Functions", ""),
- ("Operators", "_ops"),
- ("NativeMetaFunctions", "_meta"),
- ("NativeFunctions", "_native"),
- ]:
- cpu_fm.write(
- f"{category}.h",
- lambda: {
- f"{category}_includes": [
- f"#include <ATen/ops/{name}{suffix}.h>"
- for name in sorted(functions_by_root_name.keys())
- ],
- f"{category}_declarations": [],
- },
- )
- for dispatch_key in dispatch_keys:
- if dispatch_key not in functions_keys:
- continue
- dispatch_namespace = dispatch_key.lower()
- dispatch_names = []
- for name, functions in functions_by_root_name.items():
- grouped_functions = grouped_functions_by_root_name.get(name, [])
- declarations = list(
- concatMap(
- dest.RegisterDispatchKey(
- backend_indices[dispatch_key],
- Target.NAMESPACED_DECLARATION,
- selector,
- rocm=rocm,
- symint=True,
- class_method_name=None,
- skip_dispatcher_op_registration=False,
- ),
- grouped_functions,
- )
- )
- if len(declarations) == 0:
- continue
- dispatch_names.append(name)
- ops_fm.write_with_template(
- f"{name}_{dispatch_namespace}_dispatch.h",
- "DispatchKeyFunction.h",
- lambda: {
- "dispatch_namespace": dispatch_namespace,
- "dispatch_namespaced_declarations": declarations,
- },
- )
- fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm
- inl_headers = f"#include <ATen/{dispatch_key}Functions_inl.h>"
- fm.write_with_template(
- f"{dispatch_key}Functions.h",
- "DispatchKeyFunctions.h",
- lambda: {
- "dispatch_key": str(dispatch_key),
- "inline_headers": inl_headers,
- },
- )
- fm.write_with_template(
- f"{dispatch_key}Functions_inl.h",
- "DispatchKeyFunctions_inl.h",
- lambda: {
- "dispatch_namespace": dispatch_namespace,
- "DispatchKeyFunctions_inl_includes": [
- f"#include <ATen/ops/{name}_{dispatch_namespace}_dispatch.h>"
- for name in sorted(dispatch_names)
- ],
- "dispatch_namespaced_declarations": [],
- },
- )
- del fm
- cpu_fm.write(
- "MethodOperators.h",
- lambda: {
- "MethodOperators_includes": sorted(
- f"#include <ATen/ops/{name}_ops.h>"
- for name, functions in functions_by_root_name.items()
- if any(Variant.method in fn.variants for fn in functions)
- ),
- "MethodOperators_declarations": [],
- },
- )
- def gen_headers(
- *,
- native_functions: Sequence[NativeFunction],
- valid_tags: Set[str],
- grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
- structured_native_functions: Sequence[NativeFunctionsGroup],
- static_dispatch_idx: List[BackendIndex],
- selector: SelectiveBuilder,
- backend_indices: Dict[DispatchKey, BackendIndex],
- core_fm: FileManager,
- cpu_fm: FileManager,
- cuda_fm: FileManager,
- ops_fm: FileManager,
- dispatch_keys: Sequence[DispatchKey],
- functions_keys: Set[DispatchKey],
- rocm: bool,
- per_operator_headers: bool,
- ) -> None:
- if per_operator_headers:
- gen_per_operator_headers(
- native_functions=native_functions,
- grouped_native_functions=grouped_native_functions,
- static_dispatch_idx=static_dispatch_idx,
- selector=selector,
- backend_indices=backend_indices,
- cpu_fm=cpu_fm,
- cuda_fm=cuda_fm,
- ops_fm=ops_fm,
- dispatch_keys=dispatch_keys,
- functions_keys=functions_keys,
- rocm=rocm,
- )
- else:
- gen_aggregated_headers(
- native_functions=native_functions,
- grouped_native_functions=grouped_native_functions,
- structured_native_functions=structured_native_functions,
- static_dispatch_idx=static_dispatch_idx,
- selector=selector,
- backend_indices=backend_indices,
- cpu_fm=cpu_fm,
- cuda_fm=cuda_fm,
- dispatch_keys=dispatch_keys,
- functions_keys=functions_keys,
- rocm=rocm,
- )
- core_fm.write(
- "TensorBody.h",
- lambda: {
- "tensor_method_declarations": list(
- mapMaybe(
- ComputeTensorMethod(
- target=Target.DECLARATION,
- static_dispatch_backend_indices=static_dispatch_idx,
- ),
- native_functions,
- )
- ),
- "tensor_method_definitions": list(
- mapMaybe(
- ComputeTensorMethod(
- target=Target.DEFINITION,
- static_dispatch_backend_indices=static_dispatch_idx,
- ),
- native_functions,
- )
- ),
- },
- )
- cpu_fm.write(
- "RedispatchFunctions.h",
- lambda: {
- "function_redispatch_definitions": list(
- mapMaybe(ComputeRedispatchFunction(), native_functions)
- ),
- },
- )
- cpu_fm.write(
- "RegistrationDeclarations.h",
- lambda: {
- "registration_declarations": [
- compute_registration_declarations(f, backend_indices)
- for f in native_functions
- ],
- },
- )
- cpu_fm.write(
- "VmapGeneratedPlumbing.h", lambda: gen_all_vmap_plumbing(native_functions)
- )
- def gen_aten_interned_strings() -> Dict[str, str]:
- attrs = set() # All function argument names
- names = set() # All ATen function names
- for func in native_functions:
- names.add(str(func.func.name.name))
- # Some operators don't have a functional variant but we still create a
- # symbol without the underscore
- names.add(func.func.name.name.base)
- for arg in func.func.schema_order_arguments():
- attrs.add(arg.name)
- # These are keywords in C++, so aren't valid symbol names
- # https://en.cppreference.com/w/cpp/language/operator_alternative
- names -= {
- "and",
- "and_eq",
- "bitand",
- "bitor",
- "compl",
- "not",
- "not_eq",
- "or",
- "or_eq",
- "xor",
- "xor_eq",
- }
- return {
- "aten_symbols": " \\\n".join(
- [f"_(aten, {name})" for name in sorted(names)]
- ),
- "attr_symbols": " \\\n".join(
- [f"_(attr, {name})" for name in sorted(attrs)]
- ),
- }
- core_fm.write("aten_interned_strings.h", gen_aten_interned_strings)
- def gen_tags_enum() -> Dict[str, str]:
- return {"enum_of_valid_tags": (",\n".join(sorted(valid_tags)))}
- core_fm.write("enum_tag.h", gen_tags_enum)
- def gen_source_files(
- *,
- native_functions: Sequence[NativeFunction],
- grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
- structured_native_functions: Sequence[NativeFunctionsGroup],
- view_groups: Sequence[NativeFunctionsViewGroup],
- selector: SelectiveBuilder,
- static_dispatch_idx: List[BackendIndex],
- backend_indices: Dict[DispatchKey, BackendIndex],
- core_fm: FileManager,
- cpu_fm: FileManager,
- cpu_vec_fm: FileManager,
- cuda_fm: FileManager,
- dispatch_keys: Sequence[DispatchKey],
- functions_keys: Set[DispatchKey],
- rocm: bool,
- force_schema_registration: bool,
- per_operator_headers: bool,
- skip_dispatcher_op_registration: bool,
- ) -> None:
- extra_cuda_headers = """\
- #include <c10/cuda/CUDAGuard.h>
- #include <ATen/cuda/ATenCUDAGeneral.h>
- #include <ATen/cuda/CUDADevice.h>
- #include <ATen/cuda/CUDAContext.h>"""
- if rocm:
- extra_cuda_headers = """\
- #include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
- #include <ATen/hip/ATenHIPGeneral.h>
- #include <ATen/hip/HIPDevice.h>
- #include <ATen/hip/HIPContext.h>"""
- for dispatch_key in dispatch_keys:
- fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm
- if per_operator_headers:
- def operator_headers() -> List[str]:
- headers = []
- for g in grouped_native_functions:
- is_registered = False
- if backend_index.has_kernel(g):
- is_registered = True
- # The above has_kernel test on a group will only test for
- # the existence of out dispatch, because that's how
- # structured kernels work. But sometimes functions can be
- # grouped but not be structured, and then you need to check
- # each individual piece, as they may have manual dispatch
- # entries.
- elif isinstance(g, NativeFunctionsGroup) and any(
- backend_index.has_kernel(fn) for fn in g.functions()
- ):
- is_registered = True
- # TODO: this condition is a bit questionable
- # (It has to do with the fact that structured kernels get generated kernels
- # to the Meta + CompositeExplicitAutogradNonFunctional keys).
- elif g.structured and dispatch_key in (
- DispatchKey.Meta,
- DispatchKey.CompositeExplicitAutogradNonFunctional,
- ):
- is_registered = True
- if not is_registered:
- continue
- headers.append(f"#include <ATen/ops/{g.root_name}_native.h>")
- if (
- dispatch_key
- == DispatchKey.CompositeExplicitAutogradNonFunctional
- ):
- headers.append(f"#include <ATen/ops/{g.root_name}.h>")
- if dispatch_key in functions_keys:
- headers.append(
- f"#include <ATen/ops/{g.root_name}_{dispatch_namespace}_dispatch.h>"
- )
- return sorted(set(headers))
- else:
- def operator_headers() -> List[str]:
- headers = ["#include <ATen/NativeFunctions.h>"]
- if dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional:
- headers.append("#include <ATen/Functions.h>")
- if dispatch_key in functions_keys:
- headers.append(f"#include <ATen/{dispatch_key!s}Functions.h>")
- return headers
- backend_index = backend_indices[dispatch_key]
- ns_grouped_native_functions = defaultdict(list)
- for grouped_native_function in grouped_native_functions:
- namespace = (
- grouped_native_function.namespace
- if isinstance(grouped_native_function, NativeFunction)
- else grouped_native_function.functional.namespace
- )
- ns_grouped_native_functions[namespace].append(grouped_native_function)
- dispatch_namespace = str(dispatch_key).lower()
- # CompositeImplicitAutogradNestdTensor does not currently user the helpers generated
- # compilation will fail when `-Werror=unused-function` flag is set
- gen_dispatch_helpers: bool = (
- dispatch_key != DispatchKey.CompositeImplicitAutogradNestedTensor
- )
- dispatch_definitions = get_native_function_definitions(
- fm=fm,
- grouped_native_functions=grouped_native_functions,
- dispatch_key=dispatch_key,
- backend_idx=backend_index,
- selector=selector,
- rocm=rocm,
- symint=True,
- skip_dispatcher_op_registration=skip_dispatcher_op_registration,
- gen_dispatch_helpers=gen_dispatch_helpers,
- )
- fm.write_with_template(
- f"Register{dispatch_key}.cpp",
- "RegisterDispatchKey.cpp",
- lambda: {
- "extra_cuda_headers": extra_cuda_headers
- if is_cuda_dispatch_key(dispatch_key)
- else "",
- "external_backend_headers": "",
- "dispatch_headers": dest.gen_registration_headers(
- backend_index, per_operator_headers, rocm
- ),
- "ops_headers": operator_headers(),
- "dispatch_helpers": "",
- "dispatch_definitions": dispatch_definitions,
- },
- )
- for g in structured_native_functions:
- if not g.out.ufunc_inner_loop or not is_ufunc_dispatch_key(dispatch_key):
- continue
- name = g.functional.func.name.name
- if dispatch_key is DispatchKey.CPU:
- assert fm is cpu_fm
- fm.write_with_template(
- f"UfuncCPU_{name}.cpp",
- "UfuncCPU.cpp",
- lambda: {
- "meta_declaration": compute_meta_function_declaration(g),
- "native_declaration": dest.compute_native_function_declaration(
- g, backend_indices[dispatch_key]
- ),
- "native_definitions": dest.compute_ufunc_cpu(g),
- },
- )
- cpu_vec_fm.write_with_template(
- f"UfuncCPUKernel_{name}.cpp",
- "UfuncCPUKernel.cpp",
- lambda: {
- "name": name,
- "native_definitions": dest.compute_ufunc_cpu_kernel(g),
- },
- )
- elif dispatch_key is DispatchKey.CUDA:
- cuda_headers = "#include <ATen/native/cuda/Loops.cuh>"
- if rocm:
- cuda_headers = "#include <ATen/native/hip/Loops.cuh>"
- fm.write_with_template(
- f"UfuncCUDA_{name}.cu",
- "UfuncCUDA.cu",
- lambda: {
- "name": name,
- "cuda_headers": cuda_headers,
- "meta_declaration": compute_meta_function_declaration(g),
- "native_declaration": dest.compute_native_function_declaration(
- g, backend_indices[dispatch_key]
- ),
- "native_definitions": dest.compute_ufunc_cuda(g),
- },
- )
- else:
- raise AssertionError(f"unrecognized {dispatch_key} for ufunc")
- del fm
- # BackendSelect is generated specially
- def gen_backend_select() -> Dict[str, List[str]]:
- relevant_fns = [
- fn for fn in native_functions if needs_backend_select(fn, selector)
- ]
- return {
- "ops_headers": [
- f"#include <ATen/ops/{fn.root_name}_ops.h>" for fn in relevant_fns
- ],
- "backend_select_method_definitions": list(
- mapMaybe(
- ComputeBackendSelect(Target.DEFINITION, selector), relevant_fns
- )
- ),
- "backend_select_function_registrations": list(
- mapMaybe(
- ComputeBackendSelect(Target.REGISTRATION, selector), relevant_fns
- )
- ),
- }
- cpu_fm.write("RegisterBackendSelect.cpp", gen_backend_select)
- schema_selector = selector
- if force_schema_registration:
- schema_selector = SelectiveBuilder.get_nop_selector()
- (
- aten_schema_registrations,
- schema_registrations,
- ) = get_native_function_schema_registrations(
- native_functions=native_functions, schema_selector=schema_selector
- )
- cpu_fm.write(
- "RegisterSchema.cpp",
- lambda: {
- "aten_schema_registrations": []
- if skip_dispatcher_op_registration
- else aten_schema_registrations,
- "schema_registrations": []
- if skip_dispatcher_op_registration
- else schema_registrations,
- },
- )
- def key_func(
- fn: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup]
- ) -> str:
- return fn.root_name
- cpu_fm.write_sharded(
- "Operators.cpp",
- native_functions,
- key_fn=key_func,
- env_callable=lambda fn: {
- "operator_headers": [f"#include <ATen/ops/{fn.root_name}.h>"],
- "definitions": [
- ComputeOperators(
- Target.DEFINITION,
- static_dispatch_backend_indices=static_dispatch_idx,
- )(fn)
- ],
- },
- base_env={
- "static_dispatch_extra_headers": static_dispatch_extra_headers(
- static_dispatch_idx
- ),
- },
- num_shards=5,
- sharded_keys={
- "operator_headers",
- "definitions",
- "static_dispatch_extra_headers",
- },
- )
- cpu_fm.write("Functions.cpp", lambda: {})
- core_fm.write("TensorMethods.cpp", lambda: {})
- core_fm.write(
- "ATenOpList.cpp",
- lambda: {
- "aten_ops": list(mapMaybe(compute_aten_op, native_functions)),
- },
- )
- def functionalization_env_callable(
- g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup]
- ) -> Dict[str, List[str]]:
- def gen_op_headers(
- g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup]
- ) -> List[str]:
- if isinstance(g, NativeFunctionsViewGroup):
- # view ops always get a functionalization kernel
- headers = [
- f"#include <ATen/ops/{g.view.root_name}_native.h>",
- f"#include <ATen/ops/{g.view.root_name}_ops.h>",
- ]
- if g.view_copy is not None:
- headers += [
- f"#include <ATen/ops/{g.view_copy.root_name}_native.h>",
- f"#include <ATen/ops/{g.view_copy.root_name}_ops.h>",
- ]
- return headers
- elif isinstance(g, NativeFunctionsGroup):
- headers = [
- f"#include <ATen/ops/{g.functional.root_name}_native.h>",
- f"#include <ATen/ops/{g.functional.root_name}_ops.h>",
- f"#include <ATen/ops/{g.out.root_name}_native.h>",
- f"#include <ATen/ops/{g.out.root_name}_ops.h>",
- ]
- if g.inplace is not None:
- headers += [
- f"#include <ATen/ops/{g.inplace.root_name}_native.h>",
- f"#include <ATen/ops/{g.inplace.root_name}_ops.h>",
- ]
- if g.mutable is not None:
- headers += [
- f"#include <ATen/ops/{g.mutable.root_name}_native.h>",
- f"#include <ATen/ops/{g.mutable.root_name}_ops.h>",
- ]
- return headers
- else:
- return [
- f"#include <ATen/ops/{g.root_name}_native.h>",
- f"#include <ATen/ops/{g.root_name}_ops.h>",
- ]
- return {
- "ops_headers": gen_op_headers(g),
- "func_definitions": gen_functionalization_definition(
- selector,
- g,
- ),
- "func_registrations": gen_functionalization_registration(
- selector,
- g,
- backend_indices[DispatchKey.CompositeImplicitAutograd],
- ),
- }
- all_groups: List[
- Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup]
- ] = list(structured_native_functions) + list(
- view_groups # type: ignore[assignment, arg-type, operator]
- )
- # Note: all operators that functionalization needs to handle (mutable and aliasing ops) should be grouped properly.
- # The only reason we really need to deal with direct NativeFunctions here (instead of the groups) is because:
- # (1) We can provide better error checking (error out if someone introduces a mutable op that doesn't obey the grouping logic)
- # (2) functionalization needs to manually register CompositeImplicitAutograd kernels, which might not be grouped.
- # Although this could go away long-term if we add a dedicated dispatch key for decompositions.
- structured_map: Dict[OperatorName, NativeFunction] = {
- f.func.name: f
- for f in concatMap(lambda g: list(g.functions()), structured_native_functions)
- }
- view_map: Dict[OperatorName, NativeFunction] = {
- f.func.name: f for f in concatMap(lambda g: list(g.functions()), view_groups)
- }
- for f in native_functions:
- if f.func.name not in structured_map and f.func.name not in view_map:
- all_groups.append(f)
- cpu_fm.write_sharded(
- "RegisterFunctionalization.cpp",
- all_groups,
- key_fn=key_func,
- env_callable=functionalization_env_callable,
- num_shards=4,
- sharded_keys={
- "ops_headers",
- "func_definitions",
- "func_registrations",
- "func_add_back_views_definitions",
- "func_add_back_views_registrations",
- },
- )
- cpu_fm.write(
- "FunctionalInverses.h",
- lambda: {
- "view_inverse_declarations": list(
- mapMaybe(
- lambda g: gen_functionalization_view_inverse_declaration(
- selector, g
- ),
- view_groups,
- )
- )
- },
- )
- # Note [view_copy NativeFunctions]
- # Every view operator in native_functions.yaml that is not CompositeImplicitAutograd
- # needs to have a corresponding non-aliasing {view}_copy variant.
- # Backends that use functionalization and don't know how to handle aliasing ops
- # are expected to implement kernels for these {view}_copy kernels instead.
- # The code for {view}_copy operators in core is pretty boilerplate-heavy however,
- # so we codegen the following:
- # (1) A CompositeExplicitAutogradNonFunctional kernel for every {view}_copy operator.
- # These are never explicitly invoked by the functionalization pass,
- # but they could theoretically be called from user code (I added these kernels for completeness,
- # since the ops are part of the public API).
- # (2) A derivative formula for every {view}_copy operator
- # {view}_copy operators can re-use the same derivative formulas as their {view} op counterparts,
- # so rather than stamping all of the entries out in derivatives.yaml,
- # we codegen them in.
- # This is similar to how autograd codegen doesn't require inplace ops to have a derivatives.yaml entry.
- cpu_fm.write(
- "CompositeViewCopyKernels.cpp",
- lambda: {
- "ops_headers": [
- "\n".join(
- f"#include <ATen/ops/{f.root_name}_ops.h>\n"
- # NB: this include is important as it ensures we
- # set the visibility on generated view_copy kernels
- # correctly
- f"#include <ATen/ops/{f.root_name}_native.h>"
- for f in (
- [g.view] if g.view_copy is None else [g.view, g.view_copy]
- )
- )
- for g in view_groups
- ]
- + [
- "\n".join(
- f"#include <ATen/ops/{f.root_name}_ops.h>"
- for f in [g.inplace, g.mutable, g.functional]
- if f is not None and "generated" not in f.tags
- )
- for g in structured_native_functions
- ],
- "CompositeViewCopyKernel_Definitions": list(
- mapMaybe(
- GenCompositeViewCopyKernel(
- backend_indices[
- DispatchKey.CompositeExplicitAutogradNonFunctional
- ]
- ),
- view_groups,
- )
- ),
- "GeneratedCompositeFunctional_Definitions": list(
- mapMaybe(
- gen_composite_functional_kernel,
- structured_native_functions,
- )
- ),
- "GeneratedCompositeOut_Definitions": list(
- mapMaybe(
- gen_composite_out_kernel,
- structured_native_functions,
- )
- ),
- },
- )
- def gen_declarations_yaml(
- cpu_fm: FileManager, native_functions: Sequence[NativeFunction]
- ) -> None:
- cpu_fm.write(
- "Declarations.yaml",
- lambda: format_yaml([compute_declaration_yaml(f) for f in native_functions]),
- )
- def get_torchgen_root() -> pathlib.Path:
- """
- If you're depending on torchgen out-of-tree, you can use the root to figure
- out the path to native_functions.yaml
- """
- return pathlib.Path(__file__).parent.resolve()
- def main() -> None:
- parser = argparse.ArgumentParser(description="Generate ATen source files")
- parser.add_argument(
- "-s",
- "--source-path",
- help="path to source directory for ATen",
- default="aten/src/ATen",
- )
- parser.add_argument(
- "-o",
- "--output-dependencies",
- help="output a list of dependencies into the given file and exit",
- )
- parser.add_argument(
- "--dry-run",
- action="store_true",
- help="run without writing any files (still updates outputs)",
- )
- parser.add_argument(
- "--per-operator-headers",
- action="store_true",
- help="generate separate headers per operator in ATen/ops",
- )
- parser.add_argument(
- "-d",
- "--install-dir",
- "--install_dir",
- help="output directory",
- default="build/aten/src/ATen",
- )
- parser.add_argument(
- "--rocm",
- action="store_true",
- help="reinterpret CUDA as ROCm/HIP and adjust filepaths accordingly",
- )
- parser.add_argument(
- "--mps",
- action="store_true",
- help="Generate MPS registration code when set",
- )
- # TODO: --op-registration-whitelist will be removed when all call-sites
- # for gen.py are moved over to using the operator YAML file for mobile
- # custom build.
- parser.add_argument(
- "--op-registration-whitelist",
- "--op_registration_whitelist",
- nargs="*",
- help="filter op registrations by the whitelist (if set); "
- "each item is `namespace`::`operator name` without overload name; "
- "e.g.: aten::empty aten::conv2d ...",
- )
- parser.add_argument(
- "--op-selection-yaml-path",
- "--op_selection_yaml_path",
- help="Provide a path to the operator selection (for custom build) YAML "
- "that contains the information about the set of selected operators "
- "and their categories (training, ...). Each operator is either a "
- "full operator name with overload or just a bare operator name. "
- "The operator names also contain the namespace prefix (e.g. aten::)",
- )
- parser.add_argument(
- "--backend-whitelist",
- "--backend_whitelist",
- nargs="*",
- help="filter dispatch backend by the whitelist (if set), "
- "e.g.: CPU CUDA QuantizedCPU ...",
- )
- parser.add_argument(
- "--static-dispatch-backend",
- "--static_dispatch_backend",
- nargs="*",
- help="generate static dispatch code for the specific backend (if set)",
- )
- parser.add_argument(
- "--skip-dispatcher-op-registration",
- "--skip_dispatcher_op_registration",
- action="store_true",
- help="Avoid registering operators into the dispatcher.",
- )
- parser.add_argument(
- "--force-schema-registration",
- "--force_schema_registration",
- action="store_true",
- help="force it to generate schema-only registrations for all ops, including"
- "those that are not listed on --op-registration-whitelist",
- )
- parser.add_argument(
- "--generate",
- type=str,
- nargs="*",
- choices=["headers", "sources", "declarations_yaml"],
- default=["headers", "sources", "declarations_yaml"],
- help="Generate only a subset of files",
- )
- options = parser.parse_args()
- selector = get_custom_build_selector(
- options.op_registration_whitelist,
- options.op_selection_yaml_path,
- )
- native_yaml_path = os.path.join(options.source_path, "native/native_functions.yaml")
- tags_yaml_path = os.path.join(options.source_path, "native/tags.yaml")
- from torchgen.model import dispatch_keys
- # TODO: stop generating CUDA kernels for non-CUDA builds
- ignore_keys = set()
- if not options.mps:
- ignore_keys.add(DispatchKey.MPS)
- if DispatchKey.MPS in dispatch_keys:
- del dispatch_keys[dispatch_keys.index(DispatchKey.MPS)]
- parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path, ignore_keys)
- valid_tags = _GLOBAL_PARSE_TAGS_YAML_CACHE[tags_yaml_path]
- native_functions, backend_indices = (
- parsed_yaml.native_functions,
- parsed_yaml.backend_indices,
- )
- grouped_native_functions = get_grouped_native_functions(native_functions)
- structured_native_functions = [
- g for g in grouped_native_functions if isinstance(g, NativeFunctionsGroup)
- ]
- native_functions_with_view_groups = get_grouped_by_view_native_functions(
- native_functions
- )
- view_groups = [
- g
- for g in native_functions_with_view_groups
- if isinstance(g, NativeFunctionsViewGroup)
- ]
- # NB: It is mandatory to NOT use os.path.join here, as the install directory
- # will eventually be ingested by cmake, which does not respect Windows style
- # path slashes. If you switch this to use os.path.join, you'll get an error
- # like:
- #
- # Syntax error in cmake code when parsing string
- #
- # C:/Jenkins/workspace/pytorch-builds/pytorch-win-ws2016-cuda9-cudnn7-py3-build/build/aten/src/ATen\core/TensorMethods.h
- #
- # Invalid character escape '\c'.
- core_install_dir = f"{options.install_dir}/core"
- pathlib.Path(core_install_dir).mkdir(parents=True, exist_ok=True)
- ops_install_dir = f"{options.install_dir}/ops"
- pathlib.Path(ops_install_dir).mkdir(parents=True, exist_ok=True)
- core_fm = make_file_manager(options=options, install_dir=core_install_dir)
- cpu_fm = make_file_manager(options=options)
- cpu_vec_fm = make_file_manager(options=options)
- cuda_fm = make_file_manager(options=options)
- ops_fm = make_file_manager(options=options, install_dir=ops_install_dir)
- # Only a limited set of dispatch keys get CPUFunctions.h headers generated
- # for them; this is the set
- functions_keys = {
- DispatchKey.CPU,
- DispatchKey.CUDA,
- DispatchKey.CompositeImplicitAutograd,
- DispatchKey.CompositeImplicitAutogradNestedTensor,
- DispatchKey.CompositeExplicitAutograd,
- DispatchKey.CompositeExplicitAutogradNonFunctional,
- DispatchKey.Meta,
- }
- if options.mps:
- functions_keys.add(DispatchKey.MPS)
- if options.backend_whitelist:
- dispatch_keys = [
- k
- for k in dispatch_keys
- if is_generic_dispatch_key(k) or str(k) in options.backend_whitelist
- ]
- static_dispatch_idx: List[BackendIndex] = []
- if options.static_dispatch_backend:
- static_dispatch_idx = [
- backend_indices[DispatchKey.parse(key)]
- for key in options.static_dispatch_backend
- ]
- for key in options.static_dispatch_backend:
- dp_key = DispatchKey.parse(key)
- if dp_key not in functions_keys:
- functions_keys.add(dp_key)
- if "sources" in options.generate:
- gen_source_files(
- native_functions=native_functions,
- grouped_native_functions=grouped_native_functions,
- structured_native_functions=structured_native_functions,
- view_groups=view_groups,
- selector=selector,
- static_dispatch_idx=static_dispatch_idx,
- backend_indices=backend_indices,
- core_fm=core_fm,
- cpu_fm=cpu_fm,
- cpu_vec_fm=cpu_vec_fm,
- cuda_fm=cuda_fm,
- dispatch_keys=dispatch_keys,
- functions_keys=functions_keys,
- rocm=options.rocm,
- force_schema_registration=options.force_schema_registration,
- per_operator_headers=options.per_operator_headers,
- skip_dispatcher_op_registration=options.skip_dispatcher_op_registration,
- )
- if "headers" in options.generate:
- gen_headers(
- native_functions=native_functions,
- valid_tags=valid_tags,
- grouped_native_functions=grouped_native_functions,
- structured_native_functions=structured_native_functions,
- static_dispatch_idx=static_dispatch_idx,
- selector=selector,
- backend_indices=backend_indices,
- core_fm=core_fm,
- cpu_fm=cpu_fm,
- cuda_fm=cuda_fm,
- ops_fm=ops_fm,
- dispatch_keys=dispatch_keys,
- functions_keys=functions_keys,
- rocm=options.rocm,
- per_operator_headers=options.per_operator_headers,
- )
- if "declarations_yaml" in options.generate:
- gen_declarations_yaml(native_functions=native_functions, cpu_fm=cpu_fm)
- if options.output_dependencies:
- depfile_path = pathlib.Path(options.output_dependencies).resolve()
- depfile_name = depfile_path.name
- depfile_stem = depfile_path.stem
- for fm, prefix in [
- (cpu_fm, ""),
- (cpu_vec_fm, "cpu_vec_"),
- (core_fm, "core_"),
- (cuda_fm, "cuda_"),
- (ops_fm, "ops_"),
- ]:
- varname = prefix + depfile_stem
- path = depfile_path.parent / (prefix + depfile_name)
- fm.write_outputs(varname, str(path))
- if __name__ == "__main__":
- main()
|