12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723 |
- import dataclasses
- import itertools
- import re
- from dataclasses import dataclass
- from enum import auto, Enum
- from typing import Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
- from torchgen.utils import assert_never, NamespaceHelper, OrderedSet
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- #
- # DATA MODEL
- #
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- #
- # Some general principles for our data model.
- #
- # - Stop using C++ data types as the internal data representation
- # format. Instead, the internal data structures are centered
- # around JIT schema representation. This avoid a big problem
- # with the old codegen where we read in all the types from
- # native_functions.yaml and then immediately had to retranslate
- # them into C++ types.
- #
- # - More semantic data representation. Instead of representing
- # everything as dicts and strings, we define dataclasses for
- # every interesting entity the code generation has to deal with.
- # These dataclasses have strong semantic invariants: for example,
- # we generally require them to roundtrip losslessly into the
- # form they were parsed from. These structures are immutable
- # and you're expected to populate information once during
- # construction.
- # Represent a source location; used for better error reporting
- @dataclass(frozen=True)
- class Location:
- file: str
- line: int
- def __str__(self) -> str:
- return "{}:{}".format(self.file, self.line)
- # Valid values of the 'variants' field in native_functions.yaml
- class Variant(Enum):
- function = auto()
- method = auto()
- # Default kernel namespace
- DEFAULT_KERNEL_NAMESPACE = "at::native"
- # NOTE: Keep the list in sync with `DispatchKey` in c10/core/DispatchKey.h
- BACKEND_COMPONENTS = "CPU CUDA HIP XLA MPS IPU XPU HPU VE Lazy Meta PrivateUse1 PrivateUse2 PrivateUse3".split()
- FUNCTIONALITY_KEYS = ["", "Quantized", "Sparse", "NestedTensor", "Autograd"]
- # This list guards dispatches that can be used in derivatives.yaml
- # For now we omit AutogradFunctionality and AutogradOther
- AUTOGRAD_KEYS = ["AutogradNestedTensor"] + [
- "Autograd" + component for component in BACKEND_COMPONENTS
- ]
- FRAGMENT_NAMESPACES = {"quantized", "quantized_decomposed"}
- # This doesn't have to be in sync with the header, it only needs to contain
- # entries that we actually use in the codegen or want pyi entries for
- class DispatchKey(Enum):
- Undefined = 0
- CatchAll = Undefined
- FPGA = auto()
- ORT = auto()
- Vulkan = auto()
- Metal = auto()
- MKLDNN = auto()
- OpenGL = auto()
- OpenCL = auto()
- IDEEP = auto()
- CustomRNGKeyId = auto()
- MkldnnCPU = auto()
- Sparse = auto()
- SparseCsrCPU = auto()
- SparseCsrCUDA = auto()
- Python = auto()
- FuncTorchDynamicLayerBackMode = auto()
- ZeroTensor = auto()
- BackendSelect = auto()
- Named = auto()
- AutogradOther = auto()
- AutogradFunctionality = auto()
- AutogradNestedTensor = auto()
- Tracer = auto()
- Autocast = auto()
- Batched = auto()
- VmapMode = auto()
- FuncTorchDynamicLayerFrontMode = auto()
- Functionalize = auto()
- TESTING_ONLY_GenericWrapper = auto()
- TESTING_ONLY_GenericMode = auto()
- ADInplaceOrView = auto()
- Autograd = auto()
- CompositeImplicitAutograd = auto()
- CompositeImplicitAutogradNestedTensor = auto()
- CompositeExplicitAutograd = auto()
- CompositeExplicitAutogradNonFunctional = auto()
- # BEGIN autogenerated
- CPU = auto()
- CUDA = auto()
- HIP = auto()
- XLA = auto()
- MPS = auto()
- IPU = auto()
- XPU = auto()
- HPU = auto()
- VE = auto()
- Lazy = auto()
- Meta = auto()
- PrivateUse1 = auto()
- PrivateUse2 = auto()
- PrivateUse3 = auto()
- QuantizedCPU = auto()
- QuantizedCUDA = auto()
- QuantizedHIP = auto()
- QuantizedXLA = auto()
- QuantizedMPS = auto()
- QuantizedIPU = auto()
- QuantizedXPU = auto()
- QuantizedHPU = auto()
- QuantizedVE = auto()
- QuantizedLazy = auto()
- QuantizedMeta = auto()
- QuantizedPrivateUse1 = auto()
- QuantizedPrivateUse2 = auto()
- QuantizedPrivateUse3 = auto()
- SparseCPU = auto()
- SparseCUDA = auto()
- SparseHIP = auto()
- SparseXLA = auto()
- SparseMPS = auto()
- SparseIPU = auto()
- SparseXPU = auto()
- SparseHPU = auto()
- SparseVE = auto()
- SparseLazy = auto()
- SparseMeta = auto()
- SparsePrivateUse1 = auto()
- SparsePrivateUse2 = auto()
- SparsePrivateUse3 = auto()
- NestedTensorCPU = auto()
- NestedTensorCUDA = auto()
- NestedTensorHIP = auto()
- NestedTensorXLA = auto()
- NestedTensorMPS = auto()
- NestedTensorIPU = auto()
- NestedTensorXPU = auto()
- NestedTensorHPU = auto()
- NestedTensorVE = auto()
- NestedTensorLazy = auto()
- NestedTensorMeta = auto()
- NestedTensorPrivateUse1 = auto()
- NestedTensorPrivateUse2 = auto()
- NestedTensorPrivateUse3 = auto()
- AutogradCPU = auto()
- AutogradCUDA = auto()
- AutogradHIP = auto()
- AutogradXLA = auto()
- AutogradMPS = auto()
- AutogradIPU = auto()
- AutogradXPU = auto()
- AutogradHPU = auto()
- AutogradVE = auto()
- AutogradLazy = auto()
- AutogradMeta = auto()
- AutogradPrivateUse1 = auto()
- AutogradPrivateUse2 = auto()
- AutogradPrivateUse3 = auto()
- # END autogenerated
- def __str__(self) -> str:
- return self.name
- def lower(self) -> str:
- return str(self).lower()
- @staticmethod
- def parse(value: str) -> "DispatchKey":
- for k, v in DispatchKey.__members__.items():
- if k == value:
- return v
- raise AssertionError(f"unknown dispatch key {value}")
- def codegen_per_backend_entries() -> str:
- r = []
- for fk in FUNCTIONALITY_KEYS:
- for bc in BACKEND_COMPONENTS:
- r.append(f" {fk}{bc} = auto()")
- return "\n".join(r)
- for fk in FUNCTIONALITY_KEYS:
- for bc in BACKEND_COMPONENTS:
- if not hasattr(DispatchKey, fk + bc):
- r = codegen_per_backend_entries()
- print(r)
- raise RuntimeError(
- f"Missing {fk}{bc} from DispatchKey enum. Here is the autogenerated list we expect to have:\n\n{r}"
- )
- STRUCTURED_DISPATCH_KEYS = {DispatchKey.MPS, DispatchKey.CUDA, DispatchKey.CPU}
- UFUNC_DISPATCH_KEYS = {DispatchKey.CUDA, DispatchKey.CPU}
- # Set of supported dispatch keys
- dispatch_keys = [
- DispatchKey.CPU,
- DispatchKey.SparseCPU,
- DispatchKey.SparseCsrCPU,
- DispatchKey.MkldnnCPU,
- DispatchKey.CUDA,
- DispatchKey.MPS,
- DispatchKey.SparseCUDA,
- DispatchKey.SparseCsrCUDA,
- DispatchKey.QuantizedCPU,
- DispatchKey.QuantizedCUDA,
- DispatchKey.CompositeImplicitAutograd,
- DispatchKey.CompositeImplicitAutogradNestedTensor,
- DispatchKey.CompositeExplicitAutograd,
- DispatchKey.CompositeExplicitAutogradNonFunctional,
- DispatchKey.NestedTensorCPU,
- DispatchKey.NestedTensorCUDA,
- # Meta is a magic key: it is automatically generated for structured
- # kernels
- DispatchKey.Meta,
- DispatchKey.SparseMeta,
- DispatchKey.QuantizedMeta,
- DispatchKey.NestedTensorMeta,
- DispatchKey.ZeroTensor,
- ]
- # Dispatch keys that "support all backends". These codegen slightly differently
- # then backend specific keys.
- def is_generic_dispatch_key(dk: DispatchKey) -> bool:
- return dk in {
- DispatchKey.CompositeExplicitAutograd,
- DispatchKey.CompositeExplicitAutogradNonFunctional,
- DispatchKey.CompositeImplicitAutograd,
- DispatchKey.CompositeImplicitAutogradNestedTensor,
- }
- # CUDA specific dispatch keys
- def is_cuda_dispatch_key(dk: DispatchKey) -> bool:
- return dk in {
- DispatchKey.CUDA,
- DispatchKey.QuantizedCUDA,
- DispatchKey.SparseCUDA,
- DispatchKey.SparseCsrCUDA,
- DispatchKey.NestedTensorCUDA,
- DispatchKey.AutogradCUDA,
- }
- # Structured kernel generation is only supported for certain key types;
- # otherwise use old-style
- def is_structured_dispatch_key(dk: DispatchKey) -> bool:
- return dk in STRUCTURED_DISPATCH_KEYS
- def is_ufunc_dispatch_key(dk: DispatchKey) -> bool:
- # For now, ufunc dispatch keys coincide with structured keys
- return dk in UFUNC_DISPATCH_KEYS
- # This is oddly named ScalarType and not DType for symmetry with C++
- class ScalarType(Enum):
- Byte = auto()
- Char = auto()
- Short = auto()
- Int = auto()
- Long = auto()
- Half = auto()
- Float = auto()
- Double = auto()
- ComplexHalf = auto()
- ComplexFloat = auto()
- ComplexDouble = auto()
- Bool = auto()
- BFloat16 = auto()
- def __str__(self) -> str:
- return self.name
- @staticmethod
- def maybe_parse(value: str) -> Optional["ScalarType"]:
- for k, v in ScalarType.__members__.items():
- if k == value:
- return v
- return None
- @staticmethod
- def parse(value: str) -> "ScalarType":
- mb_r = ScalarType.maybe_parse(value)
- assert mb_r is not None, f"unknown dtype {value}"
- return mb_r
- @staticmethod
- def parse_set(values: str) -> OrderedSet["ScalarType"]:
- dtypes: OrderedSet[ScalarType] = OrderedSet()
- for value in values.split(", "):
- if value in DTYPE_CLASSES:
- dtypes.update(DTYPE_CLASSES[value])
- else:
- dtypes.add(ScalarType.parse(value))
- return dtypes
- DTYPE_CLASSES: Dict[str, OrderedSet[ScalarType]] = {}
- # NB: Integral doesn't include boolean
- DTYPE_CLASSES["Integral"] = OrderedSet(
- [
- ScalarType.Byte,
- ScalarType.Char,
- ScalarType.Int,
- ScalarType.Long,
- ScalarType.Short,
- ]
- )
- # NB: Floating doesn't include low precision types
- DTYPE_CLASSES["Floating"] = OrderedSet([ScalarType.Float, ScalarType.Double])
- DTYPE_CLASSES["Complex"] = OrderedSet(
- [ScalarType.ComplexFloat, ScalarType.ComplexDouble]
- )
- DTYPE_CLASSES["All"] = DTYPE_CLASSES["Integral"] | DTYPE_CLASSES["Floating"]
- DTYPE_CLASSES["AllAndComplex"] = DTYPE_CLASSES["All"] | DTYPE_CLASSES["Complex"]
- DTYPE_CLASSES["FloatingAndComplex"] = (
- DTYPE_CLASSES["Floating"] | DTYPE_CLASSES["Complex"]
- )
- # Represents the valid entries for ufunc_inner_loop in native_functions.yaml.
- # NB: if you add a new UfuncKey, you will teach torchgen.dest.ufunc how
- # to process it. Most logic will ignore keys they don't understand, so your
- # new key will get silently ignored until you hook in logic to deal with it.
- class UfuncKey(Enum):
- # These are low level keys that represent exactly one particular
- # instantiation of the kernel produced by codegen
- CUDAFunctor = auto()
- CUDAFunctorOnOther = auto()
- CUDAFunctorOnSelf = auto()
- CPUScalar = auto()
- CPUVector = auto()
- # These are the ones users will usually specify, and
- # implicitly "fill in" the low level keys
- ScalarOnly = auto() # CUDA*, CPUScalar
- Generic = auto() # CUDA*, CPU*
- def __str__(self) -> str:
- return self.name
- @staticmethod
- def parse(value: str) -> "UfuncKey":
- for k, v in UfuncKey.__members__.items():
- if k == value:
- return v
- raise AssertionError(f"unknown ufunc key {value}")
- class DeviceCheckType(Enum):
- NoCheck = 0
- ExactSame = 1
- class ViewSchemaKind(Enum):
- aliasing = auto()
- aliasing_inplace = auto()
- non_aliasing = auto()
- # The basic input to the code generation is native_functions.yaml.
- # The name "native", BTW, comes from the distinction between native
- # functions and legacy TH functions. The legacy TH functions are gone,
- # but the "native" descriptor has stuck.
- #
- # NativeFunction models a single entry in native_functions.yaml. Its
- # fields roughly correspond to what you would see in the YAML itself,
- # but after canonicalization and parsing has occurred.
- #
- # You can see some of the overall design patterns for how we setup
- # dataclasses in this class, but we will defer a complete discussion
- # of this at FunctionSchema.
- @dataclass(frozen=True)
- class NativeFunction:
- # The namespace for this operator. For example, if we have "at::add"
- # then the namespace would be "at". This enables ops to be registered
- # through the same DSL with a custom namespace. If not specified, the
- # default namespace would be "at".
- namespace: str
- # The function schema of the operator in question. This schema
- # has been parsed; see FunctionSchema for more about its structure.
- # (This type is quoted as we are forward referencing a type
- # defined later in the file. I opted for this ordering of the
- # classes for expository clarity.)
- func: "FunctionSchema"
- # Whether or not to generate mutable tensor arguments like regular
- # ones
- use_const_ref_for_mutable_tensors: bool
- # Whether or not to omit automatic generation of a DeviceGuard
- device_guard: bool
- # How to emit automatic generation of device check
- device_check: DeviceCheckType
- # What python module to put the function in
- python_module: Optional[str]
- # TODO: figure out what this does
- category_override: Optional[str]
- # If no variants are specified in native_functions.yaml, this is
- # assumed to be {'function'}.
- variants: Set[Variant]
- # Whether or not we should skip generating registrations for
- # this kernel. This is a bit of a double-edged sword, as manual
- # registrations don't participate in codegen-based selective build!
- manual_kernel_registration: bool
- # Whether or not to skip generating TensorMethod/Functions bindings
- # for this kernel. Technically, this doesn't actually skip generating
- # the binding; instead, the binding gets generated to __dispatch_{funcname}
- # so you can make use of the normal binding if you need it.
- manual_cpp_binding: bool
- # The location in the YAML file were this native function entry was
- # defined. This is for conveniently reporting error messages!
- loc: "Location"
- # A list of operators that are expected to be auto-generated for this NativeFunction.
- # Note: This list isn't actually directly used by the codegen to generate anything.
- # Instead, the codegen figures out what operators to generate purely based off of
- # function schema, and uses the autogen declarations to error check.
- # We expect every NativeFunction that gets auto-generated be explicitly called out
- # in native_functions.yaml
- autogen: List["OperatorName"]
- # If non-empty, this kernel is subject to ufunc codegen.
- # Sorted by ufunc_key
- ufunc_inner_loop: Dict[UfuncKey, "UfuncInnerLoop"]
- # Whether or not this out functions is a "structured kernel". Structured
- # kernels are defined a little differently from normal kernels; in
- # particular, their shape checking logic is defined separately from
- # the kernel. Only out functions can be structured; other functions
- # delegate to the out function using the structured_delegate keyword.
- # Every structured kernel must have at least an out and a functional
- # variant.
- structured: bool
- # Whether or not this non-out function is a structured kernel, defined
- # in terms of the out kernel referenced by the string here.
- structured_delegate: Optional["OperatorName"]
- # Only valid for structured kernels. Specifies alternative of what
- # to inherit from when defining the meta class for the structured
- # operator. This will usually be TensorIteratorBase. This also
- # changes the semantics of set_output to call the parent class.
- structured_inherits: Optional[str]
- # Structured kernels can declare elements as "precomputed". These elements
- # are returned by the meta function in one struct and passed to the impl
- # function in lieu of certain kernel arguments that these precomputed
- # elements supersede. Information about the names and types of these
- # precomputed elements and how they correspond to kernel arguments is stored
- # in this member, if applicable.
- precomputed: Optional["Precompute"]
- # Argument names whose default should be excluded from the C++ interface.
- # Intended for resolving overload ambiguities between signatures.
- cpp_no_default_args: Set[str]
- # Note [Abstract ATen methods]
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- # An abstract ATen method is one whose dispatch differs between
- # types. These are implemented in derived types (with a
- # standard (throwing) definition in Type). A concrete ATen
- # method is one which has the same dispatch for all types;
- # we just implement it in the base Type. This is exposed
- # in Declarations.yaml via a field named 'abstract'.
- is_abstract: bool
- # Whether or not the NativeFunction contains a backend-agnostic kernel
- has_composite_implicit_autograd_kernel: bool
- has_composite_implicit_autograd_nested_tensor_kernel: bool
- has_composite_explicit_autograd_kernel: bool
- has_composite_explicit_autograd_non_functional_kernel: bool
- # Tags are used to describe semantic information about (groups of) operators,
- # That aren't easily inferrable directly from the operator's schema.
- tags: Set[str]
- # NB: The benefit of defining a dataclass is that we automatically get
- # a constructor defined for all the fields we specify. No need
- # to explicitly write it out.
- # We parse both the NativeFunction + backend-specific information about it, which it stored in a corresponding BackendIndex.
- @staticmethod
- def from_yaml(
- ei: Dict[str, object],
- loc: "Location",
- valid_tags: Set[str],
- ignore_keys: Optional[Set[DispatchKey]] = None,
- ) -> Tuple[
- "NativeFunction", Dict[DispatchKey, Dict["OperatorName", "BackendMetadata"]]
- ]:
- """
- Parse a NativeFunction from a dictionary as directly parsed
- from native_functions.yaml
- """
- e = ei.copy()
- funcs = e.pop("func")
- assert isinstance(funcs, str), f"not a str: {funcs}"
- # only support one level of namespace. E.g., aten::add
- namespace_helper = NamespaceHelper.from_namespaced_entity(
- namespaced_entity=funcs, max_level=1
- )
- namespace = namespace_helper.get_cpp_namespace(default="aten")
- func = FunctionSchema.parse(namespace_helper.entity_name)
- cpp_no_default_args_list = e.pop("cpp_no_default_args", [])
- assert isinstance(cpp_no_default_args_list, list)
- cpp_no_default_args = set(cpp_no_default_args_list)
- use_const_ref_for_mutable_tensors = e.pop(
- "use_const_ref_for_mutable_tensors", False
- )
- assert isinstance(use_const_ref_for_mutable_tensors, bool)
- variants_s = e.pop("variants", "function")
- assert isinstance(variants_s, str)
- variants: Set[Variant] = set()
- for v in variants_s.split(", "):
- if v == "function":
- variants.add(Variant.function)
- elif v == "method":
- variants.add(Variant.method)
- else:
- raise AssertionError(f"illegal variant {v}")
- manual_kernel_registration = e.pop("manual_kernel_registration", False)
- assert isinstance(
- manual_kernel_registration, bool
- ), f"not a bool: {manual_kernel_registration}"
- manual_cpp_binding = e.pop("manual_cpp_binding", False)
- assert isinstance(manual_cpp_binding, bool), f"not a bool: {manual_cpp_binding}"
- device_guard = e.pop("device_guard", True)
- assert isinstance(device_guard, bool), f"not a bool: {device_guard}"
- device_check_s = e.pop("device_check", None)
- assert device_check_s is None or isinstance(
- device_check_s, str
- ), f"not a str: {device_check_s}"
- device_check: DeviceCheckType
- if device_check_s is None:
- device_check = DeviceCheckType.ExactSame
- else:
- device_check = DeviceCheckType[device_check_s]
- structured = e.pop("structured", False)
- assert isinstance(structured, bool), f"not a bool: {structured}"
- structured_delegate_s = e.pop("structured_delegate", None)
- assert structured_delegate_s is None or isinstance(
- structured_delegate_s, str
- ), f"not a str: {structured_delegate_s}"
- assert structured_delegate_s is None or "::" not in structured_delegate_s, (
- "namespace is not supported in structured delegate,"
- " using the same namespace as the native function"
- )
- structured_delegate: Optional[OperatorName] = None
- if structured_delegate_s is not None:
- structured_delegate = OperatorName.parse(structured_delegate_s)
- structured_inherits = e.pop("structured_inherits", None)
- assert structured_inherits is None or isinstance(
- structured_inherits, str
- ), f"not a str: {structured_inherits}"
- assert structured_inherits is None or "::" not in structured_inherits, (
- "namespace is not supported in structured inherits,"
- " using the same namespace as the native function"
- )
- python_module = e.pop("python_module", None)
- assert python_module is None or isinstance(
- python_module, str
- ), f"not a str: {python_module}"
- assert (
- python_module is None or Variant.method not in variants
- ), "functions in modules cannot be methods"
- category_override = e.pop("category_override", None)
- assert category_override is None or isinstance(
- category_override, str
- ), f"not a str: {category_override}"
- precomputed_dict = e.pop("precomputed", None)
- assert precomputed_dict is None or structured is True
- precomputed = Precompute.parse(precomputed_dict) if precomputed_dict else None
- tags_inp = e.pop("tags", [])
- if isinstance(tags_inp, str):
- tags_inp = [tags_inp]
- assert isinstance(tags_inp, list)
- tags: Set[str] = set()
- for t in tags_inp:
- assert len(valid_tags) > 0
- # TODO: verify that the tag is valid and has an entry in tags.yaml
- if t in valid_tags:
- tags.add(t)
- else:
- raise AssertionError(f"illegal tag {t}")
- from torchgen.api import cpp
- raw_dispatch = e.pop("dispatch", None)
- assert raw_dispatch is None or isinstance(raw_dispatch, dict), e
- dispatch: Dict[DispatchKey, BackendMetadata] = {}
- num_dispatch_keys: int = 0
- if raw_dispatch is not None:
- assert not manual_kernel_registration, (
- "cannot specify both manual_kernel_registration and dispatch; with "
- "manual registration, dispatch has no effect!"
- )
- redundant_composite_implicit_autograd = False
- for ks, v in raw_dispatch.items():
- if ks == "__line__":
- continue # not worth tracking line numbers for dispatch entries
- assert isinstance(ks, str), e
- for k in ks.split(","):
- dispatch_key = DispatchKey.parse(k.strip())
- num_dispatch_keys += 1
- if ignore_keys and dispatch_key in ignore_keys:
- continue
- assert dispatch_key in dispatch_keys, (
- f"Dispatch key {dispatch_key} of kernel {v} "
- "is not a supported dispatch key."
- )
- # We only allow at most 3 levels of namespace for kernels.
- # We will append "native" to a custom kernel namespace.
- namespace_helper = NamespaceHelper.from_namespaced_entity(
- v, max_level=3
- )
- kernel_namespace = namespace_helper.get_cpp_namespace(default="at")
- # Why is 'structured' included? External backends (e.g.
- # XLA) opt into which ops are structured independently
- # of which in-tree ops are structured
- dispatch[dispatch_key] = BackendMetadata(
- kernel=namespace_helper.entity_name,
- structured=structured
- and is_structured_dispatch_key(dispatch_key),
- cpp_namespace=(kernel_namespace + "::native"),
- )
- if (
- dispatch_key is DispatchKey.CompositeImplicitAutograd
- and v == cpp.name(func)
- ):
- redundant_composite_implicit_autograd = True
- # We count the number of dispatch keys which have not been ignored to prevent a dispatch table
- # in which all backend keys are ignored but necessarily kept, remaining compositeimplicit,
- # from being treated as redundant.
- assert not (
- num_dispatch_keys == 1 and redundant_composite_implicit_autograd
- ), (
- "unnecessary dispatch table for this function; just delete the dispatch "
- "key entirely"
- )
- # if a function is a structured delegate, deleting the dispatch
- # table is NOT semantics preserving
- assert (
- structured_delegate
- or dispatch.keys() != {DispatchKey.CompositeImplicitAutograd}
- or dispatch[DispatchKey.CompositeImplicitAutograd].supports_symint()
- or num_dispatch_keys != 1
- ), (
- f"unexpected name for singleton CompositeImplicitAutograd dispatch entry: expected {cpp.name(func)} "
- f"but got {dispatch[DispatchKey.CompositeImplicitAutograd]}. Rename your implementation to the expected "
- "name, then delete the dispatch table"
- )
- elif not structured and structured_delegate is None:
- name = str(func.name.name)
- assert not (
- name.startswith("new_")
- or name.endswith("_like")
- # TODO: maybe it's better to test the return
- or (
- func.arguments.tensor_options
- and not func.arguments.has_tensor_arg()
- )
- ), (
- f"expected {name} to have a CompositeExplicitAutograd "
- "dispatch entry, but there was no dispatch table. Factory functions "
- "should not have implicit dispatch as they should not be decomposed "
- "for __torch_dispatch__"
- )
- dispatch[DispatchKey.CompositeImplicitAutograd] = BackendMetadata(
- cpp.name(func), structured=False, cpp_namespace=DEFAULT_KERNEL_NAMESPACE
- )
- composites_in_dispatch = [
- d
- for d in dispatch
- if d == DispatchKey.CompositeExplicitAutograd
- or d == DispatchKey.CompositeExplicitAutogradNonFunctional
- or d == DispatchKey.CompositeImplicitAutograd
- or d == DispatchKey.CompositeImplicitAutogradNestedTensor
- ]
- assert len(composites_in_dispatch) <= 1 or (
- len(composites_in_dispatch) == 2
- and (
- DispatchKey.CompositeExplicitAutogradNonFunctional
- not in composites_in_dispatch
- )
- and (
- DispatchKey.CompositeImplicitAutogradNestedTensor
- in composites_in_dispatch
- )
- ), (
- "cannot specify more than one of CompositeExplicitAutograd, CompositeExplicitAutogradNonFunctional, "
- "or CompositeImplicitAutograd on a single kernel; each "
- "strictly subsumes the other. If you wanted to provide an explicit autograd "
- "implementation, specify CompositeExplicitAutograd; otherwise specify CompositeImplicitAutograd only"
- )
- autogen_str = e.pop("autogen", "")
- assert isinstance(autogen_str, str)
- autogen = (
- []
- if autogen_str == ""
- else [OperatorName.parse(x) for x in autogen_str.split(", ")]
- )
- raw_ufunc_inner_loop = e.pop("ufunc_inner_loop", {})
- ufunc_inner_loop = {}
- if isinstance(raw_ufunc_inner_loop, str):
- ufunc_inner_loop[UfuncKey.Generic] = UfuncInnerLoop.parse(
- raw_ufunc_inner_loop, UfuncKey.Generic
- )
- elif isinstance(raw_ufunc_inner_loop, dict):
- for k, vo in raw_ufunc_inner_loop.items():
- if k == "__line__":
- continue
- assert isinstance(k, str), f"ufunc_inner_loop key is not a str: {k}"
- assert isinstance(vo, str), f"ufunc_inner_loop value is not a str: {v}"
- ufunc_key = UfuncKey.parse(k)
- ufunc_inner_loop[ufunc_key] = UfuncInnerLoop.parse(vo, ufunc_key)
- else:
- raise AssertionError(
- f"ufunc_inner_loop not str or dict: {raw_ufunc_inner_loop}"
- )
- # Program the BackendIndex for the implicit dispatch entry from ufunc
- if ufunc_inner_loop:
- assert structured, "ufunc must be structured"
- # Delay import ufunc here to avoid circular import issue
- # See: https://github.com/pytorch/pytorch/issues/81294
- import torchgen.api.ufunc as ufunc
- for dispatch_key in UFUNC_DISPATCH_KEYS:
- assert (
- dispatch_key not in dispatch
- ), f"ufunc should not have explicit dispatch entry for {dispatch_key}"
- dispatch[dispatch_key] = BackendMetadata(
- kernel=ufunc.schema_kernel_name(func, dispatch_key),
- structured=True,
- cpp_namespace=DEFAULT_KERNEL_NAMESPACE,
- )
- if structured_delegate:
- # Structured functions MUST have a dispatch table
- is_abstract = True
- else:
- is_abstract = (
- dispatch.keys() != {DispatchKey.CompositeImplicitAutograd}
- and dispatch.keys()
- != {DispatchKey.CompositeImplicitAutogradNestedTensor}
- and dispatch.keys()
- != {
- DispatchKey.CompositeImplicitAutograd,
- DispatchKey.CompositeImplicitAutogradNestedTensor,
- }
- )
- has_composite_implicit_autograd_kernel = (
- DispatchKey.CompositeImplicitAutograd in dispatch.keys()
- )
- has_composite_implicit_autograd_nested_tensor_kernel = (
- DispatchKey.CompositeImplicitAutogradNestedTensor in dispatch.keys()
- )
- has_composite_explicit_autograd_kernel = (
- DispatchKey.CompositeExplicitAutograd in dispatch.keys()
- )
- has_composite_explicit_autograd_non_functional_kernel = (
- DispatchKey.CompositeExplicitAutogradNonFunctional in dispatch.keys()
- )
- # We aren't going to store dispatch metadata inline in NativeFunctions;
- # instead it is separately indexed by backend (so other backends can
- # add more dispatch entries after the fact). Reindex the individual
- # metadata by OperatorName!
- backend_metadata = {k: {func.name: v} for k, v in dispatch.items()}
- # don't care if it exists or not; make it easier to use this function
- # with other yaml parsers that aren't setting __line__ in the dict
- e.pop("__line__", None)
- assert not e, f"leftover entries: {e}"
- # Asserts that we can't do in post_init, because they rely on backend-specific info
- if structured_delegate is not None:
- for key in STRUCTURED_DISPATCH_KEYS:
- assert key not in dispatch, (
- f"if structured_delegate, then must not have {key} in dispatch dictionary "
- "(it is delegated!)"
- )
- return (
- NativeFunction(
- func=func,
- use_const_ref_for_mutable_tensors=use_const_ref_for_mutable_tensors,
- variants=variants,
- structured=structured,
- structured_delegate=structured_delegate,
- structured_inherits=structured_inherits,
- precomputed=precomputed,
- autogen=autogen,
- ufunc_inner_loop=ufunc_inner_loop,
- manual_kernel_registration=manual_kernel_registration,
- manual_cpp_binding=manual_cpp_binding,
- python_module=python_module,
- category_override=category_override,
- device_guard=device_guard,
- device_check=device_check,
- loc=loc,
- cpp_no_default_args=cpp_no_default_args,
- is_abstract=is_abstract,
- has_composite_implicit_autograd_kernel=has_composite_implicit_autograd_kernel,
- has_composite_implicit_autograd_nested_tensor_kernel=has_composite_implicit_autograd_nested_tensor_kernel,
- has_composite_explicit_autograd_kernel=has_composite_explicit_autograd_kernel,
- has_composite_explicit_autograd_non_functional_kernel=has_composite_explicit_autograd_non_functional_kernel,
- tags=tags,
- namespace=namespace,
- ),
- backend_metadata,
- )
- def validate_unstructured(self) -> None:
- # TODO: probably better to accumulate these errors and report them all
- # at once
- assert not self.structured, (
- "This function is structured, but there was "
- "no valid functional variant of it."
- )
- assert self.structured_delegate, (
- "This function delegates to another structured out function, "
- "but no valid function was found (the delegate may not exist, or it has the wrong type)"
- )
- # __post_init__ functions in dataclasses can be used to do extra
- # validation after construction.
- #
- # Notice that we don't do any type validation here. In fact, we
- # rely exclusively on mypy to check if you've done types correctly!
- # Validation is for nontrivial invariants that cannot be (conveniently)
- # encoded in the type system.
- def __post_init__(self) -> None:
- if self.func.arguments.out:
- assert self.variants == {Variant.function}, (
- "Native functions with out arguments MUST "
- "be declared with only function variant; e.g., variants: function; "
- "otherwise you will tickle a Python argument binding bug "
- "(which usually manifests itself as the result variable being undefined.)"
- )
- if self.structured:
- assert self.func.kind() == SchemaKind.out, (
- "Put structured field on the out= "
- "variant of a function; did you mean structured_delegate?"
- )
- assert (
- self.device_guard
- ), "device_guard: False is not respected by structured kernels"
- if self.structured_delegate:
- assert self.func.kind() != SchemaKind.out, (
- "structured_delegate field not allowed "
- "on out= functions; did you mean structured?"
- )
- assert (
- self.device_guard
- ), "device_guard: False is not respected by structured kernels"
- # Technically, with the asserts above, this assert is impossible to
- # happen
- assert not (
- self.structured and self.structured_delegate
- ), "Cannot have both structured and structured_delegate on function"
- defaulted_arguments = {
- a.name for a in self.func.schema_order_arguments() if a.default is not None
- }
- invalid_args = set.difference(self.cpp_no_default_args, defaulted_arguments)
- assert len(invalid_args) == 0, f"Invalid cpp_no_default_args: {invalid_args}"
- if self.structured_inherits is not None:
- assert (
- self.structured
- ), "structured_inherits must also imply structured: True"
- if str(self.func.name).startswith("_foreach"):
- assert self.device_check == DeviceCheckType.NoCheck, (
- "foreach kernels fall back to slow path when tensor are on different devices, "
- "device_check not allowed to be enabled"
- )
- # NB: if your function accidentally has rand/dropout/... in its name
- # but is not actually random, feel free to amend this to special case
- if (
- "rand" in str(self.func.name)
- or (
- "dropout" in str(self.func.name)
- # Backwards of dropout is typically deterministic
- and "backward" not in str(self.func.name)
- and str(self.func.name.name) not in ["_cudnn_init_dropout_state"]
- )
- or self.func.arguments.has_generator_arg()
- ):
- assert "nondeterministic_seeded" in self.tags, str(self.func.name)
- @property
- def has_composite_kernel(self) -> bool:
- return (
- self.has_composite_implicit_autograd_kernel
- or self.has_composite_explicit_autograd_kernel
- or self.has_composite_explicit_autograd_non_functional_kernel
- ) or (
- self.has_composite_implicit_autograd_kernel
- and self.has_composite_implicit_autograd_nested_tensor_kernel
- )
- @property
- def is_view_op(self) -> bool:
- rets = self.func.returns
- is_non_mutating_view = len(rets) > 0 and any(
- r.annotation is not None and not r.annotation.is_write for r in rets
- )
- # See Note [resize_ in Functionalization] for more dtails
- is_inplace_view = (
- "inplace_view" in self.tags and str(self.func.name) != "resize_"
- )
- is_wildcard_view = any(
- inp.annotation is not None and "*" in inp.annotation.alias_set_after
- for inp in self.func.schema_order_arguments()
- )
- return is_non_mutating_view or is_inplace_view or is_wildcard_view
- @property
- def view_schema_kind(self) -> ViewSchemaKind:
- if self.is_view_op and self.func.name.name.inplace:
- assert "inplace_view" in self.tags
- return ViewSchemaKind.aliasing_inplace
- if self.is_view_op:
- return ViewSchemaKind.aliasing
- else:
- return ViewSchemaKind.non_aliasing
- @property
- def root_name(self) -> str:
- return self.func.name.name.base
- @property
- def part_of_structured_group(self) -> bool:
- return self.structured or self.structured_delegate is not None
- class SchemaKind(Enum):
- functional = auto()
- inplace = auto()
- out = auto()
- mutable = auto()
- scratch = auto()
- # A structured kernel is guaranteed to have a functional and out variant, and
- # optionally an inplace variant.
- #
- # NB: we create NativeFunctionsGroup *even if* the function is not
- # actually annotated structured. Test the structured boolean to see if it
- # actually is structured or not.
- @dataclass(frozen=True)
- class NativeFunctionsGroup:
- functional: NativeFunction
- inplace: Optional[NativeFunction]
- mutable: Optional[NativeFunction]
- out: NativeFunction
- @property
- def structured(self) -> bool:
- # Whether or not the operator has a meta() function. This information is backend-agnostic.
- return self.out.structured
- def __post_init__(self) -> None:
- test_sig: FunctionSchema = self.functional.func.signature()
- for f in self.functions():
- if test_sig != f.func.signature():
- raise AssertionError(
- "NativeFunctionsGroup constructed from two NativeFunctions "
- f"that don't have matching signatures: {test_sig} != {f.func.signature()}"
- )
- if self.structured != f.part_of_structured_group:
- raise AssertionError(
- "NativeFunctionsGroup constructed from structured and unstructured "
- f"functions: {self.out.func.name} and {f.func.name}"
- )
- assert self.functional.func.kind() == SchemaKind.functional
- assert self.out.func.kind() == SchemaKind.out
- assert self.functional.namespace == self.out.namespace
- if self.inplace is not None:
- assert self.inplace.func.kind() == SchemaKind.inplace
- assert self.inplace.namespace == self.functional.namespace
- if self.mutable is not None:
- assert self.mutable.func.kind() == SchemaKind.mutable
- assert self.mutable.namespace == self.functional.namespace
- # See Note [Overload Ambiguity With Functional Variants]
- assert self.functional.func.name.name.functional_overload
- if self.structured:
- # For now, structured composite kernels are not supported (need some
- # design work to figure out how to make the composite case work)
- assert (
- not self.out.has_composite_implicit_autograd_kernel
- and not self.out.has_composite_implicit_autograd_nested_tensor_kernel
- )
- assert self.functional.structured_delegate == self.out.func.name, (
- f"{self.functional.func.name} delegates to {self.functional.structured_delegate} "
- f"but its actual delegate is {self.out.func.name}"
- )
- if self.inplace is not None:
- assert self.inplace.structured_delegate == self.out.func.name
- generated_fns = sorted(
- [str(f.func.name) for f in self.functions() if "generated" in f.tags]
- )
- generated_fns_str = ", ".join(str(x) for x in generated_fns)
- expected_generated_fns: Set[str] = set()
- for f in self.functions():
- expected_generated_fns.update(str(op) for op in f.autogen)
- expected_generated_fns_str = ", ".join(
- str(x) for x in sorted(expected_generated_fns)
- )
- if len(expected_generated_fns) == 0 and len(generated_fns) > 0:
- raise RuntimeError(
- f"The codegen expects to be able to generate '{generated_fns_str}'."
- " In order to generate them however, we expect them to be called out explicitly in the yaml."
- f" Please add an 'autogen: {generated_fns_str}' line to the entry for {str(f.func.name)}"
- )
- if expected_generated_fns_str != generated_fns_str:
- raise RuntimeError(
- f"The codegen expects to be able to generate '{generated_fns_str}'."
- f" To do so, it expects a line: 'autogen: {generated_fns_str}'."
- f" Instead, it found 'autogen: {expected_generated_fns_str}'"
- )
- def signature(self) -> "FunctionSchema":
- return self.out.func.signature()
- def functions(self) -> Iterator[NativeFunction]:
- yield self.functional
- yield self.out
- if self.inplace is not None:
- yield self.inplace
- if self.mutable is not None:
- yield self.mutable
- @property
- def root_name(self) -> str:
- return self.functional.root_name
- @staticmethod
- def from_dict(
- d: Dict[SchemaKind, NativeFunction]
- ) -> Optional["NativeFunctionsGroup"]:
- assert d
- if len(d) == 1:
- return None
- d = dict(d) # non-destructive updates please
- functional = d.pop(SchemaKind.functional, None)
- inplace = d.pop(SchemaKind.inplace, None)
- mutable = d.pop(SchemaKind.mutable, None)
- out = d.pop(SchemaKind.out, None)
- assert not d
- assert functional is not None
- # There are a few operators which only have functional/inplace variants;
- # these don't count as structured for our purposes here
- if out is None:
- return None
- # assuming all variants have the same namespace
- return NativeFunctionsGroup(
- functional=functional,
- inplace=inplace,
- mutable=mutable,
- out=out,
- )
- @dataclass(frozen=True)
- class BackendMetadata:
- # The name of the backend kernel, for a given operator
- # for in-tree backends. These names come directly from the 'dispatch" field
- # in native_functions.yaml. The dispatch entry is optional; in that
- # case, that is equivalent to having written:
- #
- # dispatch:
- # CompositeImplicitAutograd: $operator_name
- kernel: str
- # Whether or not the operator has a structured kernel implemented, for this particular backend.
- # For in-tree backends, they all have the same value for structured- this is listed
- # in native_functions.yaml.
- # However, external backends like XLA can indendently toggle which ops are structured.
- structured: bool
- # The namespace for kernels, default value: DEFAULT_KERNEL_NAMESPACE
- cpp_namespace: str
- def supports_symint(self) -> bool:
- return "_symint" in self.kernel
- @dataclass(frozen=True)
- class UfuncInnerLoop:
- name: str
- supported_dtypes: OrderedSet[ScalarType]
- # key is stored here because it affects the semantics of name,
- # so its helpful to have them together for further processing
- ufunc_key: UfuncKey
- @staticmethod
- def parse(value: str, ufunc_key: UfuncKey) -> "UfuncInnerLoop":
- name, supported_dtypes_str = value.split(" ", 1)
- assert supported_dtypes_str[0] == "("
- assert supported_dtypes_str[-1] == ")"
- supported_dtypes: OrderedSet[ScalarType] = OrderedSet()
- for k in supported_dtypes_str[1:-1].split(", "):
- supported_dtypes |= ScalarType.parse_set(k)
- return UfuncInnerLoop(
- name=name, supported_dtypes=supported_dtypes, ufunc_key=ufunc_key
- )
- # BackendIndex represents a backend.
- # The BackendIndex encodes per-operator information that is potentially different
- # for each backend. The most obvious example is the name of the kernel
- # (the 'dispatch' entry in native_functions.yaml).
- # However, there can be other examples of different backends having different information.
- # External backends can choose to opt their kernels to be structured independently from in-tree backends,
- # which means that this information isn't inherently tied to a NativeFunction- it's different per backend.
- @dataclass(frozen=True)
- class BackendIndex:
- dispatch_key: DispatchKey
- # Mainly important for structured kernels, this determines which variant in the operator group is used to implement the others.
- # All in-tree ops use out kernels, while XLA uses functional kernels.
- use_out_as_primary: bool
- # Whether the backend requires a device guard, and device checks.
- # For in-tree backends, this is currently just CUDA/HIP
- # For out-of-tree backends, this is currently just Intel XPU
- device_guard: bool
- # Whether the backend is in-tree (CPU/CUDA) or out-of-tree (XLA)
- external: bool
- # Other backend-specific information that is on a per-operator basis
- index: Dict["OperatorName", BackendMetadata]
- @staticmethod
- def grow_index(
- parent_index: Dict[DispatchKey, Dict["OperatorName", BackendMetadata]],
- child_index: Dict[DispatchKey, Dict["OperatorName", BackendMetadata]],
- ) -> None:
- for k, v in child_index.items():
- for op_name, metadata in v.items():
- assert (
- op_name not in parent_index[k]
- ), f"duplicate operator {op_name} for dispatch key {k}"
- parent_index[k][op_name] = metadata
- def primary(self, g: NativeFunctionsGroup) -> NativeFunction:
- if self.use_out_as_primary:
- return g.out
- else:
- return g.functional
- def has_kernel(self, g: Union[NativeFunction, NativeFunctionsGroup]) -> bool:
- m = self.get_kernel(g)
- return m is not None
- def get_kernel(
- self, g: Union[NativeFunction, NativeFunctionsGroup]
- ) -> Optional[BackendMetadata]:
- if isinstance(g, NativeFunction):
- f = g
- elif isinstance(g, NativeFunctionsGroup):
- f = self.primary(g)
- else:
- assert_never(g)
- if f.func.name not in self.index:
- return None
- return self.index[f.func.name]
- def native_function_class_name(self) -> Optional[str]:
- if self.external:
- return f"{str(self.dispatch_key)}NativeFunctions"
- else:
- # TODO: This discrepancy isn't required; we could also generated
- # a class for in-tree kernels. It'll just require carefully
- # updating every kernel definition + callsite of every in-tree aten kernel.
- return None
- # The function schema is undoubtedly the most important data structure
- # in all of the codegen, as it defines the type signature for operators,
- # and most of the code generation we do is type directed (e.g., look at
- # the types, decide what to do. Think about how we code generate
- # C++ function stubs!)
- #
- # We will also see in this class the general structure for how we model
- # data in this code generation. A few notable properties to point out
- # ahead of time:
- #
- # - These dataclasses are a *lossless* representation of the strings
- # they are parsed from. In fact, we assert that given the
- # information stored in the dataclass, we can exactly reconstruct
- # the string we parsed from (and assert this inside the parse
- # definition). There are a few reasons for this:
- #
- # - If you find that it is difficult to reconstruct the string
- # given a dataclass, that is a clue that you are data
- # representation is wrong.
- #
- # - It helps ensure that all relevant information is present
- # in the dataclass, so that downstream users aren't tempted
- # to reparse the original string to get some information
- # that was omitted.
- #
- # - It forces you to represent the data in-memory in the same way
- # it is recorded textually, which makes the dataclasses easier
- # to understand for someone who is familiar with the
- # textual format. (As a tradeoff, it means you have to model
- # the syntax, even when it is inconvenient. But maybe that means
- # the syntax is bad!) If you don't understand the internal
- # representation, go look at the printing code to see how
- # it maps onto the surface syntax!
- #
- # - It makes it easy to test the parsing code, as parsing code
- # that is inconsistent with the string code will fail early
- # and loudly. (As a tradeoff, it makes the parsing code a bit
- # brittle (in particular, with trivial whitespace changes you
- # are likely to trigger an assert error).
- #
- # In general, try to make the __str__ code as simple as possible
- # (even at the cost of more complex parsing logic.) Additionally,
- # try to minimize redundancy in data representation. (Precomputed
- # fields are OK though: they are defined as a simple function on
- # the canonical representation in question.)
- #
- # - These dataclasses are all frozen; once constructed their
- # values never change. This makes it easy to tell where any
- # given data came from: just look to the constructor. As a
- # tradeoff, you can't easily "decorate" a schema with extra
- # information from a post-facto analysis. We impose this
- # restriction to make these structures more understandable.
- #
- @dataclass(frozen=True)
- class FunctionSchema:
- # The name of the operator this function schema describes.
- name: "OperatorName"
- arguments: "Arguments"
- # TODO: Need to handle collisions with argument names at some point
- returns: Tuple["Return", ...]
- def schema_order_arguments(self) -> Iterator["Argument"]:
- return itertools.chain(
- self.arguments.flat_positional,
- self.arguments.flat_kwarg_only,
- self.arguments.out,
- )
- decl_re = re.compile(r"(?P<name>[^\(]+)\((?P<args>.*)\) -> (?P<returns>.*)")
- @staticmethod
- def parse(func: str) -> "FunctionSchema":
- # We should probably get a proper parser here
- decls = FunctionSchema.decl_re.findall(func)
- assert len(decls) == 1, f"Invalid function schema: {func}"
- ops, args, return_decl = decls[0]
- name = OperatorName.parse(ops)
- arguments = Arguments.parse(args)
- returns = parse_returns(return_decl)
- r = FunctionSchema(name=name, arguments=arguments, returns=returns)
- assert str(r) == func, f"{str(r)} != {func}"
- return r
- def returns_are_aliased(self) -> bool:
- # We assert earlier that schemas can't have a mix of aliased and non-aliased returns
- return any(
- r
- for r in self.returns
- if r.annotation is not None and r.annotation.is_write
- )
- def __post_init__(self) -> None:
- for arg, ret in zip(self.arguments.out, self.returns):
- assert arg.annotation == ret.annotation, (
- "Out arguments must have matching return Tensor; furthermore, "
- "the ith-argument needs to correspond to the ith return"
- )
- # We also enforce that if you have any mutable, positional args, then they are not returned.
- # This makes it easier to group these functions properly with their functional/out= counterparts.
- for a in self.arguments.post_self_positional_mutable:
- assert not any(
- a.annotation == r.annotation for r in self.returns
- ), f"If you have a schema with mutable positional args, we expect them to not be returned. schema: {str(self)}"
- # Invariant: we expect out arguments to appear as keyword arguments in the schema.
- # This means that all mutable returns should be aliased to a keyword argument
- # (except for "self", which we explicitly don't treat as an out argument because of its use in methods)
- # See Note [is_out_fn]
- out_and_self = list(self.arguments.out) + [
- arg for arg in self.arguments.flat_positional if arg.name == "self"
- ]
- mutable_returns = [
- ret
- for ret in self.returns
- if ret.annotation is not None and ret.annotation.is_write
- ]
- immutable_returns = [
- ret
- for ret in self.returns
- if ret.annotation is None or not ret.annotation.is_write
- ]
- # Some assertions: We don't want any functions with a return type of "-> (Tensor(a!), Tensor)",
- # because:
- # (1) It's more annoying to handle properly
- # (2) It's unnecessary - you can't method-chain on the first (mutated) output because it's part of a tuple.
- # Instead, we expect the (a!) argument to not be returned.
- assert (
- len(mutable_returns) == 0 or len(immutable_returns) == 0
- ), f"NativeFunctions must have either only mutable returns, or only immutable returns. Found: {str(self)}"
- for ret in mutable_returns:
- assert any([ret.annotation == arg.annotation for arg in out_and_self]), (
- 'All mutable returns must be aliased either to a keyword argument, or to "self". '
- "Did you forget to mark an out argument as keyword-only?"
- )
- if self.arguments.out:
- # out= ops that return their mutable inputs are only really useful for method chaining.
- # And method chaining is only really useful if the thing you're returning is a plain Tensor.
- # So ideally, we'd enforce that out= ops with a single plain mutable tensor should return the tensor,
- # and all other types of out= op schemas should return void.
- # There are a bunch of existing out= ops that return tuples of tensors though, so we're stuck with allowing that.
- if any(a.type != BaseType(BaseTy.Tensor) for a in self.arguments.out):
- assert (
- len(self.returns) == 0
- ), "out= ops that accept tensor lists as out arguments "
- "are expected to have no return type (since you can't do method chaining on them)"
- else:
- # mutable keyward arguments whose name has _scratch_ prefix are
- # scratch tensors for memory planning and should not be returned
- assert len(
- [
- arg
- for arg in self.arguments.out
- if not arg.name.startswith("_scratch_")
- ]
- ) == len(
- self.returns
- ), "Must return as many arguments as there are out arguments, or no return at all"
- if self.name.name.inplace:
- self_a = self.arguments.self_arg
- assert (
- self_a
- and self_a.argument.annotation
- and self_a.argument.annotation.is_write
- )
- if self_a.argument.type == BaseType(BaseTy.Tensor):
- # All inplace ops with an ordinary `Tensor self` argument should return self,
- # to allow for method chaining.
- assert (
- len(self.returns) == 1
- and self.returns[0].annotation == self_a.argument.annotation
- )
- else:
- # You can't method chain on non-tensor self arguments though (like a List[Tensor])
- # so in all other cases we expect the return type to be none.
- assert len(self.returns) == 0
- if self.arguments.tensor_options is not None:
- assert self.kind() == SchemaKind.functional, (
- "Found an operator that is not functional or out varuabt, but has tensor options arguments."
- "This is not allowed- tensor options arguments are only allowed for factory functions."
- f"schema: {str(self)}"
- )
- if self.is_functional_fn():
- assert self.kind() == SchemaKind.functional, (
- "Found an operator that is not functional, but its overload contains the string 'functional'."
- "This is a special keyword in the codegen, please use a different overload name."
- f"schema: {str(self)}"
- )
- def is_functional_fn(self) -> bool:
- return "functional" in self.name.overload_name
- def is_out_fn(self) -> bool:
- # Note [is_out_fn]
- #
- # out functions are the variants which take an explicit out= argument
- # to populate into. We need to know if a schema corresponds to an
- # out function for several reasons:
- #
- # - They codegen differently in C++ API
- # - codegen to at::add_out rather than at::add
- # - out argument is moved to front of C++ argument list
- #
- # out functions are DEFINED to be any function with a keyword-only
- # argument that is mutable. In principle, this could lead to a
- # false positive if you define a function that mutates a
- # kwarg only argument, but this isn't the "true" output of this
- # function. A more robust definition that would work in this
- # case would also look at:
- #
- # - The output types. Out functions take in the arguments
- # they mutate and then return them again; this is sort
- # of "definitionally" what makes something an out function.
- # Historically, we DO check this for consistency.
- # - Correspondence with pure variant. An out function
- # should have a signature equivalent to its pure variant,
- # but just with extra kwargs for the output elements. This
- # is difficult to actually check for and historically
- # we only do this check in tools/
- return bool(self.arguments.out)
- def kind(self) -> SchemaKind:
- """
- What kind of schema is this? A functional schema is one
- that returns a newly allocated output; an inplace schema
- modifies the self argument inplace; an out schema writes
- the result into an explicitly provided out argument.
- """
- is_out = bool(self.arguments.out)
- is_scratch = bool(
- [arg for arg in self.arguments.out if arg.name.startswith("_scratch_")]
- )
- is_inplace = self.name.name.inplace
- is_mutable = any(
- a.annotation is not None and a.annotation.is_write
- for a in self.arguments.post_self_positional
- )
- assert not (is_out and is_inplace)
- # out= and inplace schemas can also have post_self_positional mutable args,
- # but we give precedence to out= and inplace when deciding the schema kind.
- # Tradeoff: we probably don't want to have to teach codegen that looks at inplace ops
- # to also worry about mutable post_self_positional arguments,
- # but it seems like a much bigger lift to classify them has having a new schema kind.
- # The number of ops that fit in this strange category is small enough that
- # we can probably manually write code for them instead of forcing the codegen to handle them.
- if is_inplace:
- return SchemaKind.inplace
- elif is_scratch:
- assert (
- is_out
- ), "invariant: all scratch operators are expected to be out= operators too"
- return SchemaKind.scratch
- elif is_out:
- assert (
- not is_scratch
- ), "We should not categorize a scratch op as an out variant. Check if the order of if statements are expected!"
- return SchemaKind.out
- elif is_mutable:
- return SchemaKind.mutable
- else:
- return SchemaKind.functional
- # For every return:
- # - If the return aliases an input, we return the input name
- # - Otherwise, we return None.
- # If return names were enforced to be consistent with aliasing information, then we wouldn't need this.
- def aliased_return_names(self) -> List[Optional[str]]:
- outs: List[Optional[str]] = []
- for r in self.returns:
- aliased_args = [
- a
- for a in self.arguments.flat_all
- if a.annotation is not None and a.annotation == r.annotation
- ]
- if len(aliased_args) == 0:
- outs.append(None)
- elif len(aliased_args) == 1:
- outs.append(aliased_args[0].name)
- else:
- aliased_names = ", ".join(a.name for a in aliased_args)
- raise AssertionError(
- f"Found a return ({r.name})that aliases multiple inputs ({aliased_names})"
- )
- return outs
- def signature(
- self,
- *,
- strip_default: bool = False,
- strip_view_copy_name: bool = False,
- keep_return_names: bool = False,
- ) -> "FunctionSchema":
- """
- Certain schemas are 'related', in that they are simply
- inplace/out/functional versions of the same function. This method
- factors these schemas into the "core" functional signature which
- is equal across all versions.
- Here is what normalization happens to the schema to convert
- it to a signature:
- - The overload name is stripped (name is retained, since
- it expresses semantic content about what the function does)
- - Inplace is set False
- - Out arguments are stripped
- - Mutable post_self_positional args are converted to returns
- - Mutability annotations are stripped (this is sound
- because you cannot overload on mutability annotation)
- - Return names are stripped since they are not overloadable and
- some variants have return names but some not
- - TensorOptions are dropped
- because out= variants of factory functions don't include them
- (and we want to be able to pair up factory functions with their out variants)
- Finally, we want to be able to pair up related "view" and their
- corresponding "view_copy" operators. We do this by optionally
- stripping the trailing "_copy" from the base name.
- Example of a mutable op before and after:
- f.func (Mutable operator):
- _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask) # noqa: B950
- f.func (Corresponding functional operator):
- _fused_moving_avg_obs_fq_helper.functional(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask, Tensor running_min_out, Tensor running_max_out, Tensor scale_out, Tensor zero_point_out) # noqa: B950
- f.func.signature() output:
- _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) # noqa: B950
- """
- def strip_ret_annotation(r: Return) -> Return:
- return Return(
- name=r.name if keep_return_names else None,
- type=r.type,
- annotation=None,
- )
- base_name = self.name.name.base
- if strip_view_copy_name and base_name.endswith("_copy"):
- base_name = base_name.replace("_copy", "")
- # find mutable inputs that are not originally returned, and convert them to returns
- returns_from_mutable_inputs = tuple(
- # When we're grouping functions we strip the return names,
- # but when we're generating the actual functional variants then we follow
- # a convention for what to name the returns
- Return(
- name=f"{a.name}_out" if keep_return_names else None,
- type=a.type,
- annotation=None,
- )
- for a in itertools.chain(
- # Order is important here (otherwise e.g. inplace with mutable args
- # and out= with mutable args won't have the same signature)
- [self.arguments.self_arg.argument]
- if self.arguments.self_arg is not None
- else [],
- self.arguments.out,
- self.arguments.post_self_positional,
- )
- if a.annotation is not None
- and a.annotation.is_write
- and not any(a.annotation == r.annotation for r in self.returns)
- )
- original_returns = tuple(map(strip_ret_annotation, self.returns))
- # Ordering is important here. We expect the "mutable input" returns to come last.
- returns = original_returns + returns_from_mutable_inputs
- args_sig = self.arguments.signature(strip_default=strip_default)
- # See Note [bernoulli.p schema]
- if str(self.name) == "bernoulli.p":
- args_sig = Arguments.parse(str(args_sig).replace("float p", "float p=0.5"))
- return FunctionSchema(
- name=OperatorName(
- name=BaseOperatorName(
- base=base_name,
- inplace=False,
- dunder_method=self.name.name.dunder_method,
- ),
- overload_name="", # stripped
- ),
- arguments=args_sig,
- returns=returns,
- )
- def view_signature(self) -> "FunctionSchema":
- return self.signature(strip_view_copy_name=True)
- def with_name(self, name: "OperatorName") -> "FunctionSchema":
- return FunctionSchema(
- name=name,
- arguments=self.arguments,
- returns=self.returns,
- )
- @property
- def modifies_arguments(self) -> bool:
- return self.kind() in [SchemaKind.inplace, SchemaKind.out, SchemaKind.mutable]
- def has_symint(self) -> bool:
- return self.arguments.has_symint_arg() or any(
- r.type.is_symint_like() for r in self.returns
- )
- def __str__(self) -> str:
- all_arguments_str = str(self.arguments)
- if len(self.returns) == 1:
- returns = str(self.returns[0]) # omit parentheses
- else:
- returns = "(" + ", ".join(map(str, self.returns)) + ")"
- return f"{self.name}({all_arguments_str}) -> {returns}"
- # Here is the rest of the data model, described more briefly.
- # Simplified version for what actually shows up in built-ins.
- # Look at alias_info.h for expanded syntax. If you need the structure,
- # you also need to make this structure recursive so it can be lined
- # up with the type components too. For primitives this isn't really
- # necessary
- @dataclass(frozen=True)
- class Annotation:
- # Typically only has one element. Not actually a set so
- # we can conveniently assume it is canonically ordered
- alias_set: Tuple[str, ...]
- is_write: bool
- alias_set_after: Tuple[str, ...]
- @staticmethod
- def parse(ann: str) -> "Annotation":
- # TODO: implement a proper parser if this gets more ugly
- # Regex Explanation:
- # Example: "a! -> a|b"
- # Group #1: alias before optional '|', required. Matches the first
- # character 'a' in the example
- # Group #2: optional alias set after optional '|', matches empty string
- # in the example
- # Group #3: optional "is write" flag, matches '!' in the example.
- # Group #4: optional section containing arrow, matches " -> a|b" in the
- # example.
- # Group #5: optional alias after set, supports wildcard, matches "a|b"
- # in the example.
- # Group #6: optional sub-section of alias after set, matches "|b" in the
- # example.
- m = re.match(r"^([a-z])(\|[a-z])*(!?)( -> (\*|[a-z](\|[a-z])*))?$", ann)
- assert m is not None, f"unrecognized alias annotation {ann}"
- before_alias = m.group(1) + (m.group(2) if m.group(2) else "")
- alias_set = tuple(before_alias.split("|"))
- is_write = m.group(3) == "!"
- assert not (
- is_write and len(alias_set) > 1
- ), f"alias set larger than 1 is not mutable, got {ann} instead."
- after_set = tuple(m.group(5).split("|")) if m.group(5) else tuple()
- assert not (
- len(before_alias) > 1 and len(after_set) > 1
- ), f"before alias set and after alias set cannot be larger than 1 at the same time, got {ann} instead."
- r = Annotation(
- alias_set=alias_set, is_write=is_write, alias_set_after=after_set
- )
- assert str(r) == ann, f"{r} != {ann}"
- return r
- def __str__(self) -> str:
- alias_set = "|".join(self.alias_set)
- if self.is_write:
- alias_set = f"{alias_set}!"
- alias_set_after = "|".join(self.alias_set_after)
- if alias_set_after:
- alias_set = f'{alias_set}{" -> "}{alias_set_after}'
- return alias_set
- # The base class for the type system. This is also loosely modeled
- # off of jit_type.h, but we've simplified the hierarchy to focus
- # in on the aspects of the type system that matter for code generation
- # (for example, there's no SingleElementType subclass anymore).
- # You never actually construct a Type; usually it's going to be one
- # of the subclasses. If Python had ADTs this would be one!
- @dataclass(frozen=True)
- class Type:
- @staticmethod
- def parse(t: str) -> "Type":
- r = Type._parse(t)
- assert str(r) == t, f"{r} != {t}"
- return r
- @staticmethod
- def _parse(t: str) -> "Type":
- m = re.match(r"^(.+)\?$", t)
- if m is not None:
- return OptionalType(Type.parse(m.group(1)))
- m = re.match(r"^(.+)\[([0-9]+)?\]$", t)
- if m is not None:
- size = int(m.group(2)) if m.group(2) is not None else None
- return ListType(elem=Type.parse(m.group(1)), size=size)
- # '__torch__.torch.classes.' is the prefix for custom class
- m = re.match(r"^__torch__\.torch\.classes\.([a-zA-Z0-9_.]+)$", t)
- if m is not None:
- return CustomClassType(m.group(1))
- try:
- return BaseType(BaseTy[t])
- except KeyError as e:
- raise RuntimeError(f"unrecognized type {t}") from e
- def __str__(self) -> str:
- raise NotImplementedError
- # WARNING: These concepts are not very well-defined. For example,
- # is "int?" nullable? How about "int?[]". They are defined
- # so we can conveniently generate legacy Declarations.yaml but
- # really we should probably just remove these at some point
- def is_base_ty_like(self, base_ty: "BaseTy") -> bool:
- raise NotImplementedError
- def is_tensor_like(self) -> bool:
- return self.is_base_ty_like(BaseTy.Tensor)
- def is_generator_like(self) -> bool:
- return self.is_base_ty_like(BaseTy.Generator)
- def is_symint_like(self) -> bool:
- return self.is_base_ty_like(BaseTy.SymInt)
- def is_nullable(self) -> bool:
- raise NotImplementedError
- def is_list_like(self) -> Optional["ListType"]:
- raise NotImplementedError
- # Base types are simple, atomic types with no further structure
- class BaseTy(Enum):
- Generator = auto()
- ScalarType = auto()
- Tensor = auto()
- int = auto()
- Dimname = auto()
- DimVector = auto()
- float = auto()
- str = auto()
- bool = auto()
- Layout = auto()
- Device = auto()
- Scalar = auto()
- MemoryFormat = auto()
- QScheme = auto()
- Storage = auto()
- Stream = auto()
- SymInt = auto()
- ConstQuantizerPtr = auto() # TODO: rename
- @dataclass(frozen=True)
- class BaseType(Type):
- name: BaseTy
- def __str__(self) -> str:
- return f"{self.name.name}"
- def is_base_ty_like(self, base_ty: BaseTy) -> bool:
- return self.name == base_ty
- def is_nullable(self) -> bool:
- return False
- def is_list_like(self) -> Optional["ListType"]:
- return None
- def is_symint_like(self) -> bool:
- return self.name == BaseTy.SymInt
- # Optional types may be specified, or may also be validly given None
- @dataclass(frozen=True)
- class OptionalType(Type):
- elem: Type
- def __str__(self) -> str:
- return f"{self.elem}?"
- def is_base_ty_like(self, base_ty: BaseTy) -> bool:
- return self.elem.is_base_ty_like(base_ty)
- def is_symint_like(self) -> bool:
- return self.elem.is_symint_like()
- def is_nullable(self) -> bool:
- return True
- def is_list_like(self) -> Optional["ListType"]:
- return self.elem.is_list_like()
- # A type representing a PyTorch custom class
- @dataclass(frozen=True)
- class CustomClassType(Type):
- class_name: str
- def __str__(self) -> str:
- """
- Return the class name will prefix __torch__.torch.classes
- """
- return f"__torch__.torch.classes.{self.class_name}"
- def is_base_ty_like(self, base_ty: BaseTy) -> bool:
- return False
- def is_symint_like(self) -> bool:
- return False
- def is_nullable(self) -> bool:
- """
- Assume a custom class is not nullable.
- """
- return False
- def is_list_like(self) -> Optional["ListType"]:
- return None
- # List types specify that we may have multiples of an element. We
- # also support explicit sizes on list types, but these have
- # some nontrivial semantics! (However, for C++ API purposes, explicit
- # sizes are mostly erased from the type system.)
- #
- # DANGER WILL ROBINSON: C++ elaboration depends on elem type; e.g.,
- # int[] elaborates differently than bool[3]!
- @dataclass(frozen=True)
- class ListType(Type):
- elem: Type
- size: Optional[int]
- def __str__(self) -> str:
- size = f"{self.size}" if self.size else ""
- return f"{self.elem}[{size}]"
- def is_base_ty_like(self, base_ty: BaseTy) -> bool:
- return self.elem.is_base_ty_like(base_ty)
- def is_symint_like(self) -> bool:
- return self.elem.is_symint_like()
- def is_nullable(self) -> bool:
- return self.elem.is_nullable()
- def is_list_like(self) -> Optional["ListType"]:
- return self
- @dataclass(frozen=True)
- class Argument:
- # NB: I didn't put kwarg_only as a boolean field here, unlike
- # c10::Argument, so that printing works correctly
- name: str
- type: Type
- default: Optional[str]
- # The semantics of the annotation field are a little strange.
- #
- # Alias annotations parametrize Tensors (since Tensors are the only things
- # that can alias.) This motivates why I write Tensor(a!)? (and not, for
- # example, Tensor?(a!)), because the (a!) describes aliasing on the tensor,
- # which may be optional (i.e., the alias annotation should bind first to
- # Tensor, before the optional postfix annotation).
- #
- # However, despite being a property of Tensor, we (and c10::Argument)
- # store the annotation at the top level of the Argument, rather than
- # inside the embedded Tensor type. In the C++ version of this
- # class, we then go through great lengths to mimic the type
- # structure in the annotation structure so we can correlate
- # annotations with types.
- #
- # Now, it turns out, in all applications in code generation, the
- # structure of annotated types is very simple. So we just hard
- # code it here. But if we ever do get anything more complex, this
- # model will have to change!
- annotation: Optional[Annotation]
- @staticmethod
- def parse(arg: str) -> "Argument":
- name: str
- default: Optional[str]
- type_and_annot, name_and_default = arg.rsplit(" ", 1)
- if "=" in name_and_default:
- name, default = name_and_default.split("=")
- else:
- name = name_and_default
- default = None
- # TODO: deduplicate annotation matching with Return
- match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot)
- annotation: Optional[Annotation]
- if match:
- # If you update this, make sure the __str__ still works too
- assert match.group(2) in [
- "",
- "?",
- "[]",
- ], "unrecognized alias analysis form with Tensor"
- type_s = "Tensor" + match.group(2)
- annotation = Annotation.parse(match.group(1))
- else:
- type_s = type_and_annot
- annotation = None
- type = Type.parse(type_s)
- r = Argument(
- name=name,
- type=type,
- default=default,
- annotation=annotation,
- )
- assert str(r) == arg, f"{str(r)} != {arg}"
- return r
- @property
- def is_write(self) -> bool:
- return self.annotation is not None and self.annotation.is_write
- def __str__(self) -> str:
- type = f"{self.type}"
- if self.annotation:
- assert type in ["Tensor", "Tensor?", "Tensor[]"]
- type = type.replace("Tensor", f"Tensor({self.annotation})")
- if self.name is None:
- return type
- else:
- mb_default = ""
- if self.default:
- mb_default = f"={self.default}"
- return f"{type} {self.name}{mb_default}"
- @dataclass(frozen=True)
- class Return:
- name: Optional[str]
- type: Type
- annotation: Optional[Annotation]
- @staticmethod
- def parse(arg: str) -> "Return":
- name: Optional[str]
- if " " in arg:
- type_and_annot, name = arg.rsplit(" ", 1)
- else:
- type_and_annot = arg
- name = None
- match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot)
- annotation: Optional[Annotation]
- if match:
- # If you update this, make sure the __str__ still works too
- assert match.group(2) in [
- "",
- "?",
- "[]",
- ], "unrecognized alias analysis form with Tensor"
- type_s = "Tensor" + match.group(2)
- annotation = Annotation.parse(match.group(1))
- else:
- type_s = type_and_annot
- annotation = None
- type = Type.parse(type_s)
- r = Return(
- name=name,
- type=type,
- annotation=annotation,
- )
- assert str(r) == arg, f"{str(r)} != {arg}"
- return r
- @property
- def is_write(self) -> bool:
- return self.annotation is not None and self.annotation.is_write
- def __str__(self) -> str:
- type = f"{self.type}"
- if self.annotation:
- assert type in ["Tensor", "Tensor?", "Tensor[]"]
- type = type.replace("Tensor", f"Tensor({self.annotation})")
- if self.name is None:
- return type
- else:
- return f"{type} {self.name}"
- # Represents the self argument for functions that may be methods
- @dataclass(frozen=True)
- class SelfArgument:
- argument: Argument
- # Bundle of arguments that represent a TensorOptions. This is mostly
- # relevant for the public C++ API but we bake it into the core data
- # model because other APIs often have to interact with it
- @dataclass(frozen=True)
- class TensorOptionsArguments:
- dtype: Argument
- layout: Argument
- device: Argument
- pin_memory: Argument
- def all(self) -> Sequence[Argument]:
- return [self.dtype, self.layout, self.device, self.pin_memory]
- @dataclass(frozen=True)
- class Arguments:
- # pre_self_positional is usually empty, but is notably non-empty
- # for where.self, where the condition argument comes before the
- # self argument
- pre_self_positional: Tuple[Argument, ...]
- self_arg: Optional[SelfArgument]
- post_self_positional: Tuple[Argument, ...]
- pre_tensor_options_kwarg_only: Tuple[Argument, ...]
- tensor_options: Optional[TensorOptionsArguments]
- # post_tensor_options is typically memory format, which should be
- # part of tensor options but isn't right now, and is usually
- # placed after the tensor options arguments
- post_tensor_options_kwarg_only: Tuple[Argument, ...]
- # Unlike in the previous codegen, we have factored out 'out' arguments
- # in the canonical representation, removing them from kwarg
- # arguments. This choice is justified by numerous downstream
- # transformations which treat out arguments specially; additionally,
- # you can see that canonicity is not violated!
- out: Tuple[Argument, ...] # these are also kwarg-only
- @property
- def flat_non_out(self) -> Sequence[Argument]:
- ret: List[Argument] = []
- ret.extend(self.flat_positional)
- ret.extend(self.flat_kwarg_only)
- return ret
- @property
- def flat_positional(self) -> Sequence[Argument]:
- ret: List[Argument] = []
- ret.extend(self.pre_self_positional)
- if self.self_arg is not None:
- ret.append(self.self_arg.argument)
- ret.extend(self.post_self_positional)
- return ret
- @property
- def post_self_positional_mutable(self) -> Sequence[Argument]:
- return [a for a in self.post_self_positional if a.is_write]
- # NB: doesn't contain out arguments
- @property
- def flat_kwarg_only(self) -> Sequence[Argument]:
- ret: List[Argument] = []
- ret.extend(self.pre_tensor_options_kwarg_only)
- if self.tensor_options is not None:
- ret.extend(self.tensor_options.all())
- ret.extend(self.post_tensor_options_kwarg_only)
- return ret
- @property
- def flat_all(self) -> Sequence[Argument]:
- ret: List[Argument] = []
- ret.extend(self.flat_positional)
- ret.extend(self.flat_kwarg_only)
- ret.extend(self.out)
- return ret
- @property
- def non_out(
- self,
- ) -> Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]]:
- ret: List[Union[Argument, SelfArgument, TensorOptionsArguments]] = []
- ret.extend(self.positional)
- ret.extend(self.kwarg_only)
- return ret
- @property
- def positional(self) -> Sequence[Union[Argument, SelfArgument]]:
- ret: List[Union[Argument, SelfArgument]] = []
- ret.extend(self.pre_self_positional)
- if self.self_arg is not None:
- ret.append(self.self_arg)
- ret.extend(self.post_self_positional)
- return ret
- @property
- def kwarg_only(self) -> Sequence[Union[Argument, TensorOptionsArguments]]:
- ret: List[Union[Argument, TensorOptionsArguments]] = []
- ret.extend(self.pre_tensor_options_kwarg_only)
- if self.tensor_options is not None:
- ret.append(self.tensor_options)
- ret.extend(self.post_tensor_options_kwarg_only)
- return ret
- @property
- def all(self) -> Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]]:
- ret: List[Union[Argument, SelfArgument, TensorOptionsArguments]] = []
- ret.extend(self.positional)
- ret.extend(self.kwarg_only)
- ret.extend(self.out)
- return ret
- def mutable_arg_names(self) -> List[str]:
- return [
- a.name
- for a in self.flat_all
- if a.annotation is not None and a.annotation.is_write
- ]
- def has_tensor_arg(self) -> bool:
- return any(a.type.is_tensor_like() for a in self.flat_non_out)
- def has_symint_arg(self) -> bool:
- return any(a.type.is_symint_like() for a in self.flat_non_out)
- def has_generator_arg(self) -> bool:
- return any(a.type.is_generator_like() for a in self.flat_non_out)
- def signature(self, *, strip_default: bool = False) -> "Arguments":
- # dataclasses.replace could be used here, but it is less
- # type safe so for now I've opted to type everything out
- def strip_arg_annotation(a: Argument) -> Argument:
- return Argument(
- name=a.name,
- type=a.type,
- default=a.default if not strip_default else None,
- annotation=None,
- )
- return Arguments(
- pre_self_positional=tuple(
- map(strip_arg_annotation, self.pre_self_positional)
- ),
- self_arg=SelfArgument(strip_arg_annotation(self.self_arg.argument))
- if self.self_arg is not None
- else None,
- post_self_positional=tuple(
- map(strip_arg_annotation, self.post_self_positional)
- ),
- # Since TensorOptions are droped, the post_tensor_options_kwargs are
- # converted to pre_tensor_options_kwargs
- pre_tensor_options_kwarg_only=tuple(
- map(strip_arg_annotation, self.pre_tensor_options_kwarg_only)
- )
- + tuple(map(strip_arg_annotation, self.post_tensor_options_kwarg_only)),
- # TensorOptions are dropped in signature,
- # so we can pair factory functions with their out= variants.
- tensor_options=None,
- post_tensor_options_kwarg_only=tuple(),
- # out arguments are dropped in signature
- out=(),
- )
- def remove_self_annotation(self) -> "Arguments":
- assert self.self_arg is not None
- return dataclasses.replace(
- self,
- self_arg=SelfArgument(
- dataclasses.replace(self.self_arg.argument, annotation=None)
- ),
- )
- def with_out_args(self, outs: List[Argument]) -> "Arguments":
- assert len(self.out) == 0
- return dataclasses.replace(
- self,
- out=tuple(outs),
- )
- @staticmethod
- def _preparse(args: str) -> Tuple[List[Argument], List[Argument], List[Argument]]:
- positional: List[Argument] = []
- kwarg_only: List[Argument] = []
- out: List[Argument] = []
- arguments_acc = positional
- # TODO: Use a real parser here; this will get bamboozled
- # by signatures that contain things like std::array<bool, 2> (note the space)
- for arg in args.split(", "):
- if not arg:
- continue
- if arg == "*":
- assert (
- arguments_acc is positional
- ), "invalid syntax: kwarg-only specifier * can only occur once"
- arguments_acc = kwarg_only
- continue
- parg = Argument.parse(arg)
- # Currently, we rely directly on the invariant that there are NO
- # kwarg-only mutating arguments. If you want to relax this,
- # we will need a more semantic way of matching that takes
- # into account return arguments. In that case, you will have
- # to manage out computation a level up, in FunctionSchema. See Note
- # [is_out_fn]
- if parg.annotation is not None and parg.annotation.is_write:
- if arguments_acc is positional:
- pass # do nothing
- elif arguments_acc is kwarg_only:
- arguments_acc = out
- else:
- assert arguments_acc is not out
- arguments_acc.append(parg)
- return positional, kwarg_only, out
- @staticmethod
- def parse(args: str) -> "Arguments":
- """
- Input: 'int x, int y, int z'
- """
- # We do this in two phases. First we parse into three
- # main categories: positional, kwarg_only, out.
- # Then, we reparse positional and kwarg_only to separate
- # out the self argument and tensor options arguments.
- positional, kwarg_only, out = Arguments._preparse(args)
- # Split self argument
- self_ix = None
- for i, a in enumerate(positional):
- if a.name == "self":
- self_ix = i
- break
- pre_self_positional: List[Argument]
- self_arg: Optional[SelfArgument]
- post_self_positional: List[Argument]
- if self_ix is not None:
- pre_self_positional = positional[:self_ix]
- self_arg = SelfArgument(positional[self_ix])
- post_self_positional = positional[self_ix + 1 :]
- else:
- pre_self_positional = []
- self_arg = None
- post_self_positional = positional
- # Group tensor options arguments
- pre_tensor_options_kwarg_only: List[Argument] = []
- tensor_options: Optional[TensorOptionsArguments] = None
- post_tensor_options_kwarg_only: List[Argument] = []
- kwarg_only_acc = pre_tensor_options_kwarg_only
- def pred(name: str, ty: Type) -> Callable[[Argument], bool]:
- return lambda a: a.name == name and a.type in [ty, OptionalType(ty)]
- predicates = [ # order matters
- pred("dtype", Type.parse("ScalarType")),
- pred("layout", Type.parse("Layout")),
- pred("device", Type.parse("Device")),
- pred("pin_memory", Type.parse("bool")),
- ]
- i = 0
- while i < len(kwarg_only):
- # If there is enough space...
- if i <= len(kwarg_only) - len(predicates):
- # And the next len(predicates) arguments look like TensorOptions arguments
- if all(
- p(a)
- for p, a in zip(predicates, kwarg_only[i : i + len(predicates)])
- ):
- assert kwarg_only_acc is pre_tensor_options_kwarg_only
- # Group them together as one argument
- tensor_options = TensorOptionsArguments(
- dtype=kwarg_only[i],
- layout=kwarg_only[i + 1],
- device=kwarg_only[i + 2],
- pin_memory=kwarg_only[i + 3],
- )
- i += len(predicates)
- kwarg_only_acc = post_tensor_options_kwarg_only
- continue
- kwarg_only_acc.append(kwarg_only[i])
- i += 1
- return Arguments(
- pre_self_positional=tuple(pre_self_positional),
- self_arg=self_arg,
- post_self_positional=tuple(post_self_positional),
- pre_tensor_options_kwarg_only=tuple(pre_tensor_options_kwarg_only),
- tensor_options=tensor_options,
- post_tensor_options_kwarg_only=tuple(post_tensor_options_kwarg_only),
- out=tuple(out),
- )
- def __str__(self) -> str:
- all_arguments: List[str] = []
- all_arguments.extend(map(str, self.flat_positional))
- if self.flat_kwarg_only or self.out:
- all_arguments.append("*")
- all_arguments.extend(map(str, self.flat_kwarg_only))
- all_arguments.extend(map(str, self.out))
- return ", ".join(all_arguments)
- def __post_init__(self) -> None:
- # TODO: These invariants are weirdly asymmetric?
- # TODO: Fancier types?
- if self.self_arg is None:
- assert not self.pre_self_positional
- if self.tensor_options is None:
- assert not self.post_tensor_options_kwarg_only
- # We don't allow any of the following to have argument annotations,
- # to keep things simple.
- mutable_pre_self_positionals = [
- a
- for a in self.pre_self_positional
- if a.annotation is not None and a.annotation.is_write
- ]
- assert (
- len(mutable_pre_self_positionals) == 0
- ), "mutable pre_self_positional arguments are not currently supported in the schema"
- # Names that validly are __iXXX__ indicating inplace operations.
- # Taken from https://www.python.org/dev/peps/pep-0203/#new-methods
- # NB: PyTorch hasn't actually implemented all of these
- AUGMENTED_ASSIGNMENT_NAMES = [
- "add",
- "sub",
- "mul",
- "div",
- "mod",
- "pow",
- "lshift",
- "rshift",
- "and",
- "xor",
- "or",
- ]
- # A BaseOperatorName is what we think of the operator name, without
- # the overload name. Unusually, we don't represent this as just a
- # string; instead, we directly represent a few important semantic
- # bits of information we derive from the string: namely whether
- # or not it's inplace (add_) and whether or not it's a double-underscore
- # method (__add__)
- @dataclass(frozen=True)
- class BaseOperatorName:
- base: str
- inplace: bool
- dunder_method: bool
- # Note [Overload Ambiguity With Functional Variants]
- # A handful of operators have both a "mutable" and a "functional" variant.
- # (native_batch_norm is a good example, although this isn't the case today).
- # For those operators, the mutable and functional variant take in the same set of
- # arguments, but have different alias annotations.
- # this makes it ambiguous when you try to resolve an OverloadPacket into an overload,
- # given a set of input arguments.
- #
- # So instead of making the "functional" variant in this case a real overload, e.g:
- # native_batch_norm (mutable variant)
- # native_batch_norm.functional (functional variant)
- # we make it a new base operator,
- # native_batch_norm_functional (functional variant)
- #
- # In an ideal world, we would probably invert this so the operators were:
- # native_batch_norm.mutable (mutable variant)
- # native_batch_norm (functional variant)
- #
- # Doing that is BC-breaking though, so we're stuck with the above modeling.
- functional_overload: bool = False
- @staticmethod
- def parse(op: str) -> "BaseOperatorName":
- assert op != ""
- assert not op.endswith("_out"), (
- "_out suffix is reserved and not permitted for operator names; "
- "did you mean to specify an out overload name instead?"
- )
- m = re.match(r"^__([^_]+)__$", op)
- if m is not None:
- dunder_method = True
- base = m.group(1)
- if any(base == f"i{n}" for n in AUGMENTED_ASSIGNMENT_NAMES):
- inplace = True
- base = base[1:]
- else:
- inplace = False
- # temporary, this is not intrinsically true but
- # has been historically true for dunder methods
- # we support (but, if we ever got, say, __int__, this would
- # be wrong!)
- assert base[0] != "i"
- else:
- dunder_method = False
- base = op
- if base[-1] == "_":
- inplace = True
- base = base[:-1]
- else:
- inplace = False
- # See Note [Overload Ambiguity With Functional Variants]
- functional_suffix = "_functional"
- if base.endswith(functional_suffix):
- functional_overload = True
- base = base[: -len(functional_suffix)]
- # This seems complicated and unnecessary, so banning dunder methods
- # for now on ops that have a functional + mutable variant (like native_batch_norm).
- assert not dunder_method and not inplace
- else:
- functional_overload = False
- r = BaseOperatorName(
- base=base,
- inplace=inplace,
- dunder_method=dunder_method,
- functional_overload=functional_overload,
- )
- assert str(r) == op, f"{str(r)} != {op}"
- return r
- def __str__(self) -> str:
- if self.dunder_method:
- i = "i" if self.inplace else ""
- return f"__{i}{self.base}__"
- else:
- i = (
- "_"
- if self.inplace
- else "_functional"
- if self.functional_overload
- else ""
- )
- return f"{self.base}{i}"
- # Operator name is the base operator name along with the (typically not
- # user visible) overload string.
- @dataclass(frozen=True)
- class OperatorName:
- name: BaseOperatorName
- overload_name: str
- @staticmethod
- def parse(op_name: str) -> "OperatorName":
- if "." in op_name:
- name, overload_name = op_name.split(".", 1)
- else:
- name = op_name
- overload_name = ""
- r = OperatorName(name=BaseOperatorName.parse(name), overload_name=overload_name)
- assert str(r) == op_name, f"{str(r)} != {op_name}"
- return r
- def __str__(self) -> str:
- if self.overload_name:
- return f"{self.name}.{self.overload_name}"
- else:
- return f"{self.name}"
- # NB: This must be synchronized with the naming scheme in
- # aten/src/ATen/templates/Operators.h
- # Given a function schema "aten::op.overload(...)",
- # If there is no overload name, this returns f"{op}"
- # If there is an overload name, this returns f"{op}_{overload}"
- def unambiguous_name(self) -> str:
- if self.overload_name:
- return f"{self.name}_{self.overload_name}"
- else:
- return f"{self.name}"
- def remove_inplace(self) -> "OperatorName":
- return OperatorName(
- name=BaseOperatorName(
- base=self.name.base,
- inplace=False,
- dunder_method=self.name.dunder_method,
- ),
- overload_name=self.overload_name,
- )
- def with_overload(self, overload: str) -> "OperatorName":
- return OperatorName(
- name=BaseOperatorName(
- base=self.name.base,
- inplace=False,
- dunder_method=self.name.dunder_method,
- ),
- overload_name=overload,
- )
- def gets_generated_out_inplace_wrapper(
- f: NativeFunction, g: NativeFunctionsGroup, b: BackendIndex
- ) -> bool:
- return (
- f.func.kind() is not SchemaKind.functional
- and not b.has_kernel(f)
- and b.has_kernel(g.functional)
- )
- # NativeFunction objects that are views (f.is_view_op returns True)
- # are added into a `NativeFunctionsViewGroup`, which we can use to
- # easily access the generated (optional) view_copy NativeFunction.
- # It's convenient to group them together, so we pair them up in NativeFunctionsViewGroup.
- # See Note [Codegen'd {view}_copy Operators]
- #
- # One property of this representation is that in order for a view-like op to be part of
- # a NativeFunctionsViewGroup, the "aliasing" version of that view op must exist.
- # There's one case where that doesn't happen: we have a non-aliasing `narrow_copy.out` op,
- # but don't have corresponding aliasing `narrow.out` op.
- # This means that `narrow_copy.out` won't appear as a NativeFunctionsViewGroup.
- @dataclass(frozen=True)
- class NativeFunctionsViewGroup:
- view: NativeFunction
- # Note: the {view}_copy operator is optional because we currently don't generate copy variants
- # for all view ops. Notably, we don't generate them for CompositeImplicitAutograd views
- # (we already get them "for free" through decomposition)
- view_copy: Optional[NativeFunction]
- # view_inplace ops are also optional, but every view_inplace op should have out-of-place variant.
- view_inplace: Optional[NativeFunction]
- def __post_init__(self) -> None:
- assert self.view.is_view_op
- if self.view_copy is None:
- assert not gets_generated_view_copy(self.view), (
- f"{str(self.view.func.name)} appears to be a new operator that aliases its inputs."
- " The codegen expects you to add a corresponding operator to native_functions.yaml:"
- f" {get_view_copy_name(self.view)!s}."
- " See Note [view_copy NativeFunctions] for details."
- )
- else:
- assert self.view_copy.func.name.name.base.endswith("_copy")
- assert self.view.func.signature() == self.view_copy.func.signature(
- strip_view_copy_name=True
- )
- assert "view_copy" in self.view_copy.tags, (
- f"{str(self.view_copy.func.name), str(self.view.tags)} appears to be a view_copy operator. The codegen expects"
- " view_copy operators to be annotated with the 'view_copy' tag in native_functions.yaml."
- " See Note [view_copy NativeFunction] for details."
- )
- if self.view_inplace is not None:
- assert self.view.func.signature() == self.view_inplace.func.signature()
- if self.view.has_composite_implicit_autograd_kernel:
- if self.view_inplace is not None:
- assert self.view_inplace.has_composite_implicit_autograd_kernel, (
- f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either"
- " both have CompositeImplicitAutograd kernels, or both not have composite kernels."
- )
- if self.view.has_composite_implicit_autograd_nested_tensor_kernel:
- if self.view_inplace is not None:
- assert (
- self.view_inplace.has_composite_implicit_autograd_nested_tensor_kernel
- ), (
- f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either"
- " both have CompositeImplicitAutogradNestedTensor kernels, or both not have composite kernels."
- )
- def functions(self, *, include_copy: bool = True) -> Iterator[NativeFunction]:
- yield self.view
- if self.view_inplace is not None:
- yield self.view_inplace
- if self.view_copy is not None and include_copy:
- yield self.view_copy
- @property
- def root_name(self) -> str:
- return self.view.root_name
- @property
- def composite(self) -> bool:
- # We currently assert that the "group" is consistent.
- # If the view op is composite, then its view_inplace op is too.
- return self.view.has_composite_implicit_autograd_kernel
- def gets_generated_view_copy(f: NativeFunction) -> bool:
- # Only aliasing (view) operators get a copy variant.
- if not f.is_view_op:
- return False
- # We don't need to bother generating copy variants for CompositeImplicitAutograd ops,
- # because we can let them decompose into base view ops.
- if f.has_composite_implicit_autograd_kernel:
- return False
- # We also don't need to generate copy variants for inplace views.
- if "inplace_view" in f.tags:
- return False
- return True
- # Given a NativeFunction that corresponds to a view op,
- # returns the OperatorName of the corresponding "copy" variant of the op.
- def get_view_copy_name(f: NativeFunction) -> "OperatorName":
- # Right now, when asking for a view op's corresponding "view_copy" name
- # we assert for sanity that the op is allowed to have a generated view_copy variant.
- # (We can do this because "gets_generated_view_copy()" tell us which ops get a generated view_copy op).
- # However, narrow_copy() already exists as an op directly in native_functions.yaml.
- # I'm hardcoding narrow_copy here for now to maintain the assert,
- # But we could also just get rid of the assert.
- list_of_ops_with_explicit_view_copy_operators = ["narrow"]
- if str(f.func.name) not in list_of_ops_with_explicit_view_copy_operators:
- assert gets_generated_view_copy(f)
- base_name = f"{f.func.name.name.base}_copy"
- view_copy_name = OperatorName(
- name=BaseOperatorName(
- base=base_name, inplace=False, dunder_method=f.func.name.name.dunder_method
- ),
- overload_name=f.func.name.overload_name,
- )
- return view_copy_name
- # Helper functions for parsing argument lists (both inputs and returns)
- def parse_returns(return_decl: str) -> Tuple[Return, ...]:
- """
- Input: '()'
- Output: []
- """
- if return_decl == "()":
- return ()
- if return_decl[0] == "(" and return_decl[-1] == ")":
- return_decl = return_decl[1:-1]
- return tuple(Return.parse(arg) for arg in return_decl.split(", "))
- # A Precompute instance consists of a map from kernel argument name
- # to the list of Argument instances that should replace that
- # kernel argument in the impl function.
- @dataclass(frozen=True)
- class Precompute:
- # A map from kernel argument name -> a list of precomputed
- # elements that replaces/supersedes it.
- replace: Dict[str, List[Argument]]
- # List of precomputed args added without replacement
- add: List[Argument]
- @staticmethod
- def parse(src: object) -> "Precompute":
- assert isinstance(src, list)
- # src is a list of strings of the format:
- # {kernel param name} -> {replacement decl}[, {replacement decl}, ...]
- # [{add decl}[, {add decl}, ...]]
- # The last line is optional and contains the precomputed parameters that are
- # added without replacement.
- # The other lines are parsed to get the names of which precomputed elements
- # should replace which kernel arguments.
- add_args = []
- if " -> " not in src[-1]:
- add_list = src[-1].split(",")
- add_args = [Argument.parse(name.strip()) for name in add_list]
- src = src[:-1]
- replace = {}
- for raw_replace_item in src:
- assert isinstance(raw_replace_item, str)
- assert " -> " in raw_replace_item, (
- "precomputed parameters without replacement"
- " are allowed only in the last line"
- )
- arg, with_list_raw = raw_replace_item.split(" -> ")
- with_list = with_list_raw.split(",")
- with_list_args = [Argument.parse(name.strip()) for name in with_list]
- replace[arg] = with_list_args
- r = Precompute(replace=replace, add=add_args)
- assert r.to_list() == src, "r.to_list() != src"
- return r
- def __post_init__(self) -> None:
- # the template parameters are upper so if these are the
- # same then it is ambiguous
- for a in self.add:
- assert a.name.upper() != a.name
- for args in self.replace.values():
- for a in args:
- assert a.name.upper() != a.name
- def to_list(self) -> List[str]:
- replace_list = []
- for kernel_param, replacement_params in self.replace.items():
- replacements = ", ".join(str(param) for param in replacement_params)
- replace_list.append(f"{kernel_param} -> {replacements}")
- return replace_list
|