12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932393339343935393639373938393939403941394239433944394539463947394839493950395139523953395439553956395739583959396039613962396339643965396639673968396939703971397239733974397539763977397839793980398139823983398439853986398739883989399039913992399339943995399639973998399940004001400240034004400540064007400840094010401140124013401440154016401740184019402040214022402340244025402640274028402940304031403240334034403540364037403840394040404140424043404440454046404740484049405040514052405340544055405640574058405940604061406240634064406540664067406840694070407140724073407440754076407740784079408040814082408340844085408640874088408940904091409240934094409540964097409840994100410141024103410441054106410741084109411041114112411341144115411641174118411941204121412241234124412541264127412841294130413141324133413441354136413741384139414041414142414341444145414641474148414941504151415241534154415541564157415841594160416141624163416441654166416741684169417041714172417341744175417641774178417941804181418241834184418541864187418841894190419141924193419441954196419741984199420042014202420342044205420642074208420942104211421242134214421542164217421842194220422142224223422442254226422742284229423042314232423342344235423642374238423942404241424242434244424542464247424842494250425142524253425442554256425742584259426042614262426342644265426642674268426942704271427242734274427542764277427842794280428142824283428442854286428742884289429042914292429342944295429642974298429943004301430243034304430543064307430843094310431143124313431443154316431743184319432043214322432343244325432643274328432943304331433243334334433543364337433843394340434143424343434443454346434743484349435043514352435343544355435643574358435943604361436243634364436543664367436843694370437143724373437443754376437743784379438043814382438343844385438643874388438943904391439243934394439543964397439843994400440144024403440444054406440744084409441044114412441344144415441644174418441944204421442244234424442544264427442844294430443144324433443444354436443744384439444044414442444344444445444644474448444944504451445244534454445544564457445844594460446144624463446444654466446744684469447044714472447344744475447644774478447944804481448244834484448544864487448844894490449144924493449444954496449744984499450045014502450345044505450645074508450945104511451245134514451545164517451845194520452145224523452445254526452745284529453045314532453345344535453645374538453945404541454245434544454545464547454845494550455145524553455445554556455745584559456045614562456345644565456645674568456945704571457245734574457545764577457845794580458145824583458445854586458745884589459045914592459345944595459645974598459946004601460246034604460546064607460846094610461146124613461446154616461746184619462046214622462346244625462646274628462946304631463246334634463546364637463846394640464146424643464446454646464746484649465046514652465346544655465646574658465946604661466246634664466546664667466846694670467146724673467446754676467746784679468046814682468346844685468646874688468946904691469246934694469546964697469846994700470147024703470447054706470747084709471047114712471347144715471647174718471947204721472247234724472547264727472847294730473147324733473447354736473747384739474047414742474347444745474647474748474947504751475247534754475547564757475847594760476147624763476447654766476747684769477047714772477347744775477647774778477947804781478247834784478547864787478847894790479147924793479447954796479747984799480048014802480348044805480648074808480948104811481248134814481548164817481848194820482148224823482448254826482748284829483048314832483348344835483648374838483948404841484248434844484548464847484848494850485148524853485448554856485748584859486048614862486348644865486648674868486948704871487248734874487548764877487848794880488148824883488448854886488748884889489048914892489348944895489648974898489949004901490249034904490549064907490849094910491149124913491449154916491749184919492049214922492349244925492649274928492949304931493249334934493549364937493849394940494149424943494449454946494749484949495049514952495349544955495649574958495949604961496249634964496549664967496849694970497149724973497449754976497749784979498049814982498349844985498649874988498949904991499249934994499549964997499849995000500150025003500450055006500750085009501050115012501350145015501650175018501950205021502250235024502550265027502850295030503150325033503450355036503750385039504050415042504350445045504650475048504950505051505250535054505550565057505850595060506150625063506450655066506750685069507050715072507350745075507650775078507950805081508250835084508550865087508850895090509150925093509450955096509750985099510051015102510351045105510651075108510951105111511251135114511551165117511851195120512151225123512451255126512751285129513051315132513351345135513651375138513951405141514251435144514551465147514851495150515151525153515451555156515751585159516051615162516351645165516651675168516951705171517251735174517551765177517851795180518151825183518451855186518751885189519051915192519351945195519651975198519952005201520252035204520552065207520852095210521152125213521452155216521752185219522052215222522352245225522652275228522952305231523252335234523552365237523852395240524152425243524452455246524752485249525052515252525352545255525652575258525952605261526252635264526552665267526852695270527152725273527452755276527752785279528052815282528352845285528652875288528952905291529252935294529552965297529852995300530153025303530453055306530753085309531053115312531353145315531653175318531953205321532253235324532553265327532853295330533153325333533453355336533753385339534053415342534353445345534653475348534953505351535253535354535553565357535853595360536153625363536453655366536753685369537053715372537353745375537653775378537953805381538253835384538553865387538853895390539153925393539453955396539753985399540054015402540354045405540654075408540954105411541254135414541554165417541854195420542154225423542454255426542754285429543054315432543354345435543654375438543954405441 |
- import builtins
- import collections
- import math
- import operator
- import warnings
- from collections.abc import Iterable
- from enum import Enum
- from functools import partial, reduce, singledispatch, wraps
- from typing import Callable, List, Optional, overload, Sequence, Tuple, Union
- import torch
- import torch._prims as prims
- import torch._prims_common as utils
- from torch import sym_float, sym_int
- from torch._prims_common import (
- check,
- DeviceLikeType,
- Dim,
- DimsSequenceType,
- DimsType,
- dtype_to_type,
- ELEMENTWISE_TYPE_PROMOTION_KIND,
- FloatLike,
- FloatWithoutSymFloat,
- IntLike,
- is_weakly_lesser_type,
- Number,
- NumberType,
- REDUCTION_OUTPUT_TYPE_KIND,
- ShapeType,
- StrideType,
- TensorLike,
- TensorLikeType,
- TensorOrNumberLikeType,
- TensorSequenceType,
- )
- from torch._prims_common.wrappers import (
- _maybe_convert_to_dtype,
- _maybe_resize_out,
- _safe_copy_out,
- elementwise_type_promotion_wrapper,
- elementwise_unary_scalar_wrapper,
- out_wrapper,
- )
- # Experimental module containing prototype Python references for existing
- # PyTorch operations.
- __all__ = [
- #
- # Elementwise Unary References
- #
- "abs",
- "acos",
- "acosh",
- "asinh",
- "asin",
- "atan",
- "atanh",
- "bitwise_not",
- # "cbrt", # No corresponding torch operation
- "ceil",
- "conj_physical",
- "cos",
- "cosh",
- "digamma",
- "erf",
- "erfinv",
- "erfc",
- "exp",
- "expm1",
- "exp2",
- "fill",
- "floor",
- "frac",
- "index_add",
- "index_copy",
- "index_copy_",
- "index_select",
- "index_fill",
- "index_fill_",
- "isfinite",
- "isinf",
- "isposinf",
- "isneginf",
- "isnan",
- "isreal",
- "i0",
- "lerp",
- "lgamma",
- "log",
- "log1p",
- "log2",
- "log10",
- "log_softmax",
- "nan_to_num",
- "neg",
- "positive",
- "reciprocal",
- "round", # TODO: model kwargs
- "sigmoid",
- "sgn",
- "sign",
- "signbit",
- "sin",
- "sinc",
- "sinh",
- "softmax",
- "sqrt",
- "square",
- "tan",
- "tanh",
- "trace",
- "trunc",
- #
- # Elementwise Binary References
- #
- "add",
- "atan2",
- "bitwise_and",
- "bitwise_left_shift",
- "bitwise_or",
- "bitwise_right_shift",
- "bitwise_xor",
- "clamp_min",
- "clamp_max",
- "copysign",
- "div",
- "eq",
- "float_power",
- "floor_divide",
- "fmax",
- "fmin",
- "fmod",
- "gcd",
- "ge",
- "gt",
- "heaviside",
- "hypot",
- "igamma",
- "igammac",
- "imag",
- "isclose",
- "lcm",
- # 'ldexp',
- "le",
- "logical_and",
- "logical_not",
- "logical_or",
- "logical_xor",
- "lt",
- # 'max', # implement with reductions
- "maximum",
- # 'min', # implement with reductions
- "minimum",
- "mul",
- "ne",
- "nextafter",
- # 'polar', # abs, cos, sin
- "pow",
- "real",
- "rpow",
- "remainder",
- "rsub",
- "rtruediv",
- "rfloordiv",
- "sub",
- "true_divide",
- "trunc_divide",
- "xlogy",
- #
- # Elementwise Ternary References
- #
- "addcdiv",
- "addcmul",
- "clamp",
- #
- # Conditional references
- #
- "masked_fill",
- "where",
- #
- # Data conversion and movement references
- #
- "clone",
- "copy_to", # TODO: add OpInfo (or implement .to)
- "item", # TODO: add OpInfo
- "to",
- #
- # Reduction ops
- #
- "all",
- "amax",
- "amin",
- "any",
- "mean",
- "std",
- "std_mean",
- "sum",
- "sum_to_size",
- "prod",
- "var",
- "var_mean",
- #
- # Linear algebra ops
- #
- "addr",
- #
- # View & Shape Ops
- #
- "atleast_1d",
- "atleast_2d",
- "atleast_3d",
- "as_strided",
- "broadcast_shapes",
- "broadcast_tensors",
- "broadcast_to",
- "cat",
- "chunk",
- "column_stack",
- "conj",
- "constant_pad_nd",
- "contiguous",
- "diag_embed",
- "diag",
- "diagonal",
- "diagonal_copy",
- "diagonal_scatter",
- "dsplit",
- "dstack",
- "expand",
- "expand_as",
- "flatten",
- "flip",
- "fliplr",
- "flipud",
- "hsplit",
- "hstack",
- "meshgrid",
- "movedim",
- "narrow",
- "narrow_copy",
- "native_group_norm",
- "native_layer_norm",
- "permute",
- "ravel",
- "repeat",
- "reshape",
- "roll",
- "rot90",
- "rsqrt",
- "stack",
- "swap_axes", # alias for transpose
- "squeeze",
- "t",
- "T",
- "tensor_split",
- "transpose",
- "unfold",
- "unfold_copy",
- "unsqueeze",
- "view",
- "vsplit",
- "vstack",
- "unflatten",
- "unbind",
- "triu",
- "tril",
- "triu_indices",
- "tril_indices",
- #
- # Tensor Creation
- #
- "arange",
- "empty",
- "empty_like",
- "empty_strided",
- "eye",
- "full",
- "full_like",
- "linspace",
- "logspace",
- "ones",
- "ones_like",
- "randn",
- "scalar_tensor",
- "zeros",
- "zeros_like",
- #
- # Test-related functions
- #
- "allclose",
- "equal", # TODO: add OpInfo
- #
- # Statistical operations
- #
- "bucketize",
- ]
- Tensor = torch.Tensor
- DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
- aten = torch._ops.ops.aten
- def _broadcast_shapes(*_shapes):
- shapes = tuple(
- (x,) if isinstance(x, IntLike) else x
- for x in filter(lambda x: x is not None, _shapes)
- )
- # Short-circuits on no input
- if len(shapes) == 0:
- return None
- # Type checking
- # TODO: make common validations available as utils
- for shape in shapes:
- assert isinstance(shape, Sequence)
- # Computes common shape
- common_shape = [
- 1,
- ] * reduce(max, (len(shape) for shape in shapes))
- for arg_idx, shape in enumerate(shapes):
- for idx in range(-1, -1 - len(shape), -1):
- if common_shape[idx] == 1:
- if shape[idx] < 0:
- raise ValueError(
- "Attempting to broadcast a dimension with negative length!"
- )
- common_shape[idx] = shape[idx]
- elif shape[idx] != 1:
- if common_shape[idx] != shape[idx]:
- raise RuntimeError(
- f"Attempting to broadcast a dimension of length {shape[idx]} at {idx}! "
- f"Mismatching argument at index {arg_idx} had {shape}; but expected shape "
- f"should be broadcastable to {common_shape}"
- )
- return common_shape
- def _maybe_broadcast(*args, preserve_cpu_scalar_tensors=True):
- # Computes common shape
- common_shape = _broadcast_shapes(
- *map(lambda t: t.shape if isinstance(t, TensorLike) else None, args)
- )
- def __maybe_broadcast(x, shape):
- if x is None:
- return None
- elif isinstance(x, Number):
- return x
- elif isinstance(x, TensorLike):
- if preserve_cpu_scalar_tensors and utils.is_cpu_scalar_tensor(x):
- return x
- if not utils.same_shape(x.shape, common_shape):
- return x.expand(common_shape)
- return x
- else:
- raise RuntimeError(
- "Unexpected type when broadcasting: " + str(type(x)) + "!"
- )
- return tuple(__maybe_broadcast(x, common_shape) for x in args)
- # Utilities should come BEFORE this import
- from torch._decomp import register_decomposition
- #
- # Elementwise unary references
- #
- infer_aten_op = object()
- # TODO: add type promotion support
- def _make_elementwise_unary_reference(
- type_promotion_kind,
- *,
- aten_op=infer_aten_op,
- extra_meta=None,
- ) -> Callable:
- def inner(prim: Callable):
- nonlocal aten_op
- @wraps(prim)
- @out_wrapper()
- @elementwise_unary_scalar_wrapper
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("a",),
- type_promotion_kind=type_promotion_kind,
- )
- def _ref(a: TensorLikeType) -> TensorLikeType:
- if extra_meta is not None:
- extra_meta(a)
- return prim(a)
- if aten_op is infer_aten_op:
- aten_op = utils.get_aten_op(prim, prim.__name__)
- if aten_op is not None:
- register_decomposition(aten_op)(_ref)
- return _ref
- return inner
- def _make_alias(fn, name):
- """
- This function defines an alias of another function and sets its __name__argument
- Note that when naïvely doing `alias = fn`, we have that `alias.__name__ == "fn"`.
- """
- def _fn(*args, **kwargs):
- return fn(*args, **kwargs)
- _fn.__name__ = name
- return _fn
- def _make_inplace(fn):
- """
- Given a function with out variant (i.e. using `out_wrapper()), it returns its in-place variant
- See https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-do-in-place-operations-work-in-pytorch
- """
- # nb. We use the name of the first argument used in the unary references
- @wraps(fn)
- def _fn(a, *args, **kwargs):
- return fn(a, *args, out=a, **kwargs)
- inplace_name = f"{fn.__name__}_"
- _fn.__name__ = inplace_name
- _fn = register_decomposition(getattr(aten, inplace_name))(_fn)
- # We access the __all__ attribute of the module where fn is defined
- # There may be a cleaner way of doing this...
- from inspect import getmodule
- _all = getmodule(fn).__all__ # type: ignore[union-attr]
- if inplace_name not in _all:
- _all.append(inplace_name)
- return _fn
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT)
- def abs(a):
- return prims.abs(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def acos(a):
- return prims.acos(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def acosh(a):
- return prims.acosh(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def asin(a):
- return prims.asin(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def asinh(a):
- return prims.asinh(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def atan(a):
- return prims.atan(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def atanh(a):
- return prims.atanh(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
- def bitwise_not(a):
- return prims.bitwise_not(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
- def ceil(a):
- return prims.ceil(a)
- @register_decomposition(aten.conj_physical)
- @out_wrapper()
- def conj_physical(input: TensorLikeType):
- if not utils.is_complex_dtype(input.dtype):
- return input
- return prims.conj_physical(input)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def cos(a):
- return prims.cos(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def cosh(a):
- return prims.cosh(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def digamma(a):
- return prims.digamma(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def erf(a):
- return prims.erf(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def erfinv(a):
- return prims.erf_inv(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def erfc(a):
- return prims.erfc(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def exp(a):
- return prims.exp(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def expm1(a):
- return prims.expm1(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def exp2(a):
- return prims.exp2(a)
- # Fill has its own implementation because it has a value parameter
- # CompositeImplicitAutograd - don't register decomp
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("a,"),
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
- )
- def fill(a: TensorLikeType, value: NumberType) -> TensorLikeType:
- assert isinstance(a, TensorLike)
- assert isinstance(value, Number)
- python_type = utils.dtype_to_type(a.dtype)
- if not utils.is_weakly_lesser_type(type(value), python_type):
- msg = "value argument of type {0} cannot be safely cast to type {1}!".format(
- type(value), python_type
- )
- raise ValueError(msg)
- return prims.fill(a, value)
- def fill_(a: TensorLikeType, value: NumberType) -> TensorLikeType:
- r = prims.fill(a, value)
- prims.copy_to(a, r)
- return a
- @register_decomposition(aten.zero)
- @out_wrapper()
- def zero(input: TensorLikeType) -> TensorLikeType:
- return torch.zeros_like(input)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
- def floor(a):
- return prims.floor(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
- def frac(x: TensorLikeType) -> TensorLikeType:
- trunc_x = torch.mul(torch.floor(torch.abs(x)), torch.sign(x))
- return torch.sub(x, trunc_x)
- # imag does not use _make_elementwise_unary_reference because it does not support out
- def imag(a: TensorLikeType) -> TensorLikeType:
- assert isinstance(a, TensorLike)
- utils.check(
- utils.is_complex_dtype(a.dtype), lambda: "imag only supports complex tensors."
- )
- return prims.imag(a)
- @_make_elementwise_unary_reference(
- ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
- aten_op=None, # CompositeImplicitAutograd
- )
- def isfinite(a: TensorLikeType) -> TensorLikeType:
- if utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype):
- return prims.isfinite(a)
- return ones_like(a, dtype=torch.bool)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
- def isinf(a: TensorLikeType) -> TensorLikeType:
- if utils.is_complex_dtype(a.dtype):
- return torch.logical_or(isinf(torch.real(a)), isinf(torch.imag(a)))
- if utils.is_float_dtype(a.dtype):
- return torch.abs(a) == float("inf")
- return torch.zeros_like(a, dtype=torch.bool)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
- def isposinf(a: TensorLikeType) -> TensorLikeType:
- utils.check(
- not utils.is_complex_dtype(a.dtype),
- lambda: f"Complex dtype is not supported for isposinf, got dtype {a.dtype}",
- )
- if utils.is_float_dtype(a.dtype):
- return a == float("inf")
- return torch.zeros_like(a, dtype=torch.bool)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
- def isneginf(a: TensorLikeType) -> TensorLikeType:
- utils.check(
- not utils.is_complex_dtype(a.dtype),
- lambda: f"Complex dtype is not supported for isneginf, got dtype {a.dtype}",
- )
- if utils.is_float_dtype(a.dtype):
- return a == float("-inf")
- return torch.zeros_like(a, dtype=torch.bool)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
- def isnan(a: TensorLikeType) -> TensorLikeType:
- return prims.ne(a, a)
- # alias
- mvlgamma = _make_alias(torch.special.multigammaln, "mvlgamma") # type: ignore[has-type]
- @_make_elementwise_unary_reference(
- ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
- aten_op=None, # CompositeImplicitAutograd
- )
- def isreal(a: TensorLikeType) -> TensorLikeType:
- if utils.is_complex_dtype(a.dtype):
- return torch.imag(a) == 0
- return torch.ones_like(a, dtype=torch.bool)
- # TODO: if this is special maybe it should be defined there and imported here?
- @_make_elementwise_unary_reference(
- ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, aten_op=aten.special_i0
- )
- def i0(a):
- return prims.bessel_i0(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def lgamma(a):
- return prims.lgamma(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def log(a):
- return prims.log(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def log1p(a):
- return prims.log1p(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def log2(a):
- return prims.log2(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def log10(a):
- return prims.log10(a)
- # CompositeImplicitAutograd - don't register decomp
- @out_wrapper()
- def log_softmax(
- a: TensorLikeType,
- dim: int,
- dtype: Optional[torch.dtype] = None,
- ) -> TensorLikeType:
- result_dtype = dtype or a.dtype
- computation_dtype = utils.get_computation_dtype(result_dtype)
- a_ = _maybe_convert_to_dtype(a, computation_dtype)
- return _maybe_convert_to_dtype(a_ - logsumexp(a_, dim, keepdim=True), result_dtype) # type: ignore[return-value]
- @register_decomposition(aten.logsumexp)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("self",),
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- )
- def logsumexp(
- self: TensorLikeType, dim: DimsType, keepdim: bool = False
- ) -> TensorLikeType:
- if not isinstance(dim, Iterable):
- dim = (dim,)
- if self.numel() == 0:
- return torch.sum(torch.exp(self), dim, keepdim).log()
- maxes = torch.amax(self, dim, keepdim=True)
- maxes = torch.masked_fill(maxes, maxes.abs() == float("inf"), 0)
- maxes_squeezed = maxes if keepdim else torch.squeeze(maxes, dim)
- result = torch.sum(torch.exp(self - maxes), dim, keepdim)
- return result.log().add(maxes_squeezed)
- @register_decomposition(aten.nan_to_num)
- @out_wrapper()
- def nan_to_num(
- a: TensorLikeType,
- nan: Optional[NumberType] = 0.0,
- posinf: Optional[NumberType] = None,
- neginf: Optional[NumberType] = None,
- ) -> TensorLikeType:
- assert isinstance(a, TensorLike)
- if utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype):
- return a.clone()
- if nan is None:
- nan = 0.0
- if posinf is None:
- posinf = torch.finfo(a.dtype).max
- if neginf is None:
- neginf = torch.finfo(a.dtype).min
- result = torch.where(torch.isnan(a), nan, a) # type: ignore[call-overload]
- result = torch.where(torch.isneginf(a), neginf, result) # type: ignore[call-overload]
- result = torch.where(torch.isposinf(a), posinf, result) # type: ignore[call-overload]
- return result
- def _neg_meta(a: TensorLikeType):
- check(
- a.dtype is not torch.bool,
- lambda: (
- "Negation, the `-` operator, on a bool tensor is not supported. "
- "If you are trying to invert a mask, use the `~` or `logical_not()` "
- "operator instead."
- ),
- )
- @_make_elementwise_unary_reference(
- ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, extra_meta=_neg_meta
- )
- def neg(a):
- return prims.neg(a)
- # positive does not use _make_elementwise_unary_reference because it does not support out
- # CompositeImplicitAutograd - don't register decomp
- def positive(a: TensorLikeType) -> TensorLikeType:
- assert isinstance(a, TensorLike)
- if a.dtype is torch.bool:
- msg = "positive does not support bool tensors."
- raise RuntimeError(msg)
- return a
- # real does not use _make_elementwise_unary_reference because it does not support out
- def real(a: TensorLikeType) -> TensorLikeType:
- assert isinstance(a, TensorLike)
- if utils.is_complex_dtype(a.dtype):
- return prims.real(a)
- return a
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def reciprocal(a):
- return prims.reciprocal(a)
- # TODO: round takes additional kwargs
- @_make_elementwise_unary_reference(
- ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- aten_op=None, # TODO: this does need a decomp, but kwarg handling is needed
- )
- def round(a):
- return prims.round(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def rsqrt(a):
- return prims.rsqrt(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def sigmoid(a: TensorLikeType) -> TensorLikeType:
- return true_divide(1, add(1, exp(neg(a))))
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
- def sgn(a):
- if utils.is_complex_dtype(a.dtype):
- a_abs = a.abs()
- return torch.where(a_abs == 0, 0, a / a_abs)
- else:
- return a.sign()
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
- def sign(a):
- return prims.sign(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
- def signbit(a):
- return prims.signbit(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def sin(a):
- return prims.sin(a)
- # Autograd note: This will give the right first derivative at zero (by chance),
- # but not the right second derivative
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def sinc(a):
- a = math.pi * a
- return torch.where(a == 0, 1, torch.sin(a) / a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def sinh(a):
- return prims.sinh(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def sqrt(a):
- return prims.sqrt(a)
- @_make_elementwise_unary_reference(
- ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG,
- aten_op=None, # CompositeImplicitAutograd,
- )
- def square(a: TensorLikeType) -> TensorLikeType:
- return mul(a, a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def tan(a):
- return prims.tan(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
- def tanh(a):
- return prims.tanh(a)
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
- def trunc(a):
- return prims.trunc(a)
- def _make_elementwise_binary_reference(
- type_promotion_kind,
- aten_op=infer_aten_op,
- name=None,
- has_out=True,
- supports_lhs_python_scalar=True,
- supports_rhs_python_scalar=True,
- supports_two_python_scalars=False,
- ) -> Callable:
- def inner(prim: Callable):
- nonlocal aten_op, name
- if name is None:
- name = prim.__name__
- @wraps(prim)
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("a", "b"),
- type_promotion_kind=type_promotion_kind,
- )
- def _ref(
- a: Union[Tensor, NumberType],
- b: Union[Tensor, NumberType],
- ) -> Tensor:
- check(
- supports_lhs_python_scalar or not isinstance(a, Number),
- lambda: f"{name}: Received a lhs Python scalar to an elementwise binary "
- "operation that does not accept lhs scalars!",
- ValueError,
- )
- check(
- supports_rhs_python_scalar or not isinstance(b, Number),
- lambda: f"{name}: Received a rhs Python scalar to an elementwise binary "
- "operation that does not accept rhs scalars!",
- ValueError,
- )
- check(
- supports_two_python_scalars
- or not (isinstance(a, Number) and isinstance(b, Number)),
- lambda: f"{name}: Receive two Number inputs to an elementwise binary operation!",
- ValueError,
- )
- a, b = _maybe_broadcast(a, b)
- return prim(a, b)
- if has_out:
- _ref = out_wrapper()(_ref)
- _ref.__name__ = name
- if aten_op is infer_aten_op:
- aten_op = utils.get_aten_op(prim, name)
- if aten_op is not None:
- register_decomposition(aten_op)(_ref)
- return _ref
- return inner
- # Add has its own implementation because it has an alpha argument
- @register_decomposition(aten.add)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("a", "b"),
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def add(
- a: Union[TensorLikeType, NumberType],
- b: Union[TensorLikeType, NumberType],
- *,
- alpha: Optional[NumberType] = None,
- ):
- """
- Reference implementation of torch.add
- """
- a, b = _maybe_broadcast(a, b)
- if alpha is not None:
- dtype = a.dtype if isinstance(a, TensorLike) else b.dtype # type: ignore[union-attr]
- python_type = utils.dtype_to_type(dtype)
- if python_type != bool and not utils.is_weakly_lesser_type(
- type(alpha), python_type
- ):
- msg = (
- "alpha argument of type {0} cannot be safely cast to type {1}!".format(
- type(alpha), python_type
- )
- )
- raise ValueError(msg)
- b = prims.mul(b, alpha)
- return prims.add(a, b)
- # TODO: add docstring
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- supports_lhs_python_scalar=False,
- supports_rhs_python_scalar=False,
- )
- def atan2(a, b):
- return prims.atan2(a, b)
- # TODO: add docstring
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def bitwise_and(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.bitwise_and(a, b)
- # TODO: add docstring
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def bitwise_left_shift(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.shift_left(a, b)
- # TODO: add docstring
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def bitwise_or(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.bitwise_or(a, b)
- # TODO: add docstring
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def bitwise_right_shift(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.shift_right_arithmetic(a, b)
- # TODO: add docstring
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def bitwise_xor(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.bitwise_xor(a, b)
- # TODO: add docstring
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- supports_lhs_python_scalar=False,
- )
- def copysign(
- a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
- ):
- if isinstance(b, Number) and isinstance(a, Tensor):
- b = scalar_tensor(b, dtype=a.dtype, device=a.device)
- elif isinstance(a, Tensor) and isinstance(b, Tensor) and a.device != b.device:
- msg = "Expected divisor (b) to be on the same device ({0}) as dividend (a), but it is found on {1}!".format(
- a.device, b.device
- )
- raise RuntimeError(msg)
- return where(signbit(b), neg(abs(a)), abs(a))
- # TODO: add docstring
- # complex = _make_elementwise_binary_reference(prims.complex, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
- @register_decomposition(aten.div)
- @out_wrapper()
- def div(
- a: Union[TensorLikeType, NumberType],
- b: Union[TensorLikeType, NumberType],
- *,
- rounding_mode: Optional[str] = None,
- ):
- """
- Reference implementation of torch.div
- """
- if rounding_mode is None:
- return true_divide(a, b)
- elif rounding_mode == "trunc":
- return trunc_divide(a, b)
- elif rounding_mode == "floor":
- return floor_divide(a, b)
- else:
- msg = (
- "div expected rounding_mode to be one of None, 'trunc', or 'floor' "
- "but found {0}.".format(rounding_mode)
- )
- raise ValueError(msg)
- # TODO: add docstring
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
- supports_lhs_python_scalar=False,
- )
- def eq(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.eq(a, b)
- # TODO: add docstring
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG,
- )
- def pow(
- a: Union[TensorLikeType, NumberType],
- b: Union[TensorLikeType, NumberType],
- ) -> TensorLikeType:
- assert isinstance(a, TensorLikeType) or isinstance(b, TensorLikeType)
- if isinstance(b, Number):
- if b == 1.0:
- return a.clone() # type: ignore[return-value,union-attr]
- elif b == 2.0:
- return a * a # type: ignore[return-value]
- elif b == 0.5:
- return torch.sqrt(a) # type: ignore[arg-type]
- elif isinstance(a, Number):
- if a == 1.0:
- return torch.fill(b, True)
- if a == 2.0 and (
- utils.is_float_dtype(b.dtype) or utils.is_complex_dtype(b.dtype)
- ):
- return torch.exp2(b)
- return prims.pow(a, b)
- # TODO: add docstring
- # Float power has its own implementation because it has unique type promotion.
- # NB: aten_op not registered because CompositeExplicitAutograd
- @out_wrapper()
- def float_power(
- a: Union[TensorLikeType, NumberType],
- b: Union[TensorLikeType, NumberType],
- ) -> Tensor:
- if isinstance(a, Number) and isinstance(b, Number):
- raise ValueError(
- "Receive two Number inputs to an elementwise binary operation!"
- )
- # Handles type promotion
- dtype = utils.get_higher_dtype(a, b)
- assert dtype is not None
- if utils.is_complex_dtype(dtype):
- dtype = torch.complex128
- else:
- dtype = torch.float64
- # Float power has the following contiguous cast behavior to be
- # consistent with its C++ impl
- a = _maybe_convert_to_dtype(a, dtype)
- b = _maybe_convert_to_dtype(b, dtype)
- a, b = _maybe_broadcast(a, b)
- return pow(a, b)
- # >>> a = torch.tensor(-0.2500, dtype=torch.float64)
- # tensor(-0.250000000000000, dtype=torch.float64)
- #
- # >>> b = torch.tensor(-0.0010, dtype=torch.float64)
- # tensor(-0.001000000000000, dtype=torch.float64)
- #
- # Note: In this case, casting float to double will expand the float mantissa with zeros,
- # while creating a double generates a distinct mantissa.
- # >>> torch.tensor(-0.001).to(dtype=torch.float64)
- # tensor(-0.001000000047497, dtype=torch.float64)
- #
- # Floor Division
- # The difference is caused because torch.remainder(a, b) = -0.001.
- #
- # >>> torch.floor(torch.true_divide(a, b))
- # tensor(250., dtype=torch.float64)
- #
- # >>> torch.div(a, b, rounding_mode='floor')
- # tensor(249., dtype=torch.float64)
- #
- # Definition: a // b = (a - remainder(a, b)) / b
- # >>> torch.true_divide(torch.sub(a, torch.remainder(a, b)), b)
- # tensor(249., dtype=torch.float64)
- #
- # For reference, see CPython's implementation:
- # https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636
- # TODO: add docstring
- @_make_elementwise_binary_reference(
- type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- supports_two_python_scalars=True,
- )
- def floor_divide(
- a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
- ):
- # Wrap scalars because some references only accept tensor arguments.
- if isinstance(a, Number) and isinstance(b, Number):
- a = scalar_tensor(a)
- b = scalar_tensor(b)
- elif isinstance(b, Number) and isinstance(a, Tensor):
- b = scalar_tensor(b, dtype=a.dtype, device=a.device)
- elif isinstance(a, Number) and isinstance(b, Tensor):
- a = scalar_tensor(a, dtype=b.dtype, device=b.device)
- elif isinstance(a, Tensor) and isinstance(b, Tensor) and a.device != b.device:
- if a.device == torch.device("cpu"):
- msg = "Expected divisor (b) to be on the same device ({0}) as dividend (a), but it is found on {1}!".format(
- a.device, b.device
- )
- raise RuntimeError(msg)
- else:
- b = prims.device_put(b, device=a.device)
- assert isinstance(a, Tensor) and isinstance(b, Tensor)
- dtype = a.dtype
- if utils.is_float_dtype(dtype):
- return _floor_divide_float(a, b)
- elif utils.is_integer_dtype(dtype):
- return _floor_divide_integer(a, b)
- else:
- check(False, lambda: f"{dtype} not supported for floor_divide")
- def _floor_divide_integer(a: Tensor, b: Tensor) -> Tensor:
- a, b = _maybe_broadcast(a, b)
- if not a.dtype.is_signed:
- return prims.div(a, b)
- # Convert truncation to flooring:
- offset = (torch.signbit(a) != torch.signbit(b)).logical_and(torch.fmod(a, b) != 0)
- return prims.div(a, b) - _maybe_convert_to_dtype(offset, a.dtype)
- def _floor_divide_float(a: Tensor, b: Tensor) -> Tensor:
- mod = fmod(a, b)
- div = true_divide(sub(a, mod), b)
- # Ensure that the remainder has the same sign as denominator
- different_signed_inputs = bitwise_xor(lt(a, 0), lt(b, 0))
- non_zero_remainder = ne(mod, 0)
- mask = bitwise_and(non_zero_remainder, different_signed_inputs)
- div = where(mask, sub(div, 1), div)
- # Map quotient to nearest integer value
- floor_div = floor(div)
- mask = gt(sub(div, floor_div), 0.5)
- floor_div = where(mask, add(floor_div, 1), floor_div)
- basic_div = true_divide(a, b)
- zero_tensor = scalar_tensor(0, dtype=basic_div.dtype, device=basic_div.device)
- # If quotient is zero, copy signbit from true_divide quotient
- floor_div = where(ne(div, 0), floor_div, copysign(zero_tensor, basic_div))
- # If denominator is zero, then follow true_divide behavior
- return where(ne(b, 0), floor_div, basic_div)
- # TODO: add docstring
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- supports_lhs_python_scalar=False,
- supports_rhs_python_scalar=False,
- )
- def fmax(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.fmax(a, b)
- # TODO: add docstring
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- supports_lhs_python_scalar=False,
- supports_rhs_python_scalar=False,
- )
- def fmin(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.fmin(a, b)
- # TODO: add docstring
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- supports_lhs_python_scalar=False,
- supports_rhs_python_scalar=True,
- )
- def fmod(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.fmod(a, b)
- # TODO: add docstring
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- supports_lhs_python_scalar=False,
- supports_rhs_python_scalar=False,
- )
- def gcd(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.gcd(a, b)
- # TODO: add docstring
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
- supports_lhs_python_scalar=False,
- )
- def ge(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.ge(a, b)
- # TODO: add docstring
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
- supports_lhs_python_scalar=False,
- )
- def gt(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.gt(a, b)
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- supports_lhs_python_scalar=False,
- supports_rhs_python_scalar=False,
- )
- def heaviside(input: TensorLikeType, values: TensorLikeType) -> TensorLikeType:
- input_eq_zero = torch.eq(input, 0)
- input_lt_zero = torch.logical_or(torch.lt(input, 0), torch.isnan(input))
- zeros_and_ones = torch.where(input_lt_zero, 0, 1)
- output = torch.where(input_eq_zero, values, zeros_and_ones)
- return output
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- supports_lhs_python_scalar=False,
- supports_rhs_python_scalar=False,
- )
- def hypot(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.hypot(a, b)
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- supports_lhs_python_scalar=False,
- supports_rhs_python_scalar=False,
- )
- def igamma(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.igamma(a, b)
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- supports_lhs_python_scalar=False,
- supports_rhs_python_scalar=False,
- )
- def igammac(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.igammac(a, b)
- def _check_close_args(
- name: str,
- a: TensorLikeType,
- b: TensorLikeType,
- rtol: float,
- atol: float,
- ) -> None:
- check(
- a.dtype == b.dtype,
- lambda: "{0}: Attempting to compare tensors of different dtypes {1} and {2}!".format(
- name, a.dtype, b.dtype
- ),
- ValueError,
- )
- check(
- rtol >= 0,
- lambda: "{0}: rtol must be greater than or equal to zero, but got {1}!".format(
- name, rtol
- ),
- )
- check(
- atol >= 0,
- lambda: "{0}: atol must be greater than or equal to zero, but got {1}!".format(
- name, atol
- ),
- )
- # CompositeImplicitAutograd - don't register decomp
- def isclose(
- a: TensorLikeType,
- b: TensorLikeType,
- rtol: float = 1e-05,
- atol: float = 1e-08,
- equal_nan: bool = False,
- ) -> TensorLikeType:
- _check_close_args(name="torch.isclose", a=a, b=b, rtol=rtol, atol=atol)
- close = eq(a, b)
- if equal_nan and (utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype)):
- close = logical_or(close, logical_and(isnan(a), isnan(b)))
- # Note: In case of zero tolerances the closeness inequality degenerates to an equality check.
- # In this case, the short-circuit prevents false positives as detailed in the paragraph below.
- if atol == 0 and rtol == 0:
- return close
- # Note [closeness error computation]
- # atol and rtol are provided as doubles, so the computation
- # rtol * other will produce a float or complex tensor.
- # When the difference (self - other) is compared to it then the
- # tensor representing the difference will also be cast to float or complex.
- # However, since (self - other) in uint8 is very likely to produce a
- # negative value, this moves the cast forward so the difference is
- # always computed in a float or complex type.
- # If the values of the integer tensors cannot be exactly represented
- # by the default scalar type then this may cause an incorrect result.
- if not utils.is_float_dtype(a.dtype) and not utils.is_complex_dtype(a.dtype):
- a = prims.convert_element_type(a, torch.get_default_dtype())
- b = prims.convert_element_type(b, torch.get_default_dtype())
- allowed_error = add(atol, abs(mul(b, rtol)))
- actual_error = abs(sub(a, b))
- # Computes finite closeness
- result = logical_or(
- close, logical_and(isfinite(actual_error), le(actual_error, allowed_error))
- )
- return result
- # TODO: add docstring
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- supports_lhs_python_scalar=False,
- supports_rhs_python_scalar=False,
- )
- def lcm(a: TensorLikeType, b: TensorLikeType):
- dtype = a.dtype
- # promoting to int32 to maintain 100% consistency with C++ and to
- # prevent overflow in case of int8 and int16
- promote_to_int = dtype in (torch.int8, torch.int16)
- if promote_to_int:
- a = prims.convert_element_type(a, torch.int32)
- b = prims.convert_element_type(b, torch.int32)
- g = torch.gcd(a, b)
- # Avoid division by zero in case gcd(0, 0) == 0
- g = torch.where(g == 0, 1, g)
- res = torch.abs(prims.div(a, g) * b)
- return res if not promote_to_int else prims.convert_element_type(res, dtype)
- # TODO: add docstring
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
- supports_lhs_python_scalar=False,
- )
- def le(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.le(a, b)
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- supports_lhs_python_scalar=False,
- supports_rhs_python_scalar=False,
- )
- def logaddexp(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- # Nb. this implementation does nto distribute the gradients evenly when a == b
- mask = a >= b
- max_ = torch.where(mask, a, b)
- min_ = torch.where(mask, b, a)
- inf_mask = torch.logical_and(torch.isinf(a), a == b)
- return torch.where(inf_mask, a, max_ + torch.log1p(torch.exp(min_ - max_)))
- # TODO: add docstring
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
- )
- def logical_and(a: TensorLikeType, b: TensorLikeType):
- if not utils.is_boolean_dtype(a.dtype):
- a = a != 0
- if not utils.is_boolean_dtype(b.dtype):
- b = b != 0
- return a & b
- # TODO: add docstring
- @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
- def logical_not(a: TensorLikeType):
- if not utils.is_boolean_dtype(a.dtype):
- return a == 0
- return ~a
- # TODO: add docstring
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
- )
- def logical_or(a: TensorLikeType, b: TensorLikeType):
- if not utils.is_boolean_dtype(a.dtype):
- a = a != 0
- if not utils.is_boolean_dtype(b.dtype):
- b = b != 0
- return bitwise_or(a, b)
- # TODO: add docstring
- # TODO: skip unnecessary conversion of long to float
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
- )
- def logical_xor(a: TensorLikeType, b: TensorLikeType):
- if not utils.is_boolean_dtype(a.dtype):
- a = a != 0
- if not utils.is_boolean_dtype(b.dtype):
- b = b != 0
- return a ^ b
- # TODO: add docstring
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
- supports_lhs_python_scalar=False,
- )
- def lt(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.lt(a, b)
- # TODO: add docstring
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def maximum(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.maximum(a, b)
- # TODO: add docstring
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def minimum(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.minimum(a, b)
- # TODO: add docstring
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- supports_two_python_scalars=True,
- )
- def mul(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.mul(a, b)
- # TODO: add docstring
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
- supports_lhs_python_scalar=False,
- )
- def ne(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.ne(a, b)
- # TODO: add docstring
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
- supports_lhs_python_scalar=False,
- supports_rhs_python_scalar=False,
- )
- def nextafter(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.nextafter(a, b)
- # TODO: add docstring
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def remainder(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.remainder(a, b)
- # reverse sub
- def rsub(
- a: Union[TensorLikeType, NumberType],
- b: Union[TensorLikeType, NumberType],
- *,
- alpha: Optional[NumberType] = None,
- ):
- if isinstance(a, Number):
- msg = "Received a Number for the first argument, but expected a Tensor"
- raise ValueError(msg)
- return sub(b, a, alpha=alpha)
- # TODO: add docstring
- # TODO: consider refactoring this with add impl
- # sub has its own implementation because it has an alpha argument
- @register_decomposition(aten.sub)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("a", "b"),
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def sub(
- a: Union[TensorLikeType, NumberType],
- b: Union[TensorLikeType, NumberType],
- *,
- alpha: Optional[NumberType] = None,
- ):
- """
- Reference implementation of torch.sub
- """
- a, b = _maybe_broadcast(a, b)
- if alpha is not None:
- dtype = a.dtype if isinstance(a, TensorLike) else b.dtype # type: ignore[union-attr]
- python_type = utils.dtype_to_type(dtype)
- if not utils.is_weakly_lesser_type(type(alpha), python_type):
- msg = (
- "alpha argument of type {0} cannot be safely cast to type {1}!".format(
- type(alpha), python_type
- )
- )
- raise ValueError(msg)
- b = prims.mul(b, alpha)
- return prims.sub(a, b)
- # TODO: add docstring
- @_make_elementwise_binary_reference(
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- name="true_divide",
- aten_op=None, # CompositeImplicitAutograd
- supports_two_python_scalars=True,
- )
- def true_divide(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.div(a, b)
- @register_decomposition(aten.xlogy)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("a", "b"),
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- )
- def xlogy(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]):
- utils.check(
- isinstance(a, TensorLike) or isinstance(b, TensorLike),
- lambda: 'Expected either argument a or b to be a Tensor"',
- )
- # Operations like eq and log do not handle scalar values, so we convert them to scalar_tensors.
- if isinstance(b, TensorLike) and isinstance(a, Number):
- a = scalar_tensor(a, dtype=b.dtype, device=b.device)
- elif isinstance(a, TensorLike) and isinstance(b, Number):
- b = scalar_tensor(b, dtype=a.dtype, device=a.device)
- # mypy: expected "Tensor"
- assert isinstance(a, TensorLike)
- assert isinstance(b, TensorLike)
- rhs = torch.where(torch.eq(a, 0), 0, torch.mul(a, torch.log(b)))
- return torch.where(torch.isnan(b), float("nan"), rhs)
- # TODO: add docstring
- @_make_elementwise_binary_reference(
- type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- aten_op=None, # CompositeImplicitAutograd
- supports_two_python_scalars=True,
- )
- def trunc_divide(
- a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
- ):
- dtype = utils.get_dtype(a)
- if utils.is_integer_dtype(dtype):
- return prims.div(a, b)
- return trunc(prims.div(a, b))
- #
- # Elementwise Ternary References
- #
- @register_decomposition(aten.addcdiv)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("self", "tensor1", "tensor2"),
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- )
- def addcdiv(
- self: TensorLikeType,
- tensor1: TensorLikeType,
- tensor2: TensorLikeType,
- *,
- value: NumberType = 1,
- ) -> TensorLikeType:
- """
- Reference implementation of torch.addcdiv
- """
- if value is not None:
- dtype = self.dtype # no scalars allowed, see add
- python_type = utils.dtype_to_type(dtype)
- check(
- utils.is_weakly_lesser_type(type(value), python_type),
- lambda: "value argument of type {0} cannot be safely cast to type {1}!".format(
- type(value), python_type
- ),
- exc_type=ValueError,
- )
- return self + value * tensor1 / tensor2
- @register_decomposition(aten.addcmul)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("self", "tensor1", "tensor2"),
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def addcmul(
- self: TensorLikeType,
- tensor1: TensorLikeType,
- tensor2: TensorLikeType,
- *,
- value: NumberType = 1,
- ) -> TensorLikeType:
- """
- Reference implementation of torch.addcmul
- """
- if value is not None:
- dtype = self.dtype # no scalars allowed, see add
- python_type = utils.dtype_to_type(dtype)
- check(
- utils.is_weakly_lesser_type(type(value), python_type),
- lambda: "value argument of type {0} cannot be safely cast to type {1}!".format(
- type(value), python_type
- ),
- exc_type=ValueError,
- )
- return self + value * tensor1 * tensor2
- @register_decomposition(aten.clamp)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("a", "min", "max"),
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def clamp(
- a: TensorLikeType,
- min: Optional[TensorOrNumberLikeType] = None,
- max: Optional[TensorOrNumberLikeType] = None,
- ) -> TensorLikeType:
- # NOTE: grad behavior with implementation `where` is not consistent on `nan`
- if min is None and max is None:
- msg = "clamp called but both min and max are none!"
- raise ValueError(msg)
- if min is not None:
- a_isnan = torch.isnan(a)
- condition = torch.bitwise_or(torch.ge(a, min), a_isnan) # type: ignore[arg-type]
- # we should also propagate `nan` coming from boundaries. However, that's
- # not necessary since `ge` would already `False` when either operands has
- # a `nan`. So this line below is redundant
- # `condition = bitwise_and(condition, bitwise_not(isnan(min)))`
- a = torch.where(condition, a, min) # type: ignore[arg-type]
- if max is not None:
- a_isnan = torch.isnan(a)
- # same as above, no need to adjust `nan` from `max`
- condition = torch.bitwise_or(torch.le(a, max), a_isnan) # type: ignore[arg-type]
- a = torch.where(condition, a, max) # type: ignore[arg-type]
- return a
- @register_decomposition(aten.clamp_min)
- @out_wrapper()
- def clamp_min(
- self: TensorLikeType,
- min: TensorOrNumberLikeType = None,
- ) -> TensorLikeType:
- return torch.clamp(self, min=min) # type: ignore[arg-type]
- @register_decomposition(aten.clamp_max)
- @out_wrapper()
- def clamp_max(
- self: TensorLikeType,
- max: TensorOrNumberLikeType = None,
- ) -> TensorLikeType:
- return torch.clamp(self, max=max) # type: ignore[arg-type]
- #
- # Conditional references
- #
- # https://pytorch.org/docs/stable/generated/torch.where.html
- # TODO: implement alternate where
- @register_decomposition(aten.where)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("a", "b"),
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
- )
- def where(
- pred: Tensor,
- a: Optional[TensorOrNumberLikeType] = None,
- b: Optional[TensorOrNumberLikeType] = None,
- ):
- """ """
- if a is None or b is None:
- raise NotImplementedError
- utils.check_same_device(pred, a, b, allow_cpu_scalar_tensors=True)
- check(
- pred.dtype is torch.bool,
- lambda: f"expected predicate to be bool, got {pred.dtype}",
- )
- pred, a, b = _maybe_broadcast(pred, a, b)
- return prims.where(pred, a, b)
- #
- # Data Movement References
- #
- @register_decomposition(aten.clone)
- def clone(
- a: TensorLikeType, *, memory_format: torch.memory_format = torch.preserve_format
- ) -> TensorLikeType:
- result = prims.clone(a, memory_format=memory_format)
- return result
- def copy_to(a: Tensor, b: Tensor, *, allow_cross_device=True):
- if not allow_cross_device and a.device != b.device:
- msg = "Attempting to copy from device {0} to device {1}, but cross-device copies are not allowed!".format(
- b.device, a.device
- )
- raise RuntimeError(msg)
- return prims.copy_to(a, b)
- @register_decomposition(aten.item)
- def item(a: TensorLikeType) -> NumberType:
- if a.numel() != 1:
- msg = f"Can't convert a tensor with {a.numel()} elements to a number!"
- raise ValueError(msg)
- # NOTE: explicit conversion is necessary for bool!
- # See https://github.com/pytorch/pytorch/issues/78071
- number_type = utils.dtype_to_type(a.dtype)
- return number_type(prims.item(a))
- # fast path when `to` returns an alias to input. This mimics the same function in aten
- def _to_will_alias(
- a: TensorLikeType,
- device: Optional[torch.device] = None,
- dtype: Optional[torch.dtype] = None,
- copy: Optional[bool] = None,
- layout: Optional[torch.layout] = None,
- memory_format: Optional[torch.memory_format] = None,
- pin_memory: Optional[bool] = False,
- non_blocking: bool = False, # not using non_blocking
- ) -> bool:
- return (
- not copy
- and (device is None or a.device == device)
- and (dtype is None or a.dtype == dtype)
- and (layout is None or a.layout == layout)
- # is_pinned issue #84925
- # and (pin_memory is None or pin_memory == a.is_pinned())
- and (
- memory_format is None
- or memory_format == torch.preserve_format
- or utils.is_contiguous_for_memory_format(a, memory_format=memory_format)
- )
- )
- @singledispatch
- def _to_dispatch(*args, **kwargs):
- raise NotImplementedError
- @_to_dispatch.register
- def _to_device(
- device: torch.device,
- dtype: torch.dtype,
- non_blocking: bool = False,
- copy: bool = False,
- memory_format: Optional[torch.memory_format] = None,
- ):
- kwargs = {
- "device": device,
- "dtype": dtype,
- "non_blocking": non_blocking,
- "copy": copy,
- "memory_format": memory_format,
- }
- return kwargs
- @_to_dispatch.register
- def _to_device_str(
- device: str,
- dtype: torch.dtype,
- non_blocking: bool = False,
- copy: bool = False,
- memory_format: Optional[torch.memory_format] = None,
- ):
- kwargs = {
- "device": torch.device(device),
- "dtype": dtype,
- "non_blocking": non_blocking,
- "copy": copy,
- "memory_format": memory_format,
- }
- return kwargs
- @_to_dispatch.register
- def _to_dtype(
- dtype: torch.dtype,
- non_blocking: bool = False,
- copy: bool = False,
- memory_format: Optional[torch.memory_format] = None,
- ):
- kwargs = {
- "dtype": dtype,
- "non_blocking": non_blocking,
- "copy": copy,
- "memory_format": memory_format,
- }
- return kwargs
- @_to_dispatch.register
- def _to_other(
- other: Tensor,
- non_blocking: bool = False,
- copy: bool = False,
- memory_format: Optional[torch.memory_format] = None,
- ):
- device = other.device
- dtype = other.dtype
- layout = other.layout
- # is_pinned issue #84925
- # pin_memory = other.is_pinned()
- kwargs = {
- "device": device,
- "dtype": dtype,
- "layout": layout,
- "non_blocking": non_blocking,
- "copy": copy,
- "memory_format": memory_format,
- }
- return kwargs
- # remove to_kwargs that is already present in `a`
- def canonicalize_to_arguments(a: Tensor, to_kwargs: dict):
- options_to_check = ["dtype", "device", "layout", "memory_format"]
- # "device" option could be passed a str instead torch.device
- if "device" in to_kwargs and isinstance(to_kwargs["device"], str):
- to_kwargs["device"] = torch.device(to_kwargs["device"])
- for kw in options_to_check:
- if kw in to_kwargs:
- if (
- (kw == "memory_format" and to_kwargs[kw] is torch.preserve_format)
- or (
- kw == "device"
- and to_kwargs[kw].type == a.device.type
- and (
- not to_kwargs[kw].index or to_kwargs[kw].index == a.device.index
- )
- )
- or (
- getattr(a, kw, None) == to_kwargs[kw]
- ) # this also handles {"memory_format": None}
- ):
- to_kwargs.pop(kw)
- def to(a: TensorLikeType, *args, **kwargs) -> TensorLikeType:
- # handled dispatch via positional arguments
- if len(args) != 0:
- kwargs = _to_dispatch(*args, **kwargs)
- # TODO: is_pinned is not currently supported in refs or fake_tensor
- # https://github.com/pytorch/pytorch/issues/84925
- assert "pin_memory" not in kwargs
- canonicalize_to_arguments(a, kwargs)
- if _to_will_alias(a, **kwargs):
- return a
- copy = kwargs.pop("copy") if "copy" in kwargs else False
- non_blocking = kwargs.pop("non_blocking") if "non_blocking" in kwargs else False
- # short-circuit to `prims.convert_element_type` when `to` is just a dtype change
- if (
- (copy or (kwargs.get("dtype", a.dtype) != a.dtype))
- and (not non_blocking)
- and ("memory_format" not in kwargs)
- and ("device" not in kwargs)
- and ("layout" not in kwargs)
- # is_pinned issue #84925
- # and ("pin_memory" not in kwargs)
- ):
- return prims.convert_element_type(a, kwargs.get("dtype", a.dtype))
- result = torch.empty_like(a, **kwargs)
- # TODO: non_blocking should be handled by `copy_to`
- copy_to(result, a)
- return result
- #
- # Reduction references
- #
- def _reduction(
- a: TensorLikeType,
- prim: Callable,
- *,
- has_identity: bool = True,
- accepts_dim_tuple: bool = True, # to handle min/argmin that accept single dim only
- dims: Optional[DimsType] = None,
- keepdims: bool = False,
- dtype: Optional[torch.dtype] = None, # should be specified for ops that support it
- out: Optional[Tensor] = None,
- output_dtype_kind: REDUCTION_OUTPUT_TYPE_KIND,
- ) -> TensorLikeType: # it is usually SAME, but I want
- # ref writers to actually think about what to put here
- assert isinstance(a, TensorLike)
- if a.ndim > 64:
- raise RuntimeError(
- "Received a tensor with {0} dimensions, but only tensors with up to 64 dims are supported!".format(
- a.ndim
- )
- )
- if out is not None:
- assert isinstance(out, TensorLike)
- if dtype is not None:
- # TODO - this is true for eager mode currently, but it's wrong behavior for complex norms
- if dtype != out.dtype:
- raise RuntimeError(
- "dtype argument and out dtype must match in reduction"
- )
- if not accepts_dim_tuple:
- assert dims is None or isinstance(dims, Dim)
- if isinstance(dims, Dim):
- dims = (dims,) # type: ignore[assignment]
- dims = utils.reduction_dims(a.shape, dims)
- if not has_identity:
- valid_shape = a.ndim == 0 or py_all(a.shape[i] for i in dims)
- if not valid_shape:
- raise RuntimeError(
- "reducing over zero-size dimension for reduction operation without identity"
- )
- computation_dtype, result_dtype = utils.reduction_dtypes(
- a, output_dtype_kind, dtype
- )
- a = _maybe_convert_to_dtype(a, computation_dtype) # type: ignore[assignment]
- result = prim(a, dims)
- if keepdims:
- output_shape = [a.shape[i] if i not in dims else 1 for i in range(a.ndim)]
- broadcast_dims = [i for i in range(a.ndim) if i not in dims]
- result = prims.broadcast_in_dim(result, output_shape, broadcast_dims)
- if out is not None:
- assert result_dtype is not None
- if dtype is not None and result_dtype != out.dtype:
- raise RuntimeError(
- "Expected the dtype of reduction result and out to match"
- )
- out = _maybe_resize_out(out, result.shape)
- return _safe_copy_out(copy_from=result, copy_to=out) # type: ignore[arg-type]
- if result.dtype != result_dtype and result_dtype is not None:
- result = prims.convert_element_type(result, result_dtype)
- return result
- def _make_copy_from_view(fn):
- """
- Given a view function (e.g. torch.diagonal) generates its copy variant (e.g. torch.diagonal_copy)
- """
- name = fn.__name__
- fn = out_wrapper()(fn)
- def _fn(*args, out=None, **kwargs):
- result = fn(*args, out=out, **kwargs)
- if out is None:
- return result.clone(memory_format=torch.contiguous_format)
- return result
- copy_name = f"{name}_copy"
- _fn.__name__ = copy_name
- _fn = register_decomposition(getattr(aten, copy_name))(_fn)
- return _fn
- # Saves Python all
- py_all = all
- @register_decomposition(aten.all)
- @out_wrapper()
- def all(
- a: TensorLikeType,
- dim: Optional[DimsType] = None,
- keepdim: bool = False,
- ) -> TensorLikeType:
- # Computes nelem
- if isinstance(dim, Dim):
- dim = (dim,) # type: ignore[assignment]
- a_ = _maybe_convert_to_dtype(a, torch.bool)
- # avoid comparison with symbolic number of elements to make this op symint friendly
- result = eq(sum(logical_not(a_), dim=dim, keepdim=keepdim), 0)
- # Preserves uint8 -- probably a legacy mask thing
- if a.dtype is torch.uint8:
- return prims.convert_element_type(result, torch.uint8)
- return result
- # Saves Python any
- py_any = any
- @register_decomposition(aten.any)
- @out_wrapper()
- def any(
- a: TensorLikeType,
- dim: Optional[DimsType] = None,
- keepdim: bool = False,
- ) -> TensorLikeType:
- a_ = _maybe_convert_to_dtype(a, torch.bool)
- result = ne(sum(a_, dim=dim, keepdim=keepdim), False) # type: ignore[arg-type]
- # Preserves uint8 -- probably a legacy mask thing
- if a.dtype is torch.uint8:
- return prims.convert_element_type(result, torch.uint8)
- return result
- @register_decomposition(aten.sum)
- def sum(
- a: TensorLikeType,
- dim: Union[Optional[int], Optional[List[int]]] = None,
- keepdim: bool = False,
- *,
- dtype: Optional[torch.dtype] = None,
- out: Optional[Tensor] = None,
- ) -> TensorLikeType:
- if dtype is None:
- if utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype):
- dtype = torch.int64
- else:
- dtype = a.dtype
- # reduces over all dimensions if dim=() is passed
- if dim == () or dim == []:
- dim = None
- return _reduction(
- a,
- prims.sum,
- dims=dim,
- keepdims=keepdim,
- dtype=dtype,
- out=out,
- output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME,
- )
- def sum_to_size(
- a: Tensor,
- *shape,
- ) -> Tensor:
- shape = utils.extract_shape_from_varargs(shape, validate=False)
- utils.check(
- utils.is_expandable_to(shape, a.shape),
- lambda: f'sum_to_size: size "{shape}" is not expandable to size "{a.shape}"',
- )
- # In ATen scalar tensors are sent through sum and the result is returned as
- # type promoted
- if utils.is_same_shape(shape, a.shape) and len(shape) > 0:
- return prims.view_of(a)
- leading_dims = a.ndim - len(shape)
- reduce_dims = tuple(range(leading_dims)) + tuple(
- i
- for i in range(leading_dims, len(shape))
- if shape[i - leading_dims] == 1 and a.shape[i] != 1
- )
- return torch.sum(a, dim=reduce_dims, keepdim=True, dtype=None)
- @register_decomposition(aten.prod)
- def prod(
- a: TensorLikeType,
- dim: Union[Optional[int], Optional[List[int]]] = None,
- keepdim: bool = False,
- *,
- dtype=None,
- out: Optional[Tensor] = None,
- ) -> TensorLikeType:
- if dtype is None:
- if utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype):
- dtype = torch.int64
- else:
- dtype = a.dtype
- # reduces over all dimensions if dim=() is passed
- if dim == () or dim == []:
- dim = None
- return _reduction(
- a,
- prims.prod,
- dims=dim,
- keepdims=keepdim,
- dtype=dtype,
- out=out,
- output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME,
- )
- @register_decomposition(aten.amin)
- def amin(
- a: TensorLikeType,
- dim: Optional[DimsType] = None,
- keepdim: bool = False,
- *,
- out: Optional[Tensor] = None,
- ) -> TensorLikeType:
- # reduces over all dimensions if dim=() is passed
- if dim == () or dim == []:
- dim = None
- return _reduction(
- a,
- prims.amin,
- dims=dim,
- keepdims=keepdim,
- dtype=None,
- out=out,
- has_identity=False,
- output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME,
- )
- @register_decomposition(aten.amax)
- def amax(
- a: TensorLikeType,
- dim: Optional[DimsType] = None,
- keepdim: bool = False,
- *,
- out: Optional[Tensor] = None,
- ) -> TensorLikeType:
- # reduces over all dimensions if dim=() is passed
- if dim == () or dim == []:
- dim = None
- return _reduction(
- a,
- prims.amax,
- dims=dim,
- keepdims=keepdim,
- dtype=None,
- out=out,
- has_identity=False,
- output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME,
- )
- def _dim_var_dispatch(dim=None, unbiased=None):
- # There's the following overload of torch.var:
- # var(Tensor self, bool unbiased=True) -> (Tensor, Tensor)
- # We need to explicitly convert bool dims to unbiased arg
- if unbiased is None and isinstance(dim, bool):
- unbiased = dim
- dim = None
- return dim, unbiased
- @register_decomposition(aten.var)
- @out_wrapper()
- def var(
- a: TensorLikeType,
- dim: Optional[DimsType] = None,
- unbiased: Optional[bool] = None,
- keepdim: bool = False,
- *,
- correction: Optional[int] = None,
- ) -> TensorLikeType:
- dim, unbiased = _dim_var_dispatch(dim, unbiased)
- correction = utils.set_correction(unbiased, correction)
- # reduces over all dimensions if dim=() is passed
- if dim == () or dim == []:
- dim = None
- result = _reduction(
- a,
- partial(prims.var, correction=correction),
- dims=dim,
- keepdims=keepdim,
- dtype=None,
- out=None,
- has_identity=True,
- output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT,
- )
- return result
- @register_decomposition(aten.std)
- @out_wrapper()
- def std(
- a: TensorLikeType,
- dim: Union[Optional[int], Optional[List[int]]] = None,
- unbiased: Optional[bool] = None,
- keepdim: bool = False,
- *,
- correction: Optional[int] = None,
- ) -> TensorLikeType:
- dim, unbiased = _dim_var_dispatch(dim, unbiased)
- correction = utils.set_correction(unbiased, correction)
- opmath_dtype, dtype = utils.reduction_dtypes(
- a, REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT
- )
- a = _maybe_convert_to_dtype(a, opmath_dtype)
- a_var = torch.var(a, dim, correction=correction, keepdim=keepdim)
- a_std = torch.sqrt(a_var)
- assert dtype is not None
- return _maybe_convert_to_dtype(a_std, dtype)
- @register_decomposition(aten.mean)
- def mean(
- a: TensorLikeType,
- dim: Optional[DimsType] = None,
- keepdim: bool = False,
- *,
- dtype=None,
- out=None,
- ) -> TensorLikeType:
- # reduces over all dimensions if dim=() is passed
- if dim == () or dim == []:
- dim = None
- orig_dtype = dtype
- if dtype is None:
- dtype = a.dtype
- # can't use out wrapper because of this argument
- check(
- out is None or out.dtype == dtype,
- lambda: f"Expected out tensor to have dtype {dtype}, but got {out.dtype} instead",
- )
- result = _reduction(
- a,
- prims.sum,
- dims=dim,
- keepdims=keepdim,
- dtype=dtype,
- out=None,
- output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.KEEP_PROMOTED_TYPE,
- )
- check(
- utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
- lambda: (
- f"mean(): could not infer output dtype. "
- f"{'Input' if orig_dtype is None else 'Optional'} dtype must be either "
- f"a floating point or complex dtype. Got: {dtype}"
- ),
- )
- if isinstance(dim, Dim):
- dim = (dim,) # type: ignore[assignment]
- dims = utils.reduction_dims(a.shape, dim) # type: ignore[arg-type]
- nelem = 1 if a.ndim == 0 else reduce(operator.mul, (a.shape[i] for i in dims), 1)
- result = true_divide(result, nelem)
- result_dtype = a.dtype if dtype is None else dtype
- result = _maybe_convert_to_dtype(result, result_dtype) # type: ignore[assignment]
- if out is not None:
- assert isinstance(out, TensorLike)
- out = _maybe_resize_out(out, result.shape)
- return _safe_copy_out(copy_from=result, copy_to=out) # type: ignore[arg-type]
- return result
- @register_decomposition(aten.std_mean.correction)
- def std_mean(
- a: TensorLikeType,
- dim: Optional[DimsType] = None,
- *,
- unbiased: Optional[bool] = None,
- keepdim: bool = False,
- correction: Optional[int] = None,
- ):
- dim, unbiased = _dim_var_dispatch(dim, unbiased)
- correction = utils.set_correction(unbiased, correction)
- opmath_dtype, dtype = utils.reduction_dtypes(
- a, REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT
- )
- original_dtype = a.dtype
- a = _maybe_convert_to_dtype(a, opmath_dtype)
- a_var, a_mean = torch.var_mean(a, dim, correction=correction, keepdim=keepdim)
- a_std = torch.sqrt(a_var)
- assert dtype is not None
- return (
- _maybe_convert_to_dtype(a_std, dtype),
- _maybe_convert_to_dtype(a_mean, original_dtype),
- )
- @register_decomposition(aten.var_mean)
- def var_mean(
- a: TensorLikeType,
- dim: Optional[DimsType] = None,
- unbiased: Optional[bool] = None,
- keepdim: bool = False,
- *,
- correction: Optional[int] = None,
- ):
- dim, unbiased = _dim_var_dispatch(dim, unbiased)
- v = var(a, dim, unbiased, keepdim, correction=correction)
- m = mean(a, dim, keepdim)
- return v, m
- @register_decomposition(aten.addr)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("self", "vec1", "vec2"),
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def addr(
- self: TensorLikeType,
- vec1: TensorLikeType,
- vec2: TensorLikeType,
- *,
- beta: NumberType = 1,
- alpha: NumberType = 1,
- ) -> TensorLikeType:
- check(
- vec1.ndim == 1,
- lambda: f"addr: Expected 1-D argument vec1, but got {vec1.ndim}-D",
- )
- check(
- vec2.ndim == 1,
- lambda: f"addr: Expected 1-D argument vec2, but got {vec2.ndim}-D",
- )
- self = self.expand(vec1.shape[0], vec2.shape[0])
- if utils.is_boolean_dtype(self.dtype):
- # Integers are accepted for booleans
- check(
- is_weakly_lesser_type(type(beta), int),
- lambda: f"expected bool/int beta but got {type(beta)}",
- )
- check(
- is_weakly_lesser_type(type(alpha), int),
- lambda: f"expected bool/int alpha but got {type(beta)}",
- )
- if not beta:
- return torch.outer(vec1, vec2) if alpha else torch.full_like(self, False)
- else:
- return torch.logical_or(
- self,
- torch.outer(vec1, vec2) if alpha else torch.full_like(self, False),
- )
- else:
- check(
- is_weakly_lesser_type(type(beta), dtype_to_type(self.dtype)),
- lambda: f"cannot safely convert {type(beta)} to {self.dtype}",
- )
- check(
- is_weakly_lesser_type(type(alpha), dtype_to_type(self.dtype)),
- lambda: f"cannot safely convert {type(alpha)} to {self.dtype}",
- )
- if beta == 0:
- # This means NaNs from self are dropped if beta is zero
- return alpha * torch.outer(vec1, vec2)
- else:
- return beta * self + alpha * torch.outer(vec1, vec2)
- # CompositeImplicitAutograd - don't register decomp
- def atleast_1d(
- arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType
- ) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]:
- """Reference implementation of :func:`torch.atleast_1d`."""
- if not args and isinstance(arg, collections.abc.Sequence):
- args_ = arg
- else:
- assert not isinstance(arg, collections.abc.Sequence)
- args_ = (arg,) + args
- res = tuple(a if a.ndim >= 1 else unsqueeze(a, 0) for a in args_)
- return res if len(res) > 1 else res[0]
- # Helper function with assert to avoid MyPy error
- # of incompatible type passed to unsqueeze
- def _unsqueeze_atleast(
- at_least_fn: Callable, dim: int, arg: TensorLikeType
- ) -> TensorLikeType:
- arg_ = at_least_fn(arg)
- assert isinstance(arg_, TensorLike)
- return unsqueeze(arg_, dim)
- # CompositeImplicitAutograd - don't register decomp
- def atleast_2d(
- arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType
- ) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]:
- """Reference implementation of :func:`torch.atleast_2d`."""
- if not args and isinstance(arg, collections.abc.Sequence):
- args_ = arg
- else:
- assert not isinstance(arg, collections.abc.Sequence)
- args_ = (arg,) + args
- unsqueeze_atleast_1d = partial(_unsqueeze_atleast, atleast_1d, 0)
- res = tuple(a if a.ndim >= 2 else unsqueeze_atleast_1d(a) for a in args_)
- return res if len(res) > 1 else res[0]
- # CompositeImplicitAutograd - don't register decomp
- def atleast_3d(
- arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType
- ) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]:
- """Reference implementation of :func:`torch.atleast_3d`."""
- if not args and isinstance(arg, collections.abc.Sequence):
- args_ = arg
- else:
- assert not isinstance(arg, collections.abc.Sequence)
- args_ = (arg,) + args
- unsqueeze_atleast_2d = partial(_unsqueeze_atleast, atleast_2d, -1)
- res = tuple(a if a.ndim >= 3 else unsqueeze_atleast_2d(a) for a in args_)
- return res if len(res) > 1 else res[0]
- def as_strided(
- a: TensorLikeType,
- size: ShapeType,
- stride: StrideType,
- storage_offset: Optional[int] = None,
- ) -> TensorLikeType:
- storage_offset_int = (
- storage_offset if storage_offset is not None else a.storage_offset()
- )
- return prims.as_strided(a, size, stride, storage_offset_int)
- @register_decomposition(aten.as_strided_scatter)
- def as_strided_scatter(
- input: TensorLikeType,
- src: TensorLikeType,
- size: ShapeType,
- stride: StrideType,
- storage_offset: Optional[int] = None,
- ) -> TensorLikeType:
- storage_offset_int = 0 if storage_offset is None else storage_offset
- return prims.as_strided_scatter(input, src, size, stride, storage_offset_int)
- def broadcast_shapes(*shapes) -> ShapeType:
- return torch.Size(_broadcast_shapes(*shapes))
- @aten.broadcast_tensors.default.py_impl(DispatchKey.CompositeImplicitAutograd)
- @aten.broadcast_tensors.default.py_impl(DispatchKey.Meta)
- def broadcast_tensors(*tensors) -> List[TensorLikeType]:
- if len(tensors) == 1 and not isinstance(tensors[0], Tensor):
- tensors = tensors[0]
- return list(_maybe_broadcast(*tensors, preserve_cpu_scalar_tensors=False))
- # CompositeImplicitAutograd - don't register decomp
- def broadcast_to(a: TensorLikeType, size: ShapeType) -> TensorLikeType:
- start = len(size) - len(a.shape)
- dims = tuple(range(start, len(a.shape) + start))
- return prims.broadcast_in_dim(a, size, dims)
- @register_decomposition(aten.cat)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("tensors",),
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
- )
- def cat(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType:
- def cat_compute_output_memory_format(inputs):
- format = None
- for t in inputs:
- f = utils.suggest_memory_format(t)
- if f == torch.contiguous_format:
- return f
- if format is not None and format != f:
- return torch.contiguous_format
- format = f
- assert format is not None
- return format
- if len(tensors) == 0:
- msg = "cat expects at least one tensor, but received zero!"
- raise ValueError(msg)
- for tensor in tensors:
- assert isinstance(tensor, TensorLike)
- utils.check_same_device(*tensors, allow_cpu_scalar_tensors=False)
- for t in tensors:
- # match logic in legacy_cat_wrap_dim
- if t.ndim == 1 and t.size(0) == 0:
- continue
- dim = utils.canonicalize_dim(t.ndim, dim)
- utils.validate_idx(t.ndim, dim)
- break
- memory_format = cat_compute_output_memory_format(tensors)
- # Filters tensors with one dimension of length zero
- filtered = tuple(x for x in tensors if not (x.ndim == 1 and x.numel() == 0))
- if len(filtered) == 0:
- t = tensors[0]
- # TODO: fix this to work with meta tensors
- try:
- requires_grad = any(x.requires_grad for x in tensors)
- except Exception:
- requires_grad = False
- return empty(
- (0,),
- dtype=t.dtype,
- device=t.device,
- requires_grad=requires_grad,
- memory_format=memory_format,
- )
- return prims.cat(filtered, dim).clone(memory_format=memory_format)
- # CompositeImplicitAutograd - don't register decomp
- @out_wrapper()
- def column_stack(tensors: TensorSequenceType) -> TensorLikeType:
- aligned_tensors = tuple(
- x if x.ndim > 1 else x.reshape((x.numel(), 1)) for x in tensors
- )
- return cat(aligned_tensors, 1)
- def conj(input: TensorLikeType) -> TensorLikeType:
- if not utils.is_complex_dtype(input.dtype):
- return input
- if input.is_sparse:
- return torch.conj_physical(input)
- return prims.conj(input)
- # This replicates at::constant_pad_nd, defined in ATen/native/PadNd.cpp
- @register_decomposition(aten.constant_pad_nd)
- def constant_pad_nd(
- input: TensorLikeType, pad: List[int], value: NumberType = 0
- ) -> TensorLikeType:
- check(
- len(pad) % 2 == 0,
- lambda: f"Length of pad must be even but instead it equals {len(pad)}",
- )
- input_sizes = input.shape
- l_inp = len(input_sizes)
- l_pad = len(pad) // 2
- l_diff = l_inp - l_pad
- check(
- l_inp >= l_pad,
- lambda: "Length of pad should be no more than twice the number of "
- f"dimensions of the input. Pad length is {len(pad)} while the input has "
- f"{l_inp} dimensions.",
- )
- c_input = input
- for i in range(l_diff, l_inp):
- pad_idx = 2 * (l_inp - i - 1)
- if pad[pad_idx] < 0:
- c_input = c_input.narrow(i, -pad[pad_idx], c_input.shape[i] + pad[pad_idx])
- if pad[pad_idx + 1] < 0:
- c_input = c_input.narrow(i, 0, c_input.shape[i] + pad[pad_idx + 1])
- # if none of the pads are positive we can just return the result
- if builtins.all(p <= 0 for p in pad):
- return c_input.clone()
- new_shape = list(input_sizes[:l_diff])
- for i in range(l_pad):
- pad_idx = len(pad) - ((i + 1) * 2)
- new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1]
- check(
- new_dim > 0,
- lambda: f"The input size {input_sizes[l_diff + i]}, plus negative padding "
- f"{pad[pad_idx]} and {pad[pad_idx + 1]} resulted in a negative output size, "
- f"which is invalid. Check dimension {l_diff + i} of your input.",
- )
- new_shape.append(new_dim)
- memory_format = utils.suggest_memory_format(input)
- output = torch.empty(
- new_shape,
- dtype=input.dtype,
- device=input.device,
- requires_grad=input.requires_grad,
- memory_format=memory_format,
- )
- if value == 0 and input.dtype == torch.bool:
- value = False
- # torch.fill isn't typed to allow complex values
- output = torch.fill(output, value) # type: ignore[arg-type]
- c_output = output
- for i in range(l_diff, l_inp):
- pad_idx = 2 * (l_inp - i - 1)
- if pad[pad_idx] > 0:
- c_output = c_output.narrow(
- i, pad[pad_idx], c_output.shape[i] - pad[pad_idx]
- )
- if pad[pad_idx + 1] > 0:
- c_output = c_output.narrow(i, 0, c_output.shape[i] - pad[pad_idx + 1])
- prims.copy_to(c_output, c_input)
- return output
- def contiguous(
- a: Tensor, *, memory_format: torch.memory_format = torch.contiguous_format
- ) -> Tensor:
- check(
- memory_format != torch.preserve_format,
- lambda: "preserve memory format is unsupported by the contiguous operator",
- )
- if utils.is_contiguous_for_memory_format(a, memory_format=memory_format):
- return a
- return torch.clone(a, memory_format=memory_format)
- @out_wrapper()
- def dstack(tensors: TensorSequenceType) -> TensorLikeType:
- check(len(tensors) > 0, lambda: "dstack expects a non-empty TensorList")
- aligned_tensors = atleast_3d(*tensors)
- return cat(aligned_tensors, 2)
- @register_decomposition(aten.expand)
- def expand(a: Tensor, *shape) -> Tensor:
- # NOTE: cannot use utils.extract_shape_from_varargs here
- # because that also validates the shape, but the shape
- # given to expand may be "invalid"
- if len(shape) == 1 and isinstance(shape[0], Sequence):
- shape = tuple(shape[0])
- check(
- len(shape) >= len(a.shape),
- lambda: "expand: the requested shape has too few dimensions!",
- )
- offset = len(shape) - len(a.shape)
- shape_ = list(shape)
- for idx, x in enumerate(a.shape):
- offset_idx = idx + offset
- requested_length = shape[offset_idx]
- check(
- requested_length == x or x == 1 or requested_length == -1,
- lambda: f"expand: attempting to expand a dimension of length {x}!",
- )
- shape_[offset_idx] = requested_length if requested_length != -1 else x
- # At this point shape must be valid
- utils.validate_shape(shape_)
- return prims.broadcast_in_dim(
- a, shape_, tuple(range(offset, len(a.shape) + offset))
- )
- # CompositeImplicitAutograd - don't register decomp
- def expand_as(a: Tensor, b: Tensor) -> Tensor:
- return a.expand(b.shape)
- def chunk(a: TensorLikeType, chunks: int, dim: int = 0) -> Tuple[TensorLikeType, ...]:
- if chunks <= 0:
- msg = "Expected at least one chunk, but got {0}!".format(chunks)
- raise ValueError(msg)
- dim = utils.canonicalize_dim(a.ndim, dim)
- length = a.shape[dim]
- chunk_size = math.ceil(length / chunks)
- full_chunks = math.floor(length / chunk_size)
- tail_chunk_size = length % chunk_size
- result = []
- for i in range(full_chunks):
- result.append(narrow(a, dim, i * chunk_size, chunk_size))
- if tail_chunk_size != 0:
- result.append(narrow(a, dim, full_chunks * chunk_size, tail_chunk_size))
- return tuple(result)
- # Note: flatten, unlike prim.collapse and prim.collapse_view has an inclusive end_dim
- # Note: flatten, unlike other shape operators, returns the input tensor on a no-op (unless
- # a 0D tensor is flattened, in which case it's returned in 1D)
- # CompositeImplicitAutograd - don't register decomp
- def flatten(a: TensorLikeType, start_dim: int = 0, end_dim: int = -1) -> TensorLikeType:
- start_dim = utils.canonicalize_dim(a.ndim, start_dim)
- end_dim = utils.canonicalize_dim(a.ndim, end_dim)
- # Short-circuits on no-op
- if start_dim == end_dim and a.ndim != 0:
- return a
- # Tries to take a view
- # TODO: we could look at directing collapse_view to skip its meta function here (unsafe_collapse_view)
- new_shape, new_strides = prims._collapse_view_helper(a, start_dim, end_dim + 1)
- if new_shape is not None:
- return prims.collapse_view(a, start_dim, end_dim + 1)
- # Makes a copy if it can't make a view
- return prims.collapse(a, start_dim, end_dim + 1)
- @register_decomposition(aten.flip)
- def flip(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType:
- if not isinstance(dims, tuple) and not isinstance(dims, list):
- raise ValueError("dims has to be a sequence of ints")
- dims = utils.canonicalize_dims(a.ndim, dims) # type: ignore[assignment]
- utils.validate_no_repeating_dims(dims)
- return prims.rev(a, dims)
- # CompositeImplicitAutograd - don't register decomp
- def fliplr(a: TensorLikeType) -> TensorLikeType:
- if a.ndim < 2:
- raise RuntimeError("Input must be >= 2-d.")
- return flip(a, (1,))
- # CompositeImplicitAutograd - don't register decomp
- def flipud(a: TensorLikeType) -> TensorLikeType:
- if a.ndim < 1:
- raise RuntimeError("Input must be >= 1-d.")
- return flip(a, (0,))
- # CompositeImplicitAutograd - don't register decomp
- def narrow(
- a: TensorLikeType, dim: int, start: Union[int, TensorLikeType], length: int
- ) -> TensorLikeType:
- # Supports Tensor overload that was added for XLA:
- # https://github.com/pytorch/pytorch/issues/31558
- if isinstance(start, TensorLike):
- check(
- start.dim() == 0 and utils.is_integer_dtype(start.dtype),
- lambda: "start must be an 0-dim integral Tensor.",
- )
- start = start.item() # type: ignore[assignment]
- check(a.dim() > 0, lambda: "narrow() cannot be applied to a 0-dim tensor.")
- check(length >= 0, lambda: "narrow(): length must be non-negative.")
- dim = utils.canonicalize_dim(a.ndim, dim)
- dim_length = a.size(dim)
- # Start being the end is usually invalid since it's out of bounds. So it's
- # not allowed by canonicalize_dim. But for narrow it's valid as long as
- # the length is 0, which is handled by the check below.
- if start != dim_length:
- # Negative start means indexing from the end of dim.
- # Note: a dimension isn't being canonicalized here, this reuses
- # canonicalize_dim because the semantics are similar.
- start = utils.canonicalize_dim(dim_length, start) # type: ignore[arg-type]
- check(
- start <= dim_length - length, # type: ignore[arg-type]
- lambda: f"start ({start}) + length ({length}) exceeds dimension size ({dim_length}).",
- )
- return prims.slice_in_dim(a, start, start + length, axis=dim)
- # TODO: This must return a sparse tensor if the input is sparse, but refs have
- # no sparse support. See narrow_copy_sparse in core.
- narrow_copy = _make_copy_from_view(narrow)
- def _normalize(
- a: Tensor, norm_dims: DimsType, eps: float
- ) -> Tuple[Tensor, Tensor, Tensor]:
- """Computes mean and 1/std of a tensor along norm_dims.
- Used as a helper function for normalization layers.
- Args:
- a (Tensor): input tensor
- norm_dims (DimsType): dimensions to normalize over
- eps (float): epsilon for numerical stability
- Returns:
- out (Tensor): normalized tensor.
- mean (Tensor): mean of the tensor along norm_dims.
- rstd (Tensor): 1/std of the tensor along norm_dims.
- """
- norm_dims = utils.canonicalize_dims(a.ndim, norm_dims)
- computation_dtype = utils.get_computation_dtype(a.dtype)
- a_acc = _maybe_convert_to_dtype(a, computation_dtype)
- assert isinstance(a_acc, TensorLike) # to avoid mypy error for var_mean
- biased_var, mean = torch.var_mean(
- a_acc, dim=norm_dims, unbiased=False, keepdim=True
- )
- rstd = torch.rsqrt(biased_var + eps)
- out = (a - mean) * rstd
- return out, mean, rstd
- # add all specified dimensions
- def _unsqueeze_multiple(x: TensorLikeType, dimensions: List[int]) -> TensorLikeType:
- for dim in sorted(dimensions):
- x = torch.unsqueeze(x, dim)
- return x
- @register_decomposition(aten.native_group_norm.default)
- def native_group_norm(
- input: Tensor,
- weight: Optional[Tensor],
- bias: Optional[Tensor],
- batch_size: int,
- num_channels: int,
- flattened_inner_size: int,
- num_groups: int,
- eps: float,
- ) -> Tuple[Tensor, Tensor, Tensor]:
- utils.check(
- input.ndim >= 2,
- lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}",
- )
- utils.check(
- num_channels % num_groups == 0,
- lambda: "Expected number of channels in input to be divisible by num_groups, "
- + f"but got input of shape {input.shape} and num_groups = {num_groups}",
- )
- # num_channels / num_groups and flattened inner dimension are the reduction axes
- reduction_dims = [2, 3]
- input_reshaped = torch.reshape(
- input,
- [batch_size, num_groups, num_channels // num_groups, flattened_inner_size],
- )
- out, mean, rstd = _normalize(input_reshaped, reduction_dims, eps)
- out = out.view(input.shape)
- broadcast_dims = [0] + list(range(2, input.ndim))
- unsqueeze_bias = None
- if bias is not None:
- unsqueeze_bias = _unsqueeze_multiple(bias, broadcast_dims)
- unsqueeze_weight = None
- if weight is not None:
- unsqueeze_weight = _unsqueeze_multiple(weight, broadcast_dims)
- if unsqueeze_weight is not None:
- out = out * unsqueeze_weight
- if unsqueeze_bias is not None:
- out = out + unsqueeze_bias
- out = _maybe_convert_to_dtype(out, input.dtype) # type: ignore[assignment]
- mean = _maybe_convert_to_dtype(mean, input.dtype) # type: ignore[assignment]
- rstd = _maybe_convert_to_dtype(rstd, input.dtype) # type: ignore[assignment]
- # remove broadcast dimensions from mean and rstd
- mean = torch.squeeze(mean, reduction_dims)
- rstd = torch.squeeze(rstd, reduction_dims)
- return (out, mean, rstd)
- @register_decomposition(aten.native_layer_norm)
- def native_layer_norm(
- input: Tensor,
- normalized_shape: ShapeType,
- weight: Optional[Tensor],
- bias: Optional[Tensor],
- eps: float,
- ) -> Tuple[Tensor, Tensor, Tensor]:
- normalized_ndim = len(normalized_shape)
- utils.check(
- normalized_ndim >= 1,
- lambda: "Expected normalized_shape to be at least 1-dimensional, i.e., "
- + "containing at least one element, but got normalized_shape = "
- + str(normalized_shape),
- )
- # torch.Size([1, 2, 3]) == [1, 2, 3] evaluates to False
- # while torch.Size([1, 2, 3]) == (1, 2, 3) is True
- # therefore we use tuple(normalized_shape)
- utils.check(
- weight is None or weight.shape == tuple(normalized_shape),
- lambda: "Expected weight to be of same shape as normalized_shape, but got "
- + "weight of shape "
- + str(weight.shape) # type: ignore[union-attr]
- + " and normalized_shape = "
- + str(normalized_shape),
- )
- utils.check(
- bias is None or bias.shape == tuple(normalized_shape),
- lambda: "Expected bias to be of same shape as normalized_shape, but got "
- + "bias of shape "
- + str(bias.shape) # type: ignore[union-attr]
- + " and normalized_shape = "
- + str(normalized_shape),
- )
- utils.check(
- input.ndim >= normalized_ndim
- and input.shape[(input.ndim - normalized_ndim) :] == tuple(normalized_shape),
- lambda: "Given normalized_shape="
- + str(normalized_shape)
- + ", expected input with shape "
- + str(normalized_shape)
- + ", but got input of size "
- + str(input.shape),
- )
- input = input.contiguous()
- if weight is not None:
- weight = weight.contiguous()
- if bias is not None:
- bias = bias.contiguous()
- axis = input.ndim - normalized_ndim
- reduction_dims = list(range(axis, input.ndim))
- out, mean, rstd = _normalize(input, reduction_dims, eps)
- if weight is None and bias is not None:
- out = out + bias
- elif weight is not None and bias is None:
- out = out * weight
- elif weight is not None and bias is not None:
- out = out * weight + bias
- out = _maybe_convert_to_dtype(out, input.dtype) # type: ignore[assignment]
- if input.device.type == "cpu":
- mean = _maybe_convert_to_dtype(mean, input.dtype) # type: ignore[assignment]
- rstd = _maybe_convert_to_dtype(rstd, input.dtype) # type: ignore[assignment]
- return (out, mean, rstd)
- # TODO: Adding this as a meta function causes functorch tests to fail when compiled with debug mode.
- # test/test_eager_transforms.py::TestFunctionalizeCPU::test_functionalize_fx_transpose_simple_cpu
- @register_decomposition(aten.permute)
- def permute(a: TensorLikeType, *dims) -> TensorLikeType:
- _permutation = utils.canonicalize_dims(
- a.ndim, utils.extract_dims_from_varargs(dims)
- )
- return prims.transpose(a, _permutation)
- # Get the new shape and stride after applying unfold to an input tensor
- def _get_unfold_shape_stride(
- a_shape: ShapeType, a_stride: StrideType, dimension: int, size: int, step: int
- ):
- a_ndim = len(a_shape)
- dim = utils.canonicalize_dim(a_ndim, dimension, wrap_scalar=True)
- max_size = 1 if a_ndim == 0 else a_shape[dim]
- last_stride = 1 if a_ndim == 0 else a_stride[dim]
- utils.check(
- size <= max_size,
- lambda: f"Maximum size for tensor at dimension {dim} is {max_size} but size is {size}",
- )
- utils.check(
- step > 0,
- lambda: f"Step is {step} but must be > 0",
- )
- shape = list(a_shape)
- strides = list(a_stride)
- shape.append(size)
- strides.append(last_stride)
- if dim < a_ndim:
- shape[dim] = (shape[dim] - size) // step + 1
- strides[dim] *= step
- return shape, strides
- @register_decomposition(aten.repeat)
- def repeat(a: Tensor, *repeat_shape) -> Tensor:
- repeat_shape = utils.extract_shape_from_varargs(repeat_shape, validate=False)
- utils.check(
- len(repeat_shape) >= len(a.shape),
- lambda: "repeat: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor",
- )
- if len(repeat_shape) == 0:
- return torch.clone(a)
- num_new_dimensions = len(repeat_shape) - a.ndim
- padded_shape = [1] * num_new_dimensions
- for dim_size in a.shape:
- padded_shape.append(dim_size)
- target_shape = tuple(
- padded_size * repeat_size
- for padded_size, repeat_size in zip(padded_shape, repeat_shape)
- )
- # return an empty tensor if one of the repeat_shape dimensions is zero
- if 0 in repeat_shape:
- return torch.empty(
- target_shape,
- dtype=a.dtype,
- device=a.device,
- requires_grad=a.requires_grad,
- memory_format=utils.suggest_memory_format(a),
- )
- urtensor_shape = target_shape
- urtensor_stride = utils.make_contiguous_strides_for(target_shape)
- for dim, dim_size in enumerate(padded_shape):
- # repeat each dimension by using unfold_copy operation
- urtensor_shape, urtensor_stride = _get_unfold_shape_stride(
- urtensor_shape, urtensor_stride, dim, dim_size, max(dim_size, 1)
- )
- # derive permute order by sorting urtensor strides
- enumerated_stride = list(enumerate(urtensor_stride))
- enumerated_stride.sort(key=lambda item: item[1], reverse=True)
- permute_order, sorted_stride = zip(*enumerated_stride)
- # add new and expand dimensions according to urtensor
- repeat_xtensor = a.expand(urtensor_shape)
- # clone tensor to concretize expanded dimensions
- cloned_result = torch.clone(repeat_xtensor)
- # transpose axis so strides are in sorted order
- permuted_result = cloned_result.permute(permute_order)
- # reshape to get contiguous tensor with correct target shape
- return permuted_result.reshape(target_shape)
- def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorLikeType:
- # Creates a valid shape
- shape = utils.extract_shape_from_varargs(shape, validate=False)
- # Reshape may be given a shape with a -1 length
- # This indicates that the dimension's length should be inferred
- shape = utils.infer_size(shape, a.numel())
- # Short-circuits if shape is the same
- if tuple(a.shape) == tuple(shape):
- return prims.view_of(a)
- # Special-cases tensors with no elements
- if a.numel() == 0:
- return as_strided(a, shape, utils.make_contiguous_strides_for(shape))
- # Special-cases reshaping zero dim tensors
- if a.ndim == 0:
- _a = a
- for length in shape:
- assert length == 1
- _a = unsqueeze(_a, -1)
- return _a
- # Special-cases reshaping to zero dim tensors
- if len(shape) == 0:
- _a = a
- for length in a.shape:
- assert length == 1
- _a = squeeze(_a, -1)
- return _a
- # Handles general case: a 1+D tensor reshaped into a distinct 1+D shape
- # NOTE [Reshape Algorithm]
- # This algorithm works by attempting to greedily construct the desired dimensions in
- # the output shape, left to right. It does this by, conceptually, accumulating
- # dimensions of the original tensor, also left to right, until the dimension
- # can be constructed using prims.split_dim.
- # The algorithm also has special handling for tail squeezes/unsqueezes, like
- # if a reshape from (5, 5) to (5, 5, 1) or vice versa.
- #
- # This algorithm does not flatten the original tensor and then split dims as appropriate
- # because that would create copies more often than this algorithm. flatten is the only
- # operation below which can create a view or a copy, and while it prefers creating
- # views it may sometimes create a copy if the tensor's strides do not permit a view.
- # As a result, this algorithm tries to minimize flattening.
- #
- # Note that a better version of this algorithm may exist. Regions which could be
- # flattened without creating a copy can be identified in advance, and that might
- # allow fewer flatten calls or faster short-circuiting to make a copy.
- idx = 0
- a_ = a
- for length in shape:
- # Handles tail unsqueezes
- if idx >= a_.ndim:
- assert length == 1
- last_dim = a_.ndim - 1
- # NOTE: using split_dim instead of unsqueeze may seem silly here,
- # but it's necessary to get the strides correct
- a_ = prims.split_dim(a_, last_dim, a_.shape[last_dim])
- idx = idx + 1
- continue
- # Skips dimensions that are already the correct length
- if length == a_.shape[idx]:
- idx = idx + 1
- continue
- # Gathers enough original dimensions such that this new dimension can be created
- # Note that this accumulation will terminate because we've verified a and the shape
- # specify the same number of elements above
- accum = a_.shape[idx]
- end = idx
- while accum % length != 0:
- end = end + 1
- accum = accum * a_.shape[end]
- if end != idx:
- # NOTE: in this case multiple dimensions must be flatten to create the desired dimension
- # This flattening is why reshape sometimes creates a copy -- because flattening
- # may return a view of a copy
- # Checks if collapse can be a view and short-circuits to copying reshape if it can't
- new_shape, new_strides = prims._collapse_view_helper(a_, idx, end + 1)
- if new_shape is None:
- if allow_copy:
- return prims.reshape(a, shape)
- msg = "Cannot view a tensor with shape {0} and strides {1} as a tensor with shape {2}!".format(
- a.shape, a.stride(), shape
- )
- raise ValueError(msg)
- a_ = flatten(a_, idx, end)
- # Splits the (possibly flattened) dimension to create the desired dim length
- if accum != length:
- a_ = prims.split_dim(a_, idx, length)
- idx = idx + 1
- # Squeezes tail
- while idx < a_.ndim:
- assert a_.shape[idx] == 1
- a_ = squeeze(a_, idx)
- return a_
- # CompositeImplicitAutograd - don't register decomp
- # NOTE: shape is a vararg because Tensor.reshape can be called with as
- # Tensor.reshape(a, b, c) or Tensor.reshape((a, b, c)) Function call
- # torch.reshape doesn't support unpacked shapes
- def reshape(a: TensorLikeType, *shape: ShapeType) -> TensorLikeType:
- return _reshape_view_helper(a, *shape, allow_copy=True)
- # CompositeImplicitAutograd - don't register decomp
- def reshape_as(self: TensorLikeType, other: TensorLikeType) -> TensorLikeType:
- return self.reshape(other.size())
- @register_decomposition(aten.roll)
- def roll(
- a: TensorLikeType, shifts: DimsType, dims: DimsType = tuple()
- ) -> TensorLikeType:
- """Reference implementation of :func:`torch.roll`."""
- dims = utils.canonicalize_dims(a.ndim, dims)
- # ATen specifies int[1] type for shifts and dims which expands integers to tuples of length 1
- if not isinstance(shifts, Iterable):
- shifts = (shifts,)
- if not isinstance(dims, Iterable):
- dims = (dims,)
- # Avoid modulo by zero
- if a.numel() == 0:
- # Keeping this as ref for now as FakeTensor runs into some issues with complex tensors
- return clone(a)
- len_shifts = len(shifts)
- len_dims = len(dims)
- if len_shifts != 1 or len_dims != 1:
- if len_shifts == 0:
- raise RuntimeError("`shifts` required")
- # Takes care of the case when dims is not specified (default)
- # By default, the tensor is flattened before shifting, after which the original shape is restored
- if len_dims == 0 and len_shifts == 1:
- return torch.roll(torch.flatten(a), shifts, 0).view(a.shape)
- if len_shifts != len_dims:
- raise RuntimeError(
- f"shifts and dimensions must align. shifts: {len_shifts}, dims: {len_dims}"
- )
- assert len_dims > 1
- tail_shifts = shifts[1:]
- tail_dims = dims[1:]
- first_dim_rolled = torch.roll(a, shifts[0], dims[0])
- return torch.roll(first_dim_rolled, tail_shifts, tail_dims)
- # This path is taken when only one dimension is rolled
- # For example to get `first_dim_rolled` above
- dim = dims[0]
- size = a.shape[dim]
- start = (size - shifts[0]) % size
- t0 = torch.narrow(a, dim, start, size - start)
- t1 = torch.narrow(a, dim, 0, start)
- return torch.cat((t0, t1), dim)
- @register_decomposition(aten.rot90)
- def rot90(
- a: TensorLikeType, k: int = 1, dims: DimsSequenceType = (0, 1)
- ) -> TensorLikeType:
- """Reference implementation of :func:`torch.rot90`."""
- if len(dims) != 2:
- raise RuntimeError(
- f"expected total rotation dims == 2, but got dims = {len(dims)}"
- )
- if a.ndim < 2:
- raise RuntimeError(f"expected total dims >= 2, but got total dims = {a.ndim}")
- # Do this after the initial checks to be compatible with the behavior in
- # core.
- dims = utils.canonicalize_dims(a.ndim, dims)
- if dims[0] == dims[1]:
- raise RuntimeError(
- f"expected rotation dims to be different, but got dim0 = {dims[0]} and dim1 = {dims[1]}"
- )
- k = k % 4 # Rotation direction is from the second towards the first axis for k < 0
- if k == 1:
- return torch.transpose(torch.flip(a, (dims[1],)), dims[0], dims[1])
- elif k == 2:
- return torch.flip(a, dims)
- elif k == 3:
- return torch.transpose(torch.flip(a, (dims[0],)), dims[0], dims[1])
- else:
- return clone(a, memory_format=torch.contiguous_format)
- def _check_stack_inputs(tensors: TensorSequenceType) -> None:
- entry_shape = tensors[0].shape
- for i in range(1, len(tensors)):
- assert tensors[i].shape == entry_shape, (
- f"stack expects each tensor to be equal size, but got {entry_shape} at entry 0"
- f"and {tensors[i].shape} at entry {i}"
- )
- @register_decomposition(aten.stack)
- @out_wrapper()
- def stack(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType:
- assert len(tensors) > 0, "stack expects a non-empty TensorList"
- wrapped_dim = utils.canonicalize_dim(tensors[0].ndim + 1, dim)
- # Refs need sparse support to check other condition
- if wrapped_dim < tensors[0].ndim: # and not tensors[0].is_sparse:
- _check_stack_inputs(tensors)
- result_sizes = list(tensors[0].shape)
- result_sizes.insert(wrapped_dim, len(tensors))
- out = torch.cat(tensors, wrapped_dim)
- return out.view(result_sizes)
- # If dim == tensors[0].ndim, view cannot efficiently handle it
- return torch.cat([t.unsqueeze(wrapped_dim) for t in tensors], dim)
- # CompositeImplicitAutograd - don't register decomp
- @out_wrapper()
- def softmax(
- a: TensorLikeType,
- dim: int,
- dtype: Optional[torch.dtype] = None,
- ) -> TensorLikeType:
- result_dtype = dtype or a.dtype
- computation_dtype = utils.get_computation_dtype(result_dtype)
- a_ = _maybe_convert_to_dtype(a, computation_dtype)
- if a.numel() == 0:
- a_exp = exp(a_)
- else:
- a_max = amax(a_, dim, keepdim=True)
- a_exp = exp(a_ - a_max)
- return _maybe_convert_to_dtype(
- true_divide(a_exp, sum(a_exp, dim, keepdim=True)), result_dtype
- ) # type: ignore[return-value]
- # CompositeImplicitAutograd - don't register decomp
- @out_wrapper()
- def hstack(tensors: TensorSequenceType) -> TensorLikeType:
- check(len(tensors) > 0, lambda: "hstack expects a non-empty TensorList")
- aligned_tensors = atleast_1d(*tensors)
- if aligned_tensors[0].ndim == 1:
- return cat(aligned_tensors, 0)
- return cat(aligned_tensors, 1)
- # CompositeImplicitAutograd - don't register decomp
- @out_wrapper()
- def vstack(tensors: TensorSequenceType) -> TensorLikeType:
- check(len(tensors) > 0, lambda: "vstack expects a non-empty TensorList")
- aligned_tensors = atleast_2d(*tensors)
- return cat(aligned_tensors, 0)
- # CompositeImplicitAutograd - don't register decomp
- def unflatten(a: TensorLikeType, dim: int, sizes: ShapeType) -> TensorLikeType:
- dim = utils.canonicalize_dim(a.ndim, dim)
- utils.check(len(sizes) != 0, lambda: "unflatten: sizes must be non-empty")
- return a.view(tuple(a.shape[:dim]) + tuple(sizes) + tuple(a.shape[dim + 1 :]))
- @register_decomposition(aten.unbind)
- def unbind(t: TensorLikeType, dim: int = 0) -> TensorSequenceType:
- dim = utils.canonicalize_dim(t.ndim, dim)
- check(
- len(t.shape) > 0,
- lambda: "Dimension specified as 0 but tensor has no dimensions",
- IndexError,
- )
- return tuple(
- torch.squeeze(s, dim) for s in torch.tensor_split(t, t.shape[dim], dim)
- )
- @out_wrapper()
- def index_copy(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike):
- return x.clone(memory_format=torch.contiguous_format).index_copy_(
- dim, index, tensor
- )
- def index_copy_(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike):
- dim = utils.canonicalize_dims(x.ndim, dim)
- utils.check(
- index.ndim <= 1,
- lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
- )
- # Treat scalars as elements of \R^1
- y = x.unsqueeze(0) if x.ndim == 0 else x
- idx = (slice(None),) * dim + (index,)
- y[idx] = tensor
- return x
- @register_decomposition(aten.index_fill)
- def index_fill(
- x: TensorLike, dim: int, index: TensorLike, value: Union[NumberType, TensorLike]
- ):
- return _index_fill(x, dim, index, value, inplace=False)
- @register_decomposition(aten.index_fill_)
- def index_fill_(
- x: TensorLike, dim: int, index: TensorLike, value: Union[NumberType, TensorLike]
- ):
- return _index_fill(x, dim, index, value, inplace=True)
- def _index_fill(
- x: TensorLike,
- dim: int,
- index: TensorLike,
- value: Union[NumberType, TensorLike],
- *,
- inplace: bool,
- ):
- utils.check(
- index.ndim <= 1,
- lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
- )
- if isinstance(value, TensorLike):
- utils.check(
- value.ndim == 0,
- lambda: "Only supports 0-dimensional value tensor. " # type: ignore[union-attr]
- f"Got a tensor with {value.ndim} dimensions.",
- ) # type: ignore[arg-type]
- else:
- value = torch.scalar_tensor(
- value, dtype=x.dtype, layout=x.layout, device=x.device # type: ignore[arg-type]
- )
- # index_copy has some unnecessary preconditions when x is a scalar. We do this to work through them
- zero_dim = x.ndim == 0
- y = x.unsqueeze(0) if zero_dim else x
- # index_copy does not broadcast on value so we have to do it manually
- shape = list(y.shape)
- shape[dim] = index.numel()
- value = value.expand(shape)
- index_copy = Tensor.index_copy_ if inplace else torch.index_copy
- out = index_copy(y, dim, index, value) # type: ignore[operator]
- if inplace:
- return x
- else:
- if zero_dim:
- # The clone is necessary so that it returns a fresh tensor rather than a view
- out = out.squeeze(0).clone()
- # index_fill preserves the strides. index_copy always returns contiguous tensors
- if out.stride() != x.stride():
- new_out = torch.empty_like(x)
- new_out.copy_(out)
- out = new_out
- return out
- @out_wrapper()
- def index_add(
- x: TensorLike,
- dim: int,
- index: TensorLike,
- tensor: TensorLike,
- *,
- alpha: NumberType = 1,
- ):
- # index_add always returns a new contiguous tensor
- return x.clone(memory_format=torch.contiguous_format).index_add_(
- dim, index, tensor, alpha=alpha # type: ignore[arg-type]
- )
- @register_decomposition(aten.index_select)
- @out_wrapper()
- def index_select(x: TensorLike, dim: int, index: TensorLike):
- dim = utils.canonicalize_dims(x.ndim, dim)
- utils.check(
- index.ndim <= 1,
- lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
- )
- if index.ndim == 0:
- index = index.unsqueeze(0)
- if x.ndim == 0:
- # Treat scalars as elements of \R^1
- # We cannot use x[idx] here as it accesses item() (??), hence this awkward construction
- return torch.empty_like(x).index_copy(0, index, x.expand_as(index))
- idx = (slice(None),) * dim + (index,)
- return x[idx]
- @register_decomposition(aten.squeeze)
- def squeeze(a: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType:
- if dim is None:
- dims = tuple(idx for idx, size in enumerate(a.shape) if size == 1)
- return prims.squeeze(a, dims) if dims else prims.view_of(a)
- ndim = a.ndim
- dim = utils.canonicalize_dims(ndim, dim)
- dims = (dim,) if isinstance(dim, Dim) else dim
- # Short-circuits if the tensor has no dimensions
- if ndim == 0:
- assert len(dims) == 0 or dims == (0,)
- return prims.view_of(a)
- # Note: squeeze does not modify tensors when the given dim is not a dimension of length 1
- dims = tuple(d for d in dims if a.shape[d] == 1)
- return prims.squeeze(a, dims) if dims else prims.view_of(a)
- # Note: does not work with TensorMetas because of data-dependent control-flow
- # CompositeImplicitAutograd - don't register decomp
- def tensor_split(
- a: TensorLikeType,
- indices_or_sections: Union[Tensor, DimsType],
- dim: int = 0,
- ) -> Tuple[TensorLikeType, ...]:
- _dim = utils.canonicalize_dim(a.ndim, dim)
- if a.ndim == 0:
- msg = "tensor_split: received a rank zero tensor, but expected a tensor of rank one or greater!"
- raise ValueError(msg)
- # If indices_or_sections is a tensor, it must be a CPU Long tensor
- if isinstance(indices_or_sections, TensorLike):
- if not indices_or_sections.device.type == "cpu":
- msg = "tensor_split: if indices_or_sections is a tensor it must be on the CPU, but received one on {0}".format(
- indices_or_sections.device
- )
- raise ValueError(msg)
- if indices_or_sections.dtype != torch.long:
- msg = "tensor_split: if indices_or_sections is a tensor it must have long dtype, "
- " but received one with dtype {0}".format(indices_or_sections.dtype)
- raise ValueError(msg)
- # Case 0 -- indices_or_sections is an integer or a scalar tensor n and a is split along dim into n parts of equal-ish length
- if isinstance(indices_or_sections, IntLike) or (
- isinstance(indices_or_sections, TensorLike) and indices_or_sections.ndim == 0
- ):
- sections: int = (
- indices_or_sections # type: ignore[assignment]
- if isinstance(indices_or_sections, Number)
- else indices_or_sections.item()
- )
- if sections <= 0:
- msg = "tensor_split: number of sections must be greater than 0, but was {0}".format(
- sections
- )
- raise ValueError(msg)
- splits = []
- dim_size = a.shape[_dim]
- min_split_size = math.floor(dim_size / sections)
- num_splits_one_extra = dim_size % sections
- start_idx = 0
- for split_idx in range(sections):
- split_size = (
- min_split_size + 1
- if (split_idx < num_splits_one_extra)
- else min_split_size
- )
- s = prims.slice_in_dim(a, start_idx, start_idx + split_size, axis=_dim)
- splits.append(s)
- start_idx = start_idx + split_size
- return tuple(splits)
- # Case 1 -- indices_or_sections is a sequence of integers or a 1D tensor describing the splits
- else:
- indices = indices_or_sections
- if isinstance(indices_or_sections, TensorLike):
- if indices_or_sections.ndim != 1:
- msg = "tensor_split: non-scalar indices_or_sections tensors must have only one dimension, "
- "but received a tensor with {0} dimensions".format(
- indices_or_sections.ndim
- )
- raise ValueError(msg)
- indices = indices_or_sections.tolist()
- splits = []
- start_idx = 0
- for x in indices:
- splits.append(prims.slice_in_dim(a, start_idx, x, axis=_dim))
- start_idx = x
- splits.append(prims.slice_in_dim(a, start_idx, a.shape[_dim], axis=_dim))
- return tuple(splits)
- # CompositeImplicitAutograd - don't register decomp
- def hsplit(
- a: TensorLikeType, indices_or_sections: DimsType
- ) -> Tuple[TensorLikeType, ...]:
- check(
- a.ndim >= 1,
- lambda: (
- "torch.hsplit requires a tensor with at least 1 dimension, but got a tensor with "
- + str(a.ndim)
- + " dimensions!"
- ),
- )
- dim = 0 if a.ndim == 1 else 1
- if isinstance(indices_or_sections, IntLike):
- split_size = indices_or_sections
- check(
- (split_size != 0 and a.shape[dim] % split_size == 0),
- lambda: (
- "torch.hsplit attempted to split along dimension "
- + str(dim)
- + ", but the size of the dimension "
- + str(a.shape[dim])
- + " is not divisible by the split_size "
- + str(split_size)
- + "!"
- ),
- )
- return tensor_split(a, split_size, dim)
- check(
- isinstance(indices_or_sections, (list, tuple)),
- lambda: (
- "hsplit(): received an invalid combination of arguments. "
- "Expected indices_or_sections to be of type int, list of ints or tuple of ints "
- f"but got type {type(indices_or_sections)}"
- ),
- exc_type=TypeError,
- )
- split_sizes = indices_or_sections
- return tensor_split(a, split_sizes, dim)
- # CompositeImplicitAutograd - don't register decomp
- def vsplit(
- a: TensorLikeType, indices_or_sections: DimsType
- ) -> Tuple[TensorLikeType, ...]:
- check(
- a.ndim >= 2,
- lambda: (
- "torch.vsplit requires a tensor with at least 2 dimension, but got a tensor with "
- + str(a.ndim)
- + " dimensions!"
- ),
- )
- if isinstance(indices_or_sections, IntLike):
- split_size = indices_or_sections
- check(
- (split_size != 0 and a.shape[0] % split_size == 0),
- lambda: (
- f"torch.vsplit attempted to split along dimension 0"
- f", but the size of the dimension "
- f"{a.shape[0]}"
- f" is not divisible by the split_size "
- f"{split_size}"
- f"!"
- ),
- )
- return tensor_split(a, split_size, 0)
- check(
- isinstance(indices_or_sections, (list, tuple)),
- lambda: (
- "vsplit(): received an invalid combination of arguments. "
- "Expected indices_or_sections to be of type int, list of ints or tuple of ints "
- f"but got type {type(indices_or_sections)}"
- ),
- exc_type=TypeError,
- )
- split_sizes = indices_or_sections
- return tensor_split(a, split_sizes, 0)
- @register_decomposition(aten.diag.out)
- @out_wrapper()
- def diag(
- self: TensorLikeType,
- offset: int = 0,
- ) -> TensorLikeType:
- ndim = self.dim()
- utils.check(
- ndim in (1, 2), lambda: f"diag(): Supports 1D or 2D tensors. Got {ndim}D"
- )
- if ndim == 1:
- return torch.diag_embed(self, offset)
- else:
- return torch.diagonal_copy(self, offset)
- @register_decomposition(aten.diagonal_scatter)
- @out_wrapper()
- def diagonal_scatter(
- input: TensorLikeType,
- src: TensorLikeType,
- offset: int = 0,
- dim1: int = 0,
- dim2: int = 1,
- ) -> TensorLikeType:
- out = utils.clone_preserve_strides(input)
- diag = out.diagonal(offset, dim1, dim2)
- check(
- diag.shape == src.shape,
- lambda: "expected src to have a size equal to the diagonal of the input."
- f"Got {src.shape} for a diagonal of shape {diag.shape}",
- )
- copy_to(diag, src)
- return out
- @register_decomposition(aten.diagonal)
- def diagonal(
- self: TensorLikeType,
- offset: int = 0,
- dim1: int = 0,
- dim2: int = 1,
- ) -> TensorLikeType:
- """
- Reference implementation of torch.diagonal
- """
- num_dims = self.dim()
- dim1 = utils.canonicalize_dim(idx=dim1, rank=num_dims)
- dim2 = utils.canonicalize_dim(idx=dim2, rank=num_dims)
- check(
- dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}"
- )
- storage_offset = self.storage_offset()
- if offset >= 0:
- diag_size = max(min(self.size()[dim1], self.size()[dim2] - offset), 0)
- else:
- diag_size = max(min(self.size()[dim1] + offset, self.size()[dim2]), 0)
- if diag_size > 0:
- if offset >= 0:
- storage_offset += offset * self.stride()[dim2]
- else:
- storage_offset -= offset * self.stride()[dim1]
- sizes = [s for i, s in enumerate(self.size()) if i not in (dim1, dim2)]
- sizes.append(diag_size)
- strides = [s for i, s in enumerate(self.stride()) if i not in (dim1, dim2)]
- strides.append(self.stride()[dim1] + self.stride()[dim2])
- result = self.as_strided(size=sizes, stride=strides, storage_offset=storage_offset)
- return result
- diagonal_copy = _make_copy_from_view(diagonal)
- @register_decomposition(aten.diag_embed)
- @out_wrapper()
- def diag_embed(
- t: TensorLikeType,
- offset: int = 0,
- dim1: int = -2,
- dim2: int = -1,
- ) -> TensorLikeType:
- """
- Reference implementation of torch.diag_embed
- """
- # as per the docs, exchanging dims is equivalent to changing the sign of
- # offset
- if dim1 > dim2:
- dim1, dim2 = dim2, dim1
- offset = -offset
- # convert from negative dims
- rank = t.ndim + 1
- dim1 = utils.canonicalize_dim(rank=rank, idx=dim1)
- dim2 = utils.canonicalize_dim(rank=rank, idx=dim2)
- check(
- dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}"
- )
- # as per the docs, the size of last dim is placed at dim1 and dim2
- last_dim = t.size(-1)
- if offset != 0:
- # add padding to match the new size
- t_shape = list(t.shape)
- t_shape[-1] = builtins.abs(offset)
- z = torch.zeros(t_shape, dtype=t.dtype, device=t.device, requires_grad=False)
- pair = (z, t) if offset > 0 else (t, z)
- t = torch.cat(pair, dim=-1)
- # make sure the diagonal always has the same size
- last_dim += builtins.abs(offset)
- # preserve original data, but place 1 at dim1 and move last dim to dim2
- t = t.unsqueeze(dim1).movedim(-1, dim2)
- # generate ranges shifting indices based on offset
- a_range = torch.arange(last_dim, device=t.device, dtype=torch.int64)
- b_range = torch.arange(
- offset, last_dim + offset, device=t.device, dtype=torch.int64
- )
- # broadcast
- cond = a_range == b_range.unsqueeze(-1)
- cond_shape = [last_dim if i in (dim1, dim2) else 1 for i in range(len(t.shape))]
- cond = cond.reshape(cond_shape)
- # aten.diag_embed always returns a new contiguous tensor
- # contiguous() is needed to correctly model the output stride
- return utils.mask_tensor(cond, t).contiguous()
- # CompositeImplicitAutograd - don't register decomp
- def dsplit(a: TensorLikeType, sections: DimsType) -> TensorSequenceType:
- if a.ndim < 3:
- raise RuntimeError(
- f"torch.dsplit requires a tensor with at least 3 dimension, but got a tensor with {a.ndim} dimensions!"
- )
- if isinstance(sections, IntLike) and (sections == 0 or a.shape[2] % sections != 0):
- raise RuntimeError(
- "torch.dsplit attempted to split along dimension 2, "
- + f"but the size of the dimension {a.shape[2]} is not divisible by the split_size {sections}!"
- )
- return tensor_split(a, sections, 2)
- @register_decomposition(aten.t.default)
- def t(a: TensorLikeType):
- # TODO: Add sparse support
- # if a.is_sparse:
- # sparse_dim = a.sparse_dim()
- # dense_dim = a.dense_dim()
- # if not (sparse_dim <= 2 and dense_dim == 0):
- # raise RuntimeError(
- # f"t() expects a tensor with <= 2 sparse and 0 dense dimensions, but got {sparse_dim} sparse and"
- # f"{dense_dim} dense dimensions"
- # )
- if a.ndim > 2:
- raise RuntimeError(
- f"t() expects a tensor with <= 2 dimensions, but self is {a.ndim}D"
- )
- return torch.transpose(a, 0, 0 if a.ndim < 2 else 1)
- # CompositeImplicitAutograd - don't register decomp
- def T(a: TensorLikeType) -> TensorLikeType:
- # n != 2 && n != 0 is deprecated in regular PyTorch.
- check(
- a.ndim in (0, 2),
- lambda: (
- "The use of `x.T` on tensors of dimension other than 0 or 2 "
- "to reverse their shape is not supported."
- ),
- )
- return a.t()
- @register_decomposition(aten.alias)
- def alias(a: TensorLikeType) -> TensorLikeType:
- return prims.view_of(a)
- @register_decomposition(aten.transpose)
- def transpose(a: TensorLikeType, dim0: int, dim1: int) -> TensorLikeType:
- _dim0, _dim1 = utils.canonicalize_dims(a.ndim, (dim0, dim1)) # type: ignore[misc]
- if a.ndim <= 1 or dim0 == dim1:
- return aten.alias.default(a)
- _permutation = list(range(0, a.ndim))
- _permutation[_dim0] = _dim1
- _permutation[_dim1] = _dim0
- return torch.permute(a, _permutation)
- # Aliases for transpose
- swap_axes = transpose
- @register_decomposition(aten.unfold)
- def unfold(
- self: TensorLikeType, dimension: int, size: int, step: int
- ) -> TensorLikeType:
- shape, strides = _get_unfold_shape_stride(
- self.shape, self.stride(), dimension, size, step
- )
- return self.as_strided(shape, strides)
- @register_decomposition(aten.unfold_copy)
- @out_wrapper()
- def unfold_copy(self: TensorLikeType, dimension: int, size: int, step: int):
- return self.unfold(dimension, size, step).clone(
- memory_format=torch.contiguous_format
- )
- @register_decomposition(aten.cumsum)
- def cumsum(
- a: TensorLikeType,
- dim: int,
- *,
- keepdim: bool = False,
- dtype: Optional[torch.dtype] = None,
- out: Optional[Tensor] = None,
- ) -> TensorLikeType:
- # We implement all the kwargs of a reduction. ATen just handles dtype
- # nb. This decomposition may not be as efficient as a backend-specific implementation
- ndim = a.ndim
- dim = utils.canonicalize_dim(ndim, dim)
- if ndim == 0:
- return sum(a.unsqueeze(0), dim=0, keepdim=keepdim, dtype=dtype, out=out)
- a = a.unsqueeze(dim + 1)
- rg = torch.arange(a.shape[dim], device=a.device)
- mask = rg.unsqueeze(1) <= rg
- for _ in range(ndim - dim - 1):
- mask = mask.unsqueeze(-1)
- masked_a = utils.mask_tensor(mask, a)
- return sum(masked_a, dim=dim, keepdim=keepdim, dtype=dtype, out=out)
- # Note: although squeeze is documented as having the out= kwarg it doesn't
- @register_decomposition(aten.unsqueeze)
- def unsqueeze(a: TensorLikeType, dim: int) -> TensorLikeType:
- # Note that unsqueeze canonicalizes with rank + 1 because it allows
- # a new innermost dimension to be specified
- ndim = a.ndim + 1
- dim = utils.canonicalize_dim(ndim, dim)
- return prims.expand_dims(a, (dim,), ndim=ndim)
- # NOTE: shape is a vararg because Tensor.reshape can be called with as
- # Tensor.view(a, b, c) or Tensor.view((a, b, c)) Function call torch.view
- # doesn't support unpacked shapes
- # TODO: Turn this into a decomposition (currently fails on reshape meta tests)
- @register_decomposition(aten.view)
- def view(a: TensorLikeType, *shape: ShapeType) -> TensorLikeType:
- return _reshape_view_helper(a, *shape, allow_copy=False)
- # CompositeImplicitAutograd - don't register decomp
- def view_as(self: TensorLikeType, other: TensorLikeType) -> TensorLikeType:
- return self.view(other.size())
- # CompositeImplicitAutograd - don't register decomp
- def ravel(a: TensorLikeType) -> TensorLikeType:
- return reshape(a, (-1,))
- @register_decomposition(aten.empty.memory_format)
- @out_wrapper()
- def empty(
- *shape,
- dtype: Optional[torch.dtype] = None,
- layout: torch.layout = torch.strided,
- device: Optional[torch.device] = None,
- requires_grad: bool = False,
- pin_memory: bool = False,
- memory_format: torch.memory_format = torch.contiguous_format,
- ) -> TensorLikeType:
- check(
- memory_format != torch.preserve_format,
- lambda: "torch.empty: the Preserve memory format is not supported",
- )
- shape = utils.extract_shape_from_varargs(shape)
- if memory_format == torch.contiguous_format:
- strides = utils.make_contiguous_strides_for(shape)
- elif memory_format == torch.channels_last_3d:
- strides = utils.make_channels_last_3d_strides_for(shape)
- else: # memory_format == torch.channels_last
- check(
- memory_format == torch.channels_last,
- lambda: f"torch.empty: received an unknown memory format {memory_format}!",
- )
- strides = utils.make_channels_last_2d_strides_for(shape)
- return torch.empty_strided(
- shape,
- strides,
- dtype=dtype,
- layout=layout,
- device=device,
- pin_memory=pin_memory,
- requires_grad=requires_grad,
- )
- @register_decomposition(aten.new_empty)
- def new_empty(
- a: TensorLikeType,
- size: ShapeType,
- *,
- dtype: Optional[torch.dtype] = None,
- layout: Optional[torch.layout] = None,
- device: Optional[torch.device] = None,
- pin_memory: bool = False,
- ) -> TensorLikeType:
- dtype = a.dtype if dtype is None else dtype
- layout = a.layout if layout is None else layout
- device = a.device if device is None else device
- return torch.empty(
- size,
- dtype=dtype,
- device=device,
- pin_memory=pin_memory,
- layout=layout,
- )
- @register_decomposition(aten.new_empty_strided)
- def new_empty_strided(
- a: TensorLikeType,
- size: ShapeType,
- stride: StrideType,
- *,
- dtype: Optional[torch.dtype] = None,
- layout: Optional[torch.layout] = None,
- device: Optional[torch.device] = None,
- pin_memory: bool = False,
- ) -> TensorLikeType:
- """
- Reference implementation of torch.Tensor.new_empty_strided
- """
- dtype = a.dtype if dtype is None else dtype
- layout = a.layout if layout is None else layout
- device = a.device if device is None else device
- return torch.empty_strided(
- size,
- stride,
- dtype=dtype,
- device=device,
- pin_memory=pin_memory,
- layout=layout,
- )
- @register_decomposition(aten.zeros.default)
- @out_wrapper()
- def zeros(
- *size,
- dtype: Optional[torch.dtype] = None,
- layout: torch.layout = torch.strided,
- device: Optional[torch.device] = None,
- pin_memory: bool = False,
- requires_grad: bool = False,
- ) -> TensorLikeType:
- size = utils.extract_shape_from_varargs(size)
- if dtype is None:
- dtype = torch.get_default_dtype()
- return torch.full(
- size,
- False if dtype == torch.bool else 0,
- dtype=dtype,
- layout=layout,
- device=device,
- pin_memory=pin_memory,
- requires_grad=requires_grad,
- )
- @register_decomposition(aten.new_zeros)
- def new_zeros(
- a: TensorLikeType,
- size: ShapeType,
- *,
- dtype: Optional[torch.dtype] = None,
- layout: Optional[torch.layout] = None,
- device: Optional[torch.device] = None,
- pin_memory: bool = False,
- requires_grad: bool = False,
- ) -> TensorLikeType:
- dtype = a.dtype if dtype is None else dtype
- layout = a.layout if layout is None else layout
- device = a.device if device is None else device
- return torch.full(
- size,
- False if (dtype or a.dtype) == torch.bool else 0,
- dtype=dtype,
- layout=layout,
- device=device,
- pin_memory=pin_memory,
- requires_grad=requires_grad,
- )
- @register_decomposition(aten.ones.default)
- @out_wrapper()
- def ones(
- *size,
- dtype: Optional[torch.dtype] = None,
- layout: torch.layout = torch.strided,
- device: Optional[torch.device] = None,
- pin_memory: bool = False,
- requires_grad: bool = False,
- ) -> TensorLikeType:
- size = utils.extract_shape_from_varargs(size)
- if dtype is None:
- dtype = torch.get_default_dtype()
- return torch.full(
- size,
- True if dtype == torch.bool else 1,
- dtype=dtype,
- layout=layout,
- device=device,
- pin_memory=pin_memory,
- requires_grad=requires_grad,
- )
- @register_decomposition(aten.new_ones)
- def new_ones(
- a: TensorLikeType,
- size: ShapeType,
- *,
- dtype: Optional[torch.dtype] = None,
- layout: Optional[torch.layout] = None,
- device: Optional[torch.device] = None,
- pin_memory: bool = False,
- requires_grad: bool = False,
- ) -> TensorLikeType:
- dtype = a.dtype if dtype is None else dtype
- layout = a.layout if layout is None else layout
- device = a.device if device is None else device
- return torch.full(
- size,
- True if (dtype or a.dtype) == torch.bool else 1,
- dtype=dtype,
- layout=layout,
- device=device,
- pin_memory=pin_memory,
- requires_grad=requires_grad,
- )
- @register_decomposition(aten.new_full)
- def new_full(
- a: TensorLikeType,
- size: ShapeType,
- fill_value: Union[int, float, bool],
- *,
- dtype: Optional[torch.dtype] = None,
- layout: Optional[torch.layout] = None,
- device: Optional[torch.device] = None,
- pin_memory: bool = False,
- ) -> TensorLikeType:
- dtype = a.dtype if dtype is None else dtype
- layout = a.layout if layout is None else layout
- device = a.device if device is None else device
- return torch.full(
- size,
- fill_value,
- dtype=dtype,
- layout=layout,
- device=device,
- pin_memory=pin_memory,
- )
- @register_decomposition(aten.empty_like)
- def empty_like(
- a: TensorLikeType,
- *,
- dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
- layout: Optional[torch.layout] = None,
- pin_memory: bool = False,
- requires_grad: bool = False,
- memory_format: torch.memory_format = torch.preserve_format,
- ) -> TensorLikeType:
- dtype = a.dtype if dtype is None else dtype
- layout = a.layout if layout is None else layout
- device = a.device if device is None else device
- strides: Tuple[int, ...]
- if memory_format != torch.preserve_format:
- return torch.empty(
- a.shape,
- dtype=dtype,
- layout=layout,
- device=device,
- requires_grad=requires_grad,
- pin_memory=pin_memory,
- memory_format=memory_format,
- )
- # memory_format == torch.preserve_format
- strides = utils.compute_elementwise_output_strides(a)
- return torch.empty_strided(
- a.shape,
- strides,
- dtype=dtype,
- layout=layout,
- device=device,
- pin_memory=pin_memory,
- requires_grad=requires_grad,
- )
- @register_decomposition(aten.arange)
- @out_wrapper()
- def arange(
- start: NumberType = 0,
- end: Optional[NumberType] = None,
- step: NumberType = 1,
- *,
- dtype: Optional[torch.dtype] = None,
- layout: torch.layout = torch.strided,
- device: Optional[torch.device] = None,
- pin_memory: bool = False,
- requires_grad: bool = False,
- ) -> TensorLikeType:
- utils.check_layout(layout)
- utils.check_pin_memory(pin_memory)
- device = torch.device(utils.device_or_default(device))
- assert not isinstance(start, complex)
- assert not isinstance(end, complex)
- assert not isinstance(step, complex)
- # Case: torch.arange(5)
- if end is None:
- end = start
- start = 0
- utils.check(step != 0, lambda: "step must be nonzero")
- utils.check(
- (step > 0 and end >= start) or (step < 0 and end <= start),
- lambda: "upper bound and lower bound inconsistent with step sign",
- )
- def is_finite(x):
- return not isinstance(x, FloatWithoutSymFloat) or math.isfinite(x)
- utils.check(
- is_finite(start) and is_finite(end),
- lambda: f"unsupported range: {start} -> {end}",
- )
- utils.check(
- is_finite(step),
- lambda: f"step must be finite but got {step}",
- )
- if dtype is None:
- args = (start, end, step)
- integer_args = builtins.all(isinstance(arg, IntLike) for arg in args)
- dtype = torch.int64 if integer_args else torch.get_default_dtype()
- is_integer = utils.is_integer_dtype(dtype)
- if is_integer:
- xstart = sym_int(start)
- xend = sym_int(end)
- xstep = sym_int(step)
- # For int64 we truncate arguments to int before calculating length, but
- # other integral dtypes we don't. Weird... but needed to match ATen shapes.
- if dtype == torch.int64:
- length = math.ceil((xend - xstart) / xstep)
- else:
- length = math.ceil((end - start) / step)
- if is_integer:
- return prims.iota(
- length,
- start=xstart,
- step=xstep,
- dtype=dtype,
- device=device,
- requires_grad=requires_grad,
- )
- computation_dtype = utils.get_acc_type(dtype, device)
- index = prims.iota(
- length,
- start=0,
- step=1,
- dtype=torch.int64,
- device=device,
- requires_grad=False,
- )
- index = _maybe_convert_to_dtype(index, computation_dtype)
- result = start + step * index
- result = _maybe_convert_to_dtype(result, dtype)
- if requires_grad:
- result.requires_grad_(True)
- return result
- @register_decomposition(aten.lerp)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("start", "end", "weight"),
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def lerp(start: Tensor, end: Tensor, weight: Union[Tensor, NumberType]):
- inputs = [start, end]
- if isinstance(weight, Number):
- weight = start.new_full((), weight) # type: ignore[arg-type]
- else:
- inputs.append(weight)
- assert isinstance(weight, Tensor) # mypy
- # We implement it this way for numerical stability. We assume (in the stability optimisation)
- # that 0 <= weight <= 1. We take the abs to deal with complex numbers
- # We want to perform operations near zero, which is where floating points are most precise
- # thus, we perform the following optimisation:
- # If weight.abs() >= 0.5:
- # return (1 - weight) * (start - end) + end
- mask = weight.abs() >= 0.5
- coeff = torch.where(mask, weight - 1, weight)
- base = torch.where(mask, end, start)
- output = coeff * (end - start) + base
- # make sure the decomposition output's stride is same as non-decomposition path.
- stride = utils.compute_elementwise_output_strides(*_maybe_broadcast(*inputs))
- if output.stride() != stride:
- return prims.copy_strided(output, stride)
- return output
- @register_decomposition(aten.linspace)
- @out_wrapper()
- def linspace(
- start: NumberType,
- end: NumberType,
- steps: NumberType,
- *,
- dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
- layout: torch.layout = torch.strided,
- pin_memory: bool = False,
- requires_grad: bool = False,
- ) -> TensorLikeType:
- if py_any(isinstance(arg, complex) for arg in (start, end, steps)):
- default_complex_dtype = utils.corresponding_complex_dtype(
- torch.get_default_dtype()
- )
- if dtype is None:
- dtype = default_complex_dtype
- else:
- check(
- utils.is_complex_dtype(dtype),
- lambda: f"linspace(): inferred dtype {default_complex_dtype} can't be safely cast to passed dtype {dtype}",
- )
- else:
- dtype = dtype or torch.get_default_dtype()
- assert isinstance(dtype, torch.dtype)
- # steps does not participate in the computation of the dtype
- check(
- isinstance(steps, IntLike),
- lambda: "steps must be int, not float",
- exc_type=TypeError,
- )
- assert isinstance(steps, IntLike) # for mypy
- check(steps >= 0, lambda: "number of steps must be non-negative")
- factory_kwargs = {
- "layout": layout,
- "device": device,
- "pin_memory": pin_memory,
- "requires_grad": requires_grad,
- }
- if steps == 0:
- return torch.full((0,), 0, dtype=dtype, **factory_kwargs) # type: ignore[arg-type]
- if steps == 1:
- return torch.full((1,), start, dtype=dtype, **factory_kwargs) # type: ignore[arg-type]
- if start == end:
- return torch.full((steps,), start, dtype=dtype, **factory_kwargs) # type: ignore[arg-type]
- # Perform in arange in int because some backends like ATen or Triton do not support all the dtypes
- rg = torch.arange(0, steps, **factory_kwargs) # type: ignore[arg-type]
- # Small types need to be computed in higher precision as this is, at heart, an associative scan
- dtype_red = (
- torch.int64
- if (utils.is_boolean_dtype(dtype) or utils.is_integer_dtype(dtype))
- else dtype
- )
- computation_dtype, _ = utils.reduction_dtypes(
- rg, REDUCTION_OUTPUT_TYPE_KIND.SAME, dtype_red
- )
- cast_rg = partial(_maybe_convert_to_dtype, dtype=computation_dtype)
- # We implement torch.lerp without performing rg / (steps - 1) explicitly
- # With this we get out[0] == start, out[-1] == end
- step = (end - start) / (steps - 1)
- out = torch.where(
- rg < steps / 2,
- start + step * cast_rg(rg), # type: ignore[arg-type,operator]
- end - step * cast_rg((steps - 1) - rg), # type: ignore[arg-type,operator]
- )
- return _maybe_convert_to_dtype(out, dtype) # type: ignore[return-value]
- @register_decomposition(aten.logspace)
- @out_wrapper()
- def logspace(
- start: NumberType,
- end: NumberType,
- steps: NumberType,
- base: NumberType = 10,
- *,
- dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
- layout: torch.layout = torch.strided,
- pin_memory: bool = False,
- requires_grad: bool = False,
- ) -> TensorLikeType:
- if dtype is None:
- dtype = torch.get_default_dtype()
- # NB: NumPy doesn't have this cast
- if prims.utils.is_integer_dtype(dtype):
- if isinstance(start, FloatLike):
- start = sym_int(start)
- if isinstance(end, FloatLike):
- end = sym_int(end)
- assert not isinstance(base, complex) # for mypy
- if base < 0:
- raise NotImplementedError
- ret = torch.linspace(
- start,
- end,
- steps,
- dtype=torch.float64,
- layout=layout,
- device=device,
- pin_memory=pin_memory,
- requires_grad=requires_grad,
- )
- return _maybe_convert_to_dtype(torch.pow(base, ret), dtype)
- @overload
- def meshgrid(tensors: Sequence[TensorLikeType], indexing: str):
- pass
- @overload
- def meshgrid(*tensors: TensorLikeType, indexing: str):
- pass
- @register_decomposition(aten.meshgrid)
- def meshgrid(
- *tensors: Union[TensorLikeType, List[TensorLikeType], Tuple[TensorLikeType]],
- indexing: str,
- ) -> List[TensorLikeType]:
- # This ref simultaneously handles two overloads (see stubs above)
- # The `indexing` argument is currently optional for torch.meshgrid, but we
- # plan to make the argument required: https://github.com/pytorch/pytorch/issues/50276
- if isinstance(tensors[0], list) or isinstance(tensors[0], tuple):
- assert len(tensors) == 1
- tensors = tuple(tensors[0])
- check(
- py_all(isinstance(a, TensorLike) for a in tensors),
- lambda: "meshgrid expects its inputs to be tensors",
- )
- check(len(tensors) > 0, lambda: "meshgrid expects a non-empty TensorList")
- for i in range(len(tensors) - 1):
- check(
- tensors[i].dtype == tensors[i + 1].dtype, # type: ignore[union-attr]
- lambda: "meshgrid expects all tensors to have the same dtype",
- )
- check(
- tensors[i].device == tensors[i + 1].device, # type: ignore[union-attr]
- lambda: "meshgrid expects all tensors to have the same device",
- )
- swap_first_and_second_tensors = False
- if indexing == "xy":
- swap_first_and_second_tensors = len(tensors) >= 2
- if swap_first_and_second_tensors:
- tensors = (tensors[1], tensors[0], *tensors[2:])
- else:
- check(
- indexing == "ij",
- lambda: (
- 'torch.meshgrid: indexing must be one of "xy" or "ij", '
- f"but received: {indexing}"
- ),
- )
- result_shape: List[int] = []
- for t in tensors:
- assert isinstance(t, TensorLike) # mypy
- check(
- t.ndim == 0 or t.ndim == 1,
- lambda: f"torch.meshgrid: Expected 0D or 1D tensor in the tensor list but got: {t}",
- )
- result_shape.append(t.numel())
- grids: List[TensorLikeType] = []
- for i, t in enumerate(tensors):
- assert isinstance(t, TensorLike) # mypy
- if t.ndim == 0:
- t = t.view((1,))
- grids.append(prims.broadcast_in_dim(t, result_shape, (i,)))
- if swap_first_and_second_tensors:
- # Swap outputs if we originally swapped at the beginning
- grids[0], grids[1] = grids[1], grids[0]
- return grids
- # CompositeImplicitAutograd - don't register decomp
- def movedim(
- input: TensorLikeType,
- source: Union[int, DimsSequenceType],
- destination: Union[int, DimsSequenceType],
- ) -> TensorLikeType:
- """
- Reference implementation of torch.movedim
- """
- if type(source) is int:
- source = (source,)
- if type(destination) is int:
- destination = (destination,)
- # Converts to list to produce a compatible error message with core PyTorch,
- # which prints sequences in square brackets.
- utils.check(
- len(source) == len(destination), # type: ignore[arg-type]
- lambda: (
- "movedim: Invalid source or destination dims: source " # type: ignore[arg-type]
- f"({list(source)} dims) should contain the same number " # type: ignore[arg-type]
- f"of dims as destination ({list(destination)} dims)" # type: ignore[arg-type]
- ),
- )
- rank = input.ndim
- ss = tuple(utils.canonicalize_dims(rank=rank, indices=source)) # type: ignore[arg-type]
- ds = tuple(utils.canonicalize_dims(rank=rank, indices=destination)) # type: ignore[arg-type]
- sss = set(ss)
- dss = set(ds)
- # See above on why this converts to list in error messages.
- utils.check(
- len(ss) == len(sss),
- lambda: f"movedim: repeated dim in `source` ({list(source)})", # type: ignore[arg-type]
- )
- utils.check(
- len(ds) == len(dss),
- lambda: f"movedim: repeated dim in `destination` ({list(destination)})", # type: ignore[arg-type]
- )
- m = dict(zip(ds, ss))
- dims = []
- si = 0 # source index
- for di in range(rank):
- # check if the destination index is in the mapping
- s = m.get(di)
- if s is not None:
- # insert source index if found
- dims.append(s)
- else:
- # insert source index sequentially, skipping indices from the mapping
- while si in sss:
- si += 1
- dims.append(si)
- si += 1
- result = torch.permute(input, tuple(dims))
- return result
- # NOTE: for convenience, shape can be a tuple of ints or a tuple containing a tuple of ints
- @register_decomposition(aten.empty_strided)
- def empty_strided(
- shape: Union[ShapeType, Tuple[ShapeType]],
- strides: StrideType,
- *,
- dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
- layout: torch.layout = torch.strided,
- requires_grad: bool = False,
- pin_memory: bool = False,
- ) -> TensorLikeType:
- # Layout == strided, pin_memory is False
- utils.check_layout(layout)
- utils.check_pin_memory(pin_memory)
- shape = utils.extract_shape_from_varargs(shape)
- dtype = torch.get_default_dtype() if dtype is None else dtype
- device = torch.device("cpu") if device is None else device
- return prims.empty_strided(
- shape,
- strides,
- dtype=dtype,
- device=device,
- requires_grad=requires_grad,
- )
- @register_decomposition(aten.eye)
- @out_wrapper()
- def eye(
- n: int,
- m: Optional[int] = None,
- *,
- dtype: Optional[torch.dtype] = None,
- layout: torch.layout = torch.strided,
- device: Optional[torch.device] = None,
- pin_memory: bool = False,
- requires_grad: bool = False, # TODO: unused
- ) -> TensorLikeType:
- """
- Reference implementation of torch.eye
- """
- if m is None:
- m = n
- check(n >= 0, lambda: f"n must be greater or equal to 0, got {n}")
- check(m >= 0, lambda: f"m must be greater or equal to 0, got {m}")
- range_n = torch.arange(n, dtype=torch.int64, device=device, requires_grad=False)
- range_m = torch.arange(m, dtype=torch.int64, device=device, requires_grad=False)
- cond = range_n.unsqueeze(-1) == range_m
- if dtype is torch.bool:
- return cond
- else:
- one = torch.ones(
- (1,),
- dtype=dtype,
- layout=layout,
- device=device,
- pin_memory=pin_memory,
- requires_grad=False,
- )
- return torch.where(cond, one, 0)
- # TODO: Use requires_grad. All refs taking the requires_grad kwarg must
- # return a leaf tensor.
- # result.requires_grad_(requires_grad)
- @register_decomposition(aten.full)
- @out_wrapper()
- def full(
- shape: ShapeType,
- fill_value: NumberType,
- *,
- dtype: Optional[torch.dtype] = None,
- layout: torch.layout = torch.strided,
- device: Optional[torch.device] = None,
- pin_memory: bool = False,
- requires_grad: bool = False,
- ) -> TensorLikeType:
- utils.check_layout(layout)
- utils.check_pin_memory(pin_memory)
- dtype = dtype if dtype is not None else utils.type_to_dtype(type(fill_value))
- device = device if device is not None else torch.device("cpu")
- e = empty(
- shape,
- dtype=dtype,
- layout=layout,
- device=device,
- pin_memory=pin_memory,
- requires_grad=requires_grad,
- )
- return torch.fill(e, fill_value) # type: ignore[arg-type]
- def full_like(
- a: TensorLikeType,
- fill_value: NumberType,
- *,
- dtype: Optional[torch.dtype] = None,
- layout: Optional[torch.layout] = None,
- device: Optional[torch.device] = None,
- pin_memory: bool = False,
- requires_grad: bool = False,
- memory_format: torch.memory_format = torch.preserve_format,
- ) -> TensorLikeType:
- e = torch.empty_like(
- a,
- dtype=dtype,
- layout=layout,
- device=device,
- pin_memory=pin_memory,
- requires_grad=requires_grad,
- memory_format=memory_format,
- )
- return fill(e, fill_value)
- @register_decomposition(aten.zeros_like)
- def zeros_like(
- a: TensorLikeType,
- *,
- dtype: Optional[torch.dtype] = None,
- layout: Optional[torch.layout] = None,
- device: Optional[torch.device] = None,
- pin_memory: bool = False,
- requires_grad: bool = False,
- memory_format: torch.memory_format = torch.preserve_format,
- ) -> TensorLikeType:
- return torch.full_like(
- a,
- False if (dtype or a.dtype) == torch.bool else 0,
- dtype=dtype,
- layout=layout,
- device=device,
- pin_memory=pin_memory,
- requires_grad=requires_grad,
- memory_format=memory_format,
- )
- @register_decomposition(aten.ones_like)
- def ones_like(
- a: TensorLikeType,
- *,
- dtype: Optional[torch.dtype] = None,
- layout: Optional[torch.layout] = None,
- device: Optional[torch.device] = None,
- pin_memory: bool = False,
- requires_grad: bool = False,
- memory_format: torch.memory_format = torch.preserve_format,
- ) -> TensorLikeType:
- return torch.full_like(
- a,
- True if (dtype or a.dtype) == torch.bool else 1,
- dtype=dtype,
- layout=layout,
- device=device,
- pin_memory=pin_memory,
- requires_grad=requires_grad,
- memory_format=memory_format,
- )
- @register_decomposition(aten.randn.default)
- @out_wrapper()
- def randn(
- *shape,
- dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
- layout: Optional[torch.layout] = None,
- requires_grad: bool = False,
- pin_memory: bool = False,
- ) -> TensorLikeType:
- utils.check_pin_memory(pin_memory)
- shape_ = utils.extract_shape_from_varargs(shape)
- dtype = utils.dtype_or_default(dtype)
- device = utils.device_or_default(device)
- return prims.normal(
- shape_,
- mean=0.0,
- std=1.0,
- dtype=dtype,
- device=device,
- requires_grad=requires_grad,
- )
- def scalar_tensor(
- a: NumberType,
- *,
- dtype: Optional[torch.dtype] = None,
- layout: torch.layout = torch.strided,
- device: Optional[torch.device] = None,
- pin_memory: bool = False,
- ) -> TensorLikeType:
- utils.check_layout(layout)
- utils.check_pin_memory(pin_memory)
- dtype = dtype if dtype is not None else utils.type_to_dtype(type(a))
- device = device if device is not None else torch.device("cpu")
- return prims.scalar_tensor(a, dtype=dtype, device=device)
- #
- # Randomness References
- #
- def _uniform_helper(
- shape: ShapeType,
- low: Union[bool, int, float] = 0.0,
- high: Union[bool, int, float] = 1.0,
- *,
- dtype: torch.dtype,
- device: DeviceLikeType,
- ) -> TensorLikeType:
- utils.validate_shape(shape)
- assert isinstance(low, Number)
- assert isinstance(high, Number)
- low = sym_float(low)
- high = sym_float(high)
- assert isinstance(dtype, torch.dtype)
- device = utils.canonicalize_device(device)
- return prims._uniform_helper(shape, low=low, high=high, dtype=dtype, device=device)
- @register_decomposition(aten.masked_fill)
- def masked_fill(a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLikeType):
- python_type = utils.dtype_to_type(a.dtype)
- if isinstance(value, Number):
- value_type = type(value)
- else:
- # NOTE: Could not use value = item(value) as it resulted in
- # RuntimeError: Cannot cast FakeTensor(cpu) to number
- value_ndim = value.ndim
- check(
- value_ndim == 0,
- lambda: f"only supports a 0-dimensional value tensor, but got tensor with {value_ndim} dimension",
- )
- # `masked_fill` allows cpu scalar to be moved to cuda but not otherwise.
- is_cpu_scalar = a.device.type == "cuda" and value.device.type == "cpu"
- check(
- is_cpu_scalar or value.device == a.device,
- lambda: "Expected `value` to be on same device as `a`",
- )
- value_type = utils.dtype_to_type(value.dtype)
- if value_type is complex:
- # only downcasting from complex to lower type is not allowed.
- # We allow casting `value` to lower type for other case
- # Eg. float -> int.
- # Ref: https://github.com/pytorch/pytorch/issues/79195
- check(
- utils.is_weakly_lesser_type(value_type, python_type),
- lambda: f"could not convert to type {python_type} without overflow",
- )
- # Since `where` allows type-promotion,
- # cast value to correct type before passing to `where`
- value = _maybe_convert_to_dtype(value, a.dtype)
- r = torch.where(mask, value, a) # type: ignore[arg-type]
- # aten.mask_fill always return a new contiguous tensor
- # contiguous() is needed to correctly model the output stride
- return r.contiguous()
- @register_decomposition(aten.masked_fill_)
- def masked_fill_(
- a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLikeType
- ) -> TensorLikeType:
- b = torch.masked_fill(a, mask, value) # type: ignore[arg-type]
- a.copy_(b)
- return a
- # CompositeImplicitAutograd - don't register decomp
- def allclose(
- a: TensorLikeType,
- b: TensorLikeType,
- rtol: float = 1e-05,
- atol: float = 1e-08,
- equal_nan: bool = False,
- ) -> bool:
- """
- Reference implementation of torch.allclose
- """
- _check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol)
- return bool(
- torch.all(torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)).item()
- )
- # TODO: add OpInfo for torch.equal and refs.equal
- def equal(a: TensorLikeType, b: TensorLikeType) -> bool:
- utils.check_same_device(a, b, allow_cpu_scalar_tensors=False)
- utils.check_same_dtype(a, b)
- # Shape check
- if a.ndim != b.ndim:
- return False
- for x, y in zip(a.shape, b.shape):
- if x != y:
- return False
- # Short-circuits if there are no elements to validate
- if a.numel() == 0:
- return True
- return item(all(eq(a, b))) # type: ignore[return-value]
- @register_decomposition(aten.norm)
- @out_wrapper(exact_dtype=True)
- def norm(
- input: TensorLikeType,
- p: Optional[Union[float, str]] = "fro",
- dim: Optional[DimsType] = None,
- keepdim: bool = False,
- *,
- dtype: Optional[torch.dtype] = None,
- ) -> TensorLikeType:
- # In these cases we compute the "Frobenius norm"
- if (
- p == "fro" and (dim is None or isinstance(dim, Dim) or len(dim) <= 2)
- ) or p is None:
- p = 2
- if isinstance(dim, Dim):
- dim = [dim]
- if isinstance(p, str):
- # Here we either call the nuclear norm, or we call matrix_norm with some arguments
- # that will throw an error
- if dim is None:
- dim = tuple(range(input.ndim))
- return torch.linalg.matrix_norm(input, p, dim, keepdim, dtype=dtype)
- else:
- return torch.linalg.vector_norm(input, p, dim, keepdim, dtype=dtype)
- @register_decomposition(aten.trace)
- def trace(self: TensorLikeType) -> TensorLikeType:
- utils.check(
- self.ndim == 2, lambda: "expected a matrix, but got tensor with dim {self.ndim}"
- )
- return torch.sum(torch.diag(self, 0))
- def _make_r_binary_op(base_op):
- def rop(
- a: Union[TensorLikeType, NumberType],
- b: Union[TensorLikeType, NumberType],
- ) -> TensorLikeType:
- return base_op(b, a)
- return rop
- rtruediv = _make_r_binary_op(true_divide)
- rfloordiv = _make_r_binary_op(floor_divide)
- rpow = _make_r_binary_op(pow)
- @register_decomposition(aten.triu)
- @out_wrapper()
- def triu(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType:
- utils.check(
- a.ndim >= 2, lambda: "triu: input tensor must have at least 2 dimensions"
- )
- h, w = a.shape[-2:]
- mask = (
- torch.arange(w, device=a.device).unsqueeze(-2)
- - torch.arange(h, device=a.device).unsqueeze(-1)
- ) >= diagonal
- # aten.triu always returns a new contiguous tensor
- # contiguous() is needed to correctly model the output stride
- return utils.mask_tensor(mask, a).contiguous()
- @register_decomposition(aten.tril)
- @out_wrapper()
- def tril(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType:
- utils.check(
- a.ndim >= 2, lambda: "tril: input tensor must have at least 2 dimensions"
- )
- h, w = a.shape[-2:]
- mask = (
- torch.arange(w, device=a.device).unsqueeze(-2)
- - torch.arange(h, device=a.device).unsqueeze(-1)
- ) <= diagonal
- # aten.tril always returns a new contiguous tensor
- # contiguous() is needed to correctly model the output stride
- return utils.mask_tensor(mask, a).contiguous()
- # This is based on get_tril_size in aten/src/ATen/native/TensorFactories.h
- # The components of the matrix that belong to the lower triangle with offset
- # form a pentagon that can be broken down into a top trapezoid and a bottom
- # rectangle. For the implementation of tril_indices, we need the sizes of
- # both of these, as well as the length of the top side of the trapezoid.
- def _get_tril_sizes(row: int, col: int, offset: int) -> Tuple[int, int, int]:
- if row == 0 or col == 0:
- return 0, 0, 0
- m_first_row = min(col, 1 + offset) if offset > 0 else int(row + offset > 0)
- m_last_row = max(0, min(col, row + offset))
- n_row_all = max(0, min(row, row + offset))
- n_row_trapezoid = m_last_row - m_first_row + 1
- # Number of elements in top trapezoid
- trapezoid_size = (m_first_row + m_last_row) * n_row_trapezoid // 2
- # Number of elements in bottom rectangle
- diff_row = n_row_all - n_row_trapezoid
- rectangle_size = max(0, diff_row * col)
- return trapezoid_size, rectangle_size, m_first_row
- def _trilu_checks(
- name: str,
- row: int,
- col: int,
- dtype: torch.dtype,
- layout: torch.layout,
- pin_memory: bool,
- ):
- check(row >= 0, lambda: f"row must be non-negative, got {row}")
- check(col >= 0, lambda: f"col must be non-negative, got {col}")
- check(
- dtype in (torch.int32, torch.int64),
- lambda: f"\"{name}\" not implemented for '{dtype}'",
- )
- # This is based on tril_indices_cuda in aten/src/ATen/native/cuda/TensorFactories.cu
- @register_decomposition(aten.tril_indices)
- def tril_indices(
- row: int,
- col: int,
- offset: int = 0,
- *,
- dtype: torch.dtype = torch.long,
- layout: torch.layout = torch.strided,
- device: DeviceLikeType = "cpu",
- pin_memory: bool = False,
- ) -> TensorLikeType:
- _trilu_checks("tril_indices", row, col, dtype, layout, pin_memory)
- trapezoid_size, rectangle_size, m_first_row = _get_tril_sizes(row, col, offset)
- row_offset = max(0, -offset)
- arange_kw = partial(
- torch.arange, layout=layout, device=device, pin_memory=pin_memory
- )
- # first we do the indices for top trapezoid
- xs1 = arange_kw(0, trapezoid_size, dtype=torch.float64)
- b = m_first_row - 0.5
- row_inds1 = torch.floor(-b + torch.sqrt(b * b + 2 * xs1))
- col_inds1 = torch.floor(xs1 - (2 * m_first_row - 1 + row_inds1) * row_inds1 * 0.5)
- row_inds1 = _maybe_convert_to_dtype(row_inds1 + row_offset, dtype)
- col_inds1 = _maybe_convert_to_dtype(col_inds1, dtype)
- # then bottom rectangle
- xs2 = arange_kw(0, rectangle_size, dtype=dtype)
- row_inds2 = xs2 // col + (col - m_first_row + 1 + row_offset)
- col_inds2 = xs2 % col
- return torch.stack(
- (torch.cat((row_inds1, row_inds2)), torch.cat((col_inds1, col_inds2)))
- )
- # Similar to _get_tril_sizes above, but here there is a top trapezoid and
- # a bottom rectangle instead. Note that you can't reduce this to
- # _get_tril_sizes(col, row, -offset) because that would correspond to
- # decomposing into a left trapezoid and right rectangle.
- def _get_triu_sizes(row: int, col: int, offset: int) -> Tuple[int, int, int]:
- if row == 0 or col == 0:
- return 0, 0, 0
- m_first_row = max(0, col - offset) if offset > 0 else col
- # Number of elements in top rectangle
- rectangle_size = max(0, min(row, -offset) * col)
- # Number of elements in bottom trapezoid
- trapezoid_size_tril, rectangle_size_tril, _ = _get_tril_sizes(row, col, offset - 1)
- triu_size = row * col - (trapezoid_size_tril + rectangle_size_tril)
- trapezoid_size = triu_size - rectangle_size
- return trapezoid_size, rectangle_size, m_first_row
- @register_decomposition(aten.triu_indices)
- def triu_indices(
- row: int,
- col: int,
- offset: int = 0,
- *,
- dtype: torch.dtype = torch.long,
- layout: torch.layout = torch.strided,
- device: DeviceLikeType = "cpu",
- pin_memory: bool = False,
- ) -> TensorLikeType:
- _trilu_checks("triu_indices", row, col, dtype, layout, pin_memory)
- trapezoid_size, rectangle_size, m_first_row = _get_triu_sizes(row, col, offset)
- col_offset = max(0, offset)
- arange_kw = partial(
- torch.arange, layout=layout, device=device, pin_memory=pin_memory
- )
- # indices for top rectangle
- xs2 = arange_kw(0, rectangle_size, dtype=dtype)
- row_inds2 = xs2 // col
- col_inds2 = xs2 % col
- # bottom trapezoid
- xs1 = arange_kw(0, trapezoid_size, dtype=torch.float64)
- b = -0.5 - m_first_row
- row_inds1 = torch.floor(-b - torch.sqrt(b * b - 2 * xs1))
- col_inds1 = torch.floor(xs1 - ((2 * m_first_row - 1 - row_inds1) * row_inds1) * 0.5)
- row_inds1 = _maybe_convert_to_dtype(row_inds1, dtype)
- col_inds1 = _maybe_convert_to_dtype(col_inds1, dtype)
- if col:
- row_inds1 = row_inds1 + (rectangle_size // col)
- col_inds1 = col_inds1 + col_offset
- return torch.stack(
- (torch.cat((row_inds2, row_inds1)), torch.cat((col_inds2, col_inds1)))
- )
- @register_decomposition(aten.bucketize)
- @out_wrapper(exact_dtype=True)
- def bucketize(
- a: TensorLikeType,
- boundaries: TensorLikeType,
- *,
- out_int32: bool = False,
- right: bool = False,
- ):
- utils.check(
- boundaries.dim() == 1,
- lambda: f"boundaries tensor must be 1 dimension but got dim({boundaries.dim()})",
- )
- out_dtype = torch.int32 if out_int32 else torch.int64
- n_boundaries = boundaries.shape[-1]
- if n_boundaries == 0:
- return torch.zeros_like(a)
- # We are trying to find the bucket (defined by pairs of consecutive elements of `boundaries`)
- # each element of `a` belongs to. We use binary search to achieve logarithimic complexity,
- # but each step of the search is done "in parallel" over all elements of `a`
- # can't use int32 as indexes, so we have to do all computations with int64 and convert at the end
- start = torch.zeros(a.shape, device=a.device, dtype=torch.int64)
- end = start + n_boundaries
- # Max depth of the binary search
- # Since we can't break out of the loop at different points for different elements of a,
- # we just do the max amount of iterations that binary search requires and add condition
- # tensor (cond_update below) to stop updating once the search terminates
- # For first iteration through loop we can skip some checks, we have separate implementation
- mid = start + (end - start) // 2
- mid_val = boundaries[mid]
- if right:
- cond_mid = mid_val > a
- else:
- cond_mid = mid_val >= a
- start = torch.where(cond_mid, start, mid + 1)
- if n_boundaries > 1:
- cond_update = torch.ones_like(a, dtype=torch.bool)
- niters = int(math.log2(n_boundaries))
- for _ in range(niters):
- end = torch.where(cond_mid & cond_update, mid, end)
- cond_update = start < end
- # start might end up pointing to 1 past the end, we guard against that
- mid = torch.where(cond_update, start + (end - start) // 2, 0)
- mid_val = boundaries[mid]
- # If right is true, the buckets are closed on the *left*
- # (i.e., we are doing the equivalent of std::upper_bound in C++)
- # Otherwise they are closed on the right (std::lower_bound)
- if right:
- cond_mid = mid_val > a
- else:
- cond_mid = mid_val >= a
- start = torch.where((~cond_mid) & cond_update, mid + 1, start)
- return start.to(dtype=out_dtype)
- @register_decomposition(aten.cauchy)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("self",),
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def cauchy(self, median=0, sigma=1, generator=None):
- assert generator is None
- utils.check(
- not utils.is_complex_dtype(self.dtype)
- and not utils.is_integer_dtype(self.dtype)
- and not utils.is_boolean_dtype(self.dtype),
- lambda: f"Cauchy distribution is a continuous probability distribution. \
- dtype must be a floating point but you specified {self.dtype}",
- )
- utils.check(
- sigma > 0.0,
- lambda: f"cauchy_ expects sigma > 0.0, but found sigma={sigma}",
- )
- return median + sigma * torch.tan(math.pi * (torch.rand_like(self) - 0.5))
- @register_decomposition(aten.exponential)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("self",),
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def exponential(self, rate=1, generator=None):
- assert generator is None
- utils.check(
- not utils.is_complex_dtype(self.dtype)
- and not utils.is_integer_dtype(self.dtype)
- and not utils.is_boolean_dtype(self.dtype),
- lambda: f"Exponential distribution is a continuous probability distribution. \
- dtype must be a floating point but you specified {self.dtype}",
- )
- utils.check(
- rate > 0.0,
- lambda: f"exponential_ expects lambda > 0.0, but found lambda={rate}",
- )
- return -1 / rate * torch.log1p(-torch.rand_like(self))
- @register_decomposition(aten.geometric)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("self",),
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def geometric(self, p, generator=None):
- assert generator is None
- # TODO: fix inductor rand_like for integer, bool dtypes
- utils.check(
- not utils.is_complex_dtype(self.dtype)
- and not utils.is_boolean_dtype(self.dtype),
- lambda: f"geometric not implemented for {self.dtype}",
- )
- utils.check(
- 0 < p and p < 1,
- lambda: f"geometric_ expects p to be in (0, 1), but got p={p}",
- )
- return torch.floor(torch.log1p(-torch.rand_like(self)) / math.log1p(-p)) + 1
- @register_decomposition(aten.log_normal)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("self",),
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def log_normal(self, mean=1, std=2, generator=None):
- assert generator is None
- utils.check(
- not utils.is_complex_dtype(self.dtype)
- and not utils.is_integer_dtype(self.dtype)
- and not utils.is_boolean_dtype(self.dtype),
- lambda: f"log_normal not implemented for {self.dtype}",
- )
- utils.check(
- 0 < std,
- lambda: f"log_normal_ expects std > 0.0, but found std={std}",
- )
- return torch.exp(std * torch.randn_like(self) + mean)
- # inplace
- abs_ = _make_inplace(abs)
- acos_ = _make_inplace(acos)
- acosh_ = _make_inplace(acosh)
- add_ = _make_inplace(add)
- addcmul_ = _make_inplace(addcmul)
- addcdiv_ = _make_inplace(addcdiv)
- asin_ = _make_inplace(asin)
- asinh_ = _make_inplace(asinh)
- atan_ = _make_inplace(atan)
- atanh_ = _make_inplace(atanh)
- atan2_ = _make_inplace(atan2)
- bitwise_and_ = _make_inplace(bitwise_and)
- bitwise_left_shift_ = _make_inplace(bitwise_left_shift)
- bitwise_not_ = _make_inplace(bitwise_not)
- bitwise_or_ = _make_inplace(bitwise_or)
- bitwise_right_shift_ = _make_inplace(bitwise_right_shift)
- bitwise_xor_ = _make_inplace(bitwise_xor)
- ceil_ = _make_inplace(ceil)
- clamp_ = _make_inplace(clamp)
- clamp_min_ = _make_inplace(clamp_min)
- clamp_max_ = _make_inplace(clamp_max)
- conj_physical_ = _make_inplace(conj_physical)
- copysign_ = _make_inplace(copysign)
- cos_ = _make_inplace(cos)
- cosh_ = _make_inplace(cosh)
- cumsum_ = _make_inplace(cumsum)
- digamma_ = _make_inplace(digamma)
- div_ = _make_inplace(div)
- eq_ = _make_inplace(eq)
- erf_ = _make_inplace(erf)
- erfc_ = _make_inplace(erfc)
- erfinv_ = _make_inplace(erfinv)
- exp_ = _make_inplace(exp)
- exp2_ = _make_inplace(exp2)
- expm1_ = _make_inplace(expm1)
- float_power_ = _make_inplace(float_power)
- floor_ = _make_inplace(floor)
- floor_divide_ = _make_inplace(floor_divide)
- fmod_ = _make_inplace(fmod)
- frac_ = _make_inplace(frac)
- gcd_ = _make_inplace(gcd)
- ge_ = _make_inplace(ge)
- gt_ = _make_inplace(gt)
- heaviside_ = _make_inplace(heaviside)
- hypot_ = _make_inplace(hypot)
- igamma_ = _make_inplace(igamma)
- igammac_ = _make_inplace(igammac)
- i0_ = _make_inplace(i0)
- lcm_ = _make_inplace(lcm)
- le_ = _make_inplace(le)
- lerp_ = _make_inplace(lerp)
- lgamma_ = _make_inplace(lgamma)
- log10_ = _make_inplace(log10)
- log1p_ = _make_inplace(log1p)
- log2_ = _make_inplace(log2)
- log_ = _make_inplace(log)
- logical_and_ = _make_inplace(logical_and)
- logical_not_ = _make_inplace(logical_not)
- logical_or_ = _make_inplace(logical_or)
- logical_xor_ = _make_inplace(logical_xor)
- lt_ = _make_inplace(lt)
- mul_ = _make_inplace(mul)
- mvlgamma_ = _make_inplace(mvlgamma)
- nan_to_num_ = _make_inplace(nan_to_num)
- ne_ = _make_inplace(ne)
- neg_ = _make_inplace(neg)
- nextafter_ = _make_inplace(nextafter)
- pow_ = _make_inplace(pow)
- reciprocal_ = _make_inplace(reciprocal)
- remainder_ = _make_inplace(remainder)
- rsqrt_ = _make_inplace(rsqrt)
- sgn_ = _make_inplace(sgn)
- sigmoid_ = _make_inplace(sigmoid)
- sign_ = _make_inplace(sign)
- sin_ = _make_inplace(sin)
- sinc_ = _make_inplace(sinc)
- sinh_ = _make_inplace(sinh)
- sqrt_ = _make_inplace(sqrt)
- square_ = _make_inplace(square)
- sub_ = _make_inplace(sub)
- tan_ = _make_inplace(tan)
- tanh_ = _make_inplace(tanh)
- tril_ = _make_inplace(tril)
- triu_ = _make_inplace(triu)
- true_divide_ = _make_inplace(true_divide)
- trunc_ = _make_inplace(trunc)
- xlogy_ = _make_inplace(xlogy)
- cauchy_ = _make_inplace(cauchy)
- exponential_ = _make_inplace(exponential)
- geometric_ = _make_inplace(geometric)
- log_normal_ = _make_inplace(log_normal)
- zero_ = _make_inplace(zero)
- # Views
- # We can't model these as above, as the pattern of doing `op(a, out=a)` does not work for a view function
- # given that it does not reshape the input (it just copies the result into it)
- # squeeze_ = _make_inplace(squeeze)
- # t_ = _make_inplace(t)
- # transpose_ = _make_inplace(transpose)
- # unsqueeze_ = _make_inplace(unsqueeze)
- import torch._refs._conversions
- import torch._refs.fft
- import torch._refs.linalg
- import torch._refs.nn.functional
- import torch._refs.special
|