__init__.py 167 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932393339343935393639373938393939403941394239433944394539463947394839493950395139523953395439553956395739583959396039613962396339643965396639673968396939703971397239733974397539763977397839793980398139823983398439853986398739883989399039913992399339943995399639973998399940004001400240034004400540064007400840094010401140124013401440154016401740184019402040214022402340244025402640274028402940304031403240334034403540364037403840394040404140424043404440454046404740484049405040514052405340544055405640574058405940604061406240634064406540664067406840694070407140724073407440754076407740784079408040814082408340844085408640874088408940904091409240934094409540964097409840994100410141024103410441054106410741084109411041114112411341144115411641174118411941204121412241234124412541264127412841294130413141324133413441354136413741384139414041414142414341444145414641474148414941504151415241534154415541564157415841594160416141624163416441654166416741684169417041714172417341744175417641774178417941804181418241834184418541864187418841894190419141924193419441954196419741984199420042014202420342044205420642074208420942104211421242134214421542164217421842194220422142224223422442254226422742284229423042314232423342344235423642374238423942404241424242434244424542464247424842494250425142524253425442554256425742584259426042614262426342644265426642674268426942704271427242734274427542764277427842794280428142824283428442854286428742884289429042914292429342944295429642974298429943004301430243034304430543064307430843094310431143124313431443154316431743184319432043214322432343244325432643274328432943304331433243334334433543364337433843394340434143424343434443454346434743484349435043514352435343544355435643574358435943604361436243634364436543664367436843694370437143724373437443754376437743784379438043814382438343844385438643874388438943904391439243934394439543964397439843994400440144024403440444054406440744084409441044114412441344144415441644174418441944204421442244234424442544264427442844294430443144324433443444354436443744384439444044414442444344444445444644474448444944504451445244534454445544564457445844594460446144624463446444654466446744684469447044714472447344744475447644774478447944804481448244834484448544864487448844894490449144924493449444954496449744984499450045014502450345044505450645074508450945104511451245134514451545164517451845194520452145224523452445254526452745284529453045314532453345344535453645374538453945404541454245434544454545464547454845494550455145524553455445554556455745584559456045614562456345644565456645674568456945704571457245734574457545764577457845794580458145824583458445854586458745884589459045914592459345944595459645974598459946004601460246034604460546064607460846094610461146124613461446154616461746184619462046214622462346244625462646274628462946304631463246334634463546364637463846394640464146424643464446454646464746484649465046514652465346544655465646574658465946604661466246634664466546664667466846694670467146724673467446754676467746784679468046814682468346844685468646874688468946904691469246934694469546964697469846994700470147024703470447054706470747084709471047114712471347144715471647174718471947204721472247234724472547264727472847294730473147324733473447354736473747384739474047414742474347444745474647474748474947504751475247534754475547564757475847594760476147624763476447654766476747684769477047714772477347744775477647774778477947804781478247834784478547864787478847894790479147924793479447954796479747984799480048014802480348044805480648074808480948104811481248134814481548164817481848194820482148224823482448254826482748284829483048314832483348344835483648374838483948404841484248434844484548464847484848494850485148524853485448554856485748584859486048614862486348644865486648674868486948704871487248734874487548764877487848794880488148824883488448854886488748884889489048914892489348944895489648974898489949004901490249034904490549064907490849094910491149124913491449154916491749184919492049214922492349244925492649274928492949304931493249334934493549364937493849394940494149424943494449454946494749484949495049514952495349544955495649574958495949604961496249634964496549664967496849694970497149724973497449754976497749784979498049814982498349844985498649874988498949904991499249934994499549964997499849995000500150025003500450055006500750085009501050115012501350145015501650175018501950205021502250235024502550265027502850295030503150325033503450355036503750385039504050415042504350445045504650475048504950505051505250535054505550565057505850595060506150625063506450655066506750685069507050715072507350745075507650775078507950805081508250835084508550865087508850895090509150925093509450955096509750985099510051015102510351045105510651075108510951105111511251135114511551165117511851195120512151225123512451255126512751285129513051315132513351345135513651375138513951405141514251435144514551465147514851495150515151525153515451555156515751585159516051615162516351645165516651675168516951705171517251735174517551765177517851795180518151825183518451855186518751885189519051915192519351945195519651975198519952005201520252035204520552065207520852095210521152125213521452155216521752185219522052215222522352245225522652275228522952305231523252335234523552365237523852395240524152425243524452455246524752485249525052515252525352545255525652575258525952605261526252635264526552665267526852695270527152725273527452755276527752785279528052815282528352845285528652875288528952905291529252935294529552965297529852995300530153025303530453055306530753085309531053115312531353145315531653175318531953205321532253235324532553265327532853295330533153325333533453355336533753385339534053415342534353445345534653475348534953505351535253535354535553565357535853595360536153625363536453655366536753685369537053715372537353745375537653775378537953805381538253835384538553865387538853895390539153925393539453955396539753985399540054015402540354045405540654075408540954105411541254135414541554165417541854195420542154225423542454255426542754285429543054315432543354345435543654375438543954405441
  1. import builtins
  2. import collections
  3. import math
  4. import operator
  5. import warnings
  6. from collections.abc import Iterable
  7. from enum import Enum
  8. from functools import partial, reduce, singledispatch, wraps
  9. from typing import Callable, List, Optional, overload, Sequence, Tuple, Union
  10. import torch
  11. import torch._prims as prims
  12. import torch._prims_common as utils
  13. from torch import sym_float, sym_int
  14. from torch._prims_common import (
  15. check,
  16. DeviceLikeType,
  17. Dim,
  18. DimsSequenceType,
  19. DimsType,
  20. dtype_to_type,
  21. ELEMENTWISE_TYPE_PROMOTION_KIND,
  22. FloatLike,
  23. FloatWithoutSymFloat,
  24. IntLike,
  25. is_weakly_lesser_type,
  26. Number,
  27. NumberType,
  28. REDUCTION_OUTPUT_TYPE_KIND,
  29. ShapeType,
  30. StrideType,
  31. TensorLike,
  32. TensorLikeType,
  33. TensorOrNumberLikeType,
  34. TensorSequenceType,
  35. )
  36. from torch._prims_common.wrappers import (
  37. _maybe_convert_to_dtype,
  38. _maybe_resize_out,
  39. _safe_copy_out,
  40. elementwise_type_promotion_wrapper,
  41. elementwise_unary_scalar_wrapper,
  42. out_wrapper,
  43. )
  44. # Experimental module containing prototype Python references for existing
  45. # PyTorch operations.
  46. __all__ = [
  47. #
  48. # Elementwise Unary References
  49. #
  50. "abs",
  51. "acos",
  52. "acosh",
  53. "asinh",
  54. "asin",
  55. "atan",
  56. "atanh",
  57. "bitwise_not",
  58. # "cbrt", # No corresponding torch operation
  59. "ceil",
  60. "conj_physical",
  61. "cos",
  62. "cosh",
  63. "digamma",
  64. "erf",
  65. "erfinv",
  66. "erfc",
  67. "exp",
  68. "expm1",
  69. "exp2",
  70. "fill",
  71. "floor",
  72. "frac",
  73. "index_add",
  74. "index_copy",
  75. "index_copy_",
  76. "index_select",
  77. "index_fill",
  78. "index_fill_",
  79. "isfinite",
  80. "isinf",
  81. "isposinf",
  82. "isneginf",
  83. "isnan",
  84. "isreal",
  85. "i0",
  86. "lerp",
  87. "lgamma",
  88. "log",
  89. "log1p",
  90. "log2",
  91. "log10",
  92. "log_softmax",
  93. "nan_to_num",
  94. "neg",
  95. "positive",
  96. "reciprocal",
  97. "round", # TODO: model kwargs
  98. "sigmoid",
  99. "sgn",
  100. "sign",
  101. "signbit",
  102. "sin",
  103. "sinc",
  104. "sinh",
  105. "softmax",
  106. "sqrt",
  107. "square",
  108. "tan",
  109. "tanh",
  110. "trace",
  111. "trunc",
  112. #
  113. # Elementwise Binary References
  114. #
  115. "add",
  116. "atan2",
  117. "bitwise_and",
  118. "bitwise_left_shift",
  119. "bitwise_or",
  120. "bitwise_right_shift",
  121. "bitwise_xor",
  122. "clamp_min",
  123. "clamp_max",
  124. "copysign",
  125. "div",
  126. "eq",
  127. "float_power",
  128. "floor_divide",
  129. "fmax",
  130. "fmin",
  131. "fmod",
  132. "gcd",
  133. "ge",
  134. "gt",
  135. "heaviside",
  136. "hypot",
  137. "igamma",
  138. "igammac",
  139. "imag",
  140. "isclose",
  141. "lcm",
  142. # 'ldexp',
  143. "le",
  144. "logical_and",
  145. "logical_not",
  146. "logical_or",
  147. "logical_xor",
  148. "lt",
  149. # 'max', # implement with reductions
  150. "maximum",
  151. # 'min', # implement with reductions
  152. "minimum",
  153. "mul",
  154. "ne",
  155. "nextafter",
  156. # 'polar', # abs, cos, sin
  157. "pow",
  158. "real",
  159. "rpow",
  160. "remainder",
  161. "rsub",
  162. "rtruediv",
  163. "rfloordiv",
  164. "sub",
  165. "true_divide",
  166. "trunc_divide",
  167. "xlogy",
  168. #
  169. # Elementwise Ternary References
  170. #
  171. "addcdiv",
  172. "addcmul",
  173. "clamp",
  174. #
  175. # Conditional references
  176. #
  177. "masked_fill",
  178. "where",
  179. #
  180. # Data conversion and movement references
  181. #
  182. "clone",
  183. "copy_to", # TODO: add OpInfo (or implement .to)
  184. "item", # TODO: add OpInfo
  185. "to",
  186. #
  187. # Reduction ops
  188. #
  189. "all",
  190. "amax",
  191. "amin",
  192. "any",
  193. "mean",
  194. "std",
  195. "std_mean",
  196. "sum",
  197. "sum_to_size",
  198. "prod",
  199. "var",
  200. "var_mean",
  201. #
  202. # Linear algebra ops
  203. #
  204. "addr",
  205. #
  206. # View & Shape Ops
  207. #
  208. "atleast_1d",
  209. "atleast_2d",
  210. "atleast_3d",
  211. "as_strided",
  212. "broadcast_shapes",
  213. "broadcast_tensors",
  214. "broadcast_to",
  215. "cat",
  216. "chunk",
  217. "column_stack",
  218. "conj",
  219. "constant_pad_nd",
  220. "contiguous",
  221. "diag_embed",
  222. "diag",
  223. "diagonal",
  224. "diagonal_copy",
  225. "diagonal_scatter",
  226. "dsplit",
  227. "dstack",
  228. "expand",
  229. "expand_as",
  230. "flatten",
  231. "flip",
  232. "fliplr",
  233. "flipud",
  234. "hsplit",
  235. "hstack",
  236. "meshgrid",
  237. "movedim",
  238. "narrow",
  239. "narrow_copy",
  240. "native_group_norm",
  241. "native_layer_norm",
  242. "permute",
  243. "ravel",
  244. "repeat",
  245. "reshape",
  246. "roll",
  247. "rot90",
  248. "rsqrt",
  249. "stack",
  250. "swap_axes", # alias for transpose
  251. "squeeze",
  252. "t",
  253. "T",
  254. "tensor_split",
  255. "transpose",
  256. "unfold",
  257. "unfold_copy",
  258. "unsqueeze",
  259. "view",
  260. "vsplit",
  261. "vstack",
  262. "unflatten",
  263. "unbind",
  264. "triu",
  265. "tril",
  266. "triu_indices",
  267. "tril_indices",
  268. #
  269. # Tensor Creation
  270. #
  271. "arange",
  272. "empty",
  273. "empty_like",
  274. "empty_strided",
  275. "eye",
  276. "full",
  277. "full_like",
  278. "linspace",
  279. "logspace",
  280. "ones",
  281. "ones_like",
  282. "randn",
  283. "scalar_tensor",
  284. "zeros",
  285. "zeros_like",
  286. #
  287. # Test-related functions
  288. #
  289. "allclose",
  290. "equal", # TODO: add OpInfo
  291. #
  292. # Statistical operations
  293. #
  294. "bucketize",
  295. ]
  296. Tensor = torch.Tensor
  297. DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
  298. aten = torch._ops.ops.aten
  299. def _broadcast_shapes(*_shapes):
  300. shapes = tuple(
  301. (x,) if isinstance(x, IntLike) else x
  302. for x in filter(lambda x: x is not None, _shapes)
  303. )
  304. # Short-circuits on no input
  305. if len(shapes) == 0:
  306. return None
  307. # Type checking
  308. # TODO: make common validations available as utils
  309. for shape in shapes:
  310. assert isinstance(shape, Sequence)
  311. # Computes common shape
  312. common_shape = [
  313. 1,
  314. ] * reduce(max, (len(shape) for shape in shapes))
  315. for arg_idx, shape in enumerate(shapes):
  316. for idx in range(-1, -1 - len(shape), -1):
  317. if common_shape[idx] == 1:
  318. if shape[idx] < 0:
  319. raise ValueError(
  320. "Attempting to broadcast a dimension with negative length!"
  321. )
  322. common_shape[idx] = shape[idx]
  323. elif shape[idx] != 1:
  324. if common_shape[idx] != shape[idx]:
  325. raise RuntimeError(
  326. f"Attempting to broadcast a dimension of length {shape[idx]} at {idx}! "
  327. f"Mismatching argument at index {arg_idx} had {shape}; but expected shape "
  328. f"should be broadcastable to {common_shape}"
  329. )
  330. return common_shape
  331. def _maybe_broadcast(*args, preserve_cpu_scalar_tensors=True):
  332. # Computes common shape
  333. common_shape = _broadcast_shapes(
  334. *map(lambda t: t.shape if isinstance(t, TensorLike) else None, args)
  335. )
  336. def __maybe_broadcast(x, shape):
  337. if x is None:
  338. return None
  339. elif isinstance(x, Number):
  340. return x
  341. elif isinstance(x, TensorLike):
  342. if preserve_cpu_scalar_tensors and utils.is_cpu_scalar_tensor(x):
  343. return x
  344. if not utils.same_shape(x.shape, common_shape):
  345. return x.expand(common_shape)
  346. return x
  347. else:
  348. raise RuntimeError(
  349. "Unexpected type when broadcasting: " + str(type(x)) + "!"
  350. )
  351. return tuple(__maybe_broadcast(x, common_shape) for x in args)
  352. # Utilities should come BEFORE this import
  353. from torch._decomp import register_decomposition
  354. #
  355. # Elementwise unary references
  356. #
  357. infer_aten_op = object()
  358. # TODO: add type promotion support
  359. def _make_elementwise_unary_reference(
  360. type_promotion_kind,
  361. *,
  362. aten_op=infer_aten_op,
  363. extra_meta=None,
  364. ) -> Callable:
  365. def inner(prim: Callable):
  366. nonlocal aten_op
  367. @wraps(prim)
  368. @out_wrapper()
  369. @elementwise_unary_scalar_wrapper
  370. @elementwise_type_promotion_wrapper(
  371. type_promoting_args=("a",),
  372. type_promotion_kind=type_promotion_kind,
  373. )
  374. def _ref(a: TensorLikeType) -> TensorLikeType:
  375. if extra_meta is not None:
  376. extra_meta(a)
  377. return prim(a)
  378. if aten_op is infer_aten_op:
  379. aten_op = utils.get_aten_op(prim, prim.__name__)
  380. if aten_op is not None:
  381. register_decomposition(aten_op)(_ref)
  382. return _ref
  383. return inner
  384. def _make_alias(fn, name):
  385. """
  386. This function defines an alias of another function and sets its __name__argument
  387. Note that when naïvely doing `alias = fn`, we have that `alias.__name__ == "fn"`.
  388. """
  389. def _fn(*args, **kwargs):
  390. return fn(*args, **kwargs)
  391. _fn.__name__ = name
  392. return _fn
  393. def _make_inplace(fn):
  394. """
  395. Given a function with out variant (i.e. using `out_wrapper()), it returns its in-place variant
  396. See https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-do-in-place-operations-work-in-pytorch
  397. """
  398. # nb. We use the name of the first argument used in the unary references
  399. @wraps(fn)
  400. def _fn(a, *args, **kwargs):
  401. return fn(a, *args, out=a, **kwargs)
  402. inplace_name = f"{fn.__name__}_"
  403. _fn.__name__ = inplace_name
  404. _fn = register_decomposition(getattr(aten, inplace_name))(_fn)
  405. # We access the __all__ attribute of the module where fn is defined
  406. # There may be a cleaner way of doing this...
  407. from inspect import getmodule
  408. _all = getmodule(fn).__all__ # type: ignore[union-attr]
  409. if inplace_name not in _all:
  410. _all.append(inplace_name)
  411. return _fn
  412. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT)
  413. def abs(a):
  414. return prims.abs(a)
  415. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  416. def acos(a):
  417. return prims.acos(a)
  418. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  419. def acosh(a):
  420. return prims.acosh(a)
  421. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  422. def asin(a):
  423. return prims.asin(a)
  424. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  425. def asinh(a):
  426. return prims.asinh(a)
  427. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  428. def atan(a):
  429. return prims.atan(a)
  430. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  431. def atanh(a):
  432. return prims.atanh(a)
  433. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
  434. def bitwise_not(a):
  435. return prims.bitwise_not(a)
  436. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
  437. def ceil(a):
  438. return prims.ceil(a)
  439. @register_decomposition(aten.conj_physical)
  440. @out_wrapper()
  441. def conj_physical(input: TensorLikeType):
  442. if not utils.is_complex_dtype(input.dtype):
  443. return input
  444. return prims.conj_physical(input)
  445. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  446. def cos(a):
  447. return prims.cos(a)
  448. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  449. def cosh(a):
  450. return prims.cosh(a)
  451. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  452. def digamma(a):
  453. return prims.digamma(a)
  454. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  455. def erf(a):
  456. return prims.erf(a)
  457. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  458. def erfinv(a):
  459. return prims.erf_inv(a)
  460. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  461. def erfc(a):
  462. return prims.erfc(a)
  463. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  464. def exp(a):
  465. return prims.exp(a)
  466. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  467. def expm1(a):
  468. return prims.expm1(a)
  469. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  470. def exp2(a):
  471. return prims.exp2(a)
  472. # Fill has its own implementation because it has a value parameter
  473. # CompositeImplicitAutograd - don't register decomp
  474. @out_wrapper()
  475. @elementwise_type_promotion_wrapper(
  476. type_promoting_args=("a,"),
  477. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
  478. )
  479. def fill(a: TensorLikeType, value: NumberType) -> TensorLikeType:
  480. assert isinstance(a, TensorLike)
  481. assert isinstance(value, Number)
  482. python_type = utils.dtype_to_type(a.dtype)
  483. if not utils.is_weakly_lesser_type(type(value), python_type):
  484. msg = "value argument of type {0} cannot be safely cast to type {1}!".format(
  485. type(value), python_type
  486. )
  487. raise ValueError(msg)
  488. return prims.fill(a, value)
  489. def fill_(a: TensorLikeType, value: NumberType) -> TensorLikeType:
  490. r = prims.fill(a, value)
  491. prims.copy_to(a, r)
  492. return a
  493. @register_decomposition(aten.zero)
  494. @out_wrapper()
  495. def zero(input: TensorLikeType) -> TensorLikeType:
  496. return torch.zeros_like(input)
  497. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
  498. def floor(a):
  499. return prims.floor(a)
  500. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
  501. def frac(x: TensorLikeType) -> TensorLikeType:
  502. trunc_x = torch.mul(torch.floor(torch.abs(x)), torch.sign(x))
  503. return torch.sub(x, trunc_x)
  504. # imag does not use _make_elementwise_unary_reference because it does not support out
  505. def imag(a: TensorLikeType) -> TensorLikeType:
  506. assert isinstance(a, TensorLike)
  507. utils.check(
  508. utils.is_complex_dtype(a.dtype), lambda: "imag only supports complex tensors."
  509. )
  510. return prims.imag(a)
  511. @_make_elementwise_unary_reference(
  512. ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
  513. aten_op=None, # CompositeImplicitAutograd
  514. )
  515. def isfinite(a: TensorLikeType) -> TensorLikeType:
  516. if utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype):
  517. return prims.isfinite(a)
  518. return ones_like(a, dtype=torch.bool)
  519. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
  520. def isinf(a: TensorLikeType) -> TensorLikeType:
  521. if utils.is_complex_dtype(a.dtype):
  522. return torch.logical_or(isinf(torch.real(a)), isinf(torch.imag(a)))
  523. if utils.is_float_dtype(a.dtype):
  524. return torch.abs(a) == float("inf")
  525. return torch.zeros_like(a, dtype=torch.bool)
  526. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
  527. def isposinf(a: TensorLikeType) -> TensorLikeType:
  528. utils.check(
  529. not utils.is_complex_dtype(a.dtype),
  530. lambda: f"Complex dtype is not supported for isposinf, got dtype {a.dtype}",
  531. )
  532. if utils.is_float_dtype(a.dtype):
  533. return a == float("inf")
  534. return torch.zeros_like(a, dtype=torch.bool)
  535. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
  536. def isneginf(a: TensorLikeType) -> TensorLikeType:
  537. utils.check(
  538. not utils.is_complex_dtype(a.dtype),
  539. lambda: f"Complex dtype is not supported for isneginf, got dtype {a.dtype}",
  540. )
  541. if utils.is_float_dtype(a.dtype):
  542. return a == float("-inf")
  543. return torch.zeros_like(a, dtype=torch.bool)
  544. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
  545. def isnan(a: TensorLikeType) -> TensorLikeType:
  546. return prims.ne(a, a)
  547. # alias
  548. mvlgamma = _make_alias(torch.special.multigammaln, "mvlgamma") # type: ignore[has-type]
  549. @_make_elementwise_unary_reference(
  550. ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
  551. aten_op=None, # CompositeImplicitAutograd
  552. )
  553. def isreal(a: TensorLikeType) -> TensorLikeType:
  554. if utils.is_complex_dtype(a.dtype):
  555. return torch.imag(a) == 0
  556. return torch.ones_like(a, dtype=torch.bool)
  557. # TODO: if this is special maybe it should be defined there and imported here?
  558. @_make_elementwise_unary_reference(
  559. ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, aten_op=aten.special_i0
  560. )
  561. def i0(a):
  562. return prims.bessel_i0(a)
  563. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  564. def lgamma(a):
  565. return prims.lgamma(a)
  566. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  567. def log(a):
  568. return prims.log(a)
  569. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  570. def log1p(a):
  571. return prims.log1p(a)
  572. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  573. def log2(a):
  574. return prims.log2(a)
  575. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  576. def log10(a):
  577. return prims.log10(a)
  578. # CompositeImplicitAutograd - don't register decomp
  579. @out_wrapper()
  580. def log_softmax(
  581. a: TensorLikeType,
  582. dim: int,
  583. dtype: Optional[torch.dtype] = None,
  584. ) -> TensorLikeType:
  585. result_dtype = dtype or a.dtype
  586. computation_dtype = utils.get_computation_dtype(result_dtype)
  587. a_ = _maybe_convert_to_dtype(a, computation_dtype)
  588. return _maybe_convert_to_dtype(a_ - logsumexp(a_, dim, keepdim=True), result_dtype) # type: ignore[return-value]
  589. @register_decomposition(aten.logsumexp)
  590. @out_wrapper()
  591. @elementwise_type_promotion_wrapper(
  592. type_promoting_args=("self",),
  593. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  594. )
  595. def logsumexp(
  596. self: TensorLikeType, dim: DimsType, keepdim: bool = False
  597. ) -> TensorLikeType:
  598. if not isinstance(dim, Iterable):
  599. dim = (dim,)
  600. if self.numel() == 0:
  601. return torch.sum(torch.exp(self), dim, keepdim).log()
  602. maxes = torch.amax(self, dim, keepdim=True)
  603. maxes = torch.masked_fill(maxes, maxes.abs() == float("inf"), 0)
  604. maxes_squeezed = maxes if keepdim else torch.squeeze(maxes, dim)
  605. result = torch.sum(torch.exp(self - maxes), dim, keepdim)
  606. return result.log().add(maxes_squeezed)
  607. @register_decomposition(aten.nan_to_num)
  608. @out_wrapper()
  609. def nan_to_num(
  610. a: TensorLikeType,
  611. nan: Optional[NumberType] = 0.0,
  612. posinf: Optional[NumberType] = None,
  613. neginf: Optional[NumberType] = None,
  614. ) -> TensorLikeType:
  615. assert isinstance(a, TensorLike)
  616. if utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype):
  617. return a.clone()
  618. if nan is None:
  619. nan = 0.0
  620. if posinf is None:
  621. posinf = torch.finfo(a.dtype).max
  622. if neginf is None:
  623. neginf = torch.finfo(a.dtype).min
  624. result = torch.where(torch.isnan(a), nan, a) # type: ignore[call-overload]
  625. result = torch.where(torch.isneginf(a), neginf, result) # type: ignore[call-overload]
  626. result = torch.where(torch.isposinf(a), posinf, result) # type: ignore[call-overload]
  627. return result
  628. def _neg_meta(a: TensorLikeType):
  629. check(
  630. a.dtype is not torch.bool,
  631. lambda: (
  632. "Negation, the `-` operator, on a bool tensor is not supported. "
  633. "If you are trying to invert a mask, use the `~` or `logical_not()` "
  634. "operator instead."
  635. ),
  636. )
  637. @_make_elementwise_unary_reference(
  638. ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, extra_meta=_neg_meta
  639. )
  640. def neg(a):
  641. return prims.neg(a)
  642. # positive does not use _make_elementwise_unary_reference because it does not support out
  643. # CompositeImplicitAutograd - don't register decomp
  644. def positive(a: TensorLikeType) -> TensorLikeType:
  645. assert isinstance(a, TensorLike)
  646. if a.dtype is torch.bool:
  647. msg = "positive does not support bool tensors."
  648. raise RuntimeError(msg)
  649. return a
  650. # real does not use _make_elementwise_unary_reference because it does not support out
  651. def real(a: TensorLikeType) -> TensorLikeType:
  652. assert isinstance(a, TensorLike)
  653. if utils.is_complex_dtype(a.dtype):
  654. return prims.real(a)
  655. return a
  656. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  657. def reciprocal(a):
  658. return prims.reciprocal(a)
  659. # TODO: round takes additional kwargs
  660. @_make_elementwise_unary_reference(
  661. ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  662. aten_op=None, # TODO: this does need a decomp, but kwarg handling is needed
  663. )
  664. def round(a):
  665. return prims.round(a)
  666. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  667. def rsqrt(a):
  668. return prims.rsqrt(a)
  669. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  670. def sigmoid(a: TensorLikeType) -> TensorLikeType:
  671. return true_divide(1, add(1, exp(neg(a))))
  672. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
  673. def sgn(a):
  674. if utils.is_complex_dtype(a.dtype):
  675. a_abs = a.abs()
  676. return torch.where(a_abs == 0, 0, a / a_abs)
  677. else:
  678. return a.sign()
  679. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
  680. def sign(a):
  681. return prims.sign(a)
  682. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
  683. def signbit(a):
  684. return prims.signbit(a)
  685. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  686. def sin(a):
  687. return prims.sin(a)
  688. # Autograd note: This will give the right first derivative at zero (by chance),
  689. # but not the right second derivative
  690. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  691. def sinc(a):
  692. a = math.pi * a
  693. return torch.where(a == 0, 1, torch.sin(a) / a)
  694. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  695. def sinh(a):
  696. return prims.sinh(a)
  697. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  698. def sqrt(a):
  699. return prims.sqrt(a)
  700. @_make_elementwise_unary_reference(
  701. ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG,
  702. aten_op=None, # CompositeImplicitAutograd,
  703. )
  704. def square(a: TensorLikeType) -> TensorLikeType:
  705. return mul(a, a)
  706. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  707. def tan(a):
  708. return prims.tan(a)
  709. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
  710. def tanh(a):
  711. return prims.tanh(a)
  712. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
  713. def trunc(a):
  714. return prims.trunc(a)
  715. def _make_elementwise_binary_reference(
  716. type_promotion_kind,
  717. aten_op=infer_aten_op,
  718. name=None,
  719. has_out=True,
  720. supports_lhs_python_scalar=True,
  721. supports_rhs_python_scalar=True,
  722. supports_two_python_scalars=False,
  723. ) -> Callable:
  724. def inner(prim: Callable):
  725. nonlocal aten_op, name
  726. if name is None:
  727. name = prim.__name__
  728. @wraps(prim)
  729. @elementwise_type_promotion_wrapper(
  730. type_promoting_args=("a", "b"),
  731. type_promotion_kind=type_promotion_kind,
  732. )
  733. def _ref(
  734. a: Union[Tensor, NumberType],
  735. b: Union[Tensor, NumberType],
  736. ) -> Tensor:
  737. check(
  738. supports_lhs_python_scalar or not isinstance(a, Number),
  739. lambda: f"{name}: Received a lhs Python scalar to an elementwise binary "
  740. "operation that does not accept lhs scalars!",
  741. ValueError,
  742. )
  743. check(
  744. supports_rhs_python_scalar or not isinstance(b, Number),
  745. lambda: f"{name}: Received a rhs Python scalar to an elementwise binary "
  746. "operation that does not accept rhs scalars!",
  747. ValueError,
  748. )
  749. check(
  750. supports_two_python_scalars
  751. or not (isinstance(a, Number) and isinstance(b, Number)),
  752. lambda: f"{name}: Receive two Number inputs to an elementwise binary operation!",
  753. ValueError,
  754. )
  755. a, b = _maybe_broadcast(a, b)
  756. return prim(a, b)
  757. if has_out:
  758. _ref = out_wrapper()(_ref)
  759. _ref.__name__ = name
  760. if aten_op is infer_aten_op:
  761. aten_op = utils.get_aten_op(prim, name)
  762. if aten_op is not None:
  763. register_decomposition(aten_op)(_ref)
  764. return _ref
  765. return inner
  766. # Add has its own implementation because it has an alpha argument
  767. @register_decomposition(aten.add)
  768. @out_wrapper()
  769. @elementwise_type_promotion_wrapper(
  770. type_promoting_args=("a", "b"),
  771. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  772. )
  773. def add(
  774. a: Union[TensorLikeType, NumberType],
  775. b: Union[TensorLikeType, NumberType],
  776. *,
  777. alpha: Optional[NumberType] = None,
  778. ):
  779. """
  780. Reference implementation of torch.add
  781. """
  782. a, b = _maybe_broadcast(a, b)
  783. if alpha is not None:
  784. dtype = a.dtype if isinstance(a, TensorLike) else b.dtype # type: ignore[union-attr]
  785. python_type = utils.dtype_to_type(dtype)
  786. if python_type != bool and not utils.is_weakly_lesser_type(
  787. type(alpha), python_type
  788. ):
  789. msg = (
  790. "alpha argument of type {0} cannot be safely cast to type {1}!".format(
  791. type(alpha), python_type
  792. )
  793. )
  794. raise ValueError(msg)
  795. b = prims.mul(b, alpha)
  796. return prims.add(a, b)
  797. # TODO: add docstring
  798. @_make_elementwise_binary_reference(
  799. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  800. supports_lhs_python_scalar=False,
  801. supports_rhs_python_scalar=False,
  802. )
  803. def atan2(a, b):
  804. return prims.atan2(a, b)
  805. # TODO: add docstring
  806. @_make_elementwise_binary_reference(
  807. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  808. )
  809. def bitwise_and(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  810. return prims.bitwise_and(a, b)
  811. # TODO: add docstring
  812. @_make_elementwise_binary_reference(
  813. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  814. )
  815. def bitwise_left_shift(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  816. return prims.shift_left(a, b)
  817. # TODO: add docstring
  818. @_make_elementwise_binary_reference(
  819. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  820. )
  821. def bitwise_or(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  822. return prims.bitwise_or(a, b)
  823. # TODO: add docstring
  824. @_make_elementwise_binary_reference(
  825. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  826. )
  827. def bitwise_right_shift(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  828. return prims.shift_right_arithmetic(a, b)
  829. # TODO: add docstring
  830. @_make_elementwise_binary_reference(
  831. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  832. )
  833. def bitwise_xor(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  834. return prims.bitwise_xor(a, b)
  835. # TODO: add docstring
  836. @_make_elementwise_binary_reference(
  837. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  838. supports_lhs_python_scalar=False,
  839. )
  840. def copysign(
  841. a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
  842. ):
  843. if isinstance(b, Number) and isinstance(a, Tensor):
  844. b = scalar_tensor(b, dtype=a.dtype, device=a.device)
  845. elif isinstance(a, Tensor) and isinstance(b, Tensor) and a.device != b.device:
  846. msg = "Expected divisor (b) to be on the same device ({0}) as dividend (a), but it is found on {1}!".format(
  847. a.device, b.device
  848. )
  849. raise RuntimeError(msg)
  850. return where(signbit(b), neg(abs(a)), abs(a))
  851. # TODO: add docstring
  852. # complex = _make_elementwise_binary_reference(prims.complex, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
  853. @register_decomposition(aten.div)
  854. @out_wrapper()
  855. def div(
  856. a: Union[TensorLikeType, NumberType],
  857. b: Union[TensorLikeType, NumberType],
  858. *,
  859. rounding_mode: Optional[str] = None,
  860. ):
  861. """
  862. Reference implementation of torch.div
  863. """
  864. if rounding_mode is None:
  865. return true_divide(a, b)
  866. elif rounding_mode == "trunc":
  867. return trunc_divide(a, b)
  868. elif rounding_mode == "floor":
  869. return floor_divide(a, b)
  870. else:
  871. msg = (
  872. "div expected rounding_mode to be one of None, 'trunc', or 'floor' "
  873. "but found {0}.".format(rounding_mode)
  874. )
  875. raise ValueError(msg)
  876. # TODO: add docstring
  877. @_make_elementwise_binary_reference(
  878. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
  879. supports_lhs_python_scalar=False,
  880. )
  881. def eq(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  882. return prims.eq(a, b)
  883. # TODO: add docstring
  884. @_make_elementwise_binary_reference(
  885. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG,
  886. )
  887. def pow(
  888. a: Union[TensorLikeType, NumberType],
  889. b: Union[TensorLikeType, NumberType],
  890. ) -> TensorLikeType:
  891. assert isinstance(a, TensorLikeType) or isinstance(b, TensorLikeType)
  892. if isinstance(b, Number):
  893. if b == 1.0:
  894. return a.clone() # type: ignore[return-value,union-attr]
  895. elif b == 2.0:
  896. return a * a # type: ignore[return-value]
  897. elif b == 0.5:
  898. return torch.sqrt(a) # type: ignore[arg-type]
  899. elif isinstance(a, Number):
  900. if a == 1.0:
  901. return torch.fill(b, True)
  902. if a == 2.0 and (
  903. utils.is_float_dtype(b.dtype) or utils.is_complex_dtype(b.dtype)
  904. ):
  905. return torch.exp2(b)
  906. return prims.pow(a, b)
  907. # TODO: add docstring
  908. # Float power has its own implementation because it has unique type promotion.
  909. # NB: aten_op not registered because CompositeExplicitAutograd
  910. @out_wrapper()
  911. def float_power(
  912. a: Union[TensorLikeType, NumberType],
  913. b: Union[TensorLikeType, NumberType],
  914. ) -> Tensor:
  915. if isinstance(a, Number) and isinstance(b, Number):
  916. raise ValueError(
  917. "Receive two Number inputs to an elementwise binary operation!"
  918. )
  919. # Handles type promotion
  920. dtype = utils.get_higher_dtype(a, b)
  921. assert dtype is not None
  922. if utils.is_complex_dtype(dtype):
  923. dtype = torch.complex128
  924. else:
  925. dtype = torch.float64
  926. # Float power has the following contiguous cast behavior to be
  927. # consistent with its C++ impl
  928. a = _maybe_convert_to_dtype(a, dtype)
  929. b = _maybe_convert_to_dtype(b, dtype)
  930. a, b = _maybe_broadcast(a, b)
  931. return pow(a, b)
  932. # >>> a = torch.tensor(-0.2500, dtype=torch.float64)
  933. # tensor(-0.250000000000000, dtype=torch.float64)
  934. #
  935. # >>> b = torch.tensor(-0.0010, dtype=torch.float64)
  936. # tensor(-0.001000000000000, dtype=torch.float64)
  937. #
  938. # Note: In this case, casting float to double will expand the float mantissa with zeros,
  939. # while creating a double generates a distinct mantissa.
  940. # >>> torch.tensor(-0.001).to(dtype=torch.float64)
  941. # tensor(-0.001000000047497, dtype=torch.float64)
  942. #
  943. # Floor Division
  944. # The difference is caused because torch.remainder(a, b) = -0.001.
  945. #
  946. # >>> torch.floor(torch.true_divide(a, b))
  947. # tensor(250., dtype=torch.float64)
  948. #
  949. # >>> torch.div(a, b, rounding_mode='floor')
  950. # tensor(249., dtype=torch.float64)
  951. #
  952. # Definition: a // b = (a - remainder(a, b)) / b
  953. # >>> torch.true_divide(torch.sub(a, torch.remainder(a, b)), b)
  954. # tensor(249., dtype=torch.float64)
  955. #
  956. # For reference, see CPython's implementation:
  957. # https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636
  958. # TODO: add docstring
  959. @_make_elementwise_binary_reference(
  960. type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  961. supports_two_python_scalars=True,
  962. )
  963. def floor_divide(
  964. a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
  965. ):
  966. # Wrap scalars because some references only accept tensor arguments.
  967. if isinstance(a, Number) and isinstance(b, Number):
  968. a = scalar_tensor(a)
  969. b = scalar_tensor(b)
  970. elif isinstance(b, Number) and isinstance(a, Tensor):
  971. b = scalar_tensor(b, dtype=a.dtype, device=a.device)
  972. elif isinstance(a, Number) and isinstance(b, Tensor):
  973. a = scalar_tensor(a, dtype=b.dtype, device=b.device)
  974. elif isinstance(a, Tensor) and isinstance(b, Tensor) and a.device != b.device:
  975. if a.device == torch.device("cpu"):
  976. msg = "Expected divisor (b) to be on the same device ({0}) as dividend (a), but it is found on {1}!".format(
  977. a.device, b.device
  978. )
  979. raise RuntimeError(msg)
  980. else:
  981. b = prims.device_put(b, device=a.device)
  982. assert isinstance(a, Tensor) and isinstance(b, Tensor)
  983. dtype = a.dtype
  984. if utils.is_float_dtype(dtype):
  985. return _floor_divide_float(a, b)
  986. elif utils.is_integer_dtype(dtype):
  987. return _floor_divide_integer(a, b)
  988. else:
  989. check(False, lambda: f"{dtype} not supported for floor_divide")
  990. def _floor_divide_integer(a: Tensor, b: Tensor) -> Tensor:
  991. a, b = _maybe_broadcast(a, b)
  992. if not a.dtype.is_signed:
  993. return prims.div(a, b)
  994. # Convert truncation to flooring:
  995. offset = (torch.signbit(a) != torch.signbit(b)).logical_and(torch.fmod(a, b) != 0)
  996. return prims.div(a, b) - _maybe_convert_to_dtype(offset, a.dtype)
  997. def _floor_divide_float(a: Tensor, b: Tensor) -> Tensor:
  998. mod = fmod(a, b)
  999. div = true_divide(sub(a, mod), b)
  1000. # Ensure that the remainder has the same sign as denominator
  1001. different_signed_inputs = bitwise_xor(lt(a, 0), lt(b, 0))
  1002. non_zero_remainder = ne(mod, 0)
  1003. mask = bitwise_and(non_zero_remainder, different_signed_inputs)
  1004. div = where(mask, sub(div, 1), div)
  1005. # Map quotient to nearest integer value
  1006. floor_div = floor(div)
  1007. mask = gt(sub(div, floor_div), 0.5)
  1008. floor_div = where(mask, add(floor_div, 1), floor_div)
  1009. basic_div = true_divide(a, b)
  1010. zero_tensor = scalar_tensor(0, dtype=basic_div.dtype, device=basic_div.device)
  1011. # If quotient is zero, copy signbit from true_divide quotient
  1012. floor_div = where(ne(div, 0), floor_div, copysign(zero_tensor, basic_div))
  1013. # If denominator is zero, then follow true_divide behavior
  1014. return where(ne(b, 0), floor_div, basic_div)
  1015. # TODO: add docstring
  1016. @_make_elementwise_binary_reference(
  1017. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  1018. supports_lhs_python_scalar=False,
  1019. supports_rhs_python_scalar=False,
  1020. )
  1021. def fmax(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1022. return prims.fmax(a, b)
  1023. # TODO: add docstring
  1024. @_make_elementwise_binary_reference(
  1025. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  1026. supports_lhs_python_scalar=False,
  1027. supports_rhs_python_scalar=False,
  1028. )
  1029. def fmin(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1030. return prims.fmin(a, b)
  1031. # TODO: add docstring
  1032. @_make_elementwise_binary_reference(
  1033. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  1034. supports_lhs_python_scalar=False,
  1035. supports_rhs_python_scalar=True,
  1036. )
  1037. def fmod(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1038. return prims.fmod(a, b)
  1039. # TODO: add docstring
  1040. @_make_elementwise_binary_reference(
  1041. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  1042. supports_lhs_python_scalar=False,
  1043. supports_rhs_python_scalar=False,
  1044. )
  1045. def gcd(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1046. return prims.gcd(a, b)
  1047. # TODO: add docstring
  1048. @_make_elementwise_binary_reference(
  1049. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
  1050. supports_lhs_python_scalar=False,
  1051. )
  1052. def ge(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1053. return prims.ge(a, b)
  1054. # TODO: add docstring
  1055. @_make_elementwise_binary_reference(
  1056. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
  1057. supports_lhs_python_scalar=False,
  1058. )
  1059. def gt(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1060. return prims.gt(a, b)
  1061. @_make_elementwise_binary_reference(
  1062. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  1063. supports_lhs_python_scalar=False,
  1064. supports_rhs_python_scalar=False,
  1065. )
  1066. def heaviside(input: TensorLikeType, values: TensorLikeType) -> TensorLikeType:
  1067. input_eq_zero = torch.eq(input, 0)
  1068. input_lt_zero = torch.logical_or(torch.lt(input, 0), torch.isnan(input))
  1069. zeros_and_ones = torch.where(input_lt_zero, 0, 1)
  1070. output = torch.where(input_eq_zero, values, zeros_and_ones)
  1071. return output
  1072. @_make_elementwise_binary_reference(
  1073. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  1074. supports_lhs_python_scalar=False,
  1075. supports_rhs_python_scalar=False,
  1076. )
  1077. def hypot(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1078. return prims.hypot(a, b)
  1079. @_make_elementwise_binary_reference(
  1080. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  1081. supports_lhs_python_scalar=False,
  1082. supports_rhs_python_scalar=False,
  1083. )
  1084. def igamma(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1085. return prims.igamma(a, b)
  1086. @_make_elementwise_binary_reference(
  1087. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  1088. supports_lhs_python_scalar=False,
  1089. supports_rhs_python_scalar=False,
  1090. )
  1091. def igammac(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1092. return prims.igammac(a, b)
  1093. def _check_close_args(
  1094. name: str,
  1095. a: TensorLikeType,
  1096. b: TensorLikeType,
  1097. rtol: float,
  1098. atol: float,
  1099. ) -> None:
  1100. check(
  1101. a.dtype == b.dtype,
  1102. lambda: "{0}: Attempting to compare tensors of different dtypes {1} and {2}!".format(
  1103. name, a.dtype, b.dtype
  1104. ),
  1105. ValueError,
  1106. )
  1107. check(
  1108. rtol >= 0,
  1109. lambda: "{0}: rtol must be greater than or equal to zero, but got {1}!".format(
  1110. name, rtol
  1111. ),
  1112. )
  1113. check(
  1114. atol >= 0,
  1115. lambda: "{0}: atol must be greater than or equal to zero, but got {1}!".format(
  1116. name, atol
  1117. ),
  1118. )
  1119. # CompositeImplicitAutograd - don't register decomp
  1120. def isclose(
  1121. a: TensorLikeType,
  1122. b: TensorLikeType,
  1123. rtol: float = 1e-05,
  1124. atol: float = 1e-08,
  1125. equal_nan: bool = False,
  1126. ) -> TensorLikeType:
  1127. _check_close_args(name="torch.isclose", a=a, b=b, rtol=rtol, atol=atol)
  1128. close = eq(a, b)
  1129. if equal_nan and (utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype)):
  1130. close = logical_or(close, logical_and(isnan(a), isnan(b)))
  1131. # Note: In case of zero tolerances the closeness inequality degenerates to an equality check.
  1132. # In this case, the short-circuit prevents false positives as detailed in the paragraph below.
  1133. if atol == 0 and rtol == 0:
  1134. return close
  1135. # Note [closeness error computation]
  1136. # atol and rtol are provided as doubles, so the computation
  1137. # rtol * other will produce a float or complex tensor.
  1138. # When the difference (self - other) is compared to it then the
  1139. # tensor representing the difference will also be cast to float or complex.
  1140. # However, since (self - other) in uint8 is very likely to produce a
  1141. # negative value, this moves the cast forward so the difference is
  1142. # always computed in a float or complex type.
  1143. # If the values of the integer tensors cannot be exactly represented
  1144. # by the default scalar type then this may cause an incorrect result.
  1145. if not utils.is_float_dtype(a.dtype) and not utils.is_complex_dtype(a.dtype):
  1146. a = prims.convert_element_type(a, torch.get_default_dtype())
  1147. b = prims.convert_element_type(b, torch.get_default_dtype())
  1148. allowed_error = add(atol, abs(mul(b, rtol)))
  1149. actual_error = abs(sub(a, b))
  1150. # Computes finite closeness
  1151. result = logical_or(
  1152. close, logical_and(isfinite(actual_error), le(actual_error, allowed_error))
  1153. )
  1154. return result
  1155. # TODO: add docstring
  1156. @_make_elementwise_binary_reference(
  1157. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  1158. supports_lhs_python_scalar=False,
  1159. supports_rhs_python_scalar=False,
  1160. )
  1161. def lcm(a: TensorLikeType, b: TensorLikeType):
  1162. dtype = a.dtype
  1163. # promoting to int32 to maintain 100% consistency with C++ and to
  1164. # prevent overflow in case of int8 and int16
  1165. promote_to_int = dtype in (torch.int8, torch.int16)
  1166. if promote_to_int:
  1167. a = prims.convert_element_type(a, torch.int32)
  1168. b = prims.convert_element_type(b, torch.int32)
  1169. g = torch.gcd(a, b)
  1170. # Avoid division by zero in case gcd(0, 0) == 0
  1171. g = torch.where(g == 0, 1, g)
  1172. res = torch.abs(prims.div(a, g) * b)
  1173. return res if not promote_to_int else prims.convert_element_type(res, dtype)
  1174. # TODO: add docstring
  1175. @_make_elementwise_binary_reference(
  1176. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
  1177. supports_lhs_python_scalar=False,
  1178. )
  1179. def le(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1180. return prims.le(a, b)
  1181. @_make_elementwise_binary_reference(
  1182. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  1183. supports_lhs_python_scalar=False,
  1184. supports_rhs_python_scalar=False,
  1185. )
  1186. def logaddexp(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1187. # Nb. this implementation does nto distribute the gradients evenly when a == b
  1188. mask = a >= b
  1189. max_ = torch.where(mask, a, b)
  1190. min_ = torch.where(mask, b, a)
  1191. inf_mask = torch.logical_and(torch.isinf(a), a == b)
  1192. return torch.where(inf_mask, a, max_ + torch.log1p(torch.exp(min_ - max_)))
  1193. # TODO: add docstring
  1194. @_make_elementwise_binary_reference(
  1195. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
  1196. )
  1197. def logical_and(a: TensorLikeType, b: TensorLikeType):
  1198. if not utils.is_boolean_dtype(a.dtype):
  1199. a = a != 0
  1200. if not utils.is_boolean_dtype(b.dtype):
  1201. b = b != 0
  1202. return a & b
  1203. # TODO: add docstring
  1204. @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
  1205. def logical_not(a: TensorLikeType):
  1206. if not utils.is_boolean_dtype(a.dtype):
  1207. return a == 0
  1208. return ~a
  1209. # TODO: add docstring
  1210. @_make_elementwise_binary_reference(
  1211. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
  1212. )
  1213. def logical_or(a: TensorLikeType, b: TensorLikeType):
  1214. if not utils.is_boolean_dtype(a.dtype):
  1215. a = a != 0
  1216. if not utils.is_boolean_dtype(b.dtype):
  1217. b = b != 0
  1218. return bitwise_or(a, b)
  1219. # TODO: add docstring
  1220. # TODO: skip unnecessary conversion of long to float
  1221. @_make_elementwise_binary_reference(
  1222. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
  1223. )
  1224. def logical_xor(a: TensorLikeType, b: TensorLikeType):
  1225. if not utils.is_boolean_dtype(a.dtype):
  1226. a = a != 0
  1227. if not utils.is_boolean_dtype(b.dtype):
  1228. b = b != 0
  1229. return a ^ b
  1230. # TODO: add docstring
  1231. @_make_elementwise_binary_reference(
  1232. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
  1233. supports_lhs_python_scalar=False,
  1234. )
  1235. def lt(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1236. return prims.lt(a, b)
  1237. # TODO: add docstring
  1238. @_make_elementwise_binary_reference(
  1239. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  1240. )
  1241. def maximum(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1242. return prims.maximum(a, b)
  1243. # TODO: add docstring
  1244. @_make_elementwise_binary_reference(
  1245. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  1246. )
  1247. def minimum(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1248. return prims.minimum(a, b)
  1249. # TODO: add docstring
  1250. @_make_elementwise_binary_reference(
  1251. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  1252. supports_two_python_scalars=True,
  1253. )
  1254. def mul(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1255. return prims.mul(a, b)
  1256. # TODO: add docstring
  1257. @_make_elementwise_binary_reference(
  1258. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
  1259. supports_lhs_python_scalar=False,
  1260. )
  1261. def ne(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1262. return prims.ne(a, b)
  1263. # TODO: add docstring
  1264. @_make_elementwise_binary_reference(
  1265. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
  1266. supports_lhs_python_scalar=False,
  1267. supports_rhs_python_scalar=False,
  1268. )
  1269. def nextafter(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1270. return prims.nextafter(a, b)
  1271. # TODO: add docstring
  1272. @_make_elementwise_binary_reference(
  1273. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  1274. )
  1275. def remainder(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1276. return prims.remainder(a, b)
  1277. # reverse sub
  1278. def rsub(
  1279. a: Union[TensorLikeType, NumberType],
  1280. b: Union[TensorLikeType, NumberType],
  1281. *,
  1282. alpha: Optional[NumberType] = None,
  1283. ):
  1284. if isinstance(a, Number):
  1285. msg = "Received a Number for the first argument, but expected a Tensor"
  1286. raise ValueError(msg)
  1287. return sub(b, a, alpha=alpha)
  1288. # TODO: add docstring
  1289. # TODO: consider refactoring this with add impl
  1290. # sub has its own implementation because it has an alpha argument
  1291. @register_decomposition(aten.sub)
  1292. @out_wrapper()
  1293. @elementwise_type_promotion_wrapper(
  1294. type_promoting_args=("a", "b"),
  1295. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  1296. )
  1297. def sub(
  1298. a: Union[TensorLikeType, NumberType],
  1299. b: Union[TensorLikeType, NumberType],
  1300. *,
  1301. alpha: Optional[NumberType] = None,
  1302. ):
  1303. """
  1304. Reference implementation of torch.sub
  1305. """
  1306. a, b = _maybe_broadcast(a, b)
  1307. if alpha is not None:
  1308. dtype = a.dtype if isinstance(a, TensorLike) else b.dtype # type: ignore[union-attr]
  1309. python_type = utils.dtype_to_type(dtype)
  1310. if not utils.is_weakly_lesser_type(type(alpha), python_type):
  1311. msg = (
  1312. "alpha argument of type {0} cannot be safely cast to type {1}!".format(
  1313. type(alpha), python_type
  1314. )
  1315. )
  1316. raise ValueError(msg)
  1317. b = prims.mul(b, alpha)
  1318. return prims.sub(a, b)
  1319. # TODO: add docstring
  1320. @_make_elementwise_binary_reference(
  1321. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  1322. name="true_divide",
  1323. aten_op=None, # CompositeImplicitAutograd
  1324. supports_two_python_scalars=True,
  1325. )
  1326. def true_divide(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  1327. return prims.div(a, b)
  1328. @register_decomposition(aten.xlogy)
  1329. @out_wrapper()
  1330. @elementwise_type_promotion_wrapper(
  1331. type_promoting_args=("a", "b"),
  1332. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  1333. )
  1334. def xlogy(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]):
  1335. utils.check(
  1336. isinstance(a, TensorLike) or isinstance(b, TensorLike),
  1337. lambda: 'Expected either argument a or b to be a Tensor"',
  1338. )
  1339. # Operations like eq and log do not handle scalar values, so we convert them to scalar_tensors.
  1340. if isinstance(b, TensorLike) and isinstance(a, Number):
  1341. a = scalar_tensor(a, dtype=b.dtype, device=b.device)
  1342. elif isinstance(a, TensorLike) and isinstance(b, Number):
  1343. b = scalar_tensor(b, dtype=a.dtype, device=a.device)
  1344. # mypy: expected "Tensor"
  1345. assert isinstance(a, TensorLike)
  1346. assert isinstance(b, TensorLike)
  1347. rhs = torch.where(torch.eq(a, 0), 0, torch.mul(a, torch.log(b)))
  1348. return torch.where(torch.isnan(b), float("nan"), rhs)
  1349. # TODO: add docstring
  1350. @_make_elementwise_binary_reference(
  1351. type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  1352. aten_op=None, # CompositeImplicitAutograd
  1353. supports_two_python_scalars=True,
  1354. )
  1355. def trunc_divide(
  1356. a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
  1357. ):
  1358. dtype = utils.get_dtype(a)
  1359. if utils.is_integer_dtype(dtype):
  1360. return prims.div(a, b)
  1361. return trunc(prims.div(a, b))
  1362. #
  1363. # Elementwise Ternary References
  1364. #
  1365. @register_decomposition(aten.addcdiv)
  1366. @out_wrapper()
  1367. @elementwise_type_promotion_wrapper(
  1368. type_promoting_args=("self", "tensor1", "tensor2"),
  1369. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  1370. )
  1371. def addcdiv(
  1372. self: TensorLikeType,
  1373. tensor1: TensorLikeType,
  1374. tensor2: TensorLikeType,
  1375. *,
  1376. value: NumberType = 1,
  1377. ) -> TensorLikeType:
  1378. """
  1379. Reference implementation of torch.addcdiv
  1380. """
  1381. if value is not None:
  1382. dtype = self.dtype # no scalars allowed, see add
  1383. python_type = utils.dtype_to_type(dtype)
  1384. check(
  1385. utils.is_weakly_lesser_type(type(value), python_type),
  1386. lambda: "value argument of type {0} cannot be safely cast to type {1}!".format(
  1387. type(value), python_type
  1388. ),
  1389. exc_type=ValueError,
  1390. )
  1391. return self + value * tensor1 / tensor2
  1392. @register_decomposition(aten.addcmul)
  1393. @out_wrapper()
  1394. @elementwise_type_promotion_wrapper(
  1395. type_promoting_args=("self", "tensor1", "tensor2"),
  1396. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  1397. )
  1398. def addcmul(
  1399. self: TensorLikeType,
  1400. tensor1: TensorLikeType,
  1401. tensor2: TensorLikeType,
  1402. *,
  1403. value: NumberType = 1,
  1404. ) -> TensorLikeType:
  1405. """
  1406. Reference implementation of torch.addcmul
  1407. """
  1408. if value is not None:
  1409. dtype = self.dtype # no scalars allowed, see add
  1410. python_type = utils.dtype_to_type(dtype)
  1411. check(
  1412. utils.is_weakly_lesser_type(type(value), python_type),
  1413. lambda: "value argument of type {0} cannot be safely cast to type {1}!".format(
  1414. type(value), python_type
  1415. ),
  1416. exc_type=ValueError,
  1417. )
  1418. return self + value * tensor1 * tensor2
  1419. @register_decomposition(aten.clamp)
  1420. @out_wrapper()
  1421. @elementwise_type_promotion_wrapper(
  1422. type_promoting_args=("a", "min", "max"),
  1423. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  1424. )
  1425. def clamp(
  1426. a: TensorLikeType,
  1427. min: Optional[TensorOrNumberLikeType] = None,
  1428. max: Optional[TensorOrNumberLikeType] = None,
  1429. ) -> TensorLikeType:
  1430. # NOTE: grad behavior with implementation `where` is not consistent on `nan`
  1431. if min is None and max is None:
  1432. msg = "clamp called but both min and max are none!"
  1433. raise ValueError(msg)
  1434. if min is not None:
  1435. a_isnan = torch.isnan(a)
  1436. condition = torch.bitwise_or(torch.ge(a, min), a_isnan) # type: ignore[arg-type]
  1437. # we should also propagate `nan` coming from boundaries. However, that's
  1438. # not necessary since `ge` would already `False` when either operands has
  1439. # a `nan`. So this line below is redundant
  1440. # `condition = bitwise_and(condition, bitwise_not(isnan(min)))`
  1441. a = torch.where(condition, a, min) # type: ignore[arg-type]
  1442. if max is not None:
  1443. a_isnan = torch.isnan(a)
  1444. # same as above, no need to adjust `nan` from `max`
  1445. condition = torch.bitwise_or(torch.le(a, max), a_isnan) # type: ignore[arg-type]
  1446. a = torch.where(condition, a, max) # type: ignore[arg-type]
  1447. return a
  1448. @register_decomposition(aten.clamp_min)
  1449. @out_wrapper()
  1450. def clamp_min(
  1451. self: TensorLikeType,
  1452. min: TensorOrNumberLikeType = None,
  1453. ) -> TensorLikeType:
  1454. return torch.clamp(self, min=min) # type: ignore[arg-type]
  1455. @register_decomposition(aten.clamp_max)
  1456. @out_wrapper()
  1457. def clamp_max(
  1458. self: TensorLikeType,
  1459. max: TensorOrNumberLikeType = None,
  1460. ) -> TensorLikeType:
  1461. return torch.clamp(self, max=max) # type: ignore[arg-type]
  1462. #
  1463. # Conditional references
  1464. #
  1465. # https://pytorch.org/docs/stable/generated/torch.where.html
  1466. # TODO: implement alternate where
  1467. @register_decomposition(aten.where)
  1468. @out_wrapper()
  1469. @elementwise_type_promotion_wrapper(
  1470. type_promoting_args=("a", "b"),
  1471. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
  1472. )
  1473. def where(
  1474. pred: Tensor,
  1475. a: Optional[TensorOrNumberLikeType] = None,
  1476. b: Optional[TensorOrNumberLikeType] = None,
  1477. ):
  1478. """ """
  1479. if a is None or b is None:
  1480. raise NotImplementedError
  1481. utils.check_same_device(pred, a, b, allow_cpu_scalar_tensors=True)
  1482. check(
  1483. pred.dtype is torch.bool,
  1484. lambda: f"expected predicate to be bool, got {pred.dtype}",
  1485. )
  1486. pred, a, b = _maybe_broadcast(pred, a, b)
  1487. return prims.where(pred, a, b)
  1488. #
  1489. # Data Movement References
  1490. #
  1491. @register_decomposition(aten.clone)
  1492. def clone(
  1493. a: TensorLikeType, *, memory_format: torch.memory_format = torch.preserve_format
  1494. ) -> TensorLikeType:
  1495. result = prims.clone(a, memory_format=memory_format)
  1496. return result
  1497. def copy_to(a: Tensor, b: Tensor, *, allow_cross_device=True):
  1498. if not allow_cross_device and a.device != b.device:
  1499. msg = "Attempting to copy from device {0} to device {1}, but cross-device copies are not allowed!".format(
  1500. b.device, a.device
  1501. )
  1502. raise RuntimeError(msg)
  1503. return prims.copy_to(a, b)
  1504. @register_decomposition(aten.item)
  1505. def item(a: TensorLikeType) -> NumberType:
  1506. if a.numel() != 1:
  1507. msg = f"Can't convert a tensor with {a.numel()} elements to a number!"
  1508. raise ValueError(msg)
  1509. # NOTE: explicit conversion is necessary for bool!
  1510. # See https://github.com/pytorch/pytorch/issues/78071
  1511. number_type = utils.dtype_to_type(a.dtype)
  1512. return number_type(prims.item(a))
  1513. # fast path when `to` returns an alias to input. This mimics the same function in aten
  1514. def _to_will_alias(
  1515. a: TensorLikeType,
  1516. device: Optional[torch.device] = None,
  1517. dtype: Optional[torch.dtype] = None,
  1518. copy: Optional[bool] = None,
  1519. layout: Optional[torch.layout] = None,
  1520. memory_format: Optional[torch.memory_format] = None,
  1521. pin_memory: Optional[bool] = False,
  1522. non_blocking: bool = False, # not using non_blocking
  1523. ) -> bool:
  1524. return (
  1525. not copy
  1526. and (device is None or a.device == device)
  1527. and (dtype is None or a.dtype == dtype)
  1528. and (layout is None or a.layout == layout)
  1529. # is_pinned issue #84925
  1530. # and (pin_memory is None or pin_memory == a.is_pinned())
  1531. and (
  1532. memory_format is None
  1533. or memory_format == torch.preserve_format
  1534. or utils.is_contiguous_for_memory_format(a, memory_format=memory_format)
  1535. )
  1536. )
  1537. @singledispatch
  1538. def _to_dispatch(*args, **kwargs):
  1539. raise NotImplementedError
  1540. @_to_dispatch.register
  1541. def _to_device(
  1542. device: torch.device,
  1543. dtype: torch.dtype,
  1544. non_blocking: bool = False,
  1545. copy: bool = False,
  1546. memory_format: Optional[torch.memory_format] = None,
  1547. ):
  1548. kwargs = {
  1549. "device": device,
  1550. "dtype": dtype,
  1551. "non_blocking": non_blocking,
  1552. "copy": copy,
  1553. "memory_format": memory_format,
  1554. }
  1555. return kwargs
  1556. @_to_dispatch.register
  1557. def _to_device_str(
  1558. device: str,
  1559. dtype: torch.dtype,
  1560. non_blocking: bool = False,
  1561. copy: bool = False,
  1562. memory_format: Optional[torch.memory_format] = None,
  1563. ):
  1564. kwargs = {
  1565. "device": torch.device(device),
  1566. "dtype": dtype,
  1567. "non_blocking": non_blocking,
  1568. "copy": copy,
  1569. "memory_format": memory_format,
  1570. }
  1571. return kwargs
  1572. @_to_dispatch.register
  1573. def _to_dtype(
  1574. dtype: torch.dtype,
  1575. non_blocking: bool = False,
  1576. copy: bool = False,
  1577. memory_format: Optional[torch.memory_format] = None,
  1578. ):
  1579. kwargs = {
  1580. "dtype": dtype,
  1581. "non_blocking": non_blocking,
  1582. "copy": copy,
  1583. "memory_format": memory_format,
  1584. }
  1585. return kwargs
  1586. @_to_dispatch.register
  1587. def _to_other(
  1588. other: Tensor,
  1589. non_blocking: bool = False,
  1590. copy: bool = False,
  1591. memory_format: Optional[torch.memory_format] = None,
  1592. ):
  1593. device = other.device
  1594. dtype = other.dtype
  1595. layout = other.layout
  1596. # is_pinned issue #84925
  1597. # pin_memory = other.is_pinned()
  1598. kwargs = {
  1599. "device": device,
  1600. "dtype": dtype,
  1601. "layout": layout,
  1602. "non_blocking": non_blocking,
  1603. "copy": copy,
  1604. "memory_format": memory_format,
  1605. }
  1606. return kwargs
  1607. # remove to_kwargs that is already present in `a`
  1608. def canonicalize_to_arguments(a: Tensor, to_kwargs: dict):
  1609. options_to_check = ["dtype", "device", "layout", "memory_format"]
  1610. # "device" option could be passed a str instead torch.device
  1611. if "device" in to_kwargs and isinstance(to_kwargs["device"], str):
  1612. to_kwargs["device"] = torch.device(to_kwargs["device"])
  1613. for kw in options_to_check:
  1614. if kw in to_kwargs:
  1615. if (
  1616. (kw == "memory_format" and to_kwargs[kw] is torch.preserve_format)
  1617. or (
  1618. kw == "device"
  1619. and to_kwargs[kw].type == a.device.type
  1620. and (
  1621. not to_kwargs[kw].index or to_kwargs[kw].index == a.device.index
  1622. )
  1623. )
  1624. or (
  1625. getattr(a, kw, None) == to_kwargs[kw]
  1626. ) # this also handles {"memory_format": None}
  1627. ):
  1628. to_kwargs.pop(kw)
  1629. def to(a: TensorLikeType, *args, **kwargs) -> TensorLikeType:
  1630. # handled dispatch via positional arguments
  1631. if len(args) != 0:
  1632. kwargs = _to_dispatch(*args, **kwargs)
  1633. # TODO: is_pinned is not currently supported in refs or fake_tensor
  1634. # https://github.com/pytorch/pytorch/issues/84925
  1635. assert "pin_memory" not in kwargs
  1636. canonicalize_to_arguments(a, kwargs)
  1637. if _to_will_alias(a, **kwargs):
  1638. return a
  1639. copy = kwargs.pop("copy") if "copy" in kwargs else False
  1640. non_blocking = kwargs.pop("non_blocking") if "non_blocking" in kwargs else False
  1641. # short-circuit to `prims.convert_element_type` when `to` is just a dtype change
  1642. if (
  1643. (copy or (kwargs.get("dtype", a.dtype) != a.dtype))
  1644. and (not non_blocking)
  1645. and ("memory_format" not in kwargs)
  1646. and ("device" not in kwargs)
  1647. and ("layout" not in kwargs)
  1648. # is_pinned issue #84925
  1649. # and ("pin_memory" not in kwargs)
  1650. ):
  1651. return prims.convert_element_type(a, kwargs.get("dtype", a.dtype))
  1652. result = torch.empty_like(a, **kwargs)
  1653. # TODO: non_blocking should be handled by `copy_to`
  1654. copy_to(result, a)
  1655. return result
  1656. #
  1657. # Reduction references
  1658. #
  1659. def _reduction(
  1660. a: TensorLikeType,
  1661. prim: Callable,
  1662. *,
  1663. has_identity: bool = True,
  1664. accepts_dim_tuple: bool = True, # to handle min/argmin that accept single dim only
  1665. dims: Optional[DimsType] = None,
  1666. keepdims: bool = False,
  1667. dtype: Optional[torch.dtype] = None, # should be specified for ops that support it
  1668. out: Optional[Tensor] = None,
  1669. output_dtype_kind: REDUCTION_OUTPUT_TYPE_KIND,
  1670. ) -> TensorLikeType: # it is usually SAME, but I want
  1671. # ref writers to actually think about what to put here
  1672. assert isinstance(a, TensorLike)
  1673. if a.ndim > 64:
  1674. raise RuntimeError(
  1675. "Received a tensor with {0} dimensions, but only tensors with up to 64 dims are supported!".format(
  1676. a.ndim
  1677. )
  1678. )
  1679. if out is not None:
  1680. assert isinstance(out, TensorLike)
  1681. if dtype is not None:
  1682. # TODO - this is true for eager mode currently, but it's wrong behavior for complex norms
  1683. if dtype != out.dtype:
  1684. raise RuntimeError(
  1685. "dtype argument and out dtype must match in reduction"
  1686. )
  1687. if not accepts_dim_tuple:
  1688. assert dims is None or isinstance(dims, Dim)
  1689. if isinstance(dims, Dim):
  1690. dims = (dims,) # type: ignore[assignment]
  1691. dims = utils.reduction_dims(a.shape, dims)
  1692. if not has_identity:
  1693. valid_shape = a.ndim == 0 or py_all(a.shape[i] for i in dims)
  1694. if not valid_shape:
  1695. raise RuntimeError(
  1696. "reducing over zero-size dimension for reduction operation without identity"
  1697. )
  1698. computation_dtype, result_dtype = utils.reduction_dtypes(
  1699. a, output_dtype_kind, dtype
  1700. )
  1701. a = _maybe_convert_to_dtype(a, computation_dtype) # type: ignore[assignment]
  1702. result = prim(a, dims)
  1703. if keepdims:
  1704. output_shape = [a.shape[i] if i not in dims else 1 for i in range(a.ndim)]
  1705. broadcast_dims = [i for i in range(a.ndim) if i not in dims]
  1706. result = prims.broadcast_in_dim(result, output_shape, broadcast_dims)
  1707. if out is not None:
  1708. assert result_dtype is not None
  1709. if dtype is not None and result_dtype != out.dtype:
  1710. raise RuntimeError(
  1711. "Expected the dtype of reduction result and out to match"
  1712. )
  1713. out = _maybe_resize_out(out, result.shape)
  1714. return _safe_copy_out(copy_from=result, copy_to=out) # type: ignore[arg-type]
  1715. if result.dtype != result_dtype and result_dtype is not None:
  1716. result = prims.convert_element_type(result, result_dtype)
  1717. return result
  1718. def _make_copy_from_view(fn):
  1719. """
  1720. Given a view function (e.g. torch.diagonal) generates its copy variant (e.g. torch.diagonal_copy)
  1721. """
  1722. name = fn.__name__
  1723. fn = out_wrapper()(fn)
  1724. def _fn(*args, out=None, **kwargs):
  1725. result = fn(*args, out=out, **kwargs)
  1726. if out is None:
  1727. return result.clone(memory_format=torch.contiguous_format)
  1728. return result
  1729. copy_name = f"{name}_copy"
  1730. _fn.__name__ = copy_name
  1731. _fn = register_decomposition(getattr(aten, copy_name))(_fn)
  1732. return _fn
  1733. # Saves Python all
  1734. py_all = all
  1735. @register_decomposition(aten.all)
  1736. @out_wrapper()
  1737. def all(
  1738. a: TensorLikeType,
  1739. dim: Optional[DimsType] = None,
  1740. keepdim: bool = False,
  1741. ) -> TensorLikeType:
  1742. # Computes nelem
  1743. if isinstance(dim, Dim):
  1744. dim = (dim,) # type: ignore[assignment]
  1745. a_ = _maybe_convert_to_dtype(a, torch.bool)
  1746. # avoid comparison with symbolic number of elements to make this op symint friendly
  1747. result = eq(sum(logical_not(a_), dim=dim, keepdim=keepdim), 0)
  1748. # Preserves uint8 -- probably a legacy mask thing
  1749. if a.dtype is torch.uint8:
  1750. return prims.convert_element_type(result, torch.uint8)
  1751. return result
  1752. # Saves Python any
  1753. py_any = any
  1754. @register_decomposition(aten.any)
  1755. @out_wrapper()
  1756. def any(
  1757. a: TensorLikeType,
  1758. dim: Optional[DimsType] = None,
  1759. keepdim: bool = False,
  1760. ) -> TensorLikeType:
  1761. a_ = _maybe_convert_to_dtype(a, torch.bool)
  1762. result = ne(sum(a_, dim=dim, keepdim=keepdim), False) # type: ignore[arg-type]
  1763. # Preserves uint8 -- probably a legacy mask thing
  1764. if a.dtype is torch.uint8:
  1765. return prims.convert_element_type(result, torch.uint8)
  1766. return result
  1767. @register_decomposition(aten.sum)
  1768. def sum(
  1769. a: TensorLikeType,
  1770. dim: Union[Optional[int], Optional[List[int]]] = None,
  1771. keepdim: bool = False,
  1772. *,
  1773. dtype: Optional[torch.dtype] = None,
  1774. out: Optional[Tensor] = None,
  1775. ) -> TensorLikeType:
  1776. if dtype is None:
  1777. if utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype):
  1778. dtype = torch.int64
  1779. else:
  1780. dtype = a.dtype
  1781. # reduces over all dimensions if dim=() is passed
  1782. if dim == () or dim == []:
  1783. dim = None
  1784. return _reduction(
  1785. a,
  1786. prims.sum,
  1787. dims=dim,
  1788. keepdims=keepdim,
  1789. dtype=dtype,
  1790. out=out,
  1791. output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME,
  1792. )
  1793. def sum_to_size(
  1794. a: Tensor,
  1795. *shape,
  1796. ) -> Tensor:
  1797. shape = utils.extract_shape_from_varargs(shape, validate=False)
  1798. utils.check(
  1799. utils.is_expandable_to(shape, a.shape),
  1800. lambda: f'sum_to_size: size "{shape}" is not expandable to size "{a.shape}"',
  1801. )
  1802. # In ATen scalar tensors are sent through sum and the result is returned as
  1803. # type promoted
  1804. if utils.is_same_shape(shape, a.shape) and len(shape) > 0:
  1805. return prims.view_of(a)
  1806. leading_dims = a.ndim - len(shape)
  1807. reduce_dims = tuple(range(leading_dims)) + tuple(
  1808. i
  1809. for i in range(leading_dims, len(shape))
  1810. if shape[i - leading_dims] == 1 and a.shape[i] != 1
  1811. )
  1812. return torch.sum(a, dim=reduce_dims, keepdim=True, dtype=None)
  1813. @register_decomposition(aten.prod)
  1814. def prod(
  1815. a: TensorLikeType,
  1816. dim: Union[Optional[int], Optional[List[int]]] = None,
  1817. keepdim: bool = False,
  1818. *,
  1819. dtype=None,
  1820. out: Optional[Tensor] = None,
  1821. ) -> TensorLikeType:
  1822. if dtype is None:
  1823. if utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype):
  1824. dtype = torch.int64
  1825. else:
  1826. dtype = a.dtype
  1827. # reduces over all dimensions if dim=() is passed
  1828. if dim == () or dim == []:
  1829. dim = None
  1830. return _reduction(
  1831. a,
  1832. prims.prod,
  1833. dims=dim,
  1834. keepdims=keepdim,
  1835. dtype=dtype,
  1836. out=out,
  1837. output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME,
  1838. )
  1839. @register_decomposition(aten.amin)
  1840. def amin(
  1841. a: TensorLikeType,
  1842. dim: Optional[DimsType] = None,
  1843. keepdim: bool = False,
  1844. *,
  1845. out: Optional[Tensor] = None,
  1846. ) -> TensorLikeType:
  1847. # reduces over all dimensions if dim=() is passed
  1848. if dim == () or dim == []:
  1849. dim = None
  1850. return _reduction(
  1851. a,
  1852. prims.amin,
  1853. dims=dim,
  1854. keepdims=keepdim,
  1855. dtype=None,
  1856. out=out,
  1857. has_identity=False,
  1858. output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME,
  1859. )
  1860. @register_decomposition(aten.amax)
  1861. def amax(
  1862. a: TensorLikeType,
  1863. dim: Optional[DimsType] = None,
  1864. keepdim: bool = False,
  1865. *,
  1866. out: Optional[Tensor] = None,
  1867. ) -> TensorLikeType:
  1868. # reduces over all dimensions if dim=() is passed
  1869. if dim == () or dim == []:
  1870. dim = None
  1871. return _reduction(
  1872. a,
  1873. prims.amax,
  1874. dims=dim,
  1875. keepdims=keepdim,
  1876. dtype=None,
  1877. out=out,
  1878. has_identity=False,
  1879. output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME,
  1880. )
  1881. def _dim_var_dispatch(dim=None, unbiased=None):
  1882. # There's the following overload of torch.var:
  1883. # var(Tensor self, bool unbiased=True) -> (Tensor, Tensor)
  1884. # We need to explicitly convert bool dims to unbiased arg
  1885. if unbiased is None and isinstance(dim, bool):
  1886. unbiased = dim
  1887. dim = None
  1888. return dim, unbiased
  1889. @register_decomposition(aten.var)
  1890. @out_wrapper()
  1891. def var(
  1892. a: TensorLikeType,
  1893. dim: Optional[DimsType] = None,
  1894. unbiased: Optional[bool] = None,
  1895. keepdim: bool = False,
  1896. *,
  1897. correction: Optional[int] = None,
  1898. ) -> TensorLikeType:
  1899. dim, unbiased = _dim_var_dispatch(dim, unbiased)
  1900. correction = utils.set_correction(unbiased, correction)
  1901. # reduces over all dimensions if dim=() is passed
  1902. if dim == () or dim == []:
  1903. dim = None
  1904. result = _reduction(
  1905. a,
  1906. partial(prims.var, correction=correction),
  1907. dims=dim,
  1908. keepdims=keepdim,
  1909. dtype=None,
  1910. out=None,
  1911. has_identity=True,
  1912. output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT,
  1913. )
  1914. return result
  1915. @register_decomposition(aten.std)
  1916. @out_wrapper()
  1917. def std(
  1918. a: TensorLikeType,
  1919. dim: Union[Optional[int], Optional[List[int]]] = None,
  1920. unbiased: Optional[bool] = None,
  1921. keepdim: bool = False,
  1922. *,
  1923. correction: Optional[int] = None,
  1924. ) -> TensorLikeType:
  1925. dim, unbiased = _dim_var_dispatch(dim, unbiased)
  1926. correction = utils.set_correction(unbiased, correction)
  1927. opmath_dtype, dtype = utils.reduction_dtypes(
  1928. a, REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT
  1929. )
  1930. a = _maybe_convert_to_dtype(a, opmath_dtype)
  1931. a_var = torch.var(a, dim, correction=correction, keepdim=keepdim)
  1932. a_std = torch.sqrt(a_var)
  1933. assert dtype is not None
  1934. return _maybe_convert_to_dtype(a_std, dtype)
  1935. @register_decomposition(aten.mean)
  1936. def mean(
  1937. a: TensorLikeType,
  1938. dim: Optional[DimsType] = None,
  1939. keepdim: bool = False,
  1940. *,
  1941. dtype=None,
  1942. out=None,
  1943. ) -> TensorLikeType:
  1944. # reduces over all dimensions if dim=() is passed
  1945. if dim == () or dim == []:
  1946. dim = None
  1947. orig_dtype = dtype
  1948. if dtype is None:
  1949. dtype = a.dtype
  1950. # can't use out wrapper because of this argument
  1951. check(
  1952. out is None or out.dtype == dtype,
  1953. lambda: f"Expected out tensor to have dtype {dtype}, but got {out.dtype} instead",
  1954. )
  1955. result = _reduction(
  1956. a,
  1957. prims.sum,
  1958. dims=dim,
  1959. keepdims=keepdim,
  1960. dtype=dtype,
  1961. out=None,
  1962. output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.KEEP_PROMOTED_TYPE,
  1963. )
  1964. check(
  1965. utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
  1966. lambda: (
  1967. f"mean(): could not infer output dtype. "
  1968. f"{'Input' if orig_dtype is None else 'Optional'} dtype must be either "
  1969. f"a floating point or complex dtype. Got: {dtype}"
  1970. ),
  1971. )
  1972. if isinstance(dim, Dim):
  1973. dim = (dim,) # type: ignore[assignment]
  1974. dims = utils.reduction_dims(a.shape, dim) # type: ignore[arg-type]
  1975. nelem = 1 if a.ndim == 0 else reduce(operator.mul, (a.shape[i] for i in dims), 1)
  1976. result = true_divide(result, nelem)
  1977. result_dtype = a.dtype if dtype is None else dtype
  1978. result = _maybe_convert_to_dtype(result, result_dtype) # type: ignore[assignment]
  1979. if out is not None:
  1980. assert isinstance(out, TensorLike)
  1981. out = _maybe_resize_out(out, result.shape)
  1982. return _safe_copy_out(copy_from=result, copy_to=out) # type: ignore[arg-type]
  1983. return result
  1984. @register_decomposition(aten.std_mean.correction)
  1985. def std_mean(
  1986. a: TensorLikeType,
  1987. dim: Optional[DimsType] = None,
  1988. *,
  1989. unbiased: Optional[bool] = None,
  1990. keepdim: bool = False,
  1991. correction: Optional[int] = None,
  1992. ):
  1993. dim, unbiased = _dim_var_dispatch(dim, unbiased)
  1994. correction = utils.set_correction(unbiased, correction)
  1995. opmath_dtype, dtype = utils.reduction_dtypes(
  1996. a, REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT
  1997. )
  1998. original_dtype = a.dtype
  1999. a = _maybe_convert_to_dtype(a, opmath_dtype)
  2000. a_var, a_mean = torch.var_mean(a, dim, correction=correction, keepdim=keepdim)
  2001. a_std = torch.sqrt(a_var)
  2002. assert dtype is not None
  2003. return (
  2004. _maybe_convert_to_dtype(a_std, dtype),
  2005. _maybe_convert_to_dtype(a_mean, original_dtype),
  2006. )
  2007. @register_decomposition(aten.var_mean)
  2008. def var_mean(
  2009. a: TensorLikeType,
  2010. dim: Optional[DimsType] = None,
  2011. unbiased: Optional[bool] = None,
  2012. keepdim: bool = False,
  2013. *,
  2014. correction: Optional[int] = None,
  2015. ):
  2016. dim, unbiased = _dim_var_dispatch(dim, unbiased)
  2017. v = var(a, dim, unbiased, keepdim, correction=correction)
  2018. m = mean(a, dim, keepdim)
  2019. return v, m
  2020. @register_decomposition(aten.addr)
  2021. @out_wrapper()
  2022. @elementwise_type_promotion_wrapper(
  2023. type_promoting_args=("self", "vec1", "vec2"),
  2024. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  2025. )
  2026. def addr(
  2027. self: TensorLikeType,
  2028. vec1: TensorLikeType,
  2029. vec2: TensorLikeType,
  2030. *,
  2031. beta: NumberType = 1,
  2032. alpha: NumberType = 1,
  2033. ) -> TensorLikeType:
  2034. check(
  2035. vec1.ndim == 1,
  2036. lambda: f"addr: Expected 1-D argument vec1, but got {vec1.ndim}-D",
  2037. )
  2038. check(
  2039. vec2.ndim == 1,
  2040. lambda: f"addr: Expected 1-D argument vec2, but got {vec2.ndim}-D",
  2041. )
  2042. self = self.expand(vec1.shape[0], vec2.shape[0])
  2043. if utils.is_boolean_dtype(self.dtype):
  2044. # Integers are accepted for booleans
  2045. check(
  2046. is_weakly_lesser_type(type(beta), int),
  2047. lambda: f"expected bool/int beta but got {type(beta)}",
  2048. )
  2049. check(
  2050. is_weakly_lesser_type(type(alpha), int),
  2051. lambda: f"expected bool/int alpha but got {type(beta)}",
  2052. )
  2053. if not beta:
  2054. return torch.outer(vec1, vec2) if alpha else torch.full_like(self, False)
  2055. else:
  2056. return torch.logical_or(
  2057. self,
  2058. torch.outer(vec1, vec2) if alpha else torch.full_like(self, False),
  2059. )
  2060. else:
  2061. check(
  2062. is_weakly_lesser_type(type(beta), dtype_to_type(self.dtype)),
  2063. lambda: f"cannot safely convert {type(beta)} to {self.dtype}",
  2064. )
  2065. check(
  2066. is_weakly_lesser_type(type(alpha), dtype_to_type(self.dtype)),
  2067. lambda: f"cannot safely convert {type(alpha)} to {self.dtype}",
  2068. )
  2069. if beta == 0:
  2070. # This means NaNs from self are dropped if beta is zero
  2071. return alpha * torch.outer(vec1, vec2)
  2072. else:
  2073. return beta * self + alpha * torch.outer(vec1, vec2)
  2074. # CompositeImplicitAutograd - don't register decomp
  2075. def atleast_1d(
  2076. arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType
  2077. ) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]:
  2078. """Reference implementation of :func:`torch.atleast_1d`."""
  2079. if not args and isinstance(arg, collections.abc.Sequence):
  2080. args_ = arg
  2081. else:
  2082. assert not isinstance(arg, collections.abc.Sequence)
  2083. args_ = (arg,) + args
  2084. res = tuple(a if a.ndim >= 1 else unsqueeze(a, 0) for a in args_)
  2085. return res if len(res) > 1 else res[0]
  2086. # Helper function with assert to avoid MyPy error
  2087. # of incompatible type passed to unsqueeze
  2088. def _unsqueeze_atleast(
  2089. at_least_fn: Callable, dim: int, arg: TensorLikeType
  2090. ) -> TensorLikeType:
  2091. arg_ = at_least_fn(arg)
  2092. assert isinstance(arg_, TensorLike)
  2093. return unsqueeze(arg_, dim)
  2094. # CompositeImplicitAutograd - don't register decomp
  2095. def atleast_2d(
  2096. arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType
  2097. ) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]:
  2098. """Reference implementation of :func:`torch.atleast_2d`."""
  2099. if not args and isinstance(arg, collections.abc.Sequence):
  2100. args_ = arg
  2101. else:
  2102. assert not isinstance(arg, collections.abc.Sequence)
  2103. args_ = (arg,) + args
  2104. unsqueeze_atleast_1d = partial(_unsqueeze_atleast, atleast_1d, 0)
  2105. res = tuple(a if a.ndim >= 2 else unsqueeze_atleast_1d(a) for a in args_)
  2106. return res if len(res) > 1 else res[0]
  2107. # CompositeImplicitAutograd - don't register decomp
  2108. def atleast_3d(
  2109. arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType
  2110. ) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]:
  2111. """Reference implementation of :func:`torch.atleast_3d`."""
  2112. if not args and isinstance(arg, collections.abc.Sequence):
  2113. args_ = arg
  2114. else:
  2115. assert not isinstance(arg, collections.abc.Sequence)
  2116. args_ = (arg,) + args
  2117. unsqueeze_atleast_2d = partial(_unsqueeze_atleast, atleast_2d, -1)
  2118. res = tuple(a if a.ndim >= 3 else unsqueeze_atleast_2d(a) for a in args_)
  2119. return res if len(res) > 1 else res[0]
  2120. def as_strided(
  2121. a: TensorLikeType,
  2122. size: ShapeType,
  2123. stride: StrideType,
  2124. storage_offset: Optional[int] = None,
  2125. ) -> TensorLikeType:
  2126. storage_offset_int = (
  2127. storage_offset if storage_offset is not None else a.storage_offset()
  2128. )
  2129. return prims.as_strided(a, size, stride, storage_offset_int)
  2130. @register_decomposition(aten.as_strided_scatter)
  2131. def as_strided_scatter(
  2132. input: TensorLikeType,
  2133. src: TensorLikeType,
  2134. size: ShapeType,
  2135. stride: StrideType,
  2136. storage_offset: Optional[int] = None,
  2137. ) -> TensorLikeType:
  2138. storage_offset_int = 0 if storage_offset is None else storage_offset
  2139. return prims.as_strided_scatter(input, src, size, stride, storage_offset_int)
  2140. def broadcast_shapes(*shapes) -> ShapeType:
  2141. return torch.Size(_broadcast_shapes(*shapes))
  2142. @aten.broadcast_tensors.default.py_impl(DispatchKey.CompositeImplicitAutograd)
  2143. @aten.broadcast_tensors.default.py_impl(DispatchKey.Meta)
  2144. def broadcast_tensors(*tensors) -> List[TensorLikeType]:
  2145. if len(tensors) == 1 and not isinstance(tensors[0], Tensor):
  2146. tensors = tensors[0]
  2147. return list(_maybe_broadcast(*tensors, preserve_cpu_scalar_tensors=False))
  2148. # CompositeImplicitAutograd - don't register decomp
  2149. def broadcast_to(a: TensorLikeType, size: ShapeType) -> TensorLikeType:
  2150. start = len(size) - len(a.shape)
  2151. dims = tuple(range(start, len(a.shape) + start))
  2152. return prims.broadcast_in_dim(a, size, dims)
  2153. @register_decomposition(aten.cat)
  2154. @out_wrapper()
  2155. @elementwise_type_promotion_wrapper(
  2156. type_promoting_args=("tensors",),
  2157. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
  2158. )
  2159. def cat(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType:
  2160. def cat_compute_output_memory_format(inputs):
  2161. format = None
  2162. for t in inputs:
  2163. f = utils.suggest_memory_format(t)
  2164. if f == torch.contiguous_format:
  2165. return f
  2166. if format is not None and format != f:
  2167. return torch.contiguous_format
  2168. format = f
  2169. assert format is not None
  2170. return format
  2171. if len(tensors) == 0:
  2172. msg = "cat expects at least one tensor, but received zero!"
  2173. raise ValueError(msg)
  2174. for tensor in tensors:
  2175. assert isinstance(tensor, TensorLike)
  2176. utils.check_same_device(*tensors, allow_cpu_scalar_tensors=False)
  2177. for t in tensors:
  2178. # match logic in legacy_cat_wrap_dim
  2179. if t.ndim == 1 and t.size(0) == 0:
  2180. continue
  2181. dim = utils.canonicalize_dim(t.ndim, dim)
  2182. utils.validate_idx(t.ndim, dim)
  2183. break
  2184. memory_format = cat_compute_output_memory_format(tensors)
  2185. # Filters tensors with one dimension of length zero
  2186. filtered = tuple(x for x in tensors if not (x.ndim == 1 and x.numel() == 0))
  2187. if len(filtered) == 0:
  2188. t = tensors[0]
  2189. # TODO: fix this to work with meta tensors
  2190. try:
  2191. requires_grad = any(x.requires_grad for x in tensors)
  2192. except Exception:
  2193. requires_grad = False
  2194. return empty(
  2195. (0,),
  2196. dtype=t.dtype,
  2197. device=t.device,
  2198. requires_grad=requires_grad,
  2199. memory_format=memory_format,
  2200. )
  2201. return prims.cat(filtered, dim).clone(memory_format=memory_format)
  2202. # CompositeImplicitAutograd - don't register decomp
  2203. @out_wrapper()
  2204. def column_stack(tensors: TensorSequenceType) -> TensorLikeType:
  2205. aligned_tensors = tuple(
  2206. x if x.ndim > 1 else x.reshape((x.numel(), 1)) for x in tensors
  2207. )
  2208. return cat(aligned_tensors, 1)
  2209. def conj(input: TensorLikeType) -> TensorLikeType:
  2210. if not utils.is_complex_dtype(input.dtype):
  2211. return input
  2212. if input.is_sparse:
  2213. return torch.conj_physical(input)
  2214. return prims.conj(input)
  2215. # This replicates at::constant_pad_nd, defined in ATen/native/PadNd.cpp
  2216. @register_decomposition(aten.constant_pad_nd)
  2217. def constant_pad_nd(
  2218. input: TensorLikeType, pad: List[int], value: NumberType = 0
  2219. ) -> TensorLikeType:
  2220. check(
  2221. len(pad) % 2 == 0,
  2222. lambda: f"Length of pad must be even but instead it equals {len(pad)}",
  2223. )
  2224. input_sizes = input.shape
  2225. l_inp = len(input_sizes)
  2226. l_pad = len(pad) // 2
  2227. l_diff = l_inp - l_pad
  2228. check(
  2229. l_inp >= l_pad,
  2230. lambda: "Length of pad should be no more than twice the number of "
  2231. f"dimensions of the input. Pad length is {len(pad)} while the input has "
  2232. f"{l_inp} dimensions.",
  2233. )
  2234. c_input = input
  2235. for i in range(l_diff, l_inp):
  2236. pad_idx = 2 * (l_inp - i - 1)
  2237. if pad[pad_idx] < 0:
  2238. c_input = c_input.narrow(i, -pad[pad_idx], c_input.shape[i] + pad[pad_idx])
  2239. if pad[pad_idx + 1] < 0:
  2240. c_input = c_input.narrow(i, 0, c_input.shape[i] + pad[pad_idx + 1])
  2241. # if none of the pads are positive we can just return the result
  2242. if builtins.all(p <= 0 for p in pad):
  2243. return c_input.clone()
  2244. new_shape = list(input_sizes[:l_diff])
  2245. for i in range(l_pad):
  2246. pad_idx = len(pad) - ((i + 1) * 2)
  2247. new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1]
  2248. check(
  2249. new_dim > 0,
  2250. lambda: f"The input size {input_sizes[l_diff + i]}, plus negative padding "
  2251. f"{pad[pad_idx]} and {pad[pad_idx + 1]} resulted in a negative output size, "
  2252. f"which is invalid. Check dimension {l_diff + i} of your input.",
  2253. )
  2254. new_shape.append(new_dim)
  2255. memory_format = utils.suggest_memory_format(input)
  2256. output = torch.empty(
  2257. new_shape,
  2258. dtype=input.dtype,
  2259. device=input.device,
  2260. requires_grad=input.requires_grad,
  2261. memory_format=memory_format,
  2262. )
  2263. if value == 0 and input.dtype == torch.bool:
  2264. value = False
  2265. # torch.fill isn't typed to allow complex values
  2266. output = torch.fill(output, value) # type: ignore[arg-type]
  2267. c_output = output
  2268. for i in range(l_diff, l_inp):
  2269. pad_idx = 2 * (l_inp - i - 1)
  2270. if pad[pad_idx] > 0:
  2271. c_output = c_output.narrow(
  2272. i, pad[pad_idx], c_output.shape[i] - pad[pad_idx]
  2273. )
  2274. if pad[pad_idx + 1] > 0:
  2275. c_output = c_output.narrow(i, 0, c_output.shape[i] - pad[pad_idx + 1])
  2276. prims.copy_to(c_output, c_input)
  2277. return output
  2278. def contiguous(
  2279. a: Tensor, *, memory_format: torch.memory_format = torch.contiguous_format
  2280. ) -> Tensor:
  2281. check(
  2282. memory_format != torch.preserve_format,
  2283. lambda: "preserve memory format is unsupported by the contiguous operator",
  2284. )
  2285. if utils.is_contiguous_for_memory_format(a, memory_format=memory_format):
  2286. return a
  2287. return torch.clone(a, memory_format=memory_format)
  2288. @out_wrapper()
  2289. def dstack(tensors: TensorSequenceType) -> TensorLikeType:
  2290. check(len(tensors) > 0, lambda: "dstack expects a non-empty TensorList")
  2291. aligned_tensors = atleast_3d(*tensors)
  2292. return cat(aligned_tensors, 2)
  2293. @register_decomposition(aten.expand)
  2294. def expand(a: Tensor, *shape) -> Tensor:
  2295. # NOTE: cannot use utils.extract_shape_from_varargs here
  2296. # because that also validates the shape, but the shape
  2297. # given to expand may be "invalid"
  2298. if len(shape) == 1 and isinstance(shape[0], Sequence):
  2299. shape = tuple(shape[0])
  2300. check(
  2301. len(shape) >= len(a.shape),
  2302. lambda: "expand: the requested shape has too few dimensions!",
  2303. )
  2304. offset = len(shape) - len(a.shape)
  2305. shape_ = list(shape)
  2306. for idx, x in enumerate(a.shape):
  2307. offset_idx = idx + offset
  2308. requested_length = shape[offset_idx]
  2309. check(
  2310. requested_length == x or x == 1 or requested_length == -1,
  2311. lambda: f"expand: attempting to expand a dimension of length {x}!",
  2312. )
  2313. shape_[offset_idx] = requested_length if requested_length != -1 else x
  2314. # At this point shape must be valid
  2315. utils.validate_shape(shape_)
  2316. return prims.broadcast_in_dim(
  2317. a, shape_, tuple(range(offset, len(a.shape) + offset))
  2318. )
  2319. # CompositeImplicitAutograd - don't register decomp
  2320. def expand_as(a: Tensor, b: Tensor) -> Tensor:
  2321. return a.expand(b.shape)
  2322. def chunk(a: TensorLikeType, chunks: int, dim: int = 0) -> Tuple[TensorLikeType, ...]:
  2323. if chunks <= 0:
  2324. msg = "Expected at least one chunk, but got {0}!".format(chunks)
  2325. raise ValueError(msg)
  2326. dim = utils.canonicalize_dim(a.ndim, dim)
  2327. length = a.shape[dim]
  2328. chunk_size = math.ceil(length / chunks)
  2329. full_chunks = math.floor(length / chunk_size)
  2330. tail_chunk_size = length % chunk_size
  2331. result = []
  2332. for i in range(full_chunks):
  2333. result.append(narrow(a, dim, i * chunk_size, chunk_size))
  2334. if tail_chunk_size != 0:
  2335. result.append(narrow(a, dim, full_chunks * chunk_size, tail_chunk_size))
  2336. return tuple(result)
  2337. # Note: flatten, unlike prim.collapse and prim.collapse_view has an inclusive end_dim
  2338. # Note: flatten, unlike other shape operators, returns the input tensor on a no-op (unless
  2339. # a 0D tensor is flattened, in which case it's returned in 1D)
  2340. # CompositeImplicitAutograd - don't register decomp
  2341. def flatten(a: TensorLikeType, start_dim: int = 0, end_dim: int = -1) -> TensorLikeType:
  2342. start_dim = utils.canonicalize_dim(a.ndim, start_dim)
  2343. end_dim = utils.canonicalize_dim(a.ndim, end_dim)
  2344. # Short-circuits on no-op
  2345. if start_dim == end_dim and a.ndim != 0:
  2346. return a
  2347. # Tries to take a view
  2348. # TODO: we could look at directing collapse_view to skip its meta function here (unsafe_collapse_view)
  2349. new_shape, new_strides = prims._collapse_view_helper(a, start_dim, end_dim + 1)
  2350. if new_shape is not None:
  2351. return prims.collapse_view(a, start_dim, end_dim + 1)
  2352. # Makes a copy if it can't make a view
  2353. return prims.collapse(a, start_dim, end_dim + 1)
  2354. @register_decomposition(aten.flip)
  2355. def flip(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType:
  2356. if not isinstance(dims, tuple) and not isinstance(dims, list):
  2357. raise ValueError("dims has to be a sequence of ints")
  2358. dims = utils.canonicalize_dims(a.ndim, dims) # type: ignore[assignment]
  2359. utils.validate_no_repeating_dims(dims)
  2360. return prims.rev(a, dims)
  2361. # CompositeImplicitAutograd - don't register decomp
  2362. def fliplr(a: TensorLikeType) -> TensorLikeType:
  2363. if a.ndim < 2:
  2364. raise RuntimeError("Input must be >= 2-d.")
  2365. return flip(a, (1,))
  2366. # CompositeImplicitAutograd - don't register decomp
  2367. def flipud(a: TensorLikeType) -> TensorLikeType:
  2368. if a.ndim < 1:
  2369. raise RuntimeError("Input must be >= 1-d.")
  2370. return flip(a, (0,))
  2371. # CompositeImplicitAutograd - don't register decomp
  2372. def narrow(
  2373. a: TensorLikeType, dim: int, start: Union[int, TensorLikeType], length: int
  2374. ) -> TensorLikeType:
  2375. # Supports Tensor overload that was added for XLA:
  2376. # https://github.com/pytorch/pytorch/issues/31558
  2377. if isinstance(start, TensorLike):
  2378. check(
  2379. start.dim() == 0 and utils.is_integer_dtype(start.dtype),
  2380. lambda: "start must be an 0-dim integral Tensor.",
  2381. )
  2382. start = start.item() # type: ignore[assignment]
  2383. check(a.dim() > 0, lambda: "narrow() cannot be applied to a 0-dim tensor.")
  2384. check(length >= 0, lambda: "narrow(): length must be non-negative.")
  2385. dim = utils.canonicalize_dim(a.ndim, dim)
  2386. dim_length = a.size(dim)
  2387. # Start being the end is usually invalid since it's out of bounds. So it's
  2388. # not allowed by canonicalize_dim. But for narrow it's valid as long as
  2389. # the length is 0, which is handled by the check below.
  2390. if start != dim_length:
  2391. # Negative start means indexing from the end of dim.
  2392. # Note: a dimension isn't being canonicalized here, this reuses
  2393. # canonicalize_dim because the semantics are similar.
  2394. start = utils.canonicalize_dim(dim_length, start) # type: ignore[arg-type]
  2395. check(
  2396. start <= dim_length - length, # type: ignore[arg-type]
  2397. lambda: f"start ({start}) + length ({length}) exceeds dimension size ({dim_length}).",
  2398. )
  2399. return prims.slice_in_dim(a, start, start + length, axis=dim)
  2400. # TODO: This must return a sparse tensor if the input is sparse, but refs have
  2401. # no sparse support. See narrow_copy_sparse in core.
  2402. narrow_copy = _make_copy_from_view(narrow)
  2403. def _normalize(
  2404. a: Tensor, norm_dims: DimsType, eps: float
  2405. ) -> Tuple[Tensor, Tensor, Tensor]:
  2406. """Computes mean and 1/std of a tensor along norm_dims.
  2407. Used as a helper function for normalization layers.
  2408. Args:
  2409. a (Tensor): input tensor
  2410. norm_dims (DimsType): dimensions to normalize over
  2411. eps (float): epsilon for numerical stability
  2412. Returns:
  2413. out (Tensor): normalized tensor.
  2414. mean (Tensor): mean of the tensor along norm_dims.
  2415. rstd (Tensor): 1/std of the tensor along norm_dims.
  2416. """
  2417. norm_dims = utils.canonicalize_dims(a.ndim, norm_dims)
  2418. computation_dtype = utils.get_computation_dtype(a.dtype)
  2419. a_acc = _maybe_convert_to_dtype(a, computation_dtype)
  2420. assert isinstance(a_acc, TensorLike) # to avoid mypy error for var_mean
  2421. biased_var, mean = torch.var_mean(
  2422. a_acc, dim=norm_dims, unbiased=False, keepdim=True
  2423. )
  2424. rstd = torch.rsqrt(biased_var + eps)
  2425. out = (a - mean) * rstd
  2426. return out, mean, rstd
  2427. # add all specified dimensions
  2428. def _unsqueeze_multiple(x: TensorLikeType, dimensions: List[int]) -> TensorLikeType:
  2429. for dim in sorted(dimensions):
  2430. x = torch.unsqueeze(x, dim)
  2431. return x
  2432. @register_decomposition(aten.native_group_norm.default)
  2433. def native_group_norm(
  2434. input: Tensor,
  2435. weight: Optional[Tensor],
  2436. bias: Optional[Tensor],
  2437. batch_size: int,
  2438. num_channels: int,
  2439. flattened_inner_size: int,
  2440. num_groups: int,
  2441. eps: float,
  2442. ) -> Tuple[Tensor, Tensor, Tensor]:
  2443. utils.check(
  2444. input.ndim >= 2,
  2445. lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}",
  2446. )
  2447. utils.check(
  2448. num_channels % num_groups == 0,
  2449. lambda: "Expected number of channels in input to be divisible by num_groups, "
  2450. + f"but got input of shape {input.shape} and num_groups = {num_groups}",
  2451. )
  2452. # num_channels / num_groups and flattened inner dimension are the reduction axes
  2453. reduction_dims = [2, 3]
  2454. input_reshaped = torch.reshape(
  2455. input,
  2456. [batch_size, num_groups, num_channels // num_groups, flattened_inner_size],
  2457. )
  2458. out, mean, rstd = _normalize(input_reshaped, reduction_dims, eps)
  2459. out = out.view(input.shape)
  2460. broadcast_dims = [0] + list(range(2, input.ndim))
  2461. unsqueeze_bias = None
  2462. if bias is not None:
  2463. unsqueeze_bias = _unsqueeze_multiple(bias, broadcast_dims)
  2464. unsqueeze_weight = None
  2465. if weight is not None:
  2466. unsqueeze_weight = _unsqueeze_multiple(weight, broadcast_dims)
  2467. if unsqueeze_weight is not None:
  2468. out = out * unsqueeze_weight
  2469. if unsqueeze_bias is not None:
  2470. out = out + unsqueeze_bias
  2471. out = _maybe_convert_to_dtype(out, input.dtype) # type: ignore[assignment]
  2472. mean = _maybe_convert_to_dtype(mean, input.dtype) # type: ignore[assignment]
  2473. rstd = _maybe_convert_to_dtype(rstd, input.dtype) # type: ignore[assignment]
  2474. # remove broadcast dimensions from mean and rstd
  2475. mean = torch.squeeze(mean, reduction_dims)
  2476. rstd = torch.squeeze(rstd, reduction_dims)
  2477. return (out, mean, rstd)
  2478. @register_decomposition(aten.native_layer_norm)
  2479. def native_layer_norm(
  2480. input: Tensor,
  2481. normalized_shape: ShapeType,
  2482. weight: Optional[Tensor],
  2483. bias: Optional[Tensor],
  2484. eps: float,
  2485. ) -> Tuple[Tensor, Tensor, Tensor]:
  2486. normalized_ndim = len(normalized_shape)
  2487. utils.check(
  2488. normalized_ndim >= 1,
  2489. lambda: "Expected normalized_shape to be at least 1-dimensional, i.e., "
  2490. + "containing at least one element, but got normalized_shape = "
  2491. + str(normalized_shape),
  2492. )
  2493. # torch.Size([1, 2, 3]) == [1, 2, 3] evaluates to False
  2494. # while torch.Size([1, 2, 3]) == (1, 2, 3) is True
  2495. # therefore we use tuple(normalized_shape)
  2496. utils.check(
  2497. weight is None or weight.shape == tuple(normalized_shape),
  2498. lambda: "Expected weight to be of same shape as normalized_shape, but got "
  2499. + "weight of shape "
  2500. + str(weight.shape) # type: ignore[union-attr]
  2501. + " and normalized_shape = "
  2502. + str(normalized_shape),
  2503. )
  2504. utils.check(
  2505. bias is None or bias.shape == tuple(normalized_shape),
  2506. lambda: "Expected bias to be of same shape as normalized_shape, but got "
  2507. + "bias of shape "
  2508. + str(bias.shape) # type: ignore[union-attr]
  2509. + " and normalized_shape = "
  2510. + str(normalized_shape),
  2511. )
  2512. utils.check(
  2513. input.ndim >= normalized_ndim
  2514. and input.shape[(input.ndim - normalized_ndim) :] == tuple(normalized_shape),
  2515. lambda: "Given normalized_shape="
  2516. + str(normalized_shape)
  2517. + ", expected input with shape "
  2518. + str(normalized_shape)
  2519. + ", but got input of size "
  2520. + str(input.shape),
  2521. )
  2522. input = input.contiguous()
  2523. if weight is not None:
  2524. weight = weight.contiguous()
  2525. if bias is not None:
  2526. bias = bias.contiguous()
  2527. axis = input.ndim - normalized_ndim
  2528. reduction_dims = list(range(axis, input.ndim))
  2529. out, mean, rstd = _normalize(input, reduction_dims, eps)
  2530. if weight is None and bias is not None:
  2531. out = out + bias
  2532. elif weight is not None and bias is None:
  2533. out = out * weight
  2534. elif weight is not None and bias is not None:
  2535. out = out * weight + bias
  2536. out = _maybe_convert_to_dtype(out, input.dtype) # type: ignore[assignment]
  2537. if input.device.type == "cpu":
  2538. mean = _maybe_convert_to_dtype(mean, input.dtype) # type: ignore[assignment]
  2539. rstd = _maybe_convert_to_dtype(rstd, input.dtype) # type: ignore[assignment]
  2540. return (out, mean, rstd)
  2541. # TODO: Adding this as a meta function causes functorch tests to fail when compiled with debug mode.
  2542. # test/test_eager_transforms.py::TestFunctionalizeCPU::test_functionalize_fx_transpose_simple_cpu
  2543. @register_decomposition(aten.permute)
  2544. def permute(a: TensorLikeType, *dims) -> TensorLikeType:
  2545. _permutation = utils.canonicalize_dims(
  2546. a.ndim, utils.extract_dims_from_varargs(dims)
  2547. )
  2548. return prims.transpose(a, _permutation)
  2549. # Get the new shape and stride after applying unfold to an input tensor
  2550. def _get_unfold_shape_stride(
  2551. a_shape: ShapeType, a_stride: StrideType, dimension: int, size: int, step: int
  2552. ):
  2553. a_ndim = len(a_shape)
  2554. dim = utils.canonicalize_dim(a_ndim, dimension, wrap_scalar=True)
  2555. max_size = 1 if a_ndim == 0 else a_shape[dim]
  2556. last_stride = 1 if a_ndim == 0 else a_stride[dim]
  2557. utils.check(
  2558. size <= max_size,
  2559. lambda: f"Maximum size for tensor at dimension {dim} is {max_size} but size is {size}",
  2560. )
  2561. utils.check(
  2562. step > 0,
  2563. lambda: f"Step is {step} but must be > 0",
  2564. )
  2565. shape = list(a_shape)
  2566. strides = list(a_stride)
  2567. shape.append(size)
  2568. strides.append(last_stride)
  2569. if dim < a_ndim:
  2570. shape[dim] = (shape[dim] - size) // step + 1
  2571. strides[dim] *= step
  2572. return shape, strides
  2573. @register_decomposition(aten.repeat)
  2574. def repeat(a: Tensor, *repeat_shape) -> Tensor:
  2575. repeat_shape = utils.extract_shape_from_varargs(repeat_shape, validate=False)
  2576. utils.check(
  2577. len(repeat_shape) >= len(a.shape),
  2578. lambda: "repeat: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor",
  2579. )
  2580. if len(repeat_shape) == 0:
  2581. return torch.clone(a)
  2582. num_new_dimensions = len(repeat_shape) - a.ndim
  2583. padded_shape = [1] * num_new_dimensions
  2584. for dim_size in a.shape:
  2585. padded_shape.append(dim_size)
  2586. target_shape = tuple(
  2587. padded_size * repeat_size
  2588. for padded_size, repeat_size in zip(padded_shape, repeat_shape)
  2589. )
  2590. # return an empty tensor if one of the repeat_shape dimensions is zero
  2591. if 0 in repeat_shape:
  2592. return torch.empty(
  2593. target_shape,
  2594. dtype=a.dtype,
  2595. device=a.device,
  2596. requires_grad=a.requires_grad,
  2597. memory_format=utils.suggest_memory_format(a),
  2598. )
  2599. urtensor_shape = target_shape
  2600. urtensor_stride = utils.make_contiguous_strides_for(target_shape)
  2601. for dim, dim_size in enumerate(padded_shape):
  2602. # repeat each dimension by using unfold_copy operation
  2603. urtensor_shape, urtensor_stride = _get_unfold_shape_stride(
  2604. urtensor_shape, urtensor_stride, dim, dim_size, max(dim_size, 1)
  2605. )
  2606. # derive permute order by sorting urtensor strides
  2607. enumerated_stride = list(enumerate(urtensor_stride))
  2608. enumerated_stride.sort(key=lambda item: item[1], reverse=True)
  2609. permute_order, sorted_stride = zip(*enumerated_stride)
  2610. # add new and expand dimensions according to urtensor
  2611. repeat_xtensor = a.expand(urtensor_shape)
  2612. # clone tensor to concretize expanded dimensions
  2613. cloned_result = torch.clone(repeat_xtensor)
  2614. # transpose axis so strides are in sorted order
  2615. permuted_result = cloned_result.permute(permute_order)
  2616. # reshape to get contiguous tensor with correct target shape
  2617. return permuted_result.reshape(target_shape)
  2618. def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorLikeType:
  2619. # Creates a valid shape
  2620. shape = utils.extract_shape_from_varargs(shape, validate=False)
  2621. # Reshape may be given a shape with a -1 length
  2622. # This indicates that the dimension's length should be inferred
  2623. shape = utils.infer_size(shape, a.numel())
  2624. # Short-circuits if shape is the same
  2625. if tuple(a.shape) == tuple(shape):
  2626. return prims.view_of(a)
  2627. # Special-cases tensors with no elements
  2628. if a.numel() == 0:
  2629. return as_strided(a, shape, utils.make_contiguous_strides_for(shape))
  2630. # Special-cases reshaping zero dim tensors
  2631. if a.ndim == 0:
  2632. _a = a
  2633. for length in shape:
  2634. assert length == 1
  2635. _a = unsqueeze(_a, -1)
  2636. return _a
  2637. # Special-cases reshaping to zero dim tensors
  2638. if len(shape) == 0:
  2639. _a = a
  2640. for length in a.shape:
  2641. assert length == 1
  2642. _a = squeeze(_a, -1)
  2643. return _a
  2644. # Handles general case: a 1+D tensor reshaped into a distinct 1+D shape
  2645. # NOTE [Reshape Algorithm]
  2646. # This algorithm works by attempting to greedily construct the desired dimensions in
  2647. # the output shape, left to right. It does this by, conceptually, accumulating
  2648. # dimensions of the original tensor, also left to right, until the dimension
  2649. # can be constructed using prims.split_dim.
  2650. # The algorithm also has special handling for tail squeezes/unsqueezes, like
  2651. # if a reshape from (5, 5) to (5, 5, 1) or vice versa.
  2652. #
  2653. # This algorithm does not flatten the original tensor and then split dims as appropriate
  2654. # because that would create copies more often than this algorithm. flatten is the only
  2655. # operation below which can create a view or a copy, and while it prefers creating
  2656. # views it may sometimes create a copy if the tensor's strides do not permit a view.
  2657. # As a result, this algorithm tries to minimize flattening.
  2658. #
  2659. # Note that a better version of this algorithm may exist. Regions which could be
  2660. # flattened without creating a copy can be identified in advance, and that might
  2661. # allow fewer flatten calls or faster short-circuiting to make a copy.
  2662. idx = 0
  2663. a_ = a
  2664. for length in shape:
  2665. # Handles tail unsqueezes
  2666. if idx >= a_.ndim:
  2667. assert length == 1
  2668. last_dim = a_.ndim - 1
  2669. # NOTE: using split_dim instead of unsqueeze may seem silly here,
  2670. # but it's necessary to get the strides correct
  2671. a_ = prims.split_dim(a_, last_dim, a_.shape[last_dim])
  2672. idx = idx + 1
  2673. continue
  2674. # Skips dimensions that are already the correct length
  2675. if length == a_.shape[idx]:
  2676. idx = idx + 1
  2677. continue
  2678. # Gathers enough original dimensions such that this new dimension can be created
  2679. # Note that this accumulation will terminate because we've verified a and the shape
  2680. # specify the same number of elements above
  2681. accum = a_.shape[idx]
  2682. end = idx
  2683. while accum % length != 0:
  2684. end = end + 1
  2685. accum = accum * a_.shape[end]
  2686. if end != idx:
  2687. # NOTE: in this case multiple dimensions must be flatten to create the desired dimension
  2688. # This flattening is why reshape sometimes creates a copy -- because flattening
  2689. # may return a view of a copy
  2690. # Checks if collapse can be a view and short-circuits to copying reshape if it can't
  2691. new_shape, new_strides = prims._collapse_view_helper(a_, idx, end + 1)
  2692. if new_shape is None:
  2693. if allow_copy:
  2694. return prims.reshape(a, shape)
  2695. msg = "Cannot view a tensor with shape {0} and strides {1} as a tensor with shape {2}!".format(
  2696. a.shape, a.stride(), shape
  2697. )
  2698. raise ValueError(msg)
  2699. a_ = flatten(a_, idx, end)
  2700. # Splits the (possibly flattened) dimension to create the desired dim length
  2701. if accum != length:
  2702. a_ = prims.split_dim(a_, idx, length)
  2703. idx = idx + 1
  2704. # Squeezes tail
  2705. while idx < a_.ndim:
  2706. assert a_.shape[idx] == 1
  2707. a_ = squeeze(a_, idx)
  2708. return a_
  2709. # CompositeImplicitAutograd - don't register decomp
  2710. # NOTE: shape is a vararg because Tensor.reshape can be called with as
  2711. # Tensor.reshape(a, b, c) or Tensor.reshape((a, b, c)) Function call
  2712. # torch.reshape doesn't support unpacked shapes
  2713. def reshape(a: TensorLikeType, *shape: ShapeType) -> TensorLikeType:
  2714. return _reshape_view_helper(a, *shape, allow_copy=True)
  2715. # CompositeImplicitAutograd - don't register decomp
  2716. def reshape_as(self: TensorLikeType, other: TensorLikeType) -> TensorLikeType:
  2717. return self.reshape(other.size())
  2718. @register_decomposition(aten.roll)
  2719. def roll(
  2720. a: TensorLikeType, shifts: DimsType, dims: DimsType = tuple()
  2721. ) -> TensorLikeType:
  2722. """Reference implementation of :func:`torch.roll`."""
  2723. dims = utils.canonicalize_dims(a.ndim, dims)
  2724. # ATen specifies int[1] type for shifts and dims which expands integers to tuples of length 1
  2725. if not isinstance(shifts, Iterable):
  2726. shifts = (shifts,)
  2727. if not isinstance(dims, Iterable):
  2728. dims = (dims,)
  2729. # Avoid modulo by zero
  2730. if a.numel() == 0:
  2731. # Keeping this as ref for now as FakeTensor runs into some issues with complex tensors
  2732. return clone(a)
  2733. len_shifts = len(shifts)
  2734. len_dims = len(dims)
  2735. if len_shifts != 1 or len_dims != 1:
  2736. if len_shifts == 0:
  2737. raise RuntimeError("`shifts` required")
  2738. # Takes care of the case when dims is not specified (default)
  2739. # By default, the tensor is flattened before shifting, after which the original shape is restored
  2740. if len_dims == 0 and len_shifts == 1:
  2741. return torch.roll(torch.flatten(a), shifts, 0).view(a.shape)
  2742. if len_shifts != len_dims:
  2743. raise RuntimeError(
  2744. f"shifts and dimensions must align. shifts: {len_shifts}, dims: {len_dims}"
  2745. )
  2746. assert len_dims > 1
  2747. tail_shifts = shifts[1:]
  2748. tail_dims = dims[1:]
  2749. first_dim_rolled = torch.roll(a, shifts[0], dims[0])
  2750. return torch.roll(first_dim_rolled, tail_shifts, tail_dims)
  2751. # This path is taken when only one dimension is rolled
  2752. # For example to get `first_dim_rolled` above
  2753. dim = dims[0]
  2754. size = a.shape[dim]
  2755. start = (size - shifts[0]) % size
  2756. t0 = torch.narrow(a, dim, start, size - start)
  2757. t1 = torch.narrow(a, dim, 0, start)
  2758. return torch.cat((t0, t1), dim)
  2759. @register_decomposition(aten.rot90)
  2760. def rot90(
  2761. a: TensorLikeType, k: int = 1, dims: DimsSequenceType = (0, 1)
  2762. ) -> TensorLikeType:
  2763. """Reference implementation of :func:`torch.rot90`."""
  2764. if len(dims) != 2:
  2765. raise RuntimeError(
  2766. f"expected total rotation dims == 2, but got dims = {len(dims)}"
  2767. )
  2768. if a.ndim < 2:
  2769. raise RuntimeError(f"expected total dims >= 2, but got total dims = {a.ndim}")
  2770. # Do this after the initial checks to be compatible with the behavior in
  2771. # core.
  2772. dims = utils.canonicalize_dims(a.ndim, dims)
  2773. if dims[0] == dims[1]:
  2774. raise RuntimeError(
  2775. f"expected rotation dims to be different, but got dim0 = {dims[0]} and dim1 = {dims[1]}"
  2776. )
  2777. k = k % 4 # Rotation direction is from the second towards the first axis for k < 0
  2778. if k == 1:
  2779. return torch.transpose(torch.flip(a, (dims[1],)), dims[0], dims[1])
  2780. elif k == 2:
  2781. return torch.flip(a, dims)
  2782. elif k == 3:
  2783. return torch.transpose(torch.flip(a, (dims[0],)), dims[0], dims[1])
  2784. else:
  2785. return clone(a, memory_format=torch.contiguous_format)
  2786. def _check_stack_inputs(tensors: TensorSequenceType) -> None:
  2787. entry_shape = tensors[0].shape
  2788. for i in range(1, len(tensors)):
  2789. assert tensors[i].shape == entry_shape, (
  2790. f"stack expects each tensor to be equal size, but got {entry_shape} at entry 0"
  2791. f"and {tensors[i].shape} at entry {i}"
  2792. )
  2793. @register_decomposition(aten.stack)
  2794. @out_wrapper()
  2795. def stack(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType:
  2796. assert len(tensors) > 0, "stack expects a non-empty TensorList"
  2797. wrapped_dim = utils.canonicalize_dim(tensors[0].ndim + 1, dim)
  2798. # Refs need sparse support to check other condition
  2799. if wrapped_dim < tensors[0].ndim: # and not tensors[0].is_sparse:
  2800. _check_stack_inputs(tensors)
  2801. result_sizes = list(tensors[0].shape)
  2802. result_sizes.insert(wrapped_dim, len(tensors))
  2803. out = torch.cat(tensors, wrapped_dim)
  2804. return out.view(result_sizes)
  2805. # If dim == tensors[0].ndim, view cannot efficiently handle it
  2806. return torch.cat([t.unsqueeze(wrapped_dim) for t in tensors], dim)
  2807. # CompositeImplicitAutograd - don't register decomp
  2808. @out_wrapper()
  2809. def softmax(
  2810. a: TensorLikeType,
  2811. dim: int,
  2812. dtype: Optional[torch.dtype] = None,
  2813. ) -> TensorLikeType:
  2814. result_dtype = dtype or a.dtype
  2815. computation_dtype = utils.get_computation_dtype(result_dtype)
  2816. a_ = _maybe_convert_to_dtype(a, computation_dtype)
  2817. if a.numel() == 0:
  2818. a_exp = exp(a_)
  2819. else:
  2820. a_max = amax(a_, dim, keepdim=True)
  2821. a_exp = exp(a_ - a_max)
  2822. return _maybe_convert_to_dtype(
  2823. true_divide(a_exp, sum(a_exp, dim, keepdim=True)), result_dtype
  2824. ) # type: ignore[return-value]
  2825. # CompositeImplicitAutograd - don't register decomp
  2826. @out_wrapper()
  2827. def hstack(tensors: TensorSequenceType) -> TensorLikeType:
  2828. check(len(tensors) > 0, lambda: "hstack expects a non-empty TensorList")
  2829. aligned_tensors = atleast_1d(*tensors)
  2830. if aligned_tensors[0].ndim == 1:
  2831. return cat(aligned_tensors, 0)
  2832. return cat(aligned_tensors, 1)
  2833. # CompositeImplicitAutograd - don't register decomp
  2834. @out_wrapper()
  2835. def vstack(tensors: TensorSequenceType) -> TensorLikeType:
  2836. check(len(tensors) > 0, lambda: "vstack expects a non-empty TensorList")
  2837. aligned_tensors = atleast_2d(*tensors)
  2838. return cat(aligned_tensors, 0)
  2839. # CompositeImplicitAutograd - don't register decomp
  2840. def unflatten(a: TensorLikeType, dim: int, sizes: ShapeType) -> TensorLikeType:
  2841. dim = utils.canonicalize_dim(a.ndim, dim)
  2842. utils.check(len(sizes) != 0, lambda: "unflatten: sizes must be non-empty")
  2843. return a.view(tuple(a.shape[:dim]) + tuple(sizes) + tuple(a.shape[dim + 1 :]))
  2844. @register_decomposition(aten.unbind)
  2845. def unbind(t: TensorLikeType, dim: int = 0) -> TensorSequenceType:
  2846. dim = utils.canonicalize_dim(t.ndim, dim)
  2847. check(
  2848. len(t.shape) > 0,
  2849. lambda: "Dimension specified as 0 but tensor has no dimensions",
  2850. IndexError,
  2851. )
  2852. return tuple(
  2853. torch.squeeze(s, dim) for s in torch.tensor_split(t, t.shape[dim], dim)
  2854. )
  2855. @out_wrapper()
  2856. def index_copy(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike):
  2857. return x.clone(memory_format=torch.contiguous_format).index_copy_(
  2858. dim, index, tensor
  2859. )
  2860. def index_copy_(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike):
  2861. dim = utils.canonicalize_dims(x.ndim, dim)
  2862. utils.check(
  2863. index.ndim <= 1,
  2864. lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
  2865. )
  2866. # Treat scalars as elements of \R^1
  2867. y = x.unsqueeze(0) if x.ndim == 0 else x
  2868. idx = (slice(None),) * dim + (index,)
  2869. y[idx] = tensor
  2870. return x
  2871. @register_decomposition(aten.index_fill)
  2872. def index_fill(
  2873. x: TensorLike, dim: int, index: TensorLike, value: Union[NumberType, TensorLike]
  2874. ):
  2875. return _index_fill(x, dim, index, value, inplace=False)
  2876. @register_decomposition(aten.index_fill_)
  2877. def index_fill_(
  2878. x: TensorLike, dim: int, index: TensorLike, value: Union[NumberType, TensorLike]
  2879. ):
  2880. return _index_fill(x, dim, index, value, inplace=True)
  2881. def _index_fill(
  2882. x: TensorLike,
  2883. dim: int,
  2884. index: TensorLike,
  2885. value: Union[NumberType, TensorLike],
  2886. *,
  2887. inplace: bool,
  2888. ):
  2889. utils.check(
  2890. index.ndim <= 1,
  2891. lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
  2892. )
  2893. if isinstance(value, TensorLike):
  2894. utils.check(
  2895. value.ndim == 0,
  2896. lambda: "Only supports 0-dimensional value tensor. " # type: ignore[union-attr]
  2897. f"Got a tensor with {value.ndim} dimensions.",
  2898. ) # type: ignore[arg-type]
  2899. else:
  2900. value = torch.scalar_tensor(
  2901. value, dtype=x.dtype, layout=x.layout, device=x.device # type: ignore[arg-type]
  2902. )
  2903. # index_copy has some unnecessary preconditions when x is a scalar. We do this to work through them
  2904. zero_dim = x.ndim == 0
  2905. y = x.unsqueeze(0) if zero_dim else x
  2906. # index_copy does not broadcast on value so we have to do it manually
  2907. shape = list(y.shape)
  2908. shape[dim] = index.numel()
  2909. value = value.expand(shape)
  2910. index_copy = Tensor.index_copy_ if inplace else torch.index_copy
  2911. out = index_copy(y, dim, index, value) # type: ignore[operator]
  2912. if inplace:
  2913. return x
  2914. else:
  2915. if zero_dim:
  2916. # The clone is necessary so that it returns a fresh tensor rather than a view
  2917. out = out.squeeze(0).clone()
  2918. # index_fill preserves the strides. index_copy always returns contiguous tensors
  2919. if out.stride() != x.stride():
  2920. new_out = torch.empty_like(x)
  2921. new_out.copy_(out)
  2922. out = new_out
  2923. return out
  2924. @out_wrapper()
  2925. def index_add(
  2926. x: TensorLike,
  2927. dim: int,
  2928. index: TensorLike,
  2929. tensor: TensorLike,
  2930. *,
  2931. alpha: NumberType = 1,
  2932. ):
  2933. # index_add always returns a new contiguous tensor
  2934. return x.clone(memory_format=torch.contiguous_format).index_add_(
  2935. dim, index, tensor, alpha=alpha # type: ignore[arg-type]
  2936. )
  2937. @register_decomposition(aten.index_select)
  2938. @out_wrapper()
  2939. def index_select(x: TensorLike, dim: int, index: TensorLike):
  2940. dim = utils.canonicalize_dims(x.ndim, dim)
  2941. utils.check(
  2942. index.ndim <= 1,
  2943. lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
  2944. )
  2945. if index.ndim == 0:
  2946. index = index.unsqueeze(0)
  2947. if x.ndim == 0:
  2948. # Treat scalars as elements of \R^1
  2949. # We cannot use x[idx] here as it accesses item() (??), hence this awkward construction
  2950. return torch.empty_like(x).index_copy(0, index, x.expand_as(index))
  2951. idx = (slice(None),) * dim + (index,)
  2952. return x[idx]
  2953. @register_decomposition(aten.squeeze)
  2954. def squeeze(a: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType:
  2955. if dim is None:
  2956. dims = tuple(idx for idx, size in enumerate(a.shape) if size == 1)
  2957. return prims.squeeze(a, dims) if dims else prims.view_of(a)
  2958. ndim = a.ndim
  2959. dim = utils.canonicalize_dims(ndim, dim)
  2960. dims = (dim,) if isinstance(dim, Dim) else dim
  2961. # Short-circuits if the tensor has no dimensions
  2962. if ndim == 0:
  2963. assert len(dims) == 0 or dims == (0,)
  2964. return prims.view_of(a)
  2965. # Note: squeeze does not modify tensors when the given dim is not a dimension of length 1
  2966. dims = tuple(d for d in dims if a.shape[d] == 1)
  2967. return prims.squeeze(a, dims) if dims else prims.view_of(a)
  2968. # Note: does not work with TensorMetas because of data-dependent control-flow
  2969. # CompositeImplicitAutograd - don't register decomp
  2970. def tensor_split(
  2971. a: TensorLikeType,
  2972. indices_or_sections: Union[Tensor, DimsType],
  2973. dim: int = 0,
  2974. ) -> Tuple[TensorLikeType, ...]:
  2975. _dim = utils.canonicalize_dim(a.ndim, dim)
  2976. if a.ndim == 0:
  2977. msg = "tensor_split: received a rank zero tensor, but expected a tensor of rank one or greater!"
  2978. raise ValueError(msg)
  2979. # If indices_or_sections is a tensor, it must be a CPU Long tensor
  2980. if isinstance(indices_or_sections, TensorLike):
  2981. if not indices_or_sections.device.type == "cpu":
  2982. msg = "tensor_split: if indices_or_sections is a tensor it must be on the CPU, but received one on {0}".format(
  2983. indices_or_sections.device
  2984. )
  2985. raise ValueError(msg)
  2986. if indices_or_sections.dtype != torch.long:
  2987. msg = "tensor_split: if indices_or_sections is a tensor it must have long dtype, "
  2988. " but received one with dtype {0}".format(indices_or_sections.dtype)
  2989. raise ValueError(msg)
  2990. # Case 0 -- indices_or_sections is an integer or a scalar tensor n and a is split along dim into n parts of equal-ish length
  2991. if isinstance(indices_or_sections, IntLike) or (
  2992. isinstance(indices_or_sections, TensorLike) and indices_or_sections.ndim == 0
  2993. ):
  2994. sections: int = (
  2995. indices_or_sections # type: ignore[assignment]
  2996. if isinstance(indices_or_sections, Number)
  2997. else indices_or_sections.item()
  2998. )
  2999. if sections <= 0:
  3000. msg = "tensor_split: number of sections must be greater than 0, but was {0}".format(
  3001. sections
  3002. )
  3003. raise ValueError(msg)
  3004. splits = []
  3005. dim_size = a.shape[_dim]
  3006. min_split_size = math.floor(dim_size / sections)
  3007. num_splits_one_extra = dim_size % sections
  3008. start_idx = 0
  3009. for split_idx in range(sections):
  3010. split_size = (
  3011. min_split_size + 1
  3012. if (split_idx < num_splits_one_extra)
  3013. else min_split_size
  3014. )
  3015. s = prims.slice_in_dim(a, start_idx, start_idx + split_size, axis=_dim)
  3016. splits.append(s)
  3017. start_idx = start_idx + split_size
  3018. return tuple(splits)
  3019. # Case 1 -- indices_or_sections is a sequence of integers or a 1D tensor describing the splits
  3020. else:
  3021. indices = indices_or_sections
  3022. if isinstance(indices_or_sections, TensorLike):
  3023. if indices_or_sections.ndim != 1:
  3024. msg = "tensor_split: non-scalar indices_or_sections tensors must have only one dimension, "
  3025. "but received a tensor with {0} dimensions".format(
  3026. indices_or_sections.ndim
  3027. )
  3028. raise ValueError(msg)
  3029. indices = indices_or_sections.tolist()
  3030. splits = []
  3031. start_idx = 0
  3032. for x in indices:
  3033. splits.append(prims.slice_in_dim(a, start_idx, x, axis=_dim))
  3034. start_idx = x
  3035. splits.append(prims.slice_in_dim(a, start_idx, a.shape[_dim], axis=_dim))
  3036. return tuple(splits)
  3037. # CompositeImplicitAutograd - don't register decomp
  3038. def hsplit(
  3039. a: TensorLikeType, indices_or_sections: DimsType
  3040. ) -> Tuple[TensorLikeType, ...]:
  3041. check(
  3042. a.ndim >= 1,
  3043. lambda: (
  3044. "torch.hsplit requires a tensor with at least 1 dimension, but got a tensor with "
  3045. + str(a.ndim)
  3046. + " dimensions!"
  3047. ),
  3048. )
  3049. dim = 0 if a.ndim == 1 else 1
  3050. if isinstance(indices_or_sections, IntLike):
  3051. split_size = indices_or_sections
  3052. check(
  3053. (split_size != 0 and a.shape[dim] % split_size == 0),
  3054. lambda: (
  3055. "torch.hsplit attempted to split along dimension "
  3056. + str(dim)
  3057. + ", but the size of the dimension "
  3058. + str(a.shape[dim])
  3059. + " is not divisible by the split_size "
  3060. + str(split_size)
  3061. + "!"
  3062. ),
  3063. )
  3064. return tensor_split(a, split_size, dim)
  3065. check(
  3066. isinstance(indices_or_sections, (list, tuple)),
  3067. lambda: (
  3068. "hsplit(): received an invalid combination of arguments. "
  3069. "Expected indices_or_sections to be of type int, list of ints or tuple of ints "
  3070. f"but got type {type(indices_or_sections)}"
  3071. ),
  3072. exc_type=TypeError,
  3073. )
  3074. split_sizes = indices_or_sections
  3075. return tensor_split(a, split_sizes, dim)
  3076. # CompositeImplicitAutograd - don't register decomp
  3077. def vsplit(
  3078. a: TensorLikeType, indices_or_sections: DimsType
  3079. ) -> Tuple[TensorLikeType, ...]:
  3080. check(
  3081. a.ndim >= 2,
  3082. lambda: (
  3083. "torch.vsplit requires a tensor with at least 2 dimension, but got a tensor with "
  3084. + str(a.ndim)
  3085. + " dimensions!"
  3086. ),
  3087. )
  3088. if isinstance(indices_or_sections, IntLike):
  3089. split_size = indices_or_sections
  3090. check(
  3091. (split_size != 0 and a.shape[0] % split_size == 0),
  3092. lambda: (
  3093. f"torch.vsplit attempted to split along dimension 0"
  3094. f", but the size of the dimension "
  3095. f"{a.shape[0]}"
  3096. f" is not divisible by the split_size "
  3097. f"{split_size}"
  3098. f"!"
  3099. ),
  3100. )
  3101. return tensor_split(a, split_size, 0)
  3102. check(
  3103. isinstance(indices_or_sections, (list, tuple)),
  3104. lambda: (
  3105. "vsplit(): received an invalid combination of arguments. "
  3106. "Expected indices_or_sections to be of type int, list of ints or tuple of ints "
  3107. f"but got type {type(indices_or_sections)}"
  3108. ),
  3109. exc_type=TypeError,
  3110. )
  3111. split_sizes = indices_or_sections
  3112. return tensor_split(a, split_sizes, 0)
  3113. @register_decomposition(aten.diag.out)
  3114. @out_wrapper()
  3115. def diag(
  3116. self: TensorLikeType,
  3117. offset: int = 0,
  3118. ) -> TensorLikeType:
  3119. ndim = self.dim()
  3120. utils.check(
  3121. ndim in (1, 2), lambda: f"diag(): Supports 1D or 2D tensors. Got {ndim}D"
  3122. )
  3123. if ndim == 1:
  3124. return torch.diag_embed(self, offset)
  3125. else:
  3126. return torch.diagonal_copy(self, offset)
  3127. @register_decomposition(aten.diagonal_scatter)
  3128. @out_wrapper()
  3129. def diagonal_scatter(
  3130. input: TensorLikeType,
  3131. src: TensorLikeType,
  3132. offset: int = 0,
  3133. dim1: int = 0,
  3134. dim2: int = 1,
  3135. ) -> TensorLikeType:
  3136. out = utils.clone_preserve_strides(input)
  3137. diag = out.diagonal(offset, dim1, dim2)
  3138. check(
  3139. diag.shape == src.shape,
  3140. lambda: "expected src to have a size equal to the diagonal of the input."
  3141. f"Got {src.shape} for a diagonal of shape {diag.shape}",
  3142. )
  3143. copy_to(diag, src)
  3144. return out
  3145. @register_decomposition(aten.diagonal)
  3146. def diagonal(
  3147. self: TensorLikeType,
  3148. offset: int = 0,
  3149. dim1: int = 0,
  3150. dim2: int = 1,
  3151. ) -> TensorLikeType:
  3152. """
  3153. Reference implementation of torch.diagonal
  3154. """
  3155. num_dims = self.dim()
  3156. dim1 = utils.canonicalize_dim(idx=dim1, rank=num_dims)
  3157. dim2 = utils.canonicalize_dim(idx=dim2, rank=num_dims)
  3158. check(
  3159. dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}"
  3160. )
  3161. storage_offset = self.storage_offset()
  3162. if offset >= 0:
  3163. diag_size = max(min(self.size()[dim1], self.size()[dim2] - offset), 0)
  3164. else:
  3165. diag_size = max(min(self.size()[dim1] + offset, self.size()[dim2]), 0)
  3166. if diag_size > 0:
  3167. if offset >= 0:
  3168. storage_offset += offset * self.stride()[dim2]
  3169. else:
  3170. storage_offset -= offset * self.stride()[dim1]
  3171. sizes = [s for i, s in enumerate(self.size()) if i not in (dim1, dim2)]
  3172. sizes.append(diag_size)
  3173. strides = [s for i, s in enumerate(self.stride()) if i not in (dim1, dim2)]
  3174. strides.append(self.stride()[dim1] + self.stride()[dim2])
  3175. result = self.as_strided(size=sizes, stride=strides, storage_offset=storage_offset)
  3176. return result
  3177. diagonal_copy = _make_copy_from_view(diagonal)
  3178. @register_decomposition(aten.diag_embed)
  3179. @out_wrapper()
  3180. def diag_embed(
  3181. t: TensorLikeType,
  3182. offset: int = 0,
  3183. dim1: int = -2,
  3184. dim2: int = -1,
  3185. ) -> TensorLikeType:
  3186. """
  3187. Reference implementation of torch.diag_embed
  3188. """
  3189. # as per the docs, exchanging dims is equivalent to changing the sign of
  3190. # offset
  3191. if dim1 > dim2:
  3192. dim1, dim2 = dim2, dim1
  3193. offset = -offset
  3194. # convert from negative dims
  3195. rank = t.ndim + 1
  3196. dim1 = utils.canonicalize_dim(rank=rank, idx=dim1)
  3197. dim2 = utils.canonicalize_dim(rank=rank, idx=dim2)
  3198. check(
  3199. dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}"
  3200. )
  3201. # as per the docs, the size of last dim is placed at dim1 and dim2
  3202. last_dim = t.size(-1)
  3203. if offset != 0:
  3204. # add padding to match the new size
  3205. t_shape = list(t.shape)
  3206. t_shape[-1] = builtins.abs(offset)
  3207. z = torch.zeros(t_shape, dtype=t.dtype, device=t.device, requires_grad=False)
  3208. pair = (z, t) if offset > 0 else (t, z)
  3209. t = torch.cat(pair, dim=-1)
  3210. # make sure the diagonal always has the same size
  3211. last_dim += builtins.abs(offset)
  3212. # preserve original data, but place 1 at dim1 and move last dim to dim2
  3213. t = t.unsqueeze(dim1).movedim(-1, dim2)
  3214. # generate ranges shifting indices based on offset
  3215. a_range = torch.arange(last_dim, device=t.device, dtype=torch.int64)
  3216. b_range = torch.arange(
  3217. offset, last_dim + offset, device=t.device, dtype=torch.int64
  3218. )
  3219. # broadcast
  3220. cond = a_range == b_range.unsqueeze(-1)
  3221. cond_shape = [last_dim if i in (dim1, dim2) else 1 for i in range(len(t.shape))]
  3222. cond = cond.reshape(cond_shape)
  3223. # aten.diag_embed always returns a new contiguous tensor
  3224. # contiguous() is needed to correctly model the output stride
  3225. return utils.mask_tensor(cond, t).contiguous()
  3226. # CompositeImplicitAutograd - don't register decomp
  3227. def dsplit(a: TensorLikeType, sections: DimsType) -> TensorSequenceType:
  3228. if a.ndim < 3:
  3229. raise RuntimeError(
  3230. f"torch.dsplit requires a tensor with at least 3 dimension, but got a tensor with {a.ndim} dimensions!"
  3231. )
  3232. if isinstance(sections, IntLike) and (sections == 0 or a.shape[2] % sections != 0):
  3233. raise RuntimeError(
  3234. "torch.dsplit attempted to split along dimension 2, "
  3235. + f"but the size of the dimension {a.shape[2]} is not divisible by the split_size {sections}!"
  3236. )
  3237. return tensor_split(a, sections, 2)
  3238. @register_decomposition(aten.t.default)
  3239. def t(a: TensorLikeType):
  3240. # TODO: Add sparse support
  3241. # if a.is_sparse:
  3242. # sparse_dim = a.sparse_dim()
  3243. # dense_dim = a.dense_dim()
  3244. # if not (sparse_dim <= 2 and dense_dim == 0):
  3245. # raise RuntimeError(
  3246. # f"t() expects a tensor with <= 2 sparse and 0 dense dimensions, but got {sparse_dim} sparse and"
  3247. # f"{dense_dim} dense dimensions"
  3248. # )
  3249. if a.ndim > 2:
  3250. raise RuntimeError(
  3251. f"t() expects a tensor with <= 2 dimensions, but self is {a.ndim}D"
  3252. )
  3253. return torch.transpose(a, 0, 0 if a.ndim < 2 else 1)
  3254. # CompositeImplicitAutograd - don't register decomp
  3255. def T(a: TensorLikeType) -> TensorLikeType:
  3256. # n != 2 && n != 0 is deprecated in regular PyTorch.
  3257. check(
  3258. a.ndim in (0, 2),
  3259. lambda: (
  3260. "The use of `x.T` on tensors of dimension other than 0 or 2 "
  3261. "to reverse their shape is not supported."
  3262. ),
  3263. )
  3264. return a.t()
  3265. @register_decomposition(aten.alias)
  3266. def alias(a: TensorLikeType) -> TensorLikeType:
  3267. return prims.view_of(a)
  3268. @register_decomposition(aten.transpose)
  3269. def transpose(a: TensorLikeType, dim0: int, dim1: int) -> TensorLikeType:
  3270. _dim0, _dim1 = utils.canonicalize_dims(a.ndim, (dim0, dim1)) # type: ignore[misc]
  3271. if a.ndim <= 1 or dim0 == dim1:
  3272. return aten.alias.default(a)
  3273. _permutation = list(range(0, a.ndim))
  3274. _permutation[_dim0] = _dim1
  3275. _permutation[_dim1] = _dim0
  3276. return torch.permute(a, _permutation)
  3277. # Aliases for transpose
  3278. swap_axes = transpose
  3279. @register_decomposition(aten.unfold)
  3280. def unfold(
  3281. self: TensorLikeType, dimension: int, size: int, step: int
  3282. ) -> TensorLikeType:
  3283. shape, strides = _get_unfold_shape_stride(
  3284. self.shape, self.stride(), dimension, size, step
  3285. )
  3286. return self.as_strided(shape, strides)
  3287. @register_decomposition(aten.unfold_copy)
  3288. @out_wrapper()
  3289. def unfold_copy(self: TensorLikeType, dimension: int, size: int, step: int):
  3290. return self.unfold(dimension, size, step).clone(
  3291. memory_format=torch.contiguous_format
  3292. )
  3293. @register_decomposition(aten.cumsum)
  3294. def cumsum(
  3295. a: TensorLikeType,
  3296. dim: int,
  3297. *,
  3298. keepdim: bool = False,
  3299. dtype: Optional[torch.dtype] = None,
  3300. out: Optional[Tensor] = None,
  3301. ) -> TensorLikeType:
  3302. # We implement all the kwargs of a reduction. ATen just handles dtype
  3303. # nb. This decomposition may not be as efficient as a backend-specific implementation
  3304. ndim = a.ndim
  3305. dim = utils.canonicalize_dim(ndim, dim)
  3306. if ndim == 0:
  3307. return sum(a.unsqueeze(0), dim=0, keepdim=keepdim, dtype=dtype, out=out)
  3308. a = a.unsqueeze(dim + 1)
  3309. rg = torch.arange(a.shape[dim], device=a.device)
  3310. mask = rg.unsqueeze(1) <= rg
  3311. for _ in range(ndim - dim - 1):
  3312. mask = mask.unsqueeze(-1)
  3313. masked_a = utils.mask_tensor(mask, a)
  3314. return sum(masked_a, dim=dim, keepdim=keepdim, dtype=dtype, out=out)
  3315. # Note: although squeeze is documented as having the out= kwarg it doesn't
  3316. @register_decomposition(aten.unsqueeze)
  3317. def unsqueeze(a: TensorLikeType, dim: int) -> TensorLikeType:
  3318. # Note that unsqueeze canonicalizes with rank + 1 because it allows
  3319. # a new innermost dimension to be specified
  3320. ndim = a.ndim + 1
  3321. dim = utils.canonicalize_dim(ndim, dim)
  3322. return prims.expand_dims(a, (dim,), ndim=ndim)
  3323. # NOTE: shape is a vararg because Tensor.reshape can be called with as
  3324. # Tensor.view(a, b, c) or Tensor.view((a, b, c)) Function call torch.view
  3325. # doesn't support unpacked shapes
  3326. # TODO: Turn this into a decomposition (currently fails on reshape meta tests)
  3327. @register_decomposition(aten.view)
  3328. def view(a: TensorLikeType, *shape: ShapeType) -> TensorLikeType:
  3329. return _reshape_view_helper(a, *shape, allow_copy=False)
  3330. # CompositeImplicitAutograd - don't register decomp
  3331. def view_as(self: TensorLikeType, other: TensorLikeType) -> TensorLikeType:
  3332. return self.view(other.size())
  3333. # CompositeImplicitAutograd - don't register decomp
  3334. def ravel(a: TensorLikeType) -> TensorLikeType:
  3335. return reshape(a, (-1,))
  3336. @register_decomposition(aten.empty.memory_format)
  3337. @out_wrapper()
  3338. def empty(
  3339. *shape,
  3340. dtype: Optional[torch.dtype] = None,
  3341. layout: torch.layout = torch.strided,
  3342. device: Optional[torch.device] = None,
  3343. requires_grad: bool = False,
  3344. pin_memory: bool = False,
  3345. memory_format: torch.memory_format = torch.contiguous_format,
  3346. ) -> TensorLikeType:
  3347. check(
  3348. memory_format != torch.preserve_format,
  3349. lambda: "torch.empty: the Preserve memory format is not supported",
  3350. )
  3351. shape = utils.extract_shape_from_varargs(shape)
  3352. if memory_format == torch.contiguous_format:
  3353. strides = utils.make_contiguous_strides_for(shape)
  3354. elif memory_format == torch.channels_last_3d:
  3355. strides = utils.make_channels_last_3d_strides_for(shape)
  3356. else: # memory_format == torch.channels_last
  3357. check(
  3358. memory_format == torch.channels_last,
  3359. lambda: f"torch.empty: received an unknown memory format {memory_format}!",
  3360. )
  3361. strides = utils.make_channels_last_2d_strides_for(shape)
  3362. return torch.empty_strided(
  3363. shape,
  3364. strides,
  3365. dtype=dtype,
  3366. layout=layout,
  3367. device=device,
  3368. pin_memory=pin_memory,
  3369. requires_grad=requires_grad,
  3370. )
  3371. @register_decomposition(aten.new_empty)
  3372. def new_empty(
  3373. a: TensorLikeType,
  3374. size: ShapeType,
  3375. *,
  3376. dtype: Optional[torch.dtype] = None,
  3377. layout: Optional[torch.layout] = None,
  3378. device: Optional[torch.device] = None,
  3379. pin_memory: bool = False,
  3380. ) -> TensorLikeType:
  3381. dtype = a.dtype if dtype is None else dtype
  3382. layout = a.layout if layout is None else layout
  3383. device = a.device if device is None else device
  3384. return torch.empty(
  3385. size,
  3386. dtype=dtype,
  3387. device=device,
  3388. pin_memory=pin_memory,
  3389. layout=layout,
  3390. )
  3391. @register_decomposition(aten.new_empty_strided)
  3392. def new_empty_strided(
  3393. a: TensorLikeType,
  3394. size: ShapeType,
  3395. stride: StrideType,
  3396. *,
  3397. dtype: Optional[torch.dtype] = None,
  3398. layout: Optional[torch.layout] = None,
  3399. device: Optional[torch.device] = None,
  3400. pin_memory: bool = False,
  3401. ) -> TensorLikeType:
  3402. """
  3403. Reference implementation of torch.Tensor.new_empty_strided
  3404. """
  3405. dtype = a.dtype if dtype is None else dtype
  3406. layout = a.layout if layout is None else layout
  3407. device = a.device if device is None else device
  3408. return torch.empty_strided(
  3409. size,
  3410. stride,
  3411. dtype=dtype,
  3412. device=device,
  3413. pin_memory=pin_memory,
  3414. layout=layout,
  3415. )
  3416. @register_decomposition(aten.zeros.default)
  3417. @out_wrapper()
  3418. def zeros(
  3419. *size,
  3420. dtype: Optional[torch.dtype] = None,
  3421. layout: torch.layout = torch.strided,
  3422. device: Optional[torch.device] = None,
  3423. pin_memory: bool = False,
  3424. requires_grad: bool = False,
  3425. ) -> TensorLikeType:
  3426. size = utils.extract_shape_from_varargs(size)
  3427. if dtype is None:
  3428. dtype = torch.get_default_dtype()
  3429. return torch.full(
  3430. size,
  3431. False if dtype == torch.bool else 0,
  3432. dtype=dtype,
  3433. layout=layout,
  3434. device=device,
  3435. pin_memory=pin_memory,
  3436. requires_grad=requires_grad,
  3437. )
  3438. @register_decomposition(aten.new_zeros)
  3439. def new_zeros(
  3440. a: TensorLikeType,
  3441. size: ShapeType,
  3442. *,
  3443. dtype: Optional[torch.dtype] = None,
  3444. layout: Optional[torch.layout] = None,
  3445. device: Optional[torch.device] = None,
  3446. pin_memory: bool = False,
  3447. requires_grad: bool = False,
  3448. ) -> TensorLikeType:
  3449. dtype = a.dtype if dtype is None else dtype
  3450. layout = a.layout if layout is None else layout
  3451. device = a.device if device is None else device
  3452. return torch.full(
  3453. size,
  3454. False if (dtype or a.dtype) == torch.bool else 0,
  3455. dtype=dtype,
  3456. layout=layout,
  3457. device=device,
  3458. pin_memory=pin_memory,
  3459. requires_grad=requires_grad,
  3460. )
  3461. @register_decomposition(aten.ones.default)
  3462. @out_wrapper()
  3463. def ones(
  3464. *size,
  3465. dtype: Optional[torch.dtype] = None,
  3466. layout: torch.layout = torch.strided,
  3467. device: Optional[torch.device] = None,
  3468. pin_memory: bool = False,
  3469. requires_grad: bool = False,
  3470. ) -> TensorLikeType:
  3471. size = utils.extract_shape_from_varargs(size)
  3472. if dtype is None:
  3473. dtype = torch.get_default_dtype()
  3474. return torch.full(
  3475. size,
  3476. True if dtype == torch.bool else 1,
  3477. dtype=dtype,
  3478. layout=layout,
  3479. device=device,
  3480. pin_memory=pin_memory,
  3481. requires_grad=requires_grad,
  3482. )
  3483. @register_decomposition(aten.new_ones)
  3484. def new_ones(
  3485. a: TensorLikeType,
  3486. size: ShapeType,
  3487. *,
  3488. dtype: Optional[torch.dtype] = None,
  3489. layout: Optional[torch.layout] = None,
  3490. device: Optional[torch.device] = None,
  3491. pin_memory: bool = False,
  3492. requires_grad: bool = False,
  3493. ) -> TensorLikeType:
  3494. dtype = a.dtype if dtype is None else dtype
  3495. layout = a.layout if layout is None else layout
  3496. device = a.device if device is None else device
  3497. return torch.full(
  3498. size,
  3499. True if (dtype or a.dtype) == torch.bool else 1,
  3500. dtype=dtype,
  3501. layout=layout,
  3502. device=device,
  3503. pin_memory=pin_memory,
  3504. requires_grad=requires_grad,
  3505. )
  3506. @register_decomposition(aten.new_full)
  3507. def new_full(
  3508. a: TensorLikeType,
  3509. size: ShapeType,
  3510. fill_value: Union[int, float, bool],
  3511. *,
  3512. dtype: Optional[torch.dtype] = None,
  3513. layout: Optional[torch.layout] = None,
  3514. device: Optional[torch.device] = None,
  3515. pin_memory: bool = False,
  3516. ) -> TensorLikeType:
  3517. dtype = a.dtype if dtype is None else dtype
  3518. layout = a.layout if layout is None else layout
  3519. device = a.device if device is None else device
  3520. return torch.full(
  3521. size,
  3522. fill_value,
  3523. dtype=dtype,
  3524. layout=layout,
  3525. device=device,
  3526. pin_memory=pin_memory,
  3527. )
  3528. @register_decomposition(aten.empty_like)
  3529. def empty_like(
  3530. a: TensorLikeType,
  3531. *,
  3532. dtype: Optional[torch.dtype] = None,
  3533. device: Optional[torch.device] = None,
  3534. layout: Optional[torch.layout] = None,
  3535. pin_memory: bool = False,
  3536. requires_grad: bool = False,
  3537. memory_format: torch.memory_format = torch.preserve_format,
  3538. ) -> TensorLikeType:
  3539. dtype = a.dtype if dtype is None else dtype
  3540. layout = a.layout if layout is None else layout
  3541. device = a.device if device is None else device
  3542. strides: Tuple[int, ...]
  3543. if memory_format != torch.preserve_format:
  3544. return torch.empty(
  3545. a.shape,
  3546. dtype=dtype,
  3547. layout=layout,
  3548. device=device,
  3549. requires_grad=requires_grad,
  3550. pin_memory=pin_memory,
  3551. memory_format=memory_format,
  3552. )
  3553. # memory_format == torch.preserve_format
  3554. strides = utils.compute_elementwise_output_strides(a)
  3555. return torch.empty_strided(
  3556. a.shape,
  3557. strides,
  3558. dtype=dtype,
  3559. layout=layout,
  3560. device=device,
  3561. pin_memory=pin_memory,
  3562. requires_grad=requires_grad,
  3563. )
  3564. @register_decomposition(aten.arange)
  3565. @out_wrapper()
  3566. def arange(
  3567. start: NumberType = 0,
  3568. end: Optional[NumberType] = None,
  3569. step: NumberType = 1,
  3570. *,
  3571. dtype: Optional[torch.dtype] = None,
  3572. layout: torch.layout = torch.strided,
  3573. device: Optional[torch.device] = None,
  3574. pin_memory: bool = False,
  3575. requires_grad: bool = False,
  3576. ) -> TensorLikeType:
  3577. utils.check_layout(layout)
  3578. utils.check_pin_memory(pin_memory)
  3579. device = torch.device(utils.device_or_default(device))
  3580. assert not isinstance(start, complex)
  3581. assert not isinstance(end, complex)
  3582. assert not isinstance(step, complex)
  3583. # Case: torch.arange(5)
  3584. if end is None:
  3585. end = start
  3586. start = 0
  3587. utils.check(step != 0, lambda: "step must be nonzero")
  3588. utils.check(
  3589. (step > 0 and end >= start) or (step < 0 and end <= start),
  3590. lambda: "upper bound and lower bound inconsistent with step sign",
  3591. )
  3592. def is_finite(x):
  3593. return not isinstance(x, FloatWithoutSymFloat) or math.isfinite(x)
  3594. utils.check(
  3595. is_finite(start) and is_finite(end),
  3596. lambda: f"unsupported range: {start} -> {end}",
  3597. )
  3598. utils.check(
  3599. is_finite(step),
  3600. lambda: f"step must be finite but got {step}",
  3601. )
  3602. if dtype is None:
  3603. args = (start, end, step)
  3604. integer_args = builtins.all(isinstance(arg, IntLike) for arg in args)
  3605. dtype = torch.int64 if integer_args else torch.get_default_dtype()
  3606. is_integer = utils.is_integer_dtype(dtype)
  3607. if is_integer:
  3608. xstart = sym_int(start)
  3609. xend = sym_int(end)
  3610. xstep = sym_int(step)
  3611. # For int64 we truncate arguments to int before calculating length, but
  3612. # other integral dtypes we don't. Weird... but needed to match ATen shapes.
  3613. if dtype == torch.int64:
  3614. length = math.ceil((xend - xstart) / xstep)
  3615. else:
  3616. length = math.ceil((end - start) / step)
  3617. if is_integer:
  3618. return prims.iota(
  3619. length,
  3620. start=xstart,
  3621. step=xstep,
  3622. dtype=dtype,
  3623. device=device,
  3624. requires_grad=requires_grad,
  3625. )
  3626. computation_dtype = utils.get_acc_type(dtype, device)
  3627. index = prims.iota(
  3628. length,
  3629. start=0,
  3630. step=1,
  3631. dtype=torch.int64,
  3632. device=device,
  3633. requires_grad=False,
  3634. )
  3635. index = _maybe_convert_to_dtype(index, computation_dtype)
  3636. result = start + step * index
  3637. result = _maybe_convert_to_dtype(result, dtype)
  3638. if requires_grad:
  3639. result.requires_grad_(True)
  3640. return result
  3641. @register_decomposition(aten.lerp)
  3642. @out_wrapper()
  3643. @elementwise_type_promotion_wrapper(
  3644. type_promoting_args=("start", "end", "weight"),
  3645. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  3646. )
  3647. def lerp(start: Tensor, end: Tensor, weight: Union[Tensor, NumberType]):
  3648. inputs = [start, end]
  3649. if isinstance(weight, Number):
  3650. weight = start.new_full((), weight) # type: ignore[arg-type]
  3651. else:
  3652. inputs.append(weight)
  3653. assert isinstance(weight, Tensor) # mypy
  3654. # We implement it this way for numerical stability. We assume (in the stability optimisation)
  3655. # that 0 <= weight <= 1. We take the abs to deal with complex numbers
  3656. # We want to perform operations near zero, which is where floating points are most precise
  3657. # thus, we perform the following optimisation:
  3658. # If weight.abs() >= 0.5:
  3659. # return (1 - weight) * (start - end) + end
  3660. mask = weight.abs() >= 0.5
  3661. coeff = torch.where(mask, weight - 1, weight)
  3662. base = torch.where(mask, end, start)
  3663. output = coeff * (end - start) + base
  3664. # make sure the decomposition output's stride is same as non-decomposition path.
  3665. stride = utils.compute_elementwise_output_strides(*_maybe_broadcast(*inputs))
  3666. if output.stride() != stride:
  3667. return prims.copy_strided(output, stride)
  3668. return output
  3669. @register_decomposition(aten.linspace)
  3670. @out_wrapper()
  3671. def linspace(
  3672. start: NumberType,
  3673. end: NumberType,
  3674. steps: NumberType,
  3675. *,
  3676. dtype: Optional[torch.dtype] = None,
  3677. device: Optional[torch.device] = None,
  3678. layout: torch.layout = torch.strided,
  3679. pin_memory: bool = False,
  3680. requires_grad: bool = False,
  3681. ) -> TensorLikeType:
  3682. if py_any(isinstance(arg, complex) for arg in (start, end, steps)):
  3683. default_complex_dtype = utils.corresponding_complex_dtype(
  3684. torch.get_default_dtype()
  3685. )
  3686. if dtype is None:
  3687. dtype = default_complex_dtype
  3688. else:
  3689. check(
  3690. utils.is_complex_dtype(dtype),
  3691. lambda: f"linspace(): inferred dtype {default_complex_dtype} can't be safely cast to passed dtype {dtype}",
  3692. )
  3693. else:
  3694. dtype = dtype or torch.get_default_dtype()
  3695. assert isinstance(dtype, torch.dtype)
  3696. # steps does not participate in the computation of the dtype
  3697. check(
  3698. isinstance(steps, IntLike),
  3699. lambda: "steps must be int, not float",
  3700. exc_type=TypeError,
  3701. )
  3702. assert isinstance(steps, IntLike) # for mypy
  3703. check(steps >= 0, lambda: "number of steps must be non-negative")
  3704. factory_kwargs = {
  3705. "layout": layout,
  3706. "device": device,
  3707. "pin_memory": pin_memory,
  3708. "requires_grad": requires_grad,
  3709. }
  3710. if steps == 0:
  3711. return torch.full((0,), 0, dtype=dtype, **factory_kwargs) # type: ignore[arg-type]
  3712. if steps == 1:
  3713. return torch.full((1,), start, dtype=dtype, **factory_kwargs) # type: ignore[arg-type]
  3714. if start == end:
  3715. return torch.full((steps,), start, dtype=dtype, **factory_kwargs) # type: ignore[arg-type]
  3716. # Perform in arange in int because some backends like ATen or Triton do not support all the dtypes
  3717. rg = torch.arange(0, steps, **factory_kwargs) # type: ignore[arg-type]
  3718. # Small types need to be computed in higher precision as this is, at heart, an associative scan
  3719. dtype_red = (
  3720. torch.int64
  3721. if (utils.is_boolean_dtype(dtype) or utils.is_integer_dtype(dtype))
  3722. else dtype
  3723. )
  3724. computation_dtype, _ = utils.reduction_dtypes(
  3725. rg, REDUCTION_OUTPUT_TYPE_KIND.SAME, dtype_red
  3726. )
  3727. cast_rg = partial(_maybe_convert_to_dtype, dtype=computation_dtype)
  3728. # We implement torch.lerp without performing rg / (steps - 1) explicitly
  3729. # With this we get out[0] == start, out[-1] == end
  3730. step = (end - start) / (steps - 1)
  3731. out = torch.where(
  3732. rg < steps / 2,
  3733. start + step * cast_rg(rg), # type: ignore[arg-type,operator]
  3734. end - step * cast_rg((steps - 1) - rg), # type: ignore[arg-type,operator]
  3735. )
  3736. return _maybe_convert_to_dtype(out, dtype) # type: ignore[return-value]
  3737. @register_decomposition(aten.logspace)
  3738. @out_wrapper()
  3739. def logspace(
  3740. start: NumberType,
  3741. end: NumberType,
  3742. steps: NumberType,
  3743. base: NumberType = 10,
  3744. *,
  3745. dtype: Optional[torch.dtype] = None,
  3746. device: Optional[torch.device] = None,
  3747. layout: torch.layout = torch.strided,
  3748. pin_memory: bool = False,
  3749. requires_grad: bool = False,
  3750. ) -> TensorLikeType:
  3751. if dtype is None:
  3752. dtype = torch.get_default_dtype()
  3753. # NB: NumPy doesn't have this cast
  3754. if prims.utils.is_integer_dtype(dtype):
  3755. if isinstance(start, FloatLike):
  3756. start = sym_int(start)
  3757. if isinstance(end, FloatLike):
  3758. end = sym_int(end)
  3759. assert not isinstance(base, complex) # for mypy
  3760. if base < 0:
  3761. raise NotImplementedError
  3762. ret = torch.linspace(
  3763. start,
  3764. end,
  3765. steps,
  3766. dtype=torch.float64,
  3767. layout=layout,
  3768. device=device,
  3769. pin_memory=pin_memory,
  3770. requires_grad=requires_grad,
  3771. )
  3772. return _maybe_convert_to_dtype(torch.pow(base, ret), dtype)
  3773. @overload
  3774. def meshgrid(tensors: Sequence[TensorLikeType], indexing: str):
  3775. pass
  3776. @overload
  3777. def meshgrid(*tensors: TensorLikeType, indexing: str):
  3778. pass
  3779. @register_decomposition(aten.meshgrid)
  3780. def meshgrid(
  3781. *tensors: Union[TensorLikeType, List[TensorLikeType], Tuple[TensorLikeType]],
  3782. indexing: str,
  3783. ) -> List[TensorLikeType]:
  3784. # This ref simultaneously handles two overloads (see stubs above)
  3785. # The `indexing` argument is currently optional for torch.meshgrid, but we
  3786. # plan to make the argument required: https://github.com/pytorch/pytorch/issues/50276
  3787. if isinstance(tensors[0], list) or isinstance(tensors[0], tuple):
  3788. assert len(tensors) == 1
  3789. tensors = tuple(tensors[0])
  3790. check(
  3791. py_all(isinstance(a, TensorLike) for a in tensors),
  3792. lambda: "meshgrid expects its inputs to be tensors",
  3793. )
  3794. check(len(tensors) > 0, lambda: "meshgrid expects a non-empty TensorList")
  3795. for i in range(len(tensors) - 1):
  3796. check(
  3797. tensors[i].dtype == tensors[i + 1].dtype, # type: ignore[union-attr]
  3798. lambda: "meshgrid expects all tensors to have the same dtype",
  3799. )
  3800. check(
  3801. tensors[i].device == tensors[i + 1].device, # type: ignore[union-attr]
  3802. lambda: "meshgrid expects all tensors to have the same device",
  3803. )
  3804. swap_first_and_second_tensors = False
  3805. if indexing == "xy":
  3806. swap_first_and_second_tensors = len(tensors) >= 2
  3807. if swap_first_and_second_tensors:
  3808. tensors = (tensors[1], tensors[0], *tensors[2:])
  3809. else:
  3810. check(
  3811. indexing == "ij",
  3812. lambda: (
  3813. 'torch.meshgrid: indexing must be one of "xy" or "ij", '
  3814. f"but received: {indexing}"
  3815. ),
  3816. )
  3817. result_shape: List[int] = []
  3818. for t in tensors:
  3819. assert isinstance(t, TensorLike) # mypy
  3820. check(
  3821. t.ndim == 0 or t.ndim == 1,
  3822. lambda: f"torch.meshgrid: Expected 0D or 1D tensor in the tensor list but got: {t}",
  3823. )
  3824. result_shape.append(t.numel())
  3825. grids: List[TensorLikeType] = []
  3826. for i, t in enumerate(tensors):
  3827. assert isinstance(t, TensorLike) # mypy
  3828. if t.ndim == 0:
  3829. t = t.view((1,))
  3830. grids.append(prims.broadcast_in_dim(t, result_shape, (i,)))
  3831. if swap_first_and_second_tensors:
  3832. # Swap outputs if we originally swapped at the beginning
  3833. grids[0], grids[1] = grids[1], grids[0]
  3834. return grids
  3835. # CompositeImplicitAutograd - don't register decomp
  3836. def movedim(
  3837. input: TensorLikeType,
  3838. source: Union[int, DimsSequenceType],
  3839. destination: Union[int, DimsSequenceType],
  3840. ) -> TensorLikeType:
  3841. """
  3842. Reference implementation of torch.movedim
  3843. """
  3844. if type(source) is int:
  3845. source = (source,)
  3846. if type(destination) is int:
  3847. destination = (destination,)
  3848. # Converts to list to produce a compatible error message with core PyTorch,
  3849. # which prints sequences in square brackets.
  3850. utils.check(
  3851. len(source) == len(destination), # type: ignore[arg-type]
  3852. lambda: (
  3853. "movedim: Invalid source or destination dims: source " # type: ignore[arg-type]
  3854. f"({list(source)} dims) should contain the same number " # type: ignore[arg-type]
  3855. f"of dims as destination ({list(destination)} dims)" # type: ignore[arg-type]
  3856. ),
  3857. )
  3858. rank = input.ndim
  3859. ss = tuple(utils.canonicalize_dims(rank=rank, indices=source)) # type: ignore[arg-type]
  3860. ds = tuple(utils.canonicalize_dims(rank=rank, indices=destination)) # type: ignore[arg-type]
  3861. sss = set(ss)
  3862. dss = set(ds)
  3863. # See above on why this converts to list in error messages.
  3864. utils.check(
  3865. len(ss) == len(sss),
  3866. lambda: f"movedim: repeated dim in `source` ({list(source)})", # type: ignore[arg-type]
  3867. )
  3868. utils.check(
  3869. len(ds) == len(dss),
  3870. lambda: f"movedim: repeated dim in `destination` ({list(destination)})", # type: ignore[arg-type]
  3871. )
  3872. m = dict(zip(ds, ss))
  3873. dims = []
  3874. si = 0 # source index
  3875. for di in range(rank):
  3876. # check if the destination index is in the mapping
  3877. s = m.get(di)
  3878. if s is not None:
  3879. # insert source index if found
  3880. dims.append(s)
  3881. else:
  3882. # insert source index sequentially, skipping indices from the mapping
  3883. while si in sss:
  3884. si += 1
  3885. dims.append(si)
  3886. si += 1
  3887. result = torch.permute(input, tuple(dims))
  3888. return result
  3889. # NOTE: for convenience, shape can be a tuple of ints or a tuple containing a tuple of ints
  3890. @register_decomposition(aten.empty_strided)
  3891. def empty_strided(
  3892. shape: Union[ShapeType, Tuple[ShapeType]],
  3893. strides: StrideType,
  3894. *,
  3895. dtype: Optional[torch.dtype] = None,
  3896. device: Optional[torch.device] = None,
  3897. layout: torch.layout = torch.strided,
  3898. requires_grad: bool = False,
  3899. pin_memory: bool = False,
  3900. ) -> TensorLikeType:
  3901. # Layout == strided, pin_memory is False
  3902. utils.check_layout(layout)
  3903. utils.check_pin_memory(pin_memory)
  3904. shape = utils.extract_shape_from_varargs(shape)
  3905. dtype = torch.get_default_dtype() if dtype is None else dtype
  3906. device = torch.device("cpu") if device is None else device
  3907. return prims.empty_strided(
  3908. shape,
  3909. strides,
  3910. dtype=dtype,
  3911. device=device,
  3912. requires_grad=requires_grad,
  3913. )
  3914. @register_decomposition(aten.eye)
  3915. @out_wrapper()
  3916. def eye(
  3917. n: int,
  3918. m: Optional[int] = None,
  3919. *,
  3920. dtype: Optional[torch.dtype] = None,
  3921. layout: torch.layout = torch.strided,
  3922. device: Optional[torch.device] = None,
  3923. pin_memory: bool = False,
  3924. requires_grad: bool = False, # TODO: unused
  3925. ) -> TensorLikeType:
  3926. """
  3927. Reference implementation of torch.eye
  3928. """
  3929. if m is None:
  3930. m = n
  3931. check(n >= 0, lambda: f"n must be greater or equal to 0, got {n}")
  3932. check(m >= 0, lambda: f"m must be greater or equal to 0, got {m}")
  3933. range_n = torch.arange(n, dtype=torch.int64, device=device, requires_grad=False)
  3934. range_m = torch.arange(m, dtype=torch.int64, device=device, requires_grad=False)
  3935. cond = range_n.unsqueeze(-1) == range_m
  3936. if dtype is torch.bool:
  3937. return cond
  3938. else:
  3939. one = torch.ones(
  3940. (1,),
  3941. dtype=dtype,
  3942. layout=layout,
  3943. device=device,
  3944. pin_memory=pin_memory,
  3945. requires_grad=False,
  3946. )
  3947. return torch.where(cond, one, 0)
  3948. # TODO: Use requires_grad. All refs taking the requires_grad kwarg must
  3949. # return a leaf tensor.
  3950. # result.requires_grad_(requires_grad)
  3951. @register_decomposition(aten.full)
  3952. @out_wrapper()
  3953. def full(
  3954. shape: ShapeType,
  3955. fill_value: NumberType,
  3956. *,
  3957. dtype: Optional[torch.dtype] = None,
  3958. layout: torch.layout = torch.strided,
  3959. device: Optional[torch.device] = None,
  3960. pin_memory: bool = False,
  3961. requires_grad: bool = False,
  3962. ) -> TensorLikeType:
  3963. utils.check_layout(layout)
  3964. utils.check_pin_memory(pin_memory)
  3965. dtype = dtype if dtype is not None else utils.type_to_dtype(type(fill_value))
  3966. device = device if device is not None else torch.device("cpu")
  3967. e = empty(
  3968. shape,
  3969. dtype=dtype,
  3970. layout=layout,
  3971. device=device,
  3972. pin_memory=pin_memory,
  3973. requires_grad=requires_grad,
  3974. )
  3975. return torch.fill(e, fill_value) # type: ignore[arg-type]
  3976. def full_like(
  3977. a: TensorLikeType,
  3978. fill_value: NumberType,
  3979. *,
  3980. dtype: Optional[torch.dtype] = None,
  3981. layout: Optional[torch.layout] = None,
  3982. device: Optional[torch.device] = None,
  3983. pin_memory: bool = False,
  3984. requires_grad: bool = False,
  3985. memory_format: torch.memory_format = torch.preserve_format,
  3986. ) -> TensorLikeType:
  3987. e = torch.empty_like(
  3988. a,
  3989. dtype=dtype,
  3990. layout=layout,
  3991. device=device,
  3992. pin_memory=pin_memory,
  3993. requires_grad=requires_grad,
  3994. memory_format=memory_format,
  3995. )
  3996. return fill(e, fill_value)
  3997. @register_decomposition(aten.zeros_like)
  3998. def zeros_like(
  3999. a: TensorLikeType,
  4000. *,
  4001. dtype: Optional[torch.dtype] = None,
  4002. layout: Optional[torch.layout] = None,
  4003. device: Optional[torch.device] = None,
  4004. pin_memory: bool = False,
  4005. requires_grad: bool = False,
  4006. memory_format: torch.memory_format = torch.preserve_format,
  4007. ) -> TensorLikeType:
  4008. return torch.full_like(
  4009. a,
  4010. False if (dtype or a.dtype) == torch.bool else 0,
  4011. dtype=dtype,
  4012. layout=layout,
  4013. device=device,
  4014. pin_memory=pin_memory,
  4015. requires_grad=requires_grad,
  4016. memory_format=memory_format,
  4017. )
  4018. @register_decomposition(aten.ones_like)
  4019. def ones_like(
  4020. a: TensorLikeType,
  4021. *,
  4022. dtype: Optional[torch.dtype] = None,
  4023. layout: Optional[torch.layout] = None,
  4024. device: Optional[torch.device] = None,
  4025. pin_memory: bool = False,
  4026. requires_grad: bool = False,
  4027. memory_format: torch.memory_format = torch.preserve_format,
  4028. ) -> TensorLikeType:
  4029. return torch.full_like(
  4030. a,
  4031. True if (dtype or a.dtype) == torch.bool else 1,
  4032. dtype=dtype,
  4033. layout=layout,
  4034. device=device,
  4035. pin_memory=pin_memory,
  4036. requires_grad=requires_grad,
  4037. memory_format=memory_format,
  4038. )
  4039. @register_decomposition(aten.randn.default)
  4040. @out_wrapper()
  4041. def randn(
  4042. *shape,
  4043. dtype: Optional[torch.dtype] = None,
  4044. device: Optional[torch.device] = None,
  4045. layout: Optional[torch.layout] = None,
  4046. requires_grad: bool = False,
  4047. pin_memory: bool = False,
  4048. ) -> TensorLikeType:
  4049. utils.check_pin_memory(pin_memory)
  4050. shape_ = utils.extract_shape_from_varargs(shape)
  4051. dtype = utils.dtype_or_default(dtype)
  4052. device = utils.device_or_default(device)
  4053. return prims.normal(
  4054. shape_,
  4055. mean=0.0,
  4056. std=1.0,
  4057. dtype=dtype,
  4058. device=device,
  4059. requires_grad=requires_grad,
  4060. )
  4061. def scalar_tensor(
  4062. a: NumberType,
  4063. *,
  4064. dtype: Optional[torch.dtype] = None,
  4065. layout: torch.layout = torch.strided,
  4066. device: Optional[torch.device] = None,
  4067. pin_memory: bool = False,
  4068. ) -> TensorLikeType:
  4069. utils.check_layout(layout)
  4070. utils.check_pin_memory(pin_memory)
  4071. dtype = dtype if dtype is not None else utils.type_to_dtype(type(a))
  4072. device = device if device is not None else torch.device("cpu")
  4073. return prims.scalar_tensor(a, dtype=dtype, device=device)
  4074. #
  4075. # Randomness References
  4076. #
  4077. def _uniform_helper(
  4078. shape: ShapeType,
  4079. low: Union[bool, int, float] = 0.0,
  4080. high: Union[bool, int, float] = 1.0,
  4081. *,
  4082. dtype: torch.dtype,
  4083. device: DeviceLikeType,
  4084. ) -> TensorLikeType:
  4085. utils.validate_shape(shape)
  4086. assert isinstance(low, Number)
  4087. assert isinstance(high, Number)
  4088. low = sym_float(low)
  4089. high = sym_float(high)
  4090. assert isinstance(dtype, torch.dtype)
  4091. device = utils.canonicalize_device(device)
  4092. return prims._uniform_helper(shape, low=low, high=high, dtype=dtype, device=device)
  4093. @register_decomposition(aten.masked_fill)
  4094. def masked_fill(a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLikeType):
  4095. python_type = utils.dtype_to_type(a.dtype)
  4096. if isinstance(value, Number):
  4097. value_type = type(value)
  4098. else:
  4099. # NOTE: Could not use value = item(value) as it resulted in
  4100. # RuntimeError: Cannot cast FakeTensor(cpu) to number
  4101. value_ndim = value.ndim
  4102. check(
  4103. value_ndim == 0,
  4104. lambda: f"only supports a 0-dimensional value tensor, but got tensor with {value_ndim} dimension",
  4105. )
  4106. # `masked_fill` allows cpu scalar to be moved to cuda but not otherwise.
  4107. is_cpu_scalar = a.device.type == "cuda" and value.device.type == "cpu"
  4108. check(
  4109. is_cpu_scalar or value.device == a.device,
  4110. lambda: "Expected `value` to be on same device as `a`",
  4111. )
  4112. value_type = utils.dtype_to_type(value.dtype)
  4113. if value_type is complex:
  4114. # only downcasting from complex to lower type is not allowed.
  4115. # We allow casting `value` to lower type for other case
  4116. # Eg. float -> int.
  4117. # Ref: https://github.com/pytorch/pytorch/issues/79195
  4118. check(
  4119. utils.is_weakly_lesser_type(value_type, python_type),
  4120. lambda: f"could not convert to type {python_type} without overflow",
  4121. )
  4122. # Since `where` allows type-promotion,
  4123. # cast value to correct type before passing to `where`
  4124. value = _maybe_convert_to_dtype(value, a.dtype)
  4125. r = torch.where(mask, value, a) # type: ignore[arg-type]
  4126. # aten.mask_fill always return a new contiguous tensor
  4127. # contiguous() is needed to correctly model the output stride
  4128. return r.contiguous()
  4129. @register_decomposition(aten.masked_fill_)
  4130. def masked_fill_(
  4131. a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLikeType
  4132. ) -> TensorLikeType:
  4133. b = torch.masked_fill(a, mask, value) # type: ignore[arg-type]
  4134. a.copy_(b)
  4135. return a
  4136. # CompositeImplicitAutograd - don't register decomp
  4137. def allclose(
  4138. a: TensorLikeType,
  4139. b: TensorLikeType,
  4140. rtol: float = 1e-05,
  4141. atol: float = 1e-08,
  4142. equal_nan: bool = False,
  4143. ) -> bool:
  4144. """
  4145. Reference implementation of torch.allclose
  4146. """
  4147. _check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol)
  4148. return bool(
  4149. torch.all(torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)).item()
  4150. )
  4151. # TODO: add OpInfo for torch.equal and refs.equal
  4152. def equal(a: TensorLikeType, b: TensorLikeType) -> bool:
  4153. utils.check_same_device(a, b, allow_cpu_scalar_tensors=False)
  4154. utils.check_same_dtype(a, b)
  4155. # Shape check
  4156. if a.ndim != b.ndim:
  4157. return False
  4158. for x, y in zip(a.shape, b.shape):
  4159. if x != y:
  4160. return False
  4161. # Short-circuits if there are no elements to validate
  4162. if a.numel() == 0:
  4163. return True
  4164. return item(all(eq(a, b))) # type: ignore[return-value]
  4165. @register_decomposition(aten.norm)
  4166. @out_wrapper(exact_dtype=True)
  4167. def norm(
  4168. input: TensorLikeType,
  4169. p: Optional[Union[float, str]] = "fro",
  4170. dim: Optional[DimsType] = None,
  4171. keepdim: bool = False,
  4172. *,
  4173. dtype: Optional[torch.dtype] = None,
  4174. ) -> TensorLikeType:
  4175. # In these cases we compute the "Frobenius norm"
  4176. if (
  4177. p == "fro" and (dim is None or isinstance(dim, Dim) or len(dim) <= 2)
  4178. ) or p is None:
  4179. p = 2
  4180. if isinstance(dim, Dim):
  4181. dim = [dim]
  4182. if isinstance(p, str):
  4183. # Here we either call the nuclear norm, or we call matrix_norm with some arguments
  4184. # that will throw an error
  4185. if dim is None:
  4186. dim = tuple(range(input.ndim))
  4187. return torch.linalg.matrix_norm(input, p, dim, keepdim, dtype=dtype)
  4188. else:
  4189. return torch.linalg.vector_norm(input, p, dim, keepdim, dtype=dtype)
  4190. @register_decomposition(aten.trace)
  4191. def trace(self: TensorLikeType) -> TensorLikeType:
  4192. utils.check(
  4193. self.ndim == 2, lambda: "expected a matrix, but got tensor with dim {self.ndim}"
  4194. )
  4195. return torch.sum(torch.diag(self, 0))
  4196. def _make_r_binary_op(base_op):
  4197. def rop(
  4198. a: Union[TensorLikeType, NumberType],
  4199. b: Union[TensorLikeType, NumberType],
  4200. ) -> TensorLikeType:
  4201. return base_op(b, a)
  4202. return rop
  4203. rtruediv = _make_r_binary_op(true_divide)
  4204. rfloordiv = _make_r_binary_op(floor_divide)
  4205. rpow = _make_r_binary_op(pow)
  4206. @register_decomposition(aten.triu)
  4207. @out_wrapper()
  4208. def triu(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType:
  4209. utils.check(
  4210. a.ndim >= 2, lambda: "triu: input tensor must have at least 2 dimensions"
  4211. )
  4212. h, w = a.shape[-2:]
  4213. mask = (
  4214. torch.arange(w, device=a.device).unsqueeze(-2)
  4215. - torch.arange(h, device=a.device).unsqueeze(-1)
  4216. ) >= diagonal
  4217. # aten.triu always returns a new contiguous tensor
  4218. # contiguous() is needed to correctly model the output stride
  4219. return utils.mask_tensor(mask, a).contiguous()
  4220. @register_decomposition(aten.tril)
  4221. @out_wrapper()
  4222. def tril(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType:
  4223. utils.check(
  4224. a.ndim >= 2, lambda: "tril: input tensor must have at least 2 dimensions"
  4225. )
  4226. h, w = a.shape[-2:]
  4227. mask = (
  4228. torch.arange(w, device=a.device).unsqueeze(-2)
  4229. - torch.arange(h, device=a.device).unsqueeze(-1)
  4230. ) <= diagonal
  4231. # aten.tril always returns a new contiguous tensor
  4232. # contiguous() is needed to correctly model the output stride
  4233. return utils.mask_tensor(mask, a).contiguous()
  4234. # This is based on get_tril_size in aten/src/ATen/native/TensorFactories.h
  4235. # The components of the matrix that belong to the lower triangle with offset
  4236. # form a pentagon that can be broken down into a top trapezoid and a bottom
  4237. # rectangle. For the implementation of tril_indices, we need the sizes of
  4238. # both of these, as well as the length of the top side of the trapezoid.
  4239. def _get_tril_sizes(row: int, col: int, offset: int) -> Tuple[int, int, int]:
  4240. if row == 0 or col == 0:
  4241. return 0, 0, 0
  4242. m_first_row = min(col, 1 + offset) if offset > 0 else int(row + offset > 0)
  4243. m_last_row = max(0, min(col, row + offset))
  4244. n_row_all = max(0, min(row, row + offset))
  4245. n_row_trapezoid = m_last_row - m_first_row + 1
  4246. # Number of elements in top trapezoid
  4247. trapezoid_size = (m_first_row + m_last_row) * n_row_trapezoid // 2
  4248. # Number of elements in bottom rectangle
  4249. diff_row = n_row_all - n_row_trapezoid
  4250. rectangle_size = max(0, diff_row * col)
  4251. return trapezoid_size, rectangle_size, m_first_row
  4252. def _trilu_checks(
  4253. name: str,
  4254. row: int,
  4255. col: int,
  4256. dtype: torch.dtype,
  4257. layout: torch.layout,
  4258. pin_memory: bool,
  4259. ):
  4260. check(row >= 0, lambda: f"row must be non-negative, got {row}")
  4261. check(col >= 0, lambda: f"col must be non-negative, got {col}")
  4262. check(
  4263. dtype in (torch.int32, torch.int64),
  4264. lambda: f"\"{name}\" not implemented for '{dtype}'",
  4265. )
  4266. # This is based on tril_indices_cuda in aten/src/ATen/native/cuda/TensorFactories.cu
  4267. @register_decomposition(aten.tril_indices)
  4268. def tril_indices(
  4269. row: int,
  4270. col: int,
  4271. offset: int = 0,
  4272. *,
  4273. dtype: torch.dtype = torch.long,
  4274. layout: torch.layout = torch.strided,
  4275. device: DeviceLikeType = "cpu",
  4276. pin_memory: bool = False,
  4277. ) -> TensorLikeType:
  4278. _trilu_checks("tril_indices", row, col, dtype, layout, pin_memory)
  4279. trapezoid_size, rectangle_size, m_first_row = _get_tril_sizes(row, col, offset)
  4280. row_offset = max(0, -offset)
  4281. arange_kw = partial(
  4282. torch.arange, layout=layout, device=device, pin_memory=pin_memory
  4283. )
  4284. # first we do the indices for top trapezoid
  4285. xs1 = arange_kw(0, trapezoid_size, dtype=torch.float64)
  4286. b = m_first_row - 0.5
  4287. row_inds1 = torch.floor(-b + torch.sqrt(b * b + 2 * xs1))
  4288. col_inds1 = torch.floor(xs1 - (2 * m_first_row - 1 + row_inds1) * row_inds1 * 0.5)
  4289. row_inds1 = _maybe_convert_to_dtype(row_inds1 + row_offset, dtype)
  4290. col_inds1 = _maybe_convert_to_dtype(col_inds1, dtype)
  4291. # then bottom rectangle
  4292. xs2 = arange_kw(0, rectangle_size, dtype=dtype)
  4293. row_inds2 = xs2 // col + (col - m_first_row + 1 + row_offset)
  4294. col_inds2 = xs2 % col
  4295. return torch.stack(
  4296. (torch.cat((row_inds1, row_inds2)), torch.cat((col_inds1, col_inds2)))
  4297. )
  4298. # Similar to _get_tril_sizes above, but here there is a top trapezoid and
  4299. # a bottom rectangle instead. Note that you can't reduce this to
  4300. # _get_tril_sizes(col, row, -offset) because that would correspond to
  4301. # decomposing into a left trapezoid and right rectangle.
  4302. def _get_triu_sizes(row: int, col: int, offset: int) -> Tuple[int, int, int]:
  4303. if row == 0 or col == 0:
  4304. return 0, 0, 0
  4305. m_first_row = max(0, col - offset) if offset > 0 else col
  4306. # Number of elements in top rectangle
  4307. rectangle_size = max(0, min(row, -offset) * col)
  4308. # Number of elements in bottom trapezoid
  4309. trapezoid_size_tril, rectangle_size_tril, _ = _get_tril_sizes(row, col, offset - 1)
  4310. triu_size = row * col - (trapezoid_size_tril + rectangle_size_tril)
  4311. trapezoid_size = triu_size - rectangle_size
  4312. return trapezoid_size, rectangle_size, m_first_row
  4313. @register_decomposition(aten.triu_indices)
  4314. def triu_indices(
  4315. row: int,
  4316. col: int,
  4317. offset: int = 0,
  4318. *,
  4319. dtype: torch.dtype = torch.long,
  4320. layout: torch.layout = torch.strided,
  4321. device: DeviceLikeType = "cpu",
  4322. pin_memory: bool = False,
  4323. ) -> TensorLikeType:
  4324. _trilu_checks("triu_indices", row, col, dtype, layout, pin_memory)
  4325. trapezoid_size, rectangle_size, m_first_row = _get_triu_sizes(row, col, offset)
  4326. col_offset = max(0, offset)
  4327. arange_kw = partial(
  4328. torch.arange, layout=layout, device=device, pin_memory=pin_memory
  4329. )
  4330. # indices for top rectangle
  4331. xs2 = arange_kw(0, rectangle_size, dtype=dtype)
  4332. row_inds2 = xs2 // col
  4333. col_inds2 = xs2 % col
  4334. # bottom trapezoid
  4335. xs1 = arange_kw(0, trapezoid_size, dtype=torch.float64)
  4336. b = -0.5 - m_first_row
  4337. row_inds1 = torch.floor(-b - torch.sqrt(b * b - 2 * xs1))
  4338. col_inds1 = torch.floor(xs1 - ((2 * m_first_row - 1 - row_inds1) * row_inds1) * 0.5)
  4339. row_inds1 = _maybe_convert_to_dtype(row_inds1, dtype)
  4340. col_inds1 = _maybe_convert_to_dtype(col_inds1, dtype)
  4341. if col:
  4342. row_inds1 = row_inds1 + (rectangle_size // col)
  4343. col_inds1 = col_inds1 + col_offset
  4344. return torch.stack(
  4345. (torch.cat((row_inds2, row_inds1)), torch.cat((col_inds2, col_inds1)))
  4346. )
  4347. @register_decomposition(aten.bucketize)
  4348. @out_wrapper(exact_dtype=True)
  4349. def bucketize(
  4350. a: TensorLikeType,
  4351. boundaries: TensorLikeType,
  4352. *,
  4353. out_int32: bool = False,
  4354. right: bool = False,
  4355. ):
  4356. utils.check(
  4357. boundaries.dim() == 1,
  4358. lambda: f"boundaries tensor must be 1 dimension but got dim({boundaries.dim()})",
  4359. )
  4360. out_dtype = torch.int32 if out_int32 else torch.int64
  4361. n_boundaries = boundaries.shape[-1]
  4362. if n_boundaries == 0:
  4363. return torch.zeros_like(a)
  4364. # We are trying to find the bucket (defined by pairs of consecutive elements of `boundaries`)
  4365. # each element of `a` belongs to. We use binary search to achieve logarithimic complexity,
  4366. # but each step of the search is done "in parallel" over all elements of `a`
  4367. # can't use int32 as indexes, so we have to do all computations with int64 and convert at the end
  4368. start = torch.zeros(a.shape, device=a.device, dtype=torch.int64)
  4369. end = start + n_boundaries
  4370. # Max depth of the binary search
  4371. # Since we can't break out of the loop at different points for different elements of a,
  4372. # we just do the max amount of iterations that binary search requires and add condition
  4373. # tensor (cond_update below) to stop updating once the search terminates
  4374. # For first iteration through loop we can skip some checks, we have separate implementation
  4375. mid = start + (end - start) // 2
  4376. mid_val = boundaries[mid]
  4377. if right:
  4378. cond_mid = mid_val > a
  4379. else:
  4380. cond_mid = mid_val >= a
  4381. start = torch.where(cond_mid, start, mid + 1)
  4382. if n_boundaries > 1:
  4383. cond_update = torch.ones_like(a, dtype=torch.bool)
  4384. niters = int(math.log2(n_boundaries))
  4385. for _ in range(niters):
  4386. end = torch.where(cond_mid & cond_update, mid, end)
  4387. cond_update = start < end
  4388. # start might end up pointing to 1 past the end, we guard against that
  4389. mid = torch.where(cond_update, start + (end - start) // 2, 0)
  4390. mid_val = boundaries[mid]
  4391. # If right is true, the buckets are closed on the *left*
  4392. # (i.e., we are doing the equivalent of std::upper_bound in C++)
  4393. # Otherwise they are closed on the right (std::lower_bound)
  4394. if right:
  4395. cond_mid = mid_val > a
  4396. else:
  4397. cond_mid = mid_val >= a
  4398. start = torch.where((~cond_mid) & cond_update, mid + 1, start)
  4399. return start.to(dtype=out_dtype)
  4400. @register_decomposition(aten.cauchy)
  4401. @out_wrapper()
  4402. @elementwise_type_promotion_wrapper(
  4403. type_promoting_args=("self",),
  4404. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  4405. )
  4406. def cauchy(self, median=0, sigma=1, generator=None):
  4407. assert generator is None
  4408. utils.check(
  4409. not utils.is_complex_dtype(self.dtype)
  4410. and not utils.is_integer_dtype(self.dtype)
  4411. and not utils.is_boolean_dtype(self.dtype),
  4412. lambda: f"Cauchy distribution is a continuous probability distribution. \
  4413. dtype must be a floating point but you specified {self.dtype}",
  4414. )
  4415. utils.check(
  4416. sigma > 0.0,
  4417. lambda: f"cauchy_ expects sigma > 0.0, but found sigma={sigma}",
  4418. )
  4419. return median + sigma * torch.tan(math.pi * (torch.rand_like(self) - 0.5))
  4420. @register_decomposition(aten.exponential)
  4421. @out_wrapper()
  4422. @elementwise_type_promotion_wrapper(
  4423. type_promoting_args=("self",),
  4424. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  4425. )
  4426. def exponential(self, rate=1, generator=None):
  4427. assert generator is None
  4428. utils.check(
  4429. not utils.is_complex_dtype(self.dtype)
  4430. and not utils.is_integer_dtype(self.dtype)
  4431. and not utils.is_boolean_dtype(self.dtype),
  4432. lambda: f"Exponential distribution is a continuous probability distribution. \
  4433. dtype must be a floating point but you specified {self.dtype}",
  4434. )
  4435. utils.check(
  4436. rate > 0.0,
  4437. lambda: f"exponential_ expects lambda > 0.0, but found lambda={rate}",
  4438. )
  4439. return -1 / rate * torch.log1p(-torch.rand_like(self))
  4440. @register_decomposition(aten.geometric)
  4441. @out_wrapper()
  4442. @elementwise_type_promotion_wrapper(
  4443. type_promoting_args=("self",),
  4444. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  4445. )
  4446. def geometric(self, p, generator=None):
  4447. assert generator is None
  4448. # TODO: fix inductor rand_like for integer, bool dtypes
  4449. utils.check(
  4450. not utils.is_complex_dtype(self.dtype)
  4451. and not utils.is_boolean_dtype(self.dtype),
  4452. lambda: f"geometric not implemented for {self.dtype}",
  4453. )
  4454. utils.check(
  4455. 0 < p and p < 1,
  4456. lambda: f"geometric_ expects p to be in (0, 1), but got p={p}",
  4457. )
  4458. return torch.floor(torch.log1p(-torch.rand_like(self)) / math.log1p(-p)) + 1
  4459. @register_decomposition(aten.log_normal)
  4460. @out_wrapper()
  4461. @elementwise_type_promotion_wrapper(
  4462. type_promoting_args=("self",),
  4463. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  4464. )
  4465. def log_normal(self, mean=1, std=2, generator=None):
  4466. assert generator is None
  4467. utils.check(
  4468. not utils.is_complex_dtype(self.dtype)
  4469. and not utils.is_integer_dtype(self.dtype)
  4470. and not utils.is_boolean_dtype(self.dtype),
  4471. lambda: f"log_normal not implemented for {self.dtype}",
  4472. )
  4473. utils.check(
  4474. 0 < std,
  4475. lambda: f"log_normal_ expects std > 0.0, but found std={std}",
  4476. )
  4477. return torch.exp(std * torch.randn_like(self) + mean)
  4478. # inplace
  4479. abs_ = _make_inplace(abs)
  4480. acos_ = _make_inplace(acos)
  4481. acosh_ = _make_inplace(acosh)
  4482. add_ = _make_inplace(add)
  4483. addcmul_ = _make_inplace(addcmul)
  4484. addcdiv_ = _make_inplace(addcdiv)
  4485. asin_ = _make_inplace(asin)
  4486. asinh_ = _make_inplace(asinh)
  4487. atan_ = _make_inplace(atan)
  4488. atanh_ = _make_inplace(atanh)
  4489. atan2_ = _make_inplace(atan2)
  4490. bitwise_and_ = _make_inplace(bitwise_and)
  4491. bitwise_left_shift_ = _make_inplace(bitwise_left_shift)
  4492. bitwise_not_ = _make_inplace(bitwise_not)
  4493. bitwise_or_ = _make_inplace(bitwise_or)
  4494. bitwise_right_shift_ = _make_inplace(bitwise_right_shift)
  4495. bitwise_xor_ = _make_inplace(bitwise_xor)
  4496. ceil_ = _make_inplace(ceil)
  4497. clamp_ = _make_inplace(clamp)
  4498. clamp_min_ = _make_inplace(clamp_min)
  4499. clamp_max_ = _make_inplace(clamp_max)
  4500. conj_physical_ = _make_inplace(conj_physical)
  4501. copysign_ = _make_inplace(copysign)
  4502. cos_ = _make_inplace(cos)
  4503. cosh_ = _make_inplace(cosh)
  4504. cumsum_ = _make_inplace(cumsum)
  4505. digamma_ = _make_inplace(digamma)
  4506. div_ = _make_inplace(div)
  4507. eq_ = _make_inplace(eq)
  4508. erf_ = _make_inplace(erf)
  4509. erfc_ = _make_inplace(erfc)
  4510. erfinv_ = _make_inplace(erfinv)
  4511. exp_ = _make_inplace(exp)
  4512. exp2_ = _make_inplace(exp2)
  4513. expm1_ = _make_inplace(expm1)
  4514. float_power_ = _make_inplace(float_power)
  4515. floor_ = _make_inplace(floor)
  4516. floor_divide_ = _make_inplace(floor_divide)
  4517. fmod_ = _make_inplace(fmod)
  4518. frac_ = _make_inplace(frac)
  4519. gcd_ = _make_inplace(gcd)
  4520. ge_ = _make_inplace(ge)
  4521. gt_ = _make_inplace(gt)
  4522. heaviside_ = _make_inplace(heaviside)
  4523. hypot_ = _make_inplace(hypot)
  4524. igamma_ = _make_inplace(igamma)
  4525. igammac_ = _make_inplace(igammac)
  4526. i0_ = _make_inplace(i0)
  4527. lcm_ = _make_inplace(lcm)
  4528. le_ = _make_inplace(le)
  4529. lerp_ = _make_inplace(lerp)
  4530. lgamma_ = _make_inplace(lgamma)
  4531. log10_ = _make_inplace(log10)
  4532. log1p_ = _make_inplace(log1p)
  4533. log2_ = _make_inplace(log2)
  4534. log_ = _make_inplace(log)
  4535. logical_and_ = _make_inplace(logical_and)
  4536. logical_not_ = _make_inplace(logical_not)
  4537. logical_or_ = _make_inplace(logical_or)
  4538. logical_xor_ = _make_inplace(logical_xor)
  4539. lt_ = _make_inplace(lt)
  4540. mul_ = _make_inplace(mul)
  4541. mvlgamma_ = _make_inplace(mvlgamma)
  4542. nan_to_num_ = _make_inplace(nan_to_num)
  4543. ne_ = _make_inplace(ne)
  4544. neg_ = _make_inplace(neg)
  4545. nextafter_ = _make_inplace(nextafter)
  4546. pow_ = _make_inplace(pow)
  4547. reciprocal_ = _make_inplace(reciprocal)
  4548. remainder_ = _make_inplace(remainder)
  4549. rsqrt_ = _make_inplace(rsqrt)
  4550. sgn_ = _make_inplace(sgn)
  4551. sigmoid_ = _make_inplace(sigmoid)
  4552. sign_ = _make_inplace(sign)
  4553. sin_ = _make_inplace(sin)
  4554. sinc_ = _make_inplace(sinc)
  4555. sinh_ = _make_inplace(sinh)
  4556. sqrt_ = _make_inplace(sqrt)
  4557. square_ = _make_inplace(square)
  4558. sub_ = _make_inplace(sub)
  4559. tan_ = _make_inplace(tan)
  4560. tanh_ = _make_inplace(tanh)
  4561. tril_ = _make_inplace(tril)
  4562. triu_ = _make_inplace(triu)
  4563. true_divide_ = _make_inplace(true_divide)
  4564. trunc_ = _make_inplace(trunc)
  4565. xlogy_ = _make_inplace(xlogy)
  4566. cauchy_ = _make_inplace(cauchy)
  4567. exponential_ = _make_inplace(exponential)
  4568. geometric_ = _make_inplace(geometric)
  4569. log_normal_ = _make_inplace(log_normal)
  4570. zero_ = _make_inplace(zero)
  4571. # Views
  4572. # We can't model these as above, as the pattern of doing `op(a, out=a)` does not work for a view function
  4573. # given that it does not reshape the input (it just copies the result into it)
  4574. # squeeze_ = _make_inplace(squeeze)
  4575. # t_ = _make_inplace(t)
  4576. # transpose_ = _make_inplace(transpose)
  4577. # unsqueeze_ = _make_inplace(unsqueeze)
  4578. import torch._refs._conversions
  4579. import torch._refs.fft
  4580. import torch._refs.linalg
  4581. import torch._refs.nn.functional
  4582. import torch._refs.special