common_methods_invocations.py 947 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839484048414842484348444845484648474848484948504851485248534854485548564857485848594860486148624863486448654866486748684869487048714872487348744875487648774878487948804881488248834884488548864887488848894890489148924893489448954896489748984899490049014902490349044905490649074908490949104911491249134914491549164917491849194920492149224923492449254926492749284929493049314932493349344935493649374938493949404941494249434944494549464947494849494950495149524953495449554956495749584959496049614962496349644965496649674968496949704971497249734974497549764977497849794980498149824983498449854986498749884989499049914992499349944995499649974998499950005001500250035004500550065007500850095010501150125013501450155016501750185019502050215022502350245025502650275028502950305031503250335034503550365037503850395040504150425043504450455046504750485049505050515052505350545055505650575058505950605061506250635064506550665067506850695070507150725073507450755076507750785079508050815082508350845085508650875088508950905091509250935094509550965097509850995100510151025103510451055106510751085109511051115112511351145115511651175118511951205121512251235124512551265127512851295130513151325133513451355136513751385139514051415142514351445145514651475148514951505151515251535154515551565157515851595160516151625163516451655166516751685169517051715172517351745175517651775178517951805181518251835184518551865187518851895190519151925193519451955196519751985199520052015202520352045205520652075208520952105211521252135214521552165217521852195220522152225223522452255226522752285229523052315232523352345235523652375238523952405241524252435244524552465247524852495250525152525253525452555256525752585259526052615262526352645265526652675268526952705271527252735274527552765277527852795280528152825283528452855286528752885289529052915292529352945295529652975298529953005301530253035304530553065307530853095310531153125313531453155316531753185319532053215322532353245325532653275328532953305331533253335334533553365337533853395340534153425343534453455346534753485349535053515352535353545355535653575358535953605361536253635364536553665367536853695370537153725373537453755376537753785379538053815382538353845385538653875388538953905391539253935394539553965397539853995400540154025403540454055406540754085409541054115412541354145415541654175418541954205421542254235424542554265427542854295430543154325433543454355436543754385439544054415442544354445445544654475448544954505451545254535454545554565457545854595460546154625463546454655466546754685469547054715472547354745475547654775478547954805481548254835484548554865487548854895490549154925493549454955496549754985499550055015502550355045505550655075508550955105511551255135514551555165517551855195520552155225523552455255526552755285529553055315532553355345535553655375538553955405541554255435544554555465547554855495550555155525553555455555556555755585559556055615562556355645565556655675568556955705571557255735574557555765577557855795580558155825583558455855586558755885589559055915592559355945595559655975598559956005601560256035604560556065607560856095610561156125613561456155616561756185619562056215622562356245625562656275628562956305631563256335634563556365637563856395640564156425643564456455646564756485649565056515652565356545655565656575658565956605661566256635664566556665667566856695670567156725673567456755676567756785679568056815682568356845685568656875688568956905691569256935694569556965697569856995700570157025703570457055706570757085709571057115712571357145715571657175718571957205721572257235724572557265727572857295730573157325733573457355736573757385739574057415742574357445745574657475748574957505751575257535754575557565757575857595760576157625763576457655766576757685769577057715772577357745775577657775778577957805781578257835784578557865787578857895790579157925793579457955796579757985799580058015802580358045805580658075808580958105811581258135814581558165817581858195820582158225823582458255826582758285829583058315832583358345835583658375838583958405841584258435844584558465847584858495850585158525853585458555856585758585859586058615862586358645865586658675868586958705871587258735874587558765877587858795880588158825883588458855886588758885889589058915892589358945895589658975898589959005901590259035904590559065907590859095910591159125913591459155916591759185919592059215922592359245925592659275928592959305931593259335934593559365937593859395940594159425943594459455946594759485949595059515952595359545955595659575958595959605961596259635964596559665967596859695970597159725973597459755976597759785979598059815982598359845985598659875988598959905991599259935994599559965997599859996000600160026003600460056006600760086009601060116012601360146015601660176018601960206021602260236024602560266027602860296030603160326033603460356036603760386039604060416042604360446045604660476048604960506051605260536054605560566057605860596060606160626063606460656066606760686069607060716072607360746075607660776078607960806081608260836084608560866087608860896090609160926093609460956096609760986099610061016102610361046105610661076108610961106111611261136114611561166117611861196120612161226123612461256126612761286129613061316132613361346135613661376138613961406141614261436144614561466147614861496150615161526153615461556156615761586159616061616162616361646165616661676168616961706171617261736174617561766177617861796180618161826183618461856186618761886189619061916192619361946195619661976198619962006201620262036204620562066207620862096210621162126213621462156216621762186219622062216222622362246225622662276228622962306231623262336234623562366237623862396240624162426243624462456246624762486249625062516252625362546255625662576258625962606261626262636264626562666267626862696270627162726273627462756276627762786279628062816282628362846285628662876288628962906291629262936294629562966297629862996300630163026303630463056306630763086309631063116312631363146315631663176318631963206321632263236324632563266327632863296330633163326333633463356336633763386339634063416342634363446345634663476348634963506351635263536354635563566357635863596360636163626363636463656366636763686369637063716372637363746375637663776378637963806381638263836384638563866387638863896390639163926393639463956396639763986399640064016402640364046405640664076408640964106411641264136414641564166417641864196420642164226423642464256426642764286429643064316432643364346435643664376438643964406441644264436444644564466447644864496450645164526453645464556456645764586459646064616462646364646465646664676468646964706471647264736474647564766477647864796480648164826483648464856486648764886489649064916492649364946495649664976498649965006501650265036504650565066507650865096510651165126513651465156516651765186519652065216522652365246525652665276528652965306531653265336534653565366537653865396540654165426543654465456546654765486549655065516552655365546555655665576558655965606561656265636564656565666567656865696570657165726573657465756576657765786579658065816582658365846585658665876588658965906591659265936594659565966597659865996600660166026603660466056606660766086609661066116612661366146615661666176618661966206621662266236624662566266627662866296630663166326633663466356636663766386639664066416642664366446645664666476648664966506651665266536654665566566657665866596660666166626663666466656666666766686669667066716672667366746675667666776678667966806681668266836684668566866687668866896690669166926693669466956696669766986699670067016702670367046705670667076708670967106711671267136714671567166717671867196720672167226723672467256726672767286729673067316732673367346735673667376738673967406741674267436744674567466747674867496750675167526753675467556756675767586759676067616762676367646765676667676768676967706771677267736774677567766777677867796780678167826783678467856786678767886789679067916792679367946795679667976798679968006801680268036804680568066807680868096810681168126813681468156816681768186819682068216822682368246825682668276828682968306831683268336834683568366837683868396840684168426843684468456846684768486849685068516852685368546855685668576858685968606861686268636864686568666867686868696870687168726873687468756876687768786879688068816882688368846885688668876888688968906891689268936894689568966897689868996900690169026903690469056906690769086909691069116912691369146915691669176918691969206921692269236924692569266927692869296930693169326933693469356936693769386939694069416942694369446945694669476948694969506951695269536954695569566957695869596960696169626963696469656966696769686969697069716972697369746975697669776978697969806981698269836984698569866987698869896990699169926993699469956996699769986999700070017002700370047005700670077008700970107011701270137014701570167017701870197020702170227023702470257026702770287029703070317032703370347035703670377038703970407041704270437044704570467047704870497050705170527053705470557056705770587059706070617062706370647065706670677068706970707071707270737074707570767077707870797080708170827083708470857086708770887089709070917092709370947095709670977098709971007101710271037104710571067107710871097110711171127113711471157116711771187119712071217122712371247125712671277128712971307131713271337134713571367137713871397140714171427143714471457146714771487149715071517152715371547155715671577158715971607161716271637164716571667167716871697170717171727173717471757176717771787179718071817182718371847185718671877188718971907191719271937194719571967197719871997200720172027203720472057206720772087209721072117212721372147215721672177218721972207221722272237224722572267227722872297230723172327233723472357236723772387239724072417242724372447245724672477248724972507251725272537254725572567257725872597260726172627263726472657266726772687269727072717272727372747275727672777278727972807281728272837284728572867287728872897290729172927293729472957296729772987299730073017302730373047305730673077308730973107311731273137314731573167317731873197320732173227323732473257326732773287329733073317332733373347335733673377338733973407341734273437344734573467347734873497350735173527353735473557356735773587359736073617362736373647365736673677368736973707371737273737374737573767377737873797380738173827383738473857386738773887389739073917392739373947395739673977398739974007401740274037404740574067407740874097410741174127413741474157416741774187419742074217422742374247425742674277428742974307431743274337434743574367437743874397440744174427443744474457446744774487449745074517452745374547455745674577458745974607461746274637464746574667467746874697470747174727473747474757476747774787479748074817482748374847485748674877488748974907491749274937494749574967497749874997500750175027503750475057506750775087509751075117512751375147515751675177518751975207521752275237524752575267527752875297530753175327533753475357536753775387539754075417542754375447545754675477548754975507551755275537554755575567557755875597560756175627563756475657566756775687569757075717572757375747575757675777578757975807581758275837584758575867587758875897590759175927593759475957596759775987599760076017602760376047605760676077608760976107611761276137614761576167617761876197620762176227623762476257626762776287629763076317632763376347635763676377638763976407641764276437644764576467647764876497650765176527653765476557656765776587659766076617662766376647665766676677668766976707671767276737674767576767677767876797680768176827683768476857686768776887689769076917692769376947695769676977698769977007701770277037704770577067707770877097710771177127713771477157716771777187719772077217722772377247725772677277728772977307731773277337734773577367737773877397740774177427743774477457746774777487749775077517752775377547755775677577758775977607761776277637764776577667767776877697770777177727773777477757776777777787779778077817782778377847785778677877788778977907791779277937794779577967797779877997800780178027803780478057806780778087809781078117812781378147815781678177818781978207821782278237824782578267827782878297830783178327833783478357836783778387839784078417842784378447845784678477848784978507851785278537854785578567857785878597860786178627863786478657866786778687869787078717872787378747875787678777878787978807881788278837884788578867887788878897890789178927893789478957896789778987899790079017902790379047905790679077908790979107911791279137914791579167917791879197920792179227923792479257926792779287929793079317932793379347935793679377938793979407941794279437944794579467947794879497950795179527953795479557956795779587959796079617962796379647965796679677968796979707971797279737974797579767977797879797980798179827983798479857986798779887989799079917992799379947995799679977998799980008001800280038004800580068007800880098010801180128013801480158016801780188019802080218022802380248025802680278028802980308031803280338034803580368037803880398040804180428043804480458046804780488049805080518052805380548055805680578058805980608061806280638064806580668067806880698070807180728073807480758076807780788079808080818082808380848085808680878088808980908091809280938094809580968097809880998100810181028103810481058106810781088109811081118112811381148115811681178118811981208121812281238124812581268127812881298130813181328133813481358136813781388139814081418142814381448145814681478148814981508151815281538154815581568157815881598160816181628163816481658166816781688169817081718172817381748175817681778178817981808181818281838184818581868187818881898190819181928193819481958196819781988199820082018202820382048205820682078208820982108211821282138214821582168217821882198220822182228223822482258226822782288229823082318232823382348235823682378238823982408241824282438244824582468247824882498250825182528253825482558256825782588259826082618262826382648265826682678268826982708271827282738274827582768277827882798280828182828283828482858286828782888289829082918292829382948295829682978298829983008301830283038304830583068307830883098310831183128313831483158316831783188319832083218322832383248325832683278328832983308331833283338334833583368337833883398340834183428343834483458346834783488349835083518352835383548355835683578358835983608361836283638364836583668367836883698370837183728373837483758376837783788379838083818382838383848385838683878388838983908391839283938394839583968397839883998400840184028403840484058406840784088409841084118412841384148415841684178418841984208421842284238424842584268427842884298430843184328433843484358436843784388439844084418442844384448445844684478448844984508451845284538454845584568457845884598460846184628463846484658466846784688469847084718472847384748475847684778478847984808481848284838484848584868487848884898490849184928493849484958496849784988499850085018502850385048505850685078508850985108511851285138514851585168517851885198520852185228523852485258526852785288529853085318532853385348535853685378538853985408541854285438544854585468547854885498550855185528553855485558556855785588559856085618562856385648565856685678568856985708571857285738574857585768577857885798580858185828583858485858586858785888589859085918592859385948595859685978598859986008601860286038604860586068607860886098610861186128613861486158616861786188619862086218622862386248625862686278628862986308631863286338634863586368637863886398640864186428643864486458646864786488649865086518652865386548655865686578658865986608661866286638664866586668667866886698670867186728673867486758676867786788679868086818682868386848685868686878688868986908691869286938694869586968697869886998700870187028703870487058706870787088709871087118712871387148715871687178718871987208721872287238724872587268727872887298730873187328733873487358736873787388739874087418742874387448745874687478748874987508751875287538754875587568757875887598760876187628763876487658766876787688769877087718772877387748775877687778778877987808781878287838784878587868787878887898790879187928793879487958796879787988799880088018802880388048805880688078808880988108811881288138814881588168817881888198820882188228823882488258826882788288829883088318832883388348835883688378838883988408841884288438844884588468847884888498850885188528853885488558856885788588859886088618862886388648865886688678868886988708871887288738874887588768877887888798880888188828883888488858886888788888889889088918892889388948895889688978898889989008901890289038904890589068907890889098910891189128913891489158916891789188919892089218922892389248925892689278928892989308931893289338934893589368937893889398940894189428943894489458946894789488949895089518952895389548955895689578958895989608961896289638964896589668967896889698970897189728973897489758976897789788979898089818982898389848985898689878988898989908991899289938994899589968997899889999000900190029003900490059006900790089009901090119012901390149015901690179018901990209021902290239024902590269027902890299030903190329033903490359036903790389039904090419042904390449045904690479048904990509051905290539054905590569057905890599060906190629063906490659066906790689069907090719072907390749075907690779078907990809081908290839084908590869087908890899090909190929093909490959096909790989099910091019102910391049105910691079108910991109111911291139114911591169117911891199120912191229123912491259126912791289129913091319132913391349135913691379138913991409141914291439144914591469147914891499150915191529153915491559156915791589159916091619162916391649165916691679168916991709171917291739174917591769177917891799180918191829183918491859186918791889189919091919192919391949195919691979198919992009201920292039204920592069207920892099210921192129213921492159216921792189219922092219222922392249225922692279228922992309231923292339234923592369237923892399240924192429243924492459246924792489249925092519252925392549255925692579258925992609261926292639264926592669267926892699270927192729273927492759276927792789279928092819282928392849285928692879288928992909291929292939294929592969297929892999300930193029303930493059306930793089309931093119312931393149315931693179318931993209321932293239324932593269327932893299330933193329333933493359336933793389339934093419342934393449345934693479348934993509351935293539354935593569357935893599360936193629363936493659366936793689369937093719372937393749375937693779378937993809381938293839384938593869387938893899390939193929393939493959396939793989399940094019402940394049405940694079408940994109411941294139414941594169417941894199420942194229423942494259426942794289429943094319432943394349435943694379438943994409441944294439444944594469447944894499450945194529453945494559456945794589459946094619462946394649465946694679468946994709471947294739474947594769477947894799480948194829483948494859486948794889489949094919492949394949495949694979498949995009501950295039504950595069507950895099510951195129513951495159516951795189519952095219522952395249525952695279528952995309531953295339534953595369537953895399540954195429543954495459546954795489549955095519552955395549555955695579558955995609561956295639564956595669567956895699570957195729573957495759576957795789579958095819582958395849585958695879588958995909591959295939594959595969597959895999600960196029603960496059606960796089609961096119612961396149615961696179618961996209621962296239624962596269627962896299630963196329633963496359636963796389639964096419642964396449645964696479648964996509651965296539654965596569657965896599660966196629663966496659666966796689669967096719672967396749675967696779678967996809681968296839684968596869687968896899690969196929693969496959696969796989699970097019702970397049705970697079708970997109711971297139714971597169717971897199720972197229723972497259726972797289729973097319732973397349735973697379738973997409741974297439744974597469747974897499750975197529753975497559756975797589759976097619762976397649765976697679768976997709771977297739774977597769777977897799780978197829783978497859786978797889789979097919792979397949795979697979798979998009801980298039804980598069807980898099810981198129813981498159816981798189819982098219822982398249825982698279828982998309831983298339834983598369837983898399840984198429843984498459846984798489849985098519852985398549855985698579858985998609861986298639864986598669867986898699870987198729873987498759876987798789879988098819882988398849885988698879888988998909891989298939894989598969897989898999900990199029903990499059906990799089909991099119912991399149915991699179918991999209921992299239924992599269927992899299930993199329933993499359936993799389939994099419942994399449945994699479948994999509951995299539954995599569957995899599960996199629963996499659966996799689969997099719972997399749975997699779978997999809981998299839984998599869987998899899990999199929993999499959996999799989999100001000110002100031000410005100061000710008100091001010011100121001310014100151001610017100181001910020100211002210023100241002510026100271002810029100301003110032100331003410035100361003710038100391004010041100421004310044100451004610047100481004910050100511005210053100541005510056100571005810059100601006110062100631006410065100661006710068100691007010071100721007310074100751007610077100781007910080100811008210083100841008510086100871008810089100901009110092100931009410095100961009710098100991010010101101021010310104101051010610107101081010910110101111011210113101141011510116101171011810119101201012110122101231012410125101261012710128101291013010131101321013310134101351013610137101381013910140101411014210143101441014510146101471014810149101501015110152101531015410155101561015710158101591016010161101621016310164101651016610167101681016910170101711017210173101741017510176101771017810179101801018110182101831018410185101861018710188101891019010191101921019310194101951019610197101981019910200102011020210203102041020510206102071020810209102101021110212102131021410215102161021710218102191022010221102221022310224102251022610227102281022910230102311023210233102341023510236102371023810239102401024110242102431024410245102461024710248102491025010251102521025310254102551025610257102581025910260102611026210263102641026510266102671026810269102701027110272102731027410275102761027710278102791028010281102821028310284102851028610287102881028910290102911029210293102941029510296102971029810299103001030110302103031030410305103061030710308103091031010311103121031310314103151031610317103181031910320103211032210323103241032510326103271032810329103301033110332103331033410335103361033710338103391034010341103421034310344103451034610347103481034910350103511035210353103541035510356103571035810359103601036110362103631036410365103661036710368103691037010371103721037310374103751037610377103781037910380103811038210383103841038510386103871038810389103901039110392103931039410395103961039710398103991040010401104021040310404104051040610407104081040910410104111041210413104141041510416104171041810419104201042110422104231042410425104261042710428104291043010431104321043310434104351043610437104381043910440104411044210443104441044510446104471044810449104501045110452104531045410455104561045710458104591046010461104621046310464104651046610467104681046910470104711047210473104741047510476104771047810479104801048110482104831048410485104861048710488104891049010491104921049310494104951049610497104981049910500105011050210503105041050510506105071050810509105101051110512105131051410515105161051710518105191052010521105221052310524105251052610527105281052910530105311053210533105341053510536105371053810539105401054110542105431054410545105461054710548105491055010551105521055310554105551055610557105581055910560105611056210563105641056510566105671056810569105701057110572105731057410575105761057710578105791058010581105821058310584105851058610587105881058910590105911059210593105941059510596105971059810599106001060110602106031060410605106061060710608106091061010611106121061310614106151061610617106181061910620106211062210623106241062510626106271062810629106301063110632106331063410635106361063710638106391064010641106421064310644106451064610647106481064910650106511065210653106541065510656106571065810659106601066110662106631066410665106661066710668106691067010671106721067310674106751067610677106781067910680106811068210683106841068510686106871068810689106901069110692106931069410695106961069710698106991070010701107021070310704107051070610707107081070910710107111071210713107141071510716107171071810719107201072110722107231072410725107261072710728107291073010731107321073310734107351073610737107381073910740107411074210743107441074510746107471074810749107501075110752107531075410755107561075710758107591076010761107621076310764107651076610767107681076910770107711077210773107741077510776107771077810779107801078110782107831078410785107861078710788107891079010791107921079310794107951079610797107981079910800108011080210803108041080510806108071080810809108101081110812108131081410815108161081710818108191082010821108221082310824108251082610827108281082910830108311083210833108341083510836108371083810839108401084110842108431084410845108461084710848108491085010851108521085310854108551085610857108581085910860108611086210863108641086510866108671086810869108701087110872108731087410875108761087710878108791088010881108821088310884108851088610887108881088910890108911089210893108941089510896108971089810899109001090110902109031090410905109061090710908109091091010911109121091310914109151091610917109181091910920109211092210923109241092510926109271092810929109301093110932109331093410935109361093710938109391094010941109421094310944109451094610947109481094910950109511095210953109541095510956109571095810959109601096110962109631096410965109661096710968109691097010971109721097310974109751097610977109781097910980109811098210983109841098510986109871098810989109901099110992109931099410995109961099710998109991100011001110021100311004110051100611007110081100911010110111101211013110141101511016110171101811019110201102111022110231102411025110261102711028110291103011031110321103311034110351103611037110381103911040110411104211043110441104511046110471104811049110501105111052110531105411055110561105711058110591106011061110621106311064110651106611067110681106911070110711107211073110741107511076110771107811079110801108111082110831108411085110861108711088110891109011091110921109311094110951109611097110981109911100111011110211103111041110511106111071110811109111101111111112111131111411115111161111711118111191112011121111221112311124111251112611127111281112911130111311113211133111341113511136111371113811139111401114111142111431114411145111461114711148111491115011151111521115311154111551115611157111581115911160111611116211163111641116511166111671116811169111701117111172111731117411175111761117711178111791118011181111821118311184111851118611187111881118911190111911119211193111941119511196111971119811199112001120111202112031120411205112061120711208112091121011211112121121311214112151121611217112181121911220112211122211223112241122511226112271122811229112301123111232112331123411235112361123711238112391124011241112421124311244112451124611247112481124911250112511125211253112541125511256112571125811259112601126111262112631126411265112661126711268112691127011271112721127311274112751127611277112781127911280112811128211283112841128511286112871128811289112901129111292112931129411295112961129711298112991130011301113021130311304113051130611307113081130911310113111131211313113141131511316113171131811319113201132111322113231132411325113261132711328113291133011331113321133311334113351133611337113381133911340113411134211343113441134511346113471134811349113501135111352113531135411355113561135711358113591136011361113621136311364113651136611367113681136911370113711137211373113741137511376113771137811379113801138111382113831138411385113861138711388113891139011391113921139311394113951139611397113981139911400114011140211403114041140511406114071140811409114101141111412114131141411415114161141711418114191142011421114221142311424114251142611427114281142911430114311143211433114341143511436114371143811439114401144111442114431144411445114461144711448114491145011451114521145311454114551145611457114581145911460114611146211463114641146511466114671146811469114701147111472114731147411475114761147711478114791148011481114821148311484114851148611487114881148911490114911149211493114941149511496114971149811499115001150111502115031150411505115061150711508115091151011511115121151311514115151151611517115181151911520115211152211523115241152511526115271152811529115301153111532115331153411535115361153711538115391154011541115421154311544115451154611547115481154911550115511155211553115541155511556115571155811559115601156111562115631156411565115661156711568115691157011571115721157311574115751157611577115781157911580115811158211583115841158511586115871158811589115901159111592115931159411595115961159711598115991160011601116021160311604116051160611607116081160911610116111161211613116141161511616116171161811619116201162111622116231162411625116261162711628116291163011631116321163311634116351163611637116381163911640116411164211643116441164511646116471164811649116501165111652116531165411655116561165711658116591166011661116621166311664116651166611667116681166911670116711167211673116741167511676116771167811679116801168111682116831168411685116861168711688116891169011691116921169311694116951169611697116981169911700117011170211703117041170511706117071170811709117101171111712117131171411715117161171711718117191172011721117221172311724117251172611727117281172911730117311173211733117341173511736117371173811739117401174111742117431174411745117461174711748117491175011751117521175311754117551175611757117581175911760117611176211763117641176511766117671176811769117701177111772117731177411775117761177711778117791178011781117821178311784117851178611787117881178911790117911179211793117941179511796117971179811799118001180111802118031180411805118061180711808118091181011811118121181311814118151181611817118181181911820118211182211823118241182511826118271182811829118301183111832118331183411835118361183711838118391184011841118421184311844118451184611847118481184911850118511185211853118541185511856118571185811859118601186111862118631186411865118661186711868118691187011871118721187311874118751187611877118781187911880118811188211883118841188511886118871188811889118901189111892118931189411895118961189711898118991190011901119021190311904119051190611907119081190911910119111191211913119141191511916119171191811919119201192111922119231192411925119261192711928119291193011931119321193311934119351193611937119381193911940119411194211943119441194511946119471194811949119501195111952119531195411955119561195711958119591196011961119621196311964119651196611967119681196911970119711197211973119741197511976119771197811979119801198111982119831198411985119861198711988119891199011991119921199311994119951199611997119981199912000120011200212003120041200512006120071200812009120101201112012120131201412015120161201712018120191202012021120221202312024120251202612027120281202912030120311203212033120341203512036120371203812039120401204112042120431204412045120461204712048120491205012051120521205312054120551205612057120581205912060120611206212063120641206512066120671206812069120701207112072120731207412075120761207712078120791208012081120821208312084120851208612087120881208912090120911209212093120941209512096120971209812099121001210112102121031210412105121061210712108121091211012111121121211312114121151211612117121181211912120121211212212123121241212512126121271212812129121301213112132121331213412135121361213712138121391214012141121421214312144121451214612147121481214912150121511215212153121541215512156121571215812159121601216112162121631216412165121661216712168121691217012171121721217312174121751217612177121781217912180121811218212183121841218512186121871218812189121901219112192121931219412195121961219712198121991220012201122021220312204122051220612207122081220912210122111221212213122141221512216122171221812219122201222112222122231222412225122261222712228122291223012231122321223312234122351223612237122381223912240122411224212243122441224512246122471224812249122501225112252122531225412255122561225712258122591226012261122621226312264122651226612267122681226912270122711227212273122741227512276122771227812279122801228112282122831228412285122861228712288122891229012291122921229312294122951229612297122981229912300123011230212303123041230512306123071230812309123101231112312123131231412315123161231712318123191232012321123221232312324123251232612327123281232912330123311233212333123341233512336123371233812339123401234112342123431234412345123461234712348123491235012351123521235312354123551235612357123581235912360123611236212363123641236512366123671236812369123701237112372123731237412375123761237712378123791238012381123821238312384123851238612387123881238912390123911239212393123941239512396123971239812399124001240112402124031240412405124061240712408124091241012411124121241312414124151241612417124181241912420124211242212423124241242512426124271242812429124301243112432124331243412435124361243712438124391244012441124421244312444124451244612447124481244912450124511245212453124541245512456124571245812459124601246112462124631246412465124661246712468124691247012471124721247312474124751247612477124781247912480124811248212483124841248512486124871248812489124901249112492124931249412495124961249712498124991250012501125021250312504125051250612507125081250912510125111251212513125141251512516125171251812519125201252112522125231252412525125261252712528125291253012531125321253312534125351253612537125381253912540125411254212543125441254512546125471254812549125501255112552125531255412555125561255712558125591256012561125621256312564125651256612567125681256912570125711257212573125741257512576125771257812579125801258112582125831258412585125861258712588125891259012591125921259312594125951259612597125981259912600126011260212603126041260512606126071260812609126101261112612126131261412615126161261712618126191262012621126221262312624126251262612627126281262912630126311263212633126341263512636126371263812639126401264112642126431264412645126461264712648126491265012651126521265312654126551265612657126581265912660126611266212663126641266512666126671266812669126701267112672126731267412675126761267712678126791268012681126821268312684126851268612687126881268912690126911269212693126941269512696126971269812699127001270112702127031270412705127061270712708127091271012711127121271312714127151271612717127181271912720127211272212723127241272512726127271272812729127301273112732127331273412735127361273712738127391274012741127421274312744127451274612747127481274912750127511275212753127541275512756127571275812759127601276112762127631276412765127661276712768127691277012771127721277312774127751277612777127781277912780127811278212783127841278512786127871278812789127901279112792127931279412795127961279712798127991280012801128021280312804128051280612807128081280912810128111281212813128141281512816128171281812819128201282112822128231282412825128261282712828128291283012831128321283312834128351283612837128381283912840128411284212843128441284512846128471284812849128501285112852128531285412855128561285712858128591286012861128621286312864128651286612867128681286912870128711287212873128741287512876128771287812879128801288112882128831288412885128861288712888128891289012891128921289312894128951289612897128981289912900129011290212903129041290512906129071290812909129101291112912129131291412915129161291712918129191292012921129221292312924129251292612927129281292912930129311293212933129341293512936129371293812939129401294112942129431294412945129461294712948129491295012951129521295312954129551295612957129581295912960129611296212963129641296512966129671296812969129701297112972129731297412975129761297712978129791298012981129821298312984129851298612987129881298912990129911299212993129941299512996129971299812999130001300113002130031300413005130061300713008130091301013011130121301313014130151301613017130181301913020130211302213023130241302513026130271302813029130301303113032130331303413035130361303713038130391304013041130421304313044130451304613047130481304913050130511305213053130541305513056130571305813059130601306113062130631306413065130661306713068130691307013071130721307313074130751307613077130781307913080130811308213083130841308513086130871308813089130901309113092130931309413095130961309713098130991310013101131021310313104131051310613107131081310913110131111311213113131141311513116131171311813119131201312113122131231312413125131261312713128131291313013131131321313313134131351313613137131381313913140131411314213143131441314513146131471314813149131501315113152131531315413155131561315713158131591316013161131621316313164131651316613167131681316913170131711317213173131741317513176131771317813179131801318113182131831318413185131861318713188131891319013191131921319313194131951319613197131981319913200132011320213203132041320513206132071320813209132101321113212132131321413215132161321713218132191322013221132221322313224132251322613227132281322913230132311323213233132341323513236132371323813239132401324113242132431324413245132461324713248132491325013251132521325313254132551325613257132581325913260132611326213263132641326513266132671326813269132701327113272132731327413275132761327713278132791328013281132821328313284132851328613287132881328913290132911329213293132941329513296132971329813299133001330113302133031330413305133061330713308133091331013311133121331313314133151331613317133181331913320133211332213323133241332513326133271332813329133301333113332133331333413335133361333713338133391334013341133421334313344133451334613347133481334913350133511335213353133541335513356133571335813359133601336113362133631336413365133661336713368133691337013371133721337313374133751337613377133781337913380133811338213383133841338513386133871338813389133901339113392133931339413395133961339713398133991340013401134021340313404134051340613407134081340913410134111341213413134141341513416134171341813419134201342113422134231342413425134261342713428134291343013431134321343313434134351343613437134381343913440134411344213443134441344513446134471344813449134501345113452134531345413455134561345713458134591346013461134621346313464134651346613467134681346913470134711347213473134741347513476134771347813479134801348113482134831348413485134861348713488134891349013491134921349313494134951349613497134981349913500135011350213503135041350513506135071350813509135101351113512135131351413515135161351713518135191352013521135221352313524135251352613527135281352913530135311353213533135341353513536135371353813539135401354113542135431354413545135461354713548135491355013551135521355313554135551355613557135581355913560135611356213563135641356513566135671356813569135701357113572135731357413575135761357713578135791358013581135821358313584135851358613587135881358913590135911359213593135941359513596135971359813599136001360113602136031360413605136061360713608136091361013611136121361313614136151361613617136181361913620136211362213623136241362513626136271362813629136301363113632136331363413635136361363713638136391364013641136421364313644136451364613647136481364913650136511365213653136541365513656136571365813659136601366113662136631366413665136661366713668136691367013671136721367313674136751367613677136781367913680136811368213683136841368513686136871368813689136901369113692136931369413695136961369713698136991370013701137021370313704137051370613707137081370913710137111371213713137141371513716137171371813719137201372113722137231372413725137261372713728137291373013731137321373313734137351373613737137381373913740137411374213743137441374513746137471374813749137501375113752137531375413755137561375713758137591376013761137621376313764137651376613767137681376913770137711377213773137741377513776137771377813779137801378113782137831378413785137861378713788137891379013791137921379313794137951379613797137981379913800138011380213803138041380513806138071380813809138101381113812138131381413815138161381713818138191382013821138221382313824138251382613827138281382913830138311383213833138341383513836138371383813839138401384113842138431384413845138461384713848138491385013851138521385313854138551385613857138581385913860138611386213863138641386513866138671386813869138701387113872138731387413875138761387713878138791388013881138821388313884138851388613887138881388913890138911389213893138941389513896138971389813899139001390113902139031390413905139061390713908139091391013911139121391313914139151391613917139181391913920139211392213923139241392513926139271392813929139301393113932139331393413935139361393713938139391394013941139421394313944139451394613947139481394913950139511395213953139541395513956139571395813959139601396113962139631396413965139661396713968139691397013971139721397313974139751397613977139781397913980139811398213983139841398513986139871398813989139901399113992139931399413995139961399713998139991400014001140021400314004140051400614007140081400914010140111401214013140141401514016140171401814019140201402114022140231402414025140261402714028140291403014031140321403314034140351403614037140381403914040140411404214043140441404514046140471404814049140501405114052140531405414055140561405714058140591406014061140621406314064140651406614067140681406914070140711407214073140741407514076140771407814079140801408114082140831408414085140861408714088140891409014091140921409314094140951409614097140981409914100141011410214103141041410514106141071410814109141101411114112141131411414115141161411714118141191412014121141221412314124141251412614127141281412914130141311413214133141341413514136141371413814139141401414114142141431414414145141461414714148141491415014151141521415314154141551415614157141581415914160141611416214163141641416514166141671416814169141701417114172141731417414175141761417714178141791418014181141821418314184141851418614187141881418914190141911419214193141941419514196141971419814199142001420114202142031420414205142061420714208142091421014211142121421314214142151421614217142181421914220142211422214223142241422514226142271422814229142301423114232142331423414235142361423714238142391424014241142421424314244142451424614247142481424914250142511425214253142541425514256142571425814259142601426114262142631426414265142661426714268142691427014271142721427314274142751427614277142781427914280142811428214283142841428514286142871428814289142901429114292142931429414295142961429714298142991430014301143021430314304143051430614307143081430914310143111431214313143141431514316143171431814319143201432114322143231432414325143261432714328143291433014331143321433314334143351433614337143381433914340143411434214343143441434514346143471434814349143501435114352143531435414355143561435714358143591436014361143621436314364143651436614367143681436914370143711437214373143741437514376143771437814379143801438114382143831438414385143861438714388143891439014391143921439314394143951439614397143981439914400144011440214403144041440514406144071440814409144101441114412144131441414415144161441714418144191442014421144221442314424144251442614427144281442914430144311443214433144341443514436144371443814439144401444114442144431444414445144461444714448144491445014451144521445314454144551445614457144581445914460144611446214463144641446514466144671446814469144701447114472144731447414475144761447714478144791448014481144821448314484144851448614487144881448914490144911449214493144941449514496144971449814499145001450114502145031450414505145061450714508145091451014511145121451314514145151451614517145181451914520145211452214523145241452514526145271452814529145301453114532145331453414535145361453714538145391454014541145421454314544145451454614547145481454914550145511455214553145541455514556145571455814559145601456114562145631456414565145661456714568145691457014571145721457314574145751457614577145781457914580145811458214583145841458514586145871458814589145901459114592145931459414595145961459714598145991460014601146021460314604146051460614607146081460914610146111461214613146141461514616146171461814619146201462114622146231462414625146261462714628146291463014631146321463314634146351463614637146381463914640146411464214643146441464514646146471464814649146501465114652146531465414655146561465714658146591466014661146621466314664146651466614667146681466914670146711467214673146741467514676146771467814679146801468114682146831468414685146861468714688146891469014691146921469314694146951469614697146981469914700147011470214703147041470514706147071470814709147101471114712147131471414715147161471714718147191472014721147221472314724147251472614727147281472914730147311473214733147341473514736147371473814739147401474114742147431474414745147461474714748147491475014751147521475314754147551475614757147581475914760147611476214763147641476514766147671476814769147701477114772147731477414775147761477714778147791478014781147821478314784147851478614787147881478914790147911479214793147941479514796147971479814799148001480114802148031480414805148061480714808148091481014811148121481314814148151481614817148181481914820148211482214823148241482514826148271482814829148301483114832148331483414835148361483714838148391484014841148421484314844148451484614847148481484914850148511485214853148541485514856148571485814859148601486114862148631486414865148661486714868148691487014871148721487314874148751487614877148781487914880148811488214883148841488514886148871488814889148901489114892148931489414895148961489714898148991490014901149021490314904149051490614907149081490914910149111491214913149141491514916149171491814919149201492114922149231492414925149261492714928149291493014931149321493314934149351493614937149381493914940149411494214943149441494514946149471494814949149501495114952149531495414955149561495714958149591496014961149621496314964149651496614967149681496914970149711497214973149741497514976149771497814979149801498114982149831498414985149861498714988149891499014991149921499314994149951499614997149981499915000150011500215003150041500515006150071500815009150101501115012150131501415015150161501715018150191502015021150221502315024150251502615027150281502915030150311503215033150341503515036150371503815039150401504115042150431504415045150461504715048150491505015051150521505315054150551505615057150581505915060150611506215063150641506515066150671506815069150701507115072150731507415075150761507715078150791508015081150821508315084150851508615087150881508915090150911509215093150941509515096150971509815099151001510115102151031510415105151061510715108151091511015111151121511315114151151511615117151181511915120151211512215123151241512515126151271512815129151301513115132151331513415135151361513715138151391514015141151421514315144151451514615147151481514915150151511515215153151541515515156151571515815159151601516115162151631516415165151661516715168151691517015171151721517315174151751517615177151781517915180151811518215183151841518515186151871518815189151901519115192151931519415195151961519715198151991520015201152021520315204152051520615207152081520915210152111521215213152141521515216152171521815219152201522115222152231522415225152261522715228152291523015231152321523315234152351523615237152381523915240152411524215243152441524515246152471524815249152501525115252152531525415255152561525715258152591526015261152621526315264152651526615267152681526915270152711527215273152741527515276152771527815279152801528115282152831528415285152861528715288152891529015291152921529315294152951529615297152981529915300153011530215303153041530515306153071530815309153101531115312153131531415315153161531715318153191532015321153221532315324153251532615327153281532915330153311533215333153341533515336153371533815339153401534115342153431534415345153461534715348153491535015351153521535315354153551535615357153581535915360153611536215363153641536515366153671536815369153701537115372153731537415375153761537715378153791538015381153821538315384153851538615387153881538915390153911539215393153941539515396153971539815399154001540115402154031540415405154061540715408154091541015411154121541315414154151541615417154181541915420154211542215423154241542515426154271542815429154301543115432154331543415435154361543715438154391544015441154421544315444154451544615447154481544915450154511545215453154541545515456154571545815459154601546115462154631546415465154661546715468154691547015471154721547315474154751547615477154781547915480154811548215483154841548515486154871548815489154901549115492154931549415495154961549715498154991550015501155021550315504155051550615507155081550915510155111551215513155141551515516155171551815519155201552115522155231552415525155261552715528155291553015531155321553315534155351553615537155381553915540155411554215543155441554515546155471554815549155501555115552155531555415555155561555715558155591556015561155621556315564155651556615567155681556915570155711557215573155741557515576155771557815579155801558115582155831558415585155861558715588155891559015591155921559315594155951559615597155981559915600156011560215603156041560515606156071560815609156101561115612156131561415615156161561715618156191562015621156221562315624156251562615627156281562915630156311563215633156341563515636156371563815639156401564115642156431564415645156461564715648156491565015651156521565315654156551565615657156581565915660156611566215663156641566515666156671566815669156701567115672156731567415675156761567715678156791568015681156821568315684156851568615687156881568915690156911569215693156941569515696156971569815699157001570115702157031570415705157061570715708157091571015711157121571315714157151571615717157181571915720157211572215723157241572515726157271572815729157301573115732157331573415735157361573715738157391574015741157421574315744157451574615747157481574915750157511575215753157541575515756157571575815759157601576115762157631576415765157661576715768157691577015771157721577315774157751577615777157781577915780157811578215783157841578515786157871578815789157901579115792157931579415795157961579715798157991580015801158021580315804158051580615807158081580915810158111581215813158141581515816158171581815819158201582115822158231582415825158261582715828158291583015831158321583315834158351583615837158381583915840158411584215843158441584515846158471584815849158501585115852158531585415855158561585715858158591586015861158621586315864158651586615867158681586915870158711587215873158741587515876158771587815879158801588115882158831588415885158861588715888158891589015891158921589315894158951589615897158981589915900159011590215903159041590515906159071590815909159101591115912159131591415915159161591715918159191592015921159221592315924159251592615927159281592915930159311593215933159341593515936159371593815939159401594115942159431594415945159461594715948159491595015951159521595315954159551595615957159581595915960159611596215963159641596515966159671596815969159701597115972159731597415975159761597715978159791598015981159821598315984159851598615987159881598915990159911599215993159941599515996159971599815999160001600116002160031600416005160061600716008160091601016011160121601316014160151601616017160181601916020160211602216023160241602516026160271602816029160301603116032160331603416035160361603716038160391604016041160421604316044160451604616047160481604916050160511605216053160541605516056160571605816059160601606116062160631606416065160661606716068160691607016071160721607316074160751607616077160781607916080160811608216083160841608516086160871608816089160901609116092160931609416095160961609716098160991610016101161021610316104161051610616107161081610916110161111611216113161141611516116161171611816119161201612116122161231612416125161261612716128161291613016131161321613316134161351613616137161381613916140161411614216143161441614516146161471614816149161501615116152161531615416155161561615716158161591616016161161621616316164161651616616167161681616916170161711617216173161741617516176161771617816179161801618116182161831618416185161861618716188161891619016191161921619316194161951619616197161981619916200162011620216203162041620516206162071620816209162101621116212162131621416215162161621716218162191622016221162221622316224162251622616227162281622916230162311623216233162341623516236162371623816239162401624116242162431624416245162461624716248162491625016251162521625316254162551625616257162581625916260162611626216263162641626516266162671626816269162701627116272162731627416275162761627716278162791628016281162821628316284162851628616287162881628916290162911629216293162941629516296162971629816299163001630116302163031630416305163061630716308163091631016311163121631316314163151631616317163181631916320163211632216323163241632516326163271632816329163301633116332163331633416335163361633716338163391634016341163421634316344163451634616347163481634916350163511635216353163541635516356163571635816359163601636116362163631636416365163661636716368163691637016371163721637316374163751637616377163781637916380163811638216383163841638516386163871638816389163901639116392163931639416395163961639716398163991640016401164021640316404164051640616407164081640916410164111641216413164141641516416164171641816419164201642116422164231642416425164261642716428164291643016431164321643316434164351643616437164381643916440164411644216443164441644516446164471644816449164501645116452164531645416455164561645716458164591646016461164621646316464164651646616467164681646916470164711647216473164741647516476164771647816479164801648116482164831648416485164861648716488164891649016491164921649316494164951649616497164981649916500165011650216503165041650516506165071650816509165101651116512165131651416515165161651716518165191652016521165221652316524165251652616527165281652916530165311653216533165341653516536165371653816539165401654116542165431654416545165461654716548165491655016551165521655316554165551655616557165581655916560165611656216563165641656516566165671656816569165701657116572165731657416575165761657716578165791658016581165821658316584165851658616587165881658916590165911659216593165941659516596165971659816599166001660116602166031660416605166061660716608166091661016611166121661316614166151661616617166181661916620166211662216623166241662516626166271662816629166301663116632166331663416635166361663716638166391664016641166421664316644166451664616647166481664916650166511665216653166541665516656166571665816659166601666116662166631666416665166661666716668166691667016671166721667316674166751667616677166781667916680166811668216683166841668516686166871668816689166901669116692166931669416695166961669716698166991670016701167021670316704167051670616707167081670916710167111671216713167141671516716167171671816719167201672116722167231672416725167261672716728167291673016731167321673316734167351673616737167381673916740167411674216743167441674516746167471674816749167501675116752167531675416755167561675716758167591676016761167621676316764167651676616767167681676916770167711677216773167741677516776167771677816779167801678116782167831678416785167861678716788167891679016791167921679316794167951679616797167981679916800168011680216803168041680516806168071680816809168101681116812168131681416815168161681716818168191682016821168221682316824168251682616827168281682916830168311683216833168341683516836168371683816839168401684116842168431684416845168461684716848168491685016851168521685316854168551685616857168581685916860168611686216863168641686516866168671686816869168701687116872168731687416875168761687716878168791688016881168821688316884168851688616887168881688916890168911689216893168941689516896168971689816899169001690116902169031690416905169061690716908169091691016911169121691316914169151691616917169181691916920169211692216923169241692516926169271692816929169301693116932169331693416935169361693716938169391694016941169421694316944169451694616947169481694916950169511695216953169541695516956169571695816959169601696116962169631696416965169661696716968169691697016971169721697316974169751697616977169781697916980169811698216983169841698516986169871698816989169901699116992169931699416995169961699716998169991700017001170021700317004170051700617007170081700917010170111701217013170141701517016170171701817019170201702117022170231702417025170261702717028170291703017031170321703317034170351703617037170381703917040170411704217043170441704517046170471704817049170501705117052170531705417055170561705717058170591706017061170621706317064170651706617067170681706917070170711707217073170741707517076170771707817079170801708117082170831708417085170861708717088170891709017091170921709317094170951709617097170981709917100171011710217103171041710517106171071710817109171101711117112171131711417115171161711717118171191712017121171221712317124171251712617127171281712917130171311713217133171341713517136171371713817139171401714117142171431714417145171461714717148171491715017151171521715317154171551715617157171581715917160171611716217163171641716517166171671716817169171701717117172171731717417175171761717717178171791718017181171821718317184171851718617187171881718917190171911719217193171941719517196171971719817199172001720117202172031720417205172061720717208172091721017211172121721317214172151721617217172181721917220172211722217223172241722517226172271722817229172301723117232172331723417235172361723717238172391724017241172421724317244172451724617247172481724917250172511725217253172541725517256172571725817259172601726117262172631726417265172661726717268172691727017271172721727317274172751727617277172781727917280172811728217283172841728517286172871728817289172901729117292172931729417295172961729717298172991730017301173021730317304173051730617307173081730917310173111731217313173141731517316173171731817319173201732117322173231732417325173261732717328173291733017331173321733317334173351733617337173381733917340173411734217343173441734517346173471734817349173501735117352173531735417355173561735717358173591736017361173621736317364173651736617367173681736917370173711737217373173741737517376173771737817379173801738117382173831738417385173861738717388173891739017391173921739317394173951739617397173981739917400174011740217403174041740517406174071740817409174101741117412174131741417415174161741717418174191742017421174221742317424174251742617427174281742917430174311743217433174341743517436174371743817439174401744117442174431744417445174461744717448174491745017451174521745317454174551745617457174581745917460174611746217463174641746517466174671746817469174701747117472174731747417475174761747717478174791748017481174821748317484174851748617487174881748917490174911749217493174941749517496174971749817499175001750117502175031750417505175061750717508175091751017511175121751317514175151751617517175181751917520175211752217523175241752517526175271752817529175301753117532175331753417535175361753717538175391754017541175421754317544175451754617547175481754917550175511755217553175541755517556175571755817559175601756117562175631756417565175661756717568175691757017571175721757317574175751757617577175781757917580175811758217583175841758517586175871758817589175901759117592175931759417595175961759717598175991760017601176021760317604176051760617607176081760917610176111761217613176141761517616176171761817619176201762117622176231762417625176261762717628176291763017631176321763317634176351763617637176381763917640176411764217643176441764517646176471764817649176501765117652176531765417655176561765717658176591766017661176621766317664176651766617667176681766917670176711767217673176741767517676176771767817679176801768117682176831768417685176861768717688176891769017691176921769317694176951769617697176981769917700177011770217703177041770517706177071770817709177101771117712177131771417715177161771717718177191772017721177221772317724177251772617727177281772917730177311773217733177341773517736177371773817739177401774117742177431774417745177461774717748177491775017751177521775317754177551775617757177581775917760177611776217763177641776517766177671776817769177701777117772177731777417775177761777717778177791778017781177821778317784177851778617787177881778917790177911779217793177941779517796177971779817799178001780117802178031780417805178061780717808178091781017811178121781317814178151781617817178181781917820178211782217823178241782517826178271782817829178301783117832178331783417835178361783717838178391784017841178421784317844178451784617847178481784917850178511785217853178541785517856178571785817859178601786117862178631786417865178661786717868178691787017871178721787317874178751787617877178781787917880178811788217883178841788517886178871788817889178901789117892178931789417895178961789717898178991790017901179021790317904179051790617907179081790917910179111791217913179141791517916179171791817919179201792117922179231792417925179261792717928179291793017931179321793317934179351793617937179381793917940179411794217943179441794517946179471794817949179501795117952179531795417955179561795717958179591796017961179621796317964179651796617967179681796917970179711797217973179741797517976179771797817979179801798117982179831798417985179861798717988179891799017991179921799317994179951799617997179981799918000180011800218003180041800518006180071800818009180101801118012180131801418015180161801718018180191802018021180221802318024180251802618027180281802918030180311803218033180341803518036180371803818039180401804118042180431804418045180461804718048180491805018051180521805318054180551805618057180581805918060180611806218063180641806518066180671806818069180701807118072180731807418075180761807718078180791808018081180821808318084180851808618087180881808918090180911809218093180941809518096180971809818099181001810118102181031810418105181061810718108181091811018111181121811318114181151811618117181181811918120181211812218123181241812518126181271812818129181301813118132181331813418135181361813718138181391814018141181421814318144181451814618147181481814918150181511815218153181541815518156181571815818159181601816118162181631816418165181661816718168181691817018171181721817318174181751817618177181781817918180181811818218183181841818518186181871818818189181901819118192181931819418195181961819718198181991820018201182021820318204182051820618207182081820918210182111821218213182141821518216182171821818219182201822118222182231822418225182261822718228182291823018231182321823318234182351823618237182381823918240182411824218243182441824518246182471824818249182501825118252182531825418255182561825718258182591826018261182621826318264182651826618267182681826918270182711827218273182741827518276182771827818279182801828118282182831828418285182861828718288182891829018291182921829318294182951829618297182981829918300183011830218303183041830518306183071830818309183101831118312183131831418315183161831718318183191832018321183221832318324183251832618327183281832918330183311833218333183341833518336183371833818339183401834118342183431834418345183461834718348183491835018351183521835318354183551835618357183581835918360183611836218363183641836518366183671836818369183701837118372183731837418375183761837718378183791838018381183821838318384183851838618387183881838918390183911839218393183941839518396183971839818399184001840118402184031840418405184061840718408184091841018411184121841318414184151841618417184181841918420184211842218423184241842518426184271842818429184301843118432184331843418435184361843718438184391844018441184421844318444184451844618447184481844918450184511845218453184541845518456184571845818459184601846118462184631846418465184661846718468184691847018471184721847318474184751847618477184781847918480184811848218483184841848518486184871848818489184901849118492184931849418495184961849718498184991850018501185021850318504185051850618507185081850918510185111851218513185141851518516185171851818519185201852118522185231852418525185261852718528185291853018531185321853318534185351853618537185381853918540185411854218543185441854518546185471854818549185501855118552185531855418555185561855718558185591856018561185621856318564185651856618567185681856918570185711857218573185741857518576185771857818579185801858118582185831858418585185861858718588185891859018591185921859318594185951859618597185981859918600186011860218603186041860518606186071860818609186101861118612186131861418615186161861718618186191862018621186221862318624186251862618627186281862918630186311863218633186341863518636186371863818639186401864118642186431864418645186461864718648186491865018651186521865318654186551865618657186581865918660186611866218663186641866518666186671866818669186701867118672186731867418675186761867718678186791868018681186821868318684186851868618687186881868918690186911869218693186941869518696186971869818699187001870118702187031870418705187061870718708187091871018711187121871318714187151871618717187181871918720187211872218723187241872518726187271872818729187301873118732187331873418735187361873718738187391874018741187421874318744187451874618747187481874918750187511875218753187541875518756187571875818759187601876118762187631876418765187661876718768187691877018771187721877318774187751877618777187781877918780187811878218783187841878518786187871878818789187901879118792187931879418795187961879718798187991880018801188021880318804188051880618807188081880918810188111881218813188141881518816188171881818819188201882118822188231882418825188261882718828188291883018831188321883318834188351883618837188381883918840188411884218843188441884518846188471884818849188501885118852188531885418855188561885718858188591886018861188621886318864188651886618867188681886918870188711887218873188741887518876188771887818879188801888118882188831888418885188861888718888188891889018891188921889318894188951889618897188981889918900189011890218903189041890518906189071890818909189101891118912189131891418915189161891718918189191892018921189221892318924189251892618927189281892918930189311893218933189341893518936189371893818939189401894118942189431894418945189461894718948189491895018951189521895318954189551895618957189581895918960189611896218963189641896518966189671896818969189701897118972189731897418975189761897718978189791898018981189821898318984189851898618987189881898918990189911899218993189941899518996189971899818999190001900119002190031900419005190061900719008190091901019011190121901319014190151901619017190181901919020190211902219023190241902519026190271902819029190301903119032190331903419035190361903719038190391904019041190421904319044190451904619047190481904919050190511905219053190541905519056190571905819059190601906119062190631906419065190661906719068190691907019071190721907319074190751907619077190781907919080190811908219083190841908519086190871908819089190901909119092190931909419095190961909719098190991910019101191021910319104191051910619107191081910919110191111911219113191141911519116191171911819119191201912119122191231912419125191261912719128191291913019131191321913319134191351913619137191381913919140191411914219143191441914519146191471914819149191501915119152191531915419155191561915719158191591916019161191621916319164191651916619167191681916919170191711917219173191741917519176191771917819179191801918119182191831918419185191861918719188191891919019191191921919319194191951919619197191981919919200192011920219203192041920519206192071920819209192101921119212192131921419215192161921719218192191922019221192221922319224192251922619227192281922919230192311923219233192341923519236192371923819239192401924119242192431924419245192461924719248192491925019251192521925319254192551925619257192581925919260192611926219263192641926519266192671926819269192701927119272192731927419275192761927719278192791928019281192821928319284192851928619287192881928919290192911929219293192941929519296192971929819299193001930119302193031930419305193061930719308193091931019311193121931319314193151931619317193181931919320193211932219323193241932519326193271932819329193301933119332193331933419335193361933719338193391934019341193421934319344193451934619347193481934919350193511935219353193541935519356193571935819359193601936119362193631936419365193661936719368193691937019371193721937319374193751937619377193781937919380193811938219383193841938519386193871938819389193901939119392193931939419395193961939719398193991940019401194021940319404194051940619407194081940919410194111941219413194141941519416194171941819419194201942119422194231942419425194261942719428194291943019431194321943319434194351943619437194381943919440194411944219443194441944519446194471944819449194501945119452194531945419455194561945719458194591946019461194621946319464194651946619467194681946919470194711947219473194741947519476194771947819479194801948119482194831948419485194861948719488194891949019491194921949319494194951949619497194981949919500195011950219503195041950519506195071950819509195101951119512195131951419515195161951719518195191952019521195221952319524195251952619527195281952919530195311953219533195341953519536195371953819539195401954119542195431954419545195461954719548195491955019551195521955319554195551955619557195581955919560195611956219563195641956519566195671956819569195701957119572195731957419575195761957719578195791958019581195821958319584195851958619587195881958919590195911959219593195941959519596195971959819599196001960119602196031960419605196061960719608196091961019611196121961319614196151961619617196181961919620196211962219623196241962519626196271962819629196301963119632196331963419635196361963719638196391964019641196421964319644196451964619647196481964919650196511965219653196541965519656196571965819659196601966119662196631966419665196661966719668196691967019671196721967319674196751967619677196781967919680196811968219683196841968519686196871968819689196901969119692196931969419695196961969719698196991970019701197021970319704197051970619707197081970919710197111971219713197141971519716197171971819719197201972119722197231972419725197261972719728197291973019731197321973319734197351973619737197381973919740197411974219743197441974519746197471974819749197501975119752197531975419755197561975719758197591976019761197621976319764197651976619767197681976919770197711977219773197741977519776197771977819779197801978119782197831978419785197861978719788197891979019791197921979319794197951979619797197981979919800198011980219803198041980519806198071980819809198101981119812198131981419815198161981719818198191982019821198221982319824198251982619827198281982919830198311983219833198341983519836198371983819839198401984119842198431984419845198461984719848198491985019851198521985319854198551985619857198581985919860198611986219863198641986519866198671986819869198701987119872198731987419875198761987719878198791988019881198821988319884198851988619887198881988919890198911989219893198941989519896198971989819899199001990119902199031990419905199061990719908199091991019911199121991319914199151991619917199181991919920199211992219923199241992519926199271992819929199301993119932199331993419935199361993719938199391994019941199421994319944199451994619947199481994919950199511995219953199541995519956199571995819959199601996119962199631996419965199661996719968199691997019971199721997319974199751997619977199781997919980199811998219983
  1. from functools import wraps, partial
  2. from itertools import product, chain, islice
  3. import itertools
  4. import functools
  5. import copy
  6. import operator
  7. import random
  8. import unittest
  9. import math
  10. import enum
  11. import torch
  12. import numpy as np
  13. from torch import inf, nan
  14. from typing import Any, Dict, List, Tuple, Union, Sequence
  15. from torch.testing import make_tensor
  16. from torch.testing._internal.common_dtype import (
  17. _dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types,
  18. floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and,
  19. all_types, empty_types, complex_types_and, integral_types
  20. )
  21. from torch.testing._internal.common_device_type import \
  22. (onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver,
  23. skipCUDAIfNoCusolver, skipCPUIfNoLapack, skipCPUIfNoFFT, skipCUDAIf, precisionOverride,
  24. skipCPUIfNoMklSparse,
  25. toleranceOverride, tol)
  26. from torch.testing._internal.common_cuda import (
  27. SM53OrLater, SM60OrLater, with_tf32_off, TEST_CUDNN,
  28. _get_torch_cuda_version, _get_torch_rocm_version, PLATFORM_SUPPORTS_FUSED_SDPA,
  29. SM80OrLater
  30. )
  31. from torch.testing._internal.common_utils import (
  32. make_fullrank_matrices_with_distinct_singular_values,
  33. TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, TEST_SCIPY,
  34. torch_to_numpy_dtype_dict, TEST_WITH_ASAN,
  35. GRADCHECK_NONDET_TOL, freeze_rng_state, slowTest, TEST_WITH_SLOW
  36. )
  37. import torch._refs as refs # noqa: F401
  38. import torch._refs.nn.functional
  39. import torch._refs.special
  40. import torch._refs.linalg
  41. import torch._prims as prims # noqa: F401
  42. from torch.utils._pytree import tree_flatten
  43. from distutils.version import LooseVersion
  44. from torch.testing._internal.opinfo.core import ( # noqa: F401
  45. L,
  46. M,
  47. S,
  48. XS,
  49. _NOTHING,
  50. _getattr_qual,
  51. DecorateInfo,
  52. SampleInput,
  53. ErrorInput,
  54. AliasInfo,
  55. NumericsFilter,
  56. OpInfo,
  57. _generate_reduction_inputs,
  58. _generate_reduction_kwargs,
  59. sample_inputs_reduction,
  60. ReductionOpInfo,
  61. reference_inputs_elementwise_binary,
  62. make_error_inputs_elementwise_binary,
  63. generate_elementwise_binary_tensors,
  64. generate_elementwise_binary_arbitrarily_strided_tensors,
  65. generate_elementwise_binary_small_value_tensors,
  66. generate_elementwise_binary_large_value_tensors,
  67. generate_elementwise_binary_extremal_value_tensors,
  68. generate_elementwise_binary_broadcasting_tensors,
  69. generate_elementwise_binary_with_scalar_samples,
  70. generate_elementwise_binary_with_scalar_and_type_promotion_samples,
  71. generate_elementwise_binary_noncontiguous_tensors,
  72. sample_inputs_elementwise_binary,
  73. BinaryUfuncInfo,
  74. sample_inputs_elementwise_unary,
  75. generate_elementwise_unary_tensors,
  76. generate_elementwise_unary_small_value_tensors,
  77. generate_elementwise_unary_large_value_tensors,
  78. generate_elementwise_unary_extremal_value_tensors,
  79. reference_inputs_elementwise_unary,
  80. UnaryUfuncInfo,
  81. sample_inputs_spectral_ops,
  82. SpectralFuncType,
  83. SpectralFuncInfo,
  84. ShapeFuncInfo,
  85. sample_inputs_foreach,
  86. ForeachFuncInfo,
  87. gradcheck_wrapper_hermitian_input,
  88. gradcheck_wrapper_triangular_input,
  89. gradcheck_wrapper_triangular_input_real_positive_diagonal,
  90. gradcheck_wrapper_masked_operation,
  91. gradcheck_wrapper_masked_pointwise_operation,
  92. clone_sample,
  93. )
  94. from torch.testing._internal.opinfo.refs import ( # NOQA: F401
  95. _find_referenced_opinfo,
  96. _inherit_constructor_args,
  97. PythonRefInfo,
  98. ReductionPythonRefInfo,
  99. ElementwiseUnaryPythonRefInfo,
  100. ElementwiseBinaryPythonRefInfo,
  101. )
  102. from torch.testing._internal.opinfo.utils import (
  103. np_unary_ufunc_integer_promotion_wrapper,
  104. reference_reduction_numpy,
  105. prod_numpy
  106. )
  107. from torch.testing._internal import opinfo
  108. from torch.testing._internal.opinfo.definitions.linalg import (
  109. sample_inputs_linalg_cholesky,
  110. sample_inputs_linalg_cholesky_inverse,
  111. sample_inputs_cross,
  112. sample_inputs_linalg_qr_geqrf,
  113. sample_inputs_linalg_invertible,
  114. sample_inputs_lu_solve,
  115. sample_inputs_legacy_solve,
  116. sample_inputs_svd,
  117. sample_inputs_linalg_det_logdet_slogdet,
  118. sample_inputs_linalg_lu,
  119. )
  120. from torch.testing._internal.opinfo.definitions.special import (
  121. sample_inputs_i0_i1,
  122. sample_inputs_polygamma,
  123. reference_polygamma,
  124. )
  125. from torch.testing._internal.opinfo.definitions._masked import (
  126. sample_inputs_softmax_variant,
  127. )
  128. if TEST_SCIPY:
  129. from scipy import stats
  130. import scipy.spatial
  131. import scipy.special
  132. # test if a tensor is close to an integer
  133. def close_to_int(x, eps=0.1):
  134. if x.is_complex():
  135. y = torch.abs(torch.view_as_complex(torch.frac(torch.view_as_real(x))))
  136. else:
  137. y = torch.abs(torch.frac(x))
  138. return (y < eps) | (y > (1 - eps))
  139. def sample_inputs_slice(op_info, device, dtype, requires_grad, **kwargs):
  140. make_input = partial(make_tensor, device=device, dtype=dtype,
  141. low=None, high=None, requires_grad=requires_grad)
  142. yield SampleInput(make_input(3), 0)
  143. yield SampleInput(make_input(20, 30, 40), dim=1, start=1, end=-2)
  144. yield SampleInput(make_input(20, 30, 40), dim=1, start=1, end=-2, step=3)
  145. yield SampleInput(make_input(20, 30, 40), dim=0, start=-10, end=-2, step=2)
  146. def sample_inputs_tensor_split(op_info, device, dtype, requires_grad, **kwargs):
  147. make_input = partial(make_tensor, device=device, dtype=dtype,
  148. low=None, high=None, requires_grad=requires_grad)
  149. args_cases = (
  150. # Cases with tensor indices.
  151. (torch.tensor([1, 2, 3]),),
  152. (torch.tensor(1),),
  153. (torch.tensor([1, 2, 3]), 1),
  154. (torch.tensor([1, 4, 2, 5, 3, 6])[::2], 1),
  155. # Cases with list of indices.
  156. ((2, 4),),
  157. ((2, 4), 1),
  158. ((2, 4), -1),
  159. # Cases with integer section.
  160. (3,),
  161. (3, 1),
  162. (3, -1),
  163. )
  164. for args in args_cases:
  165. yield SampleInput(make_input((S, S, S)), args=args)
  166. def sample_inputs_hsplit(op_info, device, dtype, requires_grad, **kwargs):
  167. make_arg = partial(make_tensor, dtype=dtype, device=device,
  168. low=None, high=None, requires_grad=requires_grad)
  169. yield SampleInput(make_arg(6), 2)
  170. yield SampleInput(make_arg(S, S, S), [1, 2, 3])
  171. def sample_inputs_vsplit(op_info, device, dtype, requires_grad, **kwargs):
  172. make_arg = partial(make_tensor, dtype=dtype, device=device,
  173. low=None, high=None, requires_grad=requires_grad)
  174. yield SampleInput(make_arg(6, S), 2)
  175. yield SampleInput(make_arg(S, S, S), [1, 2, 3])
  176. def sample_inputs_dsplit(op_info, device, dtype, requires_grad, **kwargs):
  177. make_arg = partial(make_tensor, dtype=dtype, device=device,
  178. low=None, high=None, requires_grad=requires_grad)
  179. yield SampleInput(make_arg(S, S, S), [1, 2, 3])
  180. yield SampleInput(make_arg(S, S, 6), 2)
  181. def error_inputs_hsplit(op_info, device, **kwargs):
  182. make_arg = partial(make_tensor, dtype=torch.float32, device=device)
  183. err_msg1 = ("torch.hsplit requires a tensor with at least 1 dimension, "
  184. "but got a tensor with 0 dimensions!")
  185. yield ErrorInput(SampleInput(make_arg(()), 0), error_regex=err_msg1)
  186. err_msg2 = (f"torch.hsplit attempted to split along dimension 1, "
  187. f"but the size of the dimension {S} "
  188. f"is not divisible by the split_size 0!")
  189. yield ErrorInput(SampleInput(make_arg((S, S, S)), 0), error_regex=err_msg2)
  190. # Incorrect type for indices_or_section argument
  191. err_msg3 = ("received an invalid combination of arguments.")
  192. yield ErrorInput(
  193. SampleInput(make_arg((S, S, S)), "abc"),
  194. error_type=TypeError, error_regex=err_msg3)
  195. def error_inputs_vsplit(op_info, device, **kwargs):
  196. make_arg = partial(make_tensor, dtype=torch.float32, device=device)
  197. err_msg1 = ("torch.vsplit requires a tensor with at least 2 dimension, "
  198. "but got a tensor with 1 dimensions!")
  199. yield ErrorInput(SampleInput(make_arg(S), 0), error_regex=err_msg1)
  200. err_msg2 = (f"torch.vsplit attempted to split along dimension 0, "
  201. f"but the size of the dimension {S} "
  202. f"is not divisible by the split_size 0!")
  203. yield ErrorInput(SampleInput(make_arg(S, S, S), 0),
  204. error_regex=err_msg2)
  205. # Incorrect type for indices_or_section argument
  206. err_msg3 = ("received an invalid combination of arguments.")
  207. yield ErrorInput(SampleInput(make_arg(S, S, S), "abc"),
  208. error_type=TypeError, error_regex=err_msg3)
  209. def error_inputs_dsplit(op_info, device, **kwargs):
  210. make_arg = partial(make_tensor, dtype=torch.float32, device=device)
  211. err_msg1 = ("torch.dsplit requires a tensor with at least 3 dimension, "
  212. "but got a tensor with 1 dimensions!")
  213. yield ErrorInput(SampleInput(make_arg(S), 0), error_regex=err_msg1)
  214. err_msg2 = (f"torch.dsplit attempted to split along dimension 2, "
  215. f"but the size of the dimension {S} "
  216. f"is not divisible by the split_size 0!")
  217. yield ErrorInput(SampleInput(make_arg(S, S, S), 0), error_regex=err_msg2)
  218. def sample_inputs_as_strided(op_info, device, dtype, requires_grad, **kwargs):
  219. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  220. # input shape, output shape, output stride, output storage offset
  221. test_cases = (
  222. ((1,), (1,), (1,), 0),
  223. ((3, 3), (2, 2), (1, 2), 0),
  224. ((3, 3), (2, 2), (1, 2), 1),
  225. ((16,), (2, 2, 2, 2), (1, 1, 1, 1), 0),
  226. ((16,), (2, 1, 1, 2), (1, 7, 7, 1), 0),
  227. )
  228. for input_shape, output_shape, stride, storage_offset in test_cases:
  229. input_t = make_arg(input_shape)
  230. kwargs = dict(storage_offset=storage_offset)
  231. yield SampleInput(input_t, args=(output_shape, stride), kwargs=kwargs)
  232. def sample_inputs_as_strided_partial_views(op_info, device, dtype, requires_grad, **kwargs):
  233. def make_arg():
  234. base = make_tensor((20,), device=device, dtype=dtype)
  235. return base[5:15].requires_grad_(requires_grad)
  236. # as_strided on offset, partial views
  237. yield SampleInput(make_arg(), (2, 2), (1, 2))
  238. yield SampleInput(make_arg(), (2, 2), (1, 2), storage_offset=0)
  239. yield SampleInput(make_arg(), (2, 2), (1, 2), storage_offset=10)
  240. def sample_inputs_as_strided_scatter(op_info, device, dtype, requires_grad, **kwargs):
  241. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  242. # input shape, output shape, output stride, output storage offset
  243. test_cases = [
  244. ((1,), (), (), 0),
  245. ((1,), (1,), (1,), 0),
  246. ((3, 3), (2, 2), (1, 2), 0),
  247. ((3, 3), (2, 2), (1, 2), 1),
  248. ((3, 3), (2, 2), (2, 1), 0),
  249. # Scatter to larger dimentions
  250. ((16,), (2, 2, 2, 2), (8, 4, 2, 1), 0),
  251. # Scatter to larger dimentions with strides inverted
  252. ((16,), (2, 1, 1, 2), (1, 2, 4, 8), 0),
  253. ]
  254. for input_shape, output_shape, stride, storage_offset in test_cases:
  255. input_t = make_arg(input_shape)
  256. input_src = make_arg(output_shape)
  257. yield SampleInput(input_t, input_src, output_shape, stride, storage_offset=storage_offset)
  258. def error_inputs_as_strided_scatter(op_info, device, **kwargs):
  259. make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False)
  260. # Create a small tensor and try to scatter it out of bounds
  261. input_t = make_arg([4, 4])
  262. input_src = make_arg([2, 2])
  263. yield ErrorInput(
  264. SampleInput(input_t, input_src, [2, 2], [200, 200], storage_offset=0),
  265. error_regex="itemsize 4 requiring a storage size of 1604 are out of bounds for storage of size 64"
  266. )
  267. def sample_inputs_combinations(op_info, device, dtype, requires_grad, **kwargs):
  268. inputs = (
  269. (0,),
  270. (0, 1),
  271. (0, 1, 2, 3),
  272. )
  273. rvals = [1, 2, 4]
  274. products = product(inputs, rvals, [False, True])
  275. for input_data, r, with_replacement in products:
  276. input_t = torch.tensor(input_data, device=device, dtype=dtype, requires_grad=requires_grad)
  277. yield SampleInput(input_t, r=r, with_replacement=with_replacement)
  278. def sample_inputs_cartesian_prod(op_info, device, dtype, requires_grad, **kwargs):
  279. make_arg = partial(torch.tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  280. # constructs 1-D tensors with varying number of elements
  281. a = make_arg((0,))
  282. b = make_arg((0, 1))
  283. c = make_arg((0, 1, 2, 3))
  284. # sample with only 1 tensor
  285. yield SampleInput(a)
  286. # sample with 2 tensors
  287. yield SampleInput(a, b)
  288. # sample with 3 tensors
  289. yield SampleInput(a, b, c)
  290. def sample_inputs_cosine_similarity(op_info, device, dtype, requires_grad, **kwargs):
  291. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  292. # Ordered as input_shape, dict of dim and eps
  293. cases: Tuple[tuple, dict] = ( # type: ignore[assignment]
  294. ((S, S), {'dim': 1}),
  295. ((S, 2), {'dim': -1}),
  296. ((S,), {'dim': 0, 'eps': 0.5}),
  297. ((), {'dim': 0}),
  298. ((S, S, M), {'dim': 2}),
  299. ((S, S), {})
  300. )
  301. for input_shape, kwargs in cases:
  302. yield SampleInput(make_arg(input_shape), args=(make_arg(input_shape),), kwargs=kwargs)
  303. # Test for Broadcasting
  304. yield SampleInput(make_arg((1, 2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -1})
  305. yield SampleInput(make_arg((1, 2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -2})
  306. yield SampleInput(make_arg((2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -1})
  307. def sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs):
  308. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  309. make_arg_without_requires_grad = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
  310. # Ordered as: input shape, kwargs for training, momentum, eps
  311. cases: Tuple[Tuple[int], dict] = ( # type: ignore[assignment]
  312. ((S, S, S), {'training': True, 'momentum': 0.5, 'eps': 0.6}),
  313. ((3, 2, 4), {'training': False, 'momentum': -1.2}),
  314. ((3, 1), {'training': True, 'momentum': 0.0}),
  315. ((0,), {'training': True}),
  316. ((0,), {'training': False}),
  317. ((3, 2, 3, 4), {'training': True, 'momentum': -1.0, 'eps': 0.5}),
  318. ((3, 2, 3, 4), {'training': False, 'momentum': -1.0, 'eps': 0.5}),
  319. ((2, 1), {}),
  320. )
  321. for input_shape, kwargs in cases:
  322. # args: running mean, running var, weight and bias should necessarily be of shape: (channels,)
  323. channels = input_shape[1] if len(input_shape) > 1 else 0
  324. weight = make_arg(channels) if channels > 0 else None
  325. bias = make_arg(channels) if channels > 0 else None
  326. running_mean = make_arg_without_requires_grad(channels, low=0)
  327. running_var = make_arg_without_requires_grad(channels, low=0)
  328. yield SampleInput(
  329. make_arg(input_shape),
  330. args=(
  331. running_mean,
  332. running_var,
  333. weight,
  334. bias
  335. ),
  336. kwargs=kwargs
  337. )
  338. # Checking for permutations of weights and biases as `None`
  339. weights = [channels, None, None]
  340. biases = [None, channels, None]
  341. is_training = [True, False, False]
  342. for weight, bias, training in zip(weights, biases, is_training):
  343. yield SampleInput(
  344. make_arg(input_shape),
  345. args=(
  346. running_mean,
  347. running_var,
  348. make_arg(channels),
  349. make_arg(channels)
  350. ),
  351. kwargs={'training': training}
  352. )
  353. # Test case for no optional kwargs
  354. # running_mean and running_var are required in evaluation mode (training: False) but not in training mode
  355. yield SampleInput(make_arg((1, 2, 3)), args=(None, None, None, None), kwargs={'training': True})
  356. def sample_inputs_softmax_backward_data(op_info, device, dtype, requires_grad, **kwargs):
  357. make_arg = partial(
  358. make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
  359. )
  360. cases = [
  361. ((S,), 0),
  362. ((S, S), 0),
  363. ((S, M, S), -1),
  364. ]
  365. input_dtypes = [dtype]
  366. if dtype == torch.float and device == 'cuda':
  367. input_dtypes += [torch.float16]
  368. for (shape, dim), input_dtype in product(cases, input_dtypes):
  369. input = make_arg(shape)
  370. output = torch.nn.functional.softmax(input, dim=dim, dtype=input_dtype)
  371. yield SampleInput(make_arg(shape), output, dim, input_dtype)
  372. def sample_inputs_native_batch_norm(op_info, device, dtype, requires_grad, **kwargs):
  373. samples = sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs)
  374. for sample in samples:
  375. # torch.native_batch_norm does not support 0 numel tensors
  376. # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
  377. if sample.input.numel() == 0:
  378. continue
  379. args = sample.args
  380. training = sample.kwargs.get('training', True)
  381. momentum = sample.kwargs.get('momentum', 0.5)
  382. eps = sample.kwargs.get('eps', 1e-5)
  383. yield SampleInput(sample.input, args=(args[2], args[3], args[0], args[1], training, momentum, eps))
  384. def sample_inputs__native_batch_norm_legit(op_info, device, dtype, requires_grad, **kwargs):
  385. samples = sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs)
  386. for sample in samples:
  387. # torch.native_batch_norm does not support 0 numel tensors
  388. # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
  389. if sample.input.numel() == 0:
  390. continue
  391. args = sample.args
  392. training = sample.kwargs.get('training', True)
  393. momentum = sample.kwargs.get('momentum', 0.5)
  394. eps = sample.kwargs.get('eps', 1e-5)
  395. if args[0] is not None and args[1] is not None:
  396. yield SampleInput(sample.input, args=(args[2], args[3], args[0], args[1], training, momentum, eps))
  397. else:
  398. yield SampleInput(sample.input, args=(args[2], args[3], training, momentum, eps))
  399. def sample_inputs_nn_activation_relu(op_info, device, dtype, requires_grad, **kwargs):
  400. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  401. cases = (
  402. (()),
  403. ((S, )),
  404. ((S, S)),
  405. ((S, M, S))
  406. )
  407. for shape in cases:
  408. yield SampleInput(make_arg(shape))
  409. def sample_inputs_prelu(op_info, device, dtype, requires_grad, **kwargs):
  410. op_kwargs = op_info.sample_kwargs(device, dtype, None)[0]
  411. yield from sample_inputs_elementwise_unary(op_info, device, dtype, requires_grad,
  412. op_kwargs=op_kwargs)
  413. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  414. cases = (
  415. (()),
  416. ((S, )),
  417. ((S, S)),
  418. ((S, M, S))
  419. )
  420. for shape in cases:
  421. for weight in [-1., 0., 0.8, 1.]:
  422. weight_tensor = torch.tensor(weight, device=device, dtype=dtype, requires_grad=requires_grad)
  423. yield SampleInput(make_arg(shape), args=(weight_tensor,))
  424. channel_size = shape[1] if len(shape) >= 2 else 1
  425. yield SampleInput(make_arg(shape), args=(make_arg((channel_size,)),))
  426. weight_tensor = torch.tensor(1., device=device, dtype=dtype, requires_grad=requires_grad)
  427. yield SampleInput(make_arg((S, S)), kwargs=dict(weight=weight_tensor,))
  428. yield SampleInput(make_arg((S, S)), kwargs=dict(weight=make_arg((S,)),))
  429. def reference_inputs_prelu(op, device, dtype, requires_grad, **kwargs):
  430. yield from sample_inputs_prelu(op, device, dtype, requires_grad, **kwargs)
  431. yield from reference_inputs_elementwise_unary(op, device, dtype, requires_grad, **kwargs)
  432. def sample_kwargs_prelu_scalar_weight(device, dtype, input):
  433. weight = torch.rand(tuple(), device=device, dtype=dtype)
  434. # NumPy does not support bfloat16, so we default to float32 (only for NumPy) in that case
  435. if dtype == torch.bfloat16:
  436. weight_cpu = weight.to(dtype=torch.float32, device="cpu")
  437. else:
  438. weight_cpu = weight.cpu()
  439. np_weight = weight_cpu.numpy()
  440. return ({'weight': weight}, {'weight': np_weight})
  441. def error_inputs_prelu(op, device):
  442. # Weight has numel != 1, but self.ndim is zero-dim tensor
  443. inp = make_tensor(tuple(), device=device, dtype=torch.float32)
  444. weight = make_tensor((2,), device=device, dtype=torch.float32)
  445. yield ErrorInput(SampleInput(inp, kwargs={'weight': weight}),
  446. error_regex="Not allow zero-dim input tensor.")
  447. # Weight has numel != 1, but numel does not match channel size
  448. inp = make_tensor((2, 8, 3), device=device, dtype=torch.float32)
  449. weight = make_tensor((9,), device=device, dtype=torch.float32)
  450. yield ErrorInput(SampleInput(inp, kwargs={'weight': weight}),
  451. error_regex="Mismatch of parameter numbers and input channel size.")
  452. # Weight is neither a scalar nor 1-D tensor
  453. inp = make_tensor((2, 8, 3), device=device, dtype=torch.float32)
  454. weight = make_tensor((2, 4), device=device, dtype=torch.float32)
  455. yield ErrorInput(SampleInput(inp, kwargs={'weight': weight}),
  456. error_regex="prelu: Expected `weight` to be a scalar or 1D tensor, but got: ndim = 2")
  457. # src and index tensors must have the same # of dimensions
  458. def sample_inputs_norm(op_info, device, dtype, requires_grad, **kwargs):
  459. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  460. # ord = inf is tested in inputs_norm_inf as it fails on some tests
  461. cases = [
  462. ((S, S), (2,), '2'),
  463. ((S, S), (0,), '0'),
  464. ((S, S), (0.5,), '0_5'),
  465. ((S, S), (1,), '1'),
  466. ((S, S), (3,), '3'),
  467. ((S, S), (-1,), 'neg_1'),
  468. ((S, S), (-2,), 'neg_2'),
  469. ((S, S), (-0.5,), 'neg_0_5'),
  470. ((S, S), (-1.5,), 'neg_1_5'),
  471. ]
  472. cases_nonzero_input = (
  473. ((S, S, S), (1.5,), '1_5_default'),
  474. ((S, S, S), (1.5, 1), '1_5_dim'),
  475. ((S, S, S), (1.5, -1), '1_5_neg_dim'),
  476. ((S, S, S), (1.5, 1, True), 'keepdim_1_5_dim'),
  477. ((S, S, S), (1.5, -1, True), 'keepdim_1_5_neg_dim'),
  478. )
  479. cases_posdim = (
  480. ((S, S), (-2, 1,), 'neg_2_dim'),
  481. ((S, S), (-1, 1,), 'neg_1_dim'),
  482. ((S, S), (0, 1,), '0_dim'),
  483. ((S, S), (1, 1,), '1_dim'),
  484. ((S, S), (2, 1,), '2_dim'),
  485. ((S, S), (3, 1,), '3_dim'),
  486. ((S, S, S), (2, 1), '2_dim'),
  487. ((S, S, S), (3, 1), '3_dim'),
  488. ((S, S, S), (2, 1, True), 'keepdim_2_dim'),
  489. ((S, S, S), (3, 1, True), 'keepdim_3_dim'),
  490. ((), (2, 0), '2_dim_scalar'),
  491. ((), (3, 0), '3_dim_scalar'),
  492. ((), (2, 0, True), 'keepdim_2_dim_scalar'),
  493. ((), (3, 0, True), 'keepdim_3_dim_scalar'),
  494. )
  495. cases_negdim = ((shape, args[:1] + (-args[1],) + args[2:], name.replace("_dim", "_neg_dim"))
  496. for shape, args, name in cases_posdim)
  497. for shape, args, name in itertools.chain(cases, cases_posdim, cases_negdim):
  498. yield SampleInput(make_arg(shape), args=args, name=name)
  499. for shape, args, name in cases_nonzero_input:
  500. yield SampleInput(make_arg(shape, exclude_zero=True), args=args, name=name)
  501. def sample_inputs_norm_fro(op_info, device, dtype, requires_grad, **kwargs):
  502. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  503. cases = (
  504. ((S, S), (), 'default'),
  505. ((S, S), ('fro',), 'fro_default'),
  506. ((S, S), ('fro', [0, 1],), 'fro'),
  507. )
  508. for shape, args, name in cases:
  509. yield SampleInput(make_arg(shape), args=args, name=name)
  510. def sample_inputs_norm_nuc(op_info, device, dtype, requires_grad, **kwargs):
  511. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  512. cases = (
  513. ((S, S), ('nuc',), 'nuc'),
  514. ((S, S, S), ('nuc', [1, 2]), 'nuc_batched'),
  515. )
  516. for shape, args, name in cases:
  517. yield SampleInput(make_arg(shape), args=args, name=name)
  518. def sample_inputs_norm_inf(op_info, device, dtype, requires_grad, **kwargs):
  519. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  520. cases = (
  521. ((S, S), (-inf,), '-inf'),
  522. ((S, S), (inf,), 'inf'),
  523. ((S, S), (inf, 1,), 'inf_2_dim'),
  524. ((S, S), (inf, -1,), 'inf_2_neg_dim'),
  525. )
  526. for shape, args, name in cases:
  527. yield SampleInput(make_arg(shape), args=args, name=name)
  528. def sample_inputs_equal(op, device, dtype, requires_grad, **kwargs):
  529. make_arg = partial(
  530. make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  531. shapes = (
  532. ((), ()),
  533. ((S,), ()),
  534. ((), (S,)),
  535. ((S, 1), (S,)),
  536. ((M, S), ()),
  537. ((S, S), (S, S))
  538. )
  539. for shape_lhs, shape_rhs in shapes:
  540. lhs = make_arg(shape_lhs)
  541. rhs = make_arg(shape_rhs)
  542. broadcasts_input = shape_lhs != torch.broadcast_shapes(shape_lhs, shape_rhs)
  543. yield SampleInput(lhs, args=(rhs,), broadcasts_input=broadcasts_input)
  544. if shape_lhs == shape_rhs:
  545. yield SampleInput(lhs, args=(lhs.clone().detach_(),))
  546. def sample_inputs_jiterator(op, device, dtype, requires_grad, **kwargs):
  547. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  548. shapes = (
  549. ((), ()),
  550. ((S,), ()),
  551. ((S, 1), (S,)),
  552. ((M, S), ()),
  553. ((S, M, S), (M, S)),
  554. ((S, M, S), (S, M, S)),
  555. ((M, 1, S), (M, S)),
  556. ((M, 1, S), (1, M, S)),
  557. ((0, 1, 3), (0, 10, 3))
  558. )
  559. num_inputs = kwargs.get('num_inputs')
  560. sample_kwargs = kwargs.get('sample_kwargs', {})
  561. for shape_lhs, shape_rhs in shapes:
  562. lhs = make_arg(shape_lhs)
  563. args = []
  564. for i in range(num_inputs - 1):
  565. args.append(make_arg(shape_rhs))
  566. broadcasts_input = (shape_lhs != torch.broadcast_shapes(shape_lhs, shape_rhs))
  567. yield SampleInput(lhs, args=tuple(args), kwargs=sample_kwargs, broadcasts_input=broadcasts_input)
  568. def sample_inputs_broadcast_shapes(op, device, dtype, requires_grad, **kwargs):
  569. shapes = (
  570. ((), ()),
  571. ((S,), ()),
  572. ((S, 1), (S,)),
  573. ((S, 1), S),
  574. ((M, S), ()),
  575. ((S, M, S), (M, S)),
  576. ((S, M, S), (S, M, S)),
  577. ((M, 1, S), (M, S)),
  578. ((M, 1, S), (1, M, S)),
  579. ((0, 1, 3), (0, 10, 3))
  580. )
  581. for shape in shapes:
  582. inp, *arg0 = shape
  583. yield SampleInput(inp, args=tuple(arg0))
  584. def sample_inputs_add_sub(op, device, dtype, requires_grad, **kwargs):
  585. yield from sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs)
  586. # Adds alpha kwarg cases
  587. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  588. lhs = make_arg((S, S), **op.lhs_make_tensor_kwargs)
  589. rhs = make_arg((S, S), **op.rhs_make_tensor_kwargs)
  590. if dtype is not torch.bool:
  591. yield SampleInput(lhs, args=(rhs,), kwargs={'alpha': 2})
  592. else:
  593. yield SampleInput(lhs, args=(rhs,), kwargs={'alpha': True})
  594. neg_alpha = -3.125 if (dtype.is_floating_point or dtype.is_complex) else -3
  595. lhs = make_arg((S, S), **op.lhs_make_tensor_kwargs)
  596. rhs = make_arg((S, S), **op.rhs_make_tensor_kwargs)
  597. if dtype is not torch.bool:
  598. yield SampleInput(lhs, args=(rhs,), kwargs={'alpha': neg_alpha})
  599. else:
  600. yield SampleInput(lhs, args=(rhs,), kwargs={'alpha': False})
  601. def error_inputs_arange(op, device, **kwargs):
  602. yield ErrorInput(SampleInput(0, args=(3, 0)), error_type=RuntimeError, error_regex='step must be nonzer')
  603. yield ErrorInput(SampleInput(0, args=(-3, 2)), error_type=RuntimeError, error_regex='bound inconsistent with step sign')
  604. yield ErrorInput(SampleInput(0, args=(3, -2)), error_type=RuntimeError, error_regex='bound inconsistent with step sign')
  605. yield ErrorInput(SampleInput(0, args=(float('inf'), 2)), error_type=RuntimeError, error_regex='unsupported range')
  606. yield ErrorInput(SampleInput(float('-inf'), args=(1, 2)), error_type=RuntimeError, error_regex='unsupported range')
  607. def sample_inputs_arange(op, device, dtype, requires_grad, **kwargs):
  608. int_samples = (
  609. # positive direction
  610. (-1, 2, 2),
  611. # negative direction
  612. (2, -3, -1),
  613. # start == end
  614. (1, 1, 1),
  615. (1, 1, -1),
  616. # divides evenly
  617. (0, -8, -4),
  618. (1, 5, 2),
  619. # bool
  620. (False, True, True),
  621. # default step
  622. (0, 1, None),
  623. # default start
  624. (None, 3, None),
  625. )
  626. def to_float(start, end, step):
  627. start = start + 0.1 if start is not None else None
  628. end = end + 0.1
  629. step = float(step) if step is not None else None
  630. return start, end, step
  631. float_samples = (
  632. # includes endpoint
  633. (0., -8. - 1e-6, -4.),
  634. (1., 5. + 1e-6, 2.),
  635. (0., -8., -4.),
  636. (1., 5., 2.),
  637. *(to_float(start, end, step) for (start, end, step) in int_samples),
  638. )
  639. large_samples = (
  640. (0, 10000, None),
  641. )
  642. samples = int_samples + float_samples
  643. if dtype not in (torch.int8, torch.uint8):
  644. samples += large_samples
  645. for start, end, step in samples:
  646. if start is None:
  647. assert step is None
  648. # Pass end as positional arg
  649. yield SampleInput(end, kwargs={"dtype": dtype, "device": device})
  650. # (Similar to) calling torch.arange(end=3)
  651. yield SampleInput(0, kwargs={"end": end, "dtype": dtype, "device": device})
  652. elif step is None:
  653. yield SampleInput(start, args=(end,), kwargs={"dtype": dtype, "device": device})
  654. else:
  655. yield SampleInput(start, args=(end, step), kwargs={"dtype": dtype, "device": device})
  656. yield SampleInput(2)
  657. yield SampleInput(1, args=(3, 1))
  658. def sample_inputs_randn(op, device, dtype, requires_grad, **kwargs):
  659. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False)
  660. shapes = (
  661. (M,),
  662. (S, S)
  663. )
  664. for shape in shapes:
  665. yield SampleInput(input=shape, kwargs=dict(dtype=dtype, device=device, requires_grad=requires_grad))
  666. def sample_inputs_cauchy(op, device, dtype, requires_grad, **kwargs):
  667. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False)
  668. samples = (
  669. ((M,), 0, 0.5),
  670. ((S, S), 0, 1),
  671. ((S, S, S), -2, 1),
  672. )
  673. for shape, median, gamma in samples:
  674. yield SampleInput(make_arg(shape), args=(median, gamma))
  675. def error_inputs_cauchy(op, device, **kwargs):
  676. t = torch.zeros([10], device=device)
  677. invalid_scale = 0
  678. yield ErrorInput(
  679. SampleInput(t, args=(0, invalid_scale,)),
  680. error_type=RuntimeError,
  681. error_regex=r"cauchy_ expects sigma > 0.0, but found sigma={}".format(invalid_scale),
  682. )
  683. def sample_inputs_exponential(op, device, dtype, requires_grad, **kwargs):
  684. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False)
  685. samples = (
  686. ((M,), 0.5),
  687. ((S, S), 1),
  688. ((S, S, S), 1.5),
  689. )
  690. for shape, rate in samples:
  691. yield SampleInput(make_arg(shape), args=(rate,))
  692. def error_inputs_exponential(op, device, **kwargs):
  693. t = torch.zeros([10], device=device)
  694. invalid_rate = 0
  695. yield ErrorInput(
  696. SampleInput(t, args=(invalid_rate,)),
  697. error_type=RuntimeError,
  698. error_regex=r"exponential_ expects lambda > 0.0, but found lambda={}".format(invalid_rate),
  699. )
  700. def sample_inputs_geometric(op, device, dtype, requires_grad, **kwargs):
  701. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False)
  702. samples = (
  703. ((M,), 0.2),
  704. ((S, S), 0.5),
  705. ((S, S, S), 0.8),
  706. )
  707. for shape, rate in samples:
  708. yield SampleInput(make_arg(shape), args=(rate,))
  709. def error_inputs_geometric(op, device, **kwargs):
  710. t = torch.zeros([10], device=device)
  711. neg_prob = -1
  712. yield ErrorInput(
  713. SampleInput(t, args=(neg_prob,)),
  714. error_type=RuntimeError,
  715. error_regex=r"geometric_ expects p to be in \(0, 1\), but got p={}".format(neg_prob),
  716. )
  717. def sample_inputs_log_normal(op, device, dtype, requires_grad, **kwargs):
  718. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False)
  719. samples = (
  720. ((M,), 0, 0.25),
  721. ((S, S), 0.5, 1),
  722. ((S, S, S), 0, 0.5),
  723. )
  724. for shape, mean, std in samples:
  725. yield SampleInput(make_arg(shape), args=(mean, std))
  726. def error_inputs_log_normal(op, device, **kwargs):
  727. t = torch.zeros([10], device=device)
  728. invalid_std = 0
  729. yield ErrorInput(
  730. SampleInput(t, args=(0, invalid_std)),
  731. error_type=RuntimeError,
  732. error_regex=r"log_normal_ expects std > 0.0, but found std={}".format(invalid_std),
  733. )
  734. def sample_inputs_uniform(op, device, dtype, requires_grad, **kwargs):
  735. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False)
  736. samples = (
  737. ((M,), -100, 100),
  738. ((S, S), 0, 1),
  739. ((S, S, S), 1, 2),
  740. )
  741. for shape, hi, lo in samples:
  742. yield SampleInput(make_arg(shape), args=(hi, lo))
  743. def sample_inputs_ones_zeros(op, device, dtype, requires_grad, **kwargs):
  744. # this is a bit messy, as we want the args to be tuples
  745. # so if we pass size as a tuple, we have a tuple containing a tuple
  746. sizes = (
  747. (M,),
  748. (S, S),
  749. )
  750. for size in sizes:
  751. yield SampleInput(size, kwargs={'dtype': dtype, 'device': device})
  752. def sample_inputs_full(op, device, dtype, requires_grad, **kwargs):
  753. def get_val(dtype):
  754. return make_tensor([], dtype=dtype, device="cpu").item()
  755. sizes = (
  756. (M,),
  757. (S, S),
  758. )
  759. fill_values = [get_val(dtype), get_val(torch.int)]
  760. for size, fill_value in product(sizes, fill_values):
  761. yield SampleInput(size, fill_value, dtype=dtype, device=device)
  762. def error_inputs_uniform(op, device, **kwargs):
  763. t = torch.zeros([10], device=device)
  764. yield ErrorInput(
  765. SampleInput(t, args=(3, -1)),
  766. error_type=RuntimeError,
  767. error_regex=r"uniform_ expects to return a \[from, to\) range, but found from=3 > to=-1",
  768. )
  769. def error_inputs_linspace(op, device, **kwargs):
  770. yield ErrorInput(SampleInput(0, args=(3, -1)), error_type=RuntimeError, error_regex='number of steps must be non-negative')
  771. yield ErrorInput(SampleInput(0, args=(3, 1.)), error_type=TypeError, error_regex='must be int, not float')
  772. def sample_inputs_linspace(op, device, dtype, requires_grad, **kwargs):
  773. ends = (-3, 0, 1, 4, 50)
  774. starts = (-2., 0, 4.3, 50)
  775. nsteps = (0, 1, 50)
  776. # Extra case to replicate off-by-one issue on CUDA
  777. cases = list(product(starts, ends, nsteps)) + [(0, 7, 50)]
  778. for start, end, nstep in cases:
  779. if dtype == torch.uint8 and end < 0 or start < 0:
  780. continue
  781. yield SampleInput(start, args=(end, nstep), kwargs={"dtype": dtype, "device": device})
  782. yield SampleInput(1, args=(3, 1))
  783. def sample_inputs_logpace(op, device, dtype, requires_grad, **kwargs):
  784. ends = (-3, 0, 1.2, 2, 4)
  785. starts = (-2., 0, 1, 2, 4.3)
  786. nsteps = (0, 1, 2, 4)
  787. bases = (2., 1.1) if dtype in (torch.int8, torch.uint8) else (None, 2., 3., 1.1, 5.)
  788. for start, end, nstep, base in product(starts, ends, nsteps, bases):
  789. if dtype == torch.uint8 and end < 0 or start < 0:
  790. continue
  791. if nstep == 1 and isinstance(start, float) and not (dtype.is_complex or dtype.is_floating_point):
  792. # https://github.com/pytorch/pytorch/issues/82242
  793. continue
  794. if base is None:
  795. yield SampleInput(start, args=(end, nstep), kwargs={"dtype": dtype, "device": device})
  796. else:
  797. yield SampleInput(start, args=(end, nstep, base), kwargs={"dtype": dtype, "device": device})
  798. yield SampleInput(1, args=(3, 1, 2.))
  799. def sample_inputs_isclose(op, device, dtype, requires_grad, **kwargs):
  800. yield from sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs)
  801. # Creates additional inputs to test the rtol, atol, and equal_nan params
  802. rtols = [0., 1e-7]
  803. atols = [0., 1e-7]
  804. equal_nans = [False, True]
  805. products = product(rtols, atols, equal_nans)
  806. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  807. for rtol, atol, equal_nan in products:
  808. lhs = make_arg((S, S), **op.lhs_make_tensor_kwargs)
  809. rhs = make_arg((S, S), **op.rhs_make_tensor_kwargs)
  810. yield SampleInput(lhs, args=(rhs,),
  811. kwargs=dict(rtol=rtol, atol=atol, equal_nan=equal_nan))
  812. def error_inputs_isclose(op, device, **kwargs):
  813. make_float_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False)
  814. yield ErrorInput(SampleInput(make_float_arg(()), args=(make_float_arg(()),), kwargs={'rtol': -0.4}),
  815. error_type=RuntimeError,
  816. error_regex='rtol must be greater than or equal to zero')
  817. yield ErrorInput(SampleInput(make_float_arg(()), args=(make_float_arg(()),), kwargs={'atol': -0.4}),
  818. error_type=RuntimeError,
  819. error_regex='atol must be greater than or equal to zero')
  820. def sample_inputs_t(op_info, device, dtype, requires_grad, **kwargs):
  821. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  822. yield SampleInput(make_arg((1, 2)))
  823. yield SampleInput(make_arg((2,)))
  824. yield SampleInput(make_arg(()))
  825. def sample_inputs_mm(op_info, device, dtype, requires_grad, **kwargs):
  826. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  827. def make_arg_conj(size):
  828. return make_arg(size).conj().requires_grad_(requires_grad)
  829. first_shape, second_shape = (S, M), (M, S)
  830. yield SampleInput(make_arg(first_shape), args=(make_arg(second_shape),))
  831. if dtype.is_complex:
  832. yield SampleInput(make_arg(first_shape), args=(make_arg_conj(second_shape),))
  833. def sample_inputs_addmm(op_info, device, dtype, requires_grad, **kwargs):
  834. alpha_val = kwargs.get('alpha', 2 + 3j if dtype.is_complex else 0.6)
  835. beta_val = kwargs.get('beta', 1 + 2j if dtype.is_complex else 0.2)
  836. tests_list = [
  837. ((2, 3), (2, 2), (2, 3), False)
  838. ]
  839. tests_with_lhs_broadcasting = [
  840. ((1,), (2, 2), (2, 3), True),
  841. ((), (2, 2), (2, 3), True)
  842. ]
  843. test_cases = tests_list + tests_with_lhs_broadcasting # type: ignore[operator]
  844. kwargs = dict(alpha=alpha_val, beta=beta_val)
  845. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  846. for shape_a, shape_b, shape_c, broadcasts_input in test_cases:
  847. yield SampleInput(
  848. make_arg(shape_a),
  849. make_arg(shape_b),
  850. make_arg(shape_c),
  851. **kwargs,
  852. ).with_metadata(broadcasts_input=broadcasts_input)
  853. if dtype.is_complex:
  854. shape = (3, 3)
  855. yield SampleInput(
  856. make_arg(shape),
  857. make_arg(shape, requires_grad=False).mH.requires_grad_(requires_grad),
  858. make_arg(shape),
  859. **kwargs,
  860. )
  861. yield SampleInput(
  862. make_arg(shape),
  863. make_arg(shape),
  864. make_arg(shape, requires_grad=False).mH.requires_grad_(requires_grad),
  865. **kwargs,
  866. )
  867. def sample_inputs_sparse_sampled_addmm(op_info, device, dtype, requires_grad, **kwargs):
  868. alpha = 2 + 3j if dtype.is_complex else 0.6
  869. beta = 1 + 2j if dtype.is_complex else 0.2
  870. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  871. # sparse.sampled_addmm performs: alpha * (A @ B) * sparse_ones_like(C) + beta * C
  872. for m, n, k in itertools.product([0, 5], repeat=3):
  873. yield SampleInput(
  874. torch.eye(m, n, device=device, dtype=dtype)
  875. .to_sparse_csr()
  876. .requires_grad_(requires_grad),
  877. make_arg((m, k)),
  878. make_arg((k, n)),
  879. alpha=alpha,
  880. beta=beta,
  881. )
  882. def sample_inputs_sparse_mm_reduce(op_info, device, dtype, requires_grad, **kwargs):
  883. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  884. reductions = ["sum", "mean", "amax", "amin"]
  885. for m, k, reduce in product([5, 7], [3, 11], reductions):
  886. yield SampleInput(
  887. torch.eye(m, m)
  888. .to(device=device, dtype=dtype)
  889. .to_sparse_csr()
  890. .requires_grad_(requires_grad),
  891. make_arg((m, k)),
  892. reduce,
  893. )
  894. def sample_inputs_mv(self, device, dtype, requires_grad, **kwargs):
  895. make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad)
  896. yield SampleInput(make_arg(S, M), make_arg(M))
  897. def sample_inputs_bmm(self, device, dtype, requires_grad, **kwargs):
  898. make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad)
  899. yield SampleInput(make_arg(M, S, M), make_arg(M, M, S))
  900. def sample_inputs_dot_vdot(self, device, dtype, requires_grad, **kwargs):
  901. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  902. def make_arg_conj(size):
  903. return make_arg(size).conj().requires_grad_(requires_grad)
  904. yield SampleInput(make_arg((S, )), make_arg((S, )))
  905. if dtype.is_complex:
  906. # dot/vdot for (conj(input), conj(arg_tensor)) and (conj(input), arg_tensor)
  907. # is tested in test_conj_view (which tests operations with only conjugated input tensor
  908. # -- not conjugated arg tensors)
  909. yield SampleInput(make_arg((S, )), make_arg_conj((S, )))
  910. def sample_inputs_addmv(op_info, device, dtype, requires_grad, **kwargs):
  911. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  912. test_cases = (((S,), (S, M), (M,), 1, 1, False),
  913. ((S,), (S, M), (M,), 0.2, 0.6, False),
  914. )
  915. test_cases_with_broadcast = (((1,), (S, M), (M,), 1, 1, True),
  916. ((1,), (S, M), (M,), 0.2, 0.6, True),
  917. ((), (S, M), (M,), 1, 1, True),
  918. ((), (S, M), (M,), 0.2, 0.6, True),
  919. )
  920. cases = test_cases + test_cases_with_broadcast
  921. # addmv performs: beta * M + alpha * (mat @ vec)
  922. for size, mat, vec, beta, alpha, broadcasts_input in cases:
  923. yield SampleInput(make_arg(size), args=(make_arg(mat), make_arg(vec)),
  924. kwargs=dict(beta=beta, alpha=alpha), broadcasts_input=broadcasts_input)
  925. def sample_inputs_addbmm(op_info, device, dtype, requires_grad, **kwargs):
  926. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  927. # input_shape, batch1_shape, batch2_shape, beta_val, alpha_val, is_broadcasting
  928. test_cases = [((S, M), (S, S, S), (S, S, M), 1, 1, False),
  929. ((1,), (S, S, S), (S, S, M), 1, 1, True),
  930. ((S, M), (S, S, S), (S, S, M), 0.6, 0.2, False),
  931. ((1,), (S, S, S), (S, S, M), 0.6, 0.2, True),
  932. ((), (S, S, S), (S, S, M), 1, 1, True),
  933. ((), (S, S, S), (S, S, M), 0.6, 0.2, True),
  934. ]
  935. for input_shape, batch1_shape, batch2_shape, beta, alpha, is_broadcasting in test_cases:
  936. if dtype.is_complex:
  937. beta_complex, alpha_complex = beta * (1 + 2j), alpha * (2 + 3j)
  938. yield SampleInput(make_arg(input_shape), args=(make_arg(batch1_shape), make_arg(batch2_shape)),
  939. kwargs=dict(beta=beta_complex, alpha=alpha_complex), broadcasts_input=is_broadcasting)
  940. yield SampleInput(make_arg(input_shape), args=(make_arg(batch1_shape), make_arg(batch2_shape)),
  941. kwargs=dict(beta=beta, alpha=alpha), broadcasts_input=is_broadcasting)
  942. def sample_inputs_addcmul_addcdiv(op_info, device, dtype, requires_grad, **kwargs):
  943. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  944. test_cases = [(((S, S), (S, S), (S, S)), False),
  945. (((S, S), (S, 1), (1, S)), False),
  946. (((1,), (S, S, 1), (1, S)), True),
  947. (((), (), ()), False),
  948. (((S, S), (), ()), True),
  949. (((), (S, S, 1), (1, S)), True)
  950. ]
  951. for input_args, broadcasts_input in test_cases:
  952. # addcdiv should accept inputs with zero value
  953. # Currently, it throws ZeroDivisionError when the denominator is zero
  954. # TODO: exclude_zeros can be removed after https://github.com/pytorch/pytorch/issues/73638 is fixed
  955. args = tuple(make_arg(arg, exclude_zero=True) if isinstance(arg, tuple) else arg
  956. for arg in input_args)
  957. yield SampleInput(*args).with_metadata(broadcasts_input=broadcasts_input)
  958. # addcdiv should accept inputs with zero value
  959. # Currently, it throws ZeroDivisionError when the denominator is zero
  960. # TODO: exclude_zeros can be removed after https://github.com/pytorch/pytorch/issues/73638 is fixed
  961. args = tuple(make_arg(arg, exclude_zero=True) if isinstance(arg, tuple) else arg
  962. for arg in input_args)
  963. yield SampleInput(
  964. *args, value=3.14 if dtype.is_floating_point or dtype.is_complex else 3
  965. ).with_metadata(broadcasts_input=broadcasts_input)
  966. def reference_inputs_addcmul_addcdiv(op_info, device, dtype, requires_grad, **kwargs):
  967. yield from sample_inputs_addcmul_addcdiv(
  968. op_info, device, dtype, requires_grad, **kwargs)
  969. # type promotion cases
  970. supported_dtypes = op_info.supported_dtypes(device)
  971. make_arg = partial(make_tensor, device=device, requires_grad=requires_grad)
  972. types = (
  973. (torch.float64, torch.complex128),
  974. (torch.bfloat16, torch.float32),
  975. )
  976. values = (
  977. None,
  978. True, False,
  979. 3.14, 3,
  980. 1.0, 1,
  981. 0.0, 0,
  982. -3.14, -3,
  983. 3.14 + 2.71j,
  984. )
  985. for (type2, type3), value in product(types, values):
  986. if (type2 not in supported_dtypes or
  987. type3 not in supported_dtypes):
  988. continue
  989. # RuntimeError: value cannot be converted without overflow
  990. if (type(value) is complex and
  991. type2 is not torch.complex128):
  992. continue
  993. arg1 = make_arg([5, 5], dtype=dtype)
  994. arg2 = make_arg([5, 5], dtype=type2)
  995. arg3 = make_arg([1, 5], dtype=type3)
  996. # TypeError: addcdiv(): argument 'value' must be Number, not NoneType
  997. if value is not None:
  998. yield SampleInput(arg1, args=(arg2, arg3), kwargs=dict(value=value))
  999. else:
  1000. yield SampleInput(arg1, args=(arg2, arg3))
  1001. def sample_inputs_baddbmm(op_info, device, dtype, requires_grad, **kwargs):
  1002. test_cases = [((S, S, M), (S, S, S), (S, S, M), 1, 1, False),
  1003. ((1,), (S, S, S), (S, S, M), 1, 1, True),
  1004. ((S, S, M), (S, S, S), (S, S, M), 0.6, 0.2, False),
  1005. ((1,), (S, S, S), (S, S, M), 0.6, 0.2, True),
  1006. ((), (S, S, S), (S, S, M), 1, 1, True),
  1007. ((), (S, S, S), (S, S, M), 0.6, 0.2, True),
  1008. ]
  1009. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None)
  1010. for (input_shape, batch1_shape, batch2_shape, alpha, beta, broadcasts_input) in test_cases:
  1011. yield SampleInput(
  1012. make_arg(input_shape),
  1013. make_arg(batch1_shape),
  1014. make_arg(batch2_shape),
  1015. beta=beta,
  1016. alpha=alpha
  1017. ).with_metadata(broadcasts_input=broadcasts_input)
  1018. if dtype.is_complex:
  1019. yield SampleInput(
  1020. make_arg(input_shape),
  1021. make_arg(batch1_shape),
  1022. make_arg(batch2_shape),
  1023. beta=beta * (1 + 2j),
  1024. alpha=alpha * (2 + 3j),
  1025. ).with_metadata(broadcasts_input=broadcasts_input)
  1026. if dtype.is_complex:
  1027. shapes = [(S, S, S), (S, M, S), (S, S, M)]
  1028. args = tuple(make_arg(s) for s in shapes)
  1029. yield SampleInput(
  1030. args[0].transpose_(-1, 1),
  1031. args[1].transpose(-1, 1).conj().requires_grad_(requires_grad),
  1032. args[2].transpose(-1, 1).conj().requires_grad_(requires_grad),
  1033. beta=beta * (1 + 2j),
  1034. alpha=alpha * (2 + 3j),
  1035. )
  1036. # TODO: add reduction kwargs
  1037. def sample_inputs_multilabel_soft_margin_loss(op_info, device, dtype, requires_grad, **kwargs):
  1038. _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  1039. shapes = (
  1040. (S,),
  1041. (S, S),
  1042. )
  1043. for shape in shapes:
  1044. # Produce one with weight and one without.
  1045. yield SampleInput(_make_tensor(shape), args=(_make_tensor(shape, requires_grad=False),), kwargs={})
  1046. yield SampleInput(_make_tensor(shape), args=(_make_tensor(shape, requires_grad=False),),
  1047. kwargs={'weight': _make_tensor(shape, requires_grad=False)})
  1048. def sample_inputs_addr(op_info, device, dtype, requires_grad, **kwargs):
  1049. make_arg = partial(
  1050. make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None
  1051. )
  1052. yield SampleInput(make_arg(S, M), make_arg(S), make_arg(M))
  1053. yield SampleInput(make_arg(), make_arg(S), make_arg(M)).with_metadata(broadcasts_input=True)
  1054. if dtype.is_complex:
  1055. alpha, beta = 0.1 + 0.3j, 0.4 + 0.6j
  1056. elif dtype.is_floating_point:
  1057. alpha, beta = 0.2, 0.6
  1058. else:
  1059. alpha, beta = 2, 3
  1060. yield SampleInput(make_arg(S, M), make_arg(S), make_arg(M), beta=beta, alpha=alpha)
  1061. yield SampleInput(
  1062. make_arg(),
  1063. make_arg(S),
  1064. make_arg(M),
  1065. beta=beta,
  1066. alpha=alpha,
  1067. ).with_metadata(broadcasts_input=True)
  1068. # These samples fail gradcheck
  1069. if dtype.is_floating_point and not requires_grad:
  1070. tensor_options = dict(device=device, dtype=dtype, requires_grad=requires_grad)
  1071. yield SampleInput(
  1072. torch.tensor([[math.nan]], **tensor_options),
  1073. torch.tensor([0.0], **tensor_options),
  1074. torch.tensor([0.0], **tensor_options),
  1075. beta=0.0,
  1076. alpha=0.0,
  1077. ).with_metadata(broadcasts_input=True)
  1078. yield SampleInput(
  1079. torch.tensor([[0.0]], **tensor_options),
  1080. torch.tensor([math.nan], **tensor_options),
  1081. torch.tensor([math.nan], **tensor_options),
  1082. beta=0.0,
  1083. alpha=0.0,
  1084. ).with_metadata(broadcasts_input=True)
  1085. def sample_inputs_zero_(op_info, device, dtype, requires_grad, **kwargs):
  1086. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  1087. cases = ((), (S, S, S), (S,))
  1088. for shape in cases:
  1089. yield SampleInput(make_arg(shape))
  1090. # TODO: add reduction kwargs
  1091. def sample_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **kwargs):
  1092. _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  1093. make_target = partial(_make_tensor, dtype=torch.long, requires_grad=False)
  1094. inputs = (
  1095. ((), make_target([], low=0, high=1), {}),
  1096. ((S,), make_target([], low=0, high=S), {"p": 1}),
  1097. ((S,), make_target([1], low=0, high=S), {"p": 2}),
  1098. ((S, M), make_target([S], low=0, high=M), {"margin": 1.0}),
  1099. ((M, S), make_target([M], low=0, high=S), {"weight": None}),
  1100. )
  1101. for input_shape, target, kwargs in inputs:
  1102. yield SampleInput(_make_tensor(input_shape), args=(target,), kwargs=kwargs)
  1103. def sample_inputs_logsumexp(self, device, dtype, requires_grad, **kwargs):
  1104. inputs = (
  1105. ((), (0,), True),
  1106. ((S, S), (1,), True),
  1107. ((S, S), (1,), False),
  1108. ((S, S), (-2,), False),
  1109. ((S, S), (0, 1), False),
  1110. )
  1111. # Test large inputs to check numerical stability
  1112. lows = (None, 1e3, 1e6) if dtype in (torch.float32, torch.float64) else (None,)
  1113. for low in lows:
  1114. high = low * 2 if low is not None else None
  1115. for shape, dim, keepdim in inputs:
  1116. t = make_tensor(shape, dtype=dtype, device=device,
  1117. low=low, high=high,
  1118. requires_grad=requires_grad)
  1119. yield SampleInput(t, dim, keepdim)
  1120. def reference_inputs_logsumexp(op, device, dtype, requires_grad, **kwargs):
  1121. yield from sample_inputs_logsumexp(op, device, dtype, requires_grad, **kwargs)
  1122. # https://github.com/pytorch/pytorch/issues/91843
  1123. t = torch.tensor([20, 30, 100], dtype=dtype, device=device, requires_grad=requires_grad)
  1124. yield SampleInput(t, 0, False)
  1125. t = torch.tensor((), dtype=dtype, device=device, requires_grad=requires_grad)
  1126. yield SampleInput(t, 0, False)
  1127. # tests masking
  1128. # https://github.com/pytorch/pytorch/pull/91860#pullrequestreview-1241344073
  1129. t = torch.tensor(float("inf"))
  1130. yield SampleInput(t, 0, True)
  1131. def sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs):
  1132. inputs = [
  1133. ((), {}),
  1134. ((S, S), {}),
  1135. ((0, S, 0), {}),
  1136. ((S,), {'dtype': dtype, 'device': device}),
  1137. # Hard-code some dtypes/devices. We want to test cases where the
  1138. # (dtype, device) is different from the input's (dtype, device)
  1139. ((S,), {'dtype': torch.double}),
  1140. ((S,), {'device': 'cpu'}),
  1141. ((S,), {'dtype': torch.double, 'device': 'cpu'}),
  1142. ]
  1143. if torch.cuda.is_available():
  1144. inputs.append(((S,), {'device': 'cuda'}))
  1145. for shape, kwargs in inputs:
  1146. t = make_tensor(shape, dtype=dtype, device=device,
  1147. low=None, high=None,
  1148. requires_grad=requires_grad)
  1149. yield SampleInput(t, **kwargs)
  1150. def reference_inputs_like_fns(op, device, dtype, requires_grad, **kwargs):
  1151. yield from sample_inputs_like_fns(op, device, dtype, requires_grad, **kwargs)
  1152. # shape
  1153. cases = (
  1154. (), (0,), (1, 0), (1, 1, 4, 5), (5, 3, 0, 1), (1, 4, 3, 1, 1)
  1155. )
  1156. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  1157. for shape in cases:
  1158. yield SampleInput(make_arg(shape))
  1159. yield SampleInput(make_arg(shape).transpose(0, -1))
  1160. yield SampleInput(make_arg(shape, noncontiguous=True))
  1161. yield SampleInput(make_arg(shape, noncontiguous=True).transpose(0, -1))
  1162. # TODO: add reduction kwargs
  1163. def sample_inputs_multilabel_margin_loss(op_info, device, dtype, requires_grad, **kwargs):
  1164. _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  1165. make_target = partial(_make_tensor, dtype=torch.long, requires_grad=False)
  1166. inputs = (
  1167. ([], make_target([], low=0, high=1)),
  1168. ([S], make_target([S], low=0, high=S)),
  1169. ([M, S], make_target([M, S], low=0, high=S)),
  1170. )
  1171. for shape, target in inputs:
  1172. yield SampleInput(_make_tensor(shape), args=(target,))
  1173. def get_independent_tensor(tensor):
  1174. return tensor.clone().requires_grad_(tensor.requires_grad)
  1175. def sample_inputs_randint(self, device, dtype, requires_grad, **kwargs):
  1176. low = 2
  1177. high = 10
  1178. for sample in sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs):
  1179. # With high
  1180. yield SampleInput(high, sample.input.shape, *sample.args, **sample.kwargs)
  1181. # With low and high
  1182. yield SampleInput(low, high, sample.input.shape, *sample.args, **sample.kwargs)
  1183. def sample_inputs_randint_like(self, device, dtype, requires_grad, **kwargs):
  1184. low = 2
  1185. high = 10
  1186. for sample in sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs):
  1187. # With high
  1188. yield SampleInput(
  1189. sample.input,
  1190. high,
  1191. *sample.args,
  1192. **sample.kwargs)
  1193. # With low and high
  1194. yield SampleInput(
  1195. get_independent_tensor(sample.input),
  1196. low,
  1197. high,
  1198. *sample.args,
  1199. **sample.kwargs)
  1200. def sample_inputs_margin_ranking_loss(op_info, device, dtype, requires_grad, **kwargs):
  1201. _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  1202. shapes = (
  1203. (),
  1204. (S,),
  1205. (S, S),
  1206. (S, S, S),
  1207. )
  1208. margins = (0., 1.)
  1209. reductions = ('sum', 'mean', 'none')
  1210. for shape in shapes:
  1211. for margin, reduction in product(margins, reductions):
  1212. kwargs = {'margin': margin, 'reduction': reduction}
  1213. yield SampleInput(_make_tensor(shape),
  1214. args=(_make_tensor(shape, requires_grad=False),
  1215. _make_tensor(shape, requires_grad=False)),
  1216. kwargs=kwargs)
  1217. def reference_inputs_margin_ranking_loss(op, device, dtype, requires_grad, **kwargs):
  1218. yield from sample_inputs_margin_ranking_loss(op, device, dtype, requires_grad, **kwargs)
  1219. make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  1220. for reduction in ('sum', 'mean', 'none'):
  1221. if dtype.is_floating_point: # only supports ints and floats
  1222. # NaN propagation
  1223. inp1 = make_input((10, ))
  1224. inp1[2] = float('nan')
  1225. inp2 = make_input((10, ))
  1226. inp2[4] = float('nan')
  1227. target = make_input((10, ))
  1228. inp2[9] = float('nan')
  1229. yield SampleInput(inp1, args=(inp2, target), kwargs={'reduction': reduction})
  1230. # Inf handling
  1231. inp1 = make_input((10, ))
  1232. inp2[1] = float('inf')
  1233. inp2 = make_input((10, ))
  1234. inp2[4] = float('inf')
  1235. target = make_input((10, ))
  1236. inp2[7] = float('inf')
  1237. yield SampleInput(inp1, args=(inp2, target), kwargs={'reduction': reduction})
  1238. # Broadcasting
  1239. inp1 = make_input((5, 2))
  1240. inp2 = make_input((5, 1))
  1241. target = make_input((1, 2))
  1242. yield SampleInput(inp1, args=(inp2, target), kwargs={'reduction': reduction})
  1243. def error_inputs_margin_ranking_loss(op, device, **kwargs):
  1244. make_input = partial(make_tensor, device=device, dtype=torch.float32)
  1245. # invalid reduction value.
  1246. yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4), make_input(5, 4),), kwargs={'reduction': 'abc'}),
  1247. error_type=ValueError, error_regex='is not a valid value')
  1248. # invalid input shapes
  1249. yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4), make_input(5,),)),
  1250. error_regex='margin_ranking_loss : All input tensors should')
  1251. def sample_inputs_new_fns(self, device, dtype, requires_grad, *, is_strided=False, **kwargs):
  1252. # input_shape, output_shape, strides, kwargs
  1253. # lengths of output_shape and strides must be equal
  1254. inputs = [
  1255. ((), (), (), {}),
  1256. ((S, S), (2, 0), (3, 4), {}),
  1257. ((0, S, 0), (3, 2, 2), (1, 2, 3), {}),
  1258. ((S,), (2, 3), (7, 8), {'dtype': dtype, 'device': device}),
  1259. # Hard-code some dtypes/devices. We want to test cases where the
  1260. # (dtype, device) is different from the input's (dtype, device)
  1261. ((S,), (10,), (S,), {'dtype': torch.double}),
  1262. ((S,), (1, 1, 12), (S, L, M), {'device': 'cpu'}),
  1263. ((S,), (2, 2, 2), (L, M, S), {'dtype': torch.double, 'device': 'cpu'}),
  1264. ]
  1265. if torch.cuda.is_available():
  1266. inputs.append(((S,), (7, 2), (3, 4), {'device': 'cuda'}))
  1267. for input_shape, output_shape, strides, kwargs in inputs:
  1268. t = make_tensor(input_shape, dtype=dtype, device=device,
  1269. low=None, high=None,
  1270. requires_grad=requires_grad)
  1271. if is_strided:
  1272. yield SampleInput(t, output_shape, strides, **kwargs)
  1273. else:
  1274. yield SampleInput(t, output_shape, **kwargs)
  1275. def sample_inputs_empty(op, device, dtype, requires_grad, **kwargs):
  1276. # shape
  1277. cases = (
  1278. (), (0,), (1,), (1, 3, 5), (5, 3, 1), (1, 0, 5, 1),
  1279. )
  1280. for case in cases:
  1281. yield SampleInput(case, device=device, dtype=dtype, requires_grad=requires_grad)
  1282. def sample_inputs_scalar_tensor(op, device, dtype, requires_grad, **kwargs):
  1283. # Not including a scalar tensor in vals because meta tests start failing due to
  1284. # lack of meta support for _local_scalar_dense
  1285. # torch.tensor(2, device=device)
  1286. vals = (-5, 0, 1)
  1287. for item in vals:
  1288. yield SampleInput(item, device=device, dtype=dtype, requires_grad=requires_grad)
  1289. def sample_inputs_eye(op, device, dtype, requires_grad, **kwargs):
  1290. # only ints >= 0 are allowed for both arguments, unless m is omitted
  1291. sizes = (None, 0, 1, 2, 3, 4, 7, L, M, S)
  1292. for n, m in product(sizes, sizes):
  1293. if n is None:
  1294. continue
  1295. # TODO: no layout
  1296. _kwargs = {'device': device, 'dtype': dtype, 'requires_grad': requires_grad}
  1297. if m is None:
  1298. yield SampleInput(n, args=(), kwargs=_kwargs)
  1299. else:
  1300. yield SampleInput(n, args=(m,), kwargs=_kwargs)
  1301. def error_inputs_eye(op_info, device, **kwargs):
  1302. # TODO: no layout
  1303. _kwargs = {'device': device, 'dtype': torch.float32}
  1304. yield ErrorInput(
  1305. SampleInput(-1, args=(), kwargs=_kwargs),
  1306. error_regex="n must be greater or equal to 0, got -1"
  1307. )
  1308. yield ErrorInput(
  1309. SampleInput(-7, args=(42,), kwargs=_kwargs),
  1310. error_regex="n must be greater or equal to 0, got -7"
  1311. )
  1312. yield ErrorInput(
  1313. SampleInput(0, args=(-3,), kwargs=_kwargs),
  1314. error_regex="m must be greater or equal to 0, got -3"
  1315. )
  1316. def sample_inputs_new_full(self, device, dtype, requires_grad, **kwargs):
  1317. def get_val(dtype):
  1318. return make_tensor([], dtype=dtype, device="cpu").item()
  1319. for sample in sample_inputs_new_fns(self, device, dtype, requires_grad, **kwargs):
  1320. # The scalar we are passing to new_full must be the same dtype
  1321. # as the one of the resulting tensor
  1322. use_dtype = sample.kwargs['dtype'] if 'dtype' in sample.kwargs else dtype
  1323. yield SampleInput(
  1324. sample.input, *sample.args, get_val(use_dtype), **sample.kwargs)
  1325. def sample_inputs_full_like(self, device, dtype, requires_grad, **kwargs):
  1326. def get_val(dtype):
  1327. return make_tensor([], dtype=dtype, device="cpu").item()
  1328. inputs = [
  1329. ((), get_val(dtype), {}),
  1330. ((S, S), get_val(dtype), {}),
  1331. ((0, S, 0), get_val(dtype), {}),
  1332. ((S,), get_val(dtype), {'dtype': dtype, 'device': device}),
  1333. # Hard-code some dtypes/devices. We want to test cases where the
  1334. # (dtype, device) is different from the input's (dtype, device)
  1335. ((S,), get_val(torch.double), {'dtype': torch.double}),
  1336. ((S,), get_val(dtype), {'device': 'cpu'}),
  1337. ((S,), get_val(torch.double), {'dtype': torch.double, 'device': 'cpu'}),
  1338. ]
  1339. if torch.cuda.is_available():
  1340. inputs.append(((S,), get_val(dtype), {'device': 'cuda'}))
  1341. for shape, fill_value, kwargs in inputs:
  1342. t = make_tensor(shape, dtype=dtype, device=device,
  1343. low=None, high=None,
  1344. requires_grad=requires_grad)
  1345. yield SampleInput(t, fill_value, **kwargs)
  1346. def sample_inputs_multinomial(self, device, dtype, requires_grad, **kwargs):
  1347. cases = [
  1348. ([3], 3, {}),
  1349. ([10], 3, {}),
  1350. ([3, 10], 3, {}),
  1351. ([3], 3, dict(replacement=False)),
  1352. ([3], 3, dict(replacement=True)),
  1353. ([3, 4], 4, dict(replacement=True)),
  1354. ([3, 4], 4, dict(replacement=False)),
  1355. ]
  1356. for shape, num_samples, kwargs in cases:
  1357. t = make_tensor(shape, dtype=dtype, device=device,
  1358. low=0, high=None,
  1359. requires_grad=requires_grad)
  1360. yield SampleInput(t, num_samples, **kwargs)
  1361. def sample_inputs_normal_common(self, device, dtype, requires_grad, cases, **kwargs):
  1362. def get_value_or_make_tensor(value_or_shape):
  1363. if isinstance(value_or_shape, list):
  1364. return make_tensor(value_or_shape, dtype=dtype, device=device,
  1365. low=0, high=None,
  1366. requires_grad=requires_grad)
  1367. return value_or_shape
  1368. for value_or_mean_shape, value_or_std_shape, kwargs in cases:
  1369. mean = get_value_or_make_tensor(value_or_mean_shape)
  1370. std = get_value_or_make_tensor(value_or_std_shape)
  1371. yield SampleInput(mean, std, **kwargs)
  1372. def sample_inputs_normal_tensor_first(self, device, dtype, requires_grad, **kwargs):
  1373. # value_or_size, value_or_size, kwargs
  1374. cases = [
  1375. ([], [], {}),
  1376. ([3], [3], {}),
  1377. ([3, 4, 2], [3, 4, 2], {}),
  1378. ([2, 3], 1.1, {}),
  1379. ([1, 2, 3], [5, 2, 3], {}), # broadcasting
  1380. ]
  1381. return sample_inputs_normal_common(self, device, dtype, requires_grad, cases, **kwargs)
  1382. def sample_inputs_normal_tensor_second(self, device, dtype, requires_grad, **kwargs):
  1383. cases = [
  1384. ([3, 4], 0.3, {}),
  1385. ]
  1386. return sample_inputs_normal_common(self, device, dtype, requires_grad, cases, **kwargs)
  1387. def sample_inputs_bernoulli(self, device, dtype, requires_grad, **kwargs):
  1388. shapes = [
  1389. [3],
  1390. [],
  1391. [0, 3],
  1392. [2, 3, 4],
  1393. ]
  1394. for shape in shapes:
  1395. t = make_tensor(shape, dtype=dtype, device=device,
  1396. low=0, high=1,
  1397. requires_grad=requires_grad)
  1398. yield SampleInput(t)
  1399. def error_inputs_bernoulli(op_info, device, **kwargs):
  1400. # more than one element of the written-to tensor refers to a single memory location
  1401. x = torch.rand((1,), device=device).expand((6,))
  1402. err_msg = 'unsupported operation'
  1403. yield ErrorInput(SampleInput(torch.rand_like(x), kwargs={'out': x}),
  1404. error_regex=err_msg)
  1405. def sample_inputs_logcumsumexp(self, device, dtype, requires_grad, **kwargs):
  1406. inputs = (
  1407. ((S, S, S), 0),
  1408. ((S, S, S), 1),
  1409. ((), 0),
  1410. )
  1411. for large_number in (True, False):
  1412. for shape, dim in inputs:
  1413. t = make_tensor(shape, dtype=dtype, device=device,
  1414. low=None, high=None,
  1415. requires_grad=requires_grad)
  1416. if large_number and t.dim() > 0:
  1417. t[0] = 10000
  1418. yield SampleInput(t, dim)
  1419. def sample_inputs_trace(self, device, dtype, requires_grad, **kwargs):
  1420. yield SampleInput((make_tensor((S, S), dtype=dtype, device=device,
  1421. low=None, high=None,
  1422. requires_grad=requires_grad)))
  1423. def error_inputs_trace(op, device):
  1424. yield ErrorInput(SampleInput(make_tensor((3, 4, 5), dtype=torch.float32, device=device)), error_regex="expected a matrix")
  1425. def sample_inputs_renorm(self, device, dtype, requires_grad, **kwargs):
  1426. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  1427. cases = (((S, S, S), (2, 1, 0.5)),
  1428. ((S, S, S), (2, -1, 0.5)),
  1429. ((S, S, S), (1, 2, 3)),
  1430. ((S, S, S), (float('inf'), 2, 0.5)),
  1431. )
  1432. for shape, args in cases:
  1433. yield SampleInput(make_arg(shape), args=args)
  1434. def sample_inputs_transpose_swapdims(self, device, dtype, requires_grad, **kwargs):
  1435. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  1436. cases = (((1, 2, 3), (-1, -2)),
  1437. ((1, 2, 3), (-1, 2)),
  1438. ((1, 2, 3), (1, -2)),
  1439. ((1, 2, 3), (1, 2)),
  1440. ((), (0, 0)),
  1441. ((1, ), (0, 0)),
  1442. ((M, M), (0, 1)),
  1443. ((S, S, S), (2, 0)), )
  1444. for shape, args in cases:
  1445. yield SampleInput(make_arg(shape), args=args)
  1446. def _numpy_ref_transpose(a, dim0, dim1):
  1447. if a.ndim <= 1:
  1448. return a
  1449. return np.swapaxes(a, dim0, dim1)
  1450. def sample_inputs_adjoint(self, device, dtype, requires_grad, **kwargs):
  1451. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  1452. shapes = ((1, 2, 3), (M, M), (S, S, S), (S, M, S), (M, S, M, S))
  1453. return (SampleInput(make_arg(shape)) for shape in shapes)
  1454. def sample_inputs_T(self, device, dtype, requires_grad, **kwargs):
  1455. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  1456. shapes = ((M, M), (M, L))
  1457. return (SampleInput(make_arg(shape)) for shape in shapes)
  1458. def error_inputs_T(self, device, has_ndims_error=False):
  1459. make_arg = partial(make_tensor, device=device, dtype=torch.float32)
  1460. # Deprecated behavior in regular PyTorch, but throws an error in primTorch:
  1461. # https://github.com/pytorch/pytorch/issues/86968
  1462. if has_ndims_error:
  1463. # ndims == 1
  1464. yield ErrorInput(SampleInput(make_arg(M)),
  1465. error_regex=(r'The use of `x\.T` on tensors of dimension other than 0 or 2 '
  1466. r'to reverse their shape is not supported\.'))
  1467. # ndims > 2
  1468. yield ErrorInput(SampleInput(make_arg(M, S, L)),
  1469. error_regex=(r'The use of `x\.T` on tensors of dimension other than 0 or 2 '
  1470. r'to reverse their shape is not supported\.'))
  1471. def sample_inputs_singular_matrix_factors(op_info, device, dtype, requires_grad=False, **kwargs):
  1472. """
  1473. This function produces two tensors of shape (*, m, k) and (*, n, k) with k <= min(m, n).
  1474. Their matrix product could be used to generate tensor of shape (*, m, n) of rank k.
  1475. """
  1476. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  1477. batches = [(), (0, ), (2, ), (1, 1)]
  1478. size = [1, 5, 10]
  1479. for batch, m, n in product(batches, size, size):
  1480. for k in range(min(3, min(m, n))):
  1481. a = make_arg((*batch, m, k))
  1482. b = make_arg((*batch, n, k))
  1483. yield SampleInput(a, b, **kwargs)
  1484. def sample_inputs_svd_lowrank(op_info, device, dtype, requires_grad=False, **kwargs):
  1485. for sample in sample_inputs_singular_matrix_factors(op_info, device, dtype, requires_grad, **kwargs):
  1486. *batch, m, k = sample.input.shape
  1487. *_, n, _ = sample.args[0].shape
  1488. # NOTE: since svd_lowrank relies on non rank-revealing SVD,
  1489. # it inherits the problem of unstable behavior with repeated
  1490. # singular values including zeros.
  1491. # Since we want to avoid (repeated) zeros as singular values,
  1492. # we can only use k for q.
  1493. # This issues could be resolved with using a rank-revealing SVD
  1494. # which does not include "zero" singular values.
  1495. op_kwargs = {
  1496. 'q': k,
  1497. 'M': None
  1498. }
  1499. # without M specified
  1500. yield clone_sample(sample, **op_kwargs)
  1501. # now with M
  1502. # TODO: fix bug in the documentation for svd_lowrank:
  1503. # M has to be (*, m, n), and not (*, 1, n) as written
  1504. # in the documentation
  1505. op_kwargs['M'] = make_tensor((*batch, m, n), dtype=dtype, device=device, requires_grad=requires_grad)
  1506. yield clone_sample(sample, **op_kwargs)
  1507. def chunk_iter(iterable, size):
  1508. it = iter(iterable)
  1509. while True:
  1510. chunk = tuple(islice(it, size))
  1511. if not chunk:
  1512. break
  1513. yield chunk
  1514. def sample_inputs_pca_lowrank(op_info, device, dtype, requires_grad=False, **kwargs):
  1515. # we reuse samples from svd_lowrank which come in group of two with
  1516. # kwarg['M'] = None and with kwarg['M'] = <some tensor>
  1517. samples = sample_inputs_svd_lowrank(op_info, device, dtype, requires_grad, **kwargs)
  1518. for s1, s2 in chunk_iter(samples, 2):
  1519. del s1.kwargs['M']
  1520. del s2.kwargs['M']
  1521. s1.kwargs['center'] = False
  1522. s2.kwargs['center'] = True
  1523. yield s1
  1524. yield s2
  1525. def np_sinc_with_fp16_as_fp32(x):
  1526. # Wraps numpy's sinc function so that fp16 values are promoted to fp32
  1527. # before sinc is invoked. Context: numpy's sinc returns NaN when evaluated
  1528. # at 0 for fp16.
  1529. if x.dtype == np.float16:
  1530. return np.sinc(x.astype(np.float32))
  1531. else:
  1532. return np.sinc(x)
  1533. def sample_inputs_broadcast_to(op_info, device, dtype, requires_grad, **kwargs):
  1534. test_cases = (
  1535. ((S, 1, 1), (S, S, S)),
  1536. ((S, 1, S), (S, S, S)),
  1537. ((S, 1), (S, S, S)),
  1538. ((1,), (S, S, S)),
  1539. ((1, S), (1, 1, S)),
  1540. ((), ()),
  1541. ((), (1, 3, 2)),
  1542. )
  1543. return (
  1544. SampleInput(
  1545. make_tensor(size, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad),
  1546. shape,
  1547. ) for size, shape in test_cases)
  1548. def sample_inputs_broadcast_tensors(op_info, device, dtype, requires_grad, **kwargs):
  1549. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  1550. test_cases: Tuple[tuple] = (((3,), (1, 2, 1), (1, 1), (5, 1, 1),),)
  1551. for shape, *other_shapes in test_cases:
  1552. yield SampleInput(make_arg(shape), args=tuple(make_arg(s) for s in other_shapes))
  1553. def reference_inputs_broadcast_tensors(op, device, dtype, requires_grad, **kwargs):
  1554. yield from sample_inputs_broadcast_tensors(op, device, dtype, requires_grad, **kwargs)
  1555. m = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  1556. n = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad, noncontiguous=True)
  1557. cases = (
  1558. ((), (1, 1), (1, 1, 7, 1), (3, 1, 1)),
  1559. ((3, 5, 6), (1, 3, 5, 6), (1, 1, 1, 1, 6), (8, 3, 5, 6))
  1560. )
  1561. for a, b, c, d in cases:
  1562. yield SampleInput(m(a), args=(m(b), m(c), m(d)))
  1563. yield SampleInput(n(a), args=(n(b), n(c), n(d)))
  1564. def sample_inputs_block_diag(op_info, device, dtype, requires_grad, **kwargs):
  1565. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  1566. test_cases: Tuple[tuple] = (
  1567. ((1, S), (2, S), (3, S),),
  1568. ((S, 1), (S, 2), (S, 3),),
  1569. ((1,), (2,), (3,),),
  1570. ((2, S), (S,))
  1571. )
  1572. for shape, *other_shapes in test_cases:
  1573. yield SampleInput(make_arg(shape), args=tuple(make_arg(s) for s in other_shapes))
  1574. # We also want to test mixed complex-non-complex inputs to block_diag
  1575. if dtype == torch.complex32 or dtype == torch.complex64:
  1576. non_complex_dtype = torch.float32 if dtype == torch.complex32 else torch.float64
  1577. make_arg_non_complex = partial(make_tensor, dtype=non_complex_dtype, device=device, requires_grad=requires_grad)
  1578. yield SampleInput(make_arg_non_complex(shape), args=tuple(make_arg(s) for s in other_shapes))
  1579. def sample_inputs_cdist(op_info, device, dtype, requires_grad, **kwargs):
  1580. small_S = 2
  1581. test_cases = (
  1582. ((S, S, 2), (S, S + 1, 2)),
  1583. ((S, S), (S, S)),
  1584. ((S, S, S), (S, S, S)),
  1585. ((3, 5), (3, 5)),
  1586. ((2, 3, 5), (2, 3, 5)),
  1587. ((1, 2, 3), (1, 2, 3)),
  1588. ((1, 1), (S, 1)),
  1589. ((0, 5), (4, 5)),
  1590. ((4, 5), (0, 5)),
  1591. ((0, 4, 5), (3, 5)),
  1592. ((4, 5), (0, 3, 5)),
  1593. ((0, 4, 5), (1, 3, 5)),
  1594. ((1, 4, 5), (0, 3, 5)),
  1595. # Using S here would make this one test take 9s
  1596. ((small_S, small_S, small_S + 1, 2), (small_S, small_S, small_S + 2, 2)),
  1597. ((small_S, 1, 1, small_S), (1, small_S, small_S)),
  1598. ((1, 1, small_S), (small_S, 1, small_S, small_S)),
  1599. )
  1600. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  1601. for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
  1602. # FIXME add an override for JIT and revert 0. back to 0
  1603. # since it's accepted by eager
  1604. for p in [0., 1., 2., 3., 0.5, 1.5, 2.5, float("inf")]:
  1605. for t1_size, t2_size in test_cases:
  1606. # The args should never be non-contiguous as this is not supported in the backward
  1607. yield SampleInput(make_arg(t1_size), make_arg(t2_size), p, cm)
  1608. def sample_inputs_fill_(op_info, device, dtype, requires_grad, **kwargs):
  1609. make_arg = partial(make_tensor, device=device, dtype=dtype,
  1610. low=None, high=None, requires_grad=requires_grad)
  1611. cases = (((S, S, S), (1,)),
  1612. ((), (1,)),
  1613. ((S, S, S), (make_arg(()),)))
  1614. for shape, args in cases:
  1615. yield SampleInput(make_arg(shape), args=args)
  1616. def _fill_np(a, value):
  1617. a = a.copy()
  1618. a.fill(value)
  1619. return a
  1620. def _fill_sample_kwargs(device, dtype, input):
  1621. if dtype is torch.bool:
  1622. value = True
  1623. else:
  1624. value = 3
  1625. return ({'value': value}, {'value': value})
  1626. def sample_inputs_comparison_ops(op, device, dtype, requires_grad, **kwargs):
  1627. yield from sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs)
  1628. # Adds a sample input where both tensors have the same values
  1629. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  1630. lhs = make_arg((S, S))
  1631. yield SampleInput(lhs, args=(lhs.clone(),))
  1632. def sample_inputs_stack(op_info, device, dtype, requires_grad, **kwargs):
  1633. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  1634. # shape x number of tensors
  1635. cases = (
  1636. ((3, 4), 1),
  1637. ((1, 2, 1, 4), 3),
  1638. ((0, 1, 0), 2),)
  1639. for shape, num_tensors in cases:
  1640. tensors = []
  1641. for _ in range(num_tensors):
  1642. tensors.append(make_arg(shape))
  1643. for dim in range(-1, len(shape) - 1):
  1644. yield SampleInput(tensors, args=(dim,))
  1645. def sample_inputs_cat_concat(op_info, device, dtype, requires_grad, **kwargs):
  1646. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  1647. cases: Tuple[tuple, tuple, dict] = ( # type: ignore[assignment]
  1648. ((S, S), (S, S), {'dim': -1}),
  1649. ((S, S), (S, S), {'dim': 1}),
  1650. ((M, S), (S, S), {'dim': 0}), # different shapes
  1651. ((1, 2, 3), (1, 2, 3), {'dim': -2}),
  1652. ((0,), (0,), {'dim': 0}), # empty tensor
  1653. ((0,), (S, S), {'dim': 1}), # empty tensor with unempty and dim=1 (special case for legacy_cat_wrap_dim)
  1654. ((0, S), (S, S), {'dim': 0}),
  1655. ((1,), (1,), {}) # dim not passed, fallback to default
  1656. )
  1657. for input_shape1, input_shape2, kwargs in cases:
  1658. yield SampleInput([make_arg(input_shape1), make_arg(input_shape2)], kwargs=kwargs)
  1659. # from coat_lite_mini
  1660. yield SampleInput([make_arg((2, 2, 2, 2), memory_format=torch.channels_last)], args=(1,),)
  1661. def error_inputs_cat(op_info, device, **kwargs):
  1662. make_arg = partial(make_tensor, device=device, dtype=torch.float32)
  1663. # error inputs for more than one element of the written-to tensor refer to a single memory location
  1664. yield ErrorInput(SampleInput([make_arg((S, S)), make_arg((S, S))],
  1665. kwargs={'out': make_arg((1, S)).expand((2 * S, S))}),
  1666. error_regex='unsupported operation')
  1667. # error inputs for empty tensors
  1668. yield ErrorInput(SampleInput([], kwargs={'dim': 1}),
  1669. error_regex='non-empty list of Tensors')
  1670. # error inputs for different sizes
  1671. yield ErrorInput(SampleInput([make_arg((S, S, L, L)), make_arg((S, 0, L - 1, L))], kwargs={'dim': 1}),
  1672. error_regex='Sizes of tensors must match except in dimension')
  1673. yield ErrorInput(SampleInput([make_arg((S, 0, L - 1, L)), make_arg((S, S, L, L))], kwargs={'dim': 1}),
  1674. error_regex='Sizes of tensors must match except in dimension')
  1675. # error inputs for different dimensions
  1676. yield ErrorInput(SampleInput([make_arg((S - 1, 0)), make_arg((S, 0, L - 1, L))], kwargs={'dim': 1}),
  1677. error_regex='Tensors must have same number of dimensions')
  1678. yield ErrorInput(SampleInput([make_arg((S, 0, L - 1, L)), make_arg((S - 1, 0))], kwargs={'dim': 1}),
  1679. error_regex='Tensors must have same number of dimensions')
  1680. # error inputs for same memory locations
  1681. x = torch.zeros((0), device=device)
  1682. y = torch.randn((4, 6), device=device)
  1683. err_msg = "the written-to tensor refer to a single memory location"
  1684. yield ErrorInput(SampleInput((x, y), kwargs={'dim': 0, 'out': x}),
  1685. error_regex=err_msg)
  1686. yield ErrorInput(SampleInput((x, y), kwargs={'dim': 0, 'out': y}),
  1687. error_regex=err_msg)
  1688. z = torch.zeros((4, 6), device=device)
  1689. yield ErrorInput(SampleInput((y, z), kwargs={'out': z[:2, :]}),
  1690. error_regex=err_msg)
  1691. # error inputs for different devices
  1692. if torch.device(device).type == 'cuda':
  1693. x_cuda = make_tensor((3, 3), device=device, dtype=torch.float32)
  1694. y_cpu = make_tensor((3, 3), device='cpu', dtype=torch.float32)
  1695. yield ErrorInput(SampleInput((x_cuda, y_cpu)),
  1696. error_regex='Expected all tensors to be on the same device')
  1697. # error inputs for different input sizes for more than 2 tensors
  1698. yield ErrorInput(SampleInput([make_arg((L, 1)), make_arg((L, 1, 1)), make_arg((L, 1, 1))]),
  1699. error_regex='Tensors must have same number of dimensions')
  1700. yield ErrorInput(SampleInput([make_arg((S, 1, M)), make_arg((S, 1, 1)), make_arg((S, M, 1))],
  1701. kwargs={'dim': 1}),
  1702. error_regex='Sizes of tensors must match')
  1703. # error inputs for None input
  1704. yield ErrorInput(SampleInput((make_arg((S, 1, 1)), None)), error_type=TypeError,
  1705. error_regex='got None')
  1706. # error inputs for zero-dimensional tensors
  1707. yield ErrorInput(SampleInput([make_arg(()), make_arg(())]),
  1708. error_regex='zero-dimensional.*cannot be concatenated')
  1709. # error inputs for different dtype of out tensors
  1710. d = make_tensor((2, 3), device=device, dtype=torch.double)
  1711. x = make_tensor((2, 3), device=device, dtype=torch.float32)
  1712. yield ErrorInput(SampleInput(x, kwargs={'out': d}), error_type=TypeError,
  1713. error_regex='invalid combination of arguments')
  1714. def reference_inputs_cat(op, device, dtype, requires_grad, **kwargs):
  1715. yield from sample_inputs_cat_concat(op, device, dtype, requires_grad, **kwargs)
  1716. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  1717. # Noncontiguous type promoting tensors
  1718. a = make_arg((3, 4, 2))
  1719. b = make_arg((3, 2, 2), noncontiguous=True, dtype=torch.double)
  1720. c = make_arg((3, 3, 2), dtype=torch.float16).permute(1, 0, 2)
  1721. yield SampleInput((a, b, c), kwargs={'dim': 1})
  1722. # Special 1D tensor with dim length of 0 case
  1723. a = make_arg((0,))
  1724. b = make_arg((3, 2, 2))
  1725. yield SampleInput((a, b, a))
  1726. yield SampleInput((a, a, a))
  1727. def _elementwise_type_promo_np(*args, type_promotion_kind):
  1728. def _maybe_torch(x):
  1729. if isinstance(x, np.ndarray):
  1730. return torch.from_numpy(x)
  1731. return x
  1732. flattened = tree_flatten(args)[0]
  1733. transformed = tuple(_maybe_torch(a) for a in flattened)
  1734. result_dtype, _ = prims.utils.elementwise_dtypes(
  1735. *transformed,
  1736. type_promotion_kind=type_promotion_kind)
  1737. return torch_to_numpy_dtype_dict[result_dtype]
  1738. def _cat_np(input_seq, dim=0):
  1739. inputs = tuple(a for a in input_seq if not (a.ndim == 1 and a.size == 0))
  1740. if len(inputs) == 0:
  1741. np_dtype = _elementwise_type_promo_np(
  1742. input_seq,
  1743. type_promotion_kind=prims.utils.ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH)
  1744. return np.empty(0, dtype=np_dtype)
  1745. return np.concatenate(inputs, axis=dim)
  1746. def _floor_divide_np(a, b):
  1747. dtype = _elementwise_type_promo_np(
  1748. a,
  1749. b,
  1750. type_promotion_kind=prims.utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
  1751. if isinstance(a, np.ndarray):
  1752. a = a.astype(dtype)
  1753. if isinstance(b, np.ndarray):
  1754. b = b.astype(dtype)
  1755. return np.floor_divide(a, b)
  1756. def sample_inputs_hstack_dstack_vstack(op_info, device, dtype, requires_grad, **kwargs):
  1757. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  1758. tensor_shapes = (
  1759. # First Tensor being 1-D is special
  1760. # case for hstack
  1761. ((S,), (S,), (S,)),
  1762. ((S, S), (S, S), (S, S)),
  1763. )
  1764. for s1, s2, s3 in tensor_shapes:
  1765. tensors = (make_arg(s1,), make_arg(s2,), make_arg(s3))
  1766. yield SampleInput(tensors)
  1767. def error_inputs_hstack_dstack_vstack(op, device):
  1768. make_arg = partial(make_tensor, dtype=torch.int32, device=device, requires_grad=False)
  1769. tensor_shapes = (
  1770. ((S,), (S, S, S, S), (S,)),
  1771. )
  1772. for s1, s2, s3 in tensor_shapes:
  1773. tensors = (make_arg(s1,), make_arg(s2,), make_arg(s3))
  1774. # Different dimension tensor
  1775. yield ErrorInput(SampleInput(tensors), error_regex="Tensors must have same number of dimensions")
  1776. # empty tensor list
  1777. yield ErrorInput(SampleInput(()), error_regex="expects a non-empty TensorList")
  1778. def sample_inputs_unbind(op_info, device, dtype, requires_grad, **kwargs):
  1779. # Note: we don't do any tests where we unbind along 0-length dims
  1780. # because in that case unbind returns and empty tuple, and that breaks
  1781. # some asumptions in some backward tests in test_ops.py
  1782. shape_dims = (((S,), 0),
  1783. ((S, S), 0),
  1784. ((S, S), 1),
  1785. ((S, S), -1),
  1786. ((S, 0, S), 0),
  1787. ((S, S, S), 1),
  1788. )
  1789. for shape, dim in shape_dims:
  1790. yield SampleInput(make_tensor(shape, dtype=dtype, device=device,
  1791. requires_grad=requires_grad),
  1792. args=(dim,))
  1793. def error_inputs_unbind(op_info, device):
  1794. make_arg = partial(make_tensor, dtype=torch.int32, device=device, requires_grad=False)
  1795. yield ErrorInput(SampleInput(make_arg(()), args=(0,)), error_type=IndexError,
  1796. error_regex="Dimension specified as 0 but tensor has no dimensions")
  1797. yield ErrorInput(SampleInput(make_arg((2,)), args=(2,)), error_type=IndexError,
  1798. error_regex="Dimension out of range")
  1799. def reference_unbind(t, dim):
  1800. """A numpy implementation of torch.unbind"""
  1801. return tuple(s.squeeze(dim) for s in np.split(t, t.shape[dim], dim))
  1802. def sample_inputs_gather(op_info, device, dtype, requires_grad, **kwargs):
  1803. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None)
  1804. yield SampleInput(
  1805. make_arg((M, S)),
  1806. 0,
  1807. gather_variable((S, S), 1, M, True, device=device))
  1808. yield SampleInput(
  1809. make_arg((M, S)),
  1810. 1,
  1811. gather_variable((M, S // 2), 0, S, True, device=device))
  1812. yield SampleInput(
  1813. make_arg(),
  1814. 0,
  1815. torch.tensor([0], dtype=torch.int64, device=device))
  1816. # Empty index tensor case, see: https://github.com/pytorch/pytorch/pull/65006
  1817. yield SampleInput(
  1818. make_arg((S,)),
  1819. 0,
  1820. torch.tensor([], dtype=torch.uint8, device=device))
  1821. yield SampleInput(
  1822. make_arg(()),
  1823. 0,
  1824. torch.tensor(0, dtype=torch.int64, device=device))
  1825. def _fill_indices(idx, dim, dim_size, elems_per_row, m, n, o):
  1826. for i in range(1 if dim == 0 else m):
  1827. for j in range(1 if dim == 1 else n):
  1828. for k in range(1 if dim == 2 else o):
  1829. ii = [i, j, k]
  1830. ii[dim] = slice(0, idx.size(dim) + 1)
  1831. idx[tuple(ii)] = torch.randperm(dim_size)[0:elems_per_row]
  1832. def error_inputs_gather(op_info, device, **kwargs):
  1833. # src is [1, 2]
  1834. # [3, 4]
  1835. src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32)
  1836. # idx is [0, 0]
  1837. # [1, 0]
  1838. idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long)
  1839. # Index should be smaller than self except on dimesion 1
  1840. bad_src = make_tensor((1, 1), device=device, dtype=torch.float32)
  1841. yield ErrorInput(SampleInput(bad_src, args=(1, idx,)),
  1842. error_regex="Size does not match at dimension 0")
  1843. # Index must have long dtype
  1844. bad_idx = idx.to(torch.int32)
  1845. yield ErrorInput(SampleInput(src, args=(1, bad_idx)),
  1846. error_regex="Expected dtype int64 for index")
  1847. # TODO: FIXME
  1848. # out.dtype must match src.dtype
  1849. # Creates new src & idx since SampleInputs can't share tensors
  1850. src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32)
  1851. idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long)
  1852. out = torch.empty((2, 2), device=device, dtype=torch.float64)
  1853. yield ErrorInput(SampleInput(src, args=(1, idx), kwargs={'out': out}),
  1854. error_regex="Expected out tensor to have dtype")
  1855. # src and index tensors must have the same # of dimensions
  1856. # idx too few dimensions
  1857. src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32)
  1858. idx = torch.tensor((0, 0), device=device, dtype=torch.long)
  1859. yield ErrorInput(SampleInput(src, args=(1, idx)),
  1860. error_regex="Index tensor must have the same number of dimensions")
  1861. # src too few dimensions
  1862. src = torch.tensor((1, 2), device=device, dtype=torch.float32)
  1863. idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long)
  1864. yield ErrorInput(SampleInput(src, args=(0, idx)),
  1865. error_regex="Index tensor must have the same number of dimensions")
  1866. # index out of bounds
  1867. # NOTE: this ErrorInput is guarded because bounds checking does not occur on CUDA devices
  1868. if torch.device(device).type == 'cpu':
  1869. src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32)
  1870. idx = torch.tensor(((0, 23), (1, 0)), device=device, dtype=torch.long)
  1871. yield ErrorInput(SampleInput(src, args=(1, idx,)),
  1872. error_regex="index 23 is out of bounds for dimension")
  1873. x = torch.rand((1,), device=device).expand((3,))
  1874. src = torch.rand((6,), device=device)
  1875. ind = torch.tensor([2, 1, 0], device=device, dtype=torch.int64)
  1876. yield ErrorInput(SampleInput(src, args=(0, ind,), kwargs=dict(out=x)),
  1877. error_type=RuntimeError,
  1878. error_regex='unsupported operation')
  1879. yield ErrorInput(SampleInput(src, args=(0, ind,), kwargs=dict(out=src)),
  1880. error_type=RuntimeError,
  1881. error_regex='unsupported operation')
  1882. yield ErrorInput(SampleInput(ind.clone(), args=(0, ind[1:],), kwargs=dict(out=ind[:1])),
  1883. error_type=RuntimeError,
  1884. error_regex='unsupported operation')
  1885. def error_inputs_take(op_info, device, **kwargs):
  1886. x = torch.rand((1,), device=device).expand((3,))
  1887. src = torch.rand((6,), device=device)
  1888. ind = torch.tensor([2, 1, 0], device=device, dtype=torch.int64)
  1889. yield ErrorInput(SampleInput(src, args=(ind,), kwargs=dict(out=x)),
  1890. error_type=RuntimeError,
  1891. error_regex='unsupported operation')
  1892. yield ErrorInput(SampleInput(src, args=(ind,), kwargs=dict(out=src)),
  1893. error_type=RuntimeError,
  1894. error_regex='unsupported operation')
  1895. yield ErrorInput(SampleInput(ind.clone(), args=(ind[1:],), kwargs=dict(out=ind[:-1])),
  1896. error_type=RuntimeError,
  1897. error_regex='unsupported operation')
  1898. # Error inputs for scatter
  1899. def error_inputs_scatter_and_scatter_add(op_info, device, **kwargs):
  1900. # Error when self.dtype != src.dtype (and src is not a scalar)
  1901. src = make_tensor((2, 5), device=device, dtype=torch.float32)
  1902. idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.long)
  1903. dst = torch.zeros((3, 5), device=device, dtype=torch.double)
  1904. yield ErrorInput(SampleInput(dst, args=(0, idx, src)),
  1905. error_regex="Expected self.dtype to be equal to src.dtype")
  1906. # Index dtype must be long
  1907. src = make_tensor((2, 5), device=device, dtype=torch.float32)
  1908. idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.int32)
  1909. dst = torch.zeros((3, 5), device=device, dtype=torch.float32)
  1910. yield ErrorInput(SampleInput(dst, args=(0, idx, src)),
  1911. error_regex="Expected dtype int64 for index")
  1912. # Index and destination must have the same number of dimensions
  1913. src = make_tensor((2, 5), device=device, dtype=torch.float32)
  1914. idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.long)
  1915. dst = torch.zeros((3, 5, 3), device=device, dtype=torch.float32)
  1916. yield ErrorInput(SampleInput(dst, args=(0, idx, src)),
  1917. error_regex="Index tensor must have the same number of dimensions as self tensor")
  1918. # Index and src must have the same number of dimensions when src is not a scalar
  1919. src = make_tensor((2, 5, 2), device=device, dtype=torch.float32)
  1920. idx = torch.tensor(((34, 1), (1, 2)), device=device, dtype=torch.long)
  1921. dst = torch.zeros((3, 5), device=device, dtype=torch.float32)
  1922. yield ErrorInput(SampleInput(dst, args=(0, idx, src)),
  1923. error_regex="Index tensor must have the same number of dimensions as src tensor")
  1924. # Index out of bounds
  1925. # NOTE: this ErrorInput is guarded because bounds checking does not occur on CUDA devices
  1926. if torch.device(device).type == 'cpu':
  1927. src = make_tensor((2, 5), device=device, dtype=torch.float32)
  1928. idx = torch.tensor(((34, 1), (1, 2)), device=device, dtype=torch.long)
  1929. dst = torch.zeros((3, 5), device=device, dtype=torch.float32)
  1930. yield ErrorInput(SampleInput(dst, args=(0, idx, src)),
  1931. error_regex="index 34 is out of bounds for dimension 0 with size 3")
  1932. def error_inputs_renorm(op_info, device, **kwargs):
  1933. zero_d = torch.randn((), device=device)
  1934. yield ErrorInput(SampleInput(zero_d, args=(0.5, 0, 1.0)), error_type=RuntimeError,
  1935. error_regex="needs at least 2 dimensions, got 0 dimensions")
  1936. def error_inputs_ormqr(op_info, device, **kwargs):
  1937. zero_d = torch.randn((), device=device)
  1938. yield ErrorInput(SampleInput(zero_d, args=(zero_d, zero_d)), error_type=RuntimeError,
  1939. error_regex="input must have at least 2 dimensions")
  1940. # https://github.com/pytorch/pytorch/issues/85218
  1941. tensor_0 = torch.full((5, 0,), 1, device=device)
  1942. tensor_1 = torch.full((5,), 1, device=device)
  1943. tensor_2 = torch.full((5, 5,), 1, device=device)
  1944. bool_3 = True
  1945. bool_4 = True
  1946. yield ErrorInput(SampleInput(tensor_0, args=(tensor_1, tensor_2, bool_3, bool_4)), error_type=RuntimeError,
  1947. error_regex=r"tau.shape\[-1\] must be less than or equal to input.shape\[-1\]")
  1948. def error_inputs_diag(op_info, device, **kwargs):
  1949. zero_d = torch.randn((), device=device)
  1950. yield ErrorInput(SampleInput(zero_d, args=(0,)), error_type=RuntimeError,
  1951. error_regex="1D or 2D")
  1952. zero_d = torch.randn(1, 1, 1, device=device)
  1953. yield ErrorInput(SampleInput(zero_d, args=(0,)), error_type=RuntimeError,
  1954. error_regex="1D or 2D")
  1955. def error_inputs_embedding(op_info, device, **kwargs):
  1956. indices = torch.rand(2, 2, device=device).long()
  1957. weights = [
  1958. torch.tensor(1.0, device=device),
  1959. torch.tensor(1.0, device=device).reshape(1, 1, 1),
  1960. ]
  1961. for weight in weights:
  1962. yield ErrorInput(SampleInput(weight, args=(indices,)), error_type=RuntimeError,
  1963. error_regex="'weight' must be 2-D")
  1964. def error_inputs_t(op_info, device, **kwargs):
  1965. yield ErrorInput(
  1966. SampleInput(torch.randn(2, 3, 4, 5, device=device)),
  1967. error_regex="expects a tensor with <= 2",
  1968. )
  1969. def error_inputs_multinomial(op_info, device, **kwargs):
  1970. x = torch.empty(1, 2, 3, dtype=torch.double, device=device)
  1971. yield ErrorInput(SampleInput(x, args=(2,)),
  1972. error_regex="prob_dist must be 1 or 2 dim")
  1973. x = torch.empty(1, 2, dtype=torch.long, device=device)
  1974. yield ErrorInput(SampleInput(x, args=(2,)),
  1975. error_regex="multinomial only supports floating-point dtypes for input")
  1976. x = torch.empty(1, 2, dtype=torch.double, device=device)
  1977. y = torch.empty(1, 2, dtype=torch.double, device=device)
  1978. yield ErrorInput(SampleInput(x, args=(2,), kwargs=dict(out=y)),
  1979. error_regex="multinomial expects Long tensor out")
  1980. x = torch.empty(2, dtype=torch.double, device=device)
  1981. yield ErrorInput(SampleInput(x, args=(0,)),
  1982. error_regex="cannot sample n_sample <= 0 samples")
  1983. x = torch.empty(2, dtype=torch.double, device=device)
  1984. yield ErrorInput(SampleInput(x, args=(-1,)),
  1985. error_regex="cannot sample n_sample <= 0 samples")
  1986. x = torch.empty(2, dtype=torch.double, device=device)
  1987. yield ErrorInput(SampleInput(x, args=(3, False,)),
  1988. error_regex="cannot sample n_sample > prob_dist")
  1989. x = torch.empty(16777217, dtype=torch.double, device=device)
  1990. yield ErrorInput(SampleInput(x, args=(3,)),
  1991. error_regex="number of categories cannot exceed")
  1992. inputs = ((1., -1., 1.), (1., inf, 1.), (1., -inf, 1.), (1., 1., nan))
  1993. err_msg1 = "probability tensor contains either `inf`, `nan` or element < 0"
  1994. err_msg2 = "invalid multinomial distribution"
  1995. rep_arg = (False, True) if torch.device(device).type == 'cpu' else (False,)
  1996. for rep in rep_arg:
  1997. kwargs = {'num_samples': 2, 'replacement': rep}
  1998. for shape in inputs:
  1999. # error case when input tensor contains `inf`, `nan` or negative element
  2000. yield ErrorInput(SampleInput(torch.tensor(shape), kwargs=kwargs),
  2001. error_regex=err_msg1 if rep is False else err_msg2)
  2002. # error case for the invalid multinomial distribution (sum of probabilities <= 0), 1-D input
  2003. x = torch.zeros(3, device=device)
  2004. yield ErrorInput(SampleInput(x, kwargs=kwargs),
  2005. error_regex=err_msg2)
  2006. # error case for the invalid multinomial distribution (sum of probabilities <= 0), 2-D input
  2007. x = torch.zeros(3, 3, device=device)
  2008. yield ErrorInput(SampleInput(x, kwargs=kwargs),
  2009. error_regex=err_msg2)
  2010. # error case for the invalid multinomial distribution
  2011. x[1, :] = 1
  2012. yield ErrorInput(SampleInput(x, kwargs=kwargs),
  2013. error_regex=err_msg2)
  2014. def error_inputs_gradient(op_info, device, **kwargs):
  2015. for dtype in [torch.long, torch.float32, torch.complex64]:
  2016. t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], device=device, dtype=dtype)
  2017. dim = (1, 0)
  2018. spacing = [0.1]
  2019. yield ErrorInput(SampleInput(t, kwargs=dict(spacing=spacing, dim=dim, edge_order=1)),
  2020. error_type=RuntimeError,
  2021. error_regex='torch.gradient expected spacing to be unspecified, a scalar ')
  2022. yield ErrorInput(SampleInput(t, kwargs=dict(edge_order=3)),
  2023. error_type=RuntimeError,
  2024. error_regex='torch.gradient only supports edge_order=1 and edge_order=2.')
  2025. dim = (1, 1)
  2026. spacing = 0.1
  2027. yield ErrorInput(SampleInput(t, kwargs=dict(spacing=spacing, dim=dim, edge_order=1)),
  2028. error_type=RuntimeError,
  2029. error_regex='dim 1 appears multiple times in the list of dims')
  2030. dim = (0, 1)
  2031. coordinates = [torch.tensor([1, 2, 4], device='cpu'), torch.tensor([1, 2, 4], device='meta')]
  2032. yield ErrorInput(SampleInput(t, kwargs=dict(spacing=coordinates, dim=dim, edge_order=1)),
  2033. error_type=RuntimeError,
  2034. error_regex='torch.gradient expected each tensor to be on the same device,')
  2035. yield ErrorInput(SampleInput(t, kwargs=dict(dim=3)),
  2036. error_type=IndexError, error_regex='')
  2037. t = torch.tensor([[1], [2], [3]])
  2038. yield ErrorInput(SampleInput(t, kwargs=dict(edge_order=1)),
  2039. error_type=RuntimeError,
  2040. error_regex='torch.gradient expected each dimension size to be at least')
  2041. t = torch.tensor([[1, 2], [3, 4]])
  2042. yield ErrorInput(SampleInput(t, kwargs=dict(edge_order=2)),
  2043. error_type=RuntimeError,
  2044. error_regex='torch.gradient expected each dimension size to be at least')
  2045. def error_inputs_rrelu(op_info, device, **kwargs):
  2046. input = make_tensor((S, S), device=device, dtype=torch.float32)
  2047. yield ErrorInput(SampleInput(input, kwargs={'lower': 0.3, 'upper': 0.1}),
  2048. error_regex='Lower bound should be less than or equal to the upper bound')
  2049. def error_inputs_masked_select(op_info, device, **kwargs):
  2050. x = torch.rand((1,), device=device).expand((3,))
  2051. y = torch.rand((6,), device=device)
  2052. mask = torch.tensor([True, False, True, True, False, False], device=device)
  2053. yield ErrorInput(SampleInput(y, args=(mask,), kwargs=dict(out=x)),
  2054. error_type=RuntimeError,
  2055. error_regex='unsupported operation')
  2056. yield ErrorInput(SampleInput(y, args=(mask,), kwargs=dict(out=y)),
  2057. error_type=RuntimeError,
  2058. error_regex='unsupported operation')
  2059. yield ErrorInput(SampleInput(mask.clone(), args=(mask,), kwargs=dict(out=mask)),
  2060. error_type=RuntimeError,
  2061. error_regex='unsupported operation')
  2062. def error_inputs_median(op_info, device, **kwargs):
  2063. x = torch.tensor([[[[[[[[[[[[[[[[[[[[[[[[[nan],
  2064. [nan]]]]]]]]]]]]]]]]]]]]]]]]], device=device)
  2065. if device == 'cuda':
  2066. yield ErrorInput(SampleInput(x, kwargs=dict(dim=(-1))),
  2067. error_type=RuntimeError,
  2068. error_regex='CUDA Tensors cannot have more than 25 dimensions')
  2069. else:
  2070. return
  2071. def error_inputs_index_select(op_info, device, **kwargs):
  2072. x = torch.rand((1, 6), device=device).expand((2, 6))
  2073. y = torch.rand((3, 6), device=device)
  2074. ind = torch.tensor([0, 1], dtype=torch.int64, device=device)
  2075. yield ErrorInput(SampleInput(y, args=(1, ind,), kwargs=dict(out=x)),
  2076. error_type=RuntimeError,
  2077. error_regex='unsupported operation')
  2078. def error_inputs_logcumsumexp(op_info, device, **kwargs):
  2079. dim = 3
  2080. srcs = [torch.randn(5, 2, device=device), torch.randn(0, 2, device=device)]
  2081. for src in srcs:
  2082. yield ErrorInput(SampleInput(src, args=(dim,)),
  2083. error_type=IndexError,
  2084. error_regex='Dimension out of range')
  2085. def sample_inputs_take_along_dim(op_info, device, dtype, requires_grad, **kwargs):
  2086. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None)
  2087. yield SampleInput(
  2088. make_arg((S, S)), gather_variable((S, S), 1, S, True, device=device), 0)
  2089. # `indices` broadcast
  2090. yield SampleInput(
  2091. make_arg((S, S)), gather_variable((1, S // 2), 0, S, True, device=device), 1)
  2092. # `self` broadcast
  2093. yield SampleInput(
  2094. make_arg((1, S)), gather_variable((S, S // 2), 0, S, True, device=device), 1)
  2095. # without `dim` arg
  2096. yield SampleInput(
  2097. make_arg((S, S)), gather_variable((S, S // 2), 0, S, True, device=device))
  2098. yield SampleInput(
  2099. make_arg((S, S)), gather_variable((S, S // 2), 0, S, True, device=device))
  2100. def error_inputs_aminmax_amax_amin(op_info, device, is_ref=False, **kwargs):
  2101. # Error Inputs for zero-dim tensors, when 'dim' arg is not provided.
  2102. shape = (S, 0, S)
  2103. err_msg_amax_amin = "reduction"
  2104. err_msg_aminmax = "cannot compute aminmax over an empty dimension as the operation has no identity"
  2105. if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']:
  2106. yield ErrorInput(SampleInput(torch.rand(shape, device=device)), error_regex=err_msg_amax_amin)
  2107. elif op_info.name in ['aminmax']:
  2108. yield ErrorInput(SampleInput(torch.rand(shape, device=device)), error_regex=err_msg_aminmax)
  2109. # Error Inputs for tensors with more than 64 dimension
  2110. sizes = [1] * 65
  2111. err_msg1 = "only tensors with up to 64 dims are supported"
  2112. yield ErrorInput(SampleInput(torch.randn(sizes, device=device), kwargs={'dim': -1}),
  2113. error_regex=err_msg1)
  2114. yield ErrorInput(SampleInput(torch.randn(sizes, device=device), kwargs={'dim': 64}),
  2115. error_regex=err_msg1)
  2116. # Error Inputs for repeated 'dim'
  2117. if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']:
  2118. dims = [(0, 0), (0, -4)]
  2119. err_msg2 = "in the list of dims"
  2120. x = torch.randn(S, S, S, S, device=device)
  2121. for dim in dims:
  2122. yield ErrorInput(SampleInput(x, kwargs={'dim': dim}), error_regex=err_msg2)
  2123. # Error Input for illegal dtype
  2124. input5 = torch.randn(L, L, dtype=torch.float32, device=device)
  2125. max_values = torch.empty(L, dtype=torch.float32, device=device)
  2126. min_values = torch.empty(L, dtype=torch.double, device=device)
  2127. illegal_values = torch.empty(L, dtype=torch.int, device=device)
  2128. # Unlike regular PyTorch, amax and amin refs don't require input and out
  2129. # dtypes to match exactly:
  2130. # https://github.com/pytorch/pytorch/pull/87765#pullrequestreview-1162023824
  2131. if is_ref:
  2132. err_msg_amax_amin2 = ("Attempting to cast from torch.float32 to out tensor with dtype "
  2133. "torch.int32, but this can't be cast because it is not safe!")
  2134. else:
  2135. err_msg_amax_amin2 = ("Expected the dtype for input and out to match, but got Float "
  2136. "for input's dtype and Int for out's dtype.")
  2137. err_msg_aminmax2 = "Expected out tensor to have dtype float, but got double instead"
  2138. if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']:
  2139. yield ErrorInput(SampleInput(input5, kwargs={'dim': 0, 'out': illegal_values}),
  2140. error_regex=err_msg_amax_amin2)
  2141. elif op_info.name in ['aminmax']:
  2142. yield ErrorInput(SampleInput(input5, kwargs={'dim': 0, 'out': (max_values, min_values)}),
  2143. error_regex=err_msg_aminmax2)
  2144. # Error Inputs for functions to raise an error on specified zero'd dimension as reduction dim
  2145. err_msg3 = "reduction"
  2146. # FIXME: eager and ref impl throw different types of errors
  2147. error_type = IndexError if 'refs' not in op_info.name else RuntimeError
  2148. yield ErrorInput(SampleInput(torch.rand(shape, device=device), kwargs={'dim': 1}),
  2149. error_type=error_type, error_regex=err_msg3)
  2150. def sample_inputs_aminmax(op_info, device, dtype, requires_grad, **kwargs):
  2151. test_cases: Tuple[tuple, dict] = ( # type: ignore[assignment]
  2152. ((S, S, S), {}),
  2153. ((S, S, S), {'dim': 1}),
  2154. ((S, S, S), {'dim': 1, 'keepdim': True}),
  2155. ((), {'dim': 0}),
  2156. ((), {}),
  2157. ((), {'dim': 0, 'keepdim': True}),
  2158. )
  2159. for shape, kwargs in test_cases:
  2160. yield SampleInput(
  2161. make_tensor(shape, dtype=dtype, device=device, requires_grad=requires_grad),
  2162. **kwargs)
  2163. def sample_inputs_diff(op_info, device, dtype, requires_grad, **kwargs):
  2164. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  2165. test_cases = (
  2166. ((1,), 0, None, None),
  2167. ((S,), 0, None, None),
  2168. ((S, 1), 0, None, None),
  2169. ((S, 1), 1, None, None),
  2170. ((S, S), 0, None, None),
  2171. ((S, S), 1, None, None),
  2172. ((S, S), 0, (1, S), (2, S)),
  2173. ((S, S), 0, None, (2, S)),
  2174. ((XS, XS, XS), 1, None, None),
  2175. ((XS, XS, XS), 2, None, None),
  2176. ((XS, XS, XS), 1, (XS, 1, XS), (XS, 1, XS)),
  2177. ((XS, XS, XS), 2, (XS, XS, 1), (XS, XS, 1)),
  2178. ((XS, XS, XS), 2, (XS, XS, XS), (XS, XS, XS)),)
  2179. sample_inputs = []
  2180. for size, dim, size_prepend, size_append in test_cases:
  2181. prepend_size = 0 if (size_prepend is None) else size_prepend[dim]
  2182. append_size = 0 if (size_append is None) else size_append[dim]
  2183. dim_size = size[dim] + prepend_size + append_size
  2184. for n in range(dim_size):
  2185. input_tensor = make_arg(size)
  2186. prepend = make_arg(size_prepend) if size_prepend else None
  2187. append = make_arg(size_append) if size_append else None
  2188. yield SampleInput(input_tensor, n, dim, prepend, append)
  2189. # add some samples with n > dim_size
  2190. yield SampleInput(make_arg((XS, XS, XS)), S + 1, 1)
  2191. yield SampleInput(make_arg((XS, XS, XS)), S * 3 + 2, 2, make_arg((XS, XS, XS)), make_arg((XS, XS, XS)))
  2192. def sample_inputs_histogram(op_info, device, dtype, requires_grad, **kwargs):
  2193. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  2194. sizes = ((), (S,), (S, S), (S, S, S), (S, 1, S), (S, 0, S))
  2195. for size, bin_ct, weighted, density in product(sizes, range(1, 5), [False, True], [False, True]):
  2196. input_tensor = make_arg(size)
  2197. weight_tensor = make_arg(size) if weighted else None
  2198. yield SampleInput(input_tensor, bin_ct,
  2199. weight=weight_tensor, density=density)
  2200. bins_tensor = make_arg((bin_ct + 1,))
  2201. yield SampleInput(input_tensor, bins_tensor,
  2202. weight=weight_tensor, density=density)
  2203. def sample_inputs_histogramdd(op_info, device, dtype, requires_grad, **kwargs):
  2204. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  2205. sizes = ((S, S), (S, S, S), (S, 1, S), (S, 0, S))
  2206. bin_ct_patterns = ((1, 1, 1, 1, 1), (2, 3, 2, 3, 2), (3, 2, 3, 2, 3))
  2207. for size, bin_ct_pattern, weighted, density in product(sizes, bin_ct_patterns, [False, True], [False, True]):
  2208. input_tensor = make_arg(size)
  2209. bin_ct = bin_ct_pattern[:size[-1]]
  2210. weight_tensor = make_arg(size[:-1]) if weighted else None
  2211. yield SampleInput(input_tensor, bin_ct,
  2212. weight=weight_tensor, density=density)
  2213. bins_tensor = [make_arg(ct + 1) for ct in bin_ct]
  2214. yield SampleInput(input_tensor, bins_tensor,
  2215. weight=weight_tensor, density=density)
  2216. def sample_inputs_histc(op_info, device, dtype, requires_grad, **kwargs):
  2217. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  2218. sizes = ((), (S,), (S, S), (S, S, S), (S, 1, S), (S, 0, S))
  2219. for size, min, max in product(sizes, [0, -10], [0, 10]):
  2220. # construct sample input omitting bins arg
  2221. yield SampleInput(make_arg(size), min=min, max=max)
  2222. # construct sample inputs with a few different bins values
  2223. for bins in [1, 3, 10]:
  2224. yield SampleInput(make_arg(size), bins=bins, min=min, max=max)
  2225. def sample_inputs_bincount(op_info, device, dtype, requires_grad, **kwargs):
  2226. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  2227. for size, weighted in product((S, M), [False, True]):
  2228. input_tensor = torch.randint(0, size, (size,), dtype=dtype, device=device)
  2229. weight_tensor = make_arg((size,)) if weighted else None
  2230. max_val = int(input_tensor.max().item())
  2231. for minlength in [0, max_val // 2, max_val, 2 * max_val]:
  2232. yield SampleInput(
  2233. input_tensor, weights=weight_tensor, minlength=minlength)
  2234. def sample_inputs_bucketize(op_info, device, dtype, requires_grad, reference_inputs_mode=False, **kwargs):
  2235. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  2236. sizes = (((), S), ((S,), S), ((S, S), S), ((S, S, S), S), ((S, 1, S), S), ((S, 0, S), S))
  2237. if reference_inputs_mode:
  2238. sizes += (((256,), 128), ((128,), 256), ((32, 32), 11), ((32, 4, 32), 33))
  2239. for (input_shape, nb), out_int32, right in product(sizes, [False, True], [False, True]):
  2240. input_tensor = make_arg(input_shape)
  2241. boundaries = make_arg(nb).msort()
  2242. yield SampleInput(input_tensor, boundaries,
  2243. out_int32=out_int32, right=right)
  2244. reference_inputs_bucketize = partial(sample_inputs_bucketize, reference_inputs_mode=True)
  2245. def sample_inputs_searchsorted(op_info, device, dtype, requires_grad, **kwargs):
  2246. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  2247. sizes = ((0,), (M,), (0, 0), (M, M), (0, 0, 0), (M, M, M))
  2248. for size, noncontiguous, out_int32, right in product(sizes, [False, True], [False, True], [False, True]):
  2249. unsorted_tensor = make_arg(size, noncontiguous=noncontiguous)
  2250. input_tensor = make_arg(size, noncontiguous=noncontiguous)
  2251. if np.product(size) == 0:
  2252. boundary_tensor = unsorted_tensor
  2253. sorter = make_tensor(size, dtype=torch.int64, device=device, noncontiguous=noncontiguous)
  2254. else:
  2255. boundary_tensor, sorter = torch.sort(unsorted_tensor)
  2256. side = "right" if right else "left"
  2257. yield SampleInput(boundary_tensor, input_tensor, out_int32=out_int32, right=right)
  2258. yield SampleInput(boundary_tensor, input_tensor, out_int32=out_int32, side=side)
  2259. yield SampleInput(unsorted_tensor, input_tensor, out_int32=out_int32, right=right, sorter=sorter)
  2260. yield SampleInput(unsorted_tensor, input_tensor, out_int32=out_int32, side=side, sorter=sorter)
  2261. def sample_inputs_gradient(op_info, device, dtype, requires_grad, **kwargs):
  2262. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None)
  2263. test_cases_float = (
  2264. ((S,), None, None, 1),
  2265. ((S,), 2., None, 1),
  2266. ((S, S), None, None, 2),
  2267. ((S, S), [2.0, 2.1], None, 1),
  2268. ((S, S), [2.0, 2.1], (0, 1), 1),
  2269. ((4, 4, 4), [2., 1.], (0, 1), 2),
  2270. )
  2271. for size, spacing, dim, edge_order in test_cases_float:
  2272. t = make_arg(size)
  2273. yield SampleInput(t, dim=dim, spacing=spacing, edge_order=edge_order)
  2274. test_cases_tensor = (
  2275. ((3, 3, 3), ((1.1, 2.0, 3.5), (4.0, 2, 6.0)), (0, -1), 1),
  2276. ((3, 3, 3), ((1.0, 3.0, 2.0), (8.0, 6.0, 1.0)), (0, 1), 2),
  2277. )
  2278. for size, coordinates, dim, edge_order in test_cases_tensor:
  2279. t = make_arg(size)
  2280. coordinates_tensor_list = []
  2281. for coords in coordinates:
  2282. # `coords` will always contain floating point values and Python 3.10 does not support this
  2283. # implicit conversion to an integer using `__int__`
  2284. # TODO: this can be simplified after https://github.com/pytorch/pytorch/issues/69316 is fixed
  2285. a = torch.tensor(coords, device=device)
  2286. coordinates_tensor_list.append(a.to(dtype))
  2287. yield SampleInput(t, dim=dim, spacing=coordinates_tensor_list, edge_order=edge_order)
  2288. def sample_inputs_getitem(op_info, device, dtype, requires_grad, **kwargs):
  2289. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  2290. test_args = [
  2291. ([1, 2],),
  2292. (slice(0, 3),),
  2293. ([slice(0, 3), 1],),
  2294. ([[0, 2, 3], [1, 3, 3], [0, 0, 2]],),
  2295. ([[0, 0, 3], [1, 1, 3], [0, 0, 2]],),
  2296. ([slice(None), slice(None), [0, 3]],),
  2297. ([slice(None), [0, 3], slice(None)],),
  2298. ([[0, 3], slice(None), slice(None)],),
  2299. ([[0, 3], [1, 2], slice(None)],),
  2300. ([[0, 3], ],),
  2301. ([[0, 3], slice(None)],),
  2302. ([[0, 3], Ellipsis],),
  2303. ([[0, 2, 3], [1, 3, 3], torch.LongTensor([0, 0, 2])],),
  2304. (index_variable(2, S, device=device),),
  2305. (mask_not_all_zeros((S,)),),
  2306. ]
  2307. for args in test_args:
  2308. yield SampleInput(make_arg((S, S, S)), args=args)
  2309. yield SampleInput(make_arg((S, S, S, S)), args=([slice(None), [0, 1], slice(None), [0, 1]],))
  2310. def sample_inputs_index_put(op_info, device, dtype, requires_grad, **kwargs):
  2311. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  2312. for accumulate in [False, True]:
  2313. # Test with indices arg
  2314. yield SampleInput(
  2315. make_arg((S, S,)),
  2316. (index_variable(2, S, device=device),),
  2317. make_arg((2, S)),
  2318. accumulate=accumulate)
  2319. # Test with mask arg
  2320. mask = torch.zeros(S, dtype=torch.bool) if accumulate else mask_not_all_zeros((S,))
  2321. yield SampleInput(
  2322. make_arg((S, S)), (mask, ), make_arg((S,)), accumulate=accumulate)
  2323. def sample_inputs_sort(op_info, device, dtype, requires_grad, **kwargs):
  2324. def small_3d_unique():
  2325. res = torch.randperm(S * S * S, dtype=torch.int64, device=device).view(S, S, S)
  2326. res = res.to(dtype).requires_grad_(requires_grad)
  2327. return res
  2328. def large_1d_unique():
  2329. res = torch.randperm(L * L * L, dtype=torch.int64, device=device)
  2330. res = res.to(dtype).requires_grad_(requires_grad)
  2331. return res
  2332. # Test case for large tensor.
  2333. yield SampleInput(large_1d_unique())
  2334. # Test cases for small 3d tensors.
  2335. # Imitates legacy tests from test/test_torch.py
  2336. dims = range(-3, 3)
  2337. flag = [True, False]
  2338. for dim, descending, stable in product(dims, flag, flag):
  2339. # default schema without stable sort
  2340. yield SampleInput(small_3d_unique(), dim, descending)
  2341. # schema with stable sort, no CUDA support yet
  2342. if torch.device(device).type == 'cpu':
  2343. yield SampleInput(
  2344. small_3d_unique(), dim=dim, descending=descending, stable=stable)
  2345. # Test cases for scalar tensor
  2346. tensor_opt = dict(dtype=dtype, device=device, requires_grad=requires_grad)
  2347. yield SampleInput(torch.tensor(1, **tensor_opt))
  2348. yield SampleInput(torch.tensor(1, **tensor_opt), 0)
  2349. yield SampleInput(torch.tensor(1, **tensor_opt), 0, True)
  2350. # Test cases for stable sort
  2351. yield SampleInput(small_3d_unique(), stable=True)
  2352. yield SampleInput(small_3d_unique(), dim=0, stable=True)
  2353. yield SampleInput(small_3d_unique(), dim=0, descending=True, stable=True)
  2354. def sample_inputs_threshold(op_info, device, dtype, requires_grad, **kwargs):
  2355. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  2356. sizes = ((), (S,), (S, S), (S, S, S))
  2357. for x_size in sizes:
  2358. # threshold and values args must be numbers
  2359. yield SampleInput(make_arg(x_size), make_arg(()).item(), make_arg(()).item())
  2360. def sample_inputs_argsort(*args, **kwargs):
  2361. return (sample_input for sample_input in sample_inputs_sort(*args, **kwargs)
  2362. if "stable" not in sample_input.kwargs)
  2363. def sample_inputs_unique(op_info, device, dtype, requires_grad, **kwargs):
  2364. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  2365. sizes = ((), (S,), (S, S), (S, S, S), (S, 1, S), (S, 0, S))
  2366. for shape, sorted, return_inverse, return_counts, dim in \
  2367. product(sizes, [False, True], [False, True], [False, True], [None, -2, -1, 0, 1, 2]):
  2368. # torch.unique cannot be called if the input tensor has a zero dimension which isn't the selected dim
  2369. if 0 in shape and shape.index(0) is not dim:
  2370. continue
  2371. # skip invalid dim args
  2372. if dim is not None and (dim < -len(shape) or dim >= len(shape)):
  2373. continue
  2374. kwargs = dict(sorted=sorted, return_inverse=return_inverse, return_counts=return_counts, dim=dim)
  2375. # construct a test case with only one distinct value
  2376. input_t = torch.zeros(shape, dtype=dtype, device=device, requires_grad=requires_grad)
  2377. yield SampleInput(input_t, **kwargs)
  2378. # construct a test case with mixed 0s and 1s
  2379. input_t = make_arg(shape, dtype=torch.bool, requires_grad=False)\
  2380. .to(dtype).requires_grad_(requires_grad)
  2381. yield SampleInput(input_t, **kwargs)
  2382. # construct a test case with many different values
  2383. yield SampleInput(make_arg(shape), **kwargs)
  2384. def sample_inputs_unique_consecutive(*args, **kwargs):
  2385. for sample_input in sample_inputs_unique(*args, **kwargs):
  2386. if not sample_input.kwargs["sorted"]:
  2387. sample_input.kwargs.pop("sorted")
  2388. yield sample_input
  2389. def sample_inputs_adaptive_avg_pool1d(op_info, device, dtype, requires_grad, **kwargs):
  2390. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  2391. # Ordered as (input shape, output size)
  2392. cases = (
  2393. ((0, 8, 8), (5,)),
  2394. ((3, 8, 8), 5),
  2395. ((3, 8, 8), 1)
  2396. )
  2397. for input_shape, output_size in cases:
  2398. # Batched
  2399. yield SampleInput(make_arg(input_shape), args=(output_size,))
  2400. # Unbatched
  2401. yield SampleInput(make_arg(input_shape[1:]), args=(output_size,))
  2402. def error_inputs_adaptive_avg_pool1d(opinfo, device, **kwargs):
  2403. make_arg = partial(make_tensor, device=device, dtype=torch.float32)
  2404. # error inputs for empty output
  2405. yield ErrorInput(SampleInput(make_arg((1, 2, 3)), output_size=()),
  2406. error_regex="'output_size' should contain one int")
  2407. # error inputs for output_size lesser than 0
  2408. yield ErrorInput(SampleInput(make_arg((1, 1, 1)), output_size=(-1,)),
  2409. error_regex="elements of output_size must be greater than or equal to 0")
  2410. def sample_inputs_adaptive_avg_pool2d(op_info, device, dtype, requires_grad, **kwargs):
  2411. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  2412. # Ordered as (input shape, output size)
  2413. cases = (
  2414. ((1, 8, 8, 8), (5, 7)),
  2415. ((2, 8, 8, 8), (None, 7)),
  2416. ((1, 8, 4, 3), (5, None)),
  2417. ((1, 8, 4, 3), (None, None)),
  2418. ((1, 8, 4, 3), (5)),
  2419. )
  2420. for input_shape, output_size in cases:
  2421. # Batched
  2422. yield SampleInput(make_arg(input_shape), args=(output_size,))
  2423. # Unbatched
  2424. yield SampleInput(make_arg(input_shape[1:]), args=(output_size,))
  2425. def error_inputs_adaptive_avg_pool2d(opinfo, device, **kwargs):
  2426. make_arg = partial(make_tensor, device=device, dtype=torch.float32)
  2427. # error inputs for incorrect input dimension
  2428. yield ErrorInput(SampleInput(make_arg((2, 2)), output_size=(2, 2)),
  2429. error_type=ValueError, error_regex="Input dimension should be at least 3")
  2430. # error inputs for empty output
  2431. yield ErrorInput(SampleInput(make_arg((1, 2, 3, 4)), output_size=()),
  2432. error_regex="output_size must be 2")
  2433. # error inputs for output_size lesser than 0
  2434. yield ErrorInput(SampleInput(make_arg((1, 1, 1, 1)), output_size=(-1, 0)),
  2435. error_regex="elements of output_size must be greater than or equal to 0")
  2436. def sample_inputs_adaptive_avg_pool3d(op_info, device, dtype, requires_grad, **kwargs):
  2437. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  2438. # Ordered as (input shape, output size)
  2439. cases = (
  2440. ((0, 8, 8, 8, 8), (5, 7, 4)),
  2441. ((1, 8, 4, 3, 7), (None, None, None)),
  2442. ((1, 8, 4, 3, 7), (1, 1, 1)),
  2443. ((3, 3, 8, 8, 6), (5, 7, None)),
  2444. ((1, 3, 8, 8, 6), (5, None, 2)),
  2445. ((3, 3, 8, 8, 6), (None, 3, 2)),
  2446. )
  2447. for input_shape, output_size in cases:
  2448. # Batched
  2449. yield SampleInput(make_arg(input_shape), args=(output_size,))
  2450. # Unbatched
  2451. yield SampleInput(make_arg(input_shape[1:]), args=(output_size,))
  2452. def error_inputs_adaptive_avg_pool3d(opinfo, device, **kwargs):
  2453. make_arg = partial(make_tensor, device=device, dtype=torch.float32)
  2454. # error inputs for incorrect input dimension
  2455. yield ErrorInput(SampleInput(make_arg((2, 2, 2)), output_size=(2, 2, 2)),
  2456. error_type=ValueError, error_regex="Input dimension should be at least 4")
  2457. # error inputs for empty output
  2458. yield ErrorInput(SampleInput(make_arg((1, 2, 3, 4)), output_size=()),
  2459. error_regex="output_size must be 3")
  2460. # error inputs for output_size lesser than 0
  2461. yield ErrorInput(SampleInput(make_arg((1, 1, 1, 1, 1)), output_size=(-1, 0, 2)),
  2462. error_regex="elements of output_size must be greater than or equal to 0")
  2463. def sample_inputs_adaptive_max_pool1d(op_info, device, dtype, requires_grad, **kwargs):
  2464. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  2465. # Ordered as (input shape, output size)
  2466. cases = (
  2467. # ((0, 8, 8), (5,)),
  2468. # 0 batch size doesn't work, cannot reshape tensor of 0 elements into shape [0, 8, -1]
  2469. ((3, 4, 4), 3),
  2470. ((3, 4, 4), 1)
  2471. )
  2472. for shapes, return_idx in product(cases, (True, False)):
  2473. # Batched
  2474. yield SampleInput(make_arg(shapes[0]), args=(shapes[1], return_idx))
  2475. # Unbatched
  2476. yield SampleInput(make_arg(shapes[0][1:]), args=(shapes[1], return_idx))
  2477. def error_inputs_adaptive_max_pool1d(opinfo, device, **kwargs):
  2478. make_arg = partial(make_tensor, device=device, dtype=torch.float32)
  2479. # error inputs for empty output
  2480. yield ErrorInput(SampleInput(make_arg((1, 2, 3)), output_size=()),
  2481. error_regex="'output_size' should contain one int")
  2482. # error inputs for output_size lesser than 0
  2483. yield ErrorInput(SampleInput(make_arg((1, 1, 1)), output_size=(-1,)),
  2484. error_regex="Trying to create tensor with negative dimension")
  2485. def sample_inputs_adaptive_max_pool2d(op_info, device, dtype, requires_grad, **kwargs):
  2486. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  2487. # Ordered as (input shape, output size)
  2488. cases = (
  2489. # ((0, 8, 8, 8), (5, 7)),
  2490. # 0 batch size doesn't work, cannot reshape tensor of 0 elements into shape [0, 8, -1]
  2491. ((1, 4, 4, 4), (2, 3)),
  2492. ((2, 4, 4, 4), (None, 3)),
  2493. ((2, 4, 4, 4), (1, 1)),
  2494. ((1, 4, 4, 3), (3, None)),
  2495. ((1, 4, 4, 3), (None, None)),
  2496. ((1, 4, 4, 3), (3)),
  2497. )
  2498. for shapes, return_idx in product(cases, (True, False)):
  2499. # Batched
  2500. yield SampleInput(make_arg(shapes[0]), args=(shapes[1], return_idx))
  2501. # Unbatched
  2502. yield SampleInput(make_arg(shapes[0][1:]), args=(shapes[1], return_idx))
  2503. def error_inputs_adaptive_max_pool2d(opinfo, device, **kwargs):
  2504. make_arg = partial(make_tensor, device=device, dtype=torch.float32)
  2505. # error inputs for incorrect input dimension
  2506. yield ErrorInput(SampleInput(make_arg((2, 2)), output_size=(2, 2)),
  2507. error_type=ValueError, error_regex="Input dimension should be at least 3")
  2508. # error inputs for empty output
  2509. yield ErrorInput(SampleInput(make_arg((1, 2, 3, 4)), output_size=()),
  2510. error_regex="internal error")
  2511. # error inputs for output_size lesser than 0
  2512. yield ErrorInput(SampleInput(make_arg((1, 1, 1, 1)), output_size=(-1, 0)),
  2513. error_regex="Trying to create tensor with negative dimension")
  2514. def sample_inputs_adaptive_max_pool3d(op_info, device, dtype, requires_grad, **kwargs):
  2515. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  2516. # Ordered as (input shape, output size)
  2517. cases = (
  2518. # ((0, 8, 8, 8, 8), (5, 7, 4)),
  2519. # 0 batch size doesn't work, cannot reshape tensor of 0 elements into shape [0, 8, -1]
  2520. ((1, 4, 4, 3, 5), (None, None, None)),
  2521. ((1, 4, 4, 3, 5), (1, 1, 1)),
  2522. ((3, 3, 4, 4, 6), (2, 3, None)),
  2523. ((1, 3, 4, 4, 6), (3, None, 2)),
  2524. ((3, 3, 4, 4, 6), (None, 3, 2)),
  2525. )
  2526. for shapes, return_idx in product(cases, (True, False)):
  2527. # Batched
  2528. yield SampleInput(make_arg(shapes[0]), args=(shapes[1], return_idx))
  2529. # Unbatched
  2530. yield SampleInput(make_arg(shapes[0][1:]), args=(shapes[1], return_idx))
  2531. def error_inputs_adaptive_max_pool3d(opinfo, device, **kwargs):
  2532. make_arg = partial(make_tensor, device=device, dtype=torch.float32)
  2533. # error inputs for incorrect input dimension
  2534. yield ErrorInput(SampleInput(make_arg((2, 2, 2)), output_size=(2, 2, 2)),
  2535. error_type=ValueError, error_regex="Input dimension should be at least 4")
  2536. # error inputs for empty output
  2537. yield ErrorInput(SampleInput(make_arg((1, 2, 3, 4)), output_size=()),
  2538. error_regex="internal error")
  2539. # error inputs for output_size lesser than 0
  2540. yield ErrorInput(SampleInput(make_arg((1, 1, 1, 1, 1)), output_size=(-1, 0, 2)),
  2541. error_regex="Trying to create tensor with negative dimension")
  2542. def sample_inputs_reduction_sparse(op_info, device, dtype, requires_grad, layout, blocksize=None, **kwargs):
  2543. layout_name = str(layout).split('.', 1)[-1].rsplit('_coo', 1)[0]
  2544. op_supports_layout = getattr(op_info, 'supports_' + layout_name)
  2545. if not op_supports_layout:
  2546. return
  2547. for sample_input in sample_inputs_reduction(op_info, device, dtype, requires_grad, **kwargs):
  2548. if sample_input.input.ndim == 0:
  2549. # scalar sparse tensors are not supported
  2550. continue
  2551. yield SampleInput(
  2552. sample_input.input.detach().to_sparse(layout=layout,
  2553. blocksize=blocksize).requires_grad_(requires_grad),
  2554. args=sample_input.args,
  2555. kwargs=sample_input.kwargs)
  2556. if layout is torch.sparse_coo and (dtype.is_floating_point or dtype.is_complex):
  2557. # uncoalesced samples
  2558. inp = sample_input.input.detach().to_sparse(layout=layout)
  2559. inp = torch.sparse_coo_tensor(inp.indices().repeat(1, 2),
  2560. inp.values().repeat(2),
  2561. inp.shape,
  2562. dtype=inp.dtype,
  2563. device=inp.device)
  2564. assert not inp.is_coalesced()
  2565. yield SampleInput(inp.requires_grad_(requires_grad),
  2566. args=sample_input.args,
  2567. kwargs=sample_input.kwargs)
  2568. if sample_input.input.ndim > 2:
  2569. # hybrid samples
  2570. yield SampleInput(
  2571. sample_input.input.detach().to_sparse(layout=layout,
  2572. blocksize=blocksize,
  2573. dense_dim=sample_input.input.ndim - 2).requires_grad_(requires_grad),
  2574. args=sample_input.args,
  2575. kwargs=sample_input.kwargs)
  2576. class _TestParamsMaxPoolBase:
  2577. def __init__(self):
  2578. self.kwargs = {
  2579. 'kernel_size': [3],
  2580. 'stride': [2, None],
  2581. 'ceil_mode': [True, False],
  2582. 'padding': [0, 1],
  2583. 'dilation': [1],
  2584. 'return_indices': [True, False]
  2585. }
  2586. self.shapes = [
  2587. [1, 2, None], # batch
  2588. [2], # channels
  2589. [3, 6] # signal
  2590. ]
  2591. def _gen_shape(self):
  2592. for shape in product(*self.shapes):
  2593. # shape[0] is None indicates missing batch dimension
  2594. if shape[0] is None:
  2595. shape = shape[1:]
  2596. yield shape, torch.contiguous_format
  2597. # only 2d (N, C, H, W) rank 4 tensors support channels_last memory format
  2598. if len(self.shapes) == 4 and len(shape) == 4:
  2599. yield shape, torch.channels_last
  2600. def _gen_kwargs(self):
  2601. keys = self.kwargs.keys()
  2602. for values in product(*self.kwargs.values()):
  2603. yield dict(zip(keys, values))
  2604. def gen_input_params(self):
  2605. yield from product(self._gen_shape(), self._gen_kwargs())
  2606. class _TestParamsMaxPool1d(_TestParamsMaxPoolBase):
  2607. def __init__(self):
  2608. super().__init__()
  2609. self.kwargs['kernel_size'] += [(3,)]
  2610. self.kwargs['stride'] += [(2,)]
  2611. self.kwargs['padding'] += [(1,)]
  2612. self.kwargs['dilation'] += [(1,)]
  2613. class _TestParamsMaxPool2d(_TestParamsMaxPoolBase):
  2614. def __init__(self):
  2615. super().__init__()
  2616. self.kwargs['kernel_size'] += [(3, 2)]
  2617. self.kwargs['stride'] += [(2, 1)]
  2618. self.kwargs['padding'] += [(1, 1)]
  2619. self.kwargs['dilation'] += [(1, 2)]
  2620. self.shapes.append([6])
  2621. class _TestParamsMaxPool3d(_TestParamsMaxPoolBase):
  2622. def __init__(self):
  2623. super().__init__()
  2624. self.kwargs['kernel_size'] += [(3, 2, 3)]
  2625. self.kwargs['stride'] += [(2, 1, 2)]
  2626. self.kwargs['dilation'] += [(1, 2, 1)]
  2627. self.shapes.append([6])
  2628. self.shapes.append([5])
  2629. def sample_inputs_max_pool(op_info, device, dtype, requires_grad, **kwargs):
  2630. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
  2631. params_generator_type_dict = {
  2632. 'nn.functional.max_pool1d': _TestParamsMaxPool1d,
  2633. 'nn.functional.max_pool2d': _TestParamsMaxPool2d,
  2634. 'nn.functional.max_pool3d': _TestParamsMaxPool3d,
  2635. 'max_pool2d_with_indices_backward': _TestParamsMaxPool2d,
  2636. }
  2637. params_generator = params_generator_type_dict[op_info.name]()
  2638. for (shape, memory_format), kwargs in params_generator.gen_input_params():
  2639. arg = make_arg(shape).to(memory_format=memory_format).requires_grad_(requires_grad)
  2640. yield SampleInput(arg, kwargs=kwargs)
  2641. def max_pool2d_backward(*args, kernel_size=(), stride=(), padding=(0,), dilation=(1,), ceil_mode=False, **kwargs):
  2642. out, indices = torch.nn.functional.max_pool2d_with_indices(
  2643. *args, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, ceil_mode=ceil_mode, return_indices=True)
  2644. grad_out = torch.ones_like(out)
  2645. if stride is None:
  2646. stride = kernel_size
  2647. out_b = torch.ops.aten.max_pool2d_with_indices_backward.default(
  2648. grad_out, *args, kernel_size, stride, padding, dilation, ceil_mode, indices)
  2649. return out_b
  2650. def error_inputs_max_pool1d(op_info, device, **kwargs):
  2651. # Toggle requires_grad because `max_pool1d` has different path
  2652. # based on whether `requires_grad` is set or not.
  2653. for requires_grad in (True, False):
  2654. make_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=requires_grad)
  2655. # error inputs when pad is negative
  2656. x = make_arg((0, 1, 49))
  2657. yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1, 'return_indices': True}),
  2658. error_regex='pad must be non-negative')
  2659. # error inputs when pad > kernel_size / 2
  2660. yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4, 'return_indices': True}),
  2661. error_regex='pad should be at most half of kernel size')
  2662. # error inputs for input tensor
  2663. error_msg = r'Expected 2D or 3D \(batch mode\) tensor with optional 0 dim batch size for input'
  2664. yield ErrorInput(SampleInput(make_arg((), requires_grad=requires_grad), kwargs={'kernel_size': 1}),
  2665. error_regex=error_msg)
  2666. # error inputs for empty input
  2667. yield ErrorInput(SampleInput(torch.tensor([], device=device, requires_grad=requires_grad),
  2668. kwargs={'kernel_size': 1}),
  2669. error_regex=error_msg)
  2670. # error: unbatched input with 0 sized non-batch dims.
  2671. yield ErrorInput(SampleInput(make_arg((0, 10), requires_grad=requires_grad),
  2672. kwargs={'kernel_size': 1}),
  2673. error_regex=error_msg)
  2674. # error: batched input with 0 sized non-batch dims.
  2675. yield ErrorInput(SampleInput(make_arg((1, 10, 0), requires_grad=requires_grad),
  2676. kwargs={'kernel_size': 1}),
  2677. error_regex=error_msg)
  2678. # error inputs for empty input with stride=0
  2679. error_msg = 'stride must be greater than zero, but got 0'
  2680. yield ErrorInput(SampleInput(make_arg((3, 3, 3)), kwargs={'kernel_size': 1, 'stride': 0}),
  2681. error_regex=error_msg)
  2682. # error inputs for empty input with dilation=0
  2683. error_msg = 'dilation must be greater than zero, but got 0'
  2684. yield ErrorInput(SampleInput(make_arg((3, 3, 3)),
  2685. kwargs={'kernel_size': 1, 'stride': 1, 'padding': 0, 'dilation': 0}),
  2686. error_regex=error_msg)
  2687. # error inputs for invalid output size
  2688. error_msg = 'Invalid computed output size: -2'
  2689. yield ErrorInput(SampleInput(make_arg((2, 2, 2)),
  2690. kwargs={'kernel_size': 5, 'stride': 1, 'padding': 0, 'dilation': 1}),
  2691. error_regex=error_msg)
  2692. # error inputs when kernel_size=0
  2693. error_msg = 'kernel_size must be greater than zero'
  2694. yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 0}),
  2695. error_regex=error_msg)
  2696. # error inputs for strides > 0
  2697. error_msg = 'stride must be greater than zero'
  2698. yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 0}),
  2699. error_regex=error_msg)
  2700. def error_inputs_max_pool2d(op_info, device, **kwargs):
  2701. make_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False)
  2702. # error inputs when pad is negative
  2703. x = make_arg((0, 1, 49))
  2704. yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1, 'return_indices': True}),
  2705. error_regex='pad must be non-negative')
  2706. # 2-dimensional kernel
  2707. yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2), 'stride': 50, 'padding': -1, 'return_indices': True}),
  2708. error_regex='pad must be non-negative')
  2709. # error inputs when pad > kernel_size / 2 (kernel_size : int)
  2710. yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4, 'return_indices': True}),
  2711. error_regex='pad should be at most half of kernel size')
  2712. # error inputs when pad > kernel_size / 2 (kernel_size : tuple)
  2713. yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2), 'stride': 50, 'padding': 4, 'return_indices': True}),
  2714. error_regex='pad should be at most half of kernel size')
  2715. # error: unbatched input with 0 sized non-batch dims.
  2716. err_msg = r'Expected 3D or 4D \(batch mode\) tensor with optional 0 dim batch size for input'
  2717. yield ErrorInput(SampleInput(make_arg((1, 0, 10)),
  2718. kwargs={'kernel_size': 1}),
  2719. error_regex=err_msg)
  2720. # error: batched input with 0 sized non-batch dims.
  2721. yield ErrorInput(SampleInput(make_arg((2, 1, 10, 0)),
  2722. kwargs={'kernel_size': 1}),
  2723. error_regex=err_msg)
  2724. def error_inputs_max_pool3d(op_info, device, **kwargs):
  2725. make_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False)
  2726. # error inputs when pad is negative
  2727. x = make_arg((0, 1, 49, 50))
  2728. yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1, 'return_indices': True}),
  2729. error_regex='pad must be non-negative')
  2730. # 3-dimensional kernel
  2731. yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2, 2), 'stride': 50,
  2732. 'padding': -1, 'return_indices': True}),
  2733. error_regex='pad must be non-negative')
  2734. # error inputs when pad > kernel_size / 2 (kernel_size: int)
  2735. yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4, 'return_indices': True}),
  2736. error_regex='pad should be at most half of kernel size')
  2737. # error inputs when pad > kernel_size / 2 (kernel_size: tuple)
  2738. yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2, 2), 'stride': 50,
  2739. 'padding': 4, 'return_indices': True}),
  2740. error_regex='pad should be at most half of kernel size')
  2741. # error: unbatched input with 0 sized non-batch dims.
  2742. err_msg = r'Expected input\'s non-batch dimensions to have positive length'
  2743. yield ErrorInput(SampleInput(make_arg((0, 1, 2, 10)),
  2744. kwargs={'kernel_size': 1}),
  2745. error_regex=err_msg)
  2746. # error: batched inputs with 0 sized non-batch dims.
  2747. yield ErrorInput(SampleInput(make_arg((2, 1, 0, 1, 2)),
  2748. kwargs={'kernel_size': 1}),
  2749. error_regex=err_msg)
  2750. def sample_inputs_normalize(self, device, dtype, requires_grad, **kwargs):
  2751. make_arg = partial(make_tensor, low=-1, high=1, device=device, dtype=dtype, requires_grad=requires_grad)
  2752. cases: Tuple[Tuple[int], dict] = ( # type: ignore[assignment]
  2753. ((2, 1, 4, 5), {'p': 1., 'dim': 2}),
  2754. ((2, 3, 4, 5), {'p': 2., 'dim': 1}),
  2755. ((1, 2, 4, 5), {'p': 0.5, 'dim': 0}),
  2756. ((1, 3, 4, 5), {'p': -1., 'dim': 1}),
  2757. ((1, 3, 4, 5), {'p': 0., 'dim': -1}),
  2758. ((), {'p': 1.2, 'dim': 0}),
  2759. ((2, 3, 4, 5), {}),
  2760. ((2, 3, 4, 5), {'eps': 1e-4}))
  2761. for input_shape, kwargs in cases:
  2762. yield SampleInput(make_arg(input_shape), kwargs=kwargs)
  2763. def complex_conv(fn, input_size, weight, grad_output, stride, padding, dilation, groups):
  2764. # conv(W, x, b) = conv(Wr, xr, br) - conv(Wi, xi, 0) + i(conv(Wi, xr, bi) + conv(Wr, xi, 0))
  2765. # a = conv(Wr, xr, br),
  2766. # b = conv(Wi, xi, 0),
  2767. # c = conv(Wr + Wi, xr + xi, br + bi)
  2768. # conv(W, x, b) = a - b + i(c - a - b)
  2769. grad_output_ = torch.view_as_real(grad_output)
  2770. grad_output_r = grad_output_[..., 0]
  2771. grad_output_i = grad_output_[..., 1]
  2772. weight_ = torch.view_as_real(weight)
  2773. weight_r = weight_[..., 0]
  2774. weight_i = weight_[..., 1]
  2775. a = fn(input_size, weight_r, grad_output_r, stride, padding, dilation, groups)
  2776. b = fn(input_size, weight_i, grad_output_i, stride, padding, dilation, groups)
  2777. c = fn(input_size, weight_r + weight_i, grad_output_r + grad_output_i, stride, padding, dilation, groups)
  2778. return (a - b) + 1j * (c - a - b)
  2779. def conv_transpose_ref(input, weight, bias, stride=1, padding=0,
  2780. output_padding=0, dilation=1, groups=1,
  2781. fn=None):
  2782. # Derivative of `conv` is `conv_transpose`.
  2783. # To verify the correctness of `conv_transpose`,
  2784. # we rely `torch.nn.grad` implementation (which is tested in test_nn.py)
  2785. # for floating dtypes.
  2786. assert fn is not None
  2787. grad_fn_map = {torch.nn.functional.conv_transpose1d: torch.nn.grad.conv1d_input,
  2788. torch.nn.functional.conv_transpose2d: torch.nn.grad.conv2d_input,
  2789. torch.nn.functional.conv_transpose3d: torch.nn.grad.conv3d_input}
  2790. batched_dim_map = {torch.nn.functional.conv_transpose1d: 3,
  2791. torch.nn.functional.conv_transpose2d: 4,
  2792. torch.nn.functional.conv_transpose3d: 5}
  2793. # Input for `ref` is ndarray.
  2794. input, weight = torch.from_numpy(input), torch.from_numpy(weight)
  2795. is_batched = len(input.shape) == batched_dim_map[fn]
  2796. if not is_batched:
  2797. input = input.unsqueeze(0)
  2798. if bias is not None:
  2799. bias = torch.from_numpy(bias)
  2800. unsqueeze_dims = input.ndim - 2
  2801. for _ in range(unsqueeze_dims):
  2802. bias = bias.unsqueeze(1)
  2803. grad_output = input
  2804. # Get the input shape for grad_fn.
  2805. conv_transpose_output = fn(grad_output.to('meta'), weight.to('meta'), None,
  2806. stride=stride, padding=padding, output_padding=output_padding,
  2807. groups=groups, dilation=dilation)
  2808. input_size = conv_transpose_output.shape
  2809. grad_fn = grad_fn_map[fn]
  2810. if weight.dtype.is_complex:
  2811. out = complex_conv(grad_fn, input_size, weight, grad_output, stride, padding, dilation, groups)
  2812. else: # Floating
  2813. out = grad_fn(input_size, weight, grad_output, stride, padding, dilation, groups)
  2814. if bias is not None:
  2815. out = out + bias
  2816. return out.squeeze(0) if not is_batched else out
  2817. def sample_inputs_conv_transpose1d(op_info, device, dtype, requires_grad, **kwargs):
  2818. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  2819. # Ordered as shapes for input, weight, bias
  2820. # and a dict of values of (stride, padding, output_padding, groups, dilation)
  2821. cases: Tuple[Tuple[int], Tuple[int], Tuple[int], dict] = ( # type: ignore[assignment]
  2822. ((1, 3, 4), (3, 3, 3), (3,),
  2823. {'stride': (2,), 'padding': 2, 'output_padding': (1,), 'groups': 1}),
  2824. ((2, 2, 4), (2, 2, 4), (4,),
  2825. {'stride': (3,), 'padding': (1,), 'output_padding': (2,), 'groups': 2, 'dilation': (4,)}),
  2826. ((1, 1, 4), (1, 1, 4), (1,),
  2827. {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1, 'dilation': (2,)}),
  2828. ((1, 1, 4), (1, 2, 3), None,
  2829. {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1}),
  2830. ((1, 4, 5), (4, 8, 3), None,
  2831. {})
  2832. )
  2833. for input_shape, weight, bias, kwargs in cases:
  2834. # Batched
  2835. yield SampleInput(make_arg(input_shape), args=(
  2836. make_arg(weight),
  2837. make_arg(bias) if bias is not None else bias
  2838. ), kwargs=kwargs)
  2839. # Unbatched
  2840. yield SampleInput(make_arg(input_shape[1:]), args=(
  2841. make_arg(weight),
  2842. make_arg(bias) if bias is not None else bias
  2843. ), kwargs=kwargs)
  2844. def sample_inputs_conv_transpose2d(op_info, device, dtype, requires_grad, **kwargs):
  2845. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  2846. # Ordered as shapes for input, weight, bias
  2847. # and a dict of values of (stride, padding, output_padding, groups, dilation)
  2848. cases: Tuple[Tuple[int], Tuple[int], Tuple[int], dict] = ( # type: ignore[assignment]
  2849. ((1, 3, 4, 4), (3, 3, 3, 3), (3,),
  2850. {'stride': (2, 2), 'padding': 2, 'output_padding': (1, 1), 'groups': 1}),
  2851. ((2, 2, 4, 4), (2, 2, 4, 5), (4,),
  2852. {'stride': (3, 2), 'padding': (1, 2), 'output_padding': (2, 3), 'groups': 2, 'dilation': (4, 4)}),
  2853. ((1, 1, 4, 5), (1, 1, 4, 3), (1,),
  2854. {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1, 'dilation': (2, 3)}),
  2855. ((1, 1, 4, 3), (1, 2, 3, 4), None,
  2856. {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1}),
  2857. ((2, 4, 4, 4), (4, 1, 3, 3), None, {'groups': 4}),
  2858. ((1, 2, 5, 5), (2, 4, 3, 3), None, {})
  2859. )
  2860. for input_shape, weight, bias, kwargs in cases:
  2861. # Batched
  2862. yield SampleInput(make_arg(input_shape), args=(
  2863. make_arg(weight),
  2864. make_arg(bias) if bias is not None else bias
  2865. ), kwargs=kwargs)
  2866. # Unbatched
  2867. yield SampleInput(make_arg(input_shape[1:]), args=(
  2868. make_arg(weight),
  2869. make_arg(bias) if bias is not None else bias
  2870. ), kwargs=kwargs)
  2871. def sample_inputs_conv_transpose3d(op_info, device, dtype, requires_grad, **kwargs):
  2872. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  2873. # Ordered as shapes for input, weight, bias
  2874. # and a dict of values of (stride, padding, output_padding, groups, dilation)
  2875. cases: Tuple[Tuple[int], Tuple[int], Tuple[int], dict] = ( # type: ignore[assignment]
  2876. ((1, 3, 4, 4, 4), (3, 3, 3, 3, 3), (3,),
  2877. {'stride': (2, 2, 2), 'padding': 2, 'output_padding': (1, 1, 1), 'groups': 1}),
  2878. ((2, 2, 4, 4, 4), (2, 2, 4, 5, 6), (4,),
  2879. {'stride': (3, 2, 1), 'padding': (1, 2, 3), 'output_padding': (2, 3, 1), 'groups': 2, 'dilation': (4, 4, 4)}),
  2880. ((1, 1, 4, 5, 2), (1, 1, 4, 3, 1), (1,),
  2881. {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1, 'dilation': (2, 3, 2)}),
  2882. ((1, 1, 4, 3, 4), (1, 2, 3, 4, 5), None,
  2883. {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1}),
  2884. ((1, 4, 5, 5, 5), (4, 8, 3, 3, 3), None,
  2885. {})
  2886. )
  2887. for input_shape, weight, bias, kwargs in cases:
  2888. # Batched
  2889. yield SampleInput(make_arg(input_shape), args=(
  2890. make_arg(weight),
  2891. make_arg(bias) if bias is not None else bias
  2892. ), kwargs=kwargs)
  2893. # Unbatched
  2894. yield SampleInput(make_arg(input_shape[1:]), args=(
  2895. make_arg(weight),
  2896. make_arg(bias) if bias is not None else bias
  2897. ), kwargs=kwargs)
  2898. def sample_inputs_conv1d(op_info, device, dtype, requires_grad, **kwargs):
  2899. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  2900. # Ordered as shapes for input, weight, bias,
  2901. # and a dict of values of (stride, padding, dilation, groups)
  2902. cases: Tuple = (
  2903. ((1, 3, 4), (3, 3, 3), (3,), {'stride': (2,), 'padding': 2, 'groups': 1}),
  2904. ((2, 4, 8), (2, 2, 3), (2,), {'stride': 3, 'padding': 1, 'groups': 2, 'dilation': 2}),
  2905. ((1, 4, 5), (1, 4, 3), None, {'stride': (2,), 'padding': 'valid'}),
  2906. ((2, 2, 4), (2, 1, 4), (2,), {'stride': (1,), 'padding': 'same', 'groups': 2, 'dilation': (2,)}),
  2907. # With defaults
  2908. ((1, 4, 5), (3, 4, 3), None, {}),
  2909. )
  2910. # TODO: (@krshrimali), add error_inputs_func once https://github.com/pytorch/pytorch/pull/67354 is merged
  2911. # Should replace test_conv_modules_raise_error_on_incorrect_input_size and test_conv_shapecheck
  2912. # in test/test_nn.py
  2913. for input_shape, weight, bias, kwargs in cases:
  2914. # Batched
  2915. yield SampleInput(make_arg(input_shape), args=(
  2916. make_arg(weight),
  2917. make_arg(bias) if bias is not None else bias
  2918. ), kwargs=kwargs)
  2919. # Unbatched
  2920. yield SampleInput(make_arg(input_shape[1:]), args=(
  2921. make_arg(weight),
  2922. make_arg(bias) if bias is not None else bias
  2923. ), kwargs=kwargs)
  2924. def error_inputs_conv1d(opinfo, device, **kwargs):
  2925. input = torch.randn(size=(33, 16, 30), device=device, dtype=torch.float64)
  2926. weight = torch.randn(size=(20, 16, 5), device=device, dtype=torch.float64)
  2927. groups = 0
  2928. yield ErrorInput(
  2929. SampleInput(input, kwargs={"weight": weight, "groups": groups}),
  2930. error_regex="non-positive groups is not supported"
  2931. )
  2932. def error_inputs_conv2d(opinfo, device, **kwargs):
  2933. weight = torch.randint(high=10, size=(3, 2, 3, 3), device=device)
  2934. input = torch.randint(high=10, size=(2, 4, 4), device=device)
  2935. bias = torch.rand((3,), dtype=torch.float32, device=device)
  2936. yield ErrorInput(SampleInput(input, args=(weight, bias)), error_regex="should be the same")
  2937. weight = torch.rand(size=(3, 2, 3, 3), device=device, dtype=torch.float64)
  2938. input = torch.rand(size=(2, 4, 4), device=device, dtype=torch.float64)
  2939. bias = torch.rand((3,), dtype=torch.complex128, device=device)
  2940. yield ErrorInput(SampleInput(input, args=(weight, bias)), error_regex="should be the same")
  2941. input = torch.randn(size=(1, 4, 5, 5), device=device, dtype=torch.float64)
  2942. weight = torch.randn(size=(8, 4, 3, 3), device=device, dtype=torch.float64)
  2943. groups = 0
  2944. yield ErrorInput(
  2945. SampleInput(input, kwargs={"weight": weight, "groups": groups}),
  2946. error_regex="non-positive groups is not supported"
  2947. )
  2948. def sample_inputs_conv2d(op_info, device, dtype, requires_grad, jit_fail_sample=False, **kwargs):
  2949. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  2950. # Ordered as shapes for input, weight, bias
  2951. # and a dict of values of (stride, padding, groups, dilation)
  2952. cases: Tuple = (
  2953. ((1, 3, 4, 4), (3, 3, 3, 3), (3,),
  2954. {'stride': (2, 2), 'padding': 2, 'groups': 1}),
  2955. ((2, 4, 8, 8), (2, 2, 3, 3), (2,),
  2956. {'stride': (3, 2), 'padding': (2, 1), 'groups': 2, 'dilation': (4, 4)}),
  2957. ((1, 4, 5, 5), (1, 4, 2, 3), (1,),
  2958. {'stride': 2, 'padding': 1, 'groups': 1, 'dilation': (2, 3)}),
  2959. ((1, 4, 5, 5), (1, 4, 2, 3), (1,),
  2960. {'stride': 2, 'padding': 1, 'groups': 1, 'dilation': (2, 3)}),
  2961. ((1, 2, 4, 3), (4, 2, 3, 4), None,
  2962. {'stride': 2, 'padding': 1, 'groups': 1}),
  2963. ((1, 4, 5, 5), (1, 4, 2, 3), (1,),
  2964. {'stride': 2, 'padding': "valid"}),
  2965. ((1, 4, 5, 5), (1, 4, 2, 3), (1,),
  2966. {'stride': 1, 'padding': "same", 'dilation': 3}),
  2967. # Below are the group related samples from common_nn.py
  2968. ((2, 4, 6, 6), (4, 1, 3, 3), (4,), {'groups': 4}),
  2969. ((2, 4, 6, 6), (8, 1, 3, 3), (8,), {'groups': 4}),
  2970. ((2, 4, 6, 6), (8, 1, 3, 3), None, {'groups': 4}),
  2971. ((2, 4, 6, 6), (4, 1, 3, 3), (4,), {'groups': 4, 'stride': (3, 2)}),
  2972. ((2, 4, 6, 6), (4, 1, 3, 3), (4,), {'groups': 4, 'padding': (1, 1)}),
  2973. ((2, 4, 5, 5), (4, 1, 2, 2), (4,), {'groups': 4, 'dilation': (2, 2)}),
  2974. ((2, 4, 6, 5), (6, 2, 3, 2), (6,), {'groups': 2}),
  2975. # With defaults
  2976. ((1, 4, 5, 5), (3, 4, 3, 3), None, {}),
  2977. )
  2978. for input_shape, weight, bias, kwargs in cases:
  2979. # Batched
  2980. yield SampleInput(make_arg(input_shape), args=(
  2981. make_arg(weight),
  2982. make_arg(bias) if bias is not None else bias
  2983. ), kwargs=kwargs)
  2984. # Unbatched
  2985. yield SampleInput(make_arg(input_shape[1:]), args=(
  2986. make_arg(weight),
  2987. make_arg(bias) if bias is not None else bias
  2988. ), kwargs=kwargs)
  2989. def sample_inputs_group_norm(opinfo, device, dtype, requires_grad, **kwargs):
  2990. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  2991. # Ordered as input shape, num groups, and kwargs for eps
  2992. cases: Tuple[Tuple[int], int, float] = ( # type: ignore[assignment]
  2993. ((1, 6, 3), 2, {'eps' : 0.5}),
  2994. ((2, 6, 3), 2, {'eps' : -0.5}),
  2995. ((1, 3), 1, {'eps' : 1e-5}),
  2996. ((0, 2), 1, {'eps' : 1e-5}),
  2997. ((S, S, S), 1, {'eps' : 0.5}),
  2998. )
  2999. # num_channels is inferred to be input.shape[1] dimension
  3000. for input_shape, num_groups, kwargs in cases:
  3001. # Shape of weight and bias should be the same as num_channels
  3002. channels = input_shape[1] if len(input_shape) > 1 else 0
  3003. weight_tensor = make_arg(channels)
  3004. bias_tensor = make_arg(channels)
  3005. # Checking for permutations of weights and biases as `None`
  3006. weights = [weight_tensor, None]
  3007. biases = [bias_tensor, None]
  3008. for weight, bias in itertools.product(weights, biases):
  3009. kwargs = {
  3010. 'weight': weight,
  3011. 'bias': bias,
  3012. **kwargs
  3013. }
  3014. yield SampleInput(make_arg(input_shape), num_groups, **kwargs)
  3015. # Without any optional args
  3016. yield SampleInput(make_arg((1, 2)), args=(1,))
  3017. def reference_inputs_group_norm(op_info, device, dtype, requires_grad, **kwargs):
  3018. yield from sample_inputs_group_norm(
  3019. op_info, device, dtype, requires_grad, **kwargs)
  3020. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  3021. # Ordered as input shape, num groups, and kwargs for eps
  3022. cases: Tuple[Tuple[int], int, float] = ( # type: ignore[assignment]
  3023. ((20, 6, 10, 10), 3, {'eps' : 1e-5}),
  3024. # equivalent with InstanceNorm
  3025. # GroupNorm(C, num_groups=C) == InstanceNorm(num_features=C)
  3026. ((20, 6, 10, 10), 6, {'eps' : 1e-5}),
  3027. # equivalent with LayerNorm
  3028. # GroupNorm(C, num_groups=1, affine=False) == LayerNorm(normalized_shape=[C, H, W], elementwise_affine=False)
  3029. ((20, 6, 10, 10), 1, {'eps' : 1e-5}),
  3030. )
  3031. # num_channels is inferred to be input.shape[1] dimension
  3032. for input_shape, num_groups, kwargs in cases:
  3033. # Shape of weight and bias should be the same as num_channels
  3034. channels = input_shape[1] if len(input_shape) > 1 else 0
  3035. input_tensor = make_arg(input_shape)
  3036. weight_tensor = make_arg(channels)
  3037. bias_tensor = make_arg(channels)
  3038. # Checking for permutations of weights and biases as `None`
  3039. weights = [weight_tensor, None]
  3040. biases = [bias_tensor, None]
  3041. for weight, bias in itertools.product(weights, biases):
  3042. kwargs = {
  3043. 'weight': weight,
  3044. 'bias': bias,
  3045. **kwargs
  3046. }
  3047. yield SampleInput(input_tensor, num_groups, **kwargs)
  3048. def sample_inputs_instance_norm(opinfo, device, dtype, requires_grad, **kwargs):
  3049. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  3050. make_arg_without_requires_grad = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
  3051. # Ordered as: input shape, kwargs for momentum, eps
  3052. cases: Tuple[Tuple[int], dict] = ( # type: ignore[assignment]
  3053. ((S, S, S), {'momentum': 0.5, 'eps': 0.6}),
  3054. ((S, S, S), {'momentum': 0.5, 'eps': 0.6, 'use_input_stats': True}),
  3055. ((3, 2, 4), {'momentum': -1.2}),
  3056. ((3, 2, 4), {'momentum': 0.0}),
  3057. ((3, 2, 3, 4), {'momentum': -1.0, 'eps': 0.5}),
  3058. ((3, 2, 3, 4), {'momentum': -1.0, 'eps': 0.5}),
  3059. )
  3060. for input_shape, kwargs in cases:
  3061. # args: running mean, running var, weight and bias should necessarily be of shape: (channels,)
  3062. channels = input_shape[1]
  3063. weight = make_arg(channels)
  3064. bias = make_arg(channels)
  3065. running_mean = make_arg_without_requires_grad(channels, low=0)
  3066. running_var = make_arg_without_requires_grad(channels, low=0)
  3067. new_kwargs = {
  3068. 'running_mean': running_mean,
  3069. 'running_var': running_var,
  3070. 'weight': weight,
  3071. 'bias': bias,
  3072. **kwargs
  3073. }
  3074. yield SampleInput(
  3075. make_arg(input_shape),
  3076. args=(),
  3077. kwargs=new_kwargs
  3078. )
  3079. # Checking for permutations of weights and biases as `None`
  3080. # instance_norm assumes that if there's a bias, there's a weight
  3081. weights = [channels, None]
  3082. biases = [None, None]
  3083. for weight_channels, bias_channels in zip(weights, biases):
  3084. running_mean = make_arg_without_requires_grad(channels, low=0)
  3085. running_var = make_arg_without_requires_grad(channels, low=0)
  3086. yield SampleInput(
  3087. make_arg(input_shape),
  3088. args=(),
  3089. kwargs={
  3090. 'running_mean': running_mean,
  3091. 'running_var': running_var,
  3092. 'weight': make_arg(weight_channels) if weight_channels is not None else None,
  3093. 'bias': make_arg(bias_channels) if bias_channels is not None else None
  3094. }
  3095. )
  3096. # Test case for no optional kwargs
  3097. yield SampleInput(make_arg((1, 2, 3)), kwargs={})
  3098. def sample_inputs_layer_norm(opinfo, device, dtype, requires_grad, **kwargs):
  3099. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  3100. # Ordered as input shape, normalized_shape and a kwarg dict for eps
  3101. cases: Tuple[Tuple[int], Tuple[int], dict] = ( # type: ignore[assignment]
  3102. ((1, 2, 3), (1, 2, 3), {'eps': 0.5}),
  3103. ((2, 2, 3), (2, 3), {'eps': -0.5}),
  3104. ((1,), (1,), {}),
  3105. ((1, 2), (2,), {}),
  3106. ((0, 1), (1,), {}),
  3107. )
  3108. for input_shape, normalized_shape, kwargs in cases:
  3109. # Shape of weight and bias should be the same as normalized_shape
  3110. weight = make_arg(normalized_shape)
  3111. bias = make_arg(normalized_shape)
  3112. yield SampleInput(
  3113. make_arg(input_shape),
  3114. args=(normalized_shape, weight, bias),
  3115. kwargs=kwargs
  3116. )
  3117. # Without any optional args
  3118. yield SampleInput(make_arg((1, 2)), args=((2,),))
  3119. # TODO: @krshrimali, once to_numpy method in SampleInput class is modified to take None inputs,
  3120. # enable these inputs; see https://github.com/pytorch/pytorch/pull/63276#discussion_r691950400
  3121. # With weight and a `None` bias
  3122. # yield SampleInput(make_arg((1, 2)), args=((2,), make_arg((2,)), None))
  3123. # With `None` weight and bias (tests failing for this, see the link above)
  3124. # yield SampleInput(make_arg((1, 2)), args=((2,), None, make_arg((2,))))
  3125. def sample_inputs_native_layer_norm(opinfo, device, dtype, requires_grad, **kwargs):
  3126. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  3127. # Ordered as input shape, normalized_shape, eps
  3128. cases: Tuple[Tuple[int], Tuple[int], float] = ( # type: ignore[assignment]
  3129. ((1, 2, 3), (1, 2, 3), 0.5),
  3130. ((2, 2, 3), (2, 3), -0.5),
  3131. ((1,), (1,), 1e-5),
  3132. ((1, 2), (2,), 1e-5),
  3133. ((0, 1), (1,), 1e-5),
  3134. )
  3135. for input_shape, normalized_shape, eps in cases:
  3136. # Shape of weight and bias should be the same as normalized_shape
  3137. weight = make_arg(normalized_shape)
  3138. bias = make_arg(normalized_shape)
  3139. yield SampleInput(
  3140. make_arg(input_shape),
  3141. args=(normalized_shape, weight, bias, eps),
  3142. )
  3143. yield SampleInput(
  3144. make_arg(input_shape),
  3145. args=(normalized_shape, None, bias, eps),
  3146. )
  3147. yield SampleInput(
  3148. make_arg(input_shape),
  3149. args=(normalized_shape, weight, None, eps),
  3150. )
  3151. yield SampleInput(
  3152. make_arg(input_shape),
  3153. args=(normalized_shape, None, None, eps),
  3154. )
  3155. def error_inputs_group_norm(opinfo, device, **kwargs):
  3156. make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False)
  3157. # check that input has minimum number of dimensions
  3158. err_msg1 = "Expected at least 2 dimensions for input tensor but received"
  3159. s1 = SampleInput(make_arg((1)), args=(1,))
  3160. yield ErrorInput(s1, error_regex=err_msg1)
  3161. # check that the channels dimension is compatible with number of groups
  3162. err_msg2 = "Expected number of channels in input to be divisible by num_groups, but got input of shape"
  3163. s2 = SampleInput(make_arg((2, 7, 4)), args=(2,))
  3164. yield ErrorInput(s2, error_regex=err_msg2)
  3165. def error_inputs_native_layer_norm(opinfo, device, **kwargs):
  3166. make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False)
  3167. input_shape = (1, 2, 3)
  3168. err_msg1 = "Expected normalized_shape to be at least 1-dimensional"
  3169. s1 = SampleInput(
  3170. make_arg(input_shape), args=(tuple(), None, None, 1e-5)
  3171. )
  3172. yield ErrorInput(s1, error_regex=err_msg1)
  3173. normalized_shape = (1, 2, 3)
  3174. weight = make_arg((1, 2))
  3175. err_msg2 = "Expected weight to be of same shape as normalized_shape"
  3176. s2 = SampleInput(
  3177. make_arg(input_shape), args=(normalized_shape, weight, None, 1e-5)
  3178. )
  3179. yield ErrorInput(s2, error_regex=err_msg2)
  3180. bias = make_arg((1, 2))
  3181. err_msg3 = "Expected bias to be of same shape as normalized_shape"
  3182. s3 = SampleInput(
  3183. make_arg(input_shape), args=(normalized_shape, None, bias, 1e-5)
  3184. )
  3185. yield ErrorInput(s3, error_regex=err_msg3)
  3186. err_msg4 = "Given normalized_shape="
  3187. s4 = SampleInput(
  3188. make_arg((2, 2, 3)), args=((2, 2), None, None, 1e-5)
  3189. )
  3190. yield ErrorInput(s4, error_regex=err_msg4)
  3191. def sample_inputs_local_response_norm(opinfo, device, dtype, requires_grad, **kwargs):
  3192. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  3193. # Ordered as input shape, size and a kwarg dict for alpha, beta, and k
  3194. cases: Tuple[Tuple[int], Tuple[int], dict] = ( # type: ignore[assignment]
  3195. ((1, 6, 3), 2, {'alpha': 3e-05, 'beta': 0.5, 'k': 1.25}),
  3196. ((1, 6, 3), 2, {'beta': 0.5, 'k': 1.25}),
  3197. ((1, 6, 3), 2, {'alpha': 3e-05, 'k': 1.25}),
  3198. ((1, 6, 3), 2, {'alpha': 3e-05, 'beta': 0.5}),
  3199. ((1, 6, 3), 2, {'alpha': 3e-05}),
  3200. ((1, 6, 3), 2, {'beta': 0.5}),
  3201. ((1, 6, 3), 2, {'k': 1.25}),
  3202. ((1, 6, 3), 2, {}),
  3203. ((2, 6, 3), 2, {'alpha': 3e-05, 'beta': 0.5, 'k': 1.25}),
  3204. ((1, 1, 2), 1, {'alpha': 3e-05, 'beta': 0.5, 'k': 1.25}),
  3205. ((0, 1, 2), 1, {'alpha': 3e-05, 'beta': 0.5, 'k': 1.25}),
  3206. )
  3207. for input_shape, size, kwargs in cases:
  3208. yield SampleInput(make_arg(input_shape), args=(size,), kwargs=kwargs)
  3209. def sample_inputs_hardswish(self, device, dtype, requires_grad, **kwargs):
  3210. N = 5
  3211. # make sure we are testing -3 -> 3 range. default is -10 -> 10 so maybe unnecessary ?
  3212. make_arg = partial(make_tensor, device=device, dtype=dtype,
  3213. requires_grad=requires_grad, low=-5, high=5)
  3214. return (SampleInput(make_arg((N * 2, N * 2))) for _ in range(1, N))
  3215. def sample_inputs_linear(self, device, dtype, requires_grad, **kwargs):
  3216. features_options = [[3, 4], [8, 8]]
  3217. batch_options: List[List[int]] = [
  3218. [], # no batch
  3219. [0],
  3220. [8],
  3221. [2, 3],
  3222. ]
  3223. create_tensor = partial(make_tensor, device=device, dtype=dtype,
  3224. requires_grad=requires_grad, low=-2, high=2)
  3225. for has_bias, (in_feat, out_feat), batch_shape in \
  3226. itertools.product([True, False], features_options, batch_options):
  3227. input_tensor = create_tensor(batch_shape + [in_feat])
  3228. weight = create_tensor([out_feat, in_feat])
  3229. if not has_bias:
  3230. yield SampleInput(input_tensor, weight)
  3231. continue
  3232. bias = create_tensor([out_feat])
  3233. yield SampleInput(input_tensor, weight, bias)
  3234. def sample_inputs_bilinear(self, device, dtype, requires_grad, **kwargs):
  3235. features_options = [[3, 4, 5], [8, 8, 8]]
  3236. batch_options: List[List[int]] = [
  3237. [], # no batch
  3238. [0],
  3239. [8],
  3240. [2, 3],
  3241. ]
  3242. create_tensor = partial(make_tensor, device=device, dtype=dtype,
  3243. requires_grad=requires_grad, low=-2, high=2)
  3244. for has_bias, (in_feat1, in_feat2, out_feat), batch_shape in \
  3245. itertools.product([True, False], features_options, batch_options):
  3246. input_tensor1 = create_tensor(batch_shape + [in_feat1])
  3247. input_tensor2 = create_tensor(batch_shape + [in_feat2])
  3248. weight = create_tensor([out_feat, in_feat1, in_feat2])
  3249. if not has_bias:
  3250. yield SampleInput(input_tensor1, input_tensor2, weight)
  3251. continue
  3252. bias = create_tensor([out_feat])
  3253. yield SampleInput(input_tensor1, input_tensor2, weight, bias)
  3254. def sample_inputs_glu(self, device, dtype, requires_grad, **kwargs):
  3255. features_options = [[2], [2, 4], [8, 8], [3, 6, 8], [1, 4, 6, 7]]
  3256. batch_options: List[List[int]] = [
  3257. [], # no batch
  3258. [0],
  3259. [8],
  3260. [2, 3],
  3261. ]
  3262. create_tensor = partial(make_tensor, device=device, dtype=dtype,
  3263. requires_grad=requires_grad, low=-2, high=2)
  3264. for features, batch_shape in itertools.product(features_options, batch_options):
  3265. ndim = len(features) + len(batch_shape)
  3266. for dim in range(ndim):
  3267. input_tensor = create_tensor(batch_shape + features)
  3268. dim_size = input_tensor.size(dim)
  3269. if dim_size > 0 and dim_size % 2 == 0:
  3270. yield SampleInput(input_tensor, dim)
  3271. def sample_inputs_interpolate(mode, self, device, dtype, requires_grad, **kwargs):
  3272. N, C = 2, 3
  3273. D = 4
  3274. S = 3
  3275. L = 5
  3276. align_corners_options: Tuple[Any, ...] = (None,)
  3277. if mode in ('linear', 'bilinear', 'bicubic', 'trilinear'):
  3278. align_corners_options = (True, False, None)
  3279. ranks_for_mode = {
  3280. 'nearest': [1, 2, 3],
  3281. 'linear': [1],
  3282. 'bilinear': [2],
  3283. 'bicubic': [2],
  3284. 'trilinear': [3],
  3285. 'area': [1, 2, 3]
  3286. }
  3287. def shape(size, rank, with_batch_channel=True):
  3288. if with_batch_channel:
  3289. return tuple([N, C] + ([size] * rank))
  3290. return tuple([size] * rank)
  3291. make_arg = partial(make_tensor, device=device, dtype=dtype,
  3292. requires_grad=requires_grad, low=-1, high=1)
  3293. for align_corners in align_corners_options:
  3294. for rank in ranks_for_mode[mode]:
  3295. yield SampleInput(make_arg(shape(D, rank)),
  3296. shape(S, rank, False), None, mode, align_corners)
  3297. yield SampleInput(make_arg(shape(D, rank)),
  3298. shape(L, rank, False), None, mode, align_corners)
  3299. for recompute_scale_factor in [False, True]:
  3300. yield SampleInput(make_arg(shape(D, rank)),
  3301. None, 1.7, mode, align_corners,
  3302. recompute_scale_factor=recompute_scale_factor)
  3303. yield SampleInput(make_arg(shape(D, rank)),
  3304. None, 0.6, mode, align_corners,
  3305. recompute_scale_factor=recompute_scale_factor)
  3306. def sample_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs):
  3307. N, C = 2, 3
  3308. D = 4
  3309. S = 3
  3310. L = 5
  3311. ranks_for_mode = {
  3312. 'nearest': [1, 2, 3],
  3313. 'bilinear': [2],
  3314. }
  3315. def shape(size, rank, with_batch_channel=True):
  3316. if with_batch_channel:
  3317. return torch.Size([N, C] + ([size] * rank))
  3318. return torch.Size([size] * rank)
  3319. make_arg = partial(make_tensor, device=device, dtype=dtype,
  3320. requires_grad=requires_grad, low=-1, high=1)
  3321. for rank in ranks_for_mode[mode]:
  3322. yield SampleInput(make_arg(shape(D, rank)), size=shape(S, rank, False))
  3323. yield SampleInput(make_arg(shape(D, rank)), size=shape(L, rank, False))
  3324. yield SampleInput(make_arg(shape(D, rank)), scale_factor=1.7)
  3325. yield SampleInput(make_arg(shape(D, rank)), scale_factor=0.6)
  3326. def sample_inputs_gelu(self, device, dtype, requires_grad, **kwargs):
  3327. N = 5
  3328. for _ in range(1, N):
  3329. for approximate in ['none', 'tanh']:
  3330. yield SampleInput(
  3331. make_tensor((N * 2, N * 2), device=device, dtype=dtype,
  3332. requires_grad=requires_grad, low=-3, high=3),
  3333. approximate=approximate)
  3334. def error_inputs_gelu(op, device, **kwargs):
  3335. # Tests thtat gelu errors out when passed an approximation we don't know.
  3336. yield ErrorInput(SampleInput(make_tensor((), dtype=torch.float, device=device), kwargs={"approximate": "asdf"}),
  3337. error_regex="approximate argument must be either")
  3338. def sample_inputs_max_min_reduction_with_dim(op_info, device, dtype, requires_grad, **kwargs):
  3339. inputs = []
  3340. args_for_reduction_with_dim = (
  3341. ((S, S, S), (1,),),
  3342. ((S, S, S), (1, True, ),),
  3343. ((), (0,),),
  3344. ((), (0, True,),),
  3345. )
  3346. return ((SampleInput(make_tensor(input_tensor, dtype=dtype, device=device,
  3347. low=None, high=None,
  3348. requires_grad=requires_grad),
  3349. *args))
  3350. for input_tensor, args in args_for_reduction_with_dim)
  3351. def sample_inputs_max_min_reduction_no_dim(op_info, device, dtype, requires_grad, **kwargs):
  3352. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None)
  3353. yield SampleInput(make_arg((S, S, S)))
  3354. yield SampleInput(make_arg(()))
  3355. def _generate_nan_reduction_inputs(device, dtype, requires_grad, **kwargs):
  3356. yield from _generate_reduction_inputs(device, dtype, requires_grad)
  3357. # NaN only exists for floating point numbers
  3358. if dtype.is_complex or dtype.is_floating_point:
  3359. yield torch.tensor([2, torch.nan, -1], device=device, dtype=dtype, requires_grad=requires_grad)
  3360. yield torch.tensor([[torch.nan, 2], [0, 1]], device=device, dtype=dtype, requires_grad=requires_grad)
  3361. def sample_inputs_nan_reduction(supports_multiple_dims):
  3362. # Generates sample inputs for reduction ops that contain the input tensor
  3363. # and dim and keepdim kwargs. If a reduction op needs to test additional
  3364. # args/kwargs then create a separate sample_inputs function
  3365. def fn(op_info, device, dtype, requires_grad, **kwargs):
  3366. for t in _generate_nan_reduction_inputs(device, dtype, requires_grad):
  3367. # Add case without dim and keepdim kwargs
  3368. yield SampleInput(t.clone().requires_grad_(requires_grad))
  3369. for kwargs in _generate_reduction_kwargs(t.ndim, supports_multiple_dims):
  3370. yield SampleInput(t.clone().requires_grad_(requires_grad), **kwargs)
  3371. return fn
  3372. def sample_inputs_reduction_quantile(op_info, device, dtype, requires_grad, **kwargs):
  3373. test_quantiles = (0.5, make_tensor((2,), dtype=dtype, device=device, low=0, high=1, requires_grad=requires_grad))
  3374. test_interpolations = ['linear', 'midpoint']
  3375. for quantiles in test_quantiles:
  3376. for t in _generate_reduction_inputs(device, dtype, requires_grad):
  3377. # Add case without dim and keepdim kwargs
  3378. input = t.clone().requires_grad_(requires_grad)
  3379. yield SampleInput(input, quantiles)
  3380. for kwargs in _generate_reduction_kwargs(t.ndim, supports_multiple_dims=False):
  3381. # Interpolation kwarg for now is only supported when providing both dim and keepdim
  3382. kwargs.setdefault('dim', 0)
  3383. kwargs.setdefault('keepdim', False)
  3384. for interpolation in test_interpolations:
  3385. kwargs['interpolation'] = interpolation
  3386. input = t.clone().requires_grad_(requires_grad)
  3387. yield SampleInput(input, quantiles, **kwargs)
  3388. def sample_inputs_reduction_count_nonzero(*args, **kwargs):
  3389. """Sample inputs for count_nonzero"""
  3390. # count_nonzero does not support keepdim yet
  3391. for sample in sample_inputs_reduction(*args, **kwargs):
  3392. sample.kwargs.pop('keepdim', None)
  3393. yield sample
  3394. def sample_inputs_leaky_relu(op_info, device, dtype, requires_grad, **kwargs):
  3395. N = 10
  3396. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  3397. return (SampleInput(make_arg((N, N))) for _ in range(1, N))
  3398. def sample_inputs_fractional_max_pool2d(op_info, device, dtype, requires_grad, **kwargs):
  3399. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  3400. # Order: input_shape, kernel_size
  3401. cases = (((1, 3, 9, 9), 3),
  3402. ((1, 3, 9, 9), (4, 4)),
  3403. ((1, 3, 9, 9), (6, 6)),
  3404. ((2, 3, 9, 9), (3, 3)),
  3405. ((1, 1, 4, 4), (2, 2)),
  3406. ((1, 2, 6, 6), (4, 4)))
  3407. for input_shape, kernel_size in cases:
  3408. for return_indices in [False, True]:
  3409. # test case passing a single output size
  3410. yield SampleInput(
  3411. make_arg(input_shape),
  3412. kernel_size,
  3413. output_size=2,
  3414. return_indices=return_indices,
  3415. )
  3416. # test case passing a tuple output size
  3417. yield SampleInput(
  3418. make_arg(input_shape),
  3419. kernel_size,
  3420. output_size=(2, 3),
  3421. return_indices=return_indices,
  3422. )
  3423. # test case passing an output ratio
  3424. yield SampleInput(
  3425. make_arg(input_shape),
  3426. kernel_size,
  3427. output_ratio=(0.5, 0.5),
  3428. return_indices=return_indices,
  3429. )
  3430. def sample_inputs_fractional_max_pool3d(op_info, device, dtype, requires_grad, **kwargs):
  3431. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  3432. # Order: input_shape, kernel_size
  3433. cases = (((2, 3, 5, 5, 5), (2, 2, 2)),
  3434. ((1, 2, 6, 5, 4), 2),
  3435. ((1, 2, 5, 6, 5), (2, 3, 2)),
  3436. ((1, 2, 6, 6, 6), (2, 3, 2)),
  3437. ((1, 1, 7, 6, 7), (2, 3, 4)),
  3438. ((1, 1, 4, 5, 4), (2, 2, 1)),
  3439. ((1, 1, 8, 7, 6), (4, 3, 2)),
  3440. ((0, 1, 4, 5, 4), (2, 2, 1)))
  3441. for input_shape, kernel_size in cases:
  3442. for return_indices in [False, True]:
  3443. # test case passing a single output size
  3444. yield SampleInput(
  3445. make_arg(input_shape),
  3446. kernel_size,
  3447. output_size=2,
  3448. return_indices=return_indices,
  3449. )
  3450. # test case passing a tuple output size
  3451. yield SampleInput(
  3452. make_arg(input_shape),
  3453. kernel_size,
  3454. output_size=(2, 3, 2),
  3455. return_indices=return_indices,
  3456. )
  3457. # test case passing an output ratio
  3458. yield SampleInput(
  3459. make_arg(input_shape),
  3460. kernel_size,
  3461. output_ratio=(0.5, 0.5, 0.5),
  3462. return_indices=return_indices,
  3463. )
  3464. def sample_inputs_avgpool2d(op_info, device, dtype, requires_grad, **kwargs):
  3465. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  3466. # Order: input_shape, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override
  3467. cases = (((1, 3, 9, 9), 3, 1, 1, True, False, 2),
  3468. ((1, 3, 9, 9), (4, 4), (2, 3), 1, True, False, 2),
  3469. ((1, 3, 9, 9), (6, 6), (3, 3), (2, 3), True, True, 2),
  3470. ((2, 3, 9, 9), (3, 3), (1, 1), (1, ), True, False, 2),
  3471. ((1, 1, 4, 4), (2, 2), (), (0, ), False, True, -2),
  3472. ((1, 2, 6, 6), (4, 4), (2, 2), (2, ), True, True, None))
  3473. for input_shape, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override in cases:
  3474. yield SampleInput(make_arg(input_shape),
  3475. args=(kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override))
  3476. # Case with just input_shape and kernel_size
  3477. yield SampleInput(make_arg((1, 3, 9, 9)), args=((3, 3)))
  3478. def sample_inputs_avgpool1d(op_info, device, dtype, requires_grad, **kwargs):
  3479. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  3480. # Order: input_shape, kernel_size, kwargs
  3481. cases: List[Tuple[Tuple[int, ...], Union[int, Tuple[int, ...]], Dict]] = [
  3482. ((2, 3, 9), (3,), {}),
  3483. ((1, 3, 9), 3, dict(stride=1, padding=1, ceil_mode=True, count_include_pad=False)),
  3484. ((1, 3, 9), (6,), dict(stride=(3,), padding=(2,), ceil_mode=True, count_include_pad=True)),
  3485. ((2, 3, 9), (3,), dict(stride=(1,), padding=(1,), ceil_mode=False, count_include_pad=True)),
  3486. ((0, 3, 9), (6,), dict(stride=(3,), padding=(2,), ceil_mode=False, count_include_pad=True)),
  3487. ((1, 2, 9), (7,), dict(stride=(3,), padding=(2,), ceil_mode=False)),
  3488. ((1, 2, 9), (7,), dict(stride=(3,), padding=(3,), ceil_mode=True)),
  3489. ((1, 2, 9), (7,), dict(stride=(3,), ceil_mode=False)),
  3490. ((1, 2, 9), (7,), dict(stride=(3,), ceil_mode=True)),
  3491. ]
  3492. for input_shape, kernel_size, kwargs in cases:
  3493. yield SampleInput(make_arg(input_shape), args=(kernel_size,), kwargs=kwargs)
  3494. def sample_inputs_avgpool3d(op_info, device, dtype, requires_grad, **kwargs):
  3495. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  3496. # Order: input_shape, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override
  3497. cases: List[Tuple[Tuple[int, ...], Union[int, Tuple[int, ...]], Dict]] = [
  3498. ((2, 3, 3, 4, 4), (2, 2, 2), {}),
  3499. ((1, 2, 4, 4, 4), 2, dict(stride=1, padding=1, ceil_mode=True,
  3500. count_include_pad=False, divisor_override=2)),
  3501. ((1, 2, 5, 5, 5), (2, 3, 4), dict(stride=(1, 2, 2), padding=(0, 1, 2), ceil_mode=True,
  3502. count_include_pad=True, divisor_override=2)),
  3503. ((1, 2, 5, 5, 5), (2, 3, 4), dict(stride=(1, 2, 2), padding=(0, 1, 2), ceil_mode=False)),
  3504. ((1, 1, 7, 5, 7), (6, 3, 4), dict(stride=(2, 3, 2), padding=(3, 1, 0), ceil_mode=False,
  3505. count_include_pad=False, divisor_override=2)),
  3506. ((1, 1, 4, 5, 4), (2, 2, 3), dict(stride=(2, 2, 1), padding=0, ceil_mode=False,
  3507. count_include_pad=True, divisor_override=-2)),
  3508. ((1, 1, 6, 5, 6), (4, 5, 6), dict(stride=(2, 3, 2), padding=2, ceil_mode=True,
  3509. count_include_pad=True, divisor_override=None)),
  3510. ((0, 1, 4, 5, 4), (2, 3, 1), dict(stride=(2, 1, 2), padding=0, ceil_mode=False,
  3511. count_include_pad=True, divisor_override=None)),
  3512. ]
  3513. for input_shape, kernel_size, kwargs in cases:
  3514. yield SampleInput(make_arg(input_shape), args=(kernel_size,), kwargs=kwargs)
  3515. def error_inputs_avg_pool1d(op_info, device, **kwargs):
  3516. # error inputs when pad is negative
  3517. x = torch.rand([0, 1, 49], dtype=torch.float32)
  3518. yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1}),
  3519. error_regex='pad must be non-negative')
  3520. # error inputs when pad > kernel_size / 2
  3521. yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4}),
  3522. error_regex='pad should be at most half of kernel size')
  3523. def error_inputs_avg_pool2d(op_info, device, **kwargs):
  3524. # error inputs when pad is negative
  3525. x = torch.rand([0, 1, 49], dtype=torch.float32)
  3526. yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1}),
  3527. error_regex='pad must be non-negative')
  3528. # 2-dimensional kernel
  3529. yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2), 'stride': 50, 'padding': -1}),
  3530. error_regex='pad must be non-negative')
  3531. # error inputs when pad > kernel_size / 2
  3532. yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4}),
  3533. error_regex='pad should be at most half of kernel size')
  3534. # 2-dimensional kernel
  3535. yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2), 'stride': 50, 'padding': 4}),
  3536. error_regex='pad should be at most half of kernel size')
  3537. # error inputs for zero divisor
  3538. x = torch.zeros(3, 3, 3)
  3539. yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (2, 2), 'divisor_override': 0}),
  3540. error_regex='divisor must be not zero')
  3541. def error_inputs_avg_pool3d(op_info, device, **kwargs):
  3542. # error inputs when pad is negative
  3543. x = torch.rand([0, 1, 49, 50], dtype=torch.float32)
  3544. yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1}),
  3545. error_regex='pad must be non-negative')
  3546. # 3-dimensional kernel
  3547. yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2, 2), 'stride': 50, 'padding': -1}),
  3548. error_regex='pad must be non-negative')
  3549. # error inputs when pad > kernel_size / 2
  3550. yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4}),
  3551. error_regex='pad should be at most half of kernel size')
  3552. # 3-dimensional kernel
  3553. yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2, 2), 'stride': 50, 'padding': 4}),
  3554. error_regex='pad should be at most half of kernel size')
  3555. # error inputs for zero divisor
  3556. x = torch.zeros(3, 3, 3, 3)
  3557. yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (2, 2, 2), 'divisor_override': 0}),
  3558. error_regex='divisor must be not zero')
  3559. # error inputs for invalid input dimension
  3560. x = torch.rand([0, 1, 49], dtype=torch.float32)
  3561. yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 0}),
  3562. error_regex='non-empty 4D or 5D')
  3563. def sample_inputs_to(op_info, device, dtype, requires_grad, **kwargs):
  3564. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  3565. # test_multiple_devices_to_cuda would fail if we use a different device than given
  3566. devices = [device]
  3567. if torch.device(device).type == 'cpu':
  3568. devices = [torch.device('cpu'), torch.device('cuda:0')] if torch.cuda.is_available() else devices
  3569. memory_formats = [torch.preserve_format, torch.channels_last]
  3570. # TODO: can't switch `to.device` overload to use positional arguments
  3571. # https://github.com/pytorch/pytorch/issues/84265
  3572. # to.device overload
  3573. for device, nb, cp, mem_f in product(devices, [True, False], [True, False], memory_formats):
  3574. kwargs = {
  3575. "memory_format": mem_f,
  3576. }
  3577. yield SampleInput(make_arg((S, S, S, S)), args=(device, torch.float64, nb, cp), kwargs=kwargs)
  3578. # to.dtype overload
  3579. for nb, cp, mem_f in product([True, False], [True, False], memory_formats):
  3580. kwargs = {
  3581. "memory_format": mem_f,
  3582. }
  3583. yield SampleInput(make_arg((S, S, S, S)), args=(torch.float64, nb, cp), kwargs=kwargs)
  3584. # to.other overload
  3585. for device, nb, cp, mem_f in product(devices, [True, False], [True, False], memory_formats):
  3586. kwargs = {
  3587. "memory_format": mem_f,
  3588. }
  3589. other = make_arg((S, S, S, S), dtype=torch.float64, device=device)
  3590. yield SampleInput(make_arg((S, S, S, S)), args=(other, nb, cp), kwargs=kwargs)
  3591. def sample_inputs_topk(op_info, device, dtype, requires_grad, **kwargs):
  3592. def get_tensor_input(size):
  3593. return make_tensor(size, dtype=dtype, device=device, requires_grad=requires_grad)
  3594. yield SampleInput(get_tensor_input((S, M, S)), 3)
  3595. yield SampleInput(get_tensor_input((S, M, S)), 3, 1)
  3596. yield SampleInput(get_tensor_input((S, M, S)), 3, -2)
  3597. yield SampleInput(get_tensor_input((S, M, S)), 3, 1, True)
  3598. yield SampleInput(get_tensor_input((S, M, S)), 3, -2, True)
  3599. yield SampleInput(get_tensor_input((S, M, S)), 3, 1, True, True)
  3600. yield SampleInput(get_tensor_input((S, M, S)), 3, -2, True, True)
  3601. yield SampleInput(get_tensor_input(()), 1)
  3602. yield SampleInput(get_tensor_input(()), 1, 0)
  3603. yield SampleInput(get_tensor_input(()), 1, -1)
  3604. yield SampleInput(get_tensor_input(()), 1, 0, True)
  3605. yield SampleInput(get_tensor_input(()), 1, -1, True)
  3606. yield SampleInput(get_tensor_input(()), 1, 0, True, True)
  3607. yield SampleInput(get_tensor_input(()), 1, -1, True, True)
  3608. def sample_inputs_outer(op_info, device, dtype, requires_grad, **kwargs):
  3609. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  3610. yield SampleInput(make_arg(S), make_arg(M))
  3611. def sample_inputs_dist(op_info, device, dtype, requires_grad, **kwargs):
  3612. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  3613. sizes = ((S, S, S), (S,), (S, 1, S), (), (S, S))
  3614. ps = (2, 4)
  3615. for size_x, size_y, p in product(sizes, sizes, ps):
  3616. yield SampleInput(make_arg(size_x), args=(make_arg(size_y), p))
  3617. # Missing to test the nondeterminism of the operation
  3618. # https://github.com/pytorch/pytorch/issues/53352
  3619. def sample_inputs_index(op_info, device, dtype, requires_grad, reference=False, **kwargs):
  3620. # target.index_select(dim, idx)
  3621. select = "index_select" in op_info.name
  3622. # target.index_add(dim, idx, source, *, alpha=1)
  3623. add = "index_add" in op_info.name
  3624. # target.index_copy(dim, idx, source)
  3625. copy = "index_copy" in op_info.name
  3626. # target.index_fill(dim, idx, value)
  3627. fill = "index_fill" in op_info.name
  3628. # Extended reference inputs. We generate that exercise atomic adds / writing
  3629. # several times to one location
  3630. if reference:
  3631. make_arg = partial(torch.ones, device=device, dtype=dtype, requires_grad=requires_grad)
  3632. make_idx = partial(torch.zeros, device=device, dtype=torch.int64)
  3633. else:
  3634. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  3635. # idx They need to be different for copy and add to be deterministic
  3636. if copy or add:
  3637. make_idx = partial(torch.randperm, device=device, dtype=torch.int64)
  3638. else:
  3639. def make_idx(n):
  3640. return make_tensor((n,), device=device, dtype=torch.int64, low=0, high=n)
  3641. shapes = [(), (1,), (S, S)]
  3642. # extra parameter for add
  3643. if add:
  3644. if dtype == torch.bool:
  3645. alphas = (True, False)
  3646. else:
  3647. alphas = (-1, 0, 2)
  3648. else:
  3649. alphas = (None,)
  3650. for shape, alpha in product(shapes, alphas):
  3651. t = make_arg(shape)
  3652. args = []
  3653. # dim. We handle the scalar case
  3654. dim = 1 if t.ndim == 2 else 0
  3655. args.append(dim)
  3656. idx = make_idx(t.shape[dim] if t.ndim != 0 else 1)
  3657. args.append(idx)
  3658. # source
  3659. if copy or add:
  3660. args.append(make_arg(shape))
  3661. elif fill:
  3662. # A weird number to catch errors
  3663. args.append(make_arg((1,)).item())
  3664. args = tuple(args)
  3665. kwargs = {} if alpha is None else {"alpha": alpha}
  3666. yield SampleInput(t, args=args, kwargs=kwargs)
  3667. def sample_inputs_index_reduce(op_info, device, dtype, requires_grad, **kwargs):
  3668. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  3669. def make_idx(n, m):
  3670. return make_tensor((n,), device=device, dtype=torch.int64, low=0, high=m)
  3671. shapes = [((), ()), ((1,), (1,)), ((S, S), (S, M)), ((S, S, S), (S, M, S))]
  3672. include_selfs = (True, False)
  3673. reduces = ('prod', 'mean', 'amin', 'amax')
  3674. for shape, include_self, reduce in product(shapes, include_selfs, reduces):
  3675. self_shape, src_shape = shape
  3676. # dim. We handle the scalar case
  3677. dim = 1 if len(self_shape) >= 2 else 0
  3678. idx = make_idx(src_shape[dim] if len(src_shape) != 0 else 1,
  3679. self_shape[dim] if len(self_shape) != 0 else 1)
  3680. args = (dim, idx, make_arg(src_shape), reduce)
  3681. yield SampleInput(make_arg(self_shape),
  3682. args=args,
  3683. kwargs={'include_self' : include_self})
  3684. # Sample inputs to test edge cases for backward
  3685. if requires_grad:
  3686. # Check that gradients are propagated correctly for prod when zeros in self/src are reduced
  3687. # This sample tests gradients for the following cases
  3688. # (a) 1 zero reduced (from source (self[0, 1]), from self (self[0, 0]))
  3689. # (b) 2 zeros reduced (1 from src and 1 from self (self[1, 0], self[1, 1])
  3690. # (c) no zeros reduced (self[2, 1], self[2, 2])
  3691. # (d) 2 zeros reduced (both from src) is tested in test/test_autograd.py
  3692. # test_scatter_index_reduce_prod_gradgrad_error as this case is not supported for gradgrad
  3693. input = torch.tensor([[0, 13], [0, 0], [15, 19]], dtype=dtype, device=device, requires_grad=requires_grad)
  3694. src = torch.tensor([[2, 0], [0, 0], [2, 3], [2, 2]], dtype=dtype, device=device, requires_grad=requires_grad)
  3695. idx = torch.tensor([0, 1, 2, 0], dtype=torch.long, device=device)
  3696. yield SampleInput(input,
  3697. args=(0, idx, src, 'prod'),
  3698. kwargs={'include_self': True})
  3699. def sample_inputs_mode(op_info, device, dtype, requires_grad, **kwargs):
  3700. args = (
  3701. ((S, S, S), (),),
  3702. ((S, S, S), (1, ),),
  3703. ((S, S, S), (1, True, ),),
  3704. ((), (),),
  3705. ((), (0,),),
  3706. ((), (0, True,),),
  3707. # Non-fused mode kernel on CUDA
  3708. ((3000,), ()),
  3709. )
  3710. make_arg = partial(make_tensor, dtype=dtype, device=device,
  3711. requires_grad=requires_grad, low=None, high=None)
  3712. return (SampleInput(make_arg(input_tensor), *args)
  3713. for input_tensor, args in args)
  3714. # Missing to test the nondeterminism of the operation
  3715. # https://github.com/pytorch/pytorch/issues/53352
  3716. def sample_inputs_put(op_info, device, dtype, requires_grad, **kwargs):
  3717. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  3718. make_idx = partial(make_tensor, low=0, dtype=torch.int64, device=device, requires_grad=False)
  3719. S = 3
  3720. # Generic inputs
  3721. idx = torch.randperm(S * S, device=device, dtype=torch.int64)[:S]
  3722. idx_list = [idx, -idx - 1]
  3723. for idx, acc in product(idx_list, (True, False)):
  3724. yield SampleInput(input=make_arg((S, S)),
  3725. args=(idx.clone(),
  3726. make_arg((S,)),
  3727. acc))
  3728. # Scalar cases
  3729. scalar_sizes = [(), (1,)]
  3730. tgt_gen = (make_arg(size) for size in scalar_sizes)
  3731. idx_gen = (make_idx(size, high=1) for size in scalar_sizes)
  3732. src_gen = (make_arg(size) for size in scalar_sizes)
  3733. for tgt, idx, src, acc in product(tgt_gen, idx_gen, src_gen, (True, False)):
  3734. yield SampleInput(input=tgt.clone().requires_grad_(requires_grad),
  3735. args=(idx.clone(),
  3736. src.clone().requires_grad_(requires_grad),
  3737. acc))
  3738. # Empty cases
  3739. tgt_sizes = [(0,), (), (1,), (3, 2)]
  3740. tgt_gen = (make_arg(size) for size in tgt_sizes)
  3741. idx = make_idx((0,), high=1)
  3742. src = make_arg((0,))
  3743. for tgt, acc in product(tgt, (True, False)):
  3744. yield SampleInput(input=tgt.clone().requires_grad_(requires_grad),
  3745. args=(idx.clone(),
  3746. src.clone().requires_grad_(requires_grad),
  3747. acc))
  3748. def sample_inputs_take(op_info, device, dtype, requires_grad, **kwargs):
  3749. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  3750. make_idx = partial(make_tensor, low=0, dtype=torch.int64, device=device, requires_grad=False)
  3751. S = 3
  3752. # Generic inputs: take S elements out of S * S
  3753. index = make_idx((S,), high=(S * S))
  3754. for idx in (index, -index - 1):
  3755. yield SampleInput(input=make_arg((S, S)), args=(idx,))
  3756. # Scalar cases
  3757. scalar_sizes = [(), (1,)]
  3758. src_gen = (make_arg(size) for size in scalar_sizes)
  3759. idx_gen = (make_idx(size, high=1) for size in scalar_sizes)
  3760. for src, idx in product(src_gen, idx_gen):
  3761. yield SampleInput(input=src.clone().requires_grad_(requires_grad),
  3762. args=(idx.clone(),))
  3763. # Empty cases
  3764. src_sizes = [(0,), (), (1,), (3, 2)]
  3765. src_gen = (make_arg(size) for size in src_sizes)
  3766. idx = make_idx((0,), high=1)
  3767. for src in src_gen:
  3768. yield SampleInput(input=src.clone().requires_grad_(requires_grad),
  3769. args=(idx.clone(),))
  3770. def sample_movedim_moveaxis(op_info, device, dtype, requires_grad, **kwargs):
  3771. make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad)
  3772. yield SampleInput(make_arg((4, 3, 2, 1)), [0, 1, 2, 3], [3, 2, 1, 0])
  3773. yield SampleInput(make_arg((4, 3, 2, 1)), [0, -1, -2, -3], [-3, -2, -1, -0])
  3774. def reference_movedim_moveaxis(op_info, device, dtype, requires_grad, **kwargs):
  3775. yield from sample_movedim_moveaxis(op_info, device, dtype, requires_grad, **kwargs)
  3776. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  3777. # shape, source, destination
  3778. args = (
  3779. # empty inputs
  3780. ((), (), ()),
  3781. # int inputs, negative
  3782. ((3, 5, 7, 2), -2, 1),
  3783. # swap bounds
  3784. ((3, 5, 7, 2), (-1, 0), (0, -1)),
  3785. # non-sequential, negative
  3786. ((2, 3, 4, 5, 6), (3, -3, 4), (1, 0, -1)),
  3787. # idempotence, negative
  3788. ((2, 3, 4, 5, 6), (-3, 4, 3, 1), (-3, 4, 3, 1)),
  3789. # reverse, sequential, positive
  3790. ((6, 2, 3, 5, 4), (4, 3, 2, 1, 0), (0, 1, 2, 3, 4)),
  3791. # reverse, non-sequential
  3792. ((6, 2, 3, 5, 4), (-3, -2, -4, -5, -1), (2, 1, 3, 4, 0)),
  3793. # reverse, sequential, negative
  3794. ((6, 2, 3, 5, 4), (4, -2, 2, -4, -5), (-5, 1, 2, -2, -1)),
  3795. )
  3796. for shape, source, destination in args:
  3797. yield SampleInput(make_arg(shape), args=(source, destination))
  3798. def error_movedim_moveaxis(op_info, device, **kwargs):
  3799. make_arg = partial(make_tensor, device=device, dtype=torch.float32)
  3800. # source length < destination length
  3801. yield ErrorInput(
  3802. SampleInput(make_arg(2, 3, 4, 5, 6), args=((3, -3), (1, 0, -1))),
  3803. error_regex=(r"movedim: Invalid source or destination dims: source "
  3804. r"\(\[3, -3\] dims\) should contain the same number of "
  3805. r"dims as destination \(\[1, 0, -1\] dims\)"),
  3806. )
  3807. # source length > destination length
  3808. yield ErrorInput(
  3809. SampleInput(make_arg(2, 3, 4, 5, 6), args=((3, -3, 4), (1, 0))),
  3810. error_regex=(r"movedim: Invalid source or destination dims: source "
  3811. r"\(\[3, -3, 4\] dims\) should contain the same number of "
  3812. r"dims as destination \(\[1, 0\] dims\)"),
  3813. )
  3814. # repeated source dim, with negative indices
  3815. yield ErrorInput(
  3816. SampleInput(make_arg(2, 3, 4, 5, 6), args=((0, 4, -5), (1, 0, 2))),
  3817. error_regex=r"movedim: repeated dim in `source` \(\[0, 4, -5\]\)",
  3818. )
  3819. # repeated destination dim, with negative indices
  3820. yield ErrorInput(
  3821. SampleInput(make_arg(2, 3, 4, 5, 6), args=((1, 0, 2), (0, 4, -5))),
  3822. error_regex=r"movedim: repeated dim in `destination` \(\[0, 4, -5\]\)",
  3823. )
  3824. # repeated dim (both), with negative indices
  3825. yield ErrorInput(
  3826. SampleInput(make_arg(2, 3, 4, 5, 6), args=((1, 0, -4), (0, 4, -5))),
  3827. error_regex=r"movedim: repeated dim in `source` \(\[1, 0, -4\]\)",
  3828. )
  3829. # out of bounds source inputs, with negative indices
  3830. yield ErrorInput(
  3831. SampleInput(make_arg(2, 3, 4, 5, 6), args=((0, 1, -6), (1, 4, 2))),
  3832. error_regex=r"Dimension out of range \(expected to be in range of \[-5, 4\], but got -6\)",
  3833. error_type=IndexError,
  3834. )
  3835. # out of bounds destination inputs, with negative indices
  3836. yield ErrorInput(
  3837. SampleInput(make_arg(2, 3, 4, 5, 6), args=((1, 4, 2), (0, 1, -6))),
  3838. error_regex=r"Dimension out of range \(expected to be in range of \[-5, 4\], but got -6\)",
  3839. error_type=IndexError,
  3840. )
  3841. # out of bounds source input, int
  3842. yield ErrorInput(
  3843. SampleInput(make_arg(2, 3, 4, 5, 6), args=(-6, 1)),
  3844. error_regex=r"Dimension out of range \(expected to be in range of \[-5, 4\], but got -6\)",
  3845. error_type=IndexError,
  3846. )
  3847. # out of bounds destination input, int
  3848. yield ErrorInput(
  3849. SampleInput(make_arg(2, 3, 4, 5, 6), args=(3, -6)),
  3850. error_regex=r"Dimension out of range \(expected to be in range of \[-5, 4\], but got -6\)",
  3851. error_type=IndexError,
  3852. )
  3853. def sample_repeat_tile(op_info, device, dtype, requires_grad, **kwargs):
  3854. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  3855. rep_dims = ((), (0, ), (1, ), (0, 2), (1, 1), (2, 3), (2, 3, 2), (0, 2, 3), (2, 1, 1, 1),)
  3856. shapes = ((), (0,), (2,), (3, 0), (3, 2), (3, 0, 1))
  3857. if requires_grad:
  3858. # Tests for variant_consistency_jit, grad, gradgrad
  3859. # are slower. Use smaller bags of `rep_dims` and `shapes`
  3860. # in this case.
  3861. rep_dims = ((), (0, ), (0, 2), (1, 1), (2, 3), (1, 3, 2), (3, 1, 1)) # type: ignore[assignment]
  3862. shapes = ((), (0,), (2,), (3, 2)) # type: ignore[assignment]
  3863. is_repeat_op = op_info.name in ['repeat', '_refs.repeat']
  3864. for rep_dim, shape in product(rep_dims, shapes):
  3865. # `torch.repeat` errors for `len(rep_dims) < t.dim()`,
  3866. # so we filter such combinations.
  3867. if is_repeat_op and len(rep_dim) < len(shape):
  3868. continue
  3869. yield SampleInput(make_arg(shape), rep_dim)
  3870. def sample_inputs_narrow_narrow_copy(op_info, device, dtype, requires_grad, *, is_narrow, **kwargs):
  3871. shapes_and_args = (
  3872. ((S, S, S), 1, 2, 2),
  3873. ((S, S, S), -1, 2, 2),
  3874. ((S, S, S), 1, 0, 0),
  3875. ((S, S, S), -1, 0, 0),
  3876. ((S, S, S), 2, 1, 2),
  3877. )
  3878. for shape, dim, start, length in shapes_and_args:
  3879. tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None,
  3880. requires_grad=requires_grad)
  3881. yield SampleInput(tensor, dim, start, length)
  3882. # narrow also accepts the start argument being a Tensor
  3883. if is_narrow:
  3884. yield SampleInput(tensor, dim, torch.tensor(start), length)
  3885. def reference_inputs_narrow_narrow_copy(op_info, device, dtype, requires_grad, *, is_narrow, **kwargs):
  3886. yield from sample_inputs_narrow_narrow_copy(op_info, device, dtype, requires_grad, is_narrow=is_narrow, **kwargs)
  3887. shapes_and_args = (
  3888. # 1-dim
  3889. ((M,), 0, 0, 0), # 0 elems from the left
  3890. ((M,), -1, -1, 0), # 0 elems from the right
  3891. ((M,), 0, 5, 3), # 3 elems from the left
  3892. ((M,), 0, -5, 2), # 2 elems from the right
  3893. ((M,), -1, 0, M), # M elems from the left
  3894. ((M,), 0, -M, M), # M elems from the right
  3895. # 2-dim
  3896. ((M, S), 1, 0, 0), # dim 1, 0 elems from the left
  3897. ((S, M), -2, -1, 0), # dim 0, 0 elems from the right
  3898. ((L, S), 1, 2, 3), # dim 1, 3 elems from the left
  3899. ((L, S), -1, 3, 2), # dim 1, 2 elems from the left
  3900. ((M, L), 0, 0, M), # dim 0, M elems from the left
  3901. ((M, L), -1, -L, L), # dim 1, L elems from the right
  3902. # 3-dim
  3903. ((L, M, S), 2, 0, 0), # dim 2, 0 elems from the left
  3904. ((M, S, L), -1, -1, 0), # dim 2, 0 elems from the right
  3905. ((S, L, M), 2, 0, M), # dim 2, M elems from the left
  3906. ((L, S, M), -1, -M, M), # dim 2, M elems from the right
  3907. ((S, L, M), 1, 0, 0), # dim 1, 0 elems from the left
  3908. ((S, L, M), 0, 2, 1), # dim 0, 1 elem from the left
  3909. ((M, S, M), -1, -5, 4), # dim 2, 4 elems from the right
  3910. )
  3911. for shape, dim, start, length in shapes_and_args:
  3912. tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None,
  3913. requires_grad=requires_grad)
  3914. yield SampleInput(tensor, dim, start, length)
  3915. # narrow also accepts the start argument being a Tensor
  3916. if is_narrow:
  3917. yield SampleInput(tensor, dim, torch.tensor(start), length)
  3918. def error_inputs_narrow_narrow_copy(op_info, device, *, is_narrow, is_ref):
  3919. make_arg = partial(make_tensor, device=device, dtype=torch.float32)
  3920. # 0-dim
  3921. yield ErrorInput(SampleInput(make_arg(()), 0, 0, 1),
  3922. error_type=RuntimeError,
  3923. error_regex=r"narrow\(\) cannot be applied to a 0-dim tensor\.")
  3924. # out of bounds dim
  3925. if not is_narrow and not is_ref and torch.device(device).type == 'cpu':
  3926. # narrow_copy_dense_cpu_out
  3927. yield ErrorInput(SampleInput(make_arg((M, S, L)), 3, 0, 0),
  3928. error_type=RuntimeError,
  3929. error_regex=r"Expected dim < static_cast<int64_t>\(self_sizes.size\(\)\) to be true, but got false\.")
  3930. else:
  3931. yield ErrorInput(SampleInput(make_arg((M, S, L)), 3, 0, 0),
  3932. error_type=IndexError,
  3933. error_regex=r"Dimension out of range \(expected to be in range of \[-3, 2\], but got 3\)")
  3934. # out of bounds dim (negative)
  3935. yield ErrorInput(SampleInput(make_arg((L, S, M)), -4, 0, 0),
  3936. error_type=IndexError,
  3937. error_regex=r"Dimension out of range \(expected to be in range of \[-3, 2\], but got -4\)")
  3938. # out of bounds start
  3939. if not is_narrow and not is_ref and torch.device(device).type == 'cpu':
  3940. # narrow_copy_dense_cpu_out
  3941. yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, M + 1, 0),
  3942. error_type=RuntimeError,
  3943. error_regex=r"start \(11\) \+ length \(0\) exceeds dimension size \(10\)\.")
  3944. else:
  3945. yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, M + 1, 0),
  3946. error_type=IndexError,
  3947. error_regex=r"Dimension out of range \(expected to be in range of \[-10, 9\], but got 11\)")
  3948. # out of bounds start (negative)
  3949. yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, -M - 1, 0),
  3950. error_type=IndexError,
  3951. error_regex=r"Dimension out of range \(expected to be in range of \[-10, 9\], but got -11\)")
  3952. # out of bounds length
  3953. yield ErrorInput(SampleInput(make_arg((S, L, M)), 2, 0, M + 1),
  3954. error_type=RuntimeError,
  3955. error_regex=r"start \(0\) \+ length \(11\) exceeds dimension size \(10\)\.")
  3956. # out of bounds length (negative)
  3957. if not is_narrow and not is_ref and torch.device(device).type == 'cpu':
  3958. # narrow_copy_dense_cpu_out
  3959. yield ErrorInput(SampleInput(make_arg((M,)), 0, 0, -1),
  3960. error_type=RuntimeError,
  3961. error_regex=r"start \(0\) \+ length \(-1\) exceeds dimension size \(10\)\.")
  3962. else:
  3963. yield ErrorInput(SampleInput(make_arg((M,)), 0, 0, -1),
  3964. error_type=RuntimeError,
  3965. error_regex=r"narrow\(\): length must be non-negative\.")
  3966. # Test Tensor overload that was added for XLA. Start must be an 0-dim
  3967. # integral Tensor. narrow_copy doesn't have this overload.
  3968. # https://github.com/pytorch/pytorch/issues/31558
  3969. if is_narrow:
  3970. # *1-dim* integral Tensor
  3971. yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, make_arg(S, dtype=torch.int), 2),
  3972. error_type=RuntimeError,
  3973. error_regex=r"start must be an 0-dim integral Tensor\.")
  3974. # 0-dim *bool* Tensor (bools are not allowed)
  3975. yield ErrorInput(SampleInput(make_arg((L, M, S)), -3, make_arg((), dtype=torch.bool), 3),
  3976. error_type=RuntimeError,
  3977. error_regex=r"start must be an 0-dim integral Tensor\.")
  3978. def sample_trapezoid(op_info, device, dtype, requires_grad, **kwargs):
  3979. y_shape_x_shape_and_kwargs = [
  3980. ((2, 3), (2, 3), {}),
  3981. ((2, 3), (2, 3), {'dim': 1}),
  3982. ((6,), (6,), {}),
  3983. ((6,), None, {}),
  3984. # When 'trapezoid' is called with an empty input, it does not produce an output with requires_grad
  3985. # See Issue #{61619}
  3986. # ((6,0), (6,0), {}),
  3987. ((2, 3), (1, 3), {}),
  3988. ((3, 3), (3, 3), {}),
  3989. ((3, 3), (3, 3), {'dim': -2}),
  3990. ((5,), None, {'dx': 2.0}),
  3991. ((2, 2), None, {'dx': 3.0})
  3992. ]
  3993. make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None,
  3994. requires_grad=requires_grad)
  3995. for y_shape, x_shape, kwarg in y_shape_x_shape_and_kwargs:
  3996. y_tensor = make_arg(y_shape)
  3997. if x_shape is not None:
  3998. x_tensor = make_arg(x_shape)
  3999. yield SampleInput(y_tensor, x_tensor, **kwarg)
  4000. else:
  4001. yield SampleInput(y_tensor, **kwarg)
  4002. def sample_cumulative_trapezoid(op_info, device, dtype, requires_grad, **kwargs):
  4003. y_shape_x_shape_and_kwargs = [
  4004. ((2, 3), (2, 3), {}),
  4005. ((2, 3), (2, 3), {'dim': 1}),
  4006. ((6,), (6,), {}),
  4007. ((6,), None, {}),
  4008. # When 'cumulative_trapezoid' is called with an empty input, it does not produce an output with requires_grad
  4009. # See Issue #{61619}
  4010. # ((6,0), (6,0), {}),
  4011. ((2, 3), (1, 3), {}),
  4012. ((3, 3), (3, 3), {}),
  4013. ((3, 3), (3, 3), {'dim': -2}),
  4014. ((5,), None, {'dx': 2.0}),
  4015. ((2, 2), None, {'dx': 3.0})
  4016. ]
  4017. make_arg = partial(make_tensor, device=device, dtype=dtype,
  4018. requires_grad=requires_grad, low=None, high=None)
  4019. for y_shape, x_shape, kwarg in y_shape_x_shape_and_kwargs:
  4020. y_tensor = make_arg(y_shape)
  4021. if x_shape is not None:
  4022. x_tensor = make_arg(x_shape)
  4023. yield SampleInput(y_tensor, x_tensor, **kwarg)
  4024. else:
  4025. yield SampleInput(y_tensor, **kwarg)
  4026. def sample_unsqueeze(op_info, device, dtype, requires_grad, **kwargs):
  4027. shapes_and_axes = [
  4028. ((3, 4, 5), 0),
  4029. ((3, 4, 5), 1),
  4030. ((3, 4, 5), 3),
  4031. ((3, 4, 5), -1),
  4032. ((3, 4, 5), -3),
  4033. ((), 0),
  4034. ((), -1),
  4035. ((1,), 0),
  4036. ((1,), -1),
  4037. ]
  4038. for shape, axis in shapes_and_axes:
  4039. tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None,
  4040. requires_grad=requires_grad)
  4041. yield SampleInput(tensor, axis)
  4042. def sample_inputs_nn_unfold(op_info, device, dtype, requires_grad, **kwargs):
  4043. shapes = ((0, 1, 5, 5), (1, 1, 5, 5), (2, 3, 5, 5))
  4044. kernel_sizes = (2, (2, 2), (3, 3), (2, 3))
  4045. dilations = (1, 2, (1, 2))
  4046. paddings = (0, 1, (1, 1), (1, 2))
  4047. strides = (1, 2, (1, 2))
  4048. cases = product(shapes, kernel_sizes, dilations, paddings, strides)
  4049. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  4050. for shape, kernel_size, dilation, padding, stride in cases:
  4051. tensor = make_arg(shape)
  4052. yield SampleInput(tensor, kernel_size, dilation, padding, stride)
  4053. # With default args
  4054. yield SampleInput(make_arg((1, 1, 5, 5)), (3, 3))
  4055. def sample_inputs_squeeze(op_info, device, dtype, requires_grad, **kwargs):
  4056. shapes_and_args = (
  4057. ((S, 1, S, 1), ()),
  4058. ((1, 1, 1, 1), ()),
  4059. ((1, 1, 1, 1), (0,)),
  4060. ((S, 1, S, 1), (1,)),
  4061. ((S, 1, S, 1), (-1,)),
  4062. ((S, 1, S, 1), (2,)),
  4063. ((S, 1, S, 1), (-2,)),
  4064. ((), (0, )),
  4065. )
  4066. for shape, args in shapes_and_args:
  4067. tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None,
  4068. requires_grad=requires_grad)
  4069. yield SampleInput(tensor, args=args)
  4070. def sample_inputs_squeeze_multiple(op_info, device, dtype, requires_grad, **kwargs):
  4071. shapes_and_args = (
  4072. ((1, 1, 1, 1), ()),
  4073. ((S, 1, S, 1), (1,)),
  4074. ((S, 1, S, 1), (-1,)),
  4075. ((S, 1, S, 1), (1, 3)),
  4076. ((S, 1, S, 1), (1, 2,)),
  4077. ((), (0,)),
  4078. )
  4079. for shape, dims in shapes_and_args:
  4080. tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None,
  4081. requires_grad=requires_grad)
  4082. yield SampleInput(tensor, dims)
  4083. def _squeeze_ref(x, axis=None):
  4084. # NumPy doesn't allow squeezing scalars
  4085. if x.ndim == 0:
  4086. return x
  4087. if isinstance(axis, Sequence):
  4088. # Numpy doesn't allow specifying non-singular dimensions
  4089. axis = tuple(a for a in axis if x.shape[a] == 1)
  4090. if isinstance(axis, int) and x.shape[axis] != 1:
  4091. return x
  4092. return np.squeeze(x, axis)
  4093. def sample_inputs_nn_pad(op_info, device, dtype, requires_grad, mode, **kwargs):
  4094. assert mode in ('constant', 'reflect', 'replicate', 'circular')
  4095. if mode in ['reflect', 'replicate']:
  4096. cases: tuple = ( # ignore
  4097. ((1, 3), (1, 2)),
  4098. ((1, 3), (0, 1)),
  4099. ((0, 3, 3), (1, 2)),
  4100. ((0, 3, 3), (0, 1)),
  4101. ((1, 3, 3), (1, 2)),
  4102. ((1, 3, 3), (0, 1)),
  4103. ((1, 3, 3), (0, 2, 0, 1)),
  4104. ((0, 3, 3, 3), (0, 2, 0, 1)),
  4105. ((3, 3, 5, 5), (0, 2, 0, 1)),
  4106. ((3, 3, 5, 5), (1, 1, 1, 1, 1, 1)),
  4107. ((1, 3, 3, 3, 3), (1, 1, 1, 1, 1, 1)),
  4108. ((1, 3, 4, 4), (-1, 1, -2, 1)),
  4109. )
  4110. elif mode == 'constant':
  4111. cases = (
  4112. ((1, 3), (1, 2)),
  4113. ((1, 3), (0, 1)),
  4114. ((1, 3), (0, 2, 0, 1)),
  4115. ((0, 3, 3), (1, 2)),
  4116. ((0, 3, 3), (0, 1)),
  4117. ((0, 3, 3), (0, 2, 0, 1)),
  4118. ((0, 3, 3), (1, 1, 1, 1, 1, 1)),
  4119. ((1, 3, 3), (1, 2)),
  4120. ((1, 3, 3), (0, 1)),
  4121. ((1, 3, 3), (0, 2, 0, 1)),
  4122. ((1, 3, 3), (1, 1, 1, 1, 1, 1)),
  4123. ((0, 3, 3, 3), (1, 2)),
  4124. ((0, 3, 3, 3), (0, 1)),
  4125. ((0, 3, 3, 3), (0, 2, 0, 1)),
  4126. ((0, 3, 3, 3), (1, 1, 1, 1, 1, 1)),
  4127. ((3, 3, 5, 5), (1, 2)),
  4128. ((3, 3, 5, 5), (0, 1)),
  4129. ((3, 3, 5, 5), (0, 2, 0, 1)),
  4130. ((3, 3, 5, 5), (1, 1, 1, 1, 1, 1)),
  4131. ((1, 3, 3, 3, 3), (1, 2)),
  4132. ((1, 3, 3, 3, 3), (0, 1)),
  4133. ((1, 3, 3, 3, 3), (0, 2, 0, 1)),
  4134. ((1, 3, 3, 3, 3), (1, 1, 1, 1, 1, 1)),
  4135. ((1, 3, 4, 4), (-1, 1, -2, 1)),
  4136. )
  4137. else: # mode == 'circular'
  4138. if dtype == torch.bool:
  4139. # test_dtypes fails on ASAN with for the case ab
  4140. # runtime error: load of value 190, which is not a valid value for type 'bool'
  4141. # Reference: https://github.com/pytorch/pytorch/pull/62814#issuecomment-894156562
  4142. # Reference Issue: https://github.com/pytorch/pytorch/issues/63034
  4143. cases = (
  4144. ((2, 3, 3), (1, 2)),
  4145. ((1, 3, 3), (1, 2)),
  4146. )
  4147. else:
  4148. cases = (
  4149. ((0, 3, 3), (1, 2)),
  4150. ((0, 3, 3), (0, 1)),
  4151. ((1, 3, 3), (1, 2)),
  4152. ((1, 3, 3), (0, 1)),
  4153. ((0, 3, 3, 3), (0, 2, 0, 1)),
  4154. ((3, 3, 5, 5), (0, 2, 0, 1)),
  4155. ((1, 3, 3, 3, 3), (1, 1, 1, 1, 1, 1)),
  4156. ((1, 3, 4, 4), (-1, 1, -2, 1)),
  4157. )
  4158. make_inp = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  4159. if mode == 'constant':
  4160. # Default args
  4161. yield SampleInput(make_inp((1, 3, 3)), args=((2, 2),))
  4162. if mode in ['reflect', 'replicate', 'circular']:
  4163. for shape, pad in cases:
  4164. yield SampleInput(make_inp(shape), args=(pad, mode))
  4165. else: # mode == 'constant'
  4166. for pad_value in (1., 2.):
  4167. for shape, pad in cases:
  4168. yield SampleInput(make_inp(shape), args=(pad, mode, pad_value))
  4169. def sample_inputs_constant_pad_nd(op_info, device, dtype, *args, **kwargs):
  4170. # Inherit sample inputs from nn.pad, but transform them to fit
  4171. # constant_pad_nd's interface
  4172. nn_samples = sample_inputs_nn_pad(op_info, device, dtype, *args,
  4173. mode='constant', **kwargs)
  4174. # NOTE: primTorch is more strict about the type of the fill value argument
  4175. # So we must cast it to the correct dtype
  4176. from torch._prims_common import dtype_to_type
  4177. scalar_type = dtype_to_type(dtype)
  4178. def drop_mode_argument(input, pad, mode=None, value=None):
  4179. if value is None:
  4180. return SampleInput(input, args=(pad,))
  4181. else:
  4182. return SampleInput(input, args=(pad, scalar_type(value)))
  4183. for sample in nn_samples:
  4184. yield drop_mode_argument(sample.input, *sample.args, **sample.kwargs)
  4185. def sample_inputs_repeat_interleave(op_info, device, dtype, requires_grad, **kwargs):
  4186. make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  4187. yield SampleInput(make_input(()), repeats=2)
  4188. yield SampleInput(make_input((2, 3, 4)), repeats=2)
  4189. yield SampleInput(make_input((2, 3, 4)), repeats=2, dim=1)
  4190. yield SampleInput(make_input((2, 3, 4)), repeats=torch.arange(3, device=device), dim=1)
  4191. def sample_inputs_stft(op_info, device, dtype, requires_grad, **kwargs):
  4192. def mt(shape, **kwargs):
  4193. return make_tensor(shape, device=device, dtype=dtype,
  4194. requires_grad=requires_grad, **kwargs)
  4195. yield SampleInput(mt(100), n_fft=10, return_complex=True)
  4196. yield SampleInput(mt(100), n_fft=10, return_complex=False)
  4197. if dtype.is_complex:
  4198. yield SampleInput(mt(100), n_fft=10)
  4199. for center in [False, True]:
  4200. yield SampleInput(mt(10), n_fft=7, center=center, return_complex=True)
  4201. yield SampleInput(mt((10, 100)), n_fft=16, hop_length=4,
  4202. center=center, return_complex=True)
  4203. window = mt(16, low=.5, high=2.0)
  4204. yield SampleInput(
  4205. mt((2, 100)), kwargs=dict(n_fft=16, window=window, return_complex=True, center=center))
  4206. yield SampleInput(
  4207. mt((3, 100)), kwargs=dict(n_fft=16, window=window, return_complex=True, center=center))
  4208. if not dtype.is_complex:
  4209. yield SampleInput(
  4210. mt((10, 100)), n_fft=16, window=window, onesided=False,
  4211. return_complex=True)
  4212. def sample_inputs_istft(op_info, device, dtype, requires_grad, **kwargs):
  4213. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  4214. def mt(shape, **kwargs):
  4215. real_shape = shape if dtype.is_complex else shape + (2,)
  4216. return make_arg(real_shape, **kwargs)
  4217. yield SampleInput(mt((10, 2)), kwargs=dict(n_fft=10))
  4218. yield SampleInput(mt((6, 3)), kwargs=dict(n_fft=6, onesided=False))
  4219. yield SampleInput(mt((6, 4)), kwargs=dict(n_fft=10, onesided=True))
  4220. for center in [False, True]:
  4221. yield SampleInput(mt((10, 10, 6)), kwargs=dict(n_fft=10, center=center))
  4222. yield SampleInput(mt((1, 9, 10)), kwargs=dict(n_fft=16, hop_length=4, center=center))
  4223. window = make_arg(10, low=.5, high=2.0)
  4224. yield SampleInput(mt((10, 10, 6)), kwargs=dict(
  4225. n_fft=10, window=window, center=center, return_complex=dtype.is_complex))
  4226. yield SampleInput(mt((10, 10, 10)), kwargs=dict(
  4227. n_fft=10, window=window[:8], win_length=8, center=center, return_complex=True))
  4228. real_window = window if not dtype.is_complex else window.real
  4229. yield SampleInput(mt((10, 5, 6)), kwargs=dict(n_fft=8, window=real_window[:8], center=center))
  4230. def sample_inputs_ormqr(op_info, device, dtype, requires_grad, **kwargs):
  4231. # create a helper function wrapping `make_tensor`
  4232. make_input = partial(make_tensor, dtype=dtype, device=device, low=-1, high=1)
  4233. batches = [(), (0, ), (2, ), (2, 1)]
  4234. ns = [5, 2, 0]
  4235. tf = [True, False]
  4236. for batch, (m, n), left, transpose in product(batches, product(ns, ns), tf, tf):
  4237. input = make_input((*batch, m, n))
  4238. reflectors, tau = torch.geqrf(input)
  4239. reflectors.requires_grad_(requires_grad)
  4240. tau.requires_grad_(requires_grad)
  4241. other_matrix_shape = (m, n) if left else (n, m)
  4242. other = make_input((*batch, *other_matrix_shape), requires_grad=requires_grad)
  4243. yield SampleInput(reflectors, tau, other, left=left, transpose=transpose)
  4244. def sample_inputs_cholesky_solve(op_info, device, dtype, requires_grad=False, **kwargs):
  4245. cholesky_inverse_samples = sample_inputs_linalg_cholesky_inverse(
  4246. op_info, device, dtype, requires_grad=False
  4247. )
  4248. for sample in cholesky_inverse_samples:
  4249. psd_matrix = sample.input
  4250. sample.input = make_tensor(psd_matrix.shape, dtype=dtype, device=device, requires_grad=requires_grad, low=None, high=None)
  4251. sample.args = (psd_matrix.requires_grad_(requires_grad),)
  4252. yield sample
  4253. def sample_inputs_lu(op_info, device, dtype, requires_grad=False, **kwargs):
  4254. make_arg = partial(make_fullrank_matrices_with_distinct_singular_values,
  4255. dtype=dtype, device=device, requires_grad=requires_grad)
  4256. # not needed once OpInfo tests support Iterables
  4257. batch_shapes = ((), (3,), (3, 3))
  4258. for batch_shape, get_infos, size_delta in product(batch_shapes, (True, False), (-2, -1, 0, +1, +2)):
  4259. shape = batch_shape + (S + size_delta, S)
  4260. input = make_arg(*shape)
  4261. yield SampleInput(input, args=(True, get_infos))
  4262. def sample_inputs_lu_unpack(op_info, device, dtype, requires_grad=False, **kwargs):
  4263. def out_fn(output):
  4264. return output[1], output[2]
  4265. for lu_sample in sample_inputs_linalg_lu(op_info, device, dtype, requires_grad, **kwargs):
  4266. lu_data, pivots = torch.linalg.lu_factor(lu_sample.input)
  4267. lu_data.requires_grad_(requires_grad)
  4268. yield SampleInput(lu_data, pivots).with_metadata(output_process_fn_grad=out_fn)
  4269. def sample_inputs_roll(op_info, device, dtype, requires_grad=False, **kwargs):
  4270. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  4271. args = ((0, 0), (1, 2), (0, 2), (2, 0), (-1, 0), (10000, 1), (2,), ((1, 2, -1), (0, 1, 2)))
  4272. for arg in args:
  4273. yield SampleInput(make_arg((0, 0, 0)), args=arg)
  4274. yield SampleInput(make_arg((S, S, S)), args=arg)
  4275. def error_inputs_roll(op_info, device, **kwargs):
  4276. make_arg = partial(make_tensor, device=device, dtype=torch.float32)
  4277. err_msg1 = "`shifts` required"
  4278. s1 = SampleInput(make_arg((S,)), ())
  4279. yield ErrorInput(s1, error_regex=err_msg1)
  4280. err_msg2 = ("shifts and dimensions must align")
  4281. s2 = SampleInput(make_arg((S, S)), (2, 1), 0)
  4282. yield ErrorInput(s2, error_regex=err_msg2)
  4283. err_msg3 = ("out of range")
  4284. s3 = SampleInput(make_arg((S, )), 0, 2)
  4285. yield ErrorInput(s3, error_regex=err_msg3, error_type=IndexError)
  4286. def sample_inputs_rot90(op_info, device, dtype, requires_grad=False, **kwargs):
  4287. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  4288. args = itertools.product(range(-5, 6), [(0, 1), (1, 2), (1, -1)])
  4289. yield SampleInput(make_arg((S, S, S)))
  4290. for arg in args:
  4291. yield SampleInput(make_arg((S, S, S)), args=arg)
  4292. def error_inputs_rot90(op_info, device, **kwargs):
  4293. make_arg = partial(make_tensor, device=device, dtype=torch.float32)
  4294. err_msg1 = "expected total rotation dims"
  4295. s1 = SampleInput(make_arg((S, S)), dims=(0,))
  4296. yield ErrorInput(s1, error_regex=err_msg1)
  4297. err_msg2 = "expected total dims >= 2"
  4298. s2 = SampleInput(make_arg((S,)))
  4299. yield ErrorInput(s2, error_regex=err_msg2)
  4300. err_msg3 = "expected rotation dims to be different"
  4301. s3 = SampleInput(make_arg((S, S)), dims=(1, 1))
  4302. yield ErrorInput(s3, error_regex=err_msg3)
  4303. def sample_inputs_std_var(op_info, device, dtype, requires_grad, **kwargs):
  4304. tensor_nd = partial(make_tensor, (S, S, S), device=device, dtype=dtype,
  4305. requires_grad=requires_grad)
  4306. tensor_1d = partial(make_tensor, (S,), device=device, dtype=dtype,
  4307. requires_grad=requires_grad)
  4308. yield SampleInput(tensor_nd())
  4309. yield SampleInput(tensor_nd(), dim=1)
  4310. yield SampleInput(tensor_nd(), dim=1, unbiased=True, keepdim=True)
  4311. yield SampleInput(tensor_1d(), dim=0, unbiased=True, keepdim=True)
  4312. yield SampleInput(tensor_1d(), dim=0, unbiased=False, keepdim=False)
  4313. yield SampleInput(tensor_nd(), dim=(1,), correction=S // 2)
  4314. yield SampleInput(tensor_nd(), dim=None, correction=0, keepdim=True)
  4315. yield SampleInput(tensor_nd(), dim=None, correction=None)
  4316. yield SampleInput(tensor_nd(), correction=0, keepdim=True)
  4317. def sample_inputs_std_var_unbiased(op_info, device, dtype, requires_grad, **kwargs):
  4318. make_arg = partial(make_tensor, device=device, dtype=dtype,
  4319. requires_grad=requires_grad)
  4320. # Test var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)
  4321. yield SampleInput(make_arg((S, S)), True)
  4322. yield SampleInput(make_arg((S,)), False)
  4323. def _generate_correlation_inputs(device, dtype, requires_grad, **kwargs):
  4324. shapes = [(2,), (1, 2), (3, 2), (2, 3)]
  4325. for shape in shapes:
  4326. yield make_tensor(shape, dtype=dtype, device=device, requires_grad=requires_grad)
  4327. def sample_inputs_corrcoef(op_info, device, dtype, requires_grad, **kwargs):
  4328. return (SampleInput(t) for t in _generate_correlation_inputs(device, dtype, requires_grad))
  4329. def sample_inputs_cov(op_info, device, dtype, requires_grad, **kwargs):
  4330. for t in _generate_correlation_inputs(device, dtype, requires_grad):
  4331. yield SampleInput(t)
  4332. num_observations = t.numel() if t.ndimension() < 2 else t.size(1)
  4333. fweights = make_tensor((num_observations,), dtype=torch.int, device=device, low=1, high=10)
  4334. aweights = make_tensor((num_observations,), dtype=torch.float, device=device, low=0, high=1, requires_grad=requires_grad)
  4335. for correction, fw, aw in product(range(num_observations), [None, fweights], [None, aweights]):
  4336. yield SampleInput(t.clone().requires_grad_(requires_grad),
  4337. correction=correction, fweights=fw, aweights=aw)
  4338. def error_inputs_cov(op_info, device, **kwargs):
  4339. a = torch.rand(S, device=device)
  4340. yield ErrorInput(
  4341. SampleInput(torch.rand(S, S, S, device=device)),
  4342. error_regex="expected input to have two or fewer dimensions")
  4343. yield ErrorInput(
  4344. SampleInput(a, fweights=torch.rand(S, S, device=device)),
  4345. error_regex="expected fweights to have one or fewer dimensions")
  4346. yield ErrorInput(
  4347. SampleInput(a, aweights=torch.rand(S, S, device=device)),
  4348. error_regex="expected aweights to have one or fewer dimensions")
  4349. yield ErrorInput(
  4350. SampleInput(a, fweights=torch.rand(S, device=device)),
  4351. error_regex="expected fweights to have integral dtype")
  4352. yield ErrorInput(
  4353. SampleInput(a, aweights=torch.tensor([1, 1], device=device)),
  4354. error_regex="expected aweights to have floating point dtype")
  4355. yield ErrorInput(
  4356. SampleInput(a, fweights=torch.tensor([1], device=device)),
  4357. error_regex="expected fweights to have the same numel")
  4358. yield ErrorInput(
  4359. SampleInput(a, aweights=torch.rand(1, device=device)),
  4360. error_regex="expected aweights to have the same numel")
  4361. yield ErrorInput(
  4362. SampleInput(a, fweights=torch.tensor([-1, -2, -3, -4 , -5], device=device)),
  4363. error_regex="fweights cannot be negative")
  4364. yield ErrorInput(
  4365. SampleInput(a, aweights=torch.tensor([-1., -2., -3., -4., -5.], device=device)),
  4366. error_regex="aweights cannot be negative")
  4367. def sample_inputs_permute(op_info, device, dtype, requires_grad, **kwargs):
  4368. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  4369. cases = [((1, 2, 3, 4), (0, 2, 3, 1)),
  4370. ((1, 2, 3, 4), (0, -2, -1, 1)),
  4371. ((), ()),
  4372. ((1, 2, 3, 4), (2, 1, 3, 0))]
  4373. for shape, args in cases:
  4374. yield SampleInput(make_arg(shape), args=(args,))
  4375. def reference_inputs_permute(op, device, dtype, requires_grad, **kwargs):
  4376. yield from sample_inputs_permute(op, device, dtype, requires_grad, **kwargs)
  4377. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  4378. cases = (
  4379. ((), ()),
  4380. ((1,), (0,)),
  4381. ((2, 2), (1, 0)),
  4382. ((2, 2), (0, 1)),
  4383. ((2, 0, 1), (0, 2, 1)),
  4384. ((3, 4, 2), (2, 1, 0)),
  4385. ((3, 4, 2), (1, 0, 2)),
  4386. ((3, 4, 2), (0, 1, 2)),
  4387. )
  4388. # Adds tricky permutations and permutations with noncontiguity
  4389. for shape, permutation in cases:
  4390. for p in itertools.permutations(permutation):
  4391. a = make_arg(shape).permute(p)
  4392. yield SampleInput(a, args=(permutation,))
  4393. a = make_arg(shape, noncontiguous=True).permute(p)
  4394. yield SampleInput(a, args=(permutation,))
  4395. def error_inputs_softshrink(op, device, **kwargs):
  4396. yield ErrorInput(SampleInput(make_tensor((1,), dtype=torch.float, device=device), kwargs={"lambd": -0.5}),
  4397. error_regex="lambda must be greater or equal to 0, but found to be -0.5")
  4398. def sample_inputs_softshrink(op_info, device, dtype, requires_grad=False, **kwargs):
  4399. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  4400. # The additional sample is to check additional values of lambd beyond the default
  4401. # value (what is already checked by sample_inputs_elementwise_unary)
  4402. for lbda in (0., 0.5):
  4403. yield SampleInput(make_arg(S, S), kwargs={"lambd": lbda})
  4404. yield from sample_inputs_elementwise_unary(op_info, device, dtype, requires_grad)
  4405. def sample_inputs_hardshrink(op_info, device, dtype, requires_grad=False, **kwargs):
  4406. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  4407. # The additional sample is to check additional values of lambd beyond the default
  4408. # value (what is already checked by sample_inputs_elementwise_unary)
  4409. # Note that unlike softshrink, lambd is allowed to be negative for hardshrink
  4410. for lbda in (-0.5, 0., 0.5):
  4411. yield SampleInput(make_arg(S, S), kwargs={"lambd": lbda})
  4412. yield from sample_inputs_elementwise_unary(op_info, device, dtype, requires_grad)
  4413. def sample_inputs_hardtanh(op_info, device, dtype, requires_grad=False, **kwargs):
  4414. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  4415. # The additional sample is to check additional values of min_val and max_val beyond the default
  4416. # value (what is already checked by sample_inputs_elementwise_unary)
  4417. for max_val, min_val in ((-0.5, 0.5), (0.5, -0.5), (0., 0.)):
  4418. yield SampleInput(make_arg(S, S), kwargs={"min_val": min_val, "max_val": max_val})
  4419. yield from sample_inputs_elementwise_unary(op_info, device, dtype, requires_grad)
  4420. def sample_inputs_einsum(op_info, device, dtype, requires_grad=False, **kwargs):
  4421. def c(t):
  4422. return t.clone().requires_grad_(requires_grad)
  4423. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  4424. x = make_arg((3,))
  4425. y = make_arg((4,))
  4426. A = make_arg((2, 3,))
  4427. B = make_arg((1, 3,))
  4428. C = make_arg((1, 2, 3,))
  4429. D = make_arg((1, 3, 4,))
  4430. E = make_arg((4, 4,))
  4431. H = make_arg((3, 3,))
  4432. I = make_arg((1, 3, 1,))
  4433. # Vector operations
  4434. yield SampleInput([c(x)], 'i->') # sum
  4435. yield SampleInput([c(x), c(y)], 'i,j->ij') # outer
  4436. # Matrix operations
  4437. yield SampleInput([c(A)], "ij->i") # col sum
  4438. yield SampleInput([c(A), c(B)], "ij,kj->ik") # matmul
  4439. yield SampleInput([c(A), c(E)], "ij,Ab->ijAb") # matrix outer product
  4440. # Tensor operations
  4441. yield SampleInput([c(C), c(D)], "aij,ajk->aik") # batch matmul
  4442. yield SampleInput([c(D), c(E)], "aij,jk->aik") # tensor matrix contraction
  4443. yield SampleInput([c(C), c(B)], "ijk,ik->j") # non contiguous
  4444. # Test diagonals
  4445. yield SampleInput([c(I)], 'iji->j') # non-contiguous trace
  4446. # Test ellipsis
  4447. yield SampleInput([c(H)], "i...->...")
  4448. yield SampleInput([c(C), c(x)], '...ik, ...j -> ij')
  4449. def sample_inputs_flip(op_info, device, dtype, requires_grad, **kwargs):
  4450. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  4451. sizes = ((S, M, S), (S, 0, M))
  4452. all_dims = ((0, 1, 2), (0,), (0, 2), (-1,), ())
  4453. for size, dims in product(sizes, all_dims):
  4454. yield SampleInput(make_arg(size), kwargs={"dims": dims})
  4455. def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad, **kwargs):
  4456. shapes = [
  4457. (S, M, S),
  4458. (S, 0, M),
  4459. ]
  4460. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  4461. return (SampleInput(make_arg(shape, low=None, high=None)) for shape in shapes)
  4462. def error_inputs_fliplr(op, device, **kwargs):
  4463. yield ErrorInput(SampleInput(make_tensor((1,), dtype=torch.float, device=device)),
  4464. error_regex="Input must be >= 2-d.")
  4465. def error_inputs_flipud(op, device, **kwargs):
  4466. yield ErrorInput(SampleInput(make_tensor((), dtype=torch.float, device=device)),
  4467. error_regex="Input must be >= 1-d.")
  4468. def sample_inputs_clamp(op_info, device, dtype, requires_grad, **kwargs):
  4469. make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad)
  4470. shape = (S, M, S)
  4471. yield SampleInput(make_arg(shape), args=(make_arg(shape), make_arg(shape)))
  4472. yield SampleInput(make_arg(shape), args=(make_arg(shape[1:]), make_arg(shape[1:])))
  4473. yield SampleInput(make_arg(shape), args=(make_arg((S, 1, S)),))
  4474. yield SampleInput(make_arg(shape), args=(None, make_arg(shape)))
  4475. yield SampleInput(make_arg(shape), args=(make_arg(shape), None))
  4476. def reference_inputs_elementwise_ternary(op, device, dtype, requires_grad, *, sample_inputs_func, supports_scalars=False, **kwargs):
  4477. yield from sample_inputs_func(op, device, dtype, requires_grad, **kwargs)
  4478. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  4479. make_scalar_tensor = partial(make_tensor, (), device='cpu', dtype=dtype, requires_grad=requires_grad)
  4480. supported_dtypes = op.supported_dtypes(device)
  4481. # broadcasting and oncontiguous cases
  4482. cases = (
  4483. ((4, 4), (4, 4), (4, 4)),
  4484. ((4, 4), (1, 4, 4), (4, 4)),
  4485. ((4, 4), (1, 4, 4), (4, 1, 4)),
  4486. ((4, 4, 1), (1, 4, 4), (4, 4)),
  4487. ((4, 1), (1, 4, 4), (1, 4)),
  4488. ((4, 4), (), (4, 4)),
  4489. ((4, 4), (), ()),
  4490. ((), (4, 4), (1, 4, 4)),
  4491. )
  4492. for a, b, c in cases:
  4493. yield SampleInput(make_arg(a), args=(make_arg(b), make_arg(c)))
  4494. yield SampleInput(make_arg(a, noncontiguous=True),
  4495. args=(make_arg(b).transpose(0, -1), make_arg(c, noncontiguous=True).transpose(0, -1)))
  4496. # scalar cases
  4497. if supports_scalars:
  4498. cases = [
  4499. ((), 1, 2,),
  4500. ((), 1., 2),
  4501. ((4, 4), 1., 2,),
  4502. ((3, 4), make_scalar_tensor(), make_scalar_tensor()),
  4503. ]
  4504. if torch.complex64 in supported_dtypes:
  4505. cases.extend([
  4506. ((3, 1, 4), complex(1, 2), 3.),
  4507. ])
  4508. for a, b, c in cases:
  4509. yield SampleInput(make_arg(a), args=(b, c))
  4510. # type promotion cases
  4511. # int x float
  4512. if torch.float in supported_dtypes and torch.long in supported_dtypes:
  4513. a = make_arg((), dtype=torch.long)
  4514. b = make_arg((1, 4), dtype=torch.float)
  4515. c = make_arg((3, 4))
  4516. cases = (
  4517. (a, b, c),
  4518. (c, a, b),
  4519. )
  4520. for a, b, c in cases:
  4521. yield SampleInput(a, args=(b, c))
  4522. # NaN propagation
  4523. if dtype.is_floating_point or dtype.is_complex:
  4524. nan = float('nan') if dtype.is_floating_point else complex(float('nan'), float('nan'))
  4525. a = make_arg((12,))
  4526. a[4] = nan
  4527. a[7] = nan
  4528. b = make_arg((12,))
  4529. b[1] = nan
  4530. b[7] = nan
  4531. c = make_arg((12,))
  4532. c[9] = nan
  4533. yield SampleInput(a, args=(b, c))
  4534. def _clamp_min_numpy(a, min=None):
  4535. return np.maximum(a, min)
  4536. def _clamp_max_numpy(a, max=None):
  4537. return np.minimum(a, max)
  4538. def _clamp_numpy(a, min=None, max=None):
  4539. if min is None:
  4540. return np.minimum(a, max)
  4541. if max is None:
  4542. return np.maximum(a, min)
  4543. return np.minimum(max, np.maximum(a, min))
  4544. def sample_inputs_cumprod(op_info, device, dtype, requires_grad, **kwargs):
  4545. def make_arg(shape):
  4546. # shrink values to be in the interval [-1, +1] for better precision in gradgradcheck
  4547. return make_tensor(shape, dtype=dtype, device=device, low=-1, high=+1, requires_grad=requires_grad)
  4548. def prod_zeros(dim_select):
  4549. assert len(dim_select) == 2
  4550. result = make_arg(3 * (S,))
  4551. result.narrow(dim_select[0], 0, 1).narrow(dim_select[1], 1, 1).zero_()
  4552. result.narrow(dim_select[0], 2, 1).narrow(dim_select[1], 3, 1).zero_()
  4553. result.narrow(dim_select[0], 4, 1).narrow(dim_select[1], 3, 1).zero_()
  4554. return result
  4555. for dim in range(3):
  4556. yield SampleInput(make_arg((S, S, S)), args=(dim,))
  4557. # Scalar tensors and empty tensor
  4558. for size in [(), (1,), (0,)]:
  4559. yield SampleInput(make_arg(size), args=(0,))
  4560. yield SampleInput(prod_zeros([0, 1]), args=(1,))
  4561. yield SampleInput(prod_zeros([0, 2]), args=(1,))
  4562. yield SampleInput(prod_zeros([1, 2]), args=(1,))
  4563. # test dtype kwarg
  4564. yield SampleInput(prod_zeros([1, 2]), args=(1,), kwargs={'dtype': dtype})
  4565. def sample_inputs_view_as_complex(op_info, device, dtype, requires_grad, **kwargs):
  4566. yield SampleInput(make_tensor((S, 2), dtype=dtype, device=device, requires_grad=requires_grad))
  4567. def sample_inputs_view_as_real(op_info, device, dtype, requires_grad, **kwargs):
  4568. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  4569. sizes = ((S, S), ())
  4570. return (SampleInput(make_arg(size)) for size in sizes)
  4571. def error_inputs_complex(op_info, device, is_ref=False, **kwargs):
  4572. make_arg = partial(make_tensor, dtype=torch.float32, device=device)
  4573. if is_ref:
  4574. error_float = "Expected both inputs to be Half, Float or Double tensors but got torch.float32 and torch.int32"
  4575. error_dtype = "Expected object of scalar type torch.float32 but got scalar type torch.float64 for second argument"
  4576. error_out = "Expected out tensor to have dtype torch.complex128 but got torch.complex64 instead"
  4577. else:
  4578. error_float = "Expected both inputs to be Half, Float or Double tensors but got Float and Int"
  4579. error_dtype = "Expected object of scalar type Float but got scalar type Double for second argument"
  4580. error_out = "Expected object of scalar type ComplexDouble but got scalar type ComplexFloat for argument 'out'"
  4581. yield ErrorInput(SampleInput(make_arg(M, S), make_arg(M, S, dtype=torch.int)),
  4582. error_type=RuntimeError, error_regex=error_float)
  4583. yield ErrorInput(SampleInput(make_arg(M, S), make_arg(M, S, dtype=torch.float64)),
  4584. error_type=RuntimeError, error_regex=error_dtype)
  4585. yield ErrorInput(SampleInput(make_arg(M, S, dtype=torch.float64), make_arg(M, S, dtype=torch.float64),
  4586. out=make_arg(M, S, dtype=torch.complex64)),
  4587. error_type=RuntimeError, error_regex=error_out)
  4588. def sample_inputs_logaddexp(op_info, device, dtype, requires_grad, **kwargs):
  4589. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  4590. shape = (S, S)
  4591. yield SampleInput(make_arg(shape), make_arg(shape))
  4592. def sample_inputs_prod(op_info, device, dtype, requires_grad, **kwargs):
  4593. def make_arg(shape):
  4594. # shrink values to be in the interval [-1, +1] for better precision in gradgradcheck
  4595. return make_tensor(shape, dtype=dtype, device=device, low=-1, high=+1, requires_grad=requires_grad)
  4596. def prod_single_zero():
  4597. result = make_arg(2 * (S,))
  4598. result[0, 1] = 0
  4599. return result
  4600. for sample in sample_inputs_cumprod(op_info, device, dtype, requires_grad):
  4601. # only Tensor, ignore other inputs
  4602. yield SampleInput(sample.input.clone().requires_grad_(requires_grad))
  4603. yield sample
  4604. # Generates samples with keepdim = True
  4605. for sample in sample_inputs_cumprod(op_info, device, dtype, requires_grad):
  4606. sample.kwargs['keepdim'] = True
  4607. yield sample
  4608. yield SampleInput(prod_single_zero())
  4609. yield SampleInput(make_arg((3, 3, 3)), args=(1,))
  4610. yield SampleInput(make_arg((3, 3, 3)), args=(1,), kwargs={'keepdim': True})
  4611. yield SampleInput(make_arg((3, 0)), args=(1,))
  4612. yield SampleInput(make_arg((3, 0)), args=(1,), kwargs={'keepdim': True})
  4613. # test zero scalar tensor
  4614. zero = make_arg(())
  4615. zero.zero_()
  4616. yield SampleInput(zero.clone().requires_grad_(requires_grad))
  4617. yield SampleInput(zero.clone().requires_grad_(requires_grad), args=(0,))
  4618. yield SampleInput(zero.clone().requires_grad_(requires_grad),
  4619. args=(0,),
  4620. kwargs={'keepdim': True})
  4621. def error_inputs_neg(op_info, device, **kwargs):
  4622. si = SampleInput(torch.tensor((False, True), device=device))
  4623. msg = ("Negation, the `\\-` operator, on a bool tensor is not supported."
  4624. " If you are trying to invert a mask, use the `\\~` or"
  4625. " `logical_not\\(\\)` operator instead.")
  4626. yield ErrorInput(si, error_regex=msg)
  4627. def sample_inputs_diag(op_info, device, dtype, requires_grad, **kwargs):
  4628. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None)
  4629. yield SampleInput(make_arg(M))
  4630. tensors = (
  4631. make_arg((M, M)),
  4632. make_arg((3, 5)),
  4633. make_arg((5, 3)),
  4634. )
  4635. args = ((), (2,), (-2,), (1,), (2,))
  4636. for tensor, arg in product(tensors, args):
  4637. yield SampleInput(tensor.clone().requires_grad_(requires_grad), *arg)
  4638. def sample_inputs_diagonal_diag_embed(op_info, device, dtype, requires_grad, **kwargs):
  4639. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  4640. # Shapes for 2D Tensors
  4641. shapes_2d = ((S, S), (3, 5), (5, 3))
  4642. # Shapes for 3D Tensors
  4643. shapes_3d = ((S, S, S),)
  4644. kwargs_2d = (dict(), dict(offset=2), dict(offset=2), dict(offset=1))
  4645. kwargs_3d = (dict(offset=1, dim1=1, dim2=2),
  4646. dict(offset=2, dim1=0, dim2=1),
  4647. dict(offset=-2, dim1=0, dim2=1))
  4648. for shape, kwarg in chain(product(shapes_2d, kwargs_2d), product(shapes_3d, kwargs_3d)):
  4649. yield SampleInput(make_arg(shape), kwargs=kwarg)
  4650. def reference_inputs_diagonal_diag_embed(op_info, device, dtype, requires_grad, **kwargs):
  4651. yield from sample_inputs_diagonal_diag_embed(
  4652. op_info, device, dtype, requires_grad, **kwargs)
  4653. make_arg = partial(
  4654. make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  4655. shapes1d = ((0,), (1,))
  4656. shapes2d = ((L, M),)
  4657. shapes3d = ((L, M, S),)
  4658. kwargs1d = {}
  4659. kwargs2d = (
  4660. # dim1 > dim2 is allowed
  4661. dict(dim1=1, dim2=0),
  4662. # negative dims are allowed
  4663. dict(dim1=-2, dim2=-1),
  4664. # out of bounds offset should return an empty tensor in diagonal and
  4665. # offset the diagonal in diag_embed
  4666. dict(offset=100),
  4667. )
  4668. kwargs3d = kwargs2d + (
  4669. # make sure we can use non-sequential dims
  4670. dict(offset=-1, dim1=0, dim2=2),
  4671. )
  4672. samples1d = product(shapes1d, kwargs1d)
  4673. samples2d = product(shapes2d, kwargs2d)
  4674. samples3d = product(shapes3d, kwargs3d)
  4675. for shape, kwargs in chain(samples1d, samples2d, samples3d):
  4676. if 'diagonal' in op_info.name:
  4677. # these are error inputs for diagonal
  4678. if shape in ((0,), (1,)):
  4679. continue
  4680. yield SampleInput(input=make_arg(shape), kwargs=kwargs)
  4681. def error_inputs_diagonal_diag_embed(op_info, device, **kwargs):
  4682. make_arg = partial(make_tensor, device=device, dtype=torch.float32)
  4683. shapes1d = (0, 1, (0,), (1,))
  4684. shapes2d = ((M, L),)
  4685. shapes3d = ((M, S, L),)
  4686. kwargs1d = {}
  4687. kwargs2d = (
  4688. # dim1 == dim2 is not allowed
  4689. dict(dim1=1, dim2=1),
  4690. # out of bounds dims are not allowed
  4691. dict(dim1=10000),
  4692. dict(dim2=10000),
  4693. )
  4694. kwargs3d = kwargs2d
  4695. samples1d = product(shapes1d, kwargs1d)
  4696. samples2d = product(shapes2d, kwargs2d)
  4697. samples3d = product(shapes3d, kwargs3d)
  4698. for shape, kwargs in chain(samples1d, samples2d, samples3d):
  4699. arg = make_arg(shape)
  4700. sample = SampleInput(input=arg, kwargs=kwargs)
  4701. dim1 = kwargs.get('dim1')
  4702. dim2 = kwargs.get('dim2')
  4703. if 'diagonal' in op_info.name:
  4704. num_dim = arg.dim()
  4705. elif op_info.name in ('diag_embed', '_refs.diag_embed'):
  4706. # these are valid inputs for diag_embed
  4707. if shape in ((0,), (1,)):
  4708. continue
  4709. num_dim = arg.dim() + 1
  4710. else:
  4711. raise RuntimeError("should be unreachable")
  4712. bound1 = -num_dim
  4713. bound2 = num_dim - 1
  4714. dim_range = range(bound1, bound2 + 1)
  4715. dim1_cond = dim1 and dim1 not in dim_range
  4716. dim2_cond = dim2 and dim2 not in dim_range
  4717. if dim1 == dim2:
  4718. err = f"diagonal dimensions cannot be identical {dim1}, {dim2}"
  4719. yield ErrorInput(sample, error_regex=err, error_type=RuntimeError)
  4720. elif dim1_cond or dim2_cond:
  4721. err_dim = dim1 if dim1_cond else dim2
  4722. err = (r"Dimension out of range \(expected to be in range of "
  4723. rf"\[{bound1}, {bound2}\], but got {err_dim}\)")
  4724. yield ErrorInput(sample, error_regex=err, error_type=IndexError)
  4725. else:
  4726. raise RuntimeError("should be unreachable")
  4727. def sample_inputs_diagonal_scatter(op_info, device, dtype, requires_grad, **kwargs):
  4728. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  4729. # Shapes for 2D Tensors
  4730. shapes_2d = ((M, M), (3, 5), (5, 3))
  4731. # Shapes for 3D Tensors
  4732. shapes_3d = ((M, M, M),)
  4733. args_2d = ((), (2,), (-2,), (1,))
  4734. args_3d = ((1, 1, 2), (2, 0, 1), (-2, 0, 1))
  4735. for input_shape, arg in chain(product(shapes_2d, args_2d), product(shapes_3d, args_3d)):
  4736. input_ = make_arg(input_shape)
  4737. # We can programatically figure out the right shape for src:
  4738. # It should be the same size as input.diagonal(other_args...)
  4739. if not isinstance(arg, tuple):
  4740. arg_tuple = (arg,)
  4741. else:
  4742. arg_tuple = arg
  4743. src_shape = input_.diagonal(*arg_tuple).size()
  4744. src = make_arg(src_shape)
  4745. yield SampleInput(input_, args=(src, *arg_tuple))
  4746. def sample_inputs_to_sparse(op_info, device, dtype, requires_grad, **kwargs):
  4747. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  4748. yield SampleInput(make_arg((S, S))).with_metadata(output_process_fn_grad=lambda x: x.to_dense())
  4749. yield SampleInput(make_arg((S, S)), 1).with_metadata(output_process_fn_grad=lambda x: x.to_dense())
  4750. def sample_inputs_cross_entropy(op_info, device, dtype, requires_grad, **kwargs):
  4751. batch_size, num_classes = shape = (2, 3)
  4752. reductions = ("mean", "sum", "none")
  4753. input_shape_and_kwargs: List[Tuple[Tuple[int, ...], Dict[str, Any]]] = [
  4754. (shape, {}),
  4755. ((*shape, 1), {}),
  4756. ((*shape, 1, 2), {}),
  4757. ((*shape, 1, 2, 3), {}),
  4758. *[(shape, dict(reduction=reduction)) for reduction in reductions],
  4759. *[
  4760. (
  4761. shape,
  4762. dict(
  4763. weight=make_tensor((num_classes,), device=device, dtype=dtype),
  4764. reduction=reduction,
  4765. ),
  4766. )
  4767. for reduction in reductions
  4768. ],
  4769. (shape, dict(ignore_index=1)),
  4770. ]
  4771. for (input_shape, kwargs), probabilities_target in itertools.product(input_shape_and_kwargs, (False, True)):
  4772. input = make_tensor(input_shape, device=device, dtype=dtype, requires_grad=requires_grad)
  4773. if probabilities_target:
  4774. # ignore_index is not supported for probabilities target
  4775. if "ignore_index" in kwargs:
  4776. continue
  4777. target = make_tensor(
  4778. input_shape,
  4779. low=0,
  4780. high=1,
  4781. device=device,
  4782. dtype=dtype,
  4783. requires_grad=requires_grad,
  4784. )
  4785. else:
  4786. target = make_tensor(
  4787. (batch_size, *input_shape[2:]),
  4788. low=0,
  4789. high=num_classes,
  4790. device=device,
  4791. dtype=torch.long,
  4792. )
  4793. if "ignore_index" in kwargs and torch.all(target == kwargs["ignore_index"]):
  4794. # make sure at least one item in target is not ignored
  4795. target[0] = random.sample(sorted(set(range(num_classes)) - {kwargs["ignore_index"]}), 1)[0]
  4796. yield SampleInput(input, target, **kwargs)
  4797. def sample_inputs_logit(op_info, device, dtype, requires_grad, **kwargs):
  4798. low, high = op_info.domain
  4799. # Note: Operator is very sensitive at points near the
  4800. # start and end of domain and leads to NaN for float16
  4801. # if domain_eps is 1e-5.
  4802. domain_eps = op_info._domain_eps if dtype != torch.float16 else 3e-2
  4803. low = low + domain_eps
  4804. high = high - domain_eps
  4805. make_arg = partial(
  4806. make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=low, high=high)
  4807. make_arg = partial(make_tensor, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad)
  4808. yield SampleInput(make_arg((S, S, S)))
  4809. yield SampleInput(make_arg((S, S, S)), 0.2)
  4810. yield SampleInput(make_arg(()))
  4811. yield SampleInput(make_arg(()), 0.2)
  4812. def sample_inputs_isin(op_info, device, dtype, requires_grad, **kwargs):
  4813. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  4814. # isin has two paths based on the size of elements and test_elements.
  4815. # if elements.numel() < 10 * pow(test_elements.numel(), 0.145):
  4816. yield SampleInput(make_arg((L,)), args=(make_arg((S,)),))
  4817. # else:
  4818. yield SampleInput(make_arg((S,)), args=(make_arg((L,)),))
  4819. def sample_inputs_masked_scatter(op_info, device, dtype, requires_grad, **kwargs):
  4820. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  4821. yield SampleInput(make_arg((S, S)), args=(torch.randn(S, S, device=device) > 0, make_arg((S, S))))
  4822. yield SampleInput(make_arg((S, S)), args=(torch.randn((S,), device=device) > 0, make_arg((S, S))))
  4823. yield SampleInput(make_arg((S, S)), args=(bernoulli_scalar().to(device), make_arg((S, S))))
  4824. yield SampleInput(make_arg((S,)),
  4825. args=(torch.randn(S, S, device=device) > 0, make_arg((S, S))),
  4826. broadcasts_input=True)
  4827. def sample_inputs_masked_fill(op_info, device, dtype, requires_grad, **kwargs):
  4828. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  4829. yield SampleInput(make_arg((S, S)), args=(torch.randn(S, S, device=device) > 0, 10))
  4830. yield SampleInput(make_arg((S, S)), args=(torch.randn(S, S, device=device) > 0, make_arg(())))
  4831. yield SampleInput(make_arg((S, S)), args=(torch.randn(S, device=device) > 0, 10))
  4832. yield SampleInput(make_arg(()), args=(torch.randn((), device=device) > 0, 10))
  4833. yield SampleInput(make_arg(()), args=(torch.randn((), device=device) > 0, make_arg(())))
  4834. yield SampleInput(make_arg((S, S)), args=(torch.randn((), device=device) > 0, 10))
  4835. yield SampleInput(make_arg((S,)),
  4836. args=(torch.randn(S, S, device=device) > 0, make_arg(())),
  4837. broadcasts_input=True)
  4838. yield SampleInput(make_arg((S,)),
  4839. args=(torch.randn(S, S, device=device) > 0, 10),
  4840. broadcasts_input=True)
  4841. if torch.device(device).type == 'cuda':
  4842. # `self` and `mask` on CUDA but `value` is a CPU scalar tensor.
  4843. yield SampleInput(make_arg((S, S)), args=(torch.randn(S, S, device=device) > 0, torch.randn(())))
  4844. def error_inputs_masked_fill(op_info, device, **kwargs):
  4845. make_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False)
  4846. # `value` is not a 0-D tensor.
  4847. yield ErrorInput(SampleInput(make_arg((2, 2)), args=(make_arg(()) > 0, make_arg((1,)))),
  4848. error_regex="only supports a 0-dimensional value tensor, but got tensor with 1 dimension")
  4849. # downcasting complex value (scalar overload)
  4850. yield ErrorInput(SampleInput(make_arg((2, 2)), args=(make_arg(()) > 0, 1j)),
  4851. error_regex=r"value cannot be converted to type .* without overflow")
  4852. # downcasting complex value (tensor overload)
  4853. yield ErrorInput(SampleInput(torch.ones(2, dtype=torch.long, device=device),
  4854. args=(make_arg(()) > 0, torch.tensor(1j, device=device))),
  4855. error_regex=r"value cannot be converted to type .* without overflow")
  4856. if torch.device(device).type == 'cuda':
  4857. # `self` and `mask` on CPU but `value` is a CUDA scalar tensor.
  4858. yield ErrorInput(SampleInput(torch.randn((S, S), device='cpu'),
  4859. args=(torch.randn(S, S, device='cpu') > 0,
  4860. torch.randn((), device='cuda'))),
  4861. error_regex=r"to be on same device")
  4862. def sample_inputs_masked_select(op_info, device, dtype, requires_grad, **kwargs):
  4863. make_arg = partial(
  4864. make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None)
  4865. yield SampleInput(make_arg((M, M)), torch.randn(M, M, device=device) > 0)
  4866. yield SampleInput(make_arg((M, M)), torch.randn((M,), device=device) > 0)
  4867. yield SampleInput(make_arg((M,)), torch.randn((M, M), device=device) > 0)
  4868. yield SampleInput(make_arg((M, 1, M)), torch.randn((M, M), device=device) > 0)
  4869. yield SampleInput(make_arg(()), torch.tensor(1, device=device, dtype=torch.bool))
  4870. yield SampleInput(make_arg((M, M)), torch.tensor(1, device=device, dtype=torch.bool))
  4871. yield SampleInput(make_arg(()), torch.randn((M, M), device=device) > 0)
  4872. def sample_inputs_matrix_exp(op_info, device, dtype, requires_grad, **kwargs):
  4873. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  4874. yield SampleInput(make_arg((S, S)))
  4875. yield SampleInput(make_arg((S, S, S)))
  4876. def sample_inputs_matmul(op_info, device, dtype, requires_grad, is_rmatmul=False, **kwargs):
  4877. make_arg = partial(make_tensor, dtype=dtype, device=device, low=None,
  4878. high=None, requires_grad=requires_grad)
  4879. test_cases = (((L,), (L,)),
  4880. ((S, M), (M,)),
  4881. ((M,), (M, S)),
  4882. ((S, M), (M, S)),
  4883. ((S, 0), (0, M)),
  4884. ((S, S, M), (M,)),
  4885. ((S, S, M), (M, S)),
  4886. ((S, S, 0), (0, S)),
  4887. ((M,), (S, M, S)),
  4888. ((S, M), (S, M, S)),
  4889. ((0, 0), (S, 0, 0)),
  4890. ((S, S, M, M), (S, S, M, S)),
  4891. ((S, S, M, M), (M,)),
  4892. ((M,), (S, S, M, S)))
  4893. for lhs_shape, rhs_shape in test_cases:
  4894. lhs = make_arg(lhs_shape)
  4895. rhs = make_arg(rhs_shape)
  4896. if not is_rmatmul:
  4897. yield SampleInput(lhs, rhs)
  4898. else:
  4899. yield SampleInput(rhs, lhs)
  4900. def sample_inputs_meshgrid(op_info: OpInfo, device: torch.device, dtype: torch.dtype,
  4901. requires_grad: bool,
  4902. *, variant: str, **kwargs) -> List[SampleInput]:
  4903. if variant == 'variadic':
  4904. def make_inputs(
  4905. tensors: List[torch.Tensor]) -> Tuple[Union[torch.Tensor,
  4906. List[torch.Tensor]],
  4907. Tuple[torch.Tensor, ...]]:
  4908. return tensors
  4909. elif variant == 'list':
  4910. def make_inputs(
  4911. tensors: List[torch.Tensor]) -> Tuple[Union[torch.Tensor,
  4912. List[torch.Tensor]],
  4913. Tuple[torch.Tensor, ...]]:
  4914. return [tensors]
  4915. else:
  4916. raise ValueError(
  4917. 'Unsupported variant, must be one of {"variadic", "list"}. '
  4918. f'Got "{variant}".')
  4919. SCALAR = torch.Size([])
  4920. VECTOR = torch.Size([3])
  4921. test_cases: List[List[torch.Size]] = [
  4922. [SCALAR],
  4923. [VECTOR],
  4924. [VECTOR, SCALAR],
  4925. [VECTOR, SCALAR, VECTOR],
  4926. [VECTOR, SCALAR, VECTOR, SCALAR],
  4927. ]
  4928. for shapes, indexing in itertools.product(test_cases, {'xy', 'ij'}):
  4929. args = make_inputs(
  4930. [make_tensor(shape, dtype=dtype, device=device, requires_grad=requires_grad)
  4931. for shape in shapes])
  4932. yield SampleInput(*args, indexing=indexing)
  4933. def sample_inputs_mvlgamma(op_info, device, dtype, requires_grad, **kwargs):
  4934. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  4935. tensor_shapes = ((S, S), ())
  4936. ns = (1, 2, 3, 4, 5)
  4937. # Since the accepted lower bound for input
  4938. # to mvlgamma depends on `p` argument,
  4939. # the following function computes the lower bound
  4940. # which we pass to `make_tensor`.
  4941. def compute_min_val(p):
  4942. return (p - 1.) / 2
  4943. for shape, n in product(tensor_shapes, ns):
  4944. min_val = compute_min_val(n)
  4945. if not dtype.is_floating_point:
  4946. # Round-up minimum value for integral dtypes
  4947. min_val += 1
  4948. else:
  4949. min_val += 2 * torch.finfo(dtype).eps
  4950. yield SampleInput(make_arg(shape, low=min_val), args=(n,))
  4951. # Since `mvlgamma` has multiple entries,
  4952. # there are multiple common skips for the additional
  4953. # entries. Following function is a helper to that end.
  4954. def skips_mvlgamma(skip_redundant=False):
  4955. skips = (
  4956. # outside domain values are hard error for mvlgamma op.
  4957. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_float_domains'),
  4958. )
  4959. if skip_redundant:
  4960. # Redundant tests
  4961. skips = skips + ( # type: ignore[assignment]
  4962. DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'),
  4963. DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'),
  4964. DecorateInfo(unittest.skip("Skipped!"), 'TestJit'),
  4965. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'),
  4966. )
  4967. return skips
  4968. # To test reference numerics against multiple values of argument `p`,
  4969. # we make multiple OpInfo entries with each entry corresponding to different value of p.
  4970. # We run the op tests from test_ops.py only for `p=1` to avoid redundancy in testing.
  4971. def make_mvlgamma_opinfo(variant_test_name, domain, skips, sample_kwargs):
  4972. return UnaryUfuncInfo('mvlgamma',
  4973. ref=reference_mvlgamma if TEST_SCIPY else None,
  4974. aliases=('special.multigammaln',),
  4975. variant_test_name=variant_test_name,
  4976. domain=domain,
  4977. decorators=(precisionOverride({torch.float16: 5e-2}),),
  4978. dtypes=all_types_and(torch.bfloat16),
  4979. dtypesIfCUDA=all_types_and(torch.float16),
  4980. sample_inputs_func=sample_inputs_mvlgamma,
  4981. supports_forward_ad=True,
  4982. supports_fwgrad_bwgrad=True,
  4983. skips=skips,
  4984. sample_kwargs=sample_kwargs)
  4985. def sample_inputs_cumulative_ops(op_info, device, dtype, requires_grad, supports_dtype_kwargs=True, **kwargs):
  4986. def _make_tensor_helper(shape, low=None, high=None):
  4987. return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad)
  4988. yield SampleInput(_make_tensor_helper((S, S, S)), 0)
  4989. yield SampleInput(_make_tensor_helper((S, S, S)), 1)
  4990. yield SampleInput(_make_tensor_helper(()), 0)
  4991. if supports_dtype_kwargs:
  4992. # NOTE: if `dtype` is not same as input, then inplace variants fail with
  4993. # `provided dtype must match the dtype of self tensor in cumsum`
  4994. yield SampleInput(_make_tensor_helper((S, S, S)), 1, dtype=dtype)
  4995. def sample_inputs_unfold(op_info, device, dtype, requires_grad, **kwargs):
  4996. test_cases = (
  4997. ((), (0, 1, 1)),
  4998. ((S, S, S, S), (0, 3, 1)),
  4999. ((S, S, S, S), (1, 3, 1)),
  5000. ((S, S, S, S), (2, 3, 1)),
  5001. ((S, S, S, S), (3, 3, 1)),
  5002. ((S, S, S, S), (0, 3, 2)),
  5003. ((S, S, S, S), (1, 3, 2)),
  5004. ((S, S, S, S), (2, 3, 2)),
  5005. ((S, S, S, S), (3, 3, 2)),
  5006. ((S, S, S, S), (0, 4, 1)),
  5007. ((S, S, S, S), (1, 4, 1)),
  5008. ((S, S, S, S), (2, 4, 1)),
  5009. ((S, S, S, S), (3, 4, 1)),
  5010. ((M,), (0, 3, 1)),
  5011. ((M,), (0, 3, 2)),
  5012. ((M,), (0, 3, 3)),
  5013. ((1000,), (0, 3, 11)),
  5014. ((1000,), (0, 2, 27)),
  5015. ((10, 10), (0, 1, 2)),
  5016. ((10, 10), (1, 2, 3)),
  5017. ((10, 10), (1, 2, 2)),
  5018. ((S, S, S), (2, 3, 2)),
  5019. )
  5020. for shape, arguments in test_cases:
  5021. yield SampleInput(make_tensor(shape, dtype=dtype, device=device,
  5022. low=None, high=None,
  5023. requires_grad=requires_grad),
  5024. *arguments)
  5025. def sample_inputs_split(op_info, device, dtype, requires_grad, *, list_args=False, **kwargs):
  5026. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  5027. if list_args:
  5028. cases = (
  5029. ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]),)),
  5030. ((S, S, S), (torch.Size([int(S / 2), S - int(S / 2) * 2, int(S / 2)]), 2),),
  5031. ((S, S, S), (torch.Size([int(S / 2), S - int(S / 2) * 2, int(S / 2)]), -2),)
  5032. )
  5033. else:
  5034. cases = ( # type: ignore[assignment]
  5035. ((S, S, S), (2,)),
  5036. ((S, S, S), (S, 1)),
  5037. )
  5038. for shape, args in cases:
  5039. yield SampleInput(make_arg(shape), args=args)
  5040. def sample_inputs_split_with_sizes(op_info, device, dtype, requires_grad, **kwargs):
  5041. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  5042. cases = (((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]),)),
  5043. ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3), 0]),)),
  5044. ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]), 2)),
  5045. ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]), -2)),
  5046. )
  5047. for shape, args in cases:
  5048. yield SampleInput(make_arg(shape), args=args)
  5049. def sample_inputs_msort(op_info, device, dtype, requires_grad, **kwargs):
  5050. def apply_grad(t):
  5051. if dtype in floating_types_and(torch.float16, torch.bfloat16):
  5052. t.requires_grad_(requires_grad)
  5053. def large_1d_unique(dtype, device):
  5054. res = torch.randperm(L * L * L, dtype=torch.int64, device=device)
  5055. res = res.to(dtype)
  5056. apply_grad(res)
  5057. return res
  5058. # Test case for large tensor.
  5059. yield SampleInput(large_1d_unique(dtype, device))
  5060. yield SampleInput(make_tensor((S, M, S), dtype=dtype, device=device,
  5061. low=None, high=None,
  5062. requires_grad=requires_grad))
  5063. def sample_inputs_lerp(op_info, device, dtype, requires_grad, **kwargs):
  5064. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  5065. # no broadcast
  5066. yield SampleInput(make_arg((S, S)), make_arg((S, S)), 0.4)
  5067. # broadcast rhs
  5068. yield SampleInput(make_arg((S, S)), make_arg((S,)), 0.4)
  5069. # scalar tensor
  5070. yield SampleInput(make_arg(()), make_arg(()), 0.4)
  5071. # broadcast rhs scalar-tensor
  5072. yield SampleInput(make_arg((S, S)), make_arg(()), 0.4)
  5073. # broadcast rhs with weight tensor
  5074. yield SampleInput(make_arg((S, S)), make_arg((S,)), make_arg((S, S)))
  5075. # broadcast rhs and weight tensor
  5076. yield SampleInput(make_arg((S, S)), make_arg((S, 1)), make_arg((S,)))
  5077. # broadcast lhs
  5078. yield SampleInput(make_arg((S,)), make_arg((S, S)), 0.4).with_metadata(broadcasts_input=True)
  5079. # scalar broadcast_lhs
  5080. yield SampleInput(make_arg(()), make_arg((S, S)), 0.4).with_metadata(broadcasts_input=True)
  5081. # broadcast all
  5082. yield SampleInput(make_arg((S, 1)), make_arg((S, S)), 0.4).with_metadata(broadcasts_input=True)
  5083. # tensor broadcast all
  5084. yield SampleInput(make_arg((S, 1)), make_arg((S, S)), make_arg((S, 1))).with_metadata(
  5085. broadcasts_input=True)
  5086. # no broadcast with weight tensor
  5087. yield SampleInput(make_arg((S, S)), make_arg((S, S)), make_arg((S, S)))
  5088. # broadcast lhs with weight tensor
  5089. yield SampleInput(make_arg((S,)), make_arg((S, S)), make_arg((S, S))).with_metadata(
  5090. broadcasts_input=True)
  5091. # broadcast lhs and weight tensor
  5092. yield SampleInput(make_arg((S,)), make_arg((S, S, S)), make_arg((S, S))).with_metadata(
  5093. broadcasts_input=True)
  5094. # broadcast lhs and weight tensor variant
  5095. yield SampleInput(make_arg((S, S)), make_arg((S, S, S)), make_arg((S,))).with_metadata(
  5096. broadcasts_input=True)
  5097. if dtype.is_complex:
  5098. # no broadcast
  5099. yield SampleInput(make_arg((S, S)), make_arg((S, S)), 0.4j)
  5100. yield SampleInput(make_arg((S, S)), make_arg((S, S)), 1.2 + 0.1j)
  5101. # broadcast rhs
  5102. yield SampleInput(make_arg((S, S)), make_arg((S,)), 0.4j)
  5103. yield SampleInput(make_arg((S, S)), make_arg((S, S)), 5.4 + 9j)
  5104. # scalar tensor
  5105. yield SampleInput(make_arg(()), make_arg(()), 0.4j)
  5106. yield SampleInput(make_arg(()), make_arg(()), 6.1 + 0.004j)
  5107. # broadcast rhs scalar-tensor
  5108. yield SampleInput(make_arg((S, S)), make_arg(()), 0.4j)
  5109. yield SampleInput(make_arg((S, S)), make_arg(()), 1 + 2j)
  5110. def sample_inputs_tensordot(self, device, dtype, requires_grad, **kwargs):
  5111. cases = (
  5112. ((2, 2, 2), (2, 2, 2), (2)),
  5113. ((2, 2, 1), (2, 1, 2), ([0, 1], [2, 0])),
  5114. )
  5115. for first_shape, second_shape, dims in cases:
  5116. yield SampleInput(make_tensor(first_shape, dtype=dtype, device=device,
  5117. requires_grad=requires_grad),
  5118. make_tensor(second_shape, dtype=dtype, device=device,
  5119. requires_grad=requires_grad),
  5120. dims=dims)
  5121. def sample_inputs_kron(op_info, device, dtype, requires_grad, **kwargs):
  5122. make_arg = partial(
  5123. make_tensor, dtype=dtype, device=device, requires_grad=requires_grad, low=None, high=None)
  5124. test_cases = (
  5125. ((S, S), (M, L)),
  5126. )
  5127. for input_shape, other_shape in test_cases:
  5128. input = make_arg(input_shape)
  5129. other = make_arg(other_shape)
  5130. yield SampleInput(input, other)
  5131. def sample_inputs_inner(self, device, dtype, requires_grad, **kwargs):
  5132. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  5133. yield SampleInput(make_arg(S), make_arg(S))
  5134. yield SampleInput(make_arg(), make_arg(S, S))
  5135. def sample_inputs_scatter(op_info, device, dtype, requires_grad, **kwargs):
  5136. def _tensor(shape, dtype=dtype, low=None, high=None):
  5137. return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad)
  5138. def _gather(shape, index_dim, max_indices):
  5139. return gather_variable(shape, index_dim, max_indices, device=device)
  5140. zero = torch.tensor(0, dtype=torch.long, device=device)
  5141. test_cases = (
  5142. (_tensor((M, S)), (0, _gather((S, S), 1, M), _tensor((S, S)))),
  5143. (_tensor((M, S)), (1, _gather((S, S), 0, S), _tensor((S, S)))),
  5144. (_tensor((M, S)), (-1, _gather((S, S), 0, S), _tensor((S, S)))),
  5145. (_tensor((M, S)), (0, _gather((M, S // 2), 1, M), _tensor((M, S // 2)))),
  5146. (_tensor((M, S)), (1, _gather((M, S // 2), 0, S), _tensor((M, S // 2)))),
  5147. (_tensor((M, S)), (-1, _gather((M, S // 2), 0, S), _tensor((M, S // 2)))),
  5148. (_tensor(()), (0, zero.clone().detach(), _tensor(()))),
  5149. (_tensor(()), (0, zero.clone().detach(), 2.5)),
  5150. )
  5151. for tensor, args in test_cases:
  5152. yield SampleInput(tensor, *args)
  5153. if not requires_grad:
  5154. yield SampleInput(tensor.clone().detach(), *args, reduce='add')
  5155. if dtype.is_floating_point:
  5156. yield SampleInput(tensor.clone().detach(), *args, reduce='multiply')
  5157. def sample_inputs_scatter_add(op_info, device, dtype, requires_grad, **kwargs):
  5158. def _tensor(shape, dtype=dtype, low=None, high=None):
  5159. return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad)
  5160. def _gather(shape, index_dim, max_indices):
  5161. return gather_variable(shape, index_dim, max_indices, device=device)
  5162. zero = torch.tensor(0, dtype=torch.long, device=device)
  5163. yield SampleInput(_tensor((M, S)), 0, _gather((S, S), 1, M), _tensor((S, S)))
  5164. yield SampleInput(_tensor((M, S)), 1, _gather((S, S), 0, S), _tensor((S, S)))
  5165. yield SampleInput(_tensor((M, S)), -1, _gather((S, S), 0, S), _tensor((S, S)))
  5166. yield SampleInput(_tensor((M, S)), 0, _gather((M, S // 2), 1, M), _tensor((M, S // 2)))
  5167. yield SampleInput(_tensor((M, S)), 1, _gather((M, S // 2), 0, S), _tensor((M, S // 2)))
  5168. yield SampleInput(_tensor((M, S)), -1, _gather((M, S // 2), 0, S), _tensor((M, S // 2)))
  5169. yield SampleInput(_tensor(()), 0, zero.clone().detach(), _tensor(()))
  5170. def sample_inputs_scatter_reduce(op_info, device, dtype, requires_grad, **kwargs):
  5171. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  5172. gather = partial(gather_variable, device=device)
  5173. zero = torch.tensor(0, dtype=torch.long, device=device)
  5174. test_cases = (
  5175. ((M, S), 0, gather((S, S), 1, M), (S, S)),
  5176. ((M, S), 1, gather((S, S), 0, S), (S, S)),
  5177. ((M, S), -1, gather((S, S), 0, S), (S, S)),
  5178. ((M, S), 0, gather((M, S // 2), 1, M), (M, S // 2)),
  5179. ((M, S), 1, gather((M, S // 2), 0, S), (M, S // 2)),
  5180. ((M, S), -1, gather((M, S // 2), 0, S), (M, S // 2)),
  5181. ((), 0, zero.clone().detach(), ()),
  5182. )
  5183. reduce = op_info.variant_test_name
  5184. for (inp_shape, dim, index, src_shape), include_self in product(test_cases, [False, True, False]):
  5185. yield SampleInput(make_arg(inp_shape),
  5186. args=(dim, index, make_arg(src_shape), reduce),
  5187. kwargs={'include_self': include_self})
  5188. # Sample inputs to test edge cases for backward
  5189. # Check that gradients are propagated correctly for prod when zeros in self/src are reduced
  5190. if requires_grad and reduce == 'prod':
  5191. # This sample tests gradients for the following cases
  5192. # (a) 1 zero reduced (from src (self[0, 1], self[1, 1]), from self (self[0, 0], self[2, 0]))
  5193. # (b) 2 zeros reduced (1 from src and 1 from self (self[1, 0])
  5194. # (c) no zeros reduced (self([2, 1]))
  5195. # (d) 2 zeros reduced (both from src) is tested in test/test_autograd.py
  5196. # test_scatter_index_reduce_prod_gradgrad_error as this case is not supported for gradgrad
  5197. input = torch.tensor([[0, 13], [0, 17], [0, 19]], dtype=dtype, device=device, requires_grad=requires_grad)
  5198. src = torch.tensor([[0, 1, 2, 3], [0, 4, 0, 1], [2, 3, 5, 6]], dtype=dtype, device=device, requires_grad=requires_grad)
  5199. idx = torch.tensor([[1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 0, 1]], dtype=torch.long, device=device)
  5200. yield SampleInput(input,
  5201. args=(1, idx, src, reduce),
  5202. kwargs={'include_self': True})
  5203. def sample_inputs_segment_reduce(op_info, device, dtype, requires_grad, *, mode='lengths', **kwargs):
  5204. def _tensor(shape, dtype=dtype, low=None, high=None):
  5205. return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad)
  5206. zero = torch.tensor(0, dtype=torch.long, device=device)
  5207. test_cases = (
  5208. # inp_shape, dim, lengths, unsafe
  5209. ((S,), 0, [0, 1, 2, 2], False),
  5210. ((S,), 0, [0, 1, 2, 2], True),
  5211. ((S,), 0, [2, 0, 3, 0], False),
  5212. ((S, S), 0, [0, 1, 2, 2], False),
  5213. # test when lengths do not sum to dim size
  5214. ((M, S, S), 0, [1, 2, 0, 6, 0], True),
  5215. # test for higher dimensions
  5216. ((S, S), 1, [[0, 1, 2, 2] for _ in range(S)], False),
  5217. ((S, S), 1, [[2, 0, 3, 0], [0, 1, 2, 2], [3, 0, 2, 0], [1, 1, 1, 2], [0, 1, 2, 2]], False),
  5218. ((S, S, S), 1, [[0, 1, 2, 2] for _ in range(S)], False),
  5219. ((S, S, S), 1, [[2, 0, 3, 0], [0, 1, 2, 2], [3, 0, 2, 0], [1, 1, 1, 2], [0, 1, 2, 2]], False),
  5220. )
  5221. reductions = ["max", "mean", "min", "sum", "prod"]
  5222. for args, reduce, initial in product(test_cases, reductions, [1, 2]):
  5223. inp_shape, dim, lengths, unsafe = args
  5224. lengths_t = torch.tensor(lengths, dtype=torch.long, device=device)
  5225. sample_input_kwargs = {'axis': dim, 'unsafe': unsafe, 'initial': initial}
  5226. if mode == 'lengths':
  5227. sample_input_kwargs['lengths'] = lengths_t
  5228. elif mode == 'offsets':
  5229. zeros_shape = list(lengths_t.shape)
  5230. zeros_shape[dim] = 1
  5231. offsets_t = torch.cat((lengths_t.new_zeros(zeros_shape), lengths_t), dim).cumsum_(dim)
  5232. sample_input_kwargs['offsets'] = offsets_t
  5233. else:
  5234. raise RuntimeError(f"mode most be one of 'offsets' or 'lengths' got '{mode}'.")
  5235. yield SampleInput(_tensor(inp_shape),
  5236. args=(reduce,),
  5237. kwargs=sample_input_kwargs)
  5238. def sample_inputs_ravel(op_info, device, dtype, requires_grad, **kwargs):
  5239. make_arg = partial(make_tensor, dtype=dtype, device=device,
  5240. low=None, high=None, requires_grad=requires_grad)
  5241. yield SampleInput(make_arg((S, S, S)))
  5242. yield SampleInput(make_arg(()))
  5243. yield SampleInput(make_arg((S, S, S), noncontiguous=True))
  5244. def sample_inputs_tril_triu(op_info, device, dtype, requires_grad, **kwargs):
  5245. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  5246. cases = (((M, M), ()),
  5247. ((M, M), (2,),),
  5248. ((M, S), ()),
  5249. ((M, S), (-1,)),
  5250. ((M, M), (2,),),
  5251. ((S, M, S), ()),
  5252. ((S, M, S), (2,)),
  5253. ((3, 3, S, S), ()),)
  5254. for shape, args in cases:
  5255. yield SampleInput(make_arg(shape), args=args)
  5256. def error_inputs_tril_triu(opinfo, device, **kwargs):
  5257. make_arg = partial(make_tensor, device=device, dtype=torch.float32)
  5258. # error inputs for input.ndim <= 2
  5259. yield ErrorInput(SampleInput(make_arg((4,))), error_regex="input tensor must have at least 2 dimensions")
  5260. def sample_inputs_trilu_indices(op_info, device, dtype, requires_grad, **kwargs):
  5261. # (row, col, offset)
  5262. args_list = ((0, 0),
  5263. (20, 0),
  5264. (0, 20),
  5265. (20, 21, 0),
  5266. (20, 21, 7),
  5267. (20, 21, -7),
  5268. # Large test cases below are deliberately commented out to speed up CI
  5269. # tests and to avoid OOM error. When modifying implementations of
  5270. # tril_indices and triu_indices, please enable these tests and make sure
  5271. # they pass.
  5272. # (2, 68435455, 3),
  5273. # (5000, 5000),
  5274. # (5000, 5000, 1234),
  5275. # (5000, 5000, -1233),
  5276. )
  5277. for args in args_list:
  5278. yield SampleInput(args[0], args=args[1:], kwargs={"dtype": dtype, "device": device})
  5279. def sample_inputs_clone_contiguous(op_info, device, dtype, requires_grad, **kwargs):
  5280. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  5281. yield SampleInput(make_arg((S, M, S)))
  5282. yield SampleInput(make_arg(()))
  5283. def reference_inputs_clone_contiguous(op, device, dtype, requires_grad, **kwargs):
  5284. # NOTE: the default memory format for clone is torch.preserve_format, for contiguous it's torch.contiguous_format
  5285. # This exploits that default to test torch.preserve_format for clone, without causing an error when testing contiguous
  5286. yield from sample_inputs_clone_contiguous(op, device, dtype, requires_grad, **kwargs)
  5287. shapes = (
  5288. (3, 5, 6),
  5289. (1, 1, 3, 5, 6),
  5290. (1, 1, 3, 5, 6, 1, 1),
  5291. (1, 0, 3, 5, 0, 2),
  5292. (1, 0, 3, 5, 0, 0, 1, 1, 2),
  5293. (),
  5294. )
  5295. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  5296. for shape in shapes:
  5297. yield SampleInput(make_arg(shape))
  5298. yield SampleInput(make_arg(shape).transpose(0, -1))
  5299. yield SampleInput(make_arg(shape, noncontiguous=True))
  5300. yield SampleInput(make_arg(shape, noncontiguous=True).transpose(0, -1))
  5301. yield SampleInput(make_arg(shape), kwargs={'memory_format': torch.contiguous_format})
  5302. yield SampleInput(make_arg(shape).transpose(0, -1), kwargs={'memory_format': torch.contiguous_format})
  5303. yield SampleInput(make_arg(shape, noncontiguous=True), kwargs={'memory_format': torch.contiguous_format})
  5304. yield SampleInput(make_arg(shape, noncontiguous=True).transpose(0, -1), kwargs={'memory_format': torch.contiguous_format})
  5305. # shape, strides, offset
  5306. strided_cases = (
  5307. ((5, 6, 2), (1, 1, 7), 2),
  5308. ((5, 5, 4), (1, 1, 7), 2),
  5309. ((5, 5, 2), (4, 5, 7), 3),
  5310. ((5, 5, 2), (5, 5, 7), 3),
  5311. ((5, 5, 2), (5, 5, 5), 3),
  5312. ((9, 5, 2), (0, 1, 7), 3),
  5313. )
  5314. for shape, strides, offset in strided_cases:
  5315. yield SampleInput(make_arg(500,).as_strided(shape, strides, offset))
  5316. yield SampleInput(make_arg(500,).as_strided(shape, strides, offset), kwargs={'memory_format': torch.contiguous_format})
  5317. # channels last 2D
  5318. yield SampleInput(make_arg((2, 2, 2, 2)), kwargs={'memory_format': torch.channels_last})
  5319. a = make_arg((2, 2, 2, 2)).permute(0, 3, 1, 2)
  5320. yield SampleInput(a, kwargs={'memory_format': torch.channels_last})
  5321. # channels last 3D
  5322. yield SampleInput(make_arg((2, 2, 2, 2, 2)), kwargs={'memory_format': torch.channels_last_3d})
  5323. a = make_arg((2, 2, 2, 2, 2)).permute(0, 4, 1, 2, 3)
  5324. yield SampleInput(a, kwargs={'memory_format': torch.channels_last_3d})
  5325. def sample_inputs_sum_to_size(op_info, device, dtype, requires_grad, **kwargs):
  5326. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  5327. # list of tuples (shape, shape) defining the shapes of the input and output tensors
  5328. sample_shapes = [
  5329. ((), ()),
  5330. ((S,), (1,)),
  5331. ((S, S), (1, 1)),
  5332. ((S, S), (1, S)),
  5333. ((S, S), (S, S)),
  5334. ((S, S, S), (S, 1, S)),
  5335. ]
  5336. for input_shape, output_shape in sample_shapes:
  5337. yield SampleInput(make_arg(input_shape), args=(output_shape,))
  5338. if output_shape == ():
  5339. continue
  5340. yield SampleInput(make_arg(input_shape), args=(list(output_shape),))
  5341. yield SampleInput(make_arg(input_shape), args=(*output_shape,))
  5342. def error_inputs_sum_to_size(op_info, device, **kwargs):
  5343. shape = (M, S, M)
  5344. err_msg = "is not expandable to size"
  5345. si = SampleInput(make_tensor(shape, device=device, dtype=torch.float32), args=(M, M))
  5346. yield ErrorInput(si, error_regex=err_msg)
  5347. shape = (M + 1, S, S, M)
  5348. err_msg = "is not expandable to size"
  5349. si = SampleInput(make_tensor(shape, device=device, dtype=torch.float32), args=(M + 1, 1))
  5350. yield ErrorInput(si, error_regex=err_msg)
  5351. def sample_inputs_resize_ops(op_info, device, dtype, requires_grad, **kwargs):
  5352. make_arg = partial(make_tensor, dtype=dtype, device=device)
  5353. cases = (((S, S, S), (S * S, S)),
  5354. ((), ()),
  5355. ((), (1, 1, 1)),
  5356. )
  5357. for shape, args_or_shape in cases:
  5358. # Update `args` based on operator
  5359. if op_info.name == 'resize_':
  5360. # resize_ takes shape/tuple of ints,
  5361. args = (args_or_shape, )
  5362. elif op_info.name == 'resize_as_':
  5363. # resize_as_ takes another tensor
  5364. args = (make_arg(shape, requires_grad=False), ) # type:ignore[assignment]
  5365. else:
  5366. raise ValueError("sample_inputs_resize_ops is being used with incorrect operator")
  5367. yield SampleInput(make_arg(shape, requires_grad=requires_grad), args=args)
  5368. def sample_inputs_view_reshape(op_info, device, dtype, requires_grad, **kwargs):
  5369. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  5370. cases = (
  5371. # a, b, is_tensor_supported
  5372. ((S, S, S), (S * S, S), True),
  5373. ((S * S, S), (S, S, S), True),
  5374. ((S * S, S), (S, -1, S), False), # neg index
  5375. ((S * S * 2, S), (S, -1), False), # neg index
  5376. ((S,), (S,), True),
  5377. ((), (), False), # empty
  5378. ((), (1,), True),
  5379. )
  5380. for a, b, is_tensor_supported in cases:
  5381. # skip unsupported cases
  5382. if kwargs.get("tensor_arg") and not is_tensor_supported:
  5383. continue
  5384. # convert to tensor
  5385. if kwargs.get("tensor_arg"):
  5386. b = make_arg(b, requires_grad=False)
  5387. yield SampleInput(make_arg(a), args=(b,))
  5388. def reference_inputs_view_reshape(op, device, dtype, requires_grad, **kwargs):
  5389. yield from sample_inputs_view_reshape(op, device, dtype, requires_grad, **kwargs)
  5390. cases = (
  5391. # a, b, is_tensor_supported
  5392. ((125,), (25, 5), True),
  5393. ((25, 25), (1, 5, 5, 1, 5, 1, 5, 1), True),
  5394. ((16, 32), (2, 4, 1, 4, 4, 1, 4), True),
  5395. ((16, 12), (12, 16), True),
  5396. ((1, 16, 12), (12, 16), True),
  5397. ((1, 5, 1, 5), (25, 1), True),
  5398. ((2, 4, 2), (4, 4), True),
  5399. ((1, 4), (1, 1, 2, 1, 2), True),
  5400. ((3, 5, 7), (7, 5, 3), True),
  5401. ((1,), (), False), # empty
  5402. ((5, 0, 2, 3), (5, 0, 2, 3), True),
  5403. ((2, 1, 0, 3, 1), (5, 0), True),
  5404. ((1,), (), False), # empty
  5405. ((4, 5, 6), (4, 5, 6, 1, 1, 1), True),
  5406. ((), (1, 1, 1, 1), False), # empty
  5407. )
  5408. irreversible_cases = (
  5409. ((), (-1,), False), # neg index, empty
  5410. ((4, 7, 9, 1, 1), (1, 4, 3, -1, 1), False), # neg index
  5411. )
  5412. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  5413. for a, b, is_tensor_supported in cases:
  5414. # skip unsupported cases
  5415. if kwargs.get("tensor_arg") and not is_tensor_supported:
  5416. continue
  5417. if kwargs.get("tensor_arg"):
  5418. # convert to tensor
  5419. yield SampleInput(make_arg(a), args=(make_arg(b, requires_grad=False),))
  5420. yield SampleInput(make_arg(b), args=(make_arg(a, requires_grad=False),))
  5421. else:
  5422. yield SampleInput(make_arg(a), args=(b,))
  5423. yield SampleInput(make_arg(b), args=(a,))
  5424. for a, b, is_tensor_supported in irreversible_cases:
  5425. # skip unsupported cases
  5426. if kwargs.get("tensor_arg") and not is_tensor_supported:
  5427. continue
  5428. # convert to tensor
  5429. if kwargs.get("tensor_arg"):
  5430. b = make_arg(b, requires_grad=False)
  5431. yield SampleInput(make_arg(a), args=(b,))
  5432. def error_inputs_view_reshape(op, device, **kwargs):
  5433. cases = (
  5434. # a, b, is_tensor_supported
  5435. # Reshape to different numel
  5436. ((2,), (), False), # empty
  5437. ((1, 3, 0), (), False), # empty
  5438. ((4, 3), (4, 2), True),
  5439. ((1, 3, 5), (5, 2, 2), True),
  5440. # No valid inference
  5441. ((1, 3, 5), (5, -1, 2), False), # neg index
  5442. # Two inferred shapes
  5443. ((1, 3, 5), (5, -1, -1), False), # neg index
  5444. ((1), (0, -1), False), # neg index
  5445. ((0, 5), (0, -1), False), # neg index
  5446. )
  5447. make_arg = partial(make_tensor, dtype=torch.float32, device=device, requires_grad=False)
  5448. for a, b, is_tensor_supported in cases:
  5449. # skip unsupported cases
  5450. if kwargs.get("tensor_arg") and not is_tensor_supported:
  5451. continue
  5452. if b == (5, -1, -1):
  5453. error_regex = "only one dimension can be inferred"
  5454. elif a == (0, 5):
  5455. error_regex = (r"cannot reshape tensor of 0 elements into shape "
  5456. r"\[0, -1\] because the unspecified dimension size "
  5457. r"-1 can be any value and is ambiguous")
  5458. else:
  5459. # to avoid having issues with a regex
  5460. shape = ', '.join(map(str, b))
  5461. size = a if type(a) is int else functools.reduce(operator.mul, a, 1)
  5462. error_regex = rf"shape '\[{shape}\]' is invalid for input of size {size}"
  5463. # convert to tensor
  5464. if kwargs.get("tensor_arg"):
  5465. b = make_arg(b, requires_grad=False)
  5466. yield ErrorInput(SampleInput(make_arg(a), args=(b,)), error_type=Exception,
  5467. error_regex=error_regex)
  5468. def sample_inputs_atleast1d2d3d(op_info, device, dtype, requires_grad, **kwargs):
  5469. input_list = []
  5470. shapes = ((S, S, S, S), (S, S, S), (S, S), (S, ), (),)
  5471. make_tensor_partial = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  5472. for shape in shapes:
  5473. yield SampleInput(make_tensor_partial(shape))
  5474. yield SampleInput([make_tensor_partial(shape) for shape in shapes])
  5475. def sample_inputs_column_stack(op_info, device, dtype, requires_grad, **kwargs):
  5476. cases: Tuple[tuple, tuple] = ( # type: ignore[assignment]
  5477. ((S, 2, 1), (S, 3, 1)),
  5478. ((S), (S, 5)), ((), (1, S))
  5479. )
  5480. make_tensor_partial = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  5481. for shape1, shape2 in cases:
  5482. yield SampleInput([make_tensor_partial(shape1), make_tensor_partial(shape2)])
  5483. def sample_inputs_flatten(op_info, device, dtype, requires_grad, **kwargs):
  5484. shapes = ((S, S, S), (S, S), (S, ), (),)
  5485. make_tensor_partial = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  5486. for shape in shapes:
  5487. yield SampleInput(make_tensor_partial(shape))
  5488. if len(shape) > 1:
  5489. yield SampleInput(make_tensor_partial(shape), start_dim=1, end_dim=-1)
  5490. def reference_inputs_flatten(op, device, dtype, requires_grad, **kwargs):
  5491. yield from sample_inputs_flatten(op, device, dtype, requires_grad, **kwargs)
  5492. # shape x start_dim x end_dim
  5493. cases = (
  5494. ((5, 4, 0, 1, 3, 7), 1, 3),
  5495. ((5, 4, 0, 1, 3, 7), 4, 5),
  5496. ((5, 4, 1, 1, 3, 7), 2, 3),
  5497. ((), 0, -1),
  5498. ((1,), 0, -1),
  5499. ((3, 7, 5), 1, 2),
  5500. ((4, 5), 1, 1),
  5501. ((1, 5, 5, 1, 5, 1, 5, 1), 0, 2),
  5502. ((1, 5, 5, 1, 5, 1, 5, 1), 3, -1),
  5503. ((1, 5, 5, 1, 5, 7, 5, 1), -2, -1),
  5504. ((2, 4, 2), 0, 1),
  5505. ((4, 2, 2), 1, 2),
  5506. ((0, 3, 4, 5), 1, 3),
  5507. )
  5508. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  5509. for shape, start, end in cases:
  5510. yield SampleInput(make_arg(shape), args=(start, end,))
  5511. yield SampleInput(make_arg(shape, noncontiguous=True).transpose(0, -1), args=(start, end,))
  5512. yield SampleInput(make_arg(shape).transpose(0, -1), args=(start, end,))
  5513. def sample_inputs_unflatten(op_info, device, dtype, requires_grad, **kwargs):
  5514. # in_shape, dim, sizes
  5515. args = (((8,), 0, (8,)),
  5516. ((8,), 0, (4, 2)),
  5517. ((8,), -1, (2, 2, 2)),
  5518. ((8,), -1, (-1, 2)),
  5519. ((3, 6, 2), 1, (2, 3)),
  5520. ((3, 6, 2), -2, (2, 3)),
  5521. ((3, 6, 2), -2, (-1, 3)),
  5522. ((3, 2, 12), 2, (3, 2, 2)),
  5523. ((4, 0), 0, (2, 2)),
  5524. ((4, 0), 1, (2, 0, 0, 0)),
  5525. )
  5526. make_tensor_partial = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  5527. for in_shape, dim, sizes in args:
  5528. yield SampleInput(make_tensor_partial(in_shape), args=(dim, sizes))
  5529. def sample_inputs_select(op_info, device, dtype, requires_grad, **kwargs):
  5530. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  5531. cases = (((S, S, S), (1, 2)),
  5532. ((S, S, S), (-1, 2)),
  5533. ((S, S, S), (-1, -1)),
  5534. ((S, S, S), (1, -1)),
  5535. ((S,), (0, 2))
  5536. )
  5537. for shape, args in cases:
  5538. yield SampleInput(make_arg(shape), args=args)
  5539. def sample_inputs_select_scatter(op_info, device, dtype, requires_grad, **kwargs):
  5540. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  5541. cases = (((S, S, S), (S, S), (1, 2)),
  5542. ((S, S, S), (S, S), (-1, 2)),
  5543. ((S, S, S), (S, S), (-1, -1)),
  5544. ((S, S, S), (S, S), (1, -1)),
  5545. ((S,), (), (0, 2))
  5546. )
  5547. for input_shape, src_shape, args in cases:
  5548. input_ = make_arg(input_shape)
  5549. src = make_arg(src_shape)
  5550. yield SampleInput(input_, args=(src, *args))
  5551. def sample_inputs_slice_scatter(op_info, device, dtype, requires_grad, **kwargs):
  5552. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  5553. cases = (((L, L, L), (L, L, L,), (0, 0, L, 1)),
  5554. ((L, L, L), (L // 2, L, L,), (0, L // 2, L, 1)),
  5555. ((L, L, L), (L // 4, L, L,), (0, L // 2, L, 2)),
  5556. ((L, L, L), (L, L, L,), (1, 0, L, 1)),
  5557. ((L, L, L), (L, L // 2, L,), (1, L // 2, L, 1)),
  5558. ((L, L, L), (L, L // 4, L,), (1, L // 2, L, 2)),
  5559. ((L, L, L), (L, L, L,), (2, 0, L, 1)),
  5560. ((L, L, L), (L, L, L // 2,), (2, L // 2, L, 1)),
  5561. ((L, L, L), (L, L, L // 4,), (2, L // 2, L, 2)),
  5562. )
  5563. for input_shape, src_shape, args in cases:
  5564. input_ = make_arg(input_shape)
  5565. src = make_arg(src_shape)
  5566. yield SampleInput(input_, args=(src, *args))
  5567. def sample_inputs_expand(op_info, device, dtype, requires_grad, **kwargs):
  5568. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  5569. cases = (((S, 1, 1), (S, S, S)),
  5570. ((S, 1, S), (S, S, S)),
  5571. ((S, 1, S), (-1, S, -1)),
  5572. ((S, 1, S), (-1, S, S)),
  5573. ((S, 1), (S, S, S)),
  5574. ((1,), (S, S, S)),
  5575. ((1, S), (1, 1, S)),
  5576. ((), ()),
  5577. ((), (1, 3, 2)),
  5578. )
  5579. for case in cases:
  5580. shape, args = case
  5581. yield SampleInput(make_arg(shape), args=(args,))
  5582. def sample_inputs_conversion(op_info, device, dtype, requires_grad, **kwargs):
  5583. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  5584. shapes = ((),
  5585. (2, 3))
  5586. memory_format_options = [None, torch.contiguous_format]
  5587. for shape, memory_format in itertools.product(shapes, memory_format_options):
  5588. yield SampleInput(make_arg(shape),
  5589. kwargs={'memory_format': memory_format} if memory_format else {})
  5590. yield SampleInput(make_arg((2, 3, 2, 3)), kwargs={'memory_format': torch.channels_last})
  5591. def sample_inputs_expand_as(op_info, device, dtype, requires_grad, **kwargs):
  5592. make_arg = partial(make_tensor, dtype=dtype, device=device)
  5593. cases = (((S, 1, 1), (S, S, S)),
  5594. ((), ()),
  5595. ((), (1, 1)),
  5596. )
  5597. for shape, shape_other in cases:
  5598. yield SampleInput(make_arg(shape, requires_grad=requires_grad),
  5599. args=(make_arg(shape_other, requires_grad=False),))
  5600. def sample_inputs_where(op_info, device, dtype, requires_grad, **kwargs):
  5601. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  5602. def make_bool_mask(shape):
  5603. # Make sure atleast one element is nonzero,
  5604. # except for empty tensor
  5605. mask_t = make_tensor(shape, dtype=torch.bool, device=device, requires_grad=False)
  5606. if mask_t.numel() == 0:
  5607. return mask_t
  5608. elif mask_t.numel() == 1:
  5609. mask_t.fill_(True)
  5610. return mask_t
  5611. if mask_t.sum() == 0:
  5612. def random_index(shape):
  5613. return tuple(map(lambda max_idx: random.randrange(0, max_idx), shape))
  5614. mask_t[random_index(mask_t.shape)] = True
  5615. return mask_t
  5616. return mask_t
  5617. cases = (((M, M), (M, M), (M, M), False),
  5618. ((M, 1, M), (M, M), (M, M, 1), True),
  5619. ((), (), (), False),
  5620. ((M, 1, M), (), (M, M, 1), True),
  5621. ((), (M, M), (), True),
  5622. ((), (2), (1, 1), True),
  5623. )
  5624. for shape, mask_shape, other_shape, broadcasts_input in cases:
  5625. yield SampleInput(make_arg(shape),
  5626. args=(make_bool_mask(mask_shape), make_arg(other_shape)),
  5627. broadcasts_input=broadcasts_input)
  5628. # TODO: add reference inputs for where(condition) signature
  5629. def reference_inputs_where(op, device, dtype, requires_grad, **kwargs):
  5630. yield from sample_inputs_where(op, device, dtype, requires_grad, **kwargs)
  5631. make_cond = partial(make_tensor, dtype=torch.bool, device=device, requires_grad=requires_grad)
  5632. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  5633. # noncontiguous
  5634. c = make_cond((10, 3), noncontiguous=True)
  5635. a = make_arg((10, 1), noncontiguous=True)
  5636. b = make_arg((3, 10, 3)).transpose(0, -1)
  5637. # NOTE that the OpInfo for where takes samples of the form a, cond, b
  5638. yield SampleInput(a, args=(c, b))
  5639. # type promoting
  5640. other_dtype = torch.double if dtype is not torch.double else torch.long
  5641. c = make_cond((10, 3), noncontiguous=True)
  5642. a = make_arg((10, 1), dtype=torch.long)
  5643. b = make_arg((10, 1))
  5644. yield SampleInput(a, args=(c, b))
  5645. # two python scalars
  5646. c = make_cond((10, 3), noncontiguous=True)
  5647. a = make_arg((1,)).item()
  5648. b = make_arg((1,)).item()
  5649. yield SampleInput(a, args=(c, b))
  5650. # NaN propagation
  5651. if dtype.is_floating_point or dtype.is_complex:
  5652. if dtype.is_floating_point:
  5653. nan = float('nan')
  5654. else:
  5655. # dtype.is_complex
  5656. nan = complex(float('nan'), float('nan'))
  5657. c = make_cond((1, 10, 3))
  5658. a = make_arg((10, 3), noncontiguous=True)
  5659. a[2, 1] = nan
  5660. b = make_arg((1, 3))
  5661. b[0, 2] = nan
  5662. yield SampleInput(a, args=(c, b))
  5663. # Python scalars type promotion
  5664. for scalar in (0, 0.0, 2j, False):
  5665. yield SampleInput(scalar, args=(c, b))
  5666. yield SampleInput(a, args=(c, scalar))
  5667. def error_inputs_where(op_info, device, **kwargs):
  5668. shape = (S,)
  5669. err_msg = "Expected all tensors to be on the same device"
  5670. for devices in product(('cpu', device), repeat=3):
  5671. if len(set(devices)) == 2:
  5672. si = SampleInput(make_tensor(shape, device=devices[0], dtype=torch.float32),
  5673. args=(make_tensor(shape, dtype=torch.bool, device=devices[1]),
  5674. make_tensor(shape, device=devices[2], dtype=torch.float32)))
  5675. yield ErrorInput(si, error_regex=err_msg)
  5676. def sample_inputs_nonzero(op_info, device, dtype, requires_grad, **kwargs):
  5677. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  5678. sizes = ((), (S,), (S, S), (S, S, S), (S, 1, S), (S, 0, S))
  5679. inputs = []
  5680. for shape in sizes:
  5681. # construct input without any non-zero elements
  5682. zeros = torch.zeros(shape, dtype=dtype, device=device, requires_grad=requires_grad)
  5683. inputs.append(zeros)
  5684. # construct input with mixed zero and non-zero elements
  5685. mixed = make_arg(shape).requires_grad_(False)
  5686. mask_t = make_tensor(shape, dtype=torch.bool, device=device, requires_grad=False)
  5687. mixed[mask_t] = 0
  5688. inputs.append(mixed)
  5689. for input_t, as_tuple in product(inputs, [False, True]):
  5690. yield SampleInput(input_t.clone().requires_grad_(requires_grad),
  5691. kwargs=dict(as_tuple=as_tuple))
  5692. def sample_inputs_chunk(op_info, device, dtype, requires_grad, **kwargs):
  5693. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  5694. cases = (((S, S, S), (2,)),
  5695. ((S, S, S), (S, 1)),
  5696. ((S, S, S), (S, -1)))
  5697. for case in cases:
  5698. shape, args = case
  5699. yield SampleInput(make_arg(shape), args=args)
  5700. def reference_inputs_chunk(op, device, dtype, requires_grad, **kwargs):
  5701. yield from sample_inputs_chunk(op, device, dtype, requires_grad, **kwargs)
  5702. make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
  5703. # shape x chunks x dim
  5704. cases = (
  5705. ((13, 9, 11), 17, -1),
  5706. ((13, 9, 11), 11, -1),
  5707. ((13,), 12, -1),
  5708. ((15,), 12, -1),
  5709. ((15,), 7, 0),
  5710. ((15,), 9, 0),
  5711. ((3, 7), 9, 1),
  5712. ((3, 7), 9, 0),
  5713. ((3, 7), 2, 0),
  5714. ((3, 7), 3, 0),
  5715. ((3, 7), 1, 0),
  5716. ((3, 7), 1, 1),
  5717. ((4, 4), 2, 0),
  5718. )
  5719. for shape, chunks, dim in cases:
  5720. yield SampleInput(make_arg(shape), args=(chunks, dim))
  5721. def sample_inputs_kthvalue(op_info, device, dtype, requires_grad, **kwargs):
  5722. def _tensor(shape, dtype=dtype, low=None, high=None):
  5723. return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad)
  5724. test_cases = [
  5725. ((S, S, S), (2,)),
  5726. ((S, S, S), (2, 1,)),
  5727. ((S, S, S), (2, -1,)),
  5728. ((S, S, S), (2, 1, True,)),
  5729. ((S, S, S), (2, -1, True,)),
  5730. ((S,), (2, 0,)),
  5731. ((S,), (2, 0, True,)),
  5732. ((), (1,)),
  5733. ((), (1, 0,)),
  5734. ((), (1, 0, True)),
  5735. ]
  5736. yield from (SampleInput(_tensor(tensor), *args) for tensor, args in test_cases)
  5737. def error_inputs_kthvalue(op_info, device, **kwargs):
  5738. # tests overlapping output fails
  5739. t = make_tensor(10, dtype=torch.float32, device=device)
  5740. indices = torch.empty((), device=device, dtype=torch.long)
  5741. yield ErrorInput(SampleInput(t, 5, out=(t, indices)),
  5742. error_regex="unsupported operation")
  5743. k_out_of_range_err = "selected number k out of range for dimension"
  5744. yield ErrorInput(SampleInput(torch.randn(2, 2, device=device), 3, 0),
  5745. error_regex=k_out_of_range_err)
  5746. yield ErrorInput(SampleInput(torch.randn(2, 2, device=device), 3),
  5747. error_regex=k_out_of_range_err)
  5748. yield ErrorInput(SampleInput(torch.tensor(2, device=device), 3),
  5749. error_regex=k_out_of_range_err)
  5750. def sample_inputs_dropout(op_info, device, dtype, requires_grad, *,
  5751. train=None, valid_input_dim=None, **kwargs):
  5752. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  5753. if valid_input_dim:
  5754. cases = ((S,) * i for i in valid_input_dim)
  5755. else:
  5756. cases = ((S, S), (S,), ())
  5757. p_vals = [0.0, 0.5, 1.0]
  5758. # This is to handle special case for feature_alpha_dropout which has different
  5759. # supported dtypes depending on `train` parameter
  5760. training_vals = [train] if train is not None else [True, False]
  5761. for case, p, training in product(cases, p_vals, training_vals):
  5762. yield SampleInput(make_arg(case), p=p, training=training)
  5763. yield SampleInput(make_arg(case))
  5764. def sample_inputs_dropout_backward(op_info, device, dtype, requires_grad, **kwargs):
  5765. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  5766. make_mask = partial(make_tensor, device=device, dtype=torch.bool, requires_grad=False)
  5767. cases = ((S, S, S, S), (S,), ())
  5768. scale_vals = [0.0, 1.0, 2.0]
  5769. for case, scale in product(cases, scale_vals):
  5770. yield SampleInput(make_arg(case), make_mask(case), scale)
  5771. def sample_inputs_embedding_bag(op_info, device, dtype, requires_grad, **kwargs):
  5772. def make_input(shape):
  5773. return make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad)
  5774. def make_long_input(shape, *, low, high, noncontiguous=False):
  5775. return make_tensor(shape, device=device, dtype=torch.long, low=low, high=high,
  5776. noncontiguous=noncontiguous)
  5777. def make_per_sample_weight(flag, idx):
  5778. # a tensor of float / double weights, or None
  5779. # to indicate all weights should be taken to be 1
  5780. if flag:
  5781. return make_input(idx.shape)
  5782. return None
  5783. offsets = torch.tensor([0, 3], device=device, dtype=torch.long)
  5784. for generate_per_sample_weight in (True, False):
  5785. for mode in ('sum', 'mean', 'max'):
  5786. # per_sample_weights is only supported for mode='sum' (got mode='****')
  5787. if generate_per_sample_weight and mode in ('mean', 'max'):
  5788. continue
  5789. # 1-D index tensor
  5790. idx = make_long_input((S,), low=0, high=M)
  5791. per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
  5792. yield SampleInput(make_input((M, S)), args=(idx,),
  5793. kwargs={'offsets': offsets, 'mode': mode,
  5794. 'per_sample_weights': per_sample_weights})
  5795. idx = make_long_input((S,), low=0, high=M, noncontiguous=True)
  5796. per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
  5797. yield SampleInput(make_input((M, S)), args=(idx,),
  5798. kwargs={'offsets': offsets, 'mode': mode,
  5799. 'per_sample_weights': per_sample_weights})
  5800. # bag with zero length
  5801. idx = make_long_input((S,), low=0, high=M, noncontiguous=True)
  5802. per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
  5803. yield SampleInput(make_input((M, S)), args=(idx,),
  5804. kwargs={'offsets': torch.tensor([0, 0, 3], device=device, dtype=torch.long),
  5805. 'mode': mode,
  5806. 'per_sample_weights': per_sample_weights})
  5807. # 2-D index tensor
  5808. idx = make_long_input((S, S), low=0, high=M)
  5809. per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
  5810. yield SampleInput(make_input((M, S)), args=(idx,),
  5811. kwargs={'mode': mode, 'per_sample_weights': per_sample_weights})
  5812. idx = make_long_input((S, S), low=0, high=M, noncontiguous=True)
  5813. per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
  5814. yield SampleInput(make_input((M, S)), args=(idx,),
  5815. kwargs={'mode': mode, 'per_sample_weights': per_sample_weights})
  5816. # The gradient vector at `padding_idx` is not updated.
  5817. # Negative padding_idx
  5818. idx = make_long_input((6,), low=0, high=S)
  5819. idx[0] = 4
  5820. idx[4] = 4
  5821. per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
  5822. yield SampleInput(make_input((S, S)), args=(idx,),
  5823. kwargs={'padding_idx': -1, 'offsets': offsets,
  5824. 'mode': mode, 'per_sample_weights': per_sample_weights},)
  5825. idx = make_long_input((3, 3), low=0, high=S)
  5826. # Positive padding_idx
  5827. idx[0, 0] = 2
  5828. idx[1, 1] = 2
  5829. per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
  5830. yield SampleInput(make_input((S, S)), args=(idx,),
  5831. kwargs={'padding_idx': 2, 'mode': mode,
  5832. 'per_sample_weights': per_sample_weights},)
  5833. idx = make_long_input((6, ), low=0, high=S)
  5834. weights = make_input((S, S))
  5835. offsets_ = torch.tensor([0, 3, 6], device=device, dtype=torch.long)
  5836. per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
  5837. yield SampleInput(weights, args=(idx,),
  5838. kwargs={'mode': mode, 'offsets': offsets_, 'include_last_offset': True},)
  5839. if not requires_grad:
  5840. # Following inputs return different gradient from the numerical gradient.
  5841. # This is expected and relevant tests are present in `test_nn.py`.
  5842. # Due to inplace renorming of weight, the numerical gradient doesn't match the
  5843. # analytical gradient.
  5844. idx = make_long_input((2, 2), low=0, high=S)
  5845. weights = make_input((S, S)) * 2
  5846. per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
  5847. yield SampleInput(weights, args=(idx,),
  5848. kwargs={'max_norm': 1., 'mode': mode,
  5849. 'per_sample_weights': per_sample_weights},)
  5850. idx = make_long_input((6, ), low=0, high=S)
  5851. weights = make_input((S, S)) * 2
  5852. per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
  5853. yield SampleInput(weights, args=(idx,),
  5854. kwargs={'max_norm': 1., 'norm_type': 1.0,
  5855. 'mode': mode, 'offsets': offsets,
  5856. 'per_sample_weights': per_sample_weights},)
  5857. if mode != 'max':
  5858. # Scale the gradient based on the inverse frequency of a particular index.
  5859. # Note : smax mode does not support sparse weights
  5860. idx = make_long_input((2, 2), low=0, high=S)
  5861. idx[0, 0] = 1
  5862. idx[0, 1] = 1
  5863. weights = make_input((S, S))
  5864. per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
  5865. yield SampleInput(weights, args=(idx,),
  5866. kwargs={'scale_grad_by_freq': True, 'mode': mode,
  5867. 'per_sample_weights': per_sample_weights},)
  5868. # gradcheck not implemented for sparse tensors.
  5869. # Note : max mode does not support sparse weights
  5870. idx = make_long_input((6, ), low=0, high=S)
  5871. weights = make_input((S, S))
  5872. per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
  5873. yield SampleInput(weights, args=(idx,),
  5874. kwargs={'sparse': True, 'offsets': offsets,
  5875. 'mode': mode, 'per_sample_weights': per_sample_weights})
  5876. idx = make_long_input((6, ), low=0, high=S)
  5877. idx[0] = 1 # freq more than 1
  5878. idx[1] = 1 # freq more than 1
  5879. idx[3] = 0 # padding_idx
  5880. weights = make_input((S, S)) * 2
  5881. per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
  5882. yield SampleInput(weights, args=(idx,),
  5883. kwargs={'sparse': True, 'scale_grad_by_freq': True, 'padding_idx': 0,
  5884. 'max_norm': 1., 'offsets': offsets,
  5885. 'mode': mode, 'per_sample_weights': per_sample_weights})
  5886. def sample_inputs_embedding(op_info, device, dtype, requires_grad, **kwargs):
  5887. def make_input(shape):
  5888. return make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad)
  5889. def make_long_input(shape, *, low, high):
  5890. return make_tensor(shape, device=device, dtype=torch.long, low=low, high=high)
  5891. # 0-D index tensor
  5892. idx = make_long_input((), low=0, high=M)
  5893. yield SampleInput(make_input((M, S)), args=(idx,),)
  5894. # 1-D index tensor
  5895. idx = make_long_input((S,), low=0, high=M)
  5896. yield SampleInput(make_input((M, S)), args=(idx,),)
  5897. # 2-D index tensor
  5898. idx = make_long_input((S, S), low=0, high=M)
  5899. yield SampleInput(make_input((M, S)), args=(idx,),)
  5900. if not requires_grad:
  5901. # Following inputs return different gradient from the numerical gradient.
  5902. # This is expected and relevant tests are present in `test_nn.py`.
  5903. # The gradient vector at `padding_idx` is not updated.
  5904. idx = make_long_input((2, 2), low=0, high=S)
  5905. idx[0, 0] = 2
  5906. idx[1, 1] = 2
  5907. yield SampleInput(make_input((S, S)), args=(idx,), kwargs={'padding_idx': 2},)
  5908. idx = make_long_input((2, 2), low=0, high=S)
  5909. idx[0, 0] = 4
  5910. idx[1, 1] = 4
  5911. yield SampleInput(make_input((S, S)), args=(idx,), kwargs={'padding_idx': -1},)
  5912. # Due to inplace renorming of weight, the numerical gradient doesn't match the
  5913. # analytical gradient.
  5914. idx = make_long_input((2, 2), low=0, high=S)
  5915. weights = make_input((S, S)) * 2
  5916. yield SampleInput(weights, args=(idx,), kwargs={'max_norm': 1.},)
  5917. idx = make_long_input((2, 2), low=0, high=S)
  5918. weights = make_input((S, S)) * 2
  5919. yield SampleInput(weights, args=(idx,), kwargs={'max_norm': 1., 'norm_type': 1.0},)
  5920. # Scale the gradient based on the inverse frequency of a particular index.
  5921. idx = make_long_input((2, 2), low=0, high=S)
  5922. idx[0, 0] = 1
  5923. idx[0, 1] = 1
  5924. weights = make_input((S, S))
  5925. yield SampleInput(weights, args=(idx,), kwargs={'scale_grad_by_freq': True},)
  5926. # gradcheck not implemented for sparse tensors.
  5927. idx = make_long_input((2, 2), low=0, high=S)
  5928. weights = make_input((S, S))
  5929. yield SampleInput(weights, args=(idx,), kwargs={'sparse': True})
  5930. idx = make_long_input((3, 3), low=0, high=S)
  5931. idx[0, 0] = 1 # freq more than 1
  5932. idx[0, 1] = 1 # freq more than 1
  5933. idx[1, 0] = 0 # padding_idx
  5934. weights = make_input((S, S)) * 2
  5935. yield SampleInput(weights, args=(idx,),
  5936. kwargs={'sparse': True, 'scale_grad_by_freq': True,
  5937. 'padding_idx': 0, 'max_norm': 1.})
  5938. def sample_inputs_one_hot(op_info, device, dtype, requires_grad, **kwargs):
  5939. def make_input(shape, *, low, high):
  5940. return make_tensor(shape, device=device, dtype=dtype, low=low, high=high, requires_grad=requires_grad)
  5941. shapes = ((), (S,), (L, M, S))
  5942. num_classess = (-1, 10)
  5943. return (
  5944. SampleInput(
  5945. make_input(
  5946. shape,
  5947. low=0,
  5948. high=10 if num_classes == -1 else num_classes // 2,
  5949. ),
  5950. kwargs=dict(num_classes=num_classes),
  5951. )
  5952. for shape, num_classes in itertools.product(shapes, num_classess)
  5953. )
  5954. def sample_inputs_loss(op_info, device, dtype, requires_grad, **kwargs):
  5955. rhs_requires_grad = kwargs.get('rhs_requires_grad', requires_grad)
  5956. _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  5957. # Although most losses also support the reduce and size_average combination instead of reduce, the former is
  5958. # deprecated since 0.4.1 and thus is not tested
  5959. shapes_and_kwargs = (
  5960. ((), None),
  5961. ((S,), dict(reduction="mean")),
  5962. ((S,), dict(reduction="sum")),
  5963. ((S,), dict(reduction="none")),
  5964. ((S, S), None),
  5965. ((S, S, S), None),
  5966. )
  5967. for shape, kwargs in shapes_and_kwargs:
  5968. yield SampleInput(_make_tensor(shape),
  5969. args=(_make_tensor(shape, requires_grad=rhs_requires_grad),),
  5970. kwargs=kwargs)
  5971. def sample_inputs_grid_sample(op_info, device, dtype, requires_grad, **kwargs):
  5972. # We get better tests if we change the range of the values to something like [-2,2]
  5973. # because for grid (second tensor argument) the "useful" range is [-1,1] and this way
  5974. # you get a better combination of out-of-range and in-range test cases
  5975. _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad,
  5976. low=-2, high=2)
  5977. batch_size = 2
  5978. num_channels = 3
  5979. modes = ("bilinear", "nearest")
  5980. align_cornerss = (False, True)
  5981. padding_modes = ("zeros", "border", "reflection")
  5982. for dim in (2, 3):
  5983. modes_ = (*modes, "bicubic") if dim == 2 else modes
  5984. for mode, padding_mode, align_corners in itertools.product(modes_, padding_modes, align_cornerss):
  5985. yield SampleInput(
  5986. _make_tensor((batch_size, num_channels, *[S] * dim)),
  5987. _make_tensor((batch_size, *[S] * dim, dim)),
  5988. mode=mode,
  5989. padding_mode=padding_mode,
  5990. align_corners=align_corners,
  5991. )
  5992. def sample_inputs_grid_sampler_2d(op_info, device, dtype, requires_grad, **kwargs):
  5993. # We get better tests if we change the range of the values to something like [-2,2]
  5994. # because for grid (second tensor argument) the "useful" range is [-1,1] and this way
  5995. # you get a better combination of out-of-range and in-range test cases
  5996. _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad,
  5997. low=-2, high=2)
  5998. batch_size = 2
  5999. num_channels = 3
  6000. modes = (0, 1, 2)
  6001. align_cornerss = (False, True)
  6002. padding_modes = (0, 1, 2)
  6003. for mode, padding_mode, align_corners in itertools.product(modes, padding_modes, align_cornerss):
  6004. yield SampleInput(
  6005. _make_tensor((batch_size, num_channels, S, L)),
  6006. _make_tensor((batch_size, num_channels, M, 2)),
  6007. mode,
  6008. padding_mode,
  6009. align_corners,
  6010. )
  6011. def sample_inputs_cosine_embedding_loss(op_info, device, dtype, requires_grad, **kwargs):
  6012. make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  6013. def make_target(shape):
  6014. shape = () if len(shape) == 1 else (shape[0], )
  6015. t = torch.randint(0, 2, shape, device=device, dtype=torch.long)
  6016. # Label with -1 or 1
  6017. t = t * 2 - 1
  6018. target = t.to(dtype=dtype).detach_().requires_grad_(requires_grad)
  6019. return target
  6020. shapes = ((S, S), (S,))
  6021. reductions = ('none', 'mean', 'sum')
  6022. for s, r in product(shapes, reductions):
  6023. yield SampleInput(
  6024. make_input(s),
  6025. args=(make_input(s), make_target(s)),
  6026. kwargs=dict(reduction=r, margin=random.uniform(-1, 1))
  6027. )
  6028. def sample_inputs_ctc_loss(op_info, device, dtype, requires_grad, **kwargs):
  6029. input_length = 50
  6030. batch = 16
  6031. num_char = 20
  6032. target_length = 30
  6033. def make_log_probs(s):
  6034. t = make_tensor(s, device=device, dtype=dtype)
  6035. log_probs = t.log_softmax(2).to(device=device, dtype=dtype).detach().requires_grad_(requires_grad=requires_grad)
  6036. return log_probs
  6037. reductions = ('none', 'mean', 'sum')
  6038. zero_inf = (True, False)
  6039. for r, z in product(reductions, zero_inf):
  6040. log_probs = make_log_probs((input_length, batch, num_char))
  6041. targets = torch.randint(1, num_char, (batch, target_length), dtype=torch.long, device=device)
  6042. input_lengths = torch.full((batch, ), input_length, dtype=torch.long, device=device)
  6043. target_lengths = torch.randint(10, target_length, (batch, ), dtype=torch.long, device=device)
  6044. yield SampleInput(log_probs, args=(targets, input_lengths, target_lengths,), kwargs=dict(reduction=r, zero_infinity=z))
  6045. def sample_inputs_nll_loss(op_info, device, dtype, requires_grad, **kwargs):
  6046. shape = (2, 3)
  6047. num_classes = shape[1]
  6048. make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  6049. # FIXME: Derivative wrt. weight not implemented
  6050. make_weight = partial(make_tensor, num_classes, device=device, dtype=dtype, requires_grad=False)
  6051. def make_target(shape, zeros=False):
  6052. s = (shape[0], *shape[2:]) if len(shape) > 1 else ()
  6053. if zeros:
  6054. return torch.zeros(s, device=device, dtype=torch.long)
  6055. else:
  6056. return make_tensor(s,
  6057. low=0,
  6058. high=shape[1] if len(shape) > 1 else shape[0],
  6059. device=device,
  6060. dtype=torch.long)
  6061. def gen_shape_kwargs():
  6062. # Batched, non-batched and 2d
  6063. shapes = (shape, (num_classes,), shape + (2, 2))
  6064. reductions = ('none', 'mean', 'sum')
  6065. for reduction, s in product(reductions, shapes):
  6066. yield make_input(s), make_target(s), dict(reduction=reduction)
  6067. yield make_input(s), make_target(s), dict(weight=make_weight(), reduction=reduction)
  6068. yield make_input(s), make_target(s), dict(weight=make_weight(low=0), reduction=reduction)
  6069. yield make_input(s), make_target(s), dict(weight=make_weight(high=0), reduction=reduction)
  6070. t = make_target(s)
  6071. ignore = num_classes // 2
  6072. # If "mean", nll returns NaN, so it's not differentiable at those points
  6073. if t.eq(ignore).all() and reduction == "mean":
  6074. t.fill_(0)
  6075. yield make_input(s), t, dict(ignore_index=num_classes // 2, reduction=reduction)
  6076. yield make_input(s), t, dict(ignore_index=num_classes // 2, reduction=reduction, weight=make_weight())
  6077. # Test ignoring all the targets
  6078. # If "mean", nll returns NaN, so it's not differentiable at those points
  6079. if reduction != "mean":
  6080. yield make_input(s), make_target(s, zeros=True), dict(ignore_index=0, reduction=reduction)
  6081. for input, target, kwargs in gen_shape_kwargs():
  6082. yield SampleInput(input, args=(target,), kwargs=kwargs)
  6083. def sample_inputs_binary_cross_entropy_with_logits(
  6084. op_info, device, dtype, requires_grad, **kwargs
  6085. ):
  6086. make = partial(make_tensor, device=device, dtype=dtype)
  6087. make_prob = partial(make, low=0, high=1)
  6088. reductions = ("mean", "sum", "none")
  6089. def make_weight_shape_kwargs():
  6090. kwargs = []
  6091. for shape in ((1,), (1, S), (S), (S, S)):
  6092. kwargs.extend([((S, S), dict(reduction=reduction, weight=make(shape))) for reduction in reductions])
  6093. return kwargs
  6094. shapes_and_kwargs = [
  6095. *[(shape, None) for shape in ((), (1,), (S,), (S, S), (S, S, S))],
  6096. *[((S, S), dict(reduction=reduction)) for reduction in reductions],
  6097. *make_weight_shape_kwargs(),
  6098. *[((S, S), dict(reduction=reduction, pos_weight=make((S,), low=0))) for reduction in reductions],
  6099. *[((S, S), dict(reduction=reduction, weight=make((S, S)), pos_weight=make((S,), low=0))) for reduction in reductions],
  6100. ]
  6101. for shape, kwargs in shapes_and_kwargs:
  6102. yield SampleInput(
  6103. make(shape, requires_grad=requires_grad),
  6104. args=(make_prob(shape, requires_grad=requires_grad),),
  6105. kwargs=kwargs,
  6106. )
  6107. def sample_inputs_argwhere(op_info, device, dtype, requires_grad, **kwargs):
  6108. yield SampleInput(torch.tensor([1, 0, 2, 0], dtype=dtype, device=device, requires_grad=requires_grad))
  6109. mask = torch.tensor([[0, 1, 0, 1, 0],
  6110. [1, 1, 1, 1, 0],
  6111. [0, 0, 0, 1, 0],
  6112. [1, 0, 1, 1, 0],
  6113. [1, 0, 0, 1, 0]], dtype=torch.bool, device=device)
  6114. t = make_tensor((S, S), dtype=dtype, device=device, requires_grad=requires_grad)
  6115. t[mask] = 0
  6116. yield SampleInput(t)
  6117. t = make_tensor((S, S), dtype=dtype, device=device, requires_grad=requires_grad, noncontiguous=True)
  6118. t[mask] = 0
  6119. yield SampleInput(t)
  6120. t = make_tensor((S, 0), dtype=dtype, device=device, requires_grad=requires_grad)
  6121. yield SampleInput(t)
  6122. yield SampleInput(torch.zeros((S,), dtype=dtype, device=device, requires_grad=requires_grad))
  6123. yield SampleInput(make_tensor((), dtype=dtype, device=device, requires_grad=requires_grad))
  6124. def _generate_sample_shape_reduction():
  6125. shapes = ((S,), (S, S), (S, S, S))
  6126. reductions = ('none', 'mean', 'sum')
  6127. yield from product(shapes, reductions)
  6128. def sample_inputs_gaussian_nll_loss(op_info, device, dtype, requires_grad, **kwargs):
  6129. _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  6130. # Set low slightly above 0 so gradcheck doesn't accidentally dip below 0
  6131. make_var = partial(make_tensor, low=0.1, device=device, dtype=dtype, requires_grad=requires_grad)
  6132. def gen_shape(shape):
  6133. yield shape
  6134. # Broadcast
  6135. yield (*shape[:-1], 1)
  6136. yield shape[:-1]
  6137. def gen_shape_kwargs():
  6138. for s, r in _generate_sample_shape_reduction():
  6139. for t_s, v_s in product(gen_shape(s), gen_shape(s)):
  6140. yield _make_tensor(s), _make_tensor(t_s), make_var(v_s), dict(reduction=r)
  6141. yield (
  6142. _make_tensor(s), _make_tensor(t_s), make_var(v_s),
  6143. dict(full=True, reduction=r)
  6144. )
  6145. yield (
  6146. _make_tensor(s), _make_tensor(t_s), make_var(v_s),
  6147. dict(eps=random.uniform(1e-6, 1e-3), reduction=r)
  6148. )
  6149. yield (
  6150. _make_tensor(s), _make_tensor(t_s), make_var(v_s),
  6151. dict(full=True, eps=random.uniform(1e-6, 1e-3), reduction=r)
  6152. )
  6153. for input, target, var, kwargs in gen_shape_kwargs():
  6154. yield SampleInput(input, args=(target, var, ), kwargs=kwargs)
  6155. def error_inputs_gaussian_nll_loss(op_info, device, **kwargs):
  6156. _make = partial(make_tensor, device=device, dtype=torch.float32)
  6157. # invalid reduction value
  6158. yield ErrorInput(SampleInput(_make(10, 2, 3), _make(10, 2, 3), _make((10, 2, 3), low=0), reduction="abc"),
  6159. error_type=ValueError, error_regex="abc is not valid")
  6160. # var is of incorrect shape
  6161. yield ErrorInput(SampleInput(_make(10, 2, 3), _make(10, 2, 3), _make((10, 2, 2), low=0)),
  6162. error_type=ValueError, error_regex="var is of incorrect size")
  6163. # target is of incorrect shape
  6164. yield ErrorInput(SampleInput(_make(10, 2, 3), _make(10, 2, 2), _make((10, 2, 3), low=0)),
  6165. error_type=RuntimeError,
  6166. error_regex=(r"The size of tensor a \(3\) must match the size of tensor b \(2\) "
  6167. r"at non-singleton dimension 2"))
  6168. def _generate_sample_inputs_nn_loss(op_info, device, dtype, requires_grad, **kwargs):
  6169. _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  6170. for s, r in _generate_sample_shape_reduction():
  6171. yield _make_tensor(s), _make_tensor(s), dict(reduction=r)
  6172. def sample_inputs_hinge_embedding_loss(op_info, device, dtype, requires_grad, **kwargs):
  6173. for input, target, d in _generate_sample_inputs_nn_loss(op_info, device, dtype, requires_grad, **kwargs):
  6174. # target should contain either 1 or -1 as per docs
  6175. mask = torch.rand_like(target) > 0.5
  6176. target[mask] = 1
  6177. target[~mask] = -1
  6178. d['margin'] = random.uniform(-9, 9)
  6179. yield SampleInput(input, args=(target, ), kwargs=d)
  6180. # scalar input and target.
  6181. _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  6182. yield SampleInput(_make_tensor(()), args=(_make_tensor(()), ))
  6183. def error_inputs_hinge_embedding_loss(op, device, **kwargs):
  6184. make_input = partial(make_tensor, device=device, dtype=torch.float32)
  6185. # invalid reduction value
  6186. yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4),), kwargs={'reduction': 'abc'}),
  6187. error_type=ValueError, error_regex='is not a valid value')
  6188. def reference_inputs_hinge_embedding_loss(op, device, dtype, requires_grad, **kwargs):
  6189. yield from sample_inputs_hinge_embedding_loss(op, device, dtype, requires_grad, **kwargs)
  6190. make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  6191. for reduction in ('sum', 'mean', 'none'):
  6192. if dtype.is_floating_point: # only supports ints and floats
  6193. # NaN propagation
  6194. inp = make_input((10, ))
  6195. inp[2] = float('nan')
  6196. target = make_input((10, ))
  6197. # target should contain either 1 or -1 as per docs
  6198. mask = torch.rand_like(target) > 0.5
  6199. target[mask] = -1
  6200. target[~mask] = 1
  6201. yield SampleInput(inp, args=(target,), kwargs={'reduction': reduction})
  6202. # Inf Handling
  6203. inp = make_input((10, ))
  6204. inp[4] = float('inf')
  6205. target = make_input((10, ))
  6206. mask = torch.rand_like(target) > 0.5
  6207. target[mask] = -1
  6208. target[~mask] = 1
  6209. yield SampleInput(inp, args=(target,), kwargs={'reduction': reduction})
  6210. # Broadcasting
  6211. inp = make_input((5, 5))
  6212. target = make_input((1, 5))
  6213. mask = torch.rand_like(target) > 0.5
  6214. target[mask] = -1
  6215. target[~mask] = 1
  6216. yield SampleInput(inp, args=(target,), kwargs={'reduction': reduction})
  6217. def sample_inputs_huber_loss(op_info, device, dtype, requires_grad, **kwargs):
  6218. for input, target, d in _generate_sample_inputs_nn_loss(op_info, device, dtype, requires_grad, **kwargs):
  6219. d['delta'] = random.uniform(1e-3, 9)
  6220. yield SampleInput(input, args=(target, ), kwargs=d)
  6221. def error_inputs_huber_loss(op, device, **kwargs):
  6222. make_input = partial(make_tensor, device=device, dtype=torch.float32)
  6223. # invalid reduction value
  6224. err = 'is not a valid value for reduction'
  6225. yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4),), kwargs={'reduction': 'abc'}),
  6226. error_type=ValueError, error_regex=err)
  6227. # delta <= 0
  6228. for delta in (0, -1):
  6229. err = 'huber_loss does not support non-positive values for delta.'
  6230. yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4),), kwargs={'delta': delta}),
  6231. error_type=RuntimeError, error_regex=err)
  6232. def sample_inputs_poisson_nll_loss(op_info, device, dtype, requires_grad, **kwargs):
  6233. _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  6234. def gen_shape_kwargs():
  6235. for s, r in _generate_sample_shape_reduction():
  6236. for li in (True, False):
  6237. for f in (True, False):
  6238. i1 = _make_tensor(s)
  6239. i2 = _make_tensor(s)
  6240. # For Poisson NLL Loss,
  6241. # target is assumed to be from
  6242. # Poisson Distribution which
  6243. # always has positive samples
  6244. t1 = _make_tensor(s, low=0)
  6245. t2 = _make_tensor(s, low=0)
  6246. if not li:
  6247. i1.abs_()
  6248. i2.abs_()
  6249. t1.abs_()
  6250. t2.abs_()
  6251. yield (
  6252. i1, t1,
  6253. dict(log_input=li, full=f, reduction=r)
  6254. )
  6255. yield (
  6256. i2, t2,
  6257. dict(log_input=li, full=f,
  6258. eps=random.uniform(1e-8, 1e-3),
  6259. reduction=r)
  6260. )
  6261. for input, target, kwargs in gen_shape_kwargs():
  6262. yield SampleInput(input, args=(target, ), kwargs=kwargs)
  6263. # test INT_TO_FLOAT promotion
  6264. if dtype.is_complex:
  6265. for d in (torch.bool, torch.int64):
  6266. yield SampleInput(_make_tensor(dtype=dtype), args=(_make_tensor(dtype=d),))
  6267. yield SampleInput(_make_tensor(dtype=d), args=(_make_tensor(dtype=dtype),))
  6268. def error_inputs_poisson_nll_loss(op_info, device, **kwargs):
  6269. make = partial(make_tensor, device=device, dtype=torch.float32)
  6270. # invalid reduction value
  6271. yield ErrorInput(SampleInput(make(5, 4), args=(make(5, 4),),
  6272. kwargs={'reduction': 'abc'}),
  6273. error_type=ValueError,
  6274. error_regex='abc is not a valid value for reduction')
  6275. # invalid input shapes
  6276. yield ErrorInput(SampleInput(make(5, 4), args=(make(5,),)),
  6277. error_regex=(r'(Attempting to broadcast a dimension of length|'
  6278. r'The size of tensor a \(5\) must match the '
  6279. r'size of tensor b \(4\) at non-singleton '
  6280. r'dimension 1)'))
  6281. def error_inputs_soft_margin_loss(op_info, device, **kwargs):
  6282. make = partial(make_tensor, device=device, dtype=torch.float32)
  6283. # invalid reduction value
  6284. yield ErrorInput(SampleInput(make(5, 4), args=(make(5, 4),),
  6285. kwargs={'reduction': 'abc'}),
  6286. error_type=ValueError,
  6287. error_regex='abc is not a valid value for reduction')
  6288. # invalid input shapes
  6289. yield ErrorInput(SampleInput(make(5, 4), args=(make(5,),)),
  6290. error_regex=(r'(Attempting to broadcast a dimension of length|'
  6291. r'The size of tensor a \(4\) must match the '
  6292. r'size of tensor b \(5\) at non-singleton '
  6293. r'dimension 1)'))
  6294. def sample_inputs_triplet_margin_loss(op_info, device, dtype, requires_grad, with_distance=False, **kwargs):
  6295. make = partial(make_tensor, (S, M), device=device, dtype=dtype, requires_grad=requires_grad)
  6296. kwargss = (
  6297. *[dict(margin=margin) for margin in (1e-6, 1.0, 10.0)],
  6298. dict(swap=True),
  6299. *[dict(reduction=reduction) for reduction in ("mean", "sum", "none")],
  6300. )
  6301. for kwargs in kwargss:
  6302. input = make()
  6303. args = (make(), make())
  6304. if with_distance:
  6305. kwargs["distance_function"] = torch.nn.PairwiseDistance()
  6306. yield SampleInput(input, args=args, kwargs=kwargs)
  6307. def error_inputs_triplet_margin_loss(op_info, device, **kwargs):
  6308. make_input = partial(make_tensor, device=device, dtype=torch.float32)
  6309. samples = (
  6310. # input, args, kwargs, error_type, error_regex
  6311. # invalid reduction
  6312. (make_input(3, 4), (make_input(3, 4), make_input(3, 4)),
  6313. dict(reduction="abc"),
  6314. ValueError, "abc is not a valid value for reduction"),
  6315. # shape mismatch
  6316. (make_input(3, 5), (make_input(3, 4), make_input(3, 4)),
  6317. dict(),
  6318. RuntimeError,
  6319. (r'(Attempting to broadcast a dimension of length|'
  6320. r"The size of tensor a \(5\) must match the size of tensor b \(4\) "
  6321. r"at non-singleton dimension 1)")),
  6322. (make_input(3, 4), (make_input(3, 5), make_input(3, 4)),
  6323. dict(),
  6324. RuntimeError,
  6325. (r'(Attempting to broadcast a dimension of length|'
  6326. r"The size of tensor a \(4\) must match the size of tensor b \(5\) "
  6327. r"at non-singleton dimension 1)")),
  6328. (make_input(3, 4), (make_input(3, 4), make_input(3, 5)),
  6329. dict(),
  6330. RuntimeError,
  6331. (r'(Attempting to broadcast a dimension of length|'
  6332. r"The size of tensor a \(4\) must match the size of tensor b \(5\) "
  6333. r"at non-singleton dimension 1)")),
  6334. # different dimensions
  6335. (make_input(3,), (make_input(3, 4), make_input(3, 4)),
  6336. dict(),
  6337. RuntimeError,
  6338. (r"The anchor, positive, and negative tensors are expected to have "
  6339. r"the same number of dimensions, but got: anchor 1D, positive 2D, "
  6340. r"and negative 2D inputs")),
  6341. (make_input(3, 4), (make_input(3,), make_input(3, 4)),
  6342. dict(),
  6343. RuntimeError,
  6344. (r"The anchor, positive, and negative tensors are expected to have "
  6345. r"the same number of dimensions, but got: anchor 2D, positive 1D, "
  6346. r"and negative 2D inputs")),
  6347. (make_input(3, 4), (make_input(3, 4), make_input(3,)),
  6348. dict(),
  6349. RuntimeError,
  6350. (r"The anchor, positive, and negative tensors are expected to have "
  6351. r"the same number of dimensions, but got: anchor 2D, positive 2D, "
  6352. r"and negative 1D inputs")),
  6353. )
  6354. for input, args, kwargs, error_type, error_regex in samples:
  6355. yield ErrorInput(SampleInput(input, args=args, kwargs=kwargs),
  6356. error_type=error_type, error_regex=error_regex)
  6357. def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_grad, **kwargs):
  6358. make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  6359. batch, seq_q, seq_kv, num_heads, head_dim = 4, 3, 6, 4, 8
  6360. dim_3_q_shape = (batch, seq_q, head_dim)
  6361. dim_3_kv_shape = (batch, seq_kv, head_dim)
  6362. dim_4_q_shape = (batch, num_heads, seq_q, head_dim)
  6363. dim_4_kv_shape = (batch, num_heads, seq_kv, head_dim)
  6364. broadcast_tuple = ((num_heads, seq_q, head_dim), (batch, num_heads, seq_kv, head_dim))
  6365. qkv_shapes = [(dim_3_q_shape, dim_3_kv_shape), (dim_4_q_shape, dim_4_kv_shape), broadcast_tuple]
  6366. samples = []
  6367. for qkv_shapes, is_causal, dropout_p in product(
  6368. qkv_shapes, [True, False], [0.0, 0.5]):
  6369. shape_q, shape_kv = qkv_shapes
  6370. samples.append(SampleInput(
  6371. make(shape_q),
  6372. make(shape_kv),
  6373. make(shape_kv),
  6374. is_causal=is_causal,
  6375. dropout_p=dropout_p
  6376. ))
  6377. # Add non standard shapes
  6378. diff_v_head_dim = SampleInput(
  6379. make((batch, num_heads, seq_q, head_dim)),
  6380. make((batch, num_heads, seq_kv, head_dim)),
  6381. make((batch, num_heads, seq_kv, head_dim + 8)),
  6382. is_causal=is_causal,
  6383. dropout_p=dropout_p
  6384. )
  6385. samples.append(diff_v_head_dim)
  6386. yield from samples
  6387. def sample_inputs_pairwise_distance(op_info, device, dtype, requires_grad, **kwargs):
  6388. make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  6389. shape = (3,)
  6390. batched_shape = (2, *shape)
  6391. shapes_and_kwargs = [
  6392. (shape, None),
  6393. (batched_shape, None),
  6394. (shape, dict(keepdim=True)),
  6395. (batched_shape, dict(keepdim=True)),
  6396. (shape, dict(p=5.0)),
  6397. (shape, dict(p=-1.0)),
  6398. (shape, dict(eps=1.0)),
  6399. ]
  6400. return (
  6401. SampleInput(make(shape), args=(make(shape),), kwargs=kwargs) for shape, kwargs in shapes_and_kwargs
  6402. )
  6403. def sample_inputs_pixel_shuffle(op_info, device, dtype, requires_grad, **kwargs):
  6404. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  6405. yield from (
  6406. SampleInput(make_arg((1, 9, 2, 2)), upscale_factor=upscale_factor)
  6407. for upscale_factor in (1, 3)
  6408. )
  6409. yield from (
  6410. SampleInput(make_arg(shape), upscale_factor=1)
  6411. for shape in [
  6412. (1, 0, 1, 1),
  6413. (1, 1, 0, 1),
  6414. (1, 1, 1, 0),
  6415. ]
  6416. )
  6417. def sample_inputs_pixel_unshuffle(op_info, device, dtype, requires_grad, **kwargs):
  6418. make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  6419. yield from (
  6420. SampleInput(make_arg((1, 1, 6, 6)), downscale_factor=downscale_factor)
  6421. for downscale_factor in (1, 3)
  6422. )
  6423. yield from (
  6424. SampleInput(make_arg(shape), downscale_factor=1)
  6425. for shape in [
  6426. (1, 0, 1, 1),
  6427. (1, 1, 0, 1),
  6428. (1, 1, 1, 0),
  6429. ]
  6430. )
  6431. def sample_inputs_binary_cross_entropy(op_info, device, dtype, requires_grad, logits=False, **kwargs):
  6432. make = partial(make_tensor, device=device, dtype=dtype)
  6433. make_prob = partial(make, low=0, high=1)
  6434. reductions = ("mean", "sum", "none")
  6435. shapes_and_kwargs = [
  6436. *[(shape, None) for shape in ((), (1,), (S,), (S, S), (S, S, S))],
  6437. *[((S, S), dict(reduction=reduction)) for reduction in reductions],
  6438. *[((S, S), dict(reduction=reduction, weight=make((S, S)))) for reduction in reductions],
  6439. ]
  6440. if logits:
  6441. shapes_and_kwargs.extend(
  6442. [((S, S), dict(reduction=reduction, pos_weight=make((S,), low=0))) for reduction in reductions]
  6443. )
  6444. for shape, kwargs in shapes_and_kwargs:
  6445. yield SampleInput(
  6446. (make if logits else make_prob)(shape, requires_grad=requires_grad),
  6447. args=(make_prob(shape, requires_grad=requires_grad),),
  6448. kwargs=kwargs,
  6449. )
  6450. def sample_inputs_allclose(op_info, device, dtype, requires_grad, **kwargs):
  6451. sample_shapes = [(), (S), (S, S, S)]
  6452. atols = [1e-2, 1e-16]
  6453. rtols = [1e-1, 0.5]
  6454. eps = 1e-8
  6455. for s, rtol, atol in product(sample_shapes, rtols, atols):
  6456. # close sample
  6457. t = make_tensor(s, device=device, dtype=dtype, requires_grad=requires_grad)
  6458. close = (t + atol).detach().requires_grad_(requires_grad)
  6459. yield SampleInput(t, close, rtol=rtol, atol=atol)
  6460. # random sample
  6461. a = make_tensor(s, device=device, dtype=dtype, requires_grad=requires_grad)
  6462. b = make_tensor(s, device=device, dtype=dtype, requires_grad=requires_grad)
  6463. yield SampleInput(a, b, rtol=rtol, atol=atol)
  6464. def sample_inputs_l1_loss(op_info, device, dtype, requires_grad, **kwargs):
  6465. yield from sample_inputs_loss(op_info, device, dtype, requires_grad, **kwargs)
  6466. # test COMPLEX_TO_FLOAT promotion
  6467. if dtype.is_complex:
  6468. make = partial(make_tensor, (), device=device, requires_grad=requires_grad)
  6469. yield SampleInput(make(dtype=dtype), args=(make(dtype=torch.double),))
  6470. yield SampleInput(make(dtype=torch.double), args=(make(dtype=dtype),))
  6471. def error_inputs_l1_loss(op_info, device, **kwargs):
  6472. make = partial(make_tensor, device=device, dtype=torch.float32)
  6473. # invalid reduction value
  6474. yield ErrorInput(SampleInput(make(5, 4), args=(make(5, 4),),
  6475. kwargs={'reduction': 'abc'}),
  6476. error_type=ValueError,
  6477. error_regex='abc is not a valid value for reduction')
  6478. # invalid input shapes
  6479. yield ErrorInput(SampleInput(make(5, 4), args=(make(5,),)),
  6480. error_regex=(r'(Attempting to broadcast a dimension of length|'
  6481. r'The size of tensor a \(4\) must match the '
  6482. r'size of tensor b \(5\) at non-singleton '
  6483. r'dimension 1)')
  6484. )
  6485. def sample_inputs_smooth_l1_loss(op_info, device, dtype, requires_grad, **kwargs):
  6486. yield from sample_inputs_loss(op_info, device, dtype, requires_grad, **kwargs)
  6487. make = partial(make_tensor, (S, S), device=device, dtype=dtype, requires_grad=requires_grad)
  6488. # This test case always triggers the smooth condition, since absolute difference of input and target
  6489. # is smaller than beta
  6490. yield SampleInput(make(low=0, high=2), args=(make(low=-2, high=0),), kwargs=dict(beta=5))
  6491. yield SampleInput(make(), args=(make(),), kwargs=dict(beta=0))
  6492. def sample_inputs_kl_div(op_info, device, dtype, requires_grad, **kwargs):
  6493. # kl_div works with inputs in [0, 1] (aka the pdf of a probability measure)
  6494. # Then log [0, 1] = (-inf, 0], so this is the log space
  6495. make_arg = partial(make_tensor, low=0., device=device, dtype=dtype, requires_grad=requires_grad)
  6496. def make_log(shape):
  6497. out = torch.nn.functional.log_softmax(make_arg(shape), -1)
  6498. out.requires_grad_(requires_grad)
  6499. return out
  6500. def make_prob(shape):
  6501. out = torch.nn.functional.softmax(make_arg(shape), -1)
  6502. out.requires_grad_(requires_grad)
  6503. return out
  6504. shapes = ((2,), (2, 3))
  6505. reductions = ("none", "mean", "batchmean", "sum")
  6506. for shape, reduction, log_target in product(shapes, reductions, (True, False)):
  6507. input = make_log(shape)
  6508. target = make_log(shape) if log_target else make_prob(shape)
  6509. yield SampleInput(input, args=(target,), kwargs=dict(reduction=reduction, log_target=log_target))
  6510. def sample_inputs_pdist(op_info, device, dtype, requires_grad, **kwargs):
  6511. make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  6512. yield from (SampleInput(make_input((n, m))) for n, m in itertools.product((1, S), repeat=2))
  6513. yield from (SampleInput(make_input((S, S)), kwargs=dict(p=p)) for p in (0.0, 1.0, 2.0, 10.0, float("inf")))
  6514. def reference_pdist(input, p=2):
  6515. pdist = scipy.spatial.distance.pdist
  6516. if p == 0:
  6517. output = pdist(input, "hamming") * input.shape[1]
  6518. elif p == float("inf"):
  6519. output = pdist(input, lambda x, y: np.abs(x - y).max())
  6520. else:
  6521. output = pdist(input, "minkowski", p=p)
  6522. return output.astype(input.dtype)
  6523. def sample_inputs_diagflat(op_info, device, dtype, requires_grad, **kwargs):
  6524. make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
  6525. yield SampleInput(make_input(()))
  6526. yield SampleInput(make_input((2,)))
  6527. yield SampleInput(make_input((2, 2)))
  6528. yield SampleInput(make_input((2,)), offset=1)
  6529. yield SampleInput(make_input((2,)), offset=-1)
  6530. def sample_inputs_max_unpool(op_info, device, dtype, requires_grad, **kwargs):
  6531. unpool_name_to_pool_method_dict = {
  6532. 'nn.functional.max_unpool1d': torch.nn.functional.max_pool1d,
  6533. 'nn.functional.max_unpool2d': torch.nn.functional.max_pool2d,
  6534. 'nn.functional.max_unpool3d': torch.nn.functional.max_pool3d
  6535. }
  6536. unpool_name_to_dim = {
  6537. 'nn.functional.max_unpool1d': 1,
  6538. 'nn.functional.max_unpool2d': 2,
  6539. 'nn.functional.max_unpool3d': 3
  6540. }
  6541. unpool_to_pool_name_dict = {k: f'nn.functional.{v.__name__}' for k, v in unpool_name_to_pool_method_dict.items()}
  6542. pool_dim = unpool_name_to_dim[op_info.name]
  6543. pool_method = unpool_name_to_pool_method_dict[op_info.name]
  6544. pool_op_info = copy.copy(op_info)
  6545. pool_op_info.name = unpool_to_pool_name_dict[op_info.name]
  6546. for sample in sample_inputs_max_pool(pool_op_info, device, dtype, requires_grad, **kwargs):
  6547. # shapes (C, ...) do not work as of now,
  6548. # see https://github.com/pytorch/pytorch/issues/68337
  6549. # TODO: remove once the issue is resolved
  6550. if sample.input.dim() != pool_dim + 2:
  6551. continue
  6552. # No dilation > 1 for max_unpool,
  6553. # see https://github.com/pytorch/pytorch/issues/68420
  6554. if sample.kwargs['dilation'] != 1:
  6555. continue
  6556. # Can't unpool without indices
  6557. if sample.kwargs['return_indices']:
  6558. pool, indices = pool_method(sample.input, **sample.kwargs)
  6559. # arg has to be a leaf
  6560. arg = pool.detach().requires_grad_(requires_grad)
  6561. sample_kwargs = {
  6562. 'kernel_size': sample.kwargs['kernel_size'],
  6563. 'stride': sample.kwargs['stride'],
  6564. 'padding': sample.kwargs['padding'],
  6565. # output_size could be None but we specify it explicitly
  6566. # to compensate for the information lose in pool due
  6567. # to the floor/ceil operation used to compute the shapes
  6568. 'output_size': sample.input.size()
  6569. }
  6570. yield SampleInput(arg, args=(indices,), kwargs=sample_kwargs)
  6571. def sample_inputs_max_unpool_grad(op_info, device, dtype, requires_grad, **kwargs):
  6572. for sample in sample_inputs_max_unpool(op_info, device, dtype, requires_grad, **kwargs):
  6573. indices = sample.args[0]
  6574. # The samples for max_unpool are generated with max_pool.
  6575. # It could be that a single element from the max_pool's
  6576. # input is mapped to several locations in its output.
  6577. # This situation leads to failed gradchecks because
  6578. # the finite difference algorithm perturbes the elements
  6579. # of the output one by one, and not in classes of
  6580. # equivalences determined by whether two elements
  6581. # in the output are coming from the same location in the
  6582. # input (simply put, they have the same corresponding index).
  6583. # So, there are two ways to resolve this issue:
  6584. # 1. Extract a pertubation for one element and apply it all
  6585. # the elements from the same equivalence class, or
  6586. # 2. Make sure that the equivalence classes are all singletons,
  6587. # i.e. the index tensor has to be comprised of only unique
  6588. # indices.
  6589. # Here we go with the solution 2, the easiest of all.
  6590. if indices.unique().numel() == indices.numel():
  6591. yield sample
  6592. # Includes some values such that N * N won't be a multiple of 4,
  6593. # which should ensure we test the vectorized and non-vectorized
  6594. # kernel code paths.
  6595. foreach_num_tensors = [20, 23] if not TEST_WITH_SLOW else [23, 30, 300]
  6596. class ForeachRightmostArgType(enum.Enum):
  6597. TensorList = 1
  6598. ScalarList = 2
  6599. Scalar = 3
  6600. foreach_scalars = (
  6601. random.randint(1, 10),
  6602. 1.0 - random.random(),
  6603. True,
  6604. complex(1.0 - random.random(), 1.0 - random.random()),
  6605. )
  6606. _foreach_inputs_default_kwargs = {"noncontiguous": False, "same_size": False, "low": None, "high": None}
  6607. # TODO(crcrpar): Update to return `n_expected_cudaLaunchKernels` as well
  6608. class foreach_inputs_sample_func:
  6609. def __init__(
  6610. self,
  6611. arity: int,
  6612. rightmost_supports_scalar: bool,
  6613. rightmost_supports_scalarlist: bool,
  6614. ) -> None:
  6615. self.arity = arity
  6616. self._set_rightmost_arg_types(rightmost_supports_scalar, rightmost_supports_scalarlist)
  6617. def _set_rightmost_arg_types(
  6618. self,
  6619. rightmost_supports_scalar: bool,
  6620. rightmost_supports_scalarlist: bool,
  6621. ) -> None:
  6622. self._rightmost_arg_types = [ForeachRightmostArgType.TensorList]
  6623. if self.arity > 1:
  6624. if rightmost_supports_scalar:
  6625. self._rightmost_arg_types.append(ForeachRightmostArgType.Scalar)
  6626. if rightmost_supports_scalarlist:
  6627. self._rightmost_arg_types.append(ForeachRightmostArgType.ScalarList)
  6628. def _sample_rightmost_arg(self, rightmost_arg_type, device, dtype, num_tensors, **_foreach_inputs_kwargs):
  6629. if rightmost_arg_type == ForeachRightmostArgType.TensorList:
  6630. return [sample_inputs_foreach(None, device, dtype, num_tensors, **_foreach_inputs_kwargs)]
  6631. if rightmost_arg_type == ForeachRightmostArgType.ScalarList:
  6632. return [
  6633. [random.randint(0, 9) + 1 for _ in range(num_tensors)],
  6634. [1.0 - random.random() for _ in range(num_tensors)],
  6635. [complex(1.0 - random.random(), 1.0 - random.random()) for _ in range(num_tensors)],
  6636. [True for _ in range(num_tensors)],
  6637. [1, 2.0, 3.0 + 4.5j] + [3.0 for _ in range(num_tensors - 3)],
  6638. [True, 1, 2.0, 3.0 + 4.5j] + [3.0 for _ in range(num_tensors - 4)],
  6639. ]
  6640. if rightmost_arg_type == ForeachRightmostArgType.Scalar:
  6641. return foreach_scalars
  6642. raise AssertionError(f"Invalid rightmost_arg_type of {rightmost_arg_type}")
  6643. def _should_disable_fastpath(self, opinfo, rightmost_arg, rightmost_arg_type, dtype):
  6644. if self.arity < 2:
  6645. return None
  6646. if rightmost_arg_type == ForeachRightmostArgType.TensorList:
  6647. disable_fastpath = "foreach_div" in opinfo.name and dtype in integral_types_and(torch.bool)
  6648. if "foreach_add" in opinfo.name and dtype == torch.bool:
  6649. disable_fastpath = True
  6650. return disable_fastpath
  6651. elif rightmost_arg_type == ForeachRightmostArgType.Scalar:
  6652. disable_fastpath = "foreach_div" in opinfo.name and dtype in integral_types_and(torch.bool)
  6653. if isinstance(rightmost_arg, bool):
  6654. disable_fastpath |= dtype == torch.bool
  6655. if opinfo.ref in (torch.add, torch.mul):
  6656. disable_fastpath = False
  6657. elif isinstance(rightmost_arg, int):
  6658. disable_fastpath |= dtype == torch.bool
  6659. elif isinstance(rightmost_arg, float):
  6660. disable_fastpath |= dtype in integral_types_and(torch.bool)
  6661. elif isinstance(rightmost_arg, complex):
  6662. disable_fastpath |= dtype not in complex_types()
  6663. else:
  6664. raise AssertionError(f"Invalid scalar of type {rightmost_arg_type} - {rightmost_arg}")
  6665. return disable_fastpath
  6666. elif rightmost_arg_type == ForeachRightmostArgType.ScalarList:
  6667. disable_fastpath = opinfo.ref == torch.div and dtype in integral_types_and(torch.bool)
  6668. elmt_t = type(rightmost_arg[0])
  6669. has_same_type = all(isinstance(v, elmt_t) for v in rightmost_arg)
  6670. if not has_same_type:
  6671. return dtype not in complex_types()
  6672. if isinstance(rightmost_arg[0], bool):
  6673. if ("foreach_add" in opinfo.name or "foreach_mul" in opinfo.name) and dtype == torch.bool:
  6674. disable_fastpath = False
  6675. elif isinstance(rightmost_arg[0], int):
  6676. disable_fastpath |= dtype == torch.bool
  6677. elif isinstance(rightmost_arg[0], float):
  6678. disable_fastpath |= dtype in integral_types_and(torch.bool)
  6679. elif isinstance(rightmost_arg[0], complex):
  6680. disable_fastpath |= dtype not in complex_types()
  6681. else:
  6682. raise AssertionError(f"Invalid scalarlist of {rightmost_arg}")
  6683. return disable_fastpath
  6684. else:
  6685. raise AssertionError(f"Invalid rightmost_arg_type of {rightmost_arg_type}")
  6686. def _sample_kwargs(self, opinfo, rightmost_arg, rightmost_arg_type, dtype):
  6687. kwargs = {}
  6688. if rightmost_arg_type == ForeachRightmostArgType.TensorList and opinfo.supports_alpha_param:
  6689. if dtype in integral_types_and(torch.bool):
  6690. kwargs["alpha"] = 3
  6691. elif dtype.is_complex:
  6692. kwargs["alpha"] = complex(3, 3)
  6693. else:
  6694. kwargs["alpha"] = 3.14
  6695. if self.arity > 1:
  6696. kwargs["disable_fastpath"] = self._should_disable_fastpath(opinfo, rightmost_arg, rightmost_arg_type, dtype)
  6697. return kwargs
  6698. def __call__(self, opinfo, device, dtype, requires_grad, **kwargs):
  6699. num_input_tensors = kwargs.pop("num_input_tensors", foreach_num_tensors)
  6700. assert isinstance(num_input_tensors, list)
  6701. _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()}
  6702. for num_tensors in num_input_tensors:
  6703. for rightmost_arg_type in self._rightmost_arg_types:
  6704. input = sample_inputs_foreach(None, device, dtype, num_tensors, **_foreach_inputs_kwargs)
  6705. args = []
  6706. kwargs = {}
  6707. if self.arity > 1:
  6708. args = [
  6709. sample_inputs_foreach(None, device, dtype, num_tensors, **_foreach_inputs_kwargs)
  6710. for _ in range(self.arity - 2)
  6711. ]
  6712. rightmost_arg_list = self._sample_rightmost_arg(
  6713. rightmost_arg_type, device, dtype, num_tensors, **_foreach_inputs_kwargs)
  6714. for rightmost_arg in rightmost_arg_list:
  6715. args.append(rightmost_arg)
  6716. kwargs = self._sample_kwargs(opinfo, rightmost_arg, rightmost_arg_type, dtype)
  6717. yield SampleInput(input, *args, **kwargs)
  6718. args.pop()
  6719. else:
  6720. if opinfo.ref in (torch.abs, torch.neg):
  6721. kwargs["disable_fastpath"] = False
  6722. else:
  6723. kwargs["disable_fastpath"] = dtype in integral_types_and(torch.bool)
  6724. yield SampleInput(input, *args, **kwargs)
  6725. class foreach_norm_sample_func(foreach_inputs_sample_func):
  6726. def __call__(self, opinfo, device, dtype, requires_grad, **kwargs):
  6727. num_input_tensors = kwargs.pop("num_input_tensors", foreach_num_tensors)
  6728. assert isinstance(num_input_tensors, list)
  6729. _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()}
  6730. for num_tensors, ord in product(num_input_tensors, (0, 1, 2, -1, -2)):
  6731. input = sample_inputs_foreach(None, device, dtype, num_tensors, **_foreach_inputs_kwargs)
  6732. disable_fastpath = True
  6733. if ord in (1, 2) and dtype in floating_types_and(torch.half, torch.bfloat16):
  6734. disable_fastpath = False
  6735. yield SampleInput(input, **{"ord": ord, "disable_fastpath": disable_fastpath})
  6736. class foreach_lerp_sample_func(foreach_inputs_sample_func):
  6737. def _sample_rightmost_arg(self, rightmost_arg_type, device, dtype, num_tensors, **_foreach_inputs_kwargs):
  6738. if rightmost_arg_type == ForeachRightmostArgType.TensorList:
  6739. return [sample_inputs_foreach(None, device, dtype, num_tensors, **_foreach_inputs_kwargs)]
  6740. if rightmost_arg_type == ForeachRightmostArgType.ScalarList:
  6741. return [
  6742. [random.randint(0, 9) + 1 for _ in range(num_tensors)],
  6743. [1.0 - random.random() for _ in range(num_tensors)],
  6744. [complex(1.0 - random.random(), 1.0 - random.random()) for _ in range(num_tensors)],
  6745. [True for _ in range(num_tensors)],
  6746. [1, 2.0, 3.0 + 4.5j] + [3.0 for _ in range(num_tensors - 3)],
  6747. [True, 1, 2.0, 3.0 + 4.5j] + [3.0 for _ in range(num_tensors - 4)],
  6748. ]
  6749. if rightmost_arg_type == ForeachRightmostArgType.Scalar:
  6750. return [random.random()]
  6751. raise AssertionError(f"Invalid rightmost_arg_type of {rightmost_arg_type}")
  6752. class foreach_pointwise_sample_func(foreach_inputs_sample_func):
  6753. def __init__(
  6754. self,
  6755. arity: int = 3,
  6756. rightmost_supports_scalar: bool = False,
  6757. rightmost_supports_scalarlist: bool = False,
  6758. ):
  6759. super().__init__(3 + 1, True, True)
  6760. def _should_disable_fastpath(self, opinfo, rightmost_arg, rightmost_arg_type, dtype):
  6761. return dtype in integral_types_and(torch.bool) and opinfo.ref in (torch.addcmul,)
  6762. def __call__(self, opinfo, device, dtype, requires_grad, **kwargs):
  6763. num_input_tensors = kwargs.pop("num_input_tensors", foreach_num_tensors)
  6764. assert isinstance(num_input_tensors, list)
  6765. _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()}
  6766. for num_tensors in num_input_tensors:
  6767. for rightmost_arg_type in self._rightmost_arg_types:
  6768. input = sample_inputs_foreach(None, device, dtype, num_tensors, **_foreach_inputs_kwargs)
  6769. args = [
  6770. sample_inputs_foreach(None, device, dtype, num_tensors, **_foreach_inputs_kwargs)
  6771. for _ in range(2 - int(rightmost_arg_type == ForeachRightmostArgType.TensorList))
  6772. ]
  6773. rightmost_arg_list = self._sample_rightmost_arg(
  6774. rightmost_arg_type, device, dtype, num_tensors, **_foreach_inputs_kwargs)
  6775. for rightmost_arg in rightmost_arg_list:
  6776. kwargs = {}
  6777. if rightmost_arg_type == ForeachRightmostArgType.TensorList:
  6778. args.append(rightmost_arg)
  6779. kwargs["values"] = None
  6780. else:
  6781. kwargs["values"] = rightmost_arg
  6782. kwargs.update(
  6783. self._sample_kwargs(
  6784. opinfo, kwargs["values"] or rightmost_arg, rightmost_arg_type, dtype)
  6785. )
  6786. assert hasattr(kwargs, "values")
  6787. assert len(args) == 2
  6788. yield SampleInput(input, *args, **kwargs)
  6789. foreach_unary_op_db: List[OpInfo] = [
  6790. ForeachFuncInfo('exp', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
  6791. ForeachFuncInfo('acos', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
  6792. ForeachFuncInfo('asin', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
  6793. ForeachFuncInfo('atan', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
  6794. ForeachFuncInfo('cos', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
  6795. ForeachFuncInfo('cosh', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
  6796. ForeachFuncInfo('log', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
  6797. ForeachFuncInfo('log10', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
  6798. ForeachFuncInfo('log2', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
  6799. ForeachFuncInfo('tan', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
  6800. ForeachFuncInfo('tanh', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
  6801. ForeachFuncInfo('sin', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
  6802. ForeachFuncInfo('sinh', sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True),
  6803. ForeachFuncInfo(
  6804. 'neg',
  6805. dtypes=all_types_and_complex(),
  6806. dtypesIfCUDA=all_types_and_complex(),
  6807. sample_inputs_func=foreach_inputs_sample_func(1, False, False),
  6808. supports_autograd=True,
  6809. ),
  6810. ForeachFuncInfo(
  6811. 'sqrt',
  6812. dtypes=floating_and_complex_types_and(torch.bfloat16),
  6813. dtypesIfCUDA=floating_and_complex_types_and(torch.half),
  6814. sample_inputs_func=foreach_inputs_sample_func(1, False, False),
  6815. supports_autograd=True,
  6816. ),
  6817. ForeachFuncInfo(
  6818. 'ceil',
  6819. dtypes=all_types_and(torch.bfloat16),
  6820. dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
  6821. sample_inputs_func=foreach_inputs_sample_func(1, False, False),
  6822. supports_autograd=True,
  6823. ),
  6824. ForeachFuncInfo(
  6825. 'erf',
  6826. dtypes=floating_types_and(torch.bfloat16),
  6827. dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
  6828. sample_inputs_func=foreach_inputs_sample_func(1, False, False),
  6829. supports_autograd=True,
  6830. ),
  6831. ForeachFuncInfo(
  6832. 'erfc',
  6833. dtypes=floating_types_and(torch.bfloat16),
  6834. dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
  6835. sample_inputs_func=foreach_inputs_sample_func(1, False, False),
  6836. supports_autograd=True,
  6837. ),
  6838. ForeachFuncInfo(
  6839. 'expm1',
  6840. dtypes=floating_types_and(torch.bfloat16),
  6841. dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
  6842. sample_inputs_func=foreach_inputs_sample_func(1, False, False),
  6843. supports_autograd=True,
  6844. ),
  6845. ForeachFuncInfo(
  6846. 'floor',
  6847. dtypes=all_types_and(torch.bfloat16),
  6848. dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
  6849. sample_inputs_func=foreach_inputs_sample_func(1, False, False),
  6850. supports_autograd=True,
  6851. ),
  6852. ForeachFuncInfo(
  6853. 'log1p',
  6854. dtypes=floating_and_complex_types_and(torch.bfloat16),
  6855. dtypesIfCUDA=floating_and_complex_types_and(torch.half),
  6856. sample_inputs_func=foreach_inputs_sample_func(1, False, False),
  6857. supports_autograd=True,
  6858. ),
  6859. ForeachFuncInfo(
  6860. 'round',
  6861. dtypes=all_types_and(torch.bfloat16),
  6862. dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
  6863. sample_inputs_func=foreach_inputs_sample_func(1, False, False),
  6864. supports_autograd=True,
  6865. ),
  6866. ForeachFuncInfo(
  6867. 'frac',
  6868. dtypes=floating_types_and(torch.bfloat16),
  6869. dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
  6870. sample_inputs_func=foreach_inputs_sample_func(1, False, False),
  6871. supports_autograd=True,
  6872. ),
  6873. ForeachFuncInfo(
  6874. 'reciprocal',
  6875. dtypes=floating_types_and(torch.bfloat16),
  6876. dtypesIfCUDA=floating_types_and(torch.half),
  6877. sample_inputs_func=foreach_inputs_sample_func(1, False, False),
  6878. supports_autograd=True,
  6879. ),
  6880. ForeachFuncInfo(
  6881. 'sigmoid',
  6882. dtypes=floating_types_and(torch.bfloat16),
  6883. dtypesIfCUDA=floating_types_and(torch.half),
  6884. sample_inputs_func=foreach_inputs_sample_func(1, False, False),
  6885. supports_autograd=True,
  6886. ),
  6887. ForeachFuncInfo(
  6888. 'trunc',
  6889. dtypes=all_types_and(torch.bfloat16),
  6890. dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
  6891. sample_inputs_func=foreach_inputs_sample_func(1, False, False),
  6892. supports_autograd=True,
  6893. ),
  6894. ForeachFuncInfo(
  6895. 'abs',
  6896. dtypes=all_types_and_complex_and(torch.bfloat16, torch.half),
  6897. dtypesIfCUDA=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool),
  6898. supports_forward_ad=True,
  6899. supports_fwgrad_bwgrad=True,
  6900. sample_inputs_func=foreach_inputs_sample_func(1, False, False),
  6901. supports_autograd=True,
  6902. ),
  6903. ]
  6904. foreach_binary_op_db: List[OpInfo] = [
  6905. ForeachFuncInfo(
  6906. "add",
  6907. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
  6908. dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
  6909. supports_alpha_param=True,
  6910. sample_inputs_func=foreach_inputs_sample_func(2, True, True),
  6911. ),
  6912. ForeachFuncInfo(
  6913. "sub",
  6914. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
  6915. dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
  6916. supports_alpha_param=True,
  6917. sample_inputs_func=foreach_inputs_sample_func(2, True, True),
  6918. ),
  6919. ForeachFuncInfo(
  6920. "mul",
  6921. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
  6922. dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
  6923. sample_inputs_func=foreach_inputs_sample_func(2, True, True),
  6924. ),
  6925. ForeachFuncInfo(
  6926. "div",
  6927. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
  6928. dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
  6929. sample_inputs_func=foreach_inputs_sample_func(2, True, True),
  6930. ),
  6931. ForeachFuncInfo(
  6932. "clamp_min",
  6933. dtypes=all_types_and(torch.bfloat16),
  6934. dtypesIfCUDA=all_types_and(torch.bfloat16, torch.float16),
  6935. supports_alpha_param=False,
  6936. sample_inputs_func=foreach_inputs_sample_func(2, True, True),
  6937. ),
  6938. ForeachFuncInfo(
  6939. "clamp_max",
  6940. dtypes=all_types_and(torch.bfloat16),
  6941. dtypesIfCUDA=all_types_and(torch.bfloat16, torch.float16),
  6942. supports_alpha_param=False,
  6943. sample_inputs_func=foreach_inputs_sample_func(2, True, True),
  6944. ),
  6945. ForeachFuncInfo(
  6946. "minimum",
  6947. dtypes=all_types_and(torch.bfloat16),
  6948. dtypesIfCUDA=all_types_and(torch.bfloat16, torch.float16),
  6949. supports_alpha_param=False,
  6950. sample_inputs_func=foreach_inputs_sample_func(2, True, True),
  6951. ),
  6952. ForeachFuncInfo(
  6953. "maximum",
  6954. dtypes=all_types_and(torch.bfloat16),
  6955. dtypesIfCUDA=all_types_and(torch.bfloat16, torch.float16),
  6956. supports_alpha_param=False,
  6957. sample_inputs_func=foreach_inputs_sample_func(2, True, True),
  6958. ),
  6959. ]
  6960. foreach_pointwise_op_db: List[ForeachFuncInfo] = [
  6961. ForeachFuncInfo(
  6962. "addcmul",
  6963. dtypes=all_types_and_complex(),
  6964. dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16),
  6965. sample_inputs_func=foreach_pointwise_sample_func(3, False, False),
  6966. ),
  6967. ForeachFuncInfo(
  6968. "addcdiv",
  6969. dtypes=all_types_and_complex(),
  6970. dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16),
  6971. sample_inputs_func=foreach_pointwise_sample_func(3, False, False),
  6972. ),
  6973. ]
  6974. foreach_reduce_op_db: List[ForeachFuncInfo] = [
  6975. ForeachFuncInfo(
  6976. "norm",
  6977. dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
  6978. dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
  6979. sample_inputs_func=foreach_norm_sample_func(1, False, False),
  6980. ),
  6981. ]
  6982. foreach_lerp_op_db: List[ForeachFuncInfo] = [
  6983. ForeachFuncInfo(
  6984. "lerp",
  6985. dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
  6986. dtypesIfROCM=floating_and_complex_types_and(torch.half, torch.bfloat16),
  6987. sample_inputs_func=foreach_lerp_sample_func(3, True, False),
  6988. ),
  6989. ]
  6990. def reference_sign(x):
  6991. if x.dtype == np.bool_:
  6992. # `np.sign` doesn't support `bool`.
  6993. # >>> np.sign(True)
  6994. # ufunc 'sign' did not contain a loop
  6995. # with signature matching types dtype('bool') -> dtype('bool')
  6996. return np.sign(x, dtype=np.uint8).astype(np.bool_)
  6997. return np.sign(x)
  6998. def reference_sgn(x):
  6999. # NumPy doesn't have an equivalent to `torch.sgn` when the dtype is complex.
  7000. # For complex inputs, `np.sign` returns sign(x.real) + 0j if x.real != 0 else sign(x.imag) + 0j.
  7001. # while `torch.sgn` returns, 0 if abs(input) == 0 else input/abs(input)
  7002. if x.dtype not in [np.complex64, np.complex128]:
  7003. return reference_sign(x)
  7004. out = (x / np.abs(x))
  7005. if out.ndim == 0:
  7006. # Handle x == 0 case
  7007. if (x == 0):
  7008. # Can't assign to np.complex object
  7009. # So make a new one.
  7010. return np.array(complex(0, 0), dtype=x.dtype)
  7011. return out
  7012. # Handle x == 0 case
  7013. mask = (x == 0)
  7014. out[mask] = complex(0, 0)
  7015. return out
  7016. def reference_sigmoid(x):
  7017. # 'scipy.special.expit' not supported for the input types
  7018. if x.dtype in [np.complex64, np.complex128]:
  7019. return (1 / (1 + np.exp(-x)))
  7020. return scipy.special.expit(x)
  7021. def reference_logsigmoid(x):
  7022. return np.where(
  7023. x < 0,
  7024. x - np.log1p(np.exp(x)),
  7025. -np.log1p(np.exp(-x)))
  7026. def reference_hardsigmoid(x):
  7027. intermediate = x / 6 + 0.5
  7028. y = np.clip(intermediate, 0, None)
  7029. return np.where(y > 1, 1, y).astype(x.dtype)
  7030. def reference_lgamma(x):
  7031. # scipy.special.gammaln returns `-inf` when input is `-inf`.
  7032. # While Pytorch, C and C++, all return `inf` when input is `-inf`.
  7033. # Reference:
  7034. # https://en.cppreference.com/w/cpp/numeric/math/lgamma
  7035. # https://en.cppreference.com/w/c/numeric/math/lgamma
  7036. # To handle the above discrepancy,
  7037. # we replace -inf with inf so values
  7038. # that were originally -inf map to inf as expected
  7039. if x.dtype.kind == 'f':
  7040. x = np.where(x == float('-inf'), np.array(float('inf'), dtype=x.dtype), x)
  7041. out = scipy.special.gammaln(x)
  7042. if x.dtype == np.float16:
  7043. # `scipy.special.gammaln` returns output of float32 when input is float16,
  7044. # while `torch.lgamma` preserves `float16`. But due to smaller range of float16,
  7045. # Pytorch version outputs `inf` while SciPy returns finite values.
  7046. out = out.astype(np.float16)
  7047. return out
  7048. def reference_mvlgamma(x, d):
  7049. if x.dtype == np.float16:
  7050. return scipy.special.multigammaln(x, d).astype(np.float16)
  7051. return scipy.special.multigammaln(x, d)
  7052. def reference_softplus(input, beta=1, threshold=20):
  7053. non_linear = input * beta <= threshold
  7054. output = input.copy()
  7055. output[non_linear] = np.log(1 + np.exp(beta * input[non_linear])) / beta
  7056. return output
  7057. def reference_gelu(X, *, approximate='none'):
  7058. def _gelu_ref(X):
  7059. return X * stats.norm.cdf(X)
  7060. def _tanh_gelu_ref(X):
  7061. M_SQRT_2_PI = math.sqrt(2 / math.pi)
  7062. Z = M_SQRT_2_PI * (X + 0.044715 * np.power(X, 3.0))
  7063. return 0.5 * X * (1.0 + np.tanh(Z))
  7064. if approximate == 'tanh':
  7065. return _tanh_gelu_ref(X)
  7066. else:
  7067. return _gelu_ref(X)
  7068. def reference_one_hot(a: np.ndarray, num_classes: int = -1) -> np.ndarray:
  7069. if num_classes == -1:
  7070. num_classes = int(np.amax(a) + 1)
  7071. idcs = a.reshape(-1) + np.arange(0, a.size, dtype=np.int64) * num_classes
  7072. one_hot = np.zeros((a.size, num_classes), dtype=a.dtype)
  7073. np.put(one_hot, idcs, 1)
  7074. return one_hot.reshape(*a.shape, -1)
  7075. def reference_mse_loss(input, target, reduction="mean"):
  7076. se = (input - target) ** 2
  7077. if reduction == "mean":
  7078. return np.mean(se)
  7079. elif reduction == "sum":
  7080. return np.sum(se)
  7081. else: # reduction == "none"
  7082. return se
  7083. def wrapper_set_seed(op, *args, **kwargs):
  7084. """Wrapper to set seed manually for some functions like dropout
  7085. See: https://github.com/pytorch/pytorch/pull/62315#issuecomment-896143189 for more details.
  7086. """
  7087. with freeze_rng_state():
  7088. torch.manual_seed(42)
  7089. return op(*args, **kwargs)
  7090. def reference_layer_norm(inp: np.ndarray, normalized_shape: Tuple[int], weight=None, bias=None, eps=1e-5):
  7091. return reference_native_layer_norm(inp, normalized_shape, weight, bias, eps)[0]
  7092. def reference_native_layer_norm(inp: np.ndarray, normalized_shape: Tuple[int], weight, bias, eps):
  7093. feature_size = np.prod(normalized_shape)
  7094. inp_view = inp.reshape(-1, feature_size) # type: ignore[call-overload]
  7095. mean = inp_view.mean(axis=-1, keepdims=True)
  7096. var = inp_view.var(axis=-1, ddof=0, keepdims=True)
  7097. Y = (inp_view - mean) / np.sqrt(var + eps)
  7098. if weight is None and bias is not None:
  7099. Y = Y + bias.reshape(-1)
  7100. elif weight is not None and bias is None:
  7101. Y = Y * weight.reshape(-1)
  7102. elif weight is not None and bias is not None:
  7103. Y = Y * weight.reshape(-1) + bias.reshape(-1)
  7104. axis = inp.ndim - len(normalized_shape)
  7105. stat_shape = inp.shape[:axis] + (1,) * len(normalized_shape)
  7106. return Y.reshape(*inp.shape), mean.reshape(stat_shape), (1.0 / np.sqrt(var + eps)).reshape(stat_shape)
  7107. def reference_group_norm(inp: np.ndarray, num_groups: int, weight=None, bias=None, eps=1e-5):
  7108. inp_view = inp
  7109. if np.prod(inp.shape) != 0:
  7110. inp_view = inp.reshape((inp.shape[0], num_groups, -1))
  7111. mean = inp_view.mean(axis=-1, keepdims=True)
  7112. var = inp_view.var(axis=-1, ddof=0, keepdims=True)
  7113. Y = (inp_view - mean) / np.sqrt(var + eps)
  7114. Y = Y.reshape(inp.shape)
  7115. if weight is not None:
  7116. # weight is a vector of length equal to the channel
  7117. if len(Y.shape) > 2:
  7118. weight = np.expand_dims(weight, [0] + [idx + 2 for idx in range(inp.ndim - 2)])
  7119. Y = Y * weight
  7120. if bias is not None:
  7121. # bias is a vector of length equal to the channel
  7122. if len(Y.shape) > 2:
  7123. bias = np.expand_dims(bias, [0] + [idx + 2 for idx in range(inp.ndim - 2)])
  7124. Y = Y + bias
  7125. return Y
  7126. # using a custom reference function since numpy only has a string side arg (instead of right and side) and doesn't
  7127. # have an out_int32 arg. Additionally, numpy doesn't support searchsorted with ND arrays, so this splits those into
  7128. # stacked 1D cases
  7129. def reference_searchsorted(sorted_sequence, boundary, out_int32=False, right=False, side='left', sorter=None):
  7130. side = 'right' if (right or side == 'right') else 'left'
  7131. if len(sorted_sequence.shape) == 1 :
  7132. ret = np.searchsorted(sorted_sequence, boundary, side=side, sorter=sorter)
  7133. return ret.astype(np.int32) if out_int32 else ret
  7134. elif sorted_sequence.shape[0] == 0:
  7135. if sorter is not None:
  7136. sorter = sorter.flatten()
  7137. ret = np.searchsorted(sorted_sequence.flatten(), boundary.flatten(), side=side, sorter=sorter)
  7138. ret = ret.astype(np.int32) if out_int32 else ret
  7139. return ret.reshape(boundary.shape)
  7140. else:
  7141. # numpy searchsorted only supports 1D inputs so we split up ND inputs
  7142. orig_shape = boundary.shape
  7143. num_splits = np.prod(sorted_sequence.shape[:-1])
  7144. splits = range(0, num_splits)
  7145. sorted_sequence, boundary = sorted_sequence.reshape(num_splits, -1), boundary.reshape(num_splits, -1)
  7146. if sorter is not None:
  7147. sorter = sorter.reshape(num_splits, -1)
  7148. split_sequence = [sorted_sequence[i] for i in splits]
  7149. split_boundary = [boundary[i] for i in splits]
  7150. split_sorter = [sorter[i] if (sorter is not None) else None for i in splits]
  7151. split_ret = [np.searchsorted(s_seq, b, side=side, sorter=s_sort)
  7152. for (s_seq, b, s_sort) in zip(split_sequence, split_boundary, split_sorter)]
  7153. split_ret = [i.astype(np.int32) for i in split_ret] if out_int32 else split_ret
  7154. return np.stack(split_ret).reshape(orig_shape)
  7155. def loss_reference_reduction_wrapper(fn):
  7156. def wrapper(input, target, *, size_average=None, reduce=None, reduction="mean", **other_kwargs):
  7157. if size_average is not None or reduce is not None:
  7158. raise RuntimeError(
  7159. "The keyword arguments 'size_average' and 'reduce' are deprecated and not supported by this wrapper"
  7160. )
  7161. output = fn(input, target, **other_kwargs)
  7162. if reduction == "mean":
  7163. return np.mean(output)
  7164. elif reduction == "sum":
  7165. return np.sum(output)
  7166. else: # reduction == "none"
  7167. return output
  7168. return wrapper
  7169. @loss_reference_reduction_wrapper
  7170. def reference_smooth_l1_loss(input, target, beta=1.0):
  7171. diff = input - target
  7172. abs_diff = np.abs(diff)
  7173. above_threshold = abs_diff >= beta
  7174. loss = np.empty_like(input)
  7175. loss[above_threshold] = abs_diff[above_threshold] - 0.5 * beta
  7176. loss[~above_threshold] = diff[~above_threshold] ** 2 / (2 * beta)
  7177. return loss
  7178. def reference_std_var(f):
  7179. """Forwards unbiased/correction kwargs as NumPy's equivalent ddof"""
  7180. g = reference_reduction_numpy(f)
  7181. @wraps(g)
  7182. def wrapper(x: np.ndarray, *args, **kwargs):
  7183. assert not ('unbiased' in kwargs and 'correction' in kwargs)
  7184. if 'unbiased' in kwargs:
  7185. kwargs['ddof'] = int(kwargs.pop('unbiased'))
  7186. elif 'correction' in kwargs:
  7187. kwargs['ddof'] = kwargs.pop('correction')
  7188. return g(x, *args, **kwargs)
  7189. return wrapper
  7190. def generate_std_var_kwargs(t: torch.Tensor, **kwargs):
  7191. """Generates unbiased/correction kwargs for std/var operators"""
  7192. yield ((), {'unbiased': True})
  7193. yield ((), {'unbiased': False})
  7194. # Currently, calling std with correction is only enabled when
  7195. # both dim and keepdim are provided.
  7196. if 'dim' in kwargs and 'keepdim' in kwargs:
  7197. yield ((), {'correction': 0})
  7198. yield ((), {'correction': 1})
  7199. numel = torch.tensor(t.shape)[kwargs.get('dim')].prod()
  7200. yield ((), {'correction': numel // 2})
  7201. def error_inputs_mean(op_info, device, is_ref=False, **kwargs):
  7202. if is_ref:
  7203. err_msg1 = (r"mean\(\): could not infer output dtype. "
  7204. r"Input dtype must be either a floating point or complex dtype. "
  7205. r"Got: torch.int64")
  7206. else:
  7207. err_msg1 = (r"mean\(\): could not infer output dtype. "
  7208. r"Input dtype must be either a floating point or complex dtype. "
  7209. r"Got: Long")
  7210. yield ErrorInput(
  7211. SampleInput(make_tensor((3, 4, 5), dtype=torch.int64, device=device), []),
  7212. error_regex=err_msg1,
  7213. )
  7214. if is_ref:
  7215. err_msg2 = (r"mean\(\): could not infer output dtype. "
  7216. r"Optional dtype must be either a floating point or complex dtype. "
  7217. r"Got: torch.int64")
  7218. else:
  7219. err_msg2 = (r"mean\(\): could not infer output dtype. "
  7220. r"Optional dtype must be either a floating point or complex dtype. "
  7221. r"Got: Long")
  7222. yield ErrorInput(
  7223. SampleInput(
  7224. make_tensor((3, 4, 5), dtype=torch.float32, device=device),
  7225. [],
  7226. dtype=torch.int64),
  7227. error_regex=err_msg2
  7228. )
  7229. if is_ref:
  7230. err_msg3 = "Expected out tensor to have dtype torch.float64, but got torch.float32 instead"
  7231. else:
  7232. err_msg3 = "Expected out tensor to have dtype double, but got float instead"
  7233. yield ErrorInput(
  7234. SampleInput(
  7235. make_tensor((3, 4, 5), dtype=torch.int64, device=device),
  7236. [],
  7237. dtype=torch.float64,
  7238. out=make_tensor([], dtype=torch.float32, device=device),
  7239. ),
  7240. error_regex=err_msg3
  7241. )
  7242. # numpy implementation of torch.flatten
  7243. # unfortunately there's no np.flatten. we figure out the desired shape and call np.reshape
  7244. def reference_flatten(input, start_dim=0, end_dim=-1):
  7245. in_shape = input.shape
  7246. in_rank = len(in_shape)
  7247. for d in start_dim, end_dim:
  7248. if not ((in_rank == 0 and d in (-1, 0)) or -in_rank <= d < in_rank):
  7249. raise IndexError(f"Dimension out of range (expected to be in range of [{-in_rank}, {in_rank-1}], but got {d}")
  7250. end_dim = end_dim if end_dim >= 0 else in_rank + end_dim
  7251. start_dim = start_dim if start_dim >= 0 else in_rank + start_dim
  7252. if in_rank == 0:
  7253. end_dim = start_dim
  7254. if end_dim < start_dim:
  7255. raise RuntimeError("flatten() has invalid args: start_dim cannot come after end_dim")
  7256. flatten_bit_dim = functools.reduce(operator.mul, in_shape[start_dim:end_dim + 1], 1)
  7257. out_shape = in_shape[:start_dim] + (flatten_bit_dim,) + in_shape[end_dim + 1:]
  7258. return np.reshape(input, out_shape)
  7259. # Operator database (sorted alphabetically)
  7260. op_db: List[OpInfo] = [
  7261. UnaryUfuncInfo('abs',
  7262. aliases=('absolute', ),
  7263. ref=np.abs,
  7264. dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf),
  7265. dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
  7266. skips=(
  7267. DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), 'TestBwdGradients',
  7268. 'test_inplace_grad', dtypes=(torch.cdouble,)),
  7269. DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), 'TestBwdGradients',
  7270. 'test_inplace_gradgrad', dtypes=(torch.cdouble,)),
  7271. DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), 'TestFwdGradients',
  7272. 'test_inplace_forward_mode_AD', dtypes=(torch.cdouble,)),
  7273. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  7274. device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
  7275. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  7276. device_type='cpu', dtypes=[torch.cfloat]),
  7277. DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), "TestSparseUnaryUfuncs",
  7278. "test_inplace", dtypes=(torch.cdouble, torch.cfloat, torch.chalf)),
  7279. # Reference: https://github.com/pytorch/pytorch/issues/49224
  7280. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
  7281. dtypes=[torch.int8], active_if=TEST_WITH_ASAN),
  7282. # TODO: Fix test_out_arg_all_dtypes as torch.empty_like(expected_output) where expected_output=op(input)
  7283. # We can break the logic of the loop over all possible types but it is OK.
  7284. # https://github.com/pytorch/pytorch/blob/master/test/test_unary_ufuncs.py#L440-L449
  7285. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_out_arg_all_dtypes',
  7286. dtypes=[torch.cfloat, torch.cdouble]),
  7287. ),
  7288. supports_fwgrad_bwgrad=True,
  7289. assert_autodiffed=True,
  7290. supports_sparse=True,
  7291. supports_sparse_csr=True,
  7292. supports_sparse_csc=True,
  7293. supports_sparse_bsr=True,
  7294. supports_sparse_bsc=True,
  7295. supports_forward_ad=True),
  7296. # NOTE: CPU complex acos produces incorrect outputs (https://github.com/pytorch/pytorch/issues/42952)
  7297. UnaryUfuncInfo('acos',
  7298. aliases=('arccos', ),
  7299. ref=np.arccos,
  7300. domain=(-1, 1),
  7301. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
  7302. dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
  7303. assert_autodiffed=True,
  7304. supports_forward_ad=True,
  7305. supports_fwgrad_bwgrad=True,
  7306. decorators=(precisionOverride({torch.float16: 1e-2,
  7307. torch.bfloat16: 1e-1,
  7308. torch.complex64: 1e-2}),),
  7309. skips=(
  7310. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal',
  7311. device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS),
  7312. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  7313. device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS),
  7314. # Failing with wrong imaginary sign on at least some Windows jobs
  7315. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
  7316. device_type='cuda', dtypes=[torch.cdouble],
  7317. active_if=IS_WINDOWS),
  7318. # Failing with wrong imaginary sign on at least some Windows jobs
  7319. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  7320. device_type='cuda', dtypes=[torch.cdouble],
  7321. active_if=IS_WINDOWS),
  7322. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  7323. device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
  7324. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  7325. device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
  7326. DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad',
  7327. dtypes=[torch.cdouble], active_if=IS_WINDOWS),
  7328. DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_method_grad',
  7329. dtypes=[torch.cdouble], active_if=IS_WINDOWS),
  7330. DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_inplace_grad',
  7331. dtypes=[torch.cdouble], active_if=IS_WINDOWS),
  7332. DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD',
  7333. dtypes=[torch.cdouble], active_if=IS_WINDOWS),
  7334. DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_inplace_forward_mode_AD',
  7335. dtypes=[torch.cdouble], active_if=IS_WINDOWS),)),
  7336. # NOTE: the derivative for inplace acosh is not implemented
  7337. UnaryUfuncInfo('acosh',
  7338. aliases=('arccosh', ),
  7339. ref=np.arccosh,
  7340. domain=(1, None),
  7341. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
  7342. dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
  7343. decorators=(precisionOverride({torch.bfloat16: 5e-2}),),
  7344. supports_inplace_autograd=False,
  7345. supports_forward_ad=True,
  7346. supports_fwgrad_bwgrad=True,
  7347. skips=(
  7348. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal',
  7349. device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS),
  7350. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  7351. device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS),
  7352. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  7353. device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
  7354. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  7355. device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
  7356. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  7357. device_type='cuda', dtypes=[torch.cdouble],
  7358. active_if=IS_WINDOWS),
  7359. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  7360. device_type='cuda', dtypes=[torch.cdouble],
  7361. active_if=IS_WINDOWS),
  7362. # Failing with wrong imaginary sign on at least some Windows jobs
  7363. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
  7364. device_type='cuda', dtypes=[torch.cdouble],
  7365. active_if=IS_WINDOWS),
  7366. ),
  7367. # acosh is not defined at x < 1 (real)
  7368. reference_numerics_filter=NumericsFilter(
  7369. condition=lambda x: (x < 1 if not x.is_complex() else torch.zeros_like(x, dtype=torch.bool)),
  7370. safe_val=2)),
  7371. BinaryUfuncInfo('add',
  7372. # NumPy has no builtin reference for the alpha kwarg, but it is easy enough to emulate
  7373. ref=lambda input, other, *, alpha=1: np.add(input, other) if alpha == 1 \
  7374. else np.add(input, np.multiply(alpha, other)),
  7375. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16,
  7376. torch.float16, torch.chalf),
  7377. assert_autodiffed=True,
  7378. sample_inputs_func=sample_inputs_add_sub,
  7379. supports_fwgrad_bwgrad=True,
  7380. supports_forward_ad=True,
  7381. supports_two_python_scalars=True,
  7382. decorators=(
  7383. DecorateInfo(
  7384. toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}),
  7385. 'TestBinaryUfuncs', 'test_reference_numerics'),
  7386. ),
  7387. skips=(
  7388. # boolean alpha not handled properly
  7389. DecorateInfo(unittest.expectedFailure,
  7390. 'TestCudaFuserOpInfo',
  7391. 'test_nvfuser_correctness',
  7392. dtypes=(torch.bool,)),
  7393. # boolean alpha not handled properly
  7394. DecorateInfo(unittest.expectedFailure,
  7395. 'TestNNCOpInfo',
  7396. 'test_nnc_correctness',
  7397. dtypes=(torch.bool,)),
  7398. DecorateInfo(unittest.skip("Skipped!"),
  7399. 'TestCommon',
  7400. 'test_numpy_refs',
  7401. dtypes=(torch.complex128,)),
  7402. DecorateInfo(unittest.skip("Skipped!"),
  7403. 'TestBinaryUfuncs',
  7404. 'test_reference_numerics_extremal_values',
  7405. dtypes=(torch.complex64, torch.complex128)),
  7406. )),
  7407. OpInfo('arange',
  7408. dtypes=all_types_and(torch.bfloat16, torch.float16),
  7409. supports_out=True,
  7410. supports_autograd=False,
  7411. is_factory_function=True,
  7412. error_inputs_func=error_inputs_arange,
  7413. sample_inputs_func=sample_inputs_arange,
  7414. skips=(
  7415. # https://github.com/pytorch/pytorch/issues/81774
  7416. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  7417. # Tests that assume input is a tensor or sequence of tensors
  7418. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
  7419. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
  7420. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
  7421. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
  7422. # Lazy tensor failures
  7423. DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_dispatched_to_lazy'),
  7424. DecorateInfo(unittest.skip("Skipped!"), 'TestLazyOpInfo', 'test_correctness'),
  7425. DecorateInfo(unittest.skip("Skipped!"), 'TestLazyOpInfo', 'test_correctness_with_reusing_ir'),
  7426. # Exception raised from analyzeImpl at ../torch/csrc/jit/ir/alias_analysis.cpp:608
  7427. # We don't have an op for aten::arange but it isn't a special case.
  7428. # Argument types: bool, bool, bool, int, int, Device, boo
  7429. DecorateInfo(unittest.expectedFailure, 'TestCudaFuserOpInfo', 'test_nvfuser_correctness'),
  7430. DecorateInfo(unittest.expectedFailure, 'TestCudaFuserOpInfo', 'test_nvfuser_extremal_values'),
  7431. DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness'),
  7432. # Captured graph does not contain aten::arange (succeeds on complex!)
  7433. # g: graph():
  7434. # %25 : Long(1, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value={1}]()
  7435. # return (%25)
  7436. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
  7437. # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
  7438. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
  7439. )),
  7440. OpInfo('cauchy',
  7441. op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.cauchy_, inp, *args, **kwargs),
  7442. inplace_variant=torch.Tensor.cauchy_,
  7443. dtypes=floating_types_and(torch.float16, torch.bfloat16),
  7444. supports_out=False,
  7445. supports_autograd=False,
  7446. sample_inputs_func=sample_inputs_cauchy,
  7447. error_inputs_func=error_inputs_cauchy,
  7448. skips=(
  7449. # Tests that assume input tensor has a meaningful effect on output tensor
  7450. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
  7451. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
  7452. # AssertionError: JIT Test does not execute any logic
  7453. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  7454. # AssertionError: Tensor-likes are not close!
  7455. DecorateInfo(unittest.expectedFailure, 'TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive_inplace'),
  7456. DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
  7457. # FX failed to normalize op - add the op to the op_skip list.
  7458. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  7459. # vmap: calling random operator not supported
  7460. DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"),
  7461. DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
  7462. DecorateInfo(unittest.skip("make_traced() doesn't set seed properly!"), 'TestCommon', 'test_python_ref_executor'),
  7463. DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'),
  7464. )),
  7465. OpInfo('exponential',
  7466. op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.exponential_, inp, *args, **kwargs),
  7467. inplace_variant=torch.Tensor.exponential_,
  7468. dtypes=floating_types_and(torch.float16, torch.bfloat16),
  7469. supports_out=False,
  7470. supports_autograd=False,
  7471. sample_inputs_func=sample_inputs_exponential,
  7472. error_inputs_func=error_inputs_exponential,
  7473. skips=(
  7474. # Tests that assume input tensor has a meaningful effect on output tensor
  7475. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
  7476. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
  7477. # AssertionError: JIT Test does not execute any logic
  7478. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  7479. # AssertionError: Tensor-likes are not close!
  7480. DecorateInfo(unittest.expectedFailure, 'TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive_inplace'),
  7481. DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
  7482. # FX failed to normalize op - add the op to the op_skip list.
  7483. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  7484. # vmap: calling random operator not supported
  7485. DecorateInfo(unittest.expectedFailure, "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"),
  7486. DecorateInfo(unittest.expectedFailure, "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
  7487. DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'),
  7488. DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
  7489. )),
  7490. OpInfo('geometric',
  7491. op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.geometric_, inp, *args, **kwargs),
  7492. inplace_variant=torch.Tensor.geometric_,
  7493. dtypes=floating_types_and(torch.float16, torch.bfloat16, torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8),
  7494. supports_out=False,
  7495. supports_autograd=False,
  7496. sample_inputs_func=sample_inputs_geometric,
  7497. error_inputs_func=error_inputs_geometric,
  7498. skips=(
  7499. # Tests that assume input tensor has a meaningful effect on output tensor
  7500. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
  7501. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
  7502. # AssertionError: JIT Test does not execute any logic
  7503. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  7504. # AssertionError: Tensor-likes are not close!
  7505. DecorateInfo(unittest.expectedFailure, 'TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive_inplace'),
  7506. DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
  7507. # FX failed to normalize op - add the op to the op_skip list.
  7508. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  7509. # vmap: calling random operator not supported
  7510. DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"),
  7511. DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
  7512. DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'),
  7513. )),
  7514. OpInfo('log_normal',
  7515. op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.log_normal_, inp, *args, **kwargs),
  7516. inplace_variant=torch.Tensor.log_normal_,
  7517. dtypes=floating_types_and(torch.float16, torch.bfloat16),
  7518. supports_out=False,
  7519. supports_autograd=False,
  7520. sample_inputs_func=sample_inputs_log_normal,
  7521. error_inputs_func=error_inputs_log_normal,
  7522. skips=(
  7523. # Tests that assume input tensor has a meaningful effect on output tensor
  7524. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
  7525. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
  7526. # AssertionError: JIT Test does not execute any logic
  7527. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  7528. # AssertionError: Tensor-likes are not close!
  7529. DecorateInfo(unittest.expectedFailure, 'TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive_inplace'),
  7530. DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
  7531. # FX failed to normalize op - add the op to the op_skip list.
  7532. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  7533. # vmap: calling random operator not supported
  7534. DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"),
  7535. DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
  7536. DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'),
  7537. )),
  7538. OpInfo('uniform',
  7539. op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.uniform_, inp, *args, **kwargs),
  7540. method_variant=None,
  7541. inplace_variant=torch.Tensor.uniform_,
  7542. dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16),
  7543. supports_out=False,
  7544. supports_autograd=False,
  7545. is_factory_function=False,
  7546. sample_inputs_func=sample_inputs_uniform,
  7547. error_inputs_func=error_inputs_uniform,
  7548. skips=(
  7549. # FX failed to normalize op - add the op to the op_skip list.
  7550. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  7551. # Tests that assume input tensor has a meningful effect on output tensor
  7552. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
  7553. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
  7554. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
  7555. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
  7556. # AssertionError: JIT Test does not execute any logic
  7557. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  7558. # aten.uniform was not decomposed
  7559. DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'),
  7560. DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
  7561. )),
  7562. BinaryUfuncInfo('clamp_max',
  7563. ref=_clamp_max_numpy,
  7564. dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
  7565. supports_forward_ad=True,
  7566. supports_rhs_python_scalar=False,
  7567. supports_fwgrad_bwgrad=True,
  7568. rhs_make_tensor_kwargs=dict(exclude_zero=False),
  7569. skips=(
  7570. # RuntimeError: "max_elementwise_cuda" not implemented for 'ComplexFloat'
  7571. DecorateInfo(unittest.expectedFailure,
  7572. 'TestBinaryUfuncs',
  7573. 'test_type_promotion',
  7574. device_type='cuda'),
  7575. # dispatch to lazy test failed
  7576. DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_dispatched_to_lazy'),
  7577. # test error disabled since rhs non-tensor python scalar is supported
  7578. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_errors'),
  7579. )),
  7580. BinaryUfuncInfo('clamp_min',
  7581. ref=_clamp_min_numpy,
  7582. dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
  7583. supports_forward_ad=True,
  7584. supports_rhs_python_scalar=False,
  7585. supports_fwgrad_bwgrad=True,
  7586. rhs_make_tensor_kwargs=dict(exclude_zero=False),
  7587. skips=(
  7588. # RuntimeError: "min_elementwise_cuda" not implemented for 'ComplexFloat'
  7589. DecorateInfo(unittest.expectedFailure,
  7590. 'TestBinaryUfuncs',
  7591. 'test_type_promotion',
  7592. device_type='cuda'),
  7593. # dispatch to lazy test failed
  7594. DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_dispatched_to_lazy'),
  7595. # test error disabled since rhs non-tensor python scalar is supported
  7596. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_errors'),
  7597. )),
  7598. BinaryUfuncInfo('mul',
  7599. aliases=('multiply',),
  7600. dtypes=all_types_and_complex_and(torch.chalf, torch.float16, torch.bfloat16, torch.bool),
  7601. assert_autodiffed=True,
  7602. supports_forward_ad=True,
  7603. supports_fwgrad_bwgrad=True,
  7604. supports_two_python_scalars=True),
  7605. BinaryUfuncInfo('sub',
  7606. # NumPy has no builtin reference for the alpha kwarg, but it is easy enough to emulate
  7607. ref=lambda input, other, *, alpha=1: np.subtract(input, np.multiply(alpha, other)),
  7608. aliases=('subtract',),
  7609. dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.chalf),
  7610. assert_autodiffed=True,
  7611. supports_forward_ad=True,
  7612. supports_fwgrad_bwgrad=True,
  7613. sample_inputs_func=sample_inputs_add_sub,
  7614. supports_two_python_scalars=True,
  7615. decorators=(
  7616. DecorateInfo(
  7617. toleranceOverride({torch.float16: tol(atol=1e-2, rtol=0),
  7618. torch.bfloat16: tol(atol=1e-5, rtol=5e-3),
  7619. torch.complex32: tol(atol=1e-5, rtol=1e-3)}),
  7620. 'TestBinaryUfuncs', 'test_reference_numerics'),
  7621. DecorateInfo(
  7622. toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}),
  7623. 'TestCommon', 'test_complex_half_reference_testing', device_type='cpu'),
  7624. DecorateInfo(
  7625. toleranceOverride({torch.chalf: tol(atol=5e-3, rtol=0)}),
  7626. 'TestDecomp', 'test_comprehensive', device_type='cpu'),
  7627. DecorateInfo(
  7628. toleranceOverride({torch.chalf: tol(atol=5e-3, rtol=0)}),
  7629. 'TestDecomp', 'test_quick', device_type='cpu'),
  7630. ),
  7631. skips=(
  7632. DecorateInfo(unittest.skip("Skipped!"),
  7633. 'TestBinaryUfuncs',
  7634. 'test_reference_numerics',
  7635. dtypes=(torch.uint8,)),
  7636. DecorateInfo(unittest.skip("Skipped!"),
  7637. 'TestBinaryUfuncs',
  7638. 'test_reference_numerics_small_values',
  7639. dtypes=(torch.uint8,)),
  7640. )),
  7641. OpInfo('addmm',
  7642. # This addmm OpInfo is for when alpha and beta are not both equal to 1.
  7643. # alpha=beta=1 is tested in the following opinfo, because that special case will
  7644. # trigger addmm being decomposed by a jit pass.
  7645. dtypes=all_types_and_complex_and(torch.bfloat16),
  7646. dtypesIfROCM=floating_and_complex_types_and(torch.float16, torch.bfloat16),
  7647. dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
  7648. assert_autodiffed=True,
  7649. supports_forward_ad=True,
  7650. supports_fwgrad_bwgrad=True,
  7651. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  7652. sample_inputs_func=sample_inputs_addmm,
  7653. skips=(
  7654. # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
  7655. DecorateInfo(
  7656. unittest.skip("Skipped!"),
  7657. 'TestSchemaCheckModeOpInfo',
  7658. 'test_schema_correctness',
  7659. dtypes=(torch.complex64, torch.complex128)),
  7660. )),
  7661. OpInfo('addmm',
  7662. # When alpha=beta=1 as compile-time constants, JIT will decompose addmm into mm and add.
  7663. variant_test_name='decomposed',
  7664. dtypes=all_types_and_complex_and(torch.bfloat16),
  7665. dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
  7666. assert_autodiffed=True,
  7667. supports_forward_ad=True,
  7668. supports_fwgrad_bwgrad=True,
  7669. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  7670. autodiff_nonfusible_nodes=['aten::add', 'aten::mm'],
  7671. sample_inputs_func=partial(sample_inputs_addmm, alpha=1, beta=1),
  7672. skips=(
  7673. # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
  7674. DecorateInfo(
  7675. unittest.skip("Skipped!"),
  7676. 'TestSchemaCheckModeOpInfo',
  7677. 'test_schema_correctness',
  7678. dtypes=(torch.complex64, torch.complex128)),
  7679. # https://github.com/pytorch/pytorch/issues/71784
  7680. DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness',
  7681. device_type='cpu', dtypes=(torch.float16,)),
  7682. DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo', 'test_nvfuser_correctness', dtypes=(torch.float16,)),
  7683. )),
  7684. OpInfo('addmv',
  7685. dtypes=all_types_and_complex_and(torch.bfloat16),
  7686. dtypesIfCUDA=floating_types_and(torch.float16, torch.complex64, torch.complex128,
  7687. torch.bfloat16),
  7688. supports_forward_ad=True,
  7689. supports_fwgrad_bwgrad=True,
  7690. sample_inputs_func=sample_inputs_addmv),
  7691. OpInfo('addbmm',
  7692. ref=lambda M, batch1, batch2, beta=1, alpha=1: np.add(np.multiply(np.asarray(beta, dtype=M.dtype), M),
  7693. np.multiply(np.asarray(alpha, dtype=batch1.dtype),
  7694. np.sum(np.matmul(batch1, batch2), axis=0))),
  7695. dtypes=all_types_and_complex_and(torch.bfloat16),
  7696. dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
  7697. *[torch.bfloat16]
  7698. if SM53OrLater or TEST_WITH_ROCM else []),
  7699. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  7700. gradcheck_fast_mode=True,
  7701. supports_forward_ad=True,
  7702. supports_fwgrad_bwgrad=True,
  7703. decorators=[
  7704. DecorateInfo(
  7705. toleranceOverride({torch.float32: tol(atol=1.3e-05, rtol=1.3e-05),
  7706. torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}),
  7707. 'TestCommon', 'test_numpy_refs'),
  7708. # MPS has slightly worse precision. Is this acceptable?
  7709. DecorateInfo(
  7710. toleranceOverride({torch.float32: tol(atol=1.3e-04, rtol=1.3e-04),
  7711. torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}),
  7712. 'TestCommon', 'test_numpy_ref_mps'),
  7713. DecorateInfo(
  7714. toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-5)}),
  7715. 'TestConsistency',
  7716. 'test_output_match',
  7717. ),
  7718. DecorateInfo(
  7719. toleranceOverride({torch.float32: tol(atol=1.5e-05, rtol=1e-05)}),
  7720. 'TestCommon', 'test_out'),
  7721. ],
  7722. skips=(
  7723. # NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3
  7724. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater),
  7725. # addbmm does not correctly warn when resizing out= inputs
  7726. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
  7727. # https://github.com/pytorch/pytorch/issues/55907
  7728. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'),
  7729. ),
  7730. sample_inputs_func=sample_inputs_addbmm),
  7731. OpInfo('baddbmm',
  7732. dtypes=all_types_and_complex_and(torch.bfloat16),
  7733. dtypesIfCUDA=floating_types_and(torch.float16, torch.complex64, torch.complex128,
  7734. torch.bfloat16),
  7735. backward_dtypesIfCUDA=floating_types_and(torch.float16,
  7736. *[torch.bfloat16] if SM53OrLater or TEST_WITH_ROCM else [],
  7737. torch.complex64, torch.complex128),
  7738. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  7739. gradcheck_fast_mode=True,
  7740. supports_forward_ad=True,
  7741. supports_fwgrad_bwgrad=True,
  7742. decorators=[
  7743. DecorateInfo(
  7744. toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}),
  7745. 'TestCommon', 'test_variant_consistency_eager', device_type='cuda'),
  7746. DecorateInfo(
  7747. toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}),
  7748. 'TestMathBits', 'test_conj_view', device_type='cuda')],
  7749. sample_inputs_func=sample_inputs_baddbmm,
  7750. skips=(
  7751. # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
  7752. DecorateInfo(
  7753. unittest.skip("Skipped!"),
  7754. 'TestSchemaCheckModeOpInfo',
  7755. 'test_schema_correctness',
  7756. dtypes=(torch.complex64, torch.complex128)),
  7757. )),
  7758. OpInfo('dot',
  7759. dtypes=all_types_and_complex_and(torch.bfloat16),
  7760. dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
  7761. assert_autodiffed=True,
  7762. sample_inputs_func=sample_inputs_dot_vdot,
  7763. supports_forward_ad=True,
  7764. supports_fwgrad_bwgrad=True,
  7765. skips=(
  7766. # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
  7767. DecorateInfo(
  7768. unittest.skip("Skipped!"),
  7769. 'TestSchemaCheckModeOpInfo',
  7770. 'test_schema_correctness',
  7771. dtypes=(torch.complex64, torch.complex128)),
  7772. )),
  7773. OpInfo('vdot',
  7774. dtypes=all_types_and_complex_and(torch.bfloat16),
  7775. dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
  7776. sample_inputs_func=sample_inputs_dot_vdot,
  7777. supports_forward_ad=True,
  7778. supports_fwgrad_bwgrad=True,
  7779. skips=(
  7780. # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
  7781. DecorateInfo(
  7782. unittest.skip("Skipped!"),
  7783. 'TestSchemaCheckModeOpInfo',
  7784. 'test_schema_correctness',
  7785. dtypes=(torch.complex64, torch.complex128)),
  7786. )),
  7787. OpInfo('bmm',
  7788. dtypes=all_types_and_complex_and(torch.bfloat16),
  7789. dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
  7790. *[torch.bfloat16]
  7791. if SM53OrLater or TEST_WITH_ROCM else []),
  7792. assert_autodiffed=True,
  7793. assert_jit_shape_analysis=True,
  7794. supports_forward_ad=True,
  7795. supports_fwgrad_bwgrad=True,
  7796. skips=(
  7797. # NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3
  7798. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater),
  7799. DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-5)}),
  7800. "TestCommon", "test_out")
  7801. ),
  7802. sample_inputs_func=sample_inputs_bmm),
  7803. OpInfo('mv',
  7804. dtypes=all_types_and_complex_and(torch.bfloat16),
  7805. dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
  7806. assert_autodiffed=True,
  7807. supports_forward_ad=True,
  7808. supports_fwgrad_bwgrad=True,
  7809. sample_inputs_func=sample_inputs_mv),
  7810. OpInfo('addr',
  7811. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
  7812. backward_dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
  7813. backward_dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
  7814. # Reference: https://github.com/pytorch/pytorch/issues/50747
  7815. supports_forward_ad=True,
  7816. supports_fwgrad_bwgrad=True,
  7817. skips=(
  7818. # Reference: https://github.com/pytorch/pytorch/issues/50747
  7819. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager',
  7820. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16)),
  7821. ),
  7822. sample_inputs_func=sample_inputs_addr,
  7823. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL),
  7824. OpInfo('addcmul',
  7825. dtypes=all_types_and_complex_and(torch.bfloat16),
  7826. dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
  7827. assert_autodiffed=True,
  7828. supports_forward_ad=True,
  7829. supports_fwgrad_bwgrad=True,
  7830. skips=(
  7831. # TODO: update sample inputs with for_inplace_variant kwarg to support this test
  7832. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
  7833. ),
  7834. sample_inputs_func=sample_inputs_addcmul_addcdiv,
  7835. reference_inputs_func=partial(
  7836. reference_inputs_elementwise_ternary, sample_inputs_func=reference_inputs_addcmul_addcdiv)),
  7837. OpInfo('addcdiv',
  7838. dtypes=floating_and_complex_types_and(torch.bfloat16),
  7839. dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
  7840. supports_forward_ad=True,
  7841. supports_fwgrad_bwgrad=True,
  7842. skips=(
  7843. # TODO: update sample inputs with for_inplace_variant kwarg to support this test
  7844. DecorateInfo(unittest.expectedFailure,
  7845. 'TestCommon',
  7846. 'test_variant_consistency_eager'),
  7847. ),
  7848. sample_inputs_func=sample_inputs_addcmul_addcdiv,
  7849. reference_inputs_func=partial(
  7850. reference_inputs_elementwise_ternary, sample_inputs_func=reference_inputs_addcmul_addcdiv)),
  7851. UnaryUfuncInfo('asin',
  7852. aliases=('arcsin', ),
  7853. ref=np.arcsin,
  7854. domain=(-1, 1),
  7855. supports_sparse=True,
  7856. supports_sparse_csr=True,
  7857. supports_sparse_csc=True,
  7858. supports_sparse_bsr=True,
  7859. supports_sparse_bsc=True,
  7860. supports_forward_ad=True,
  7861. supports_fwgrad_bwgrad=True,
  7862. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
  7863. dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
  7864. assert_autodiffed=True,
  7865. decorators=[
  7866. DecorateInfo(
  7867. toleranceOverride({torch.float16: tol(atol=1e-05, rtol=1e-03)}),
  7868. 'TestUnaryUfuncs', device_type='cuda'),
  7869. precisionOverride({torch.bfloat16: 1e-2}),
  7870. ],
  7871. skips=(
  7872. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  7873. device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
  7874. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  7875. device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
  7876. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  7877. device_type='cuda', dtypes=[torch.cdouble],
  7878. active_if=IS_WINDOWS),
  7879. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  7880. device_type='cuda', dtypes=[torch.cdouble],
  7881. active_if=IS_WINDOWS),
  7882. DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
  7883. 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
  7884. )),
  7885. # NOTE: derivative for inplace asinh is not implemented
  7886. UnaryUfuncInfo('asinh',
  7887. aliases=('arcsinh', ),
  7888. ref=np.arcsinh,
  7889. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
  7890. dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
  7891. decorators=(precisionOverride({torch.bfloat16: 5e-2}),),
  7892. supports_inplace_autograd=False,
  7893. supports_forward_ad=True,
  7894. supports_fwgrad_bwgrad=True,
  7895. supports_sparse=True,
  7896. supports_sparse_csr=True,
  7897. supports_sparse_csc=True,
  7898. supports_sparse_bsr=True,
  7899. supports_sparse_bsc=True,
  7900. skips=(
  7901. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  7902. device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
  7903. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  7904. device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
  7905. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
  7906. device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
  7907. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal',
  7908. device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
  7909. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  7910. device_type='cuda', dtypes=[torch.cdouble],
  7911. active_if=IS_WINDOWS),
  7912. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  7913. device_type='cuda', dtypes=[torch.cdouble],
  7914. active_if=IS_WINDOWS),
  7915. DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
  7916. 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
  7917. )),
  7918. UnaryUfuncInfo('atan',
  7919. aliases=('arctan', ),
  7920. ref=np.arctan,
  7921. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
  7922. dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
  7923. assert_autodiffed=True,
  7924. supports_forward_ad=True,
  7925. supports_fwgrad_bwgrad=True,
  7926. supports_sparse=True,
  7927. supports_sparse_csr=True,
  7928. supports_sparse_csc=True,
  7929. supports_sparse_bsr=True,
  7930. supports_sparse_bsc=True,
  7931. decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
  7932. skips=(
  7933. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
  7934. active_if=TEST_WITH_ROCM, device_type='cuda', dtypes=[torch.complex64, torch.complex128]),
  7935. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  7936. active_if=TEST_WITH_ROCM, device_type='cuda', dtypes=[torch.complex128]),
  7937. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  7938. device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
  7939. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  7940. device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
  7941. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
  7942. device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
  7943. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  7944. device_type='cuda', dtypes=[torch.cfloat, torch.cdouble],
  7945. active_if=IS_WINDOWS),
  7946. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  7947. device_type='cuda', dtypes=[torch.cfloat, torch.cdouble],
  7948. active_if=IS_WINDOWS),
  7949. DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
  7950. 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
  7951. )),
  7952. BinaryUfuncInfo('atan2',
  7953. aliases=('arctan2',),
  7954. dtypes=all_types_and(torch.bool, torch.bfloat16),
  7955. dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
  7956. supports_forward_ad=True,
  7957. supports_fwgrad_bwgrad=True,
  7958. promotes_int_to_float=True,
  7959. supports_rhs_python_scalar=False,
  7960. skips=(
  7961. # Incorrectly attempts to use a scalar for the second argument
  7962. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'),
  7963. )),
  7964. UnaryUfuncInfo('atanh',
  7965. aliases=('arctanh', ),
  7966. ref=np.arctanh,
  7967. domain=(-1, 1),
  7968. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
  7969. dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
  7970. decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
  7971. supports_inplace_autograd=False,
  7972. supports_forward_ad=True,
  7973. supports_fwgrad_bwgrad=True,
  7974. supports_sparse=True,
  7975. supports_sparse_csr=True,
  7976. supports_sparse_csc=True,
  7977. supports_sparse_bsr=True,
  7978. supports_sparse_bsc=True,
  7979. skips=(
  7980. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
  7981. device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
  7982. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  7983. device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
  7984. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  7985. device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
  7986. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  7987. device_type='cuda', dtypes=[torch.cfloat, torch.cdouble],
  7988. active_if=IS_WINDOWS),
  7989. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  7990. device_type='cuda', dtypes=[torch.cfloat],
  7991. active_if=IS_WINDOWS),
  7992. DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
  7993. 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
  7994. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  7995. active_if=TEST_WITH_ROCM, device_type='cuda', dtypes=[torch.complex128]),
  7996. )),
  7997. OpInfo('allclose',
  7998. dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
  7999. ref=np.allclose,
  8000. supports_autograd=False,
  8001. supports_forward_ad=False,
  8002. sample_inputs_func=sample_inputs_allclose,
  8003. skips=(
  8004. DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
  8005. DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
  8006. DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'),
  8007. ),
  8008. supports_out=False),
  8009. OpInfo('broadcast_to',
  8010. ref=np.broadcast_to,
  8011. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
  8012. supports_out=False,
  8013. supports_forward_ad=True,
  8014. supports_fwgrad_bwgrad=True,
  8015. # See https://github.com/pytorch/pytorch/pull/78358
  8016. check_batched_forward_grad=False,
  8017. sample_inputs_func=sample_inputs_broadcast_to),
  8018. OpInfo('broadcast_shapes',
  8019. op=torch.broadcast_shapes,
  8020. ref=np.broadcast_shapes if np.lib.NumpyVersion(np.__version__) >= '1.20.0' else None,
  8021. dtypes=_dispatch_dtypes((torch.float32,)),
  8022. supports_out=False,
  8023. supports_gradgrad=False,
  8024. assert_autodiffed=False,
  8025. supports_autograd=False,
  8026. supports_scripting=False,
  8027. sample_inputs_func=sample_inputs_broadcast_shapes,
  8028. skips=(
  8029. # https://github.com/pytorch/pytorch/issues/64997
  8030. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  8031. # skip dtype tests since broadcast_shape is not device dependent.
  8032. # having dtypes limited to torch.float32 would cause test_dtypes to report unexpected success
  8033. DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_dtypes'),
  8034. # skip these tests since we have non tensor input
  8035. DecorateInfo(unittest.skip('Skipped!'), "TestCommon", "test_noncontiguous_samples"),
  8036. DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'),
  8037. DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'),
  8038. )),
  8039. OpInfo('broadcast_tensors',
  8040. ref=np.broadcast_arrays,
  8041. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
  8042. sample_inputs_func=sample_inputs_broadcast_tensors,
  8043. reference_inputs_func=reference_inputs_broadcast_tensors,
  8044. supports_out=False,
  8045. supports_forward_ad=True,
  8046. supports_fwgrad_bwgrad=True,
  8047. # See https://github.com/pytorch/pytorch/pull/78358
  8048. check_batched_forward_grad=False,
  8049. skips=(
  8050. # https://github.com/pytorch/pytorch/issues/64997
  8051. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  8052. # JIT does not support variadic tensors.
  8053. # RuntimeError: input->type()->kind() == TypeKind::OptionalType
  8054. # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252,
  8055. # please report a bug to PyTorch.
  8056. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]),
  8057. )),
  8058. OpInfo('block_diag',
  8059. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
  8060. supports_out=False,
  8061. supports_forward_ad=True,
  8062. supports_fwgrad_bwgrad=True,
  8063. # Default batching rule in core doesn't work for ops with TensorList args
  8064. check_batched_forward_grad=False,
  8065. skips=(
  8066. # https://github.com/pytorch/pytorch/issues/64997
  8067. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  8068. # JIT does not support variadic tensors.
  8069. # RuntimeError: input->type()->kind() == TypeKind::OptionalType
  8070. # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252,
  8071. # please report a bug to PyTorch.
  8072. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]),
  8073. ),
  8074. sample_inputs_func=sample_inputs_block_diag),
  8075. UnaryUfuncInfo('bitwise_not',
  8076. ref=np.bitwise_not,
  8077. dtypes=integral_types_and(torch.bool),
  8078. operator_variant=operator.invert,
  8079. supports_autograd=False),
  8080. BinaryUfuncInfo('bitwise_left_shift',
  8081. op=torch.bitwise_left_shift,
  8082. dtypes=integral_types(),
  8083. dtypesIfCUDA=integral_types(),
  8084. operator_variant=operator.lshift,
  8085. inplace_operator_variant=operator.ilshift,
  8086. supports_autograd=False,
  8087. supports_one_python_scalar=True,
  8088. rhs_make_tensor_kwargs=dict(low=0),
  8089. skips=(
  8090. DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'),
  8091. # https://github.com/pytorch/pytorch/issues/70904
  8092. DecorateInfo(unittest.skip("Some inputs produce undefined outputs"), 'TestCommon', 'test_compare_cpu'),
  8093. )),
  8094. BinaryUfuncInfo('bitwise_right_shift',
  8095. op=torch.bitwise_right_shift,
  8096. dtypes=integral_types(),
  8097. dtypesIfCUDA=integral_types(),
  8098. operator_variant=operator.rshift,
  8099. inplace_operator_variant=operator.irshift,
  8100. supports_autograd=False,
  8101. supports_one_python_scalar=True,
  8102. rhs_make_tensor_kwargs=dict(low=0),
  8103. skips=(
  8104. DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'),
  8105. # https://github.com/pytorch/pytorch/issues/70904
  8106. DecorateInfo(unittest.skip("Some inputs produce undefined outputs"), 'TestCommon', 'test_compare_cpu'),
  8107. )),
  8108. OpInfo('combinations',
  8109. op=torch.combinations,
  8110. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
  8111. supports_forward_ad=True,
  8112. supports_fwgrad_bwgrad=True,
  8113. # See https://github.com/pytorch/pytorch/pull/78358
  8114. check_batched_forward_grad=False,
  8115. supports_out=False,
  8116. sample_inputs_func=sample_inputs_combinations),
  8117. OpInfo('cartesian_prod',
  8118. op=torch.cartesian_prod,
  8119. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
  8120. supports_out=False,
  8121. supports_forward_ad=True,
  8122. supports_fwgrad_bwgrad=True,
  8123. # See https://github.com/pytorch/pytorch/pull/78358
  8124. check_batched_forward_grad=False,
  8125. sample_inputs_func=sample_inputs_cartesian_prod,
  8126. skips=(
  8127. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
  8128. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
  8129. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  8130. # RuntimeError: input->type()->kind() == TypeKind::OptionalType
  8131. # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270
  8132. DecorateInfo(unittest.expectedFailure,
  8133. 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
  8134. )),
  8135. OpInfo('cdist',
  8136. dtypes=floating_types(),
  8137. supports_out=False,
  8138. supports_gradgrad=False,
  8139. assert_autodiffed=False,
  8140. sample_inputs_func=sample_inputs_cdist),
  8141. UnaryUfuncInfo('ceil',
  8142. ref=np.ceil,
  8143. dtypes=all_types_and(torch.bfloat16),
  8144. dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
  8145. supports_forward_ad=True,
  8146. supports_fwgrad_bwgrad=True,
  8147. skips=(
  8148. DecorateInfo(unittest.expectedFailure,
  8149. 'TestNNCOpInfo',
  8150. 'test_nnc_correctness',
  8151. dtypes=tuple(t for t in integral_types() if t != torch.uint8)),
  8152. DecorateInfo(unittest.expectedFailure,
  8153. 'TestCudaFuserOpInfo',
  8154. 'test_nvfuser_correctness',
  8155. dtypes=(torch.int32, torch.int64),
  8156. active_if=not TEST_WITH_ROCM),
  8157. ),
  8158. supports_sparse=True,
  8159. supports_sparse_csr=True,
  8160. supports_sparse_csc=True,
  8161. supports_sparse_bsr=True,
  8162. supports_sparse_bsc=True,
  8163. assert_autodiffed=True),
  8164. OpInfo('cholesky',
  8165. dtypes=floating_and_complex_types(),
  8166. sample_inputs_func=sample_inputs_linalg_cholesky,
  8167. gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
  8168. decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],),
  8169. OpInfo('cholesky_inverse',
  8170. dtypes=floating_and_complex_types(),
  8171. backward_dtypes=floating_and_complex_types(),
  8172. # https://github.com/pytorch/pytorch/issues/80411
  8173. gradcheck_fast_mode=True,
  8174. supports_fwgrad_bwgrad=True,
  8175. supports_forward_ad=True,
  8176. check_batched_gradgrad=True,
  8177. sample_inputs_func=sample_inputs_linalg_cholesky_inverse,
  8178. gradcheck_wrapper=gradcheck_wrapper_triangular_input_real_positive_diagonal,
  8179. decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
  8180. skips=(
  8181. # Strides are not the same! Original strides were ((4, 2, 1),) and strides are now ((4, 1, 2),)
  8182. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),)),
  8183. OpInfo('cholesky_solve',
  8184. op=torch.cholesky_solve,
  8185. dtypes=floating_and_complex_types(),
  8186. sample_inputs_func=sample_inputs_cholesky_solve,
  8187. check_batched_gradgrad=False,
  8188. supports_forward_ad=True,
  8189. supports_fwgrad_bwgrad=True,
  8190. gradcheck_wrapper=lambda *args, **kwargs: gradcheck_wrapper_triangular_input(*args, idx=1, **kwargs),
  8191. decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack]),
  8192. OpInfo('chunk',
  8193. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
  8194. sample_inputs_func=sample_inputs_chunk,
  8195. reference_inputs_func=reference_inputs_chunk,
  8196. supports_forward_ad=True,
  8197. supports_fwgrad_bwgrad=True,
  8198. supports_out=False),
  8199. OpInfo('clone',
  8200. ref=np.copy,
  8201. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
  8202. sample_inputs_func=sample_inputs_clone_contiguous,
  8203. reference_inputs_func=reference_inputs_clone_contiguous,
  8204. supports_forward_ad=True,
  8205. supports_fwgrad_bwgrad=True,
  8206. supports_out=False,
  8207. skips=(
  8208. # TypeError: _copy_dispatcher() got an unexpected keyword argument 'memory_format'
  8209. # (NumPy reference needs to be extended with memory_format)
  8210. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref'),
  8211. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref_mps'),
  8212. ),),
  8213. OpInfo('contiguous',
  8214. op=lambda x, *args, **kwargs: x.contiguous(*args, **kwargs),
  8215. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
  8216. sample_inputs_func=sample_inputs_clone_contiguous,
  8217. reference_inputs_func=reference_inputs_clone_contiguous,
  8218. supports_forward_ad=True,
  8219. supports_fwgrad_bwgrad=True,
  8220. autodiff_fusible_nodes=['aten::contiguous'],
  8221. assert_jit_shape_analysis=True,
  8222. supports_out=False,
  8223. skips=(
  8224. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  8225. )),
  8226. OpInfo('sum_to_size',
  8227. op=lambda x, *args, **kwargs: x.sum_to_size(*args, **kwargs),
  8228. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
  8229. sample_inputs_func=sample_inputs_sum_to_size,
  8230. error_inputs_func=error_inputs_sum_to_size,
  8231. supports_forward_ad=True,
  8232. supports_fwgrad_bwgrad=True,
  8233. supports_out=False,
  8234. skips=(
  8235. # lambda impl
  8236. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  8237. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float,)),
  8238. )),
  8239. OpInfo('clamp',
  8240. aliases=('clip',),
  8241. ref=_clamp_numpy,
  8242. dtypes=all_types_and(torch.bfloat16),
  8243. dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
  8244. sample_inputs_func=sample_inputs_clamp,
  8245. reference_inputs_func=partial(reference_inputs_elementwise_ternary, sample_inputs_func=sample_inputs_clamp),
  8246. assert_autodiffed=True,
  8247. supports_forward_ad=True,
  8248. supports_fwgrad_bwgrad=True,
  8249. skips=(
  8250. # nvFuser and NNC appear to not handle boolean clamp
  8251. DecorateInfo(unittest.expectedFailure,
  8252. 'TestCudaFuserOpInfo',
  8253. 'test_nvfuser_correctness',
  8254. dtypes=(torch.bool,)),
  8255. DecorateInfo(unittest.expectedFailure,
  8256. 'TestNNCOpInfo',
  8257. 'test_nnc_correctness',
  8258. dtypes=(torch.bool,)),
  8259. )),
  8260. UnaryUfuncInfo('positive',
  8261. ref=np.positive,
  8262. dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf),
  8263. supports_out=False,
  8264. supports_forward_ad=True,
  8265. supports_fwgrad_bwgrad=True,
  8266. supports_sparse=True,
  8267. supports_sparse_csr=True,
  8268. supports_sparse_csc=True,
  8269. supports_sparse_bsr=True,
  8270. supports_sparse_bsc=True,
  8271. ),
  8272. UnaryUfuncInfo('conj',
  8273. ref=np.conj,
  8274. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16,
  8275. torch.half, torch.chalf),
  8276. supports_sparse=True,
  8277. supports_forward_ad=True,
  8278. supports_fwgrad_bwgrad=True,
  8279. # See https://github.com/pytorch/pytorch/pull/78358
  8280. check_batched_forward_grad=False,
  8281. supports_out=False),
  8282. UnaryUfuncInfo('conj_physical',
  8283. decomp_aten_name='_conj_physical',
  8284. ref=np.conj,
  8285. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16,
  8286. torch.half, torch.chalf),
  8287. supports_forward_ad=True,
  8288. supports_fwgrad_bwgrad=True,
  8289. supports_sparse=True,
  8290. supports_sparse_csr=True,
  8291. supports_sparse_csc=True,
  8292. supports_sparse_bsr=True,
  8293. supports_sparse_bsc=True,
  8294. skips=(
  8295. # RuntimeError: inputSet && outputSet
  8296. # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":118,
  8297. # please report a bug to PyTorch.
  8298. DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32, )),
  8299. DecorateInfo(unittest.skip("Skipped! conj_physical_ not implemented for sparse"),
  8300. 'TestSparseUnaryUfuncs', 'test_inplace'),
  8301. )),
  8302. OpInfo('resolve_conj',
  8303. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
  8304. sample_inputs_func=sample_inputs_view_as_real,
  8305. supports_forward_ad=True,
  8306. supports_fwgrad_bwgrad=True,
  8307. supports_out=False,
  8308. ),
  8309. OpInfo('resolve_neg',
  8310. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
  8311. sample_inputs_func=sample_inputs_view_as_real,
  8312. supports_forward_ad=True,
  8313. supports_fwgrad_bwgrad=True,
  8314. supports_out=False,
  8315. ),
  8316. OpInfo('view_as_real',
  8317. dtypes=complex_types(),
  8318. supports_forward_ad=True,
  8319. supports_out=False,
  8320. supports_fwgrad_bwgrad=True,
  8321. sample_inputs_func=sample_inputs_view_as_real,
  8322. test_conjugated_samples=False,
  8323. ),
  8324. OpInfo('view_as_complex',
  8325. dtypes=floating_types_and(torch.half),
  8326. supports_out=False,
  8327. supports_forward_ad=True,
  8328. supports_fwgrad_bwgrad=True,
  8329. test_neg_view=False,
  8330. sample_inputs_func=sample_inputs_view_as_complex,
  8331. skips=(
  8332. # RuntimeError: Tensor must have a last dimension with stride 1
  8333. DecorateInfo(unittest.expectedFailure, "TestCommon", "test_noncontiguous_samples"),
  8334. # RuntimeError: "eq_cpu" not implemented for 'ComplexHalf'
  8335. DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.half,)),
  8336. # RuntimeError: "eq_cpu" not implemented for 'ComplexHalf'
  8337. DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo', 'test_nvfuser_correctness', dtypes=(torch.half,)),
  8338. )),
  8339. BinaryUfuncInfo('complex',
  8340. dtypes=floating_types_and(torch.half),
  8341. supports_forward_ad=True,
  8342. supports_fwgrad_bwgrad=True,
  8343. supports_rhs_python_scalar=False,
  8344. error_inputs_func=error_inputs_complex,
  8345. skips=(
  8346. # Test doesn't account for complex's type promotion semantics
  8347. DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'),
  8348. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out', device_type='mps'),
  8349. )),
  8350. BinaryUfuncInfo('copysign',
  8351. dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
  8352. promotes_int_to_float=True,
  8353. # https://github.com/pytorch/pytorch/issues/80411
  8354. gradcheck_fast_mode=True,
  8355. supports_forward_ad=True,
  8356. supports_fwgrad_bwgrad=True),
  8357. OpInfo('corrcoef',
  8358. dtypes=all_types_and_complex_and(torch.bfloat16),
  8359. dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16),
  8360. sample_inputs_func=sample_inputs_corrcoef,
  8361. supports_forward_ad=True,
  8362. supports_fwgrad_bwgrad=True,
  8363. # See https://github.com/pytorch/pytorch/pull/78358
  8364. check_batched_forward_grad=False,
  8365. skips=(
  8366. # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
  8367. DecorateInfo(
  8368. unittest.skip("Skipped!"),
  8369. 'TestSchemaCheckModeOpInfo',
  8370. 'test_schema_correctness',
  8371. dtypes=(torch.complex64, torch.complex128)),
  8372. ),
  8373. supports_out=False),
  8374. UnaryUfuncInfo('cos',
  8375. ref=np.cos,
  8376. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
  8377. dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
  8378. assert_autodiffed=True,
  8379. handles_large_floats=False,
  8380. supports_forward_ad=True,
  8381. supports_fwgrad_bwgrad=True,
  8382. decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
  8383. skips=(
  8384. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  8385. dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', active_if=IS_WINDOWS),
  8386. # This fails on CUDA but passes on ROCm
  8387. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  8388. dtypes=(torch.cdouble,), device_type='cuda'),
  8389. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  8390. dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS),
  8391. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  8392. device_type='cpu',
  8393. dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS),
  8394. # AssertionError: Tensor-likes are not close!
  8395. # Greatest absolute difference: nan at index (700,) (up to 1e-05 allowed)
  8396. # Greatest relative difference: nan at index (700,) (up to 0.001 allowed)
  8397. DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_large',
  8398. device_type='cuda',
  8399. dtypes=(torch.chalf,), active_if=IS_WINDOWS),
  8400. )),
  8401. UnaryUfuncInfo('cosh',
  8402. ref=np_unary_ufunc_integer_promotion_wrapper(np.cosh),
  8403. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
  8404. dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
  8405. assert_autodiffed=True,
  8406. supports_forward_ad=True,
  8407. supports_fwgrad_bwgrad=True,
  8408. skips=(
  8409. # Reference: https://github.com/pytorch/pytorch/issues/48641
  8410. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  8411. device_type='cpu', dtypes=[torch.int8]),
  8412. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  8413. dtypes=[torch.cdouble]),
  8414. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  8415. dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS),
  8416. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  8417. dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS),
  8418. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  8419. device_type='cpu',
  8420. dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS),
  8421. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  8422. device_type='cpu',
  8423. dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS),
  8424. # AssertionError: Tensor-likes are not close!
  8425. # Greatest absolute difference: nan at index (6000,) (up to 1e-05 allowed)
  8426. # Greatest relative difference: nan at index (6000,) (up to 0.001 allowed)
  8427. DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_large',
  8428. device_type='cuda',
  8429. dtypes=(torch.chalf,), active_if=IS_WINDOWS),
  8430. )),
  8431. OpInfo('cov',
  8432. dtypes=all_types_and_complex_and(torch.bfloat16),
  8433. dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16),
  8434. backward_dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16),
  8435. sample_inputs_func=sample_inputs_cov,
  8436. error_inputs_func=error_inputs_cov,
  8437. supports_out=False,
  8438. supports_forward_ad=True,
  8439. supports_fwgrad_bwgrad=True,
  8440. skips=(
  8441. # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
  8442. DecorateInfo(
  8443. unittest.skip("Skipped!"),
  8444. 'TestSchemaCheckModeOpInfo',
  8445. 'test_schema_correctness',
  8446. dtypes=(torch.complex64, torch.complex128)),
  8447. # Float did not match double
  8448. DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_grad'),
  8449. # Jacobian mismatch
  8450. DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_gradgrad'),
  8451. DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'),
  8452. DecorateInfo(unittest.skip("Barely fails"), 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'),
  8453. # JIT test not working for tensor kwargs (https://github.com/pytorch/pytorch/issues/58507)
  8454. # RuntimeError:
  8455. # undefined value tensor:
  8456. # File "<string>", line 3
  8457. # def the_method(i0):
  8458. # return torch.cov(i0, correction=0, fweights=None, aweights=tensor([0.0518, 0.4681], dtype=torch.float32, requires_grad=True)) # noqa: B950
  8459. # ~~~~~~ <--- HERE
  8460. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  8461. DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo', 'test_nvfuser_extremal_values'),
  8462. )),
  8463. OpInfo('cross',
  8464. dtypes=all_types_and_complex_and(torch.bfloat16),
  8465. dtypesIfCUDA=all_types_and_complex_and(torch.half),
  8466. sample_inputs_func=sample_inputs_cross,
  8467. supports_fwgrad_bwgrad=True,
  8468. supports_out=True,
  8469. supports_forward_ad=True),
  8470. OpInfo('cumsum',
  8471. dtypes=all_types_and_complex_and(torch.bfloat16),
  8472. dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16),
  8473. supports_forward_ad=True,
  8474. supports_fwgrad_bwgrad=True,
  8475. skips=(
  8476. # cumsum does not handle correctly out= dtypes
  8477. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
  8478. ),
  8479. sample_inputs_func=sample_inputs_cumulative_ops),
  8480. OpInfo('cumprod',
  8481. dtypes=all_types_and_complex_and(torch.bfloat16),
  8482. dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
  8483. supports_forward_ad=True,
  8484. supports_fwgrad_bwgrad=True,
  8485. skips=(
  8486. # cumprod does not handle correctly out= dtypes
  8487. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
  8488. ),
  8489. # gradgradcheck fails in fast_mode=True: #56275
  8490. sample_inputs_func=sample_inputs_cumprod,
  8491. gradcheck_fast_mode=False),
  8492. OpInfo('cummax',
  8493. dtypes=all_types_and(torch.bool, torch.bfloat16),
  8494. dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
  8495. sample_inputs_func=partial(sample_inputs_cumulative_ops, supports_dtype_kwargs=False),
  8496. supports_forward_ad=True,
  8497. supports_fwgrad_bwgrad=True,
  8498. skips=(
  8499. ),
  8500. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL),
  8501. OpInfo('cummin',
  8502. dtypes=all_types_and(torch.bool, torch.bfloat16),
  8503. dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
  8504. sample_inputs_func=partial(sample_inputs_cumulative_ops, supports_dtype_kwargs=False),
  8505. supports_forward_ad=True,
  8506. supports_fwgrad_bwgrad=True,
  8507. skips=(
  8508. ),
  8509. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL),
  8510. UnaryUfuncInfo('deg2rad',
  8511. ref=np.radians,
  8512. decorators=(precisionOverride({torch.bfloat16: 7e-1,
  8513. torch.float16: 7e-1}),),
  8514. dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
  8515. supports_forward_ad=True,
  8516. supports_fwgrad_bwgrad=True,
  8517. supports_sparse=True,
  8518. supports_sparse_csr=True,
  8519. supports_sparse_csc=True,
  8520. supports_sparse_bsr=True,
  8521. supports_sparse_bsc=True,
  8522. skips=(
  8523. # Reference: https://github.com/pytorch/pytorch/pull/51283#issuecomment-770614273
  8524. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  8525. dtypes=[torch.bfloat16]),
  8526. )),
  8527. OpInfo('diff',
  8528. op=torch.diff,
  8529. # np.diff has np._NoValue as default values for prepend and append, compare_with_reference breaks if prepend/append
  8530. # are set as None when converting to numpy
  8531. ref=lambda input, n=1, dim=-1, prepend=np._NoValue, append=np._NoValue: (
  8532. np.diff(input, n, dim, np._NoValue if prepend is None else prepend, np._NoValue if append is None else append)
  8533. ),
  8534. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
  8535. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  8536. gradcheck_fast_mode=True,
  8537. supports_forward_ad=True,
  8538. supports_fwgrad_bwgrad=True,
  8539. sample_inputs_func=sample_inputs_diff,
  8540. # See https://github.com/pytorch/pytorch/pull/78358
  8541. check_batched_forward_grad=False,
  8542. skips=(
  8543. )),
  8544. BinaryUfuncInfo('div',
  8545. aliases=('divide',),
  8546. variant_test_name='no_rounding_mode',
  8547. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
  8548. dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
  8549. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  8550. gradcheck_fast_mode=True,
  8551. supports_forward_ad=True,
  8552. promotes_int_to_float=True,
  8553. supports_fwgrad_bwgrad=True,
  8554. supports_two_python_scalars=True,
  8555. assert_autodiffed=True,
  8556. rhs_make_tensor_kwargs=dict(exclude_zero=True),),
  8557. BinaryUfuncInfo('div',
  8558. aliases=('divide',),
  8559. variant_test_name='trunc_rounding',
  8560. dtypes=all_types_and(torch.half, torch.bfloat16),
  8561. sample_inputs_func=partial(sample_inputs_elementwise_binary, sample_kwargs=dict(rounding_mode="trunc")),
  8562. # https://github.com/pytorch/pytorch/issues/80411
  8563. gradcheck_fast_mode=True,
  8564. supports_forward_ad=True,
  8565. promotes_int_to_float=True,
  8566. supports_fwgrad_bwgrad=True,
  8567. supports_two_python_scalars=True,
  8568. assert_autodiffed=True,
  8569. rhs_make_tensor_kwargs=dict(exclude_zero=True),
  8570. skips=(
  8571. # RuntimeError: MALFORMED INPUT: Unhandled node kind (in computeValue): aten::div
  8572. DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_working'),
  8573. )),
  8574. BinaryUfuncInfo('div',
  8575. aliases=('divide',),
  8576. variant_test_name='floor_rounding',
  8577. dtypes=all_types_and(torch.half, torch.bfloat16),
  8578. sample_inputs_func=partial(sample_inputs_elementwise_binary, sample_kwargs=dict(rounding_mode="floor")),
  8579. # https://github.com/pytorch/pytorch/issues/80411
  8580. gradcheck_fast_mode=True,
  8581. supports_forward_ad=True,
  8582. promotes_int_to_float=True,
  8583. supports_fwgrad_bwgrad=True,
  8584. supports_two_python_scalars=True,
  8585. assert_autodiffed=True,
  8586. rhs_make_tensor_kwargs=dict(exclude_zero=True),
  8587. skips=(
  8588. # RuntimeError: MALFORMED INPUT: Unhandled node kind (in computeValue): aten::div
  8589. DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_working'),
  8590. )),
  8591. BinaryUfuncInfo('true_divide',
  8592. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
  8593. dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
  8594. supports_forward_ad=True,
  8595. promotes_int_to_float=True,
  8596. supports_fwgrad_bwgrad=True,
  8597. supports_two_python_scalars=True,
  8598. rhs_make_tensor_kwargs=dict(exclude_zero=True)),
  8599. OpInfo('equal',
  8600. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
  8601. ref=lambda input, other: (input == other).all(),
  8602. sample_inputs_func=sample_inputs_equal,
  8603. supports_autograd=False,
  8604. supports_tracing=False,
  8605. skips=(
  8606. )),
  8607. UnaryUfuncInfo('exp',
  8608. ref=np_unary_ufunc_integer_promotion_wrapper(np.exp),
  8609. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
  8610. dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
  8611. skips=(
  8612. # Reference: https://github.com/pytorch/pytorch/pull/50093#pullrequestreview-561791547
  8613. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  8614. dtypes=[torch.bfloat16]),
  8615. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  8616. dtypes=[torch.bfloat16, torch.cdouble]),
  8617. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
  8618. dtypes=[torch.bfloat16]),
  8619. # Reference: https://github.com/pytorch/pytorch/issues/48010
  8620. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  8621. device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
  8622. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  8623. device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
  8624. ),
  8625. assert_autodiffed=True,
  8626. supports_forward_ad=True,
  8627. supports_fwgrad_bwgrad=True),
  8628. OpInfo('expand',
  8629. op=lambda self, shape: self.expand(shape),
  8630. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
  8631. sample_inputs_func=sample_inputs_expand,
  8632. supports_forward_ad=True,
  8633. supports_fwgrad_bwgrad=True,
  8634. assert_jit_shape_analysis=True,
  8635. supports_out=False,
  8636. skips=(
  8637. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  8638. )),
  8639. OpInfo('expand_as',
  8640. op=lambda self, other: self.expand_as(other),
  8641. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
  8642. supports_forward_ad=True,
  8643. supports_fwgrad_bwgrad=True,
  8644. sample_inputs_func=sample_inputs_expand_as,
  8645. supports_out=False,
  8646. skips=(
  8647. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),),
  8648. ),
  8649. OpInfo('diag',
  8650. ref=np.diag,
  8651. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
  8652. dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
  8653. supports_forward_ad=True,
  8654. supports_fwgrad_bwgrad=True,
  8655. check_batched_forward_grad=False,
  8656. sample_inputs_func=sample_inputs_diag,
  8657. error_inputs_func=error_inputs_diag),
  8658. OpInfo('diag_embed',
  8659. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
  8660. supports_out=False,
  8661. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  8662. gradcheck_fast_mode=True,
  8663. supports_forward_ad=True,
  8664. supports_fwgrad_bwgrad=True,
  8665. sample_inputs_func=sample_inputs_diagonal_diag_embed,
  8666. reference_inputs_func=reference_inputs_diagonal_diag_embed,
  8667. error_inputs_func=error_inputs_diagonal_diag_embed),
  8668. OpInfo('diagonal',
  8669. # They are not strictly aliases as they have diverging defaults, but we can see them as aliases for testing purposes
  8670. # If we add tests that test the function against the alias, make linalg.diagonal into its own OpInfo
  8671. aliases=('linalg.diagonal',),
  8672. aten_backward_name='diagonal_backward',
  8673. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
  8674. supports_out=False,
  8675. supports_forward_ad=True,
  8676. supports_fwgrad_bwgrad=True,
  8677. sample_inputs_func=sample_inputs_diagonal_diag_embed,
  8678. reference_inputs_func=reference_inputs_diagonal_diag_embed,
  8679. error_inputs_func=error_inputs_diagonal_diag_embed),
  8680. OpInfo('diagonal_copy',
  8681. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
  8682. supports_forward_ad=True,
  8683. supports_fwgrad_bwgrad=True,
  8684. sample_inputs_func=sample_inputs_diagonal_diag_embed,
  8685. reference_inputs_func=reference_inputs_diagonal_diag_embed,
  8686. error_inputs_func=error_inputs_diagonal_diag_embed),
  8687. OpInfo('diagonal_scatter',
  8688. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
  8689. supports_out=False,
  8690. supports_forward_ad=True,
  8691. supports_fwgrad_bwgrad=True,
  8692. sample_inputs_func=sample_inputs_diagonal_scatter),
  8693. BinaryUfuncInfo('eq',
  8694. ref=np.equal,
  8695. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
  8696. always_returns_bool=True,
  8697. supports_autograd=False,
  8698. sample_inputs_func=sample_inputs_comparison_ops,
  8699. skips=(
  8700. )),
  8701. BinaryUfuncInfo('fmax',
  8702. op=torch.fmax,
  8703. dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
  8704. supports_forward_ad=True,
  8705. supports_fwgrad_bwgrad=True,
  8706. supports_rhs_python_scalar=False,
  8707. skips=(
  8708. # RuntimeError: "max_elementwise_cuda" not implemented for 'ComplexFloat'
  8709. DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'),
  8710. )),
  8711. BinaryUfuncInfo('fmin',
  8712. op=torch.fmin,
  8713. dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
  8714. supports_forward_ad=True,
  8715. supports_fwgrad_bwgrad=True,
  8716. supports_rhs_python_scalar=False,
  8717. skips=(
  8718. # RuntimeError: "min_elementwise_cuda" not implemented for 'ComplexFloat'
  8719. DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'),
  8720. )),
  8721. BinaryUfuncInfo('fmod',
  8722. ref=np.fmod,
  8723. dtypes=all_types_and(torch.float16, torch.bfloat16),
  8724. dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
  8725. # https://github.com/pytorch/pytorch/issues/80411
  8726. gradcheck_fast_mode=True,
  8727. supports_forward_ad=True,
  8728. supports_fwgrad_bwgrad=True,
  8729. assert_autodiffed=None,
  8730. rhs_make_tensor_kwargs={'exclude_zero': True},
  8731. decorators=(
  8732. DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
  8733. 'test_contig_vs_every_other',
  8734. dtypes=(torch.bfloat16,)),
  8735. DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
  8736. 'test_non_contig',
  8737. dtypes=(torch.bfloat16,)),
  8738. DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
  8739. 'test_reference_numerics',
  8740. dtypes=(torch.bfloat16,)),
  8741. DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
  8742. 'test_reference_numerics_small_values',
  8743. dtypes=(torch.uint8,)),
  8744. )),
  8745. BinaryUfuncInfo('remainder',
  8746. ref=np.remainder,
  8747. dtypes=all_types_and(torch.float16, torch.bfloat16),
  8748. dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
  8749. # https://github.com/pytorch/pytorch/issues/80411
  8750. gradcheck_fast_mode=True,
  8751. supports_forward_ad=True,
  8752. supports_fwgrad_bwgrad=True,
  8753. assert_autodiffed=None,
  8754. operator_variant=operator.mod,
  8755. inplace_operator_variant=operator.imod,
  8756. supports_one_python_scalar=True,
  8757. rhs_make_tensor_kwargs={'exclude_zero': True},
  8758. decorators=(
  8759. DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
  8760. 'test_contig_vs_every_other',
  8761. dtypes=(torch.bfloat16,)),
  8762. DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
  8763. 'test_non_contig',
  8764. dtypes=(torch.bfloat16,)),
  8765. DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
  8766. 'test_reference_numerics',
  8767. dtypes=(torch.bfloat16,)),
  8768. DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
  8769. 'test_reference_numerics_small_values',
  8770. dtypes=(torch.uint8,)),
  8771. DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo',
  8772. 'test_nnc_correctness',
  8773. dtypes=(torch.bfloat16,)),
  8774. # Fails on XLA
  8775. # False is not true : Tensors failed to compare as equal!
  8776. # Attempted to compare equality of tensors with different dtypes
  8777. DecorateInfo(unittest.skip("Skipped!"), 'TestOpInfo', device_type='xla', dtypes=(torch.long,)),
  8778. )),
  8779. UnaryUfuncInfo('frac',
  8780. ref=lambda x: np.modf(x)[0],
  8781. dtypes=floating_types_and(torch.bfloat16, torch.float16),
  8782. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  8783. assert_autodiffed=True,
  8784. supports_forward_ad=True,
  8785. supports_fwgrad_bwgrad=True,
  8786. supports_sparse=True,
  8787. supports_sparse_csr=True,
  8788. supports_sparse_csc=True,
  8789. supports_sparse_bsr=True,
  8790. supports_sparse_bsc=True,
  8791. skips=(
  8792. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  8793. dtypes=(torch.bfloat16, torch.float16, torch.float32, torch.float64)),
  8794. # 76047
  8795. DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness',
  8796. dtypes=(torch.bfloat16, torch.float32, torch.float64)),
  8797. )),
  8798. OpInfo('stft',
  8799. decorators=[
  8800. skipCPUIfNoFFT,
  8801. DecorateInfo(unittest.skip("Skipped! stft does not match the native function"),
  8802. 'TestJit', 'test_variant_consistency_jit'),
  8803. ],
  8804. dtypes=floating_and_complex_types(),
  8805. sample_inputs_func=sample_inputs_stft,
  8806. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  8807. gradcheck_fast_mode=True,
  8808. supports_forward_ad=True,
  8809. supports_fwgrad_bwgrad=True,
  8810. check_batched_forward_grad=False,
  8811. check_batched_grad=False,
  8812. check_batched_gradgrad=False,
  8813. supports_out=False,
  8814. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  8815. ),
  8816. OpInfo('istft',
  8817. dtypes=complex_types(),
  8818. sample_inputs_func=sample_inputs_istft,
  8819. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  8820. gradcheck_fast_mode=True,
  8821. supports_forward_ad=True,
  8822. supports_fwgrad_bwgrad=True,
  8823. check_batched_forward_grad=False,
  8824. check_batched_grad=False,
  8825. check_batched_gradgrad=False,
  8826. supports_out=False,
  8827. decorators=(
  8828. DecorateInfo(unittest.skip("Skipped! istft does not match the native function"),
  8829. 'TestJit', 'test_variant_consistency_jit'),
  8830. ),
  8831. skips=(
  8832. skipCPUIfNoFFT,
  8833. # gradcheck fails on ROCm (gh-68429)
  8834. # grad is computed improperly (probably for weights tensor)
  8835. DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_grad'),
  8836. # Pre-existing condition (calls .item); needs to be fixed
  8837. DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
  8838. )),
  8839. UnaryUfuncInfo('floor',
  8840. ref=np.floor,
  8841. dtypes=all_types_and(torch.bfloat16),
  8842. dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
  8843. supports_forward_ad=True,
  8844. supports_fwgrad_bwgrad=True,
  8845. skips=(
  8846. DecorateInfo(unittest.expectedFailure,
  8847. 'TestNNCOpInfo',
  8848. 'test_nnc_correctness',
  8849. dtypes=tuple(t for t in integral_types() if t != torch.uint8)),
  8850. DecorateInfo(unittest.expectedFailure,
  8851. 'TestCudaFuserOpInfo',
  8852. 'test_nvfuser_correctness',
  8853. dtypes=(torch.int32, torch.int64),
  8854. active_if=not TEST_WITH_ROCM),
  8855. ),
  8856. supports_sparse=True,
  8857. supports_sparse_csr=True,
  8858. supports_sparse_csc=True,
  8859. supports_sparse_bsr=True,
  8860. supports_sparse_bsc=True,
  8861. assert_autodiffed=True),
  8862. OpInfo('flip',
  8863. op=torch.flip,
  8864. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
  8865. sample_inputs_func=sample_inputs_flip,
  8866. supports_forward_ad=True,
  8867. supports_fwgrad_bwgrad=True,
  8868. supports_out=False),
  8869. OpInfo('fliplr',
  8870. op=torch.fliplr,
  8871. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
  8872. sample_inputs_func=sample_inputs_fliplr_flipud,
  8873. error_inputs_func=error_inputs_fliplr,
  8874. supports_forward_ad=True,
  8875. supports_fwgrad_bwgrad=True,
  8876. supports_out=False),
  8877. OpInfo('flipud',
  8878. op=torch.flipud,
  8879. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
  8880. sample_inputs_func=sample_inputs_fliplr_flipud,
  8881. error_inputs_func=error_inputs_flipud,
  8882. supports_forward_ad=True,
  8883. supports_fwgrad_bwgrad=True,
  8884. supports_out=False),
  8885. OpInfo('sparse.sampled_addmm',
  8886. dtypes=floating_and_complex_types(),
  8887. supports_autograd=True,
  8888. sample_inputs_func=sample_inputs_sparse_sampled_addmm,
  8889. decorators=[
  8890. skipCUDAIf(not ((_get_torch_cuda_version() >= (11, 3))
  8891. or (_get_torch_rocm_version() >= (5, 2))),
  8892. "cusparseSDDMM was added in 11.2.1"),
  8893. skipCPUIfNoMklSparse, ],
  8894. skips=(
  8895. # NotImplementedError: Tensors of type SparseCsrTensorImpl do not have is_contiguous
  8896. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'),
  8897. # RuntimeError: Sparse CSR tensors do not have strides.
  8898. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'),
  8899. DecorateInfo(unittest.skip("Skipped!"), 'TestTags', 'test_tags'),
  8900. # RuntimeError: sampled_addmm: Expected result to have sparse csr layout, but got Strided
  8901. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out_warning'),
  8902. # RuntimeError: Sparse CSR tensors do not have strides
  8903. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'),
  8904. # RuntimeError: Sparse CSR tensors do not have strides
  8905. DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_operator'),
  8906. # RuntimeError: Sparse CSR tensors do not have strides
  8907. DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward'),
  8908. # RuntimeError: Sparse CSR tensors do not have strides
  8909. DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'),
  8910. # RuntimeError: Sparse CSR tensors do not have strides
  8911. DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
  8912. # RuntimeError: Sparse CSR tensors do not have strides
  8913. DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
  8914. # RuntimeError: Sparse CSR tensors do not have strides
  8915. DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
  8916. # RuntimeError: unsupported memory format option Preserve
  8917. DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
  8918. # GradcheckError: gradcheck expects all tensor inputs are dense when check_sparse_nnz is set to False
  8919. DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'),
  8920. # GradcheckError: gradcheck expects all tensor inputs are dense when check_sparse_nnz is set to False
  8921. DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'),
  8922. # GradcheckError: gradcheck expects all tensor inputs are dense when check_sparse_nnz is set to False
  8923. DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'),
  8924. # GradcheckError: gradcheck expects all tensor inputs are dense when check_sparse_nnz is set to False
  8925. DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
  8926. )),
  8927. OpInfo('sparse.mm',
  8928. dtypes=floating_types_and(torch.bfloat16),
  8929. variant_test_name='reduce',
  8930. supports_autograd=True,
  8931. supports_out=False,
  8932. supports_gradgrad=False,
  8933. supports_forward_ad=False,
  8934. sample_inputs_func=sample_inputs_sparse_mm_reduce,
  8935. decorators=[onlyCPU],
  8936. skips=(
  8937. # NotImplementedError: Tensors of type SparseCsrTensorImpl do not have is_contiguous
  8938. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'),
  8939. # RuntimeError: Sparse CSR tensors do not have strides.
  8940. DecorateInfo(unittest.skip("Skipped!"), 'TestTags', 'test_tags'),
  8941. # RuntimeError: Sparse CSR tensors do not have strides
  8942. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'),
  8943. # RuntimeError: Sparse CSR tensors do not have strides
  8944. DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_operator'),
  8945. # RuntimeError: Sparse CSR tensors do not have strides
  8946. DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward'),
  8947. # RuntimeError: Sparse CSR tensors do not have strides
  8948. DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'),
  8949. # RuntimeError: Sparse CSR tensors do not have strides
  8950. DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
  8951. # RuntimeError: Sparse CSR tensors do not have strides
  8952. DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
  8953. # RuntimeError: Sparse CSR tensors do not have strides
  8954. DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
  8955. # RuntimeError: unsupported memory format option Preserve
  8956. DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
  8957. # GradcheckError: gradcheck expects all tensor inputs are dense when check_sparse_nnz is set to False
  8958. DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'),
  8959. # GradcheckError: gradcheck expects all tensor inputs are dense when check_sparse_nnz is set to False
  8960. DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'),
  8961. # GradcheckError: gradcheck expects all tensor inputs are dense when check_sparse_nnz is set to False
  8962. DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'),
  8963. # GradcheckError: gradcheck expects all tensor inputs are dense when check_sparse_nnz is set to False
  8964. DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
  8965. # GradcheckError: gradcheck expects all tensor inputs are dense when check_sparse_nnz is set to False
  8966. DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_fail_gradgrad'),
  8967. )),
  8968. UnaryUfuncInfo('i0',
  8969. ref=np_unary_ufunc_integer_promotion_wrapper(
  8970. scipy.special.i0) if TEST_SCIPY else None,
  8971. aliases=('special.i0',),
  8972. decorators=(precisionOverride({torch.bfloat16: 3e-1,
  8973. torch.float16: 5e-1}),),
  8974. dtypes=all_types_and(torch.bool, torch.bfloat16),
  8975. dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
  8976. backward_dtypes=floating_types(),
  8977. supports_forward_ad=True,
  8978. supports_fwgrad_bwgrad=True,
  8979. sample_inputs_func=sample_inputs_i0_i1,
  8980. skips=(
  8981. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  8982. dtypes=(torch.int8,)),
  8983. )),
  8984. BinaryUfuncInfo('floor_divide',
  8985. ref=_floor_divide_np,
  8986. dtypes=all_types_and(torch.half, torch.bfloat16),
  8987. supports_autograd=False,
  8988. rhs_make_tensor_kwargs=dict(exclude_zero=True),
  8989. supports_two_python_scalars=True,
  8990. skips=(
  8991. # AssertionError: Results of original model and exported/imported version of model differed
  8992. DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'),
  8993. # bfloat16 floor_divide compared with a float32 reference works inconsistently
  8994. DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs',
  8995. dtypes=(torch.bfloat16,)),
  8996. # int8 floor divide has different results for -128 // -1 vs. NumPy
  8997. DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs', 'test_reference_numerics_small_values',
  8998. dtypes=(torch.int8,)),
  8999. # The following tests fails on some jobs
  9000. DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs', 'test_reference_numerics_extremal_values',
  9001. dtypes=(torch.float16,)),
  9002. DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=5e-3)}),
  9003. 'TestBinaryUfuncs', 'test_reference_numerics'),
  9004. )),
  9005. UnaryUfuncInfo('frexp',
  9006. op=torch.frexp,
  9007. ref=np.frexp,
  9008. dtypes=floating_types_and(torch.half, torch.bfloat16),
  9009. dtypesIfCUDA=floating_types_and(torch.half),
  9010. # skip testing torch.frexp as it is not supported by ROCm platform yet
  9011. decorators=[],
  9012. supports_forward_ad=True,
  9013. supports_fwgrad_bwgrad=True,
  9014. skips=(
  9015. # skips below tests as torch.frexp returns tuple-like (mantissa, exponent) as outputs,
  9016. # while theses tests currently requires output to a single tensor.
  9017. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_batch_vs_slicing'),
  9018. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_contig_vs_every_other'),
  9019. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_contig_vs_transposed'),
  9020. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_non_contig_expand'),
  9021. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_variant_consistency'),
  9022. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_out_arg_all_dtypes'),
  9023. # skips test_reference_numerics due to error in Windows CI.
  9024. # The np.frexp returns exponent as np.intc dtype on Windows platform,
  9025. # and np.intc does not have the correspond torch dtype
  9026. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
  9027. active_if=IS_WINDOWS),
  9028. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  9029. active_if=IS_WINDOWS),
  9030. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  9031. active_if=IS_WINDOWS),
  9032. )),
  9033. UnaryUfuncInfo('log1p',
  9034. ref=np.log1p,
  9035. aliases=('special.log1p',),
  9036. domain=(-1, None),
  9037. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
  9038. dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
  9039. decorators=(precisionOverride({torch.bfloat16: 1e-1}),),
  9040. supports_forward_ad=True,
  9041. supports_fwgrad_bwgrad=True,
  9042. supports_sparse=True,
  9043. supports_sparse_csr=True,
  9044. supports_sparse_csc=True,
  9045. supports_sparse_bsr=True,
  9046. supports_sparse_bsc=True,
  9047. assert_autodiffed=True),
  9048. BinaryUfuncInfo('ge',
  9049. ref=np.greater_equal,
  9050. aliases=('greater_equal',),
  9051. dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16),
  9052. always_returns_bool=True,
  9053. supports_autograd=False,
  9054. skips=(
  9055. )),
  9056. OpInfo('geqrf',
  9057. dtypes=floating_and_complex_types(),
  9058. sample_inputs_func=sample_inputs_linalg_qr_geqrf,
  9059. decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
  9060. supports_autograd=False,
  9061. skips=(
  9062. # FIXME: geqrf can't forward with complex inputs that require grad
  9063. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'),
  9064. # Strides are not the same!
  9065. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
  9066. )),
  9067. BinaryUfuncInfo('gt',
  9068. ref=np.greater,
  9069. aliases=('greater',),
  9070. dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16),
  9071. always_returns_bool=True,
  9072. supports_autograd=False,
  9073. skips=(
  9074. )),
  9075. UnaryUfuncInfo('imag',
  9076. ref=np.imag,
  9077. dtypes=complex_types_and(torch.chalf),
  9078. supports_out=False,
  9079. supports_forward_ad=True,
  9080. supports_fwgrad_bwgrad=True,
  9081. # See https://github.com/pytorch/pytorch/issues/66357
  9082. # RuntimeError: view_as_real doesn't work on unresolved conjugated tensors.
  9083. check_batched_forward_grad=False,
  9084. skips=(
  9085. # Skip since real and imag don't have out variants.
  9086. DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_out_arg_all_dtypes'),
  9087. )),
  9088. OpInfo('gradient',
  9089. dtypes=floating_and_complex_types_and(torch.int8, torch.int16,
  9090. torch.int32, torch.int64,
  9091. torch.bfloat16, torch.half),
  9092. supports_out=False,
  9093. supports_forward_ad=True,
  9094. supports_fwgrad_bwgrad=True,
  9095. # See https://github.com/pytorch/pytorch/pull/78358
  9096. check_batched_forward_grad=False,
  9097. skips=(
  9098. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  9099. # following tests give a runtime error with undefined value tensor
  9100. # see discussion : https://github.com/pytorch/pytorch/issues/56660
  9101. # RuntimeError:
  9102. # Arguments for call are not valid.
  9103. DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32, torch.complex64)), # noqa: B950
  9104. DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
  9105. DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'),
  9106. ),
  9107. supports_inplace_autograd=False,
  9108. sample_inputs_func=sample_inputs_gradient,
  9109. error_inputs_func=error_inputs_gradient),
  9110. OpInfo('isin',
  9111. dtypes=all_types(),
  9112. dtypesIfCUDA=all_types_and(torch.half),
  9113. supports_autograd=False,
  9114. sample_inputs_func=sample_inputs_isin),
  9115. OpInfo('kthvalue',
  9116. dtypes=all_types_and(torch.bfloat16),
  9117. dtypesIfCUDA=all_types_and(torch.float16),
  9118. supports_forward_ad=True,
  9119. supports_fwgrad_bwgrad=True,
  9120. sample_inputs_func=sample_inputs_kthvalue,
  9121. error_inputs_func=error_inputs_kthvalue),
  9122. BinaryUfuncInfo('le',
  9123. ref=np.less_equal,
  9124. aliases=('less_equal',),
  9125. dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16),
  9126. always_returns_bool=True,
  9127. supports_autograd=False,
  9128. skips=(
  9129. )),
  9130. OpInfo('linspace',
  9131. dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16),
  9132. is_factory_function=True,
  9133. supports_out=True,
  9134. supports_autograd=False,
  9135. error_inputs_func=error_inputs_linspace,
  9136. sample_inputs_func=sample_inputs_linspace,
  9137. skips=(
  9138. # Tests that assume input is a tensor or sequence of tensors
  9139. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
  9140. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
  9141. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
  9142. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
  9143. # Same failure as arange: cannot find linspace in captured graph
  9144. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
  9145. # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
  9146. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
  9147. # UserWarning: CUDA caching allocator reports a memory leak not verified by the driver API
  9148. # in __main__.TestJitCUDA.test_variant_consistency_jit_logspace_cuda_complex64!
  9149. # Caching allocator allocated memory was 0 and is now reported as 307200 on device 0.
  9150. # CUDA driver allocated memory was 1254555648 and is now 1242955776.
  9151. DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
  9152. dtypes=(torch.cfloat,), device_type="cuda"),
  9153. )),
  9154. OpInfo('logspace',
  9155. dtypes=all_types_and_complex_and(torch.bfloat16),
  9156. dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16),
  9157. is_factory_function=True,
  9158. supports_out=True,
  9159. supports_autograd=False,
  9160. error_inputs_func=error_inputs_linspace,
  9161. sample_inputs_func=sample_inputs_logpace,
  9162. skips=(
  9163. # Tests that assume input is a tensor or sequence of tensors
  9164. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
  9165. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
  9166. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
  9167. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
  9168. # Same failure as arange: cannot find linspace in captured graph
  9169. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
  9170. # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
  9171. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
  9172. # Off-by-one issue when casting floats to ints
  9173. DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick',
  9174. dtypes=(torch.int16, torch.int32, torch.int64), device_type="cuda"),
  9175. DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_comprehensive',
  9176. dtypes=(torch.int16, torch.int32, torch.int64), device_type="cuda"),
  9177. # UserWarning: CUDA caching allocator reports a memory leak not verified by the driver API
  9178. # in __main__.TestJitCUDA.test_variant_consistency_jit_logspace_cuda_complex64!
  9179. # Caching allocator allocated memory was 0 and is now reported as 307200 on device 0.
  9180. # CUDA driver allocated memory was 1254555648 and is now 1242955776.
  9181. DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
  9182. dtypes=(torch.cfloat,), device_type="cuda"),
  9183. )),
  9184. UnaryUfuncInfo('log',
  9185. ref=np.log,
  9186. domain=(0, None),
  9187. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
  9188. dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
  9189. backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.chalf),
  9190. assert_autodiffed=True,
  9191. supports_forward_ad=True,
  9192. supports_fwgrad_bwgrad=True,
  9193. decorators=(precisionOverride({torch.bfloat16: 5e-2}),),
  9194. skips=(
  9195. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  9196. device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
  9197. active_if=IS_WINDOWS),
  9198. ),
  9199. # log(z)->-inf for |z|->0
  9200. reference_numerics_filter=NumericsFilter(condition=lambda x: torch.abs(x) < 0.1, safe_val=1)),
  9201. UnaryUfuncInfo('log10',
  9202. ref=np.log10,
  9203. domain=(0, None),
  9204. decorators=(precisionOverride({torch.bfloat16: 5e-2}),),
  9205. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
  9206. assert_autodiffed=True,
  9207. dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
  9208. supports_forward_ad=True,
  9209. supports_fwgrad_bwgrad=True,
  9210. skips=(
  9211. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  9212. device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
  9213. active_if=IS_WINDOWS),
  9214. ),
  9215. # log10(z)->-inf for |z|->0
  9216. reference_numerics_filter=NumericsFilter(condition=lambda x: torch.abs(x) < 0.1, safe_val=1)),
  9217. UnaryUfuncInfo('log2',
  9218. ref=np.log2,
  9219. domain=(0, None),
  9220. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
  9221. dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
  9222. assert_autodiffed=True,
  9223. supports_forward_ad=True,
  9224. supports_fwgrad_bwgrad=True,
  9225. decorators=(precisionOverride({torch.bfloat16: 1e-1}),),
  9226. skips=(
  9227. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  9228. dtypes=[torch.cfloat, torch.cdouble]),
  9229. ),
  9230. # log2(z)->-inf for |z|->0
  9231. reference_numerics_filter=NumericsFilter(condition=lambda x: torch.abs(x) < 0.1, safe_val=1)),
  9232. BinaryUfuncInfo('ldexp',
  9233. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
  9234. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  9235. gradcheck_fast_mode=True,
  9236. supports_forward_ad=True,
  9237. supports_fwgrad_bwgrad=True,
  9238. supports_inplace_autograd=False,
  9239. promotes_int_to_float=True,
  9240. supports_out=True,
  9241. supports_rhs_python_scalar=False,
  9242. skips=(
  9243. # RuntimeError: mul(): functions with out=... arguments don't support
  9244. # automatic differentiation, but one of the arguments requires grad
  9245. # https://github.com/pytorch/pytorch/issues/68966
  9246. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
  9247. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
  9248. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
  9249. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
  9250. ),
  9251. decorators=[
  9252. DecorateInfo(
  9253. toleranceOverride({
  9254. torch.complex64: tol(atol=1e-05, rtol=1e-05)
  9255. }),
  9256. 'TestCommon', device_type='cpu',
  9257. ),
  9258. ], ),
  9259. BinaryUfuncInfo('logaddexp',
  9260. dtypes=floating_types_and(torch.bfloat16),
  9261. dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16),
  9262. dtypesIfROCM=floating_types_and(torch.bfloat16, torch.float16),
  9263. supports_forward_ad=True,
  9264. supports_fwgrad_bwgrad=True,
  9265. supports_rhs_python_scalar=False,
  9266. skips=(
  9267. # TODO: FIXME: RuntimeError: not implemented for 'ComplexFloat'
  9268. DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion', device_type='cuda'),
  9269. )),
  9270. OpInfo('logaddexp2',
  9271. dtypes=floating_types_and(torch.bfloat16),
  9272. dtypesIfCUDA=floating_types_and(torch.bfloat16),
  9273. dtypesIfROCM=floating_types_and(torch.bfloat16),
  9274. supports_forward_ad=True,
  9275. supports_fwgrad_bwgrad=True,
  9276. sample_inputs_func=sample_inputs_logaddexp),
  9277. UnaryUfuncInfo('logical_not',
  9278. ref=np.logical_not,
  9279. decorators=(precisionOverride({torch.bfloat16: 7e-1,
  9280. torch.float16: 5e-1}),),
  9281. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
  9282. supports_autograd=False,
  9283. skips=(
  9284. # The function variant always returns BoolTensor
  9285. # while the inplace variant preserves the input dtype.
  9286. # >>> t = torch.randn(3)
  9287. # >>> torch.logical_not(t)
  9288. # tensor([False, False, False])
  9289. # >>> torch.logical_not(t).dtype
  9290. # torch.bool
  9291. # >>> t.logical_not_().dtype
  9292. # torch.float32
  9293. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_variant_consistency',
  9294. dtypes=all_types_and_complex_and(torch.half, torch.bfloat16)),
  9295. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager',
  9296. dtypes=all_types_and_complex_and(torch.half, torch.bfloat16)),
  9297. )),
  9298. BinaryUfuncInfo('lt',
  9299. ref=np.less,
  9300. aliases=('less',),
  9301. dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16),
  9302. always_returns_bool=True,
  9303. supports_autograd=False,
  9304. skips=(
  9305. )),
  9306. OpInfo('lu_unpack',
  9307. op=torch.lu_unpack,
  9308. dtypes=floating_and_complex_types(),
  9309. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  9310. gradcheck_fast_mode=True,
  9311. supports_forward_ad=True,
  9312. supports_fwgrad_bwgrad=True,
  9313. skips=(skipCPUIfNoLapack,),
  9314. sample_inputs_func=sample_inputs_lu_unpack),
  9315. OpInfo('lu',
  9316. op=torch.lu,
  9317. dtypes=floating_and_complex_types(),
  9318. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  9319. gradcheck_fast_mode=True,
  9320. supports_forward_ad=True,
  9321. supports_fwgrad_bwgrad=True,
  9322. # https://github.com/pytorch/pytorch/issues/66357
  9323. check_batched_forward_grad=False,
  9324. sample_inputs_func=sample_inputs_lu,
  9325. decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
  9326. skips=(
  9327. # we skip jit tests because `lu` is a torch function
  9328. # RuntimeError:
  9329. # 'Tensor (inferred)' object has no attribute or method 'lu'.:
  9330. # File "<string>", line 3
  9331. # def the_method(i0):
  9332. # return i0.lu(True, True)
  9333. # ~~~~~ <--- HERE
  9334. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  9335. # RuntimeError not raised: Expected RuntimeError when calling with input.device=cpu and out.device=cuda
  9336. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
  9337. # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
  9338. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
  9339. )),
  9340. OpInfo('lu_solve',
  9341. op=torch.lu_solve,
  9342. dtypes=floating_and_complex_types(),
  9343. supports_forward_ad=True,
  9344. # See https://github.com/pytorch/pytorch/issues/66357
  9345. check_batched_forward_grad=False,
  9346. supports_fwgrad_bwgrad=True,
  9347. sample_inputs_func=sample_inputs_lu_solve,
  9348. skips=(
  9349. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out',
  9350. device_type='mps', dtypes=[torch.float32]),
  9351. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager',
  9352. device_type='mps', dtypes=[torch.float32]),
  9353. DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
  9354. device_type='mps', dtypes=[torch.float32]),
  9355. DecorateInfo(unittest.skip("Tests different backward paths"),
  9356. "TestCommon", "test_floating_inputs_are_differentiable"),),
  9357. decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagmaAndNoCusolver]),
  9358. OpInfo('masked_fill',
  9359. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
  9360. sample_inputs_func=sample_inputs_masked_fill,
  9361. error_inputs_func=error_inputs_masked_fill,
  9362. supports_forward_ad=True,
  9363. supports_fwgrad_bwgrad=True,
  9364. check_batched_forward_grad=False,
  9365. supports_out=False),
  9366. OpInfo('masked_scatter',
  9367. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
  9368. sample_inputs_func=sample_inputs_masked_scatter,
  9369. supports_forward_ad=True,
  9370. supports_fwgrad_bwgrad=True,
  9371. # https://github.com/pytorch/pytorch/issues/66357
  9372. check_batched_forward_grad=False,
  9373. supports_out=False,
  9374. skips=(
  9375. )),
  9376. OpInfo('masked_select',
  9377. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
  9378. supports_forward_ad=True,
  9379. supports_fwgrad_bwgrad=True,
  9380. sample_inputs_func=sample_inputs_masked_select,
  9381. error_inputs_func=error_inputs_masked_select,
  9382. skips=(
  9383. # Compiler issue on ROCm. Might need to skip until ROCm5.5
  9384. DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values',
  9385. dtypes=[torch.bool], active_if=TEST_WITH_ROCM),
  9386. )),
  9387. OpInfo('matrix_exp',
  9388. dtypes=floating_and_complex_types_and(torch.bfloat16),
  9389. dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
  9390. aliases=('linalg.matrix_exp',),
  9391. sample_inputs_func=sample_inputs_matrix_exp,
  9392. # Needs to construct a 2nx2n matrix by copy_ ing into it
  9393. check_batched_grad=False,
  9394. check_batched_gradgrad=False,
  9395. supports_forward_ad=True,
  9396. supports_fwgrad_bwgrad=True,
  9397. # https://github.com/pytorch/pytorch/issues/66357
  9398. check_batched_forward_grad=False,
  9399. skips=(
  9400. # times out
  9401. DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo', 'test_nvfuser_extremal_values'),
  9402. ),
  9403. supports_out=False,
  9404. ),
  9405. OpInfo('matmul',
  9406. aliases=('linalg.matmul',),
  9407. dtypes=all_types_and_complex_and(torch.bfloat16),
  9408. dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
  9409. *[torch.bfloat16]
  9410. if SM53OrLater or TEST_WITH_ROCM else []),
  9411. assert_autodiffed=True,
  9412. assert_jit_shape_analysis=True,
  9413. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  9414. gradcheck_fast_mode=True,
  9415. supports_forward_ad=True,
  9416. supports_fwgrad_bwgrad=True,
  9417. check_batched_forward_grad=False,
  9418. sample_inputs_func=partial(sample_inputs_matmul, is_rmatmul=False),
  9419. decorators=[
  9420. # NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3
  9421. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater),
  9422. # ROCm intermittently fails the test with standard atol/rtol
  9423. DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=0)}),
  9424. 'TestCommon', 'test_noncontiguous_samples', device_type='cuda',
  9425. active_if=TEST_WITH_ROCM),
  9426. DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=0)}),
  9427. 'TestCommon', 'test_out', device_type='cuda',
  9428. active_if=TEST_WITH_ROCM),
  9429. # mv for the sample with shapes (S, S, M, M), (M,) has some variance in the
  9430. # backward on CPU
  9431. DecorateInfo(toleranceOverride({torch.float32: tol(atol=0, rtol=1e-5)}),
  9432. 'TestCommon', 'test_noncontiguous_samples',
  9433. device_type='cpu'),
  9434. DecorateInfo(
  9435. toleranceOverride({
  9436. torch.float32: tol(atol=1e-5, rtol=1e-5),
  9437. torch.complex64: tol(atol=1e-5, rtol=1e-5),
  9438. }),
  9439. "TestDecomp", "test_comprehensive", device_type="cuda",
  9440. ),
  9441. ],
  9442. skips=(
  9443. # Strides are not the same!
  9444. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
  9445. # https://github.com/pytorch/pytorch/issues/67470
  9446. DecorateInfo(unittest.skip("67470!"),
  9447. 'TestCommon', 'test_noncontiguous_samples',
  9448. device_type='cpu', dtypes=(torch.long,)),
  9449. # AssertionError: False is not true : Tensors failed to compare as equal!
  9450. DecorateInfo(unittest.skip("Skipped!"), 'TestOpInfo',
  9451. device_type='xla', dtypes=(torch.long,)),
  9452. # https://github.com/pytorch/pytorch/issues/71774
  9453. DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness',
  9454. device_type='cpu', dtypes=(torch.long,)),
  9455. )),
  9456. OpInfo('max',
  9457. variant_test_name='reduction_with_dim',
  9458. dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
  9459. sample_inputs_func=sample_inputs_max_min_reduction_with_dim,
  9460. supports_fwgrad_bwgrad=True,
  9461. skips=(
  9462. ),
  9463. supports_forward_ad=True),
  9464. OpInfo('max',
  9465. variant_test_name='reduction_no_dim',
  9466. dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
  9467. supports_out=True,
  9468. supports_forward_ad=True,
  9469. supports_fwgrad_bwgrad=True,
  9470. sample_inputs_func=sample_inputs_max_min_reduction_no_dim,
  9471. skips=(
  9472. )),
  9473. OpInfo('median',
  9474. dtypes=all_types_and(torch.bfloat16),
  9475. dtypesIfCUDA=all_types_and(torch.float16),
  9476. # TODO: some signatures of median do support out
  9477. supports_out=False,
  9478. supports_forward_ad=True,
  9479. supports_fwgrad_bwgrad=True,
  9480. error_inputs_func=error_inputs_median,
  9481. sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False)),
  9482. OpInfo('nanmedian',
  9483. dtypes=all_types_and(torch.bfloat16),
  9484. dtypesIfCUDA=all_types_and(torch.float16),
  9485. # TODO: some signatures of nanmedian do support out
  9486. supports_out=False,
  9487. supports_forward_ad=True,
  9488. supports_fwgrad_bwgrad=True,
  9489. sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False)),
  9490. OpInfo('var_mean',
  9491. dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
  9492. sample_inputs_func=sample_inputs_std_var,
  9493. # TODO: some signatures of var_mean do support out
  9494. supports_out=False,
  9495. supports_forward_ad=True,
  9496. check_batched_forward_grad=False,
  9497. supports_fwgrad_bwgrad=True,
  9498. decorators=(
  9499. DecorateInfo(toleranceOverride({torch.float64: tol(atol=2e-7, rtol=2e-7)}),
  9500. "TestDecomp", "test_comprehensive", device_type="cuda"),
  9501. )),
  9502. OpInfo('var_mean',
  9503. variant_test_name='unbiased',
  9504. dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
  9505. sample_inputs_func=sample_inputs_std_var_unbiased,
  9506. # TODO: some signatures of var_mean do support out
  9507. supports_out=False,
  9508. supports_forward_ad=True,
  9509. check_batched_forward_grad=False,
  9510. supports_fwgrad_bwgrad=True,
  9511. decorators=(
  9512. DecorateInfo(toleranceOverride({torch.float64: tol(atol=2e-7, rtol=2e-7)}),
  9513. "TestDecomp", "test_comprehensive", device_type="cuda"),
  9514. )),
  9515. OpInfo('std_mean',
  9516. dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
  9517. sample_inputs_func=sample_inputs_std_var,
  9518. # TODO: some signatures of std_mean do support out
  9519. supports_out=False,
  9520. supports_forward_ad=True,
  9521. check_batched_forward_grad=False,
  9522. supports_fwgrad_bwgrad=True,
  9523. decorators=(
  9524. DecorateInfo(toleranceOverride({torch.float64: tol(atol=2e-7, rtol=2e-7)}),
  9525. "TestDecomp", "test_comprehensive", device_type="cuda"),
  9526. )),
  9527. OpInfo('std_mean',
  9528. variant_test_name='unbiased',
  9529. dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
  9530. sample_inputs_func=sample_inputs_std_var_unbiased,
  9531. # TODO: some signatures of var_mean do support out
  9532. supports_out=False,
  9533. supports_forward_ad=True,
  9534. check_batched_forward_grad=False,
  9535. supports_fwgrad_bwgrad=True,
  9536. decorators=(
  9537. DecorateInfo(toleranceOverride({torch.float64: tol(atol=2e-7, rtol=2e-7)}),
  9538. "TestDecomp", "test_comprehensive", device_type="cuda"),
  9539. )),
  9540. OpInfo('meshgrid',
  9541. variant_test_name='variadic_tensors',
  9542. ref=np.meshgrid,
  9543. dtypes=all_types_and_complex_and(torch.bfloat16, torch.bool, torch.float16),
  9544. sample_inputs_func=partial(sample_inputs_meshgrid, variant='variadic'),
  9545. skips=[
  9546. # JIT does not support variadic tensors.
  9547. # RuntimeError: input->type()->kind() == TypeKind::OptionalType
  9548. # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252,
  9549. # please report a bug to PyTorch.
  9550. DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
  9551. # meshgrid is defined in torch.functional to take a
  9552. # variadic list of tensors. Variadic parameters are not
  9553. # compatible with the normalize operator tests.
  9554. DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  9555. # Skip operator schema test because this is a functional and not an operator
  9556. DecorateInfo(unittest.skip("Skipped!"), 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
  9557. ],
  9558. supports_out=False,
  9559. supports_fwgrad_bwgrad=True,
  9560. supports_forward_ad=True,
  9561. # See https://github.com/pytorch/pytorch/pull/78358
  9562. check_batched_forward_grad=False,),
  9563. OpInfo('meshgrid',
  9564. variant_test_name='list_of_tensors',
  9565. # Unlike the variant above, we do not use np.meshgrid as a
  9566. # ref since it does not officially support list of numpy
  9567. # arrays.
  9568. dtypes=all_types_and_complex_and(torch.bfloat16, torch.bool, torch.float16),
  9569. sample_inputs_func=partial(sample_inputs_meshgrid, variant='list'),
  9570. skips=[
  9571. # meshgrid is defined in torch.functional to take a
  9572. # variadic list of tensors. Variadic parameters are not
  9573. # compatible with the normalize operator tests.
  9574. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  9575. ],
  9576. assert_autodiffed=True,
  9577. supports_out=False,
  9578. autodiff_nonfusible_nodes=[],
  9579. supports_fwgrad_bwgrad=True,
  9580. supports_forward_ad=True,
  9581. # See https://github.com/pytorch/pytorch/pull/78358
  9582. check_batched_forward_grad=False,),
  9583. OpInfo('min',
  9584. variant_test_name='reduction_with_dim',
  9585. dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
  9586. sample_inputs_func=sample_inputs_max_min_reduction_with_dim,
  9587. supports_fwgrad_bwgrad=True,
  9588. supports_forward_ad=True,
  9589. skips=(
  9590. )),
  9591. OpInfo('min',
  9592. variant_test_name='reduction_no_dim',
  9593. dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
  9594. supports_out=False,
  9595. supports_forward_ad=True,
  9596. supports_fwgrad_bwgrad=True,
  9597. sample_inputs_func=sample_inputs_max_min_reduction_no_dim,
  9598. skips=(
  9599. )),
  9600. OpInfo('quantile',
  9601. dtypes=floating_types(),
  9602. sample_inputs_func=sample_inputs_reduction_quantile,
  9603. supports_forward_ad=True,
  9604. supports_fwgrad_bwgrad=True,
  9605. skips=(
  9606. DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo', 'test_nvfuser_extremal_values'),
  9607. ),
  9608. # See https://github.com/pytorch/pytorch/issues/66357
  9609. # Relies on copy_ to broadcast, but the forward AD path calls broadcast_to which
  9610. # does not have a batching rule in core
  9611. check_batched_forward_grad=False),
  9612. OpInfo('nanquantile',
  9613. dtypes=floating_types(),
  9614. sample_inputs_func=sample_inputs_reduction_quantile,
  9615. supports_forward_ad=True,
  9616. supports_fwgrad_bwgrad=True,
  9617. skips=(
  9618. DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo', 'test_nvfuser_extremal_values'),
  9619. ),
  9620. # See https://github.com/pytorch/pytorch/issues/66357
  9621. # Relies on copy_ to broadcast, but the forward AD path calls broadcast_to which
  9622. # does not have a batching rule in core
  9623. check_batched_forward_grad=False),
  9624. BinaryUfuncInfo(
  9625. 'max',
  9626. aliases=('maximum',),
  9627. variant_test_name='binary',
  9628. dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
  9629. supports_forward_ad=True,
  9630. supports_fwgrad_bwgrad=True,
  9631. assert_autodiffed=True,
  9632. ref=np.maximum,
  9633. supports_rhs_python_scalar=False,
  9634. skips=(
  9635. # Incorrectly attempts to use a scalar for the second argument
  9636. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'),
  9637. # TODO: FIXME: RuntimeError: "max_elementwise_cuda" not implemented for 'ComplexFloat'
  9638. DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion', device_type='cuda'),
  9639. )),
  9640. BinaryUfuncInfo(
  9641. 'maximum',
  9642. dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
  9643. supports_forward_ad=True,
  9644. supports_fwgrad_bwgrad=True,
  9645. ref=np.maximum,
  9646. supports_rhs_python_scalar=False,
  9647. skips=(
  9648. # TODO: FIXME: RuntimeError: "max_elementwise_cuda" not implemented for 'ComplexFloat'
  9649. DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion', device_type='cuda'),
  9650. )),
  9651. BinaryUfuncInfo(
  9652. 'min',
  9653. aliases=('minimum',),
  9654. variant_test_name='binary',
  9655. dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
  9656. supports_forward_ad=True,
  9657. supports_fwgrad_bwgrad=True,
  9658. assert_autodiffed=True,
  9659. ref=np.minimum,
  9660. supports_rhs_python_scalar=False,
  9661. skips=(
  9662. # Incorrectly attempts to use a scalar for the second argument
  9663. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'),
  9664. # TODO: FIXME: RuntimeError: "min_elementwise_cuda" not implemented for 'ComplexFloat'
  9665. DecorateInfo(unittest.expectedFailure,
  9666. 'TestBinaryUfuncs',
  9667. 'test_type_promotion',
  9668. device_type='cuda'),
  9669. )),
  9670. BinaryUfuncInfo(
  9671. 'minimum',
  9672. dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
  9673. supports_forward_ad=True,
  9674. supports_fwgrad_bwgrad=True,
  9675. ref=np.minimum,
  9676. supports_rhs_python_scalar=False,
  9677. skips=(
  9678. # TODO: FIXME: RuntimeError: "min_elementwise_cuda" not implemented for 'ComplexFloat'
  9679. DecorateInfo(unittest.expectedFailure,
  9680. 'TestBinaryUfuncs',
  9681. 'test_type_promotion',
  9682. device_type='cuda'),
  9683. ),
  9684. ),
  9685. BinaryUfuncInfo('logical_and',
  9686. ref=np.logical_and,
  9687. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
  9688. supports_autograd=False,
  9689. always_returns_bool=True,
  9690. supports_rhs_python_scalar=False),
  9691. BinaryUfuncInfo('logical_or',
  9692. ref=np.logical_or,
  9693. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
  9694. supports_autograd=False,
  9695. always_returns_bool=True,
  9696. supports_rhs_python_scalar=False),
  9697. BinaryUfuncInfo('logical_xor',
  9698. ref=np.logical_xor,
  9699. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
  9700. supports_autograd=False,
  9701. always_returns_bool=True,
  9702. supports_rhs_python_scalar=False,
  9703. skips=(
  9704. )),
  9705. BinaryUfuncInfo('bitwise_and',
  9706. ref=np.bitwise_and,
  9707. dtypes=integral_types_and(torch.bool),
  9708. operator_variant=operator.and_,
  9709. inplace_operator_variant=operator.iand,
  9710. supports_autograd=False,
  9711. supports_one_python_scalar=True,
  9712. skips=(
  9713. # RuntimeError: "bitwise_and_cuda" not implemented for 'Half'
  9714. DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs',
  9715. 'test_type_promotion', device_type='cuda'),
  9716. )),
  9717. BinaryUfuncInfo('bitwise_or',
  9718. ref=np.bitwise_or,
  9719. dtypes=integral_types_and(torch.bool),
  9720. operator_variant=operator.or_,
  9721. inplace_operator_variant=operator.ior,
  9722. supports_autograd=False,
  9723. supports_one_python_scalar=True,
  9724. skips=(
  9725. # TODO: FIXME: RuntimeError: "bitwise_or_cuda" not implemented for 'Half'
  9726. DecorateInfo(unittest.expectedFailure,
  9727. 'TestBinaryUfuncs',
  9728. 'test_type_promotion',
  9729. device_type='cuda'),
  9730. )),
  9731. BinaryUfuncInfo('bitwise_xor',
  9732. ref=np.bitwise_xor,
  9733. dtypes=integral_types_and(torch.bool),
  9734. operator_variant=operator.xor,
  9735. inplace_operator_variant=operator.ixor,
  9736. supports_autograd=False,
  9737. supports_one_python_scalar=True,
  9738. skips=(
  9739. # TODO: FIXME: RuntimeError: "bitwise_xor_cuda" not implemented for 'Half'
  9740. DecorateInfo(unittest.expectedFailure,
  9741. 'TestBinaryUfuncs',
  9742. 'test_type_promotion',
  9743. device_type='cuda'),
  9744. )),
  9745. BinaryUfuncInfo('heaviside',
  9746. ref=lambda a, b: (
  9747. # necessary because np.heaviside incorrectly returns float64 when passed args of dtype int64
  9748. np.int64(np.heaviside(a, b)) if a.dtype == np.int64 and b.dtype == np.int64 else np.heaviside(a, b)
  9749. ),
  9750. dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
  9751. supports_autograd=False,
  9752. supports_rhs_python_scalar=False,
  9753. skips=(
  9754. # RuntimeError: heaviside is not yet implemented for tensors with different dtypes.
  9755. DecorateInfo(unittest.expectedFailure,
  9756. 'TestBinaryUfuncs',
  9757. 'test_type_promotion'),
  9758. # PyTorch's heaviside does not appear to propagate NaNs
  9759. DecorateInfo(unittest.skip("Skipped!"),
  9760. 'TestBinaryUfuncs',
  9761. 'test_reference_numerics_extremal_values'),
  9762. )),
  9763. BinaryUfuncInfo('lcm',
  9764. ref=np.lcm,
  9765. dtypes=integral_types_and(),
  9766. supports_autograd=False,
  9767. supports_rhs_python_scalar=False),
  9768. BinaryUfuncInfo('gcd',
  9769. ref=np.gcd,
  9770. dtypes=integral_types_and(),
  9771. supports_autograd=False,
  9772. supports_rhs_python_scalar=False,
  9773. skips=(
  9774. DecorateInfo(unittest.expectedFailure,
  9775. 'TestBinaryUfuncs',
  9776. 'test_reference_numerics_small_values',
  9777. dtypes=(torch.int8,)),)),
  9778. BinaryUfuncInfo('isclose',
  9779. ref=np.isclose,
  9780. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
  9781. sample_inputs_func=sample_inputs_isclose,
  9782. error_inputs_func=error_inputs_isclose,
  9783. supports_autograd=False,
  9784. supports_out=False,
  9785. supports_rhs_python_scalar=False,
  9786. skips=(
  9787. DecorateInfo(unittest.expectedFailure,
  9788. 'TestCommon',
  9789. 'test_numpy_refs', dtypes=(torch.complex128,)),
  9790. # RuntimeError: Short did not match Int
  9791. DecorateInfo(unittest.expectedFailure,
  9792. 'TestBinaryUfuncs',
  9793. 'test_type_promotion'),
  9794. DecorateInfo(unittest.skip("Skipped!"),
  9795. 'TestBinaryUfuncs',
  9796. 'test_reference_numerics_extremal_values'),
  9797. )),
  9798. # `softmax` supports different dtypes based on whether `dtype` argument,
  9799. # is passed or not. Hence two OpInfo entries, one with dtype and other without.
  9800. # https://github.com/pytorch/pytorch/issues/68752
  9801. OpInfo('softmax',
  9802. aliases=('special.softmax', 'nn.functional.softmax',),
  9803. aten_name='softmax',
  9804. aten_backward_name='_softmax_backward_data',
  9805. dtypes=floating_types_and(torch.bfloat16),
  9806. dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
  9807. sample_inputs_func=sample_inputs_softmax_variant,
  9808. assert_jit_shape_analysis=True,
  9809. assert_autodiffed=True,
  9810. supports_forward_ad=True,
  9811. supports_fwgrad_bwgrad=True,
  9812. supports_out=True),
  9813. OpInfo('softmax',
  9814. aliases=('special.softmax', 'nn.functional.softmax',),
  9815. variant_test_name="with_dtype",
  9816. aten_name='softmax',
  9817. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
  9818. sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True),
  9819. assert_autodiffed=True,
  9820. supports_forward_ad=True,
  9821. supports_fwgrad_bwgrad=True,
  9822. supports_out=True),
  9823. OpInfo(
  9824. '_softmax_backward_data',
  9825. op=torch.ops.aten._softmax_backward_data,
  9826. aten_name='_softmax_backward_data',
  9827. dtypes=floating_types_and(torch.bfloat16),
  9828. dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16),
  9829. sample_inputs_func=sample_inputs_softmax_backward_data,
  9830. assert_autodiffed=True,
  9831. supports_forward_ad=True,
  9832. supports_fwgrad_bwgrad=True,
  9833. supports_out=False,
  9834. skips=(
  9835. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='cpu'),
  9836. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
  9837. DecorateInfo(toleranceOverride({torch.float16: tol(atol=2e-4, rtol=2e-3),
  9838. torch.bfloat16: tol(atol=1e-3, rtol=0.016)}),
  9839. 'TestCudaFuserOpInfo', 'test_nvfuser_correctness'),
  9840. ),
  9841. ),
  9842. # `softmin` supports different dtypes based on whether `dtype` argument,
  9843. # is passed or not. Hence two OpInfo entries, one with dtype and other without.
  9844. # https://github.com/pytorch/pytorch/issues/68752
  9845. OpInfo('nn.functional.softmin',
  9846. aten_name='softmin',
  9847. dtypes=floating_types_and(torch.bfloat16),
  9848. dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
  9849. sample_inputs_func=sample_inputs_softmax_variant,
  9850. assert_jit_shape_analysis=False,
  9851. assert_autodiffed=False,
  9852. supports_forward_ad=True,
  9853. supports_fwgrad_bwgrad=True,
  9854. supports_out=False),
  9855. OpInfo('nn.functional.softmin',
  9856. variant_test_name="with_dtype",
  9857. aten_name='softmin',
  9858. dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
  9859. sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True),
  9860. assert_autodiffed=False,
  9861. supports_forward_ad=True,
  9862. supports_fwgrad_bwgrad=True,
  9863. supports_out=False),
  9864. OpInfo(
  9865. "nn.functional.cross_entropy",
  9866. dtypes=floating_types_and(torch.bfloat16),
  9867. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  9868. sample_inputs_func=sample_inputs_cross_entropy,
  9869. supports_out=False,
  9870. supports_forward_ad=True,
  9871. supports_fwgrad_bwgrad=True,
  9872. decorators=(
  9873. DecorateInfo(
  9874. toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-3)}),
  9875. "TestJit",
  9876. "test_variant_consistency_jit",
  9877. device_type="cpu",
  9878. ),
  9879. ),
  9880. skips=(
  9881. # AssertionError: False is not true : Scalars failed to compare as equal! 0 != 1536
  9882. # test_ops.TestJitCUDA.test_variant_consistency_jit_nn_functional_cross_entropy_cuda_float32 leaked
  9883. # 1536 bytes CUDA memory on device 0
  9884. DecorateInfo(
  9885. unittest.expectedFailure,
  9886. "TestJit",
  9887. "test_variant_consistency_jit",
  9888. device_type="cuda",
  9889. ),
  9890. )
  9891. ),
  9892. OpInfo('nn.functional.normalize',
  9893. dtypes=floating_and_complex_types_and(torch.bfloat16),
  9894. dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
  9895. sample_inputs_func=sample_inputs_normalize,
  9896. supports_forward_ad=True,
  9897. supports_fwgrad_bwgrad=True),
  9898. OpInfo('aminmax',
  9899. ref=lambda x, dim=None, keepdim=False: (np.amin(x, axis=dim, keepdims=keepdim), np.amax(x, axis=dim, keepdims=keepdim)),
  9900. dtypes=all_types_and(torch.bool),
  9901. dtypesIfCUDA=all_types_and(torch.bool, torch.float16, torch.bfloat16),
  9902. decorators=(onlyNativeDeviceTypes,),
  9903. supports_autograd=False,
  9904. sample_inputs_func=sample_inputs_aminmax,
  9905. error_inputs_func=error_inputs_aminmax_amax_amin,
  9906. skips=(
  9907. # AssertionError: Resizing an out= argument with no elements threw a resize warning!
  9908. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cpu'),
  9909. )),
  9910. OpInfo('as_strided',
  9911. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
  9912. supports_out=False,
  9913. supports_forward_ad=True,
  9914. supports_fwgrad_bwgrad=True,
  9915. # vmap does not support inplace views
  9916. check_inplace_batched_forward_grad=False,
  9917. sample_inputs_func=sample_inputs_as_strided,
  9918. skips=(
  9919. # Note: This xfail is fine -- it's inherent to how as_strided works
  9920. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples'),
  9921. # AssertionError: False is not true : Scalars failed to compare as equal!
  9922. DecorateInfo(unittest.skip("Errors when storage_offset is included"),
  9923. 'TestCommon', 'test_variant_consistency_eager'),
  9924. # Not close
  9925. DecorateInfo(unittest.skip("Errors when storage_offset is included"),
  9926. 'TestCommon', 'test_complex_half_reference_testing'),
  9927. # Not close
  9928. DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'),
  9929. DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'),
  9930. DecorateInfo(unittest.skip("Numerous errors"), 'TestFwdGradients'),
  9931. DecorateInfo(unittest.skip("Numerous errors"), 'TestBwdGradients'),
  9932. )),
  9933. OpInfo('as_strided',
  9934. variant_test_name='partial_views',
  9935. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
  9936. supports_out=False,
  9937. supports_forward_ad=True,
  9938. supports_fwgrad_bwgrad=True,
  9939. # vmap does not support inplace views
  9940. check_inplace_batched_forward_grad=False,
  9941. sample_inputs_func=sample_inputs_as_strided_partial_views,
  9942. skips=(
  9943. # Note: This xfail is fine -- it's inherent to how as_strided works
  9944. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples'),
  9945. # RuntimeError: This operator is not Composite Compliant: the
  9946. # storage_offset of the tensor was modified directly without
  9947. # going through the PyTorch dispatcher.
  9948. DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance'),
  9949. # These fail because the test changes the input's in-memory layout
  9950. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_complex_half_reference_testing'),
  9951. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
  9952. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'),
  9953. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  9954. DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad',
  9955. dtypes=(torch.complex64, torch.complex128)),
  9956. DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'),
  9957. DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_inplace_forward_mode_AD'),
  9958. DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_inplace_grad'),
  9959. DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_inplace_gradgrad'),
  9960. DecorateInfo(unittest.expectedFailure, 'TestProxyTensorOpInfo',
  9961. 'test_make_fx_symbolic_exhaustive_inplace'),
  9962. DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness'),
  9963. DecorateInfo(unittest.expectedFailure, 'TestCudaFuserOpInfo', 'test_nvfuser_correctness'),
  9964. DecorateInfo(unittest.expectedFailure, 'TestCudaFuserOpInfo', 'test_nvfuser_extremal_values'),
  9965. # Fail but are also flaky
  9966. DecorateInfo(unittest.skip("Test changes in memory layout"), 'TestMathBits'),
  9967. DecorateInfo(unittest.skip("Modifies input strides and storage_offset"), 'TestCommon',
  9968. 'test_non_standard_bool_values'),
  9969. )),
  9970. OpInfo('as_strided_scatter',
  9971. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
  9972. supports_out=False,
  9973. supports_forward_ad=True,
  9974. supports_fwgrad_bwgrad=True,
  9975. # vmap does not support inplace views
  9976. check_inplace_batched_forward_grad=False,
  9977. sample_inputs_func=sample_inputs_as_strided_scatter,
  9978. error_inputs_func=error_inputs_as_strided_scatter,
  9979. skips=(
  9980. DecorateInfo(unittest.skip('Works for int64, fails for everything else'), 'TestCommon', 'test_noncontiguous_samples'), # noqa: B950
  9981. DecorateInfo(unittest.skip('Fails in most cases, passes on LAZY for some reason'), 'TestCommon', 'test_variant_consistency_eager'), # noqa: B950
  9982. DecorateInfo(unittest.skip('Fails on cuda + rocm'), 'TestCommon', 'test_complex_half_reference_testing'),
  9983. DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_grad'),
  9984. DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'),
  9985. DecorateInfo(unittest.skip('Passes on complex128 and float64 only'), 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'),
  9986. # AssertionError: Tensor-likes are not close! (new_empty_strided.default)
  9987. DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), 'TestDecomp', 'test_comprehensive'),)),
  9988. OpInfo('native_layer_norm',
  9989. aten_name='native_layer_norm',
  9990. ref=reference_native_layer_norm,
  9991. dtypes=floating_types_and(torch.bfloat16),
  9992. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  9993. supports_out=False,
  9994. assert_jit_shape_analysis=True,
  9995. supports_fwgrad_bwgrad=True,
  9996. sample_inputs_func=sample_inputs_native_layer_norm,
  9997. error_inputs_func=error_inputs_native_layer_norm,
  9998. skips=(
  9999. # IndexError: tuple index out of range
  10000. DecorateInfo(unittest.skip('Skipped!'), 'TestFwdGradients', 'test_forward_mode_AD'),
  10001. # Tests fail when weight=None and bias is defined
  10002. # https://github.com/pytorch/pytorch/issues/79705
  10003. DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_gradgrad'),
  10004. # JIT test also tries to compute double backward, which fails
  10005. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  10006. # Extremal value issue on aten::native_layer_norm, which returns 'nan' for mean on 'inf' inputs
  10007. # possibly because of the welford implementation.
  10008. DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo', 'test_nvfuser_extremal_values'),
  10009. DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'),
  10010. )),
  10011. OpInfo('native_batch_norm',
  10012. aten_name='native_batch_norm',
  10013. dtypes=floating_types_and(torch.bfloat16),
  10014. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  10015. supports_forward_ad=True,
  10016. supports_fwgrad_bwgrad=True,
  10017. assert_jit_shape_analysis=True,
  10018. sample_inputs_func=sample_inputs_native_batch_norm,
  10019. skips=(
  10020. # NotImplementedError: Could not run
  10021. # 'aten::native_batch_norm.out' with arguments from the 'CPU' backend.
  10022. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type="cpu"),
  10023. # RuntimeError: out_invstd.dim() == 1 && out_invstd.is_contiguous() && out_invstd.sizes()[0]
  10024. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type="cuda"),
  10025. # Problem with _get_numerical_jacobian
  10026. # IndexError: tuple index out of range
  10027. DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
  10028. # RuntimeError: deepEquals(input.iValue, deepCopiedInput) INTERNAL ASSERT FAILED
  10029. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  10030. # https://github.com/pytorch/pytorch/issues/85960
  10031. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'),
  10032. # AssertionError: Booleans mismatch: True is not False
  10033. DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake_autocast'),
  10034. DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake'),
  10035. DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-5, rtol=5e-5)}),
  10036. "TestCompositeCompliance", "test_forward_ad"),
  10037. # Extremal value issue on aten::native_batch_norm, which returns 'nan' for mean on 'inf' inputs
  10038. # possibly because of the welford implementation.
  10039. DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo', 'test_nvfuser_extremal_values'),
  10040. )
  10041. ),
  10042. OpInfo('_native_batch_norm_legit',
  10043. aten_name='_native_batch_norm_legit',
  10044. dtypes=floating_types_and(torch.bfloat16),
  10045. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  10046. supports_forward_ad=True,
  10047. supports_fwgrad_bwgrad=True,
  10048. assert_jit_shape_analysis=True,
  10049. sample_inputs_func=sample_inputs__native_batch_norm_legit,
  10050. skips=(
  10051. # NotImplementedError: Could not run
  10052. # 'aten::native_batch_norm.out' with arguments from the 'CPU' backend.
  10053. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type="cpu"),
  10054. # RuntimeError: out_invstd.dim() == 1 && out_invstd.is_contiguous() && out_invstd.sizes()[0]
  10055. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type="cuda"),
  10056. # Problem with _get_numerical_jacobian
  10057. # IndexError: tuple index out of range
  10058. DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
  10059. # RuntimeError: deepEquals(input.iValue, deepCopiedInput) INTERNAL ASSERT FAILED
  10060. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  10061. # https://github.com/pytorch/pytorch/issues/85960
  10062. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'),
  10063. DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-5, rtol=5e-5)}),
  10064. "TestCompositeCompliance", "test_forward_ad"),
  10065. # Extremal value issue on aten::native_batch_norm, which returns 'nan' for mean on 'inf' inputs
  10066. # possibly because of the welford implementation.
  10067. DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo', 'test_nvfuser_extremal_values'),
  10068. )
  10069. ),
  10070. OpInfo('nn.functional.cosine_similarity',
  10071. aten_name="cosine_similarity",
  10072. dtypes=floating_types_and(torch.bfloat16),
  10073. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  10074. supports_out=False,
  10075. supports_forward_ad=True,
  10076. supports_fwgrad_bwgrad=True,
  10077. sample_inputs_func=sample_inputs_cosine_similarity),
  10078. OpInfo('nn.functional.adaptive_avg_pool1d',
  10079. dtypes=floating_types_and(torch.bfloat16),
  10080. dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
  10081. supports_out=False,
  10082. supports_forward_ad=True,
  10083. supports_fwgrad_bwgrad=True,
  10084. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  10085. error_inputs_func=error_inputs_adaptive_avg_pool1d,
  10086. sample_inputs_func=sample_inputs_adaptive_avg_pool1d),
  10087. OpInfo('nn.functional.adaptive_avg_pool2d',
  10088. dtypes=floating_types_and(torch.bfloat16),
  10089. dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
  10090. decorators=(
  10091. # RuntimeError:
  10092. # adaptive_avg_pool2d(Tensor input, int[2] output_size) -> (Tensor):
  10093. # Expected a value of type 'List[int]' for argument 'output_size' but
  10094. # instead found type 'Tuple[NoneType, int]'. :
  10095. # File "<string>", line 3
  10096. # def the_method(i0):
  10097. # return torch.nn.functional.adaptive_avg_pool2d(i0, (None, 7))
  10098. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
  10099. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  10100. ),
  10101. supports_out=False,
  10102. supports_forward_ad=True,
  10103. supports_fwgrad_bwgrad=True,
  10104. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  10105. error_inputs_func=error_inputs_adaptive_avg_pool2d,
  10106. sample_inputs_func=sample_inputs_adaptive_avg_pool2d),
  10107. OpInfo('nn.functional.adaptive_avg_pool3d',
  10108. dtypes=floating_types_and(torch.half),
  10109. dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
  10110. decorators=(
  10111. # RuntimeError:
  10112. # adaptive_avg_pool3d(Tensor input, int[3] output_size) -> (Tensor):
  10113. # Expected a value of type 'List[int]' for argument 'output_size' but
  10114. # instead found type 'Tuple[NoneType, NoneType, NoneType]'. :
  10115. # File "<string>", line 3
  10116. #
  10117. # def the_method(i0):
  10118. # return torch.nn.functional.adaptive_avg_pool3d(i0, (None, None, None))
  10119. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
  10120. #
  10121. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  10122. ),
  10123. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  10124. gradcheck_fast_mode=True,
  10125. supports_out=False,
  10126. supports_forward_ad=True,
  10127. supports_fwgrad_bwgrad=True,
  10128. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  10129. error_inputs_func=error_inputs_adaptive_avg_pool3d,
  10130. sample_inputs_func=sample_inputs_adaptive_avg_pool3d),
  10131. OpInfo('nn.functional.adaptive_max_pool1d',
  10132. dtypes=floating_types_and(torch.bfloat16),
  10133. dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
  10134. supports_out=False,
  10135. supports_forward_ad=True,
  10136. supports_fwgrad_bwgrad=True,
  10137. # got: Batching rule not implemented for aten::flatten.using_ints
  10138. check_batched_forward_grad=False,
  10139. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  10140. error_inputs_func=error_inputs_adaptive_max_pool1d,
  10141. sample_inputs_func=sample_inputs_adaptive_max_pool1d),
  10142. OpInfo('nn.functional.adaptive_max_pool2d',
  10143. dtypes=floating_types_and(torch.bfloat16),
  10144. dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
  10145. decorators=(
  10146. # RuntimeError:
  10147. # adaptive_max_pool2d(Tensor input, int[2] output_size) -> (Tensor):
  10148. # Expected a value of type 'List[int]' for argument 'output_size' but
  10149. # instead found type 'Tuple[NoneType, int]'. :
  10150. # File "<string>", line 3
  10151. # def the_method(i0):
  10152. # return torch.nn.functional.adaptive_max_pool2d(i0, (None, 7))
  10153. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
  10154. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  10155. ),
  10156. supports_out=False,
  10157. supports_forward_ad=True,
  10158. supports_fwgrad_bwgrad=True,
  10159. # got: Batching rule not implemented for aten::flatten.using_ints
  10160. check_batched_forward_grad=False,
  10161. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  10162. error_inputs_func=error_inputs_adaptive_max_pool2d,
  10163. sample_inputs_func=sample_inputs_adaptive_max_pool2d),
  10164. OpInfo('nn.functional.adaptive_max_pool3d',
  10165. dtypes=floating_types(),
  10166. dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
  10167. decorators=(
  10168. # RuntimeError:
  10169. # adaptive_max_pool3d(Tensor input, int[3] output_size) -> (Tensor):
  10170. # Expected a value of type 'List[int]' for argument 'output_size' but
  10171. # instead found type 'Tuple[NoneType, NoneType, NoneType]'. :
  10172. # File "<string>", line 3
  10173. #
  10174. # def the_method(i0):
  10175. # return torch.nn.functional.adaptive_max_pool3d(i0, (None, None, None))
  10176. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
  10177. #
  10178. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  10179. ),
  10180. supports_out=False,
  10181. supports_forward_ad=True,
  10182. supports_fwgrad_bwgrad=True,
  10183. # got: Batching rule not implemented for aten::flatten.using_ints
  10184. check_batched_forward_grad=False,
  10185. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  10186. error_inputs_func=error_inputs_adaptive_max_pool3d,
  10187. sample_inputs_func=sample_inputs_adaptive_max_pool3d),
  10188. OpInfo('nn.functional.avg_pool1d',
  10189. aten_name='avg_pool1d',
  10190. supports_autograd=True,
  10191. supports_out=False,
  10192. supports_forward_ad=True,
  10193. supports_fwgrad_bwgrad=True,
  10194. dtypes=floating_types_and(torch.int64, torch.bfloat16),
  10195. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  10196. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  10197. error_inputs_func=error_inputs_avg_pool1d,
  10198. sample_inputs_func=sample_inputs_avgpool1d),
  10199. OpInfo('nn.functional.avg_pool3d',
  10200. aten_name='avg_pool3d',
  10201. supports_autograd=True,
  10202. supports_forward_ad=True,
  10203. supports_fwgrad_bwgrad=True,
  10204. dtypes=floating_types_and(torch.int64),
  10205. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  10206. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  10207. error_inputs_func=error_inputs_avg_pool3d,
  10208. sample_inputs_func=sample_inputs_avgpool3d,
  10209. skips=(
  10210. # AssertionError: Tensor-likes are not close!
  10211. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cpu'),
  10212. )),
  10213. OpInfo(
  10214. "nn.functional.binary_cross_entropy_with_logits",
  10215. aten_name="binary_cross_entropy_with_logits",
  10216. supports_autograd=True,
  10217. supports_forward_ad=True,
  10218. supports_fwgrad_bwgrad=True,
  10219. supports_out=False,
  10220. dtypes=floating_types_and(torch.bfloat16),
  10221. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  10222. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  10223. sample_inputs_func=sample_inputs_binary_cross_entropy_with_logits,
  10224. skips=(
  10225. DecorateInfo(
  10226. unittest.skip("Skipped!"),
  10227. 'TestJit',
  10228. 'test_variant_consistency_jit',
  10229. dtypes=(torch.float32,)
  10230. ),
  10231. ),
  10232. ),
  10233. UnaryUfuncInfo(
  10234. 'nn.functional.relu',
  10235. aten_name="relu",
  10236. ref=lambda a: np.where(a <= 0, 0, a),
  10237. supports_autograd=True,
  10238. supports_sparse=True,
  10239. supports_sparse_csr=True,
  10240. supports_sparse_csc=True,
  10241. supports_sparse_bsr=True,
  10242. supports_sparse_bsc=True,
  10243. dtypes=all_types_and(torch.bfloat16),
  10244. dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
  10245. sample_inputs_func=sample_inputs_nn_activation_relu,
  10246. supports_out=False,
  10247. supports_fwgrad_bwgrad=True,
  10248. supports_forward_ad=True),
  10249. OpInfo('nn.functional.conv_transpose1d',
  10250. # `ref` for this function is backward of
  10251. # corresponding `conv*d`
  10252. ref=partial(conv_transpose_ref, fn=torch.nn.functional.conv_transpose1d),
  10253. aten_name='conv_transpose1d',
  10254. aliases=('conv_transpose1d',),
  10255. dtypes=floating_and_complex_types_and(torch.int64, torch.bfloat16),
  10256. dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf,
  10257. torch.bfloat16),
  10258. sample_inputs_func=sample_inputs_conv_transpose1d,
  10259. supports_forward_ad=True,
  10260. supports_fwgrad_bwgrad=True,
  10261. assert_jit_shape_analysis=True,
  10262. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  10263. decorators=(
  10264. DecorateInfo(
  10265. toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1.3e-06), }),
  10266. 'TestCommon', 'test_variant_consistency_eager', device_type='cuda'),
  10267. DecorateInfo(
  10268. toleranceOverride({torch.chalf: tol(atol=5e-2, rtol=5e-2), }),
  10269. 'TestCommon', 'test_complex_half_reference_testing'),
  10270. DecorateInfo(
  10271. toleranceOverride({torch.complex32: tol(atol=1e-5, rtol=5e-3)}),
  10272. "TestCudaFuserOpInfo", "test_nvfuser_correctness"),
  10273. DecorateInfo(
  10274. toleranceOverride({torch.float: tol(atol=1.5e-5, rtol=1.5e-5), }),
  10275. 'TestCommon', 'test_numpy_ref_mps'),
  10276. ),
  10277. skips=(
  10278. # Reason for Skip: https://github.com/pytorch/pytorch/pull/79694#issuecomment-1186949486
  10279. DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
  10280. dtypes=(torch.complex64,)),
  10281. # RuntimeError: UNSUPPORTED DTYPE: complex
  10282. DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness',
  10283. dtypes=(torch.complex64, torch.complex128)),
  10284. # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at
  10285. # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":104, please report a bug to PyTorch.
  10286. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',
  10287. dtypes=(torch.float,)),
  10288. # RuntimeError: "slow_conv2d_cpu_grad_input" not implemented for 'Long'
  10289. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref',
  10290. dtypes=(torch.int64,)),
  10291. ),
  10292. supports_out=False,),
  10293. OpInfo('nn.functional.conv_transpose2d',
  10294. aten_name='conv_transpose2d',
  10295. aliases=('conv_transpose2d',),
  10296. # `ref` for this function is backward of
  10297. # corresponding `conv*d`
  10298. ref=partial(conv_transpose_ref, fn=torch.nn.functional.conv_transpose2d),
  10299. dtypes=floating_and_complex_types_and(torch.int64, torch.bfloat16),
  10300. dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf,
  10301. torch.bfloat16),
  10302. sample_inputs_func=sample_inputs_conv_transpose2d,
  10303. # Runs very slowly on slow-gradcheck for complex.
  10304. gradcheck_fast_mode=True,
  10305. supports_forward_ad=True,
  10306. supports_fwgrad_bwgrad=True,
  10307. assert_jit_shape_analysis=True,
  10308. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  10309. decorators=[
  10310. DecorateInfo(
  10311. toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1.3e-06), }),
  10312. 'TestCommon', 'test_variant_consistency_eager', device_type='cuda'),
  10313. DecorateInfo(
  10314. toleranceOverride({torch.float32: tol(atol=2e-05, rtol=5e-05), }),
  10315. 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'),
  10316. DecorateInfo(
  10317. toleranceOverride({torch.complex32: tol(atol=5e-2, rtol=5e-2)}),
  10318. "TestCudaFuserOpInfo", "test_nvfuser_correctness"),
  10319. DecorateInfo(
  10320. toleranceOverride({torch.chalf: tol(atol=8e-2, rtol=8e-2), }),
  10321. 'TestCommon', 'test_complex_half_reference_testing')],
  10322. skips=(
  10323. # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at
  10324. # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":104, please report a bug to PyTorch.
  10325. DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
  10326. # RuntimeError: UNSUPPORTED DTYPE: complex
  10327. DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness',
  10328. dtypes=(torch.complex64, torch.complex128)),
  10329. # RuntimeError: "slow_conv2d_cpu_grad_input" not implemented for 'Long'
  10330. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref',
  10331. dtypes=(torch.int64,)),
  10332. # Reference: https://github.com/pytorch/pytorch/issues/86356
  10333. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref',
  10334. dtypes=(torch.double, torch.cdouble)),
  10335. DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'),
  10336. # AssertionError: None mismatch: torch.complex64 is not None
  10337. DecorateInfo(unittest.expectedFailure, 'TestDtypeCustomRules', 'test_custom_rules',
  10338. dtypes=(torch.complex64, torch.complex128)),
  10339. ),
  10340. supports_out=False,),
  10341. OpInfo('nn.functional.conv_transpose3d',
  10342. aten_name='conv_transpose3d',
  10343. aliases=('conv_transpose3d',),
  10344. # `ref` for this function is backward of
  10345. # corresponding `conv*d`
  10346. ref=partial(conv_transpose_ref, fn=torch.nn.functional.conv_transpose3d),
  10347. dtypes=floating_and_complex_types_and(torch.int64, torch.bfloat16),
  10348. dtypesIfCUDA=floating_and_complex_types_and(
  10349. torch.float16, torch.chalf, torch.bfloat16),
  10350. sample_inputs_func=sample_inputs_conv_transpose3d,
  10351. supports_forward_ad=True,
  10352. supports_fwgrad_bwgrad=True,
  10353. assert_jit_shape_analysis=True,
  10354. # Runs very slowly on slow-gradcheck - alternatively reduce input sizes
  10355. gradcheck_fast_mode=True,
  10356. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  10357. decorators=[
  10358. DecorateInfo(
  10359. toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1.3e-06),
  10360. torch.complex64: tol(atol=1.3e-04, rtol=1.3e-05)}),
  10361. 'TestCommon', 'test_variant_consistency_eager', device_type='cuda'),
  10362. DecorateInfo(
  10363. toleranceOverride({torch.float32: tol(atol=2e-04, rtol=2e-04), }),
  10364. 'TestCompositeCompliance', 'test_operator', device_type='cuda'),
  10365. DecorateInfo(
  10366. toleranceOverride({torch.float32: tol(atol=1.3e-04, rtol=1.3e-06),
  10367. torch.complex64: tol(atol=1.3e-04, rtol=1.3e-05)}),
  10368. 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'),
  10369. DecorateInfo(
  10370. toleranceOverride({torch.float32: tol(atol=1e-04, rtol=2e-05), }),
  10371. 'TestCompositeCompliance', 'test_forward_ad', device_type='cuda',
  10372. active_if=TEST_CUDNN),
  10373. DecorateInfo(
  10374. toleranceOverride({torch.complex32: tol(atol=5e-2, rtol=5e-2)}),
  10375. "TestCudaFuserOpInfo", "test_nvfuser_correctness"),
  10376. DecorateInfo(
  10377. toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1e-4)}),
  10378. "TestMathBits", "test_conj_view", device_type='cuda'),
  10379. DecorateInfo(
  10380. toleranceOverride({torch.chalf: tol(atol=9e-2, rtol=9e-2), }),
  10381. 'TestCommon', 'test_complex_half_reference_testing')],
  10382. skips=(
  10383. # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at
  10384. # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":104, please report a bug to PyTorch.
  10385. DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
  10386. DecorateInfo(unittest.skip("Skipped! 75029"), 'TestCudaFuserOpInfo', 'test_nvfuser_correctness'),
  10387. DecorateInfo(unittest.skip("Skipped! 75363"), 'TestCudaFuserOpInfo', 'test_nvfuser_extremal_values'),
  10388. # RuntimeError: "slow_conv3d_cpu_grad_input" not implemented for 'Long'
  10389. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref',
  10390. dtypes=(torch.int64,)),
  10391. # Reference: https://github.com/pytorch/pytorch/issues/86356
  10392. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref',
  10393. dtypes=(torch.double, torch.cdouble)),
  10394. DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'),
  10395. # RuntimeError: UNSUPPORTED DTYPE: complex
  10396. DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness',
  10397. dtypes=(torch.complex64, torch.complex128)),
  10398. ),
  10399. supports_out=False,),
  10400. OpInfo('nn.functional.conv1d',
  10401. aliases=('conv1d',),
  10402. aten_name='conv1d',
  10403. dtypes=floating_and_complex_types_and(torch.int64, torch.bfloat16),
  10404. dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf,
  10405. torch.bfloat16),
  10406. sample_inputs_func=sample_inputs_conv1d,
  10407. error_inputs_func=error_inputs_conv1d,
  10408. supports_forward_ad=True,
  10409. supports_fwgrad_bwgrad=True,
  10410. assert_jit_shape_analysis=True,
  10411. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  10412. decorators=(
  10413. DecorateInfo(
  10414. toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=5e-2)}),
  10415. 'TestCommon', 'test_complex_half_reference_testing'
  10416. ),
  10417. DecorateInfo(
  10418. toleranceOverride({torch.chalf: tol(atol=1e-3, rtol=1e-3)}),
  10419. 'TestCudaFuserOpInfo', 'test_nvfuser_correctness',
  10420. ),
  10421. DecorateInfo(
  10422. toleranceOverride({torch.float16: tol(atol=2e-3, rtol=1e-3)}),
  10423. 'TestInductorOpInfo', 'test_comprehensive', device_type='cuda',
  10424. ),
  10425. ),
  10426. skips=(
  10427. # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at
  10428. # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":103, please report a bug to PyTorch.
  10429. DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
  10430. # Ref: https://github.com/pytorch/pytorch/issues/75309
  10431. # AssertionError: None mismatch: torch.complex128 is not None
  10432. DecorateInfo(unittest.expectedFailure, 'TestDtypeCustomRules',
  10433. 'test_custom_rules', dtypes=(torch.complex64, torch.complex128)),
  10434. # Ref: https://github.com/pytorch/pytorch/issues/75309
  10435. # RuntimeError: UNSUPPORTED DTYPE: complex
  10436. DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo',
  10437. 'test_nnc_correctness', dtypes=(torch.complex64, torch.complex128)),
  10438. ),
  10439. supports_expanded_weight=True,
  10440. supports_out=False,),
  10441. OpInfo('nn.functional.conv2d',
  10442. aliases=('conv2d',),
  10443. aten_name='conv2d',
  10444. dtypes=floating_and_complex_types_and(torch.int64, torch.bfloat16),
  10445. dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf,
  10446. torch.bfloat16),
  10447. sample_inputs_func=partial(sample_inputs_conv2d),
  10448. error_inputs_func=error_inputs_conv2d,
  10449. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  10450. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  10451. gradcheck_fast_mode=True,
  10452. supports_forward_ad=True,
  10453. supports_fwgrad_bwgrad=True,
  10454. assert_jit_shape_analysis=True,
  10455. decorators=(
  10456. DecorateInfo(
  10457. toleranceOverride({torch.chalf: tol(atol=6e-2, rtol=5e-2)}),
  10458. 'TestCommon', 'test_complex_half_reference_testing',
  10459. ),
  10460. DecorateInfo(
  10461. toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=1e-2)}),
  10462. 'TestCudaFuserOpInfo', 'test_nvfuser_correctness',
  10463. ),
  10464. ),
  10465. skips=(
  10466. # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at
  10467. # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":103, please report a bug to PyTorch.
  10468. DecorateInfo(unittest.skip("Works on some configs!"), 'TestJit', 'test_variant_consistency_jit'),
  10469. # Ref: https://github.com/pytorch/pytorch/issues/75309
  10470. # AssertionError: None mismatch: torch.complex128 is not None
  10471. DecorateInfo(unittest.expectedFailure, 'TestDtypeCustomRules',
  10472. 'test_custom_rules', dtypes=(torch.complex64, torch.complex128)),
  10473. # RuntimeError: UNSUPPORTED DTYPE: complex
  10474. DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo',
  10475. 'test_nnc_correctness', dtypes=(torch.complex64, torch.complex128)),
  10476. ),
  10477. supports_expanded_weight=True,
  10478. supports_out=False,),
  10479. OpInfo('nn.functional.group_norm',
  10480. aten_name='group_norm',
  10481. aliases=('group_norm',),
  10482. ref=reference_group_norm,
  10483. dtypes=floating_types_and(torch.bfloat16),
  10484. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  10485. supports_out=False,
  10486. supports_forward_ad=True,
  10487. supports_fwgrad_bwgrad=True,
  10488. error_inputs_func=error_inputs_group_norm,
  10489. decorators=[
  10490. # RuntimeError: Cannot insert a Tensor that requires grad as a constant.
  10491. # Consider making it a parameter or input, or detaching the gradient
  10492. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,))
  10493. ],
  10494. sample_inputs_func=sample_inputs_group_norm,
  10495. reference_inputs_func=reference_inputs_group_norm,
  10496. supports_expanded_weight=True,),
  10497. OpInfo('nn.functional.instance_norm',
  10498. # no ref because instance_norm will often have numerical instability (large numbers or nan)
  10499. dtypes=floating_types_and(torch.bfloat16),
  10500. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  10501. supports_out=False,
  10502. supports_forward_ad=True,
  10503. supports_fwgrad_bwgrad=True,
  10504. decorators=[
  10505. # RuntimeError: Cannot insert a Tensor that requires grad as a constant.
  10506. # Consider making it a parameter or input, or detaching the gradient
  10507. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
  10508. DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad',
  10509. active_if=TEST_WITH_ROCM)
  10510. ],
  10511. sample_inputs_func=sample_inputs_instance_norm,
  10512. supports_expanded_weight=True,),
  10513. OpInfo('nn.functional.layer_norm',
  10514. aten_name='layer_norm',
  10515. aten_backward_name='layer_norm_backward',
  10516. aliases=('layer_norm',),
  10517. ref=reference_layer_norm,
  10518. dtypes=floating_types_and(torch.bfloat16),
  10519. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  10520. supports_out=False,
  10521. supports_forward_ad=True,
  10522. supports_fwgrad_bwgrad=True,
  10523. assert_jit_shape_analysis=True,
  10524. decorators=[
  10525. DecorateInfo(
  10526. toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1e-03)}),
  10527. 'TestCommon', 'test_numpy_refs'
  10528. ),
  10529. DecorateInfo(unittest.skip("Bug in MPS backend!"), 'TestCommon', 'test_numpy_ref_mps'),
  10530. ],
  10531. sample_inputs_func=sample_inputs_layer_norm,
  10532. supports_expanded_weight=True,),
  10533. OpInfo('nn.functional.local_response_norm',
  10534. dtypes=floating_types_and(torch.int64, torch.bfloat16),
  10535. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  10536. supports_out=False,
  10537. supports_forward_ad=True,
  10538. supports_fwgrad_bwgrad=True,
  10539. decorators=[
  10540. # RuntimeError: falseINTERNAL ASSERT FAILED at
  10541. # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch.
  10542. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
  10543. DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
  10544. 'TestCudaFuserOpInfo', 'test_nvfuser_correctness'),
  10545. ],
  10546. sample_inputs_func=sample_inputs_local_response_norm,),
  10547. OpInfo('constant_pad_nd',
  10548. supports_forward_ad=True,
  10549. supports_fwgrad_bwgrad=True,
  10550. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half),
  10551. sample_inputs_func=sample_inputs_constant_pad_nd,
  10552. supports_out=False,
  10553. skips=(
  10554. # bool can't be passed to Scalar arguments in JIT tracer because
  10555. # BoolType is not a subtype of ScalarType.
  10556. DecorateInfo(
  10557. unittest.expectedFailure, 'TestNNCOpInfo',
  10558. 'test_nnc_correctness', dtypes=(torch.bool,)),
  10559. DecorateInfo(
  10560. unittest.expectedFailure, 'TestCudaFuserOpInfo',
  10561. 'test_nvfuser_correctness', dtypes=(torch.bool,)),
  10562. )),
  10563. OpInfo('nn.functional.pad',
  10564. variant_test_name='constant',
  10565. aten_name='constant_pad_nd',
  10566. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  10567. gradcheck_fast_mode=True,
  10568. supports_forward_ad=True,
  10569. supports_fwgrad_bwgrad=True,
  10570. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half),
  10571. sample_inputs_func=partial(sample_inputs_nn_pad, mode='constant'),
  10572. supports_out=False),
  10573. OpInfo('nn.functional.pad',
  10574. variant_test_name='reflect',
  10575. supports_forward_ad=True,
  10576. supports_fwgrad_bwgrad=True,
  10577. dtypes=floating_and_complex_types(),
  10578. dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
  10579. sample_inputs_func=partial(sample_inputs_nn_pad, mode='reflect'),
  10580. skips=(
  10581. # Doesn't have a corresponding aten operator.
  10582. # RuntimeError: falseINTERNAL ASSERT FAILED at
  10583. # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch.
  10584. DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
  10585. ),
  10586. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  10587. supports_out=False),
  10588. OpInfo('nn.functional.pad',
  10589. variant_test_name='replicate',
  10590. supports_forward_ad=True,
  10591. supports_fwgrad_bwgrad=True,
  10592. dtypes=floating_and_complex_types(),
  10593. dtypesIfCUDA=floating_and_complex_types_and(torch.half),
  10594. sample_inputs_func=partial(sample_inputs_nn_pad, mode='replicate'),
  10595. skips=(
  10596. # Doesn't have a corresponding aten operator.
  10597. # RuntimeError: falseINTERNAL ASSERT FAILED at
  10598. # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch.
  10599. DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
  10600. ),
  10601. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  10602. supports_out=False),
  10603. OpInfo('nn.functional.pad',
  10604. variant_test_name='circular',
  10605. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half),
  10606. sample_inputs_func=partial(sample_inputs_nn_pad, mode='circular'),
  10607. supports_forward_ad=True,
  10608. supports_fwgrad_bwgrad=True,
  10609. check_batched_grad=False,
  10610. # https://github.com/pytorch/pytorch/issues/66357
  10611. check_batched_forward_grad=False,
  10612. skips=(
  10613. # Doesn't have a corresponding aten operator.
  10614. # RuntimeError: falseINTERNAL ASSERT FAILED at
  10615. # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch.
  10616. DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
  10617. # Difference from <type> is larger with decomposition new_empty_strided.default than original on output 0
  10618. DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), 'TestDecomp', 'test_comprehensive'),
  10619. ),
  10620. supports_out=False),
  10621. OpInfo('nn.functional.hardswish',
  10622. aten_name="hardswish",
  10623. aten_backward_name='hardswish_backward',
  10624. supports_autograd=True,
  10625. assert_autodiffed=True,
  10626. sample_inputs_func=sample_inputs_hardswish,
  10627. dtypes=floating_types_and(torch.bfloat16),
  10628. dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
  10629. supports_gradgrad=True,
  10630. supports_forward_ad=True,
  10631. supports_fwgrad_bwgrad=True,
  10632. supports_out=False,
  10633. autodiff_nonfusible_nodes=["aten::hardswish"]),
  10634. OpInfo('nn.functional.unfold',
  10635. aten_name='im2col',
  10636. dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
  10637. dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
  10638. sample_inputs_func=sample_inputs_nn_unfold,
  10639. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  10640. gradcheck_fast_mode=True,
  10641. supports_forward_ad=True,
  10642. supports_fwgrad_bwgrad=True,
  10643. supports_out=False,
  10644. skips=(
  10645. # NOTE: this failure may not reproduce consistently on different systems
  10646. # false INTERNAL ASSERT FAILED at "...torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185
  10647. DecorateInfo(unittest.skip("Internal assert failed!"), 'TestJit', 'test_variant_consistency_jit'),
  10648. )),
  10649. OpInfo('nn.functional.interpolate',
  10650. aten_name="interpolate",
  10651. variant_test_name='nearest',
  10652. supports_autograd=True,
  10653. supports_fwgrad_bwgrad=True,
  10654. supports_forward_ad=True,
  10655. dtypes=floating_types_and(torch.uint8, torch.bfloat16),
  10656. dtypesIfCUDA=floating_types_and(torch.half, torch.uint8),
  10657. sample_inputs_func=partial(sample_inputs_interpolate, 'nearest'),
  10658. skips=(
  10659. # RuntimeError: false
  10660. # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185,
  10661. # please report a bug to PyTorch.
  10662. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  10663. ),
  10664. supports_out=False),
  10665. OpInfo('nn.functional.interpolate',
  10666. aten_name="interpolate",
  10667. variant_test_name='linear',
  10668. supports_autograd=True,
  10669. supports_fwgrad_bwgrad=True,
  10670. supports_forward_ad=True,
  10671. dtypes=floating_types_and(torch.bfloat16),
  10672. dtypesIfCUDA=floating_types_and(torch.half),
  10673. sample_inputs_func=partial(sample_inputs_interpolate, 'linear'),
  10674. skips=(
  10675. # RuntimeError: false
  10676. # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185,
  10677. # please report a bug to PyTorch.
  10678. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  10679. ),
  10680. supports_out=False),
  10681. OpInfo('nn.functional.interpolate',
  10682. aten_name="interpolate",
  10683. variant_test_name='bilinear',
  10684. supports_fwgrad_bwgrad=True,
  10685. supports_autograd=True,
  10686. supports_forward_ad=True,
  10687. dtypes=floating_types_and(torch.uint8, torch.bfloat16),
  10688. dtypesIfCUDA=floating_types_and(torch.half),
  10689. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  10690. sample_inputs_func=partial(sample_inputs_interpolate, 'bilinear'),
  10691. skips=(
  10692. # RuntimeError: false
  10693. # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185,
  10694. # please report a bug to PyTorch.
  10695. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  10696. ),
  10697. supports_out=False),
  10698. OpInfo('nn.functional.interpolate',
  10699. aten_name="interpolate",
  10700. variant_test_name='bicubic',
  10701. supports_autograd=True,
  10702. supports_forward_ad=True,
  10703. supports_fwgrad_bwgrad=True,
  10704. dtypes=floating_types_and(torch.bfloat16),
  10705. dtypesIfCUDA=floating_types_and(torch.half),
  10706. sample_inputs_func=partial(sample_inputs_interpolate, 'bicubic'),
  10707. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  10708. skips=(
  10709. # RuntimeError: false
  10710. # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185,
  10711. # please report a bug to PyTorch.
  10712. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  10713. ),
  10714. supports_out=False),
  10715. OpInfo('nn.functional.interpolate',
  10716. aten_name="interpolate",
  10717. variant_test_name='trilinear',
  10718. supports_autograd=True,
  10719. supports_forward_ad=True,
  10720. supports_fwgrad_bwgrad=True,
  10721. dtypes=floating_types_and(torch.bfloat16),
  10722. dtypesIfCUDA=floating_types_and(torch.half),
  10723. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  10724. sample_inputs_func=partial(sample_inputs_interpolate, 'trilinear'),
  10725. skips=(
  10726. # RuntimeError: false
  10727. # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185,
  10728. # please report a bug to PyTorch.
  10729. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  10730. ),
  10731. supports_out=False),
  10732. OpInfo('nn.functional.interpolate',
  10733. aten_name="interpolate",
  10734. variant_test_name='area',
  10735. supports_autograd=True,
  10736. supports_forward_ad=True,
  10737. supports_fwgrad_bwgrad=True,
  10738. dtypes=floating_types(),
  10739. dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
  10740. sample_inputs_func=partial(sample_inputs_interpolate, 'area'),
  10741. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  10742. skips=(
  10743. # RuntimeError: false
  10744. # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185,
  10745. # please report a bug to PyTorch.
  10746. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  10747. ),
  10748. supports_out=False),
  10749. OpInfo('nn.functional.upsample_bilinear',
  10750. supports_autograd=True,
  10751. supports_forward_ad=True,
  10752. supports_fwgrad_bwgrad=True,
  10753. dtypes=floating_types_and(torch.uint8, torch.bfloat16),
  10754. dtypesIfCUDA=floating_types_and(torch.half),
  10755. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  10756. sample_inputs_func=partial(sample_inputs_upsample, 'bilinear'),
  10757. skips=(
  10758. # RuntimeError: false
  10759. # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185,
  10760. # please report a bug to PyTorch.
  10761. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  10762. ),
  10763. supports_out=False),
  10764. OpInfo(
  10765. "nn.functional.soft_margin_loss",
  10766. dtypes=floating_types_and(torch.bfloat16),
  10767. dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16),
  10768. supports_out=False,
  10769. supports_forward_ad=True,
  10770. # doesn't support grad on target
  10771. sample_inputs_func=partial(sample_inputs_loss, rhs_requires_grad=False),
  10772. error_inputs_func=error_inputs_soft_margin_loss,
  10773. ),
  10774. OpInfo('nn.functional.upsample_nearest',
  10775. supports_autograd=True,
  10776. supports_forward_ad=True,
  10777. supports_fwgrad_bwgrad=True,
  10778. dtypes=floating_types_and(torch.uint8, torch.bfloat16),
  10779. dtypesIfCUDA=floating_types_and(torch.half, torch.uint8),
  10780. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  10781. sample_inputs_func=partial(sample_inputs_upsample, 'nearest'),
  10782. skips=(
  10783. # RuntimeError: false
  10784. # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185,
  10785. # please report a bug to PyTorch.
  10786. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  10787. ),
  10788. supports_out=False),
  10789. OpInfo(
  10790. "nn.functional.margin_ranking_loss",
  10791. dtypes=all_types_and(torch.bfloat16),
  10792. dtypesIfCUDA=all_types_and(torch.bfloat16, torch.float16),
  10793. supports_out=False,
  10794. sample_inputs_func=sample_inputs_margin_ranking_loss,
  10795. error_inputs_func=error_inputs_margin_ranking_loss,
  10796. reference_inputs_func=reference_inputs_margin_ranking_loss,
  10797. supports_forward_ad=True,
  10798. supports_fwgrad_bwgrad=True),
  10799. OpInfo(
  10800. "nn.functional.multi_margin_loss",
  10801. dtypes=floating_types(),
  10802. dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16),
  10803. supports_out=False,
  10804. supports_gradgrad=False,
  10805. sample_inputs_func=sample_inputs_multi_margin_loss,
  10806. ),
  10807. OpInfo(
  10808. "nn.functional.multilabel_margin_loss",
  10809. dtypes=floating_types(),
  10810. dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16),
  10811. supports_out=False,
  10812. supports_gradgrad=False,
  10813. sample_inputs_func=sample_inputs_multilabel_margin_loss
  10814. ),
  10815. OpInfo('nn.functional.leaky_relu',
  10816. aliases=None,
  10817. aten_name="leaky_relu",
  10818. aten_backward_name='leaky_relu_backward',
  10819. sample_inputs_func=sample_inputs_leaky_relu,
  10820. dtypes=floating_types_and(torch.bfloat16),
  10821. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  10822. inplace_variant=lambda x, negative_slope=0.01:
  10823. torch.nn.functional.leaky_relu(x, negative_slope, inplace=True),
  10824. supports_autograd=True,
  10825. assert_autodiffed=True,
  10826. supports_gradgrad=True,
  10827. supports_out=False,
  10828. supports_forward_ad=True,
  10829. supports_fwgrad_bwgrad=True,
  10830. autodiff_nonfusible_nodes=["aten::leaky_relu"]),
  10831. OpInfo(
  10832. "nn.functional.multilabel_soft_margin_loss",
  10833. supports_out=False,
  10834. dtypes=floating_types_and(torch.bfloat16),
  10835. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  10836. sample_inputs_func=sample_inputs_multilabel_soft_margin_loss,
  10837. supports_forward_ad=True,
  10838. decorators=(
  10839. DecorateInfo(
  10840. toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}),
  10841. "TestJit",
  10842. "test_variant_consistency_jit",
  10843. ),
  10844. DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
  10845. 'TestCudaFuserOpInfo', 'test_nvfuser_correctness'),
  10846. ),
  10847. skips=(
  10848. # AssertionError: False is not true : Scalars failed to compare as equal! 0 != 4096
  10849. # __main__.TestJitCUDA.test_variant_consistency_jit_nn_functional_multilabel_soft_margin_loss_cuda_float32
  10850. # leaked 4096 bytes CUDA memory on device 0
  10851. DecorateInfo(
  10852. # Skip instead of expectedFailure because this fails
  10853. # locally for me but passes in CI.
  10854. unittest.skip("Skipped!"),
  10855. "TestJit",
  10856. "test_variant_consistency_jit",
  10857. device_type="cuda",
  10858. ),
  10859. ),
  10860. ),
  10861. OpInfo('nn.functional.avg_pool2d',
  10862. aten_name='avg_pool2d',
  10863. supports_autograd=True,
  10864. supports_forward_ad=True,
  10865. supports_fwgrad_bwgrad=True,
  10866. dtypes=floating_types_and(torch.int64, torch.bfloat16),
  10867. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  10868. error_inputs_func=error_inputs_avg_pool2d,
  10869. sample_inputs_func=sample_inputs_avgpool2d,
  10870. skips=(
  10871. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cuda'),
  10872. )),
  10873. OpInfo('nn.functional.fractional_max_pool2d',
  10874. supports_autograd=True,
  10875. supports_out=False,
  10876. supports_forward_ad=True,
  10877. supports_fwgrad_bwgrad=True,
  10878. op=lambda input, *args, **kwargs:
  10879. wrapper_set_seed(torch.nn.functional.fractional_max_pool2d, input, *args, **kwargs),
  10880. # vmap does not support random operations
  10881. check_batched_forward_grad=False,
  10882. dtypes=floating_types(),
  10883. dtypesIfCUDA=floating_types_and(torch.float16),
  10884. test_neg_view=False,
  10885. sample_inputs_func=sample_inputs_fractional_max_pool2d,
  10886. decorators=(
  10887. # FIXME: AssertionError: False is not true : Tensors failed to compare as equal!
  10888. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  10889. # RuntimeError: input->type()->kind() == TypeKind::OptionalType
  10890. # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270
  10891. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit')),
  10892. skips=(
  10893. DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),)),
  10894. OpInfo('nn.functional.fractional_max_pool3d',
  10895. supports_autograd=True,
  10896. supports_out=False,
  10897. supports_forward_ad=True,
  10898. supports_fwgrad_bwgrad=True,
  10899. op=lambda input, *args, **kwargs:
  10900. wrapper_set_seed(torch.nn.functional.fractional_max_pool3d, input, *args, **kwargs),
  10901. # vmap does not support random operations
  10902. check_batched_forward_grad=False,
  10903. dtypes=floating_types(),
  10904. dtypesIfCUDA=floating_types_and(torch.float16),
  10905. test_neg_view=False,
  10906. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  10907. sample_inputs_func=sample_inputs_fractional_max_pool3d,
  10908. decorators=(
  10909. # FIXME: both derivatives are implemented incorrectly
  10910. # https://github.com/pytorch/pytorch/issues/69322
  10911. # FIXME: AssertionError: False is not true : Tensors failed to compare as equal!
  10912. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  10913. # RuntimeError: input->type()->kind() == TypeKind::OptionalType
  10914. # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270
  10915. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit')),
  10916. skips=(
  10917. DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),)),
  10918. OpInfo('nn.functional.max_pool1d',
  10919. aten_name='max_pool1d',
  10920. supports_autograd=True,
  10921. supports_out=False,
  10922. supports_forward_ad=True,
  10923. supports_fwgrad_bwgrad=True,
  10924. # got: Batching rule not implemented for aten::flatten.using_ints
  10925. check_batched_forward_grad=False,
  10926. # TODO: add shape checks
  10927. assert_jit_shape_analysis=False,
  10928. dtypes=floating_types_and(torch.bfloat16),
  10929. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  10930. skips=(
  10931. # Pre-existing condition; Needs to be fixed
  10932. DecorateInfo(unittest.skip("Works on some configs"), 'TestNNCOpInfo',
  10933. 'test_nnc_correctness', dtypes=(torch.bfloat16,)),
  10934. # RuntimeError: The tensor has a non-zero number of elements, but its data is not allocated yet.
  10935. # Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data()
  10936. # to actually allocate memory
  10937. DecorateInfo(unittest.skip("Skipped!"), 'TestTags', 'test_tags'),
  10938. ),
  10939. error_inputs_func=error_inputs_max_pool1d,
  10940. sample_inputs_func=sample_inputs_max_pool),
  10941. OpInfo('nn.functional.max_pool2d',
  10942. aten_name='max_pool2d',
  10943. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  10944. gradcheck_fast_mode=True,
  10945. # Vmap is not happy with non-contiguous (channels_last) inputs
  10946. check_batched_gradgrad=False,
  10947. supports_out=False,
  10948. supports_forward_ad=True,
  10949. supports_fwgrad_bwgrad=True,
  10950. # got: Batching rule not implemented for aten::flatten.using_ints
  10951. check_batched_forward_grad=False,
  10952. assert_jit_shape_analysis=True,
  10953. dtypes=floating_types_and(torch.bfloat16),
  10954. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  10955. error_inputs_func=error_inputs_max_pool2d,
  10956. sample_inputs_func=sample_inputs_max_pool),
  10957. OpInfo('max_pool2d_with_indices_backward',
  10958. op=max_pool2d_backward,
  10959. # We've defined a custom op, so there's no corresponding aten op
  10960. aten_name=None,
  10961. method_variant=None,
  10962. inplace_variant=None,
  10963. operator_variant=None,
  10964. inplace_operator_variant=None,
  10965. check_batched_gradgrad=False,
  10966. supports_forward_ad=True,
  10967. supports_fwgrad_bwgrad=True,
  10968. check_batched_forward_grad=False,
  10969. assert_jit_shape_analysis=False,
  10970. dtypes=floating_types_and(torch.bfloat16),
  10971. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  10972. sample_inputs_func=sample_inputs_max_pool,
  10973. skips=(
  10974. # We've defined a custom op here, and we don't handle the case where we receive an out kwarg
  10975. DecorateInfo(unittest.skip("Skipped!"), "TestCommon", "test_out"),
  10976. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
  10977. # FX failed to normalize op - add the op to the op_skip list.
  10978. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  10979. # object has no attribute max_pool2d_with_indices_backward (It's not available on torch -- so expected)
  10980. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit')
  10981. )),
  10982. OpInfo('nn.functional.max_pool3d',
  10983. aten_name='max_pool3d',
  10984. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  10985. gradcheck_fast_mode=True,
  10986. supports_out=False,
  10987. supports_forward_ad=True,
  10988. supports_fwgrad_bwgrad=True,
  10989. # got: Batching rule not implemented for aten::flatten.using_ints
  10990. check_batched_forward_grad=False,
  10991. # TODO: add shape checks
  10992. assert_jit_shape_analysis=False,
  10993. dtypes=floating_types(),
  10994. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  10995. # TODO: investigate nondeterminism
  10996. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  10997. error_inputs_func=error_inputs_max_pool3d,
  10998. sample_inputs_func=sample_inputs_max_pool),
  10999. OpInfo('nn.functional.max_unpool1d',
  11000. aten_name='max_unpool1d',
  11001. supports_autograd=True,
  11002. supports_forward_ad=True,
  11003. supports_fwgrad_bwgrad=True,
  11004. supports_out=False,
  11005. assert_jit_shape_analysis=False,
  11006. dtypes=floating_types(),
  11007. dtypesIfCUDA=floating_types_and(torch.float16),
  11008. sample_inputs_func=sample_inputs_max_unpool,
  11009. skips=(
  11010. # Gradients are tested in `variant_test_name=grad` below.
  11011. # We skip tests here because there is non-determinism in backward
  11012. # with gather, when there are writes into the same memory location,
  11013. # and if there are several indices pointing to the same memory,
  11014. # gradcheck is oblivious about that and cannot perturb them all at once
  11015. # (see sample_inputs_max_unpool_grad to find out more).
  11016. DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'),
  11017. DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'),
  11018. DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD',
  11019. active_if=(not IS_MACOS)),
  11020. DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_forward_ad',
  11021. device_type='cpu'),
  11022. )),
  11023. OpInfo('nn.functional.max_unpool1d',
  11024. variant_test_name='grad',
  11025. aten_name='max_unpool1d',
  11026. supports_autograd=True,
  11027. supports_forward_ad=True,
  11028. supports_fwgrad_bwgrad=True,
  11029. supports_out=False,
  11030. assert_jit_shape_analysis=False,
  11031. dtypes=floating_types(),
  11032. dtypesIfCUDA=floating_types_and(torch.float16),
  11033. sample_inputs_func=sample_inputs_max_unpool_grad),
  11034. OpInfo('nn.functional.max_unpool2d',
  11035. aten_name='max_unpool2d',
  11036. supports_autograd=True,
  11037. supports_forward_ad=True,
  11038. supports_fwgrad_bwgrad=True,
  11039. supports_out=False,
  11040. assert_jit_shape_analysis=False,
  11041. dtypes=floating_types(),
  11042. dtypesIfCUDA=floating_types_and(torch.float16),
  11043. sample_inputs_func=sample_inputs_max_unpool,
  11044. skips=(
  11045. # Gradients are tested in `variant_test_name=grad` below.
  11046. # We skip tests here because there is non-determinism in backward
  11047. # with gather, when there are writes into the same memory location,
  11048. # and if there are several indices pointing to the same memory,
  11049. # gradcheck is oblivious about that and cannot perturb them all at once
  11050. # (see sample_inputs_max_unpool_grad to find out more).
  11051. DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD',
  11052. active_if=(not IS_MACOS)),
  11053. DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'),
  11054. DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'),
  11055. DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_forward_ad'),
  11056. )),
  11057. OpInfo('nn.functional.max_unpool2d',
  11058. variant_test_name='grad',
  11059. aten_name='max_unpool2d',
  11060. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  11061. gradcheck_fast_mode=True,
  11062. supports_forward_ad=True,
  11063. supports_fwgrad_bwgrad=True,
  11064. # Vmap is not happy with non-contiguous (channels_last) inputs
  11065. check_batched_grad=False,
  11066. supports_out=False,
  11067. assert_jit_shape_analysis=False,
  11068. dtypes=floating_types(),
  11069. dtypesIfCUDA=floating_types_and(torch.float16),
  11070. sample_inputs_func=sample_inputs_max_unpool_grad),
  11071. OpInfo('nn.functional.max_unpool3d',
  11072. aten_name='max_unpool3d',
  11073. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  11074. gradcheck_fast_mode=True,
  11075. supports_forward_ad=True,
  11076. supports_fwgrad_bwgrad=True,
  11077. supports_out=False,
  11078. assert_jit_shape_analysis=False,
  11079. dtypes=floating_types(),
  11080. dtypesIfCUDA=floating_types_and(torch.float16),
  11081. sample_inputs_func=sample_inputs_max_unpool,
  11082. skips=(
  11083. # Gradients are tested in `variant_test_name=grad` below.
  11084. # We skip tests here because there is non-determinism in backward
  11085. # with gather, when there are writes into the same memory location,
  11086. # and if there are several indices pointing to the same memory,
  11087. # gradcheck is oblivious about that and cannot perturb them all at once
  11088. # (see sample_inputs_max_unpool_grad to find out more).
  11089. DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD',
  11090. active_if=(not IS_MACOS)),
  11091. DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'),
  11092. DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'),
  11093. DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_forward_ad'),
  11094. )),
  11095. OpInfo('nn.functional.max_unpool3d',
  11096. variant_test_name='grad',
  11097. aten_name='max_unpool3d',
  11098. supports_autograd=True,
  11099. supports_forward_ad=True,
  11100. supports_fwgrad_bwgrad=True,
  11101. supports_out=False,
  11102. assert_jit_shape_analysis=False,
  11103. dtypes=floating_types(),
  11104. dtypesIfCUDA=floating_types_and(torch.float16),
  11105. sample_inputs_func=sample_inputs_max_unpool_grad),
  11106. OpInfo('nn.functional.linear',
  11107. aten_name='linear',
  11108. supports_autograd=True,
  11109. sample_inputs_func=sample_inputs_linear,
  11110. dtypes=all_types_and_complex_and(torch.bfloat16),
  11111. dtypesIfROCM=floating_and_complex_types_and(torch.float16, torch.bfloat16),
  11112. dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
  11113. backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
  11114. # linear calls mm under the hood which is nondeterministic on CUDA
  11115. # https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms
  11116. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  11117. supports_forward_ad=True,
  11118. supports_fwgrad_bwgrad=True,
  11119. # See https://github.com/pytorch/pytorch/issues/66357
  11120. check_batched_forward_grad=False,
  11121. supports_expanded_weight=True,
  11122. decorators=(
  11123. # Strides are not the same!
  11124. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
  11125. DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
  11126. 'TestCudaFuserOpInfo', 'test_nvfuser_correctness'),
  11127. )),
  11128. OpInfo('nn.functional.bilinear',
  11129. aten_name='bilinear',
  11130. supports_autograd=True,
  11131. sample_inputs_func=sample_inputs_bilinear,
  11132. dtypes=all_types_and(torch.bfloat16),
  11133. dtypesIfCUDA=floating_types_and(torch.float16,
  11134. *[torch.bfloat16] if SM53OrLater or TEST_WITH_ROCM else []),
  11135. skips=(
  11136. # NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3
  11137. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater),
  11138. DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.bfloat16,)),
  11139. ),
  11140. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  11141. gradcheck_fast_mode=True,
  11142. supports_forward_ad=True,
  11143. supports_fwgrad_bwgrad=True,
  11144. supports_out=False),
  11145. OpInfo('nn.functional.glu',
  11146. aten_name='glu',
  11147. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  11148. gradcheck_fast_mode=True,
  11149. sample_inputs_func=sample_inputs_glu,
  11150. dtypes=floating_types_and(torch.bfloat16),
  11151. dtypesIfROCM=floating_types_and(torch.float16, torch.bfloat16),
  11152. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  11153. backward_dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  11154. supports_forward_ad=True,
  11155. supports_fwgrad_bwgrad=True,
  11156. supports_out=False),
  11157. UnaryUfuncInfo(
  11158. 'nn.functional.elu',
  11159. aten_backward_name='elu_backward',
  11160. ref=lambda x, alpha=1.0, inplace=False:
  11161. np.maximum(0., x) + np.minimum(0., alpha * (np.exp(x) - 1)),
  11162. dtypes=floating_types_and(torch.bfloat16),
  11163. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  11164. supports_forward_ad=True,
  11165. supports_fwgrad_bwgrad=True,
  11166. supports_autograd=True,
  11167. assert_autodiffed=False,
  11168. supports_gradgrad=True,
  11169. supports_out=False,
  11170. sample_kwargs=lambda device, dtype, input:
  11171. ({'alpha': 0.8}, {'alpha': 0.8}),
  11172. inplace_variant=lambda x, alpha=1.0:
  11173. torch.nn.functional.elu(x, alpha, inplace=True),
  11174. decorators=[
  11175. DecorateInfo(
  11176. toleranceOverride({
  11177. torch.float16: tol(atol=1e-03, rtol=1.2e-03),
  11178. torch.bfloat16: tol(atol=1e-03, rtol=1.2e-03)
  11179. }),
  11180. 'TestUnaryUfuncs', device_type='cuda',
  11181. ), ],
  11182. ),
  11183. # Marked as a Unary function because it has some rather odd broadcasting semantics in its
  11184. # second argument
  11185. UnaryUfuncInfo(
  11186. 'nn.functional.prelu',
  11187. aten_backward_name='_prelu_kernel_backward',
  11188. ref=lambda x, weight:
  11189. np.maximum(0., x) + np.minimum(0., x) *
  11190. (weight if x.ndim == 1 else weight.reshape([weight.size if i == 1 else 1 for i in range(0, x.ndim)])),
  11191. dtypes=floating_types_and(torch.bfloat16),
  11192. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  11193. supports_forward_ad=True,
  11194. supports_fwgrad_bwgrad=True,
  11195. supports_autograd=True,
  11196. assert_autodiffed=False,
  11197. supports_gradgrad=True,
  11198. supports_out=False,
  11199. # test_reference_numerics only tests the case when the weight tensor is a scalar
  11200. sample_kwargs=sample_kwargs_prelu_scalar_weight,
  11201. error_inputs_func=error_inputs_prelu,
  11202. sample_inputs_func=sample_inputs_prelu,
  11203. reference_inputs_func=reference_inputs_prelu,
  11204. decorators=[
  11205. # RuntimeError: Cannot insert a Tensor that requires grad as a constant.
  11206. # Consider making it a parameter or input, or detaching the gradient
  11207. # https://github.com/pytorch/pytorch/issues/68752
  11208. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), ],
  11209. ),
  11210. UnaryUfuncInfo(
  11211. 'nn.functional.celu',
  11212. ref=lambda x, alpha=1.0, inplace=False:
  11213. np.maximum(0., x) + np.minimum(0., alpha * (np.exp(x / alpha) - 1)),
  11214. dtypes=floating_types_and(torch.bfloat16),
  11215. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  11216. supports_forward_ad=True,
  11217. supports_fwgrad_bwgrad=True,
  11218. supports_autograd=True,
  11219. assert_autodiffed=False,
  11220. supports_gradgrad=True,
  11221. supports_out=False,
  11222. sample_kwargs=lambda device, dtype, input:
  11223. ({'alpha': 0.8}, {'alpha': 0.8}),
  11224. inplace_variant=lambda x, alpha=1.0:
  11225. torch.nn.functional.celu(x, alpha, inplace=True),
  11226. decorators=[
  11227. DecorateInfo(
  11228. toleranceOverride({
  11229. torch.float16: tol(atol=1e-03, rtol=1.2e-03),
  11230. torch.bfloat16: tol(atol=1e-03, rtol=1.2e-03)
  11231. }),
  11232. 'TestUnaryUfuncs', device_type='cuda',
  11233. ), ],
  11234. ),
  11235. UnaryUfuncInfo(
  11236. 'nn.functional.rrelu',
  11237. aten_backward_name='rrelu_with_noise_backward',
  11238. op=lambda input, *args, **kwargs:
  11239. wrapper_set_seed(torch.nn.functional.rrelu, input, *args, **kwargs),
  11240. inplace_variant=lambda input, *args, **kwargs:
  11241. wrapper_set_seed(torch.nn.functional.rrelu, input, *args, inplace=True, **kwargs),
  11242. dtypes=floating_types_and(torch.bfloat16),
  11243. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  11244. gradcheck_wrapper=wrapper_set_seed,
  11245. supports_forward_ad=True,
  11246. supports_fwgrad_bwgrad=True,
  11247. supports_out=False,
  11248. sample_kwargs=lambda device, dtype, input:
  11249. (dict(lower=0., upper=1., training=True), dict(lower=0., upper=1., training=True)),
  11250. sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs=dict(lower=0., upper=1., training=True)),
  11251. error_inputs_func=error_inputs_rrelu,
  11252. decorators=(
  11253. DecorateInfo(
  11254. toleranceOverride({
  11255. torch.float16: tol(atol=1e-03, rtol=1.2e-03),
  11256. torch.bfloat16: tol(atol=1e-03, rtol=1.2e-03)
  11257. }),
  11258. 'TestUnaryUfuncs', device_type='cuda',
  11259. ),),
  11260. skips=(
  11261. # lambda impl
  11262. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  11263. # lambda impl
  11264. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  11265. # In-place operations do not play well with forward AD
  11266. # https://github.com/pytorch/pytorch/issues/77447
  11267. DecorateInfo(unittest.expectedFailure, 'TestFwdGradients',
  11268. 'test_inplace_forward_mode_AD'),
  11269. # The noise vector that's generated in these tests is not the same elementwise
  11270. DecorateInfo(unittest.skip("Different noise"), 'TestUnaryUfuncs', 'test_batch_vs_slicing'),
  11271. DecorateInfo(unittest.skip("Different noise"), 'TestUnaryUfuncs', 'test_contig_vs_every_other'),
  11272. DecorateInfo(unittest.skip("Different noise"), 'TestUnaryUfuncs', 'test_non_contig_expand'),
  11273. DecorateInfo(unittest.skip("Different noise"), 'TestUnaryUfuncs', 'test_contig_vs_transposed'),
  11274. DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'))),
  11275. UnaryUfuncInfo(
  11276. 'nn.functional.selu',
  11277. ref=lambda x, inplace=False:
  11278. 1.0507009873554804934193349852946 * (
  11279. np.maximum(0., x) + np.minimum(0., 1.6732632423543772848170429916717 * (np.exp(x) - 1))
  11280. ),
  11281. dtypes=floating_types_and(torch.bfloat16),
  11282. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  11283. supports_forward_ad=True, # depends on 'elu'
  11284. supports_fwgrad_bwgrad=True,
  11285. supports_autograd=True,
  11286. assert_autodiffed=False,
  11287. supports_gradgrad=True,
  11288. supports_out=False,
  11289. inplace_variant=lambda x: torch.nn.functional.selu(x, inplace=True),
  11290. decorators=[
  11291. DecorateInfo(
  11292. toleranceOverride({
  11293. torch.float16: tol(atol=1e-2, rtol=1.8e-2),
  11294. torch.bfloat16: tol(atol=1e-2, rtol=1.8e-2)
  11295. }),
  11296. 'TestUnaryUfuncs', device_type='cuda',
  11297. ), ],
  11298. ),
  11299. OpInfo(
  11300. 'nn.functional.scaled_dot_product_attention',
  11301. op=lambda *args, **kwargs:
  11302. wrapper_set_seed(torch.nn.functional.scaled_dot_product_attention, *args, **kwargs),
  11303. sample_inputs_func=sample_inputs_scaled_dot_product_attention,
  11304. dtypes=floating_types_and(torch.bfloat16),
  11305. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  11306. supports_out=False,
  11307. supports_forward_ad=False,
  11308. supports_fwgrad_bwgrad=True,
  11309. check_batched_forward_grad=False,
  11310. decorators=[DecorateInfo(toleranceOverride(
  11311. {torch.float32: tol(atol=5e-05, rtol=5e-6)}), 'TestCommon', device_type='cuda',), ],
  11312. skips=(
  11313. # This is only failing on Linux Bionic 3.10 Cuda 11.6
  11314. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes',
  11315. device_type='cuda', active_if=_get_torch_cuda_version() >= (11, 6)),
  11316. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples',
  11317. device_type='cuda', dtypes=(torch.float32,)),
  11318. # AssertionError: JIT Test does not execute any logic
  11319. DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
  11320. # Forward works for dtype=float64 which is the math path
  11321. DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
  11322. # OpInfo was implemented with a lambda
  11323. DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  11324. # See [Note] SDPA_flash's meta function returns incorrect Philox seed and offset
  11325. DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_crossref_backward_amp',
  11326. device_type='cuda', dtypes=(torch.float32,), active_if=PLATFORM_SUPPORTS_FUSED_SDPA and SM80OrLater),
  11327. DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_outplace',
  11328. device_type='cuda', dtypes=(torch.float16, torch.bfloat16),
  11329. active_if=PLATFORM_SUPPORTS_FUSED_SDPA and SM80OrLater),
  11330. DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_outplace',
  11331. device_type='cuda', dtypes=(torch.float16, torch.bfloat16),
  11332. active_if=PLATFORM_SUPPORTS_FUSED_SDPA and SM80OrLater),
  11333. # TODO Need to understand what this is testing and why it doesn't work
  11334. DecorateInfo(unittest.skip("Skipped"), 'TestDecomp', 'test_comprehensive'),
  11335. DecorateInfo(unittest.skip('output is non-deterministic (when dropout_p > 0)'), 'TestCommon', 'test_compare_cpu'),
  11336. # TODO skip this for now since we can't skip on runtime arch support
  11337. DecorateInfo(unittest.skip('This is '), 'TestInductorOpInfo', 'test_comprehensive'),),
  11338. ),
  11339. UnaryUfuncInfo(
  11340. 'nn.functional.silu',
  11341. aten_backward_name='silu_backward',
  11342. ref=lambda x, inplace=False: x / (1 + np.exp(-x)),
  11343. dtypes=floating_types_and(torch.bfloat16),
  11344. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  11345. supports_forward_ad=True,
  11346. supports_autograd=True,
  11347. supports_fwgrad_bwgrad=True,
  11348. assert_autodiffed=True,
  11349. supports_out=False,
  11350. inplace_variant=lambda x: torch.nn.functional.silu(x, inplace=True),
  11351. decorators=[
  11352. DecorateInfo(
  11353. toleranceOverride({
  11354. torch.float16: tol(atol=1e-3, rtol=1e-3),
  11355. torch.bfloat16: tol(atol=1e-4, rtol=1e-4)
  11356. }),
  11357. 'TestUnaryUfuncs', device_type='cuda',
  11358. ), ],
  11359. skips=(
  11360. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal',
  11361. dtypes=(torch.cfloat,), device_type='cpu'),
  11362. ),
  11363. autodiff_nonfusible_nodes=["aten::silu"],
  11364. ),
  11365. # TODO: combine this with the nn.functional.silu OpInfo when
  11366. # complex autodiff for silu is supported or when
  11367. # the forward bug is fixed
  11368. # Note: silu errors when given inputs that require grad
  11369. # but it doesn't support grad in their dtype
  11370. # This is why the dtypes list above passes test_dtypes,
  11371. # because it's getting lucky and failing in forward
  11372. # because test_dtypes sets requires_grad to True
  11373. # THIS IS A BUG
  11374. UnaryUfuncInfo(
  11375. 'nn.functional.silu',
  11376. variant_test_name='complex',
  11377. ref=lambda x, inplace=False:
  11378. x / (1 + np.exp(-x)),
  11379. dtypes=complex_types(),
  11380. dtypesIfCUDA=empty_types(),
  11381. supports_forward_ad=False,
  11382. supports_autograd=False,
  11383. assert_autodiffed=False,
  11384. supports_out=False,
  11385. inplace_variant=lambda x: torch.nn.functional.silu(x, inplace=True),
  11386. decorators=[
  11387. DecorateInfo(
  11388. toleranceOverride({
  11389. torch.float16: tol(atol=1e-3, rtol=1e-3),
  11390. torch.bfloat16: tol(atol=1e-4, rtol=1e-4)
  11391. }),
  11392. 'TestUnaryUfuncs', device_type='cuda',
  11393. ), ],
  11394. skips=(
  11395. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal',
  11396. dtypes=(torch.cfloat,), device_type='cpu'),
  11397. # FIXME: intentionally misreports dtypes
  11398. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'),
  11399. # FIXME: numpy reference diverges: Comparing (nan+nanj) and (-0+0j)
  11400. DecorateInfo(unittest.skip("Skipped!"),
  11401. 'TestUnaryUfuncs', 'test_reference_numerics_large',
  11402. dtypes=(torch.complex64, torch.cdouble)),
  11403. DecorateInfo(unittest.skip("Skipped!"),
  11404. 'TestUnaryUfuncs', 'test_reference_numerics_small',
  11405. dtypes=(torch.complex64,)),
  11406. DecorateInfo(unittest.skip("Skipped!"),
  11407. 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  11408. dtypes=(torch.complex64,)))),
  11409. UnaryUfuncInfo(
  11410. 'nn.functional.hardsigmoid',
  11411. aten_backward_name='hardsigmoid_backward',
  11412. ref=reference_hardsigmoid,
  11413. dtypes=floating_types_and(torch.bfloat16),
  11414. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  11415. supports_autograd=True,
  11416. assert_autodiffed=False,
  11417. supports_gradgrad=False,
  11418. supports_forward_ad=True,
  11419. supports_out=False,
  11420. inplace_variant=partial(torch.nn.functional.hardsigmoid, inplace=True),
  11421. decorators=[
  11422. DecorateInfo(
  11423. toleranceOverride({torch.float16: tol(atol=1e-04, rtol=0.001)}), 'TestUnaryUfuncs', device_type='cuda',), ],
  11424. skips=[
  11425. # still want to test that first derivative works though second derivative isn't supported
  11426. DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', "test_inplace_gradgrad"),
  11427. # produces 0 instead of nan on ROCM
  11428. DecorateInfo(unittest.expectedFailure,
  11429. 'TestUnaryUfuncs', "test_reference_numerics_extremal",
  11430. device_type='cuda',
  11431. active_if=(TEST_WITH_ROCM)), ]
  11432. ),
  11433. UnaryUfuncInfo(
  11434. 'nn.functional.logsigmoid',
  11435. aten_name="log_sigmoid",
  11436. aten_backward_name='log_sigmoid_backward',
  11437. ref=reference_logsigmoid,
  11438. dtypes=floating_types_and(torch.bfloat16),
  11439. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  11440. supports_autograd=True,
  11441. assert_autodiffed=False,
  11442. supports_forward_ad=True,
  11443. supports_gradgrad=True,
  11444. # autodiff_nonfusible_nodes=["aten::log_sigmoid"],
  11445. decorators=[
  11446. DecorateInfo(
  11447. precisionOverride({torch.float16: 1e-2, torch.bfloat16: 5e-3}),
  11448. 'TestUnaryUfuncs', 'test_reference_numerics_small'),
  11449. DecorateInfo(
  11450. precisionOverride({torch.float16: 1e-2, torch.bfloat16: 5e-3}),
  11451. 'TestUnaryUfuncs', 'test_reference_numerics_large'),
  11452. DecorateInfo(
  11453. precisionOverride({torch.float16: 1e-2, torch.bfloat16: 5e-3}),
  11454. 'TestUnaryUfuncs', 'test_reference_numerics_extremal'),
  11455. ],
  11456. skips=(
  11457. # Resized a non-empty tensor but did not warn about it.
  11458. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type='cpu'),
  11459. ),
  11460. ),
  11461. UnaryUfuncInfo(
  11462. 'nn.functional.mish',
  11463. aten_backward_name='mish_backward',
  11464. ref=lambda x: x * np.tanh(reference_softplus(x)),
  11465. dtypes=floating_types_and(torch.bfloat16),
  11466. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  11467. supports_forward_ad=True,
  11468. supports_fwgrad_bwgrad=True,
  11469. supports_autograd=True,
  11470. assert_autodiffed=False,
  11471. supports_gradgrad=True,
  11472. supports_out=False,
  11473. inplace_variant=partial(torch.nn.functional.mish, inplace=True),
  11474. decorators=[
  11475. DecorateInfo(
  11476. toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-03)}), 'TestUnaryUfuncs', device_type='cuda',), ],
  11477. ),
  11478. UnaryUfuncInfo(
  11479. 'nn.functional.softsign',
  11480. ref=lambda x: x / (np.abs(x) + 1),
  11481. dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
  11482. dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool),
  11483. supports_forward_ad=True,
  11484. supports_fwgrad_bwgrad=True,
  11485. supports_autograd=True,
  11486. assert_autodiffed=False,
  11487. supports_gradgrad=True,
  11488. supports_out=False,
  11489. decorators=[
  11490. DecorateInfo(
  11491. toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1.3e-04)}), 'TestUnaryUfuncs',), ],
  11492. skips=(
  11493. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
  11494. dtypes=(torch.int, torch.int8)),
  11495. # pytorch computes (0+nanj), numpy computes (-5e-18-1j) for input (-501.-1.0000e+20j)
  11496. DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs',
  11497. "test_reference_numerics_large", dtypes=(torch.complex64,), device_type='cpu',
  11498. active_if=not IS_MACOS and not IS_WINDOWS),),
  11499. ),
  11500. UnaryUfuncInfo(
  11501. 'nn.functional.tanhshrink',
  11502. ref=lambda x: x - np.tanh(x),
  11503. dtypes=all_types_and_complex_and(torch.bfloat16),
  11504. dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
  11505. supports_forward_ad=True,
  11506. supports_fwgrad_bwgrad=True,
  11507. supports_autograd=True,
  11508. assert_autodiffed=False,
  11509. supports_gradgrad=True,
  11510. supports_out=False,
  11511. decorators=[
  11512. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal',
  11513. device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
  11514. DecorateInfo(
  11515. toleranceOverride({torch.bfloat16: tol(atol=1e-02, rtol=1.6e-02)}), 'TestUnaryUfuncs',),
  11516. DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
  11517. 'TestCudaFuserOpInfo', 'test_nvfuser_correctness'),
  11518. ],
  11519. skips=(
  11520. # in each case, pytorch will produce a nan while numpy will not
  11521. DecorateInfo(unittest.expectedFailure,
  11522. 'TestUnaryUfuncs', "test_reference_numerics_small",
  11523. dtypes=(torch.complex64, torch.complex128), active_if=(IS_MACOS)),
  11524. DecorateInfo(unittest.skip("Fails on some jobs works on others!"),
  11525. 'TestUnaryUfuncs', "test_reference_numerics_large",
  11526. dtypes=(torch.complex64, torch.complex128), active_if=(IS_MACOS)),
  11527. DecorateInfo(unittest.skip("Fails on some jobs works on others!"),
  11528. 'TestUnaryUfuncs', "test_reference_numerics_extremal",
  11529. dtypes=(torch.complex64, torch.complex128), device_type='cpu',
  11530. active_if=(IS_MACOS or IS_WINDOWS)),
  11531. ),
  11532. ),
  11533. UnaryUfuncInfo(
  11534. 'nn.functional.threshold',
  11535. ref=lambda x, threshold, value: np.where(x <= threshold, value, x).astype(x.dtype),
  11536. dtypes=all_types_and(torch.bfloat16),
  11537. dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
  11538. inplace_variant=lambda x, threshold, value:
  11539. torch.nn.functional.threshold(x, threshold, value, inplace=True),
  11540. supports_forward_ad=True,
  11541. supports_fwgrad_bwgrad=True,
  11542. assert_autodiffed=False,
  11543. supports_gradgrad=True,
  11544. supports_out=False,
  11545. sample_kwargs=lambda device, dtype, input: ({'threshold': float.fromhex('0x1.3ap-3'),
  11546. 'value': -9},
  11547. {'threshold': float.fromhex('0x1.3ap-3'),
  11548. 'value': -9}),
  11549. # TODO(whc) should not need sample_inputs_func, but without it
  11550. # kwargs aren't being hooked up properly
  11551. sample_inputs_func=sample_inputs_threshold,
  11552. ),
  11553. OpInfo(
  11554. "nn.functional.triplet_margin_loss",
  11555. sample_inputs_func=sample_inputs_triplet_margin_loss,
  11556. error_inputs_func=error_inputs_triplet_margin_loss,
  11557. dtypes=all_types_and_complex_and(torch.bfloat16),
  11558. dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
  11559. supports_out=False,
  11560. supports_forward_ad=True,
  11561. supports_fwgrad_bwgrad=True,
  11562. ),
  11563. OpInfo(
  11564. "nn.functional.triplet_margin_with_distance_loss",
  11565. sample_inputs_func=partial(sample_inputs_triplet_margin_loss, with_distance=True),
  11566. error_inputs_func=error_inputs_triplet_margin_loss,
  11567. dtypes=all_types_and_complex_and(torch.bfloat16),
  11568. dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
  11569. supports_out=False,
  11570. supports_forward_ad=True,
  11571. supports_fwgrad_bwgrad=True,
  11572. skips=(
  11573. # This test cannot handle a callable passed to `distance_function`. If we would use
  11574. # `distance_function=None`, the test would pass fine.
  11575. DecorateInfo(
  11576. unittest.expectedFailure,
  11577. "TestJit",
  11578. "test_variant_consistency_jit",
  11579. ),
  11580. DecorateInfo(
  11581. unittest.expectedFailure,
  11582. "TestNormalizeOperators",
  11583. "test_normalize_operator_exhaustive",
  11584. ),
  11585. ),
  11586. ),
  11587. BinaryUfuncInfo('nextafter',
  11588. dtypes=floating_types_and(torch.bfloat16),
  11589. supports_autograd=False,
  11590. supports_rhs_python_scalar=False),
  11591. OpInfo(
  11592. "to",
  11593. op=lambda x, *args, **kwargs: x.to(*args, **kwargs),
  11594. dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool),
  11595. supports_forward_ad=True,
  11596. supports_fwgrad_bwgrad=True,
  11597. supports_out=False,
  11598. sample_inputs_func=sample_inputs_to,
  11599. skips=(
  11600. # RuntimeError: undefined value cpu
  11601. DecorateInfo(
  11602. unittest.skip("Skipped!"),
  11603. "TestJit",
  11604. "test_variant_consistency_jit",
  11605. device_type="cpu",
  11606. ),
  11607. # NotImplementedError: Cannot copy out of meta tensor; no data!
  11608. DecorateInfo(
  11609. unittest.skip("Skipped!"),
  11610. "TestMeta",
  11611. "test_meta_outplace",
  11612. ),
  11613. # https://github.com/pytorch/pytorch/issues/84335
  11614. DecorateInfo(
  11615. unittest.skip("Skipped!"),
  11616. "TestProxyTensorOpInfo",
  11617. "test_make_fx_symbolic_exhaustive",
  11618. ),
  11619. DecorateInfo(
  11620. unittest.skip("Skipped!"),
  11621. "TestNormalizeOperators",
  11622. "test_normalize_operator_exhaustive",
  11623. ),
  11624. ),
  11625. ),
  11626. OpInfo('topk',
  11627. dtypes=all_types_and(torch.bfloat16),
  11628. dtypesIfCUDA=all_types_and(torch.bfloat16, torch.float16),
  11629. supports_forward_ad=True,
  11630. supports_fwgrad_bwgrad=True,
  11631. assert_jit_shape_analysis=True,
  11632. sample_inputs_func=sample_inputs_topk),
  11633. # Multiple variants for batch_norm to test with and without cuDNN disabled
  11634. # See https://github.com/pytorch/pytorch/pull/63218#discussion_r688549391 for more details
  11635. OpInfo('nn.functional.batch_norm',
  11636. aten_name='batch_norm',
  11637. dtypes=floating_types_and(torch.bfloat16),
  11638. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  11639. supports_out=False,
  11640. supports_forward_ad=True,
  11641. supports_fwgrad_bwgrad=True,
  11642. assert_jit_shape_analysis=True,
  11643. sample_inputs_func=sample_inputs_batch_norm,
  11644. skips=(
  11645. # see https://github.com/pytorch/pytorch/issues/71286
  11646. DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness'),
  11647. DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness',
  11648. device_type='cpu', dtypes=(torch.bfloat16,)),
  11649. # Trying to use forward AD with miopen_batch_norm that does not support it
  11650. # because it has not been implemented yet.
  11651. DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad',
  11652. device_type="cuda", active_if=TEST_WITH_ROCM),
  11653. DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-05, rtol=1e-05)}),
  11654. 'TestCompositeCompliance', 'test_forward_ad', device_type="cpu"),
  11655. DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
  11656. 'TestCudaFuserOpInfo', 'test_nvfuser_correctness'),
  11657. )),
  11658. # This variant tests batch_norm with cuDNN disabled only on CUDA devices
  11659. OpInfo('nn.functional.batch_norm',
  11660. variant_test_name='without_cudnn',
  11661. aten_name='batch_norm',
  11662. dtypes=empty_types(),
  11663. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  11664. supports_out=False,
  11665. supports_forward_ad=True,
  11666. supports_fwgrad_bwgrad=True,
  11667. decorators=[onlyCUDA, disablecuDNN],
  11668. skips=(
  11669. DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
  11670. 'TestCudaFuserOpInfo', 'test_nvfuser_correctness'),
  11671. DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-04)}),
  11672. 'TestJit', 'test_variant_consistency_jit'),
  11673. ),
  11674. sample_inputs_func=sample_inputs_batch_norm),
  11675. OpInfo(
  11676. "nn.functional.binary_cross_entropy",
  11677. aten_backward_name='binary_cross_entropy_backward',
  11678. sample_inputs_func=sample_inputs_binary_cross_entropy,
  11679. dtypes=floating_types(),
  11680. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  11681. supports_out=False,
  11682. gradcheck_fast_mode=False,
  11683. supports_autograd=True,
  11684. supports_forward_ad=True,
  11685. supports_fwgrad_bwgrad=True,
  11686. decorators=(
  11687. # RuntimeError: expected int at position 0, but got: Tensor
  11688. DecorateInfo(
  11689. unittest.skip("Skipped!"),
  11690. "TestCudaFuserOpInfo",
  11691. ),
  11692. # RuntimeError: expected int at position 0, but got: Tensor
  11693. DecorateInfo(
  11694. unittest.skip("Skipped!"),
  11695. "TestNNCOpInfo",
  11696. "test_nnc_correctness",
  11697. ),
  11698. DecorateInfo(
  11699. toleranceOverride({torch.float32: tol(atol=1e-3, rtol=1e-3)}),
  11700. "TestJit",
  11701. "test_variant_consistency_jit",
  11702. ),
  11703. ),
  11704. skips=(
  11705. # RuntimeError: expected int at position 0, but got: Tensor
  11706. DecorateInfo(
  11707. unittest.expectedFailure,
  11708. "TestJit",
  11709. "test_variant_consistency_jit",
  11710. ),
  11711. ),
  11712. ),
  11713. # We have to add 2 OpInfo entry for `igamma` and `igammac`.First is the
  11714. # standard entry, second is to run gradcheck tests on the second argument.
  11715. BinaryUfuncInfo('igamma',
  11716. dtypes=floating_types_and(torch.bfloat16, torch.float16),
  11717. aliases=('torch.special.gammainc',),
  11718. dtypesIfCUDA=floating_types(),
  11719. # TODO: FIXME
  11720. supports_rhs_python_scalar=False,
  11721. supports_autograd=False,
  11722. skips=(
  11723. # FIXME: incorrectly tries to pass a rhs scalar
  11724. DecorateInfo(unittest.expectedFailure, 'TestJit',
  11725. 'test_jit_alias_remapping'),
  11726. )),
  11727. # TODO: FIXME, ideally by implemented grad for both inputs
  11728. # BinaryUfuncInfo('igamma',
  11729. # variant_test_name='grad_other',
  11730. # # Since autograd formula is implemented only for other and
  11731. # # gradcheck test verifies the formula for input in SampleInput,
  11732. # # we permute the arguments.
  11733. # op=lambda self, other, **kwargs: torch.igamma(other, self, **kwargs),
  11734. # inplace_variant=None,
  11735. # method_variant=None,
  11736. # supports_rhs_python_scalar=False,
  11737. # rhs_make_tensor_kwargs=dict(requires_grad=False),
  11738. # dtypes=floating_types_and(torch.bfloat16, torch.float16),
  11739. # backward_dtypesIfCPU=floating_types_and(torch.bfloat16),
  11740. # dtypesIfCUDA=floating_types(),
  11741. # backward_dtypesIfCUDA=floating_types(),
  11742. # supports_inplace_autograd=False,
  11743. # skips=(
  11744. # # Derivative wrt first tensor not implemented
  11745. # DecorateInfo(unittest.expectedFailure, "TestCommon",
  11746. # "test_floating_inputs_are_differentiable"),"),
  11747. # # test does not work with passing lambda for op
  11748. # # AssertionError: False is not true : Tensors failed to compare as equal!
  11749. # DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
  11750. # # test fails are we permute the arguments function variant
  11751. # # but not for inplace or method.
  11752. # DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'),
  11753. # # TypeError: igamma(): argument 'input' (position 1) must be Tensor, not float
  11754. # DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs'),
  11755. # )),
  11756. BinaryUfuncInfo('igammac',
  11757. dtypes=floating_types_and(torch.bfloat16, torch.float16),
  11758. aliases=('torch.special.gammaincc',),
  11759. dtypesIfCUDA=floating_types(),
  11760. supports_autograd=False,
  11761. supports_rhs_python_scalar=False,
  11762. skips=(
  11763. # FIXME: incorrectly tries to pass a rhs scalar
  11764. DecorateInfo(unittest.expectedFailure, 'TestJit',
  11765. 'test_jit_alias_remapping'),
  11766. )),
  11767. # TODO: FIXME, ideally by implementing grad for both inputs
  11768. # BinaryUfuncInfo('igammac',
  11769. # variant_test_name='grad_other',
  11770. # # Since autograd formula is implemented only for other and
  11771. # # gradcheck test verifies the formula for input in SampleInput,
  11772. # # we permute the arguments
  11773. # op=lambda self, other, **kwargs: torch.igammac(other, self, **kwargs),
  11774. # inplace_variant=None,
  11775. # method_variant=None,
  11776. # supports_rhs_python_scalar=False,
  11777. # rhs_make_tensor_kwargs=dict(requires_grad=False),
  11778. # dtypes=floating_types_and(torch.bfloat16, torch.float16),
  11779. # backward_dtypesIfCPU=floating_types_and(torch.bfloat16),
  11780. # dtypesIfCUDA=floating_types(),
  11781. # backward_dtypesIfCUDA=floating_types(),
  11782. # supports_inplace_autograd=False,
  11783. # decorators=[
  11784. # # Derivative wrt first tensor not implemented
  11785. # DecorateInfo(unittest.expectedFailure, "TestCommon",
  11786. # "test_floating_inputs_are_differentiable"),
  11787. # ],
  11788. # skips=(
  11789. # # test does not work with passing lambda for op
  11790. # # AssertionError: False is not true : Tensors failed to compare as equal!
  11791. # DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
  11792. # # test fails are we permute the arguments function variant
  11793. # # but not for inplace or method.
  11794. # DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'),
  11795. # # TypeError: igammac(): argument 'input' (position 1) must be Tensor, not float
  11796. # DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs'),
  11797. # )),
  11798. UnaryUfuncInfo('nn.functional.softshrink',
  11799. aten_name="softshrink",
  11800. aten_backward_name='softshrink_backward',
  11801. dtypes=floating_types_and(torch.bfloat16),
  11802. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  11803. supports_forward_ad=True,
  11804. supports_fwgrad_bwgrad=True,
  11805. assert_autodiffed=False,
  11806. sample_inputs_func=sample_inputs_softshrink,
  11807. error_inputs_func=error_inputs_softshrink),
  11808. UnaryUfuncInfo('nn.functional.hardshrink',
  11809. aten_name="hardshrink",
  11810. aten_backward_name='hardshrink_backward',
  11811. dtypes=floating_types_and(torch.bfloat16,),
  11812. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  11813. assert_autodiffed=True,
  11814. sample_inputs_func=sample_inputs_hardshrink,
  11815. supports_forward_ad=True,
  11816. supports_fwgrad_bwgrad=True,
  11817. autodiff_nonfusible_nodes=["aten::hardshrink"]),
  11818. UnaryUfuncInfo('nn.functional.hardtanh',
  11819. aten_name="hardtanh",
  11820. aten_backward_name='hardtanh_backward',
  11821. dtypes=floating_types_and(torch.int8, torch.int16, torch.int32, torch.int64, torch.bfloat16),
  11822. backward_dtypes=all_types_and(torch.bfloat16),
  11823. dtypesIfCUDA=floating_types_and(torch.int8, torch.int16, torch.int32, torch.int64, torch.float16,
  11824. torch.bfloat16),
  11825. backward_dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  11826. assert_autodiffed=True,
  11827. sample_inputs_func=sample_inputs_hardtanh,
  11828. supports_out=False,
  11829. supports_forward_ad=True,
  11830. supports_fwgrad_bwgrad=True,
  11831. autodiff_nonfusible_nodes=["aten::hardtanh"]),
  11832. OpInfo('nn.functional.gelu',
  11833. aten_name="gelu",
  11834. aten_backward_name='gelu_backward',
  11835. ref=reference_gelu if TEST_SCIPY else None,
  11836. error_inputs_func=error_inputs_gelu,
  11837. supports_autograd=True,
  11838. assert_autodiffed=True,
  11839. sample_inputs_func=sample_inputs_gelu,
  11840. dtypes=floating_types_and(torch.bfloat16),
  11841. dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
  11842. supports_gradgrad=True,
  11843. supports_forward_ad=True,
  11844. supports_fwgrad_bwgrad=True,
  11845. autodiff_nonfusible_nodes=["aten::gelu"],
  11846. skips=(
  11847. # AssertionError: Tensor-likes are not close!
  11848. # May not replicate in CI
  11849. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'),
  11850. DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'),
  11851. )),
  11852. UnaryUfuncInfo('nn.functional.relu6',
  11853. aten_name="relu6",
  11854. dtypes=all_types_and(torch.bfloat16),
  11855. backward_dtypes=floating_types_and(torch.bfloat16),
  11856. dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
  11857. backward_dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  11858. assert_autodiffed=True,
  11859. supports_out=False,
  11860. supports_forward_ad=True,
  11861. supports_fwgrad_bwgrad=True,
  11862. autodiff_nonfusible_nodes=["aten::relu6"]),
  11863. OpInfo('mm',
  11864. dtypes=all_types_and_complex_and(torch.bfloat16),
  11865. dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
  11866. assert_autodiffed=True,
  11867. supports_forward_ad=True,
  11868. supports_fwgrad_bwgrad=True,
  11869. sample_inputs_func=sample_inputs_mm,
  11870. skips=(
  11871. # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
  11872. DecorateInfo(
  11873. unittest.skip("Skipped!"),
  11874. 'TestSchemaCheckModeOpInfo',
  11875. 'test_schema_correctness',
  11876. dtypes=(torch.complex64, torch.complex128)),
  11877. )),
  11878. OpInfo('mode',
  11879. op=torch.mode,
  11880. dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
  11881. supports_forward_ad=True,
  11882. supports_fwgrad_bwgrad=True,
  11883. skips=(
  11884. # Resized a non-empty tensor but did not warn about it
  11885. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
  11886. ),
  11887. sample_inputs_func=sample_inputs_mode,),
  11888. make_mvlgamma_opinfo(variant_test_name='mvlgamma_p_1',
  11889. domain=(1, None),
  11890. skips=skips_mvlgamma() + (
  11891. DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_extremal'),
  11892. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  11893. dtypes=(torch.float16, torch.int8)),
  11894. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
  11895. dtypes=(torch.int8,)),
  11896. ),
  11897. sample_kwargs=lambda device, dtype, input: ({'p': 1}, {'d': 1})),
  11898. make_mvlgamma_opinfo(variant_test_name='mvlgamma_p_3',
  11899. domain=(2, None),
  11900. skips=skips_mvlgamma() + (
  11901. DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_extremal'),
  11902. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  11903. dtypes=(torch.float16, torch.int8)),
  11904. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
  11905. dtypes=(torch.int8,)),
  11906. ),
  11907. sample_kwargs=lambda device, dtype, input: ({'p': 3}, {'d': 3})),
  11908. make_mvlgamma_opinfo(variant_test_name='mvlgamma_p_5',
  11909. domain=(3, None),
  11910. skips=skips_mvlgamma() + (
  11911. DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_extremal'),
  11912. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  11913. dtypes=(torch.float16, torch.int8)),
  11914. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
  11915. dtypes=(torch.int8,)),
  11916. ),
  11917. sample_kwargs=lambda device, dtype, input: ({'p': 5}, {'d': 5})),
  11918. BinaryUfuncInfo('ne',
  11919. ref=np.not_equal,
  11920. aliases=('not_equal',),
  11921. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
  11922. always_returns_bool=True,
  11923. supports_autograd=False,
  11924. skips=(
  11925. )),
  11926. OpInfo('narrow',
  11927. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
  11928. supports_out=False,
  11929. supports_forward_ad=True,
  11930. supports_fwgrad_bwgrad=True,
  11931. sample_inputs_func=partial(sample_inputs_narrow_narrow_copy, is_narrow=True),
  11932. reference_inputs_func=partial(reference_inputs_narrow_narrow_copy, is_narrow=True),
  11933. error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=True, is_ref=False),
  11934. skips=(
  11935. # Use of .item()
  11936. DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'),
  11937. DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
  11938. DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'),
  11939. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
  11940. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
  11941. )),
  11942. OpInfo('narrow_copy',
  11943. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
  11944. supports_out=True,
  11945. supports_forward_ad=False,
  11946. supports_fwgrad_bwgrad=False,
  11947. supports_autograd=False,
  11948. # https://github.com/pytorch/pytorch/issues/86931
  11949. sample_inputs_func=partial(sample_inputs_narrow_narrow_copy, is_narrow=False),
  11950. reference_inputs_func=partial(reference_inputs_narrow_narrow_copy, is_narrow=False),
  11951. error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=False, is_ref=False),
  11952. skips=(
  11953. # https://github.com/pytorch/pytorch/issues/84577
  11954. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
  11955. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
  11956. # Lazy tensor failures: mutating and aliasing ops should all have codegen'd kernels
  11957. DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_correctness'),
  11958. DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_correctness_with_reusing_ir'),
  11959. )),
  11960. OpInfo('view_copy',
  11961. dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16),
  11962. ref=lambda x, newshape: np.reshape(x, newshape).copy(),
  11963. supports_out=True,
  11964. supports_forward_ad=True,
  11965. supports_fwgrad_bwgrad=True,
  11966. supports_autograd=True,
  11967. sample_inputs_func=sample_inputs_view_reshape,
  11968. error_inputs_func=error_inputs_view_reshape),
  11969. UnaryUfuncInfo('neg',
  11970. aliases=('negative', ),
  11971. ref=np.negative,
  11972. dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf),
  11973. error_inputs_func=error_inputs_neg,
  11974. supports_forward_ad=True,
  11975. supports_fwgrad_bwgrad=True,
  11976. supports_sparse=True,
  11977. supports_sparse_csr=True,
  11978. supports_sparse_csc=True,
  11979. supports_sparse_bsr=True,
  11980. supports_sparse_bsc=True,
  11981. assert_autodiffed=True),
  11982. OpInfo('dist',
  11983. op=torch.dist,
  11984. dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
  11985. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  11986. gradcheck_fast_mode=True,
  11987. supports_out=False,
  11988. supports_forward_ad=True,
  11989. # torch.autograd.gradcheck.GradcheckError: While computing batched gradients, got:
  11990. # Could not allocate memory to change Tensor SizesAndStrides!
  11991. check_batched_forward_grad=False,
  11992. supports_fwgrad_bwgrad=True,
  11993. sample_inputs_func=sample_inputs_dist),
  11994. OpInfo('outer',
  11995. op=torch.outer,
  11996. aliases=('ger', ),
  11997. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
  11998. supports_forward_ad=True,
  11999. supports_fwgrad_bwgrad=True,
  12000. # See https://github.com/pytorch/pytorch/pull/78358
  12001. check_batched_forward_grad=False,
  12002. sample_inputs_func=sample_inputs_outer,),
  12003. OpInfo('ormqr',
  12004. op=torch.ormqr,
  12005. dtypes=floating_and_complex_types(),
  12006. # https://github.com/pytorch/pytorch/issues/80411
  12007. gradcheck_fast_mode=True,
  12008. supports_forward_ad=False,
  12009. supports_fwgrad_bwgrad=False,
  12010. sample_inputs_func=sample_inputs_ormqr,
  12011. error_inputs_func=error_inputs_ormqr,
  12012. decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack],
  12013. skips=(
  12014. # Strides are not the same!
  12015. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
  12016. )),
  12017. OpInfo('permute',
  12018. ref=np.transpose,
  12019. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
  12020. supports_out=False,
  12021. assert_autodiffed=True,
  12022. autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused
  12023. autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused
  12024. assert_jit_shape_analysis=True,
  12025. supports_forward_ad=True,
  12026. supports_fwgrad_bwgrad=True,
  12027. supports_varargs=True,
  12028. sample_inputs_func=sample_inputs_permute,
  12029. reference_inputs_func=reference_inputs_permute),
  12030. BinaryUfuncInfo('pow',
  12031. dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
  12032. dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf),
  12033. ref=np.power,
  12034. # Due to AVX2 curently not being fully supported for Float16, log_vml_cpu can't be enabled
  12035. # for Float16, causing this test to fail. pow's autograd for Float16 is thus currently
  12036. # unsupported on CPU.
  12037. backward_dtypes=floating_and_complex_types_and(torch.bfloat16),
  12038. backward_dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16, torch.half, torch.chalf),
  12039. # https://github.com/pytorch/pytorch/issues/80411
  12040. gradcheck_fast_mode=True,
  12041. supports_inplace_autograd=False,
  12042. supports_forward_ad=True,
  12043. supports_fwgrad_bwgrad=True,
  12044. assert_autodiffed=True,
  12045. supports_one_python_scalar=True,
  12046. # Integer types do not support negative exponentes
  12047. rhs_make_tensor_kwargs=dict(low=0),
  12048. # Raising negative real numbers to fractional powers is not supported
  12049. lhs_make_tensor_kwargs=dict(low=0),
  12050. decorators=(
  12051. DecorateInfo(toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1.3e-05)}),
  12052. 'TestBinaryUfuncs', 'test_reference_numerics'),
  12053. DecorateInfo(toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1.3e-05),
  12054. torch.complex128: tol(atol=1e-4, rtol=1.3e-05)}),
  12055. 'TestBinaryUfuncs', 'test_scalar_support'),
  12056. ),
  12057. skips=(
  12058. # Skipping integers because they are being raised to negative powers causing an error
  12059. DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_reference_numerics_small_values',
  12060. dtypes=[torch.int8, torch.int16, torch.int32, torch.int64]),
  12061. DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_reference_numerics_large_values',
  12062. dtypes=[torch.int16, torch.int32, torch.int64]),
  12063. # FIXME Complex values error with: Greatest absolute difference: nan at index
  12064. # Ref: https://github.com/pytorch/pytorch/issues/76853
  12065. # For `chalf`, reference computation in `numpy` is computed in `cfloat`.
  12066. # Output of `chalf` saturates to `inf` quicker than reference due to its small range
  12067. # which leads to failure of this test.
  12068. DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_quick',
  12069. dtypes=(torch.complex32,), active_if=TEST_WITH_ROCM),
  12070. DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_comprehensive',
  12071. dtypes=(torch.complex32,), active_if=TEST_WITH_ROCM),
  12072. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_complex_half_reference_testing',
  12073. dtypes=(torch.complex32,), active_if=TEST_WITH_ROCM),
  12074. DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_batch_vs_slicing',
  12075. dtypes=(torch.complex32,)),
  12076. DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_non_contig',
  12077. dtypes=(torch.complex32,)),
  12078. DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics',
  12079. dtypes=(torch.complex32,)),
  12080. DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_small_values',
  12081. dtypes=(torch.complex32, torch.complex64, torch.complex128)),
  12082. DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_large_values',
  12083. dtypes=(torch.complex32, torch.complex64, torch.complex128)),
  12084. DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_extremal_values',
  12085. dtypes=(torch.complex32, torch.complex64, torch.complex128)),
  12086. )),
  12087. BinaryUfuncInfo('float_power',
  12088. ref=np.float_power,
  12089. dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
  12090. promotes_int_to_float=True,
  12091. # https://github.com/pytorch/pytorch/issues/80411
  12092. gradcheck_fast_mode=True,
  12093. supports_forward_ad=True,
  12094. supports_fwgrad_bwgrad=True,
  12095. supports_one_python_scalar=True,
  12096. # Integer types do not support negative exponentes
  12097. rhs_make_tensor_kwargs=dict(low=0),
  12098. # Raising negative real numbers to fractional powers is not supported
  12099. lhs_make_tensor_kwargs=dict(low=0),
  12100. decorators=(
  12101. DecorateInfo(toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1.3e-05),
  12102. torch.complex128: tol(atol=1e-4, rtol=1.3e-05)}),
  12103. 'TestBinaryUfuncs', 'test_scalar_support'),
  12104. ),
  12105. skips=(
  12106. # FIXME
  12107. # AssertionError: Object comparison failed: torch.float64 != torch.float32
  12108. DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'),
  12109. # -3.43399e+38 is outside the range of representable values of type 'float'
  12110. DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
  12111. # Complex values error with: Greatest absolute difference: nan at index
  12112. DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_small_values',
  12113. dtypes=[torch.complex64, torch.complex128]),
  12114. DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_large_values',
  12115. dtypes=[torch.complex64, torch.complex128]),
  12116. DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_extremal_values',
  12117. dtypes=[torch.complex64, torch.complex128]),
  12118. )),
  12119. OpInfo('qr',
  12120. op=torch.qr,
  12121. dtypes=floating_and_complex_types(),
  12122. sample_inputs_func=sample_inputs_linalg_qr_geqrf,
  12123. supports_forward_ad=True,
  12124. supports_fwgrad_bwgrad=True,
  12125. # In-place ops
  12126. check_batched_gradgrad=False,
  12127. decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack]),
  12128. UnaryUfuncInfo('rad2deg',
  12129. ref=np.degrees,
  12130. decorators=(precisionOverride({torch.bfloat16: 7e-1,
  12131. torch.float16: 7e-1}),),
  12132. dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
  12133. skips=(
  12134. # Reference: https://github.com/pytorch/pytorch/pull/51283#issuecomment-770614273
  12135. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
  12136. dtypes=[torch.bfloat16]),
  12137. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  12138. dtypes=[torch.bfloat16]),
  12139. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  12140. dtypes=[torch.bfloat16]),
  12141. ),
  12142. supports_forward_ad=True,
  12143. supports_fwgrad_bwgrad=True,
  12144. supports_sparse=True,
  12145. supports_sparse_csr=True,
  12146. supports_sparse_csc=True,
  12147. supports_sparse_bsr=True,
  12148. supports_sparse_bsc=True),
  12149. UnaryUfuncInfo('real',
  12150. ref=np.real,
  12151. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf),
  12152. supports_out=False,
  12153. supports_forward_ad=True,
  12154. supports_fwgrad_bwgrad=True,
  12155. # See https://github.com/pytorch/pytorch/issues/66357
  12156. check_batched_forward_grad=False,
  12157. skips=(
  12158. # Skip since real and imag don't have out variants.
  12159. DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_out_arg_all_dtypes'),
  12160. )),
  12161. OpInfo(
  12162. "roll",
  12163. ref=np.roll,
  12164. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf),
  12165. error_inputs_func=error_inputs_roll,
  12166. supports_out=False,
  12167. supports_forward_ad=True,
  12168. supports_fwgrad_bwgrad=True,
  12169. sample_inputs_func=sample_inputs_roll,
  12170. decorators=(onlyNativeDeviceTypes,),
  12171. ),
  12172. OpInfo(
  12173. "rot90",
  12174. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half),
  12175. error_inputs_func=error_inputs_rot90,
  12176. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  12177. gradcheck_fast_mode=True,
  12178. supports_out=False,
  12179. supports_forward_ad=True,
  12180. supports_fwgrad_bwgrad=True,
  12181. sample_inputs_func=sample_inputs_rot90,
  12182. ),
  12183. # To test reference numerics against multiple values of argument `decimals`,
  12184. # we make multiple OpInfo entries with each entry corresponding to different value of decimals.
  12185. UnaryUfuncInfo('round',
  12186. ref=np.round,
  12187. aliases=('special.round',),
  12188. dtypes=all_types_and(torch.bfloat16),
  12189. dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
  12190. supports_forward_ad=True,
  12191. supports_fwgrad_bwgrad=True,
  12192. skips=(
  12193. DecorateInfo(unittest.expectedFailure,
  12194. 'TestNNCOpInfo',
  12195. 'test_nnc_correctness',
  12196. dtypes=tuple(t for t in integral_types() if t != torch.uint8)),
  12197. DecorateInfo(unittest.expectedFailure,
  12198. 'TestCudaFuserOpInfo',
  12199. 'test_nvfuser_correctness',
  12200. dtypes=(torch.int32, torch.int64),
  12201. active_if=not TEST_WITH_ROCM),
  12202. DecorateInfo(unittest.skip("Skipped!"),
  12203. 'TestNNCOpInfo',
  12204. 'test_nnc_correctness',
  12205. dtypes=(torch.bfloat16,)),
  12206. ),
  12207. supports_sparse=True,
  12208. supports_sparse_csr=True,
  12209. supports_sparse_csc=True,
  12210. supports_sparse_bsr=True,
  12211. supports_sparse_bsc=True,
  12212. assert_autodiffed=True,
  12213. ),
  12214. UnaryUfuncInfo('round',
  12215. ref=np.round,
  12216. variant_test_name='decimals_0',
  12217. aliases=('special.round',),
  12218. dtypes=floating_types_and(torch.bfloat16),
  12219. dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
  12220. sample_kwargs=lambda device, dtype, input: ({'decimals': 0}, {'decimals': 0}),
  12221. sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'decimals': 0}),
  12222. supports_forward_ad=True,
  12223. supports_fwgrad_bwgrad=True,
  12224. assert_autodiffed=False,
  12225. supports_sparse_csr=False),
  12226. UnaryUfuncInfo('round',
  12227. ref=np.round,
  12228. variant_test_name='decimals_3',
  12229. aliases=('special.round',),
  12230. dtypes=floating_types_and(torch.bfloat16),
  12231. dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
  12232. sample_kwargs=lambda device, dtype, input: ({'decimals': 3}, {'decimals': 3}),
  12233. sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'decimals': 3}),
  12234. skips=(
  12235. # test_ops already tested for this overload with `decimals_0` opinfo entry
  12236. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'),
  12237. DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'),
  12238. DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'),
  12239. DecorateInfo(unittest.skip("Skipped!"), 'TestJit'),
  12240. DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits'),
  12241. DecorateInfo(toleranceOverride({torch.bfloat16: tol(atol=1e-3, rtol=0.016)}),
  12242. "TestUnaryUfuncs", "test_reference_numerics_extremal",
  12243. device_type="cuda"),
  12244. DecorateInfo(toleranceOverride({torch.bfloat16: tol(atol=1e-3, rtol=0.016)}),
  12245. "TestUnaryUfuncs", "test_reference_numerics_normal",
  12246. device_type="cuda"),
  12247. ),
  12248. supports_forward_ad=True,
  12249. supports_fwgrad_bwgrad=True,
  12250. assert_autodiffed=False,
  12251. supports_sparse_csr=False),
  12252. UnaryUfuncInfo('round',
  12253. ref=np.round,
  12254. variant_test_name='decimals_neg_3',
  12255. aliases=('special.round',),
  12256. dtypes=floating_types_and(torch.bfloat16),
  12257. dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
  12258. sample_kwargs=lambda device, dtype, input: ({'decimals': -3}, {'decimals': -3}),
  12259. sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'decimals': -3}),
  12260. skips=(
  12261. # test_ops already tested for this overload with `decimals_0` opinfo entry
  12262. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'),
  12263. DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'),
  12264. DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'),
  12265. DecorateInfo(unittest.skip("Skipped!"), 'TestJit'),
  12266. DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits'),
  12267. ),
  12268. supports_forward_ad=True,
  12269. supports_fwgrad_bwgrad=True,
  12270. assert_autodiffed=False,
  12271. supports_sparse_csr=False),
  12272. UnaryUfuncInfo('sin',
  12273. ref=np.sin,
  12274. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
  12275. dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
  12276. assert_autodiffed=True,
  12277. handles_large_floats=False,
  12278. supports_sparse=True,
  12279. supports_sparse_csr=True,
  12280. supports_sparse_csc=True,
  12281. supports_sparse_bsr=True,
  12282. supports_sparse_bsc=True,
  12283. supports_forward_ad=True,
  12284. supports_fwgrad_bwgrad=True,
  12285. skips=(
  12286. # Fails on CUDA but passes on ROCm
  12287. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  12288. dtypes=(torch.cdouble,), device_type='cuda'),
  12289. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  12290. dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', active_if=IS_WINDOWS),
  12291. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  12292. dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', active_if=IS_WINDOWS),
  12293. DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
  12294. 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
  12295. ),
  12296. decorators=(precisionOverride({torch.bfloat16: 1e-2}),)),
  12297. UnaryUfuncInfo('sinc',
  12298. ref=np_sinc_with_fp16_as_fp32,
  12299. aliases=('special.sinc',),
  12300. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
  12301. dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
  12302. handles_large_floats=False,
  12303. supports_forward_ad=True,
  12304. supports_fwgrad_bwgrad=True,
  12305. decorators=(precisionOverride({torch.bfloat16: 1e-2,
  12306. torch.float16: 1e-2}),),
  12307. skips=(
  12308. # Reference: https://github.com/pytorch/pytorch/issues/49133
  12309. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
  12310. dtypes=[torch.cfloat]),
  12311. )),
  12312. UnaryUfuncInfo('sinh',
  12313. ref=np_unary_ufunc_integer_promotion_wrapper(np.sinh),
  12314. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
  12315. dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
  12316. assert_autodiffed=True,
  12317. supports_forward_ad=True,
  12318. supports_fwgrad_bwgrad=True,
  12319. supports_sparse=True,
  12320. supports_sparse_csr=True,
  12321. supports_sparse_csc=True,
  12322. supports_sparse_bsr=True,
  12323. supports_sparse_bsc=True,
  12324. decorators=(precisionOverride({torch.float16: 1e-2}),),
  12325. skips=(
  12326. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  12327. device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
  12328. active_if=(IS_MACOS or IS_WINDOWS)),
  12329. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  12330. device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
  12331. active_if=(IS_MACOS or IS_WINDOWS)),
  12332. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  12333. dtypes=(torch.cdouble,)),
  12334. # Reference: https://github.com/pytorch/pytorch/issues/48641
  12335. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  12336. device_type='cpu', dtypes=[torch.int8]),
  12337. DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
  12338. 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
  12339. )),
  12340. UnaryUfuncInfo('sign',
  12341. ref=reference_sign,
  12342. dtypes=all_types_and(torch.bool, torch.bfloat16, torch.half),
  12343. dtypesIfCUDA=all_types_and(torch.bool, torch.bfloat16, torch.half),
  12344. supports_forward_ad=True,
  12345. supports_fwgrad_bwgrad=True,
  12346. supports_sparse=True,
  12347. supports_sparse_csr=True,
  12348. supports_sparse_csc=True,
  12349. supports_sparse_bsr=True,
  12350. supports_sparse_bsc=True,
  12351. skips=(
  12352. # Reference: https://github.com/pytorch/pytorch/issues/41245
  12353. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  12354. dtypes=[torch.bfloat16, torch.float16, torch.float32, torch.float64]),
  12355. )),
  12356. UnaryUfuncInfo('sgn',
  12357. ref=reference_sgn,
  12358. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf),
  12359. backward_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half),
  12360. backward_dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16, torch.half, torch.chalf),
  12361. supports_forward_ad=True,
  12362. supports_fwgrad_bwgrad=True,
  12363. supports_sparse=True,
  12364. supports_sparse_csr=True,
  12365. supports_sparse_csc=True,
  12366. supports_sparse_bsr=True,
  12367. supports_sparse_bsc=True,
  12368. skips=(
  12369. # Reference: https://github.com/pytorch/pytorch/issues/41245
  12370. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  12371. dtypes=[torch.bfloat16, torch.float16, torch.float32, torch.float64]),
  12372. # Reference: https://github.com/pytorch/pytorch/issues/53958
  12373. # Test fails in comparison on Nan as the `equal_nan` is True for
  12374. # comparing the CPU tensors.
  12375. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  12376. device_type='cpu', dtypes=[torch.complex64, torch.complex128]),
  12377. # Reference: https://github.com/pytorch/pytorch/issues/48486
  12378. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  12379. device_type='cpu', dtypes=[torch.complex64]),
  12380. DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
  12381. 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
  12382. )),
  12383. OpInfo('split',
  12384. dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf),
  12385. sample_inputs_func=partial(sample_inputs_split, list_args=False),
  12386. supports_forward_ad=True,
  12387. supports_fwgrad_bwgrad=True,
  12388. supports_out=False,
  12389. autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused
  12390. autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused
  12391. assert_autodiffed=True),
  12392. OpInfo('split',
  12393. # Cannot declare this aten_name because of
  12394. # test_variant_consistency_jit_split_list_args_cpu_float32
  12395. decomp_aten_name='split_with_sizes',
  12396. variant_test_name='list_args',
  12397. dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool),
  12398. sample_inputs_func=partial(sample_inputs_split, list_args=True),
  12399. supports_forward_ad=True,
  12400. supports_fwgrad_bwgrad=True,
  12401. supports_out=False),
  12402. # `unsafe_split` supports only `int` for split_size argument
  12403. OpInfo('unsafe_split',
  12404. dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf),
  12405. sample_inputs_func=partial(sample_inputs_split, list_args=False),
  12406. supports_forward_ad=True,
  12407. supports_fwgrad_bwgrad=True,
  12408. supports_out=False,
  12409. autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused
  12410. autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused
  12411. assert_autodiffed=True,
  12412. check_batched_forward_grad=False),
  12413. OpInfo('split_with_sizes',
  12414. dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf),
  12415. sample_inputs_func=sample_inputs_split_with_sizes,
  12416. autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused
  12417. autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused
  12418. supports_out=False,
  12419. supports_forward_ad=True,
  12420. supports_fwgrad_bwgrad=True,
  12421. assert_autodiffed=True),
  12422. BinaryUfuncInfo('__radd__',
  12423. op=torch.Tensor.__radd__,
  12424. dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool),
  12425. supports_out=False,
  12426. skips=(
  12427. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  12428. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',),
  12429. ),
  12430. assert_autodiffed=True,
  12431. supports_forward_ad=True,
  12432. supports_fwgrad_bwgrad=True,
  12433. autodiff_nonfusible_nodes=['aten::add'],),
  12434. BinaryUfuncInfo('__rdiv__',
  12435. op=torch.Tensor.__rdiv__,
  12436. dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool),
  12437. promotes_int_to_float=True,
  12438. lhs_make_tensor_kwargs={'exclude_zero': True},
  12439. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  12440. gradcheck_fast_mode=True,
  12441. supports_out=False,
  12442. skips=(
  12443. # https://github.com/pytorch/pytorch/issues/76806
  12444. DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'),
  12445. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  12446. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',),
  12447. ),
  12448. supports_forward_ad=True,
  12449. supports_fwgrad_bwgrad=True,
  12450. assert_autodiffed=True,
  12451. autodiff_nonfusible_nodes=['aten::mul', 'aten::reciprocal'],),
  12452. BinaryUfuncInfo('__rmul__',
  12453. op=torch.Tensor.__rmul__,
  12454. dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool),
  12455. supports_out=False,
  12456. skips=(
  12457. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  12458. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',),
  12459. ),
  12460. assert_autodiffed=True,
  12461. supports_forward_ad=True,
  12462. supports_fwgrad_bwgrad=True,
  12463. autodiff_nonfusible_nodes=['aten::mul'],),
  12464. BinaryUfuncInfo('__rand__',
  12465. op=torch.Tensor.__rand__,
  12466. dtypes=integral_types_and(torch.bool),
  12467. supports_out=False,
  12468. supports_autograd=False,
  12469. supports_forward_ad=True,
  12470. skips=(
  12471. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  12472. )),
  12473. BinaryUfuncInfo('__ror__',
  12474. op=torch.Tensor.__ror__,
  12475. dtypes=integral_types_and(torch.bool),
  12476. supports_out=False,
  12477. supports_autograd=False,
  12478. supports_forward_ad=True,
  12479. skips=(
  12480. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  12481. )),
  12482. BinaryUfuncInfo('__rxor__',
  12483. op=torch.Tensor.__rxor__,
  12484. dtypes=integral_types_and(torch.bool),
  12485. supports_out=False,
  12486. supports_autograd=False,
  12487. supports_forward_ad=True,
  12488. skips=(
  12489. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  12490. )),
  12491. OpInfo('__rmatmul__',
  12492. op=torch.Tensor.__rmatmul__,
  12493. dtypes=all_types_and_complex_and(torch.bfloat16),
  12494. dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
  12495. *[torch.bfloat16]
  12496. if SM53OrLater or TEST_WITH_ROCM else []),
  12497. assert_autodiffed=True,
  12498. sample_inputs_func=partial(sample_inputs_matmul, is_rmatmul=True),
  12499. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  12500. gradcheck_fast_mode=True,
  12501. supports_out=False,
  12502. supports_forward_ad=True,
  12503. supports_fwgrad_bwgrad=True,
  12504. check_batched_forward_grad=False,
  12505. decorators=(
  12506. # NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3
  12507. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater),
  12508. DecorateInfo(toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}),
  12509. 'TestMathBits', 'test_conj_view'),
  12510. DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1.2e-03)}),
  12511. 'TestCommon', 'test_noncontiguous_samples'),
  12512. DecorateInfo(toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1e-05)}),
  12513. "TestDecomp", "test_comprehensive", device_type="cuda",
  12514. active_if=TEST_WITH_ROCM),
  12515. ),
  12516. skips=(
  12517. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  12518. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',),
  12519. # https://github.com/pytorch/pytorch/issues/67470
  12520. DecorateInfo(unittest.skip("67470!"),
  12521. 'TestCommon', 'test_noncontiguous_samples',
  12522. device_type='cpu', dtypes=(torch.long,)),
  12523. # Fails on XLA.
  12524. # AssertionError: False is not true : Tensors failed to compare as equal
  12525. DecorateInfo(unittest.skip("Skipped!"), 'TestOpInfo', device_type='xla', dtypes=(torch.long,)),
  12526. # https://github.com/pytorch/pytorch/issues/71774
  12527. DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness',
  12528. device_type='cpu', dtypes=(torch.long,)),
  12529. )),
  12530. BinaryUfuncInfo('__rmod__',
  12531. op=torch.Tensor.__rmod__,
  12532. dtypes=floating_types_and(torch.bfloat16, torch.half,),
  12533. dtypesIfCUDA=all_types_and(torch.bfloat16, torch.half),
  12534. # https://github.com/pytorch/pytorch/issues/80411
  12535. gradcheck_fast_mode=True,
  12536. supports_out=False,
  12537. supports_forward_ad=True,
  12538. supports_fwgrad_bwgrad=True,
  12539. supports_one_python_scalar=True,
  12540. skips=(
  12541. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  12542. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',),
  12543. ),
  12544. # Support autograd after torch.remainder(Tensor, Tensor) supports
  12545. # autograd of the second argument.
  12546. # https://github.com/pytorch/pytorch/pull/58476/files#r637167630
  12547. # supports_autograd=False,
  12548. assert_autodiffed=True,
  12549. autodiff_nonfusible_nodes=['aten::remainder'],),
  12550. BinaryUfuncInfo('__rpow__',
  12551. op=torch.Tensor.__rpow__,
  12552. dtypes=all_types_and_complex_and(torch.bfloat16, torch.half),
  12553. # Reference: https://github.com/pytorch/pytorch/issues/54774
  12554. # "log2" "_vml_cpu" not implemented for Half
  12555. backward_dtypes=all_types_and_complex_and(torch.bfloat16),
  12556. backward_dtypesIfCUDA=all_types_and_complex_and(torch.bfloat16, torch.half),
  12557. supports_out=False,
  12558. supports_forward_ad=True,
  12559. supports_fwgrad_bwgrad=True,
  12560. skips=(
  12561. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  12562. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',),
  12563. # TODO: FIXME tolerance is too high
  12564. DecorateInfo(unittest.skip('Skipped!'), 'TestFwdGradients'),
  12565. DecorateInfo(unittest.skip('Skipped!'), 'TestBwdGradients'),
  12566. ),
  12567. assert_autodiffed=True,
  12568. autodiff_nonfusible_nodes=['aten::pow'],),
  12569. BinaryUfuncInfo('__rsub__',
  12570. op=torch.Tensor.__rsub__,
  12571. dtypes=all_types_and_complex_and(torch.bfloat16, torch.half),
  12572. supports_forward_ad=True,
  12573. supports_fwgrad_bwgrad=True,
  12574. supports_out=False,
  12575. supports_one_python_scalar=True,
  12576. skips=(
  12577. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  12578. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',),
  12579. ),
  12580. assert_autodiffed=True,
  12581. autodiff_nonfusible_nodes=['aten::rsub'],),
  12582. BinaryUfuncInfo('rsub',
  12583. dtypes=all_types_and_complex_and(torch.bfloat16, torch.half),
  12584. supports_forward_ad=True,
  12585. supports_fwgrad_bwgrad=True,
  12586. supports_out=False,
  12587. supports_inplace_autograd=False,
  12588. assert_autodiffed=None,
  12589. sample_inputs_func=sample_inputs_add_sub),
  12590. OpInfo('select',
  12591. aten_backward_name='select_backward',
  12592. dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf),
  12593. sample_inputs_func=sample_inputs_select,
  12594. assert_jit_shape_analysis=True,
  12595. supports_forward_ad=True,
  12596. supports_fwgrad_bwgrad=True,
  12597. supports_out=False),
  12598. OpInfo('select_scatter',
  12599. dtypes=all_types_and(torch.bfloat16, torch.half, torch.bool),
  12600. sample_inputs_func=sample_inputs_select_scatter,
  12601. supports_forward_ad=True,
  12602. supports_fwgrad_bwgrad=True,
  12603. supports_out=False),
  12604. OpInfo('slice',
  12605. op=torch.ops.aten.slice.Tensor,
  12606. dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf),
  12607. sample_inputs_func=sample_inputs_slice,
  12608. gradcheck_fast_mode=True,
  12609. supports_forward_ad=True,
  12610. supports_fwgrad_bwgrad=True,
  12611. supports_scripting=False,
  12612. supports_inplace_autograd=False,
  12613. supports_out=False),
  12614. OpInfo('slice_scatter',
  12615. dtypes=all_types_and(torch.bfloat16, torch.half, torch.bool),
  12616. sample_inputs_func=sample_inputs_slice_scatter,
  12617. # https://github.com/pytorch/pytorch/issues/80411
  12618. gradcheck_fast_mode=True,
  12619. supports_forward_ad=True,
  12620. supports_fwgrad_bwgrad=True,
  12621. supports_out=False),
  12622. UnaryUfuncInfo('signbit',
  12623. ref=np.signbit,
  12624. dtypes=all_types_and(torch.bool, torch.bfloat16, torch.half),
  12625. supports_sparse=True,
  12626. supports_sparse_csr=True,
  12627. supports_sparse_csc=True,
  12628. supports_sparse_bsr=True,
  12629. supports_sparse_bsc=True,
  12630. supports_autograd=False,),
  12631. UnaryUfuncInfo('tan',
  12632. ref=np.tan,
  12633. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
  12634. dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
  12635. assert_autodiffed=True,
  12636. supports_forward_ad=True,
  12637. supports_fwgrad_bwgrad=True,
  12638. supports_sparse=True,
  12639. supports_sparse_csr=True,
  12640. supports_sparse_csc=True,
  12641. supports_sparse_bsr=True,
  12642. supports_sparse_bsc=True,
  12643. skips=(
  12644. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  12645. device_type='cpu', dtypes=[torch.bfloat16]),
  12646. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  12647. device_type='cpu', dtypes=[torch.bfloat16]),
  12648. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
  12649. device_type='cpu', dtypes=[torch.bfloat16]),
  12650. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  12651. device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
  12652. active_if=(IS_MACOS or IS_WINDOWS)),
  12653. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  12654. device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
  12655. active_if=(IS_MACOS or IS_WINDOWS)),
  12656. DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
  12657. 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
  12658. ),
  12659. # tan(pi/2 * odd_number) is nan
  12660. reference_numerics_filter=NumericsFilter(
  12661. condition=lambda x: close_to_int(x / (math.pi * 0.5)), safe_val=math.pi)),
  12662. UnaryUfuncInfo('tanh',
  12663. ref=np.tanh,
  12664. aten_backward_name='tanh_backward',
  12665. aliases=('nn.functional.tanh',),
  12666. decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
  12667. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
  12668. dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
  12669. assert_autodiffed=True,
  12670. assert_jit_shape_analysis=True,
  12671. supports_forward_ad=True,
  12672. supports_fwgrad_bwgrad=True,
  12673. supports_sparse=True,
  12674. supports_sparse_csr=True,
  12675. supports_sparse_csc=True,
  12676. supports_sparse_bsr=True,
  12677. supports_sparse_bsc=True,
  12678. skips=(
  12679. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  12680. device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
  12681. active_if=(IS_MACOS or IS_WINDOWS)),
  12682. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  12683. device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
  12684. active_if=(IS_MACOS or IS_WINDOWS)),
  12685. DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
  12686. 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
  12687. ),
  12688. # tan(j * pi/2 * odd_number) is nan
  12689. reference_numerics_filter=NumericsFilter(
  12690. condition=lambda x: (close_to_int(x / (math.pi * 0.5j))
  12691. if x.is_complex() else x.new_tensor(False, dtype=torch.bool)),
  12692. safe_val=0)),
  12693. OpInfo('tensor_split',
  12694. ref=np.array_split,
  12695. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
  12696. dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
  12697. supports_out=False,
  12698. supports_forward_ad=True,
  12699. supports_fwgrad_bwgrad=True,
  12700. skips=(
  12701. # Pre-existing condition; Needs to be fixed
  12702. DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'),
  12703. DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
  12704. DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'),
  12705. ),
  12706. sample_inputs_func=sample_inputs_tensor_split,),
  12707. OpInfo('hsplit',
  12708. dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.bfloat16, torch.float16),
  12709. supports_out=False,
  12710. supports_forward_ad=True,
  12711. supports_fwgrad_bwgrad=True,
  12712. # See https://github.com/pytorch/pytorch/pull/78358
  12713. check_batched_forward_grad=False,
  12714. sample_inputs_func=sample_inputs_hsplit,
  12715. error_inputs_func=error_inputs_hsplit,),
  12716. OpInfo('vsplit',
  12717. dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.bfloat16, torch.float16),
  12718. supports_out=False,
  12719. supports_forward_ad=True,
  12720. supports_fwgrad_bwgrad=True,
  12721. # See https://github.com/pytorch/pytorch/pull/78358
  12722. check_batched_forward_grad=False,
  12723. sample_inputs_func=sample_inputs_vsplit,
  12724. error_inputs_func=error_inputs_vsplit,),
  12725. OpInfo('dsplit',
  12726. dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.bfloat16, torch.float16),
  12727. supports_out=False,
  12728. supports_forward_ad=True,
  12729. supports_fwgrad_bwgrad=True,
  12730. # See https://github.com/pytorch/pytorch/pull/78358
  12731. check_batched_forward_grad=False,
  12732. sample_inputs_func=sample_inputs_dsplit,
  12733. error_inputs_func=error_inputs_dsplit,),
  12734. OpInfo('triangular_solve',
  12735. op=torch.triangular_solve,
  12736. dtypes=floating_and_complex_types(),
  12737. sample_inputs_func=sample_inputs_legacy_solve,
  12738. check_batched_gradgrad=False,
  12739. supports_forward_ad=True,
  12740. supports_fwgrad_bwgrad=True,
  12741. gradcheck_wrapper=lambda *args, **kwargs: gradcheck_wrapper_triangular_input(*args, idx=1, **kwargs),
  12742. decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
  12743. skips=(
  12744. # AssertionError: Scalars are not equal!
  12745. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
  12746. # Gradcheck fails
  12747. DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad',
  12748. dtypes=floating_and_complex_types()),
  12749. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out',
  12750. device_type='mps', dtypes=[torch.float32]),
  12751. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager',
  12752. device_type='mps', dtypes=[torch.float32]),
  12753. DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
  12754. device_type='mps', dtypes=[torch.float32]),
  12755. )),
  12756. UnaryUfuncInfo('trunc',
  12757. aliases=('fix', ),
  12758. ref=np.trunc,
  12759. dtypes=all_types_and(torch.bfloat16),
  12760. dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
  12761. supports_forward_ad=True,
  12762. supports_fwgrad_bwgrad=True,
  12763. supports_sparse=True,
  12764. skips=(
  12765. DecorateInfo(unittest.expectedFailure,
  12766. 'TestNNCOpInfo',
  12767. 'test_nnc_correctness',
  12768. dtypes=tuple(t for t in integral_types() if t != torch.uint8)),
  12769. DecorateInfo(unittest.expectedFailure,
  12770. 'TestCudaFuserOpInfo',
  12771. 'test_nvfuser_correctness',
  12772. dtypes=(torch.int32, torch.int64),
  12773. active_if=not TEST_WITH_ROCM),
  12774. ),
  12775. supports_sparse_csr=True,
  12776. supports_sparse_csc=True,
  12777. supports_sparse_bsr=True,
  12778. supports_sparse_bsc=True,
  12779. assert_autodiffed=True),
  12780. UnaryUfuncInfo('exp2',
  12781. aliases=('special.exp2', ),
  12782. ref=np_unary_ufunc_integer_promotion_wrapper(np.exp2),
  12783. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
  12784. supports_forward_ad=True,
  12785. supports_fwgrad_bwgrad=True,
  12786. skips=(
  12787. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  12788. dtypes=[torch.cdouble]),
  12789. # Reference: https://github.com/pytorch/pytorch/issues/48010
  12790. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  12791. device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
  12792. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  12793. device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
  12794. )),
  12795. UnaryUfuncInfo('expm1',
  12796. aliases=('special.expm1', ),
  12797. ref=np_unary_ufunc_integer_promotion_wrapper(np.expm1),
  12798. dtypes=all_types_and(torch.bool, torch.bfloat16),
  12799. dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
  12800. supports_forward_ad=True,
  12801. supports_fwgrad_bwgrad=True,
  12802. supports_sparse=True,
  12803. supports_sparse_csr=True,
  12804. supports_sparse_csc=True,
  12805. supports_sparse_bsr=True,
  12806. supports_sparse_bsc=True,
  12807. assert_autodiffed=True,
  12808. skips=(
  12809. # Reference: https://github.com/pytorch/pytorch/pull/48926#issuecomment-739734774
  12810. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  12811. device_type='cpu', dtypes=[torch.bfloat16]),
  12812. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  12813. device_type='cpu', dtypes=[torch.bfloat16]),
  12814. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
  12815. device_type='cpu', dtypes=[torch.bfloat16]),
  12816. DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
  12817. 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
  12818. )),
  12819. UnaryUfuncInfo('nan_to_num',
  12820. ref=np.nan_to_num,
  12821. dtypes=all_types_and(torch.half, torch.bool, torch.bfloat16),
  12822. dtypesIfCUDA=all_types_and(torch.half, torch.bool, torch.bfloat16),
  12823. supports_forward_ad=True,
  12824. supports_fwgrad_bwgrad=True,
  12825. supports_sparse=True,
  12826. skips=(
  12827. DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
  12828. 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
  12829. ),
  12830. # Passing numpy_kwargs via sample_kwargs, as numpy does comparison
  12831. # with BFloat16 in float, since it currently doesn't support BFloat16.
  12832. # Ref: https://github.com/pytorch/pytorch/issues/57982#issuecomment-839150556
  12833. sample_kwargs=lambda device, dtype, input: ({},
  12834. {'posinf': torch.finfo(torch.bfloat16).max,
  12835. 'neginf': torch.finfo(torch.bfloat16).min})
  12836. if dtype is torch.bfloat16 else ({}, {})),
  12837. UnaryUfuncInfo('reciprocal',
  12838. ref=np_unary_ufunc_integer_promotion_wrapper(np.reciprocal),
  12839. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
  12840. assert_autodiffed=True,
  12841. supports_forward_ad=True,
  12842. supports_fwgrad_bwgrad=True,
  12843. skips=(
  12844. # Reference: https://github.com/pytorch/pytorch/issues/45690
  12845. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  12846. dtypes=[torch.cfloat, torch.cdouble]),
  12847. # Reference: https://github.com/pytorch/pytorch/pull/49102#issuecomment-744604601
  12848. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  12849. dtypes=[torch.bfloat16]),
  12850. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  12851. dtypes=[torch.bfloat16]),
  12852. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
  12853. dtypes=[torch.bfloat16]),
  12854. )),
  12855. UnaryUfuncInfo('rsqrt',
  12856. ref=lambda x: np.reciprocal(np.sqrt(x)),
  12857. domain=(0, None),
  12858. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
  12859. dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
  12860. decorators=(precisionOverride({torch.half: 5e-2}),),
  12861. assert_autodiffed=True,
  12862. supports_forward_ad=True,
  12863. supports_fwgrad_bwgrad=True,
  12864. skips=(
  12865. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  12866. dtypes=(torch.cfloat, torch.cdouble)),
  12867. # AssertionError: Tensor-likes are not close!
  12868. # Greatest absolute difference: nan at index (700,) (up to 0.01 allowed)
  12869. # Greatest relative difference: nan at index (700,) (up to 0.001 allowed)
  12870. DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_large',
  12871. dtypes=(torch.chalf,)),
  12872. )),
  12873. UnaryUfuncInfo('sqrt',
  12874. ref=np.sqrt,
  12875. supports_sparse=True,
  12876. domain=(0, None),
  12877. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
  12878. dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
  12879. assert_autodiffed=True,
  12880. supports_forward_ad=True,
  12881. supports_sparse_csr=True,
  12882. supports_sparse_csc=True,
  12883. supports_sparse_bsr=True,
  12884. supports_sparse_bsc=True,
  12885. supports_fwgrad_bwgrad=True,
  12886. decorators=(
  12887. precisionOverride({torch.bfloat16: 7e-2}),
  12888. DecorateInfo(
  12889. toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}),
  12890. 'TestUnaryUfuncs', 'test_reference_numerics_large'),
  12891. ),
  12892. skips=(
  12893. # Reference: https://github.com/pytorch/pytorch/issues/47358
  12894. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  12895. device_type='cpu', dtypes=(torch.cfloat, torch.cdouble),
  12896. active_if=IS_MACOS),
  12897. # Reference: https://github.com/pytorch/pytorch/pull/47293#issuecomment-721774436
  12898. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  12899. dtypes=(torch.bfloat16,)),
  12900. DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
  12901. 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
  12902. )),
  12903. UnaryUfuncInfo('square',
  12904. ref=np.square,
  12905. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
  12906. decorators=(precisionOverride({torch.complex64: 3e-4, torch.bfloat16: 3e-1}),),
  12907. supports_forward_ad=True,
  12908. supports_fwgrad_bwgrad=True,
  12909. skips=(
  12910. # Reference: https://github.com/pytorch/pytorch/issues/52549
  12911. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  12912. dtypes=[torch.cfloat, torch.cdouble]),
  12913. # >>> t = torch.tensor(complex(-0.01, float("inf")))
  12914. # >>> np.square(t.numpy())
  12915. # (-inf-infj)
  12916. # >>> t.square()
  12917. # tensor(-inf-infj)
  12918. # >>> t.cuda().square()
  12919. # tensor(inf+nanj, device='cuda:0')
  12920. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  12921. device_type='cuda', dtypes=[torch.cfloat, torch.cdouble]),
  12922. # Reference: https://github.com/pytorch/pytorch/pull/52551#issuecomment-782596181
  12923. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  12924. dtypes=[torch.bfloat16]),
  12925. ),),
  12926. OpInfo('lerp',
  12927. dtypes=floating_and_complex_types_and(torch.bfloat16),
  12928. dtypesIfCUDA=floating_and_complex_types_and(torch.chalf, torch.half, torch.bfloat16),
  12929. dtypesIfROCM=floating_and_complex_types_and(torch.half, torch.bfloat16),
  12930. sample_inputs_func=sample_inputs_lerp,
  12931. supports_forward_ad=True,
  12932. supports_fwgrad_bwgrad=True,
  12933. assert_autodiffed=True),
  12934. UnaryUfuncInfo('angle',
  12935. ref=np.angle,
  12936. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
  12937. dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool),
  12938. decorators=(precisionOverride({torch.float16: 1e-2,
  12939. torch.bfloat16: 1e-2}),),
  12940. backward_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16),
  12941. backward_dtypesIfCUDA=floating_and_complex_types_and(torch.chalf),
  12942. supports_forward_ad=True,
  12943. supports_fwgrad_bwgrad=True,
  12944. supports_sparse_csr=True,
  12945. supports_sparse_csc=True,
  12946. supports_sparse_bsr=True,
  12947. supports_sparse_bsc=True,
  12948. supports_complex_to_float=True,
  12949. skips=(
  12950. # Ref: https://github.com/pytorch/pytorch/issues/78413
  12951. DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_small',
  12952. dtypes=(torch.bfloat16, torch.float16, torch.float32, torch.float64),),
  12953. )),
  12954. UnaryUfuncInfo('isfinite',
  12955. ref=np.isfinite,
  12956. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
  12957. supports_out=False,
  12958. supports_autograd=False),
  12959. UnaryUfuncInfo('isinf',
  12960. ref=np.isinf,
  12961. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
  12962. supports_out=False,
  12963. supports_sparse=True,
  12964. supports_sparse_csr=True,
  12965. supports_sparse_csc=True,
  12966. supports_sparse_bsr=True,
  12967. supports_sparse_bsc=True,
  12968. supports_autograd=False),
  12969. UnaryUfuncInfo('isposinf',
  12970. ref=np.isposinf,
  12971. dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16),
  12972. supports_sparse=True,
  12973. supports_sparse_csr=True,
  12974. supports_sparse_csc=True,
  12975. supports_sparse_bsr=True,
  12976. supports_sparse_bsc=True,
  12977. supports_autograd=False),
  12978. UnaryUfuncInfo('isneginf',
  12979. ref=np.isneginf,
  12980. dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16),
  12981. supports_sparse=True,
  12982. supports_sparse_csr=True,
  12983. supports_sparse_csc=True,
  12984. supports_sparse_bsr=True,
  12985. supports_sparse_bsc=True,
  12986. supports_autograd=False),
  12987. UnaryUfuncInfo('isreal',
  12988. ref=np.isreal,
  12989. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
  12990. supports_out=False,
  12991. supports_autograd=False),
  12992. UnaryUfuncInfo('isnan',
  12993. ref=np.isnan,
  12994. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
  12995. supports_out=False,
  12996. supports_sparse=True,
  12997. supports_sparse_csr=True,
  12998. supports_sparse_csc=True,
  12999. supports_sparse_bsr=True,
  13000. supports_sparse_bsc=True,
  13001. supports_autograd=False),
  13002. OpInfo('einsum',
  13003. # we need this lambda because SampleInput expects tensor input as the first argument
  13004. # TODO(@heitorschueroff) update SampleInput to handle such cases
  13005. op=lambda tensors, equation: torch.einsum(equation, tensors),
  13006. dtypes=all_types_and_complex_and(torch.bfloat16),
  13007. dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
  13008. backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, *[torch.bfloat16]
  13009. if (SM60OrLater or
  13010. TEST_WITH_ROCM) else []),
  13011. supports_out=False,
  13012. supports_forward_ad=True,
  13013. supports_fwgrad_bwgrad=True,
  13014. check_batched_forward_grad=False,
  13015. # See https://github.com/pytorch/pytorch/issues/66357
  13016. sample_inputs_func=sample_inputs_einsum,
  13017. skips=(
  13018. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  13019. # test does not work with passing lambda for op
  13020. # there's a test `test_einsum` in `test_jit.py` to handle this case
  13021. # AssertionError: JIT Test does not execute any logic
  13022. DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
  13023. )),
  13024. OpInfo('svd',
  13025. op=torch.svd,
  13026. dtypes=floating_and_complex_types(),
  13027. sample_inputs_func=sample_inputs_svd,
  13028. # Runs very slowly on slow-gradcheck - alternatively reduce input sizes
  13029. gradcheck_fast_mode=True,
  13030. supports_forward_ad=True,
  13031. supports_fwgrad_bwgrad=True,
  13032. check_batched_forward_grad=False,
  13033. # We're using at::allclose, which does not have a batching rule
  13034. check_batched_grad=False,
  13035. check_batched_gradgrad=False,
  13036. decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
  13037. skips=(
  13038. # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
  13039. DecorateInfo(
  13040. unittest.skip("Skipped!"),
  13041. 'TestSchemaCheckModeOpInfo',
  13042. 'test_schema_correctness',
  13043. dtypes=(torch.complex64, torch.complex128)),
  13044. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out',
  13045. device_type='mps', dtypes=[torch.float32]),
  13046. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager',
  13047. device_type='mps', dtypes=[torch.float32]),
  13048. DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
  13049. device_type='mps', dtypes=[torch.float32]),
  13050. )),
  13051. OpInfo('svd_lowrank',
  13052. op=lambda *args, **kwargs: wrapper_set_seed(
  13053. lambda a, b, **kwargs: torch.svd_lowrank(a @ b.mT, **kwargs),
  13054. *args, **kwargs
  13055. ),
  13056. dtypes=floating_types(),
  13057. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  13058. gradcheck_fast_mode=True,
  13059. supports_out=False,
  13060. check_batched_grad=False,
  13061. check_batched_gradgrad=False,
  13062. check_batched_forward_grad=False,
  13063. supports_fwgrad_bwgrad=True,
  13064. supports_forward_ad=True,
  13065. sample_inputs_func=sample_inputs_svd_lowrank,
  13066. decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack, with_tf32_off,
  13067. DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-03)}),
  13068. 'TestCommon', 'test_noncontiguous_samples',
  13069. device_type='cuda')],
  13070. skips=(
  13071. # test does not work with passing lambda for op
  13072. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  13073. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  13074. DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
  13075. DecorateInfo(slowTest, 'TestCompositeCompliance', 'test_forward_ad'),
  13076. )),
  13077. OpInfo('pca_lowrank',
  13078. op=lambda *args, **kwargs: wrapper_set_seed(
  13079. lambda a, b, **kwargs: torch.pca_lowrank(a @ b.mT, **kwargs),
  13080. *args, **kwargs
  13081. ),
  13082. dtypes=floating_types(),
  13083. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  13084. gradcheck_fast_mode=True,
  13085. supports_out=False,
  13086. check_batched_forward_grad=False,
  13087. check_batched_grad=False,
  13088. check_batched_gradgrad=False,
  13089. supports_forward_ad=True,
  13090. supports_fwgrad_bwgrad=True,
  13091. sample_inputs_func=sample_inputs_pca_lowrank,
  13092. decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack, with_tf32_off,
  13093. DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-03)}),
  13094. 'TestCommon', 'test_noncontiguous_samples',
  13095. device_type='cuda')],
  13096. skips=(
  13097. # test does not work with passing lambda for op
  13098. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  13099. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  13100. DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
  13101. )),
  13102. BinaryUfuncInfo('polar',
  13103. dtypes=floating_types(),
  13104. # this function is undefined if 'abs' values are <0
  13105. supports_forward_ad=True,
  13106. lhs_make_tensor_kwargs=dict(low=0),
  13107. supports_rhs_python_scalar=False,
  13108. skips=(
  13109. # RuntimeError: Expected object of scalar type Float but got scalar type Double for second argument
  13110. DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs', 'test_type_promotion'),
  13111. # GradcheckError: Jacobian computed with forward mode mismatch for output 0 with respect to input 0
  13112. # Numerical:
  13113. # tensor([[0.]], dtype=torch.float64)
  13114. # Analytical:
  13115. # tensor([[-0.0047]], dtype=torch.float64, grad_fn=<CopySlices>)
  13116. DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'),
  13117. )),
  13118. # TODO(@kshitij12345): Refactor similar to `mvlgamma` entries.
  13119. # To test reference numerics against multiple values of argument `n`,
  13120. # we make multiple OpInfo entries with each entry corresponding to different value of n (currently 0 to 4).
  13121. # We run the op tests from test_ops.py only for `n=0` to avoid redundancy in testing.
  13122. UnaryUfuncInfo('polygamma',
  13123. op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs),
  13124. variant_test_name='polygamma_n_0',
  13125. ref=reference_polygamma if TEST_SCIPY else None,
  13126. dtypes=all_types_and(torch.bool, torch.bfloat16),
  13127. dtypesIfCUDA=all_types_and(torch.bool, torch.half),
  13128. supports_forward_ad=True,
  13129. supports_fwgrad_bwgrad=True,
  13130. sample_inputs_func=sample_inputs_polygamma,
  13131. skips=(
  13132. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  13133. ),
  13134. sample_kwargs=lambda device, dtype, input: ({'n': 0}, {'n': 0})),
  13135. UnaryUfuncInfo('polygamma',
  13136. op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs),
  13137. variant_test_name='polygamma_n_1',
  13138. ref=reference_polygamma if TEST_SCIPY else None,
  13139. dtypes=all_types_and(torch.bool, torch.bfloat16),
  13140. dtypesIfCUDA=all_types_and(torch.bool, torch.half),
  13141. supports_forward_ad=True,
  13142. supports_fwgrad_bwgrad=True,
  13143. sample_inputs_func=sample_inputs_polygamma,
  13144. skips=(
  13145. # Redundant tests
  13146. DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'),
  13147. DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'),
  13148. DecorateInfo(unittest.skip("Skipped!"), 'TestJit'),
  13149. DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators'),
  13150. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'),
  13151. # Mismatch: https://github.com/pytorch/pytorch/issues/55357
  13152. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal'),
  13153. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large'),
  13154. ),
  13155. sample_kwargs=lambda device, dtype, input: ({'n': 1}, {'n': 1}),
  13156. # polygamma functions have multiple singularities at x <= 0
  13157. reference_numerics_filter=NumericsFilter(condition=lambda x: x < 0.1, safe_val=1)),
  13158. UnaryUfuncInfo('polygamma',
  13159. op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs),
  13160. variant_test_name='polygamma_n_2',
  13161. ref=reference_polygamma if TEST_SCIPY else None,
  13162. dtypes=all_types_and(torch.bool, torch.bfloat16),
  13163. dtypesIfCUDA=all_types_and(torch.bool, torch.half),
  13164. supports_forward_ad=True,
  13165. supports_fwgrad_bwgrad=True,
  13166. sample_inputs_func=sample_inputs_polygamma,
  13167. skips=(
  13168. # Redundant tests
  13169. DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'),
  13170. DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'),
  13171. DecorateInfo(unittest.skip("Skipped!"), 'TestJit'),
  13172. DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators'),
  13173. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'),
  13174. # Mismatch: https://github.com/pytorch/pytorch/issues/55357
  13175. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal'),),
  13176. sample_kwargs=lambda device, dtype, input: ({'n': 2}, {'n': 2}),
  13177. # polygamma functions have multiple singularities at x <= 0
  13178. reference_numerics_filter=NumericsFilter(condition=lambda x: x < 0.1, safe_val=1)),
  13179. UnaryUfuncInfo('polygamma',
  13180. op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs),
  13181. variant_test_name='polygamma_n_3',
  13182. ref=reference_polygamma if TEST_SCIPY else None,
  13183. dtypes=all_types_and(torch.bool, torch.bfloat16),
  13184. dtypesIfCUDA=all_types_and(torch.bool, torch.half),
  13185. supports_forward_ad=True,
  13186. supports_fwgrad_bwgrad=True,
  13187. sample_inputs_func=sample_inputs_polygamma,
  13188. skips=(
  13189. # Redundant tests
  13190. DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'),
  13191. DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'),
  13192. DecorateInfo(unittest.skip("Skipped!"), 'TestJit'),
  13193. DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators'),
  13194. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'),
  13195. # Mismatch: https://github.com/pytorch/pytorch/issues/55357
  13196. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal'),),
  13197. sample_kwargs=lambda device, dtype, input: ({'n': 3}, {'n': 3}),
  13198. # polygamma functions have multiple singularities at x <= 0
  13199. reference_numerics_filter=NumericsFilter(condition=lambda x: x < 0.1, safe_val=1)),
  13200. UnaryUfuncInfo('polygamma',
  13201. op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs),
  13202. variant_test_name='polygamma_n_4',
  13203. ref=reference_polygamma if TEST_SCIPY else None,
  13204. decorators=(precisionOverride({torch.float16: 5e-4, torch.float32: 5e-4}),),
  13205. dtypes=all_types_and(torch.bool, torch.bfloat16),
  13206. dtypesIfCUDA=all_types_and(torch.bool, torch.half),
  13207. supports_forward_ad=True,
  13208. supports_fwgrad_bwgrad=True,
  13209. sample_inputs_func=sample_inputs_polygamma,
  13210. skips=(
  13211. # Redundant tests
  13212. DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'),
  13213. DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'),
  13214. DecorateInfo(unittest.skip("Skipped!"), 'TestJit'),
  13215. DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators'),
  13216. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'),
  13217. # Mismatch: https://github.com/pytorch/pytorch/issues/55357
  13218. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal'),),
  13219. sample_kwargs=lambda device, dtype, input: ({'n': 4}, {'n': 4}),
  13220. # polygamma functions have multiple singularities at x <= 0
  13221. reference_numerics_filter=NumericsFilter(condition=lambda x: x < 0.1, safe_val=1)),
  13222. OpInfo('ravel',
  13223. ref=np.ravel,
  13224. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
  13225. supports_out=False,
  13226. supports_forward_ad=True,
  13227. supports_fwgrad_bwgrad=True,
  13228. # See https://github.com/pytorch/pytorch/pull/78358
  13229. check_batched_forward_grad=False,
  13230. sample_inputs_func=sample_inputs_ravel,
  13231. ),
  13232. OpInfo('reshape',
  13233. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
  13234. sample_inputs_func=sample_inputs_view_reshape,
  13235. reference_inputs_func=reference_inputs_view_reshape,
  13236. error_inputs_func=error_inputs_view_reshape,
  13237. supports_out=False,
  13238. supports_forward_ad=True,
  13239. supports_fwgrad_bwgrad=True,
  13240. ),
  13241. OpInfo('reshape_as',
  13242. op=lambda x, other: x.reshape_as(other),
  13243. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
  13244. sample_inputs_func=partial(sample_inputs_view_reshape, tensor_arg=True),
  13245. reference_inputs_func=partial(reference_inputs_view_reshape, tensor_arg=True),
  13246. error_inputs_func=partial(error_inputs_view_reshape, tensor_arg=True),
  13247. supports_out=False,
  13248. supports_forward_ad=True,
  13249. supports_fwgrad_bwgrad=True,
  13250. skips=(
  13251. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  13252. )),
  13253. OpInfo('view',
  13254. op=lambda x, shape: x.view(shape),
  13255. dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
  13256. supports_out=False,
  13257. supports_forward_ad=True,
  13258. supports_fwgrad_bwgrad=True,
  13259. assert_jit_shape_analysis=True,
  13260. sample_inputs_func=sample_inputs_view_reshape,
  13261. reference_inputs_func=reference_inputs_view_reshape,
  13262. error_inputs_func=error_inputs_view_reshape,
  13263. skips=(
  13264. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  13265. )),
  13266. OpInfo('view_as',
  13267. op=lambda x, other: x.view_as(other),
  13268. dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
  13269. supports_out=False,
  13270. supports_forward_ad=True,
  13271. supports_fwgrad_bwgrad=True,
  13272. sample_inputs_func=partial(sample_inputs_view_reshape, tensor_arg=True),
  13273. reference_inputs_func=partial(reference_inputs_view_reshape, tensor_arg=True),
  13274. error_inputs_func=partial(error_inputs_view_reshape, tensor_arg=True),
  13275. skips=(
  13276. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  13277. )),
  13278. OpInfo('atleast_1d',
  13279. dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
  13280. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  13281. gradcheck_fast_mode=True,
  13282. supports_out=False,
  13283. supports_forward_ad=True,
  13284. supports_fwgrad_bwgrad=True,
  13285. # See https://github.com/pytorch/pytorch/pull/78358
  13286. check_batched_forward_grad=False,
  13287. sample_inputs_func=sample_inputs_atleast1d2d3d,
  13288. skips=(
  13289. # JIT does not support variadic tensors.
  13290. # RuntimeError: input->type()->kind() == TypeKind::OptionalType
  13291. # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252,
  13292. # please report a bug to PyTorch.
  13293. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  13294. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]),
  13295. ),
  13296. ),
  13297. OpInfo('atleast_2d',
  13298. dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
  13299. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  13300. gradcheck_fast_mode=True,
  13301. supports_out=False,
  13302. supports_forward_ad=True,
  13303. supports_fwgrad_bwgrad=True,
  13304. # See https://github.com/pytorch/pytorch/pull/78358
  13305. check_batched_forward_grad=False,
  13306. skips=(
  13307. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  13308. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]),
  13309. ),
  13310. sample_inputs_func=sample_inputs_atleast1d2d3d,
  13311. ),
  13312. OpInfo('atleast_3d',
  13313. dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
  13314. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  13315. gradcheck_fast_mode=True,
  13316. supports_out=False,
  13317. supports_forward_ad=True,
  13318. supports_fwgrad_bwgrad=True,
  13319. # See https://github.com/pytorch/pytorch/pull/78358
  13320. check_batched_forward_grad=False,
  13321. skips=(
  13322. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  13323. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]),
  13324. ),
  13325. sample_inputs_func=sample_inputs_atleast1d2d3d,
  13326. ),
  13327. OpInfo('flatten',
  13328. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
  13329. ref=reference_flatten,
  13330. supports_out=False,
  13331. supports_forward_ad=True,
  13332. supports_fwgrad_bwgrad=True,
  13333. # See https://github.com/pytorch/pytorch/pull/78358
  13334. check_batched_forward_grad=False,
  13335. sample_inputs_func=sample_inputs_flatten,
  13336. reference_inputs_func=reference_inputs_flatten,
  13337. ),
  13338. OpInfo('unflatten',
  13339. op=torch.unflatten,
  13340. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
  13341. supports_out=False,
  13342. supports_forward_ad=True,
  13343. supports_fwgrad_bwgrad=True,
  13344. sample_inputs_func=sample_inputs_unflatten,
  13345. ),
  13346. OpInfo('column_stack',
  13347. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
  13348. supports_forward_ad=True,
  13349. supports_fwgrad_bwgrad=True,
  13350. # See https://github.com/pytorch/pytorch/pull/78358
  13351. check_batched_forward_grad=False,
  13352. skips=(
  13353. # lambda impl
  13354. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),),
  13355. sample_inputs_func=sample_inputs_column_stack,),
  13356. OpInfo('pinverse',
  13357. op=torch.pinverse,
  13358. dtypes=floating_and_complex_types(),
  13359. check_batched_grad=False,
  13360. check_batched_gradgrad=False,
  13361. supports_forward_ad=True,
  13362. supports_fwgrad_bwgrad=True,
  13363. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  13364. supports_out=False,
  13365. sample_inputs_func=sample_inputs_linalg_invertible,
  13366. decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
  13367. skips=(
  13368. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager',
  13369. device_type='mps', dtypes=[torch.float32]),
  13370. DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
  13371. device_type='mps', dtypes=[torch.float32]),
  13372. )),
  13373. OpInfo('gather',
  13374. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
  13375. dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
  13376. sample_inputs_func=sample_inputs_gather,
  13377. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  13378. assert_autodiffed=True,
  13379. supports_forward_ad=True,
  13380. supports_fwgrad_bwgrad=True,
  13381. error_inputs_func=error_inputs_gather,
  13382. ),
  13383. OpInfo('index_fill',
  13384. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.complex32),
  13385. supports_out=False,
  13386. supports_forward_ad=True,
  13387. supports_fwgrad_bwgrad=True,
  13388. # https://github.com/pytorch/pytorch/issues/66357
  13389. check_batched_forward_grad=False,
  13390. sample_inputs_func=sample_inputs_index,
  13391. reference_inputs_func=partial(sample_inputs_index, reference=True)),
  13392. OpInfo('index_copy',
  13393. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.complex32),
  13394. supports_out=True,
  13395. supports_forward_ad=True,
  13396. supports_fwgrad_bwgrad=True,
  13397. # https://github.com/pytorch/pytorch/issues/66357
  13398. check_batched_forward_grad=False,
  13399. sample_inputs_func=sample_inputs_index,
  13400. reference_inputs_func=partial(sample_inputs_index, reference=True),
  13401. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL),
  13402. OpInfo('index_select',
  13403. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
  13404. backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16, torch.chalf),
  13405. sample_inputs_func=sample_inputs_index,
  13406. reference_inputs_func=partial(sample_inputs_index, reference=True),
  13407. error_inputs_func=error_inputs_index_select,
  13408. assert_autodiffed=True,
  13409. supports_forward_ad=True,
  13410. supports_fwgrad_bwgrad=True,
  13411. assert_jit_shape_analysis=True,
  13412. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL),
  13413. OpInfo('index_add',
  13414. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
  13415. supports_out=True,
  13416. supports_forward_ad=True,
  13417. supports_fwgrad_bwgrad=True,
  13418. # https://github.com/pytorch/pytorch/issues/66357
  13419. check_batched_forward_grad=False,
  13420. sample_inputs_func=sample_inputs_index,
  13421. reference_inputs_func=partial(sample_inputs_index, reference=True),
  13422. skips=(
  13423. # boolean alpha not handled properly
  13424. DecorateInfo(unittest.expectedFailure,
  13425. 'TestCudaFuserOpInfo',
  13426. 'test_nvfuser_correctness',
  13427. dtypes=(torch.bool,)),
  13428. DecorateInfo(unittest.expectedFailure,
  13429. 'TestNNCOpInfo',
  13430. 'test_nnc_correctness',
  13431. dtypes=(torch.bool,)),
  13432. ),
  13433. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL),
  13434. OpInfo('index_reduce',
  13435. dtypes=all_types_and(torch.float16, torch.bfloat16),
  13436. supports_out=True,
  13437. skips=(
  13438. # Pre-existing condition (calls .item); needs to be fixed
  13439. DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
  13440. ),
  13441. sample_inputs_func=sample_inputs_index_reduce),
  13442. OpInfo('__getitem__',
  13443. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
  13444. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  13445. gradcheck_fast_mode=True,
  13446. supports_out=False,
  13447. supports_forward_ad=True,
  13448. supports_fwgrad_bwgrad=True,
  13449. supports_inplace_autograd=False,
  13450. supports_scripting=False,
  13451. op=torch.Tensor.__getitem__,
  13452. skips=(
  13453. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  13454. # AssertionError: False is not true : Scalars failed to compare as equal! 0 != 104448
  13455. DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', device_type='cuda'),),
  13456. sample_inputs_func=sample_inputs_getitem),
  13457. OpInfo('index_put',
  13458. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
  13459. supports_out=False,
  13460. supports_inplace_autograd=True,
  13461. supports_forward_ad=True,
  13462. supports_fwgrad_bwgrad=True,
  13463. # https://github.com/pytorch/pytorch/issues/66357
  13464. check_batched_forward_grad=False,
  13465. test_neg_view=False,
  13466. sample_inputs_func=sample_inputs_index_put,
  13467. skips=(
  13468. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  13469. )),
  13470. OpInfo('sort',
  13471. dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
  13472. dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
  13473. sample_inputs_func=sample_inputs_sort,
  13474. supports_forward_ad=True,
  13475. supports_fwgrad_bwgrad=True,
  13476. skips=(
  13477. )),
  13478. OpInfo('unique',
  13479. dtypes=all_types_and(torch.bool, torch.bfloat16),
  13480. dtypesIfCUDA=all_types_and(torch.bool, torch.float16),
  13481. sample_inputs_func=sample_inputs_unique,
  13482. supports_out=False,
  13483. supports_autograd=False,
  13484. skips=(
  13485. # lambda impl
  13486. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  13487. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  13488. # 76571 - CUDA gets expectedFailure, but this test passes for ROCm
  13489. DecorateInfo(unittest.expectedFailure, 'TestCudaFuserOpInfo', 'test_nvfuser_extremal_values',
  13490. dtypes=(torch.float16, torch.float32, torch.float64), active_if=not TEST_WITH_ROCM),
  13491. DecorateInfo(unittest.skip('Output order is undefined when sorted=False'), 'TestCommon', 'test_compare_cpu'),
  13492. )),
  13493. OpInfo('unique_consecutive',
  13494. dtypes=all_types_and(torch.bool, torch.bfloat16),
  13495. dtypesIfCUDA=all_types_and(torch.bool, torch.float16),
  13496. sample_inputs_func=sample_inputs_unique_consecutive,
  13497. supports_out=False,
  13498. supports_autograd=False,
  13499. skips=(
  13500. # lambda impl
  13501. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  13502. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  13503. # 76571 - CUDA gets expectedFailure, but this test passes for ROCm
  13504. DecorateInfo(unittest.expectedFailure, 'TestCudaFuserOpInfo', 'test_nvfuser_extremal_values',
  13505. dtypes=(torch.float16, torch.float32, torch.float64), active_if=not TEST_WITH_ROCM),
  13506. )),
  13507. OpInfo('put',
  13508. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
  13509. supports_out=False,
  13510. supports_forward_ad=True,
  13511. supports_fwgrad_bwgrad=True,
  13512. check_batched_forward_grad=False,
  13513. check_batched_gradgrad=False, # vmap complains of the sizes
  13514. sample_inputs_func=sample_inputs_put),
  13515. OpInfo('take',
  13516. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
  13517. check_batched_grad=False, # vmap complains of the sizes
  13518. supports_forward_ad=True,
  13519. supports_fwgrad_bwgrad=True,
  13520. sample_inputs_func=sample_inputs_take,
  13521. error_inputs_func=error_inputs_take),
  13522. OpInfo('scatter',
  13523. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
  13524. supports_forward_ad=True,
  13525. supports_fwgrad_bwgrad=True,
  13526. sample_inputs_func=sample_inputs_scatter,
  13527. error_inputs_func=error_inputs_scatter_and_scatter_add),
  13528. UnaryUfuncInfo(
  13529. 'bfloat16',
  13530. op=lambda x, *args, **kwargs: x.bfloat16(*args, **kwargs),
  13531. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
  13532. supports_out=False,
  13533. sample_inputs_func=sample_inputs_conversion,
  13534. skips=(
  13535. # autograd tests don't handle operators that change dtype
  13536. DecorateInfo(unittest.expectedFailure, 'TestFwdGradients'),
  13537. DecorateInfo(unittest.expectedFailure, 'TestBwdGradients'),
  13538. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  13539. # RuntimeError: attribute lookup is not defined on builtin
  13540. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  13541. DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
  13542. DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo', 'test_nvfuser_correctness'),
  13543. )),
  13544. UnaryUfuncInfo(
  13545. 'bool',
  13546. op=lambda x, *args, **kwargs: x.bool(*args, **kwargs),
  13547. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
  13548. supports_out=False,
  13549. sample_inputs_func=sample_inputs_conversion,
  13550. supports_autograd=False,
  13551. skips=(
  13552. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  13553. # RuntimeError: attribute lookup is not defined on builtin
  13554. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  13555. )),
  13556. UnaryUfuncInfo(
  13557. 'byte',
  13558. op=lambda x, *args, **kwargs: x.byte(*args, **kwargs),
  13559. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
  13560. supports_out=False,
  13561. sample_inputs_func=sample_inputs_conversion,
  13562. # The autograd test runner cannot handle functions that change dtype
  13563. supports_autograd=False,
  13564. skips=(
  13565. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  13566. # RuntimeError: attribute lookup is not defined on builtin
  13567. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  13568. DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'),
  13569. )),
  13570. UnaryUfuncInfo(
  13571. 'char',
  13572. op=lambda x, *args, **kwargs: x.char(*args, **kwargs),
  13573. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
  13574. supports_out=False,
  13575. sample_inputs_func=sample_inputs_conversion,
  13576. # The autograd test runner cannot handle functions that change dtype
  13577. supports_autograd=False,
  13578. skips=(
  13579. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  13580. # RuntimeError: attribute lookup is not defined on builtin
  13581. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  13582. DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'),
  13583. )),
  13584. UnaryUfuncInfo(
  13585. 'double',
  13586. op=lambda x, *args, **kwargs: x.double(*args, **kwargs),
  13587. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
  13588. supports_out=False,
  13589. sample_inputs_func=sample_inputs_conversion,
  13590. supports_forward_ad=True,
  13591. supports_fwgrad_bwgrad=True,
  13592. skips=(
  13593. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  13594. # RuntimeError: attribute lookup is not defined on builtin
  13595. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  13596. )),
  13597. UnaryUfuncInfo(
  13598. 'float',
  13599. op=lambda x, *args, **kwargs: x.float(*args, **kwargs),
  13600. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
  13601. supports_out=False,
  13602. sample_inputs_func=sample_inputs_conversion,
  13603. skips=(
  13604. # autograd tests don't handle operators that change dtype
  13605. DecorateInfo(unittest.expectedFailure, 'TestFwdGradients'),
  13606. DecorateInfo(unittest.expectedFailure, 'TestBwdGradients'),
  13607. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  13608. # RuntimeError: attribute lookup is not defined on builtin
  13609. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  13610. )),
  13611. UnaryUfuncInfo(
  13612. 'half',
  13613. op=lambda x, *args, **kwargs: x.half(*args, **kwargs),
  13614. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
  13615. supports_out=False,
  13616. sample_inputs_func=sample_inputs_conversion,
  13617. supports_autograd=True,
  13618. skips=(
  13619. # autograd tests don't handle operators that change dtype
  13620. DecorateInfo(unittest.expectedFailure, 'TestFwdGradients'),
  13621. DecorateInfo(unittest.expectedFailure, 'TestBwdGradients'),
  13622. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  13623. # RuntimeError: attribute lookup is not defined on builtin
  13624. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  13625. )),
  13626. UnaryUfuncInfo(
  13627. 'int',
  13628. op=lambda x, *args, **kwargs: x.int(*args, **kwargs),
  13629. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
  13630. supports_out=False,
  13631. sample_inputs_func=sample_inputs_conversion,
  13632. supports_autograd=False,
  13633. skips=(
  13634. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  13635. # RuntimeError: attribute lookup is not defined on builtin
  13636. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  13637. DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'),
  13638. )),
  13639. UnaryUfuncInfo(
  13640. 'long',
  13641. op=lambda x, *args, **kwargs: x.long(*args, **kwargs),
  13642. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
  13643. supports_out=False,
  13644. sample_inputs_func=sample_inputs_conversion,
  13645. supports_autograd=False,
  13646. skips=(
  13647. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  13648. # RuntimeError: attribute lookup is not defined on builtin
  13649. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  13650. DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'),
  13651. )),
  13652. UnaryUfuncInfo(
  13653. 'short',
  13654. op=lambda x, *args, **kwargs: x.short(*args, **kwargs),
  13655. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
  13656. supports_out=False,
  13657. sample_inputs_func=sample_inputs_conversion,
  13658. supports_autograd=False,
  13659. skips=(
  13660. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  13661. # RuntimeError: attribute lookup is not defined on builtin
  13662. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  13663. DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'),
  13664. )),
  13665. UnaryUfuncInfo(
  13666. 'cdouble',
  13667. op=torch.Tensor.cdouble,
  13668. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
  13669. supports_out=False,
  13670. sample_inputs_func=sample_inputs_conversion,
  13671. supports_forward_ad=True,
  13672. supports_fwgrad_bwgrad=True,
  13673. skips=(
  13674. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  13675. # RuntimeError: attribute lookup is not defined on builtin
  13676. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  13677. DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
  13678. )),
  13679. UnaryUfuncInfo(
  13680. 'cfloat',
  13681. op=torch.Tensor.cfloat,
  13682. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
  13683. supports_out=False,
  13684. sample_inputs_func=sample_inputs_conversion,
  13685. skips=(
  13686. # autograd tests don't handle operators that change dtype
  13687. DecorateInfo(unittest.expectedFailure, 'TestFwdGradients'),
  13688. DecorateInfo(unittest.expectedFailure, 'TestBwdGradients'),
  13689. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  13690. # RuntimeError: attribute lookup is not defined on builtin
  13691. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  13692. DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
  13693. )),
  13694. UnaryUfuncInfo(
  13695. 'chalf',
  13696. op=lambda x, *args, **kwargs: x.chalf(*args, **kwargs),
  13697. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
  13698. supports_out=False,
  13699. sample_inputs_func=sample_inputs_conversion,
  13700. skips=(
  13701. # autograd tests don't handle operators that change dtype
  13702. DecorateInfo(unittest.expectedFailure, 'TestFwdGradients'),
  13703. DecorateInfo(unittest.expectedFailure, 'TestBwdGradients'),
  13704. # use of lambda doesn't work with test_normalize_operator_exhaustive
  13705. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  13706. # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf'
  13707. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager',
  13708. device_type='cpu'),
  13709. # TypeError: 'int' object is not iterable
  13710. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  13711. # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf'
  13712. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view',
  13713. device_type='cpu'),
  13714. # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf'
  13715. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view',
  13716. device_type='cpu'),
  13717. # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf'
  13718. # RuntimeError: "neg_conj_cuda" not implemented for 'ComplexHalf'
  13719. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
  13720. )
  13721. ),
  13722. OpInfo('empty_like',
  13723. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
  13724. supports_out=False,
  13725. sample_inputs_func=sample_inputs_like_fns,
  13726. reference_inputs_func=reference_inputs_like_fns,
  13727. supports_autograd=False,
  13728. skips=(
  13729. # Empty tensor data is garbage so it's hard to make comparisons with it.
  13730. DecorateInfo(unittest.skip("Skipped!"),
  13731. "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  13732. # Empty tensor data is garbage so it's hard to make comparisons with it.
  13733. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'),
  13734. # Empty tensor data is garbage so it's hard to make comparisons with it.
  13735. DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
  13736. # Empty tensor data is garbage so it's hard to make comparisons with it.
  13737. DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'),
  13738. # Empty tensor data is garbage so it's hard to make comparisons with it.
  13739. DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
  13740. # Empty tensor data is garbage so it's hard to make comparisons with it.
  13741. DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
  13742. # Empty tensor data is garbage so it's hard to make comparisons with it.
  13743. DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
  13744. # Empty tensor data is garbage so it's hard to make comparisons with it.
  13745. DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'),
  13746. # Empty tensor data is garbage so it's hard to make comparisons with it.
  13747. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_complex_half_reference_testing'),
  13748. # Empty tensor data is garbage so it's hard to make comparisons with it.
  13749. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'),
  13750. DecorateInfo(unittest.skip("Expected: empty_like is not comparable"), 'TestCompositeCompliance',
  13751. 'test_operator'),
  13752. DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
  13753. )),
  13754. OpInfo('zeros_like',
  13755. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
  13756. supports_out=False,
  13757. sample_inputs_func=sample_inputs_like_fns,
  13758. supports_autograd=False,
  13759. skips=(
  13760. )),
  13761. OpInfo('ones_like',
  13762. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
  13763. supports_out=False,
  13764. sample_inputs_func=sample_inputs_like_fns,
  13765. supports_autograd=False,
  13766. skips=(
  13767. )),
  13768. OpInfo('randn',
  13769. dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.complex32),
  13770. op=lambda *args, **kwargs: wrapper_set_seed(torch.randn, *args, **kwargs),
  13771. supports_out=True,
  13772. sample_inputs_func=sample_inputs_randn,
  13773. supports_autograd=False,
  13774. skips=(
  13775. # Tests that assume input is a tensor or sequence of tensors
  13776. DecorateInfo(unittest.skip("Test expects tensor input"), "TestCommon", "test_noncontiguous_samples"),
  13777. DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"),
  13778. DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
  13779. # CPU randn generates different values based on the strides of out tensor
  13780. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cpu'),
  13781. # randn fails to warn when resizing its out tensor
  13782. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
  13783. # FX failed to normalize op - add the op to the op_skip list.
  13784. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  13785. # Tests that assume input tensor has a meaningful effect on output tensor
  13786. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
  13787. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
  13788. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
  13789. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
  13790. # AssertionError: JIT Test does not execute any logic
  13791. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  13792. DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'),
  13793. )),
  13794. OpInfo('randn_like',
  13795. dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.complex32),
  13796. op=lambda inp, *args, **kwargs:
  13797. wrapper_set_seed(torch.randn_like, inp, *args, **kwargs),
  13798. supports_out=False,
  13799. sample_inputs_func=sample_inputs_like_fns,
  13800. supports_autograd=False,
  13801. supports_sparse_csr=True,
  13802. skips=(
  13803. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  13804. # AssertionError: JIT Test does not execute any logic
  13805. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  13806. DecorateInfo(unittest.skip("Expected: randn_like is not comparable between dtypes"),
  13807. 'TestCommon', 'test_complex_half_reference_testing'),
  13808. DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
  13809. )),
  13810. OpInfo('rand_like',
  13811. dtypes=floating_types_and(torch.half, torch.bfloat16, torch.complex32, torch.complex64, torch.complex128),
  13812. op=lambda inp, *args, **kwargs:
  13813. wrapper_set_seed(torch.randn_like, inp, *args, **kwargs),
  13814. supports_out=False,
  13815. sample_inputs_func=sample_inputs_like_fns,
  13816. supports_autograd=False,
  13817. skips=(
  13818. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  13819. # AssertionError: JIT Test does not execute any logic
  13820. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  13821. DecorateInfo(unittest.skip("Expected: randn_like is not comparable between dtypes"),
  13822. 'TestCommon', 'test_complex_half_reference_testing'),
  13823. DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
  13824. )),
  13825. OpInfo('randint',
  13826. dtypes=all_types_and(torch.half, torch.bfloat16),
  13827. op=lambda *args, **kwargs:
  13828. wrapper_set_seed(torch.randint, *args, **kwargs),
  13829. supports_out=False,
  13830. sample_inputs_func=sample_inputs_randint,
  13831. supports_autograd=False,
  13832. skips=(
  13833. # Tests that assume input is a tensor or sequence of tensors
  13834. DecorateInfo(unittest.skip("Test expects tensor input"), "TestCommon", "test_noncontiguous_samples"),
  13835. DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"),
  13836. DecorateInfo(unittest.skip("Test expects tensor input"), "TestCommon", "test_multiple_devices"),
  13837. DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
  13838. # CPU randint generates different values based on the strides of out tensor
  13839. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
  13840. # randint fails to warn when resizing its out tensor
  13841. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
  13842. # FX failed to normalize op - add the op to the op_skip list.
  13843. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  13844. # Tests that assume input tensor has a meaningful effect on output tensor
  13845. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
  13846. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
  13847. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
  13848. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
  13849. # AssertionError: JIT Test does not execute any logic
  13850. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  13851. # Might need to skip until ROCm5.5
  13852. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_multiple_devices',
  13853. dtypes=[torch.float32, torch.int64], active_if=TEST_WITH_ROCM),
  13854. )),
  13855. OpInfo('randint_like',
  13856. dtypes=all_types_and(torch.half, torch.bfloat16),
  13857. op=lambda inp, *args, **kwargs:
  13858. wrapper_set_seed(torch.randint_like, inp, *args, **kwargs),
  13859. supports_out=False,
  13860. sample_inputs_func=sample_inputs_randint_like,
  13861. supports_autograd=False,
  13862. skips=(
  13863. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  13864. # AssertionError: JIT Test does not execute any logic
  13865. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  13866. DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
  13867. )),
  13868. OpInfo('full_like',
  13869. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
  13870. supports_out=False,
  13871. sample_inputs_func=sample_inputs_full_like,
  13872. supports_autograd=False,
  13873. skips=(
  13874. )),
  13875. OpInfo('new_zeros',
  13876. op=lambda x, *args, **kwargs: x.new_zeros(*args, **kwargs),
  13877. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
  13878. supports_out=False,
  13879. sample_inputs_func=sample_inputs_new_fns,
  13880. skips=(
  13881. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  13882. ),
  13883. supports_autograd=False),
  13884. OpInfo('new_ones',
  13885. op=lambda x, *args, **kwargs: x.new_ones(*args, **kwargs),
  13886. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
  13887. supports_out=False,
  13888. sample_inputs_func=sample_inputs_new_fns,
  13889. skips=(
  13890. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  13891. ),
  13892. supports_autograd=False),
  13893. OpInfo('ones',
  13894. op=torch.ones,
  13895. supports_autograd=False,
  13896. supports_varargs=True,
  13897. is_factory_function=True,
  13898. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
  13899. supports_out=True,
  13900. sample_inputs_func=sample_inputs_ones_zeros,
  13901. skips=(
  13902. # Tests that assume input is a tensor or sequence of tensors
  13903. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
  13904. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
  13905. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
  13906. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
  13907. # Same failure as arange: cannot find linspace in captured graph
  13908. DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
  13909. # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
  13910. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
  13911. )),
  13912. OpInfo('zeros',
  13913. op=torch.zeros,
  13914. supports_autograd=False,
  13915. is_factory_function=True,
  13916. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
  13917. supports_out=True,
  13918. sample_inputs_func=sample_inputs_ones_zeros,
  13919. skips=(
  13920. # Tests that assume input is a tensor or sequence of tensors
  13921. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
  13922. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
  13923. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
  13924. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
  13925. # Same failure as arange: cannot find linspace in captured graph
  13926. DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
  13927. # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
  13928. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
  13929. )),
  13930. OpInfo('full',
  13931. op=torch.full,
  13932. supports_autograd=False,
  13933. is_factory_function=True,
  13934. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
  13935. supports_out=True,
  13936. sample_inputs_func=sample_inputs_full,
  13937. skips=(
  13938. # Tests that assume input is a tensor or sequence of tensors
  13939. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
  13940. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
  13941. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
  13942. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
  13943. # Same failure as arange: cannot find linspace in captured graph
  13944. DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
  13945. # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
  13946. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
  13947. # boolean alpha not handled properly
  13948. DecorateInfo(unittest.expectedFailure,
  13949. 'TestCudaFuserOpInfo',
  13950. 'test_nvfuser_correctness',
  13951. dtypes=(torch.bool,)),
  13952. # RuntimeError: UNSUPPORTED DTYPE: bool
  13953. DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.bool,)),
  13954. )),
  13955. OpInfo('new_empty',
  13956. op=lambda x, *args, **kwargs: x.new_empty(*args, **kwargs),
  13957. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
  13958. supports_out=False,
  13959. sample_inputs_func=sample_inputs_new_fns,
  13960. skips=(
  13961. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  13962. # Empty tensor data is garbage so it's hard to make comparisons with it.
  13963. DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
  13964. # Empty tensor data is garbage so it's hard to make comparisons with it.
  13965. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'),
  13966. # Empty tensor data is garbage so it's hard to make comparisons with it.
  13967. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'),
  13968. # Empty tensor data is garbage so it's hard to make comparisons with it.
  13969. DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'),
  13970. # Empty tensor data is garbage so it's hard to make comparisons with it.
  13971. DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
  13972. # Empty tensor data is garbage so it's hard to make comparisons with it.
  13973. DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
  13974. # Empty tensor data is garbage so it's hard to make comparisons with it.
  13975. DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
  13976. # Empty tensor data is garbage so it's hard to make comparisons with it.
  13977. DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'),
  13978. # Empty tensor data is garbage so it's hard to make comparisons with it.
  13979. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'),
  13980. DecorateInfo(unittest.skip("Expected: new_empty is not comparable"), 'TestCompositeCompliance',
  13981. 'test_operator'),
  13982. DecorateInfo(unittest.skip("Expected: new_empty is not comparable"),
  13983. 'TestCommon', 'test_complex_half_reference_testing'),
  13984. DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
  13985. ),
  13986. supports_autograd=False),
  13987. OpInfo('new_empty_strided',
  13988. op=lambda x, *args, **kwargs: x.new_empty_strided(*args, **kwargs),
  13989. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
  13990. supports_out=False,
  13991. sample_inputs_func=partial(sample_inputs_new_fns, is_strided=True),
  13992. supports_autograd=False,
  13993. skips=(
  13994. # FX failed to normalize op
  13995. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  13996. # Lazy tensor failures
  13997. DecorateInfo(unittest.skip("Skipped!"), 'TestLazyOpInfo', 'test_correctness'),
  13998. DecorateInfo(unittest.skip("Skipped!"), 'TestLazyOpInfo', 'test_correctness_with_reusing_ir'),
  13999. # Empty tensor data is garbage so it's hard to make comparisons with it.
  14000. DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
  14001. 'TestCommon', 'test_variant_consistency_eager'),
  14002. DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
  14003. 'TestCommon', 'test_noncontiguous_samples'),
  14004. DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
  14005. 'TestMathBits', 'test_conj_view'),
  14006. DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
  14007. 'TestMathBits', 'test_neg_view'),
  14008. DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
  14009. 'TestMathBits', 'test_neg_conj_view'),
  14010. DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
  14011. 'TestCommon', 'test_non_standard_bool_values'),
  14012. DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
  14013. 'TestCommon', 'test_complex_half_reference_testing'),
  14014. DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
  14015. 'TestCompositeCompliance', 'test_operator'),
  14016. DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
  14017. 'TestDecomp', 'test_comprehensive'),
  14018. DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
  14019. 'TestDecomp', 'test_quick'),
  14020. DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
  14021. 'TestJit', 'test_variant_consistency_jit'),
  14022. DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
  14023. 'TestProxyTensorOpInfo', 'test_make_fx_exhaustive'),
  14024. DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
  14025. 'TestProxyTensorOpInfo', 'test_make_fx_fake_exhaustive'),
  14026. DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
  14027. 'TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive'),
  14028. DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
  14029. 'TestNNCOpInfo', 'test_nnc_correctness'),
  14030. DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
  14031. 'TestCudaFuserOpInfo', 'test_nvfuser_extremal_values'),
  14032. DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
  14033. 'TestCudaFuserOpInfo', 'test_nvfuser_correctness'),
  14034. DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
  14035. )),
  14036. OpInfo('empty',
  14037. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
  14038. sample_inputs_func=sample_inputs_empty,
  14039. supports_autograd=False,
  14040. skips=(
  14041. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  14042. # Empty tensor data is garbage so it's hard to make comparisons with it.
  14043. DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
  14044. # Empty tensor data is garbage so it's hard to make comparisons with it.
  14045. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'),
  14046. # Empty tensor data is garbage so it's hard to make comparisons with it.
  14047. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'),
  14048. # Empty tensor data is garbage so it's hard to make comparisons with it.
  14049. DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'),
  14050. # Empty tensor data is garbage so it's hard to make comparisons with it.
  14051. DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
  14052. # Empty tensor data is garbage so it's hard to make comparisons with it.
  14053. DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
  14054. # Empty tensor data is garbage so it's hard to make comparisons with it.
  14055. DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
  14056. # Empty tensor data is garbage so it's hard to make comparisons with it.
  14057. DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'),
  14058. # Empty tensor data is garbage so it's hard to make comparisons with it.
  14059. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'),
  14060. DecorateInfo(unittest.skip("Expected: empty is not comparable"), 'TestCompositeCompliance',
  14061. 'test_operator'),
  14062. # requires_grad doesn't exist in the jit schema
  14063. DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
  14064. DecorateInfo(unittest.skip("Expected: empty is not comparable"),
  14065. 'TestCommon',
  14066. 'test_out'),
  14067. DecorateInfo(unittest.skip("Expected: empty is not comparable"),
  14068. 'TestCommon',
  14069. 'test_out_warning'),
  14070. DecorateInfo(unittest.skip("Expected: empty is not comparable"),
  14071. 'TestLazyOpInfo'),
  14072. DecorateInfo(unittest.skip("Expected: empty is not comparable"),
  14073. 'TestCommon', 'test_complex_half_reference_testing'),
  14074. DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
  14075. )),
  14076. OpInfo('eye',
  14077. dtypes=all_types_and_complex_and(torch.bool, torch.half),
  14078. dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
  14079. sample_inputs_func=sample_inputs_eye,
  14080. error_inputs_func=error_inputs_eye,
  14081. supports_out=True,
  14082. supports_autograd=False,
  14083. skips=(
  14084. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  14085. # TODO: same as this?
  14086. # https://github.com/pytorch/pytorch/issues/81774
  14087. # also see: arange, new_full
  14088. # fails to match any schemas despite working in the interpreter
  14089. DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
  14090. # fails to match any schemas despite working in the interpreter
  14091. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  14092. # skip these tests since we have non tensor input
  14093. DecorateInfo(unittest.skip('Skipped!'), "TestCommon", "test_noncontiguous_samples"),
  14094. DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'),
  14095. DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'),
  14096. DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
  14097. DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
  14098. # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
  14099. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
  14100. )),
  14101. OpInfo('scalar_tensor',
  14102. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
  14103. sample_inputs_func=sample_inputs_scalar_tensor,
  14104. supports_autograd=False,
  14105. supports_out=False,
  14106. skips=(
  14107. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  14108. # fails to match any schemas despite working in the interpreter
  14109. DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
  14110. # fails to match any schemas despite working in the interpreter
  14111. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  14112. # skip these tests since we have non tensor input
  14113. DecorateInfo(unittest.skip('Skipped!'), "TestCommon", "test_noncontiguous_samples"),
  14114. DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'),
  14115. DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'),
  14116. DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
  14117. DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
  14118. )),
  14119. OpInfo('new_full',
  14120. op=lambda x, *args, **kwargs: x.new_full(*args, **kwargs),
  14121. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
  14122. supports_out=False,
  14123. sample_inputs_func=sample_inputs_new_full,
  14124. skips=(
  14125. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  14126. ),
  14127. supports_autograd=False),
  14128. OpInfo('multinomial',
  14129. op=lambda inp, *args, **kwargs:
  14130. wrapper_set_seed(torch.multinomial, inp, *args, **kwargs),
  14131. method_variant=lambda inp, *args, **kwargs:
  14132. wrapper_set_seed(torch.Tensor.multinomial, inp, *args, **kwargs),
  14133. dtypes=floating_types_and(torch.bfloat16),
  14134. dtypesIfCUDA=floating_types_and(torch.half),
  14135. supports_out=True,
  14136. sample_inputs_func=sample_inputs_multinomial,
  14137. error_inputs_func=error_inputs_multinomial,
  14138. skips=(
  14139. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  14140. # Strides are not the same!
  14141. # This may not be reproducible in CI
  14142. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'),
  14143. # AssertionError: JIT Test does not execute any logic
  14144. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  14145. # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
  14146. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
  14147. DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')),
  14148. supports_autograd=False),
  14149. OpInfo('normal',
  14150. op=lambda inp, *args, **kwargs:
  14151. wrapper_set_seed(torch.normal, inp, *args, **kwargs),
  14152. # The inplace variant (Tensor.normal_) is different from torch.normal
  14153. inplace_variant=None,
  14154. dtypes=floating_types_and(torch.bfloat16, torch.half),
  14155. dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.half),
  14156. supports_out=True,
  14157. sample_inputs_func=sample_inputs_normal_tensor_first,
  14158. skips=(
  14159. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  14160. # Tensor-likes are not close!
  14161. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
  14162. # AssertionError: JIT Test does not execute any logic
  14163. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  14164. # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
  14165. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
  14166. # Computed gradient is incorrect -- would be an exfail but gradgrad somehow passes
  14167. DecorateInfo(unittest.skip("Gradients are incorrect!"), 'TestFwdGradients'),
  14168. DecorateInfo(unittest.skip("Gradients are incorrect!"), 'TestBwdGradients'),
  14169. DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'))),
  14170. OpInfo('normal',
  14171. # This has its own variant b/c OpInfos assume the first arg is a Tensor but it is not here
  14172. variant_test_name='number_mean',
  14173. op=lambda std, mean, *args, **kwargs:
  14174. wrapper_set_seed(torch.normal, mean, std, *args, **kwargs),
  14175. # The inplace variant (Tensor.normal_) is different from torch.normal
  14176. inplace_variant=None,
  14177. dtypes=floating_types_and(torch.bfloat16, torch.half),
  14178. dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.half),
  14179. supports_out=True,
  14180. sample_inputs_func=sample_inputs_normal_tensor_second,
  14181. skips=(
  14182. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  14183. # AssertionError: JIT Test does not execute any logic
  14184. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  14185. # NotImplementedError not raised
  14186. DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'),
  14187. # Computed gradient is incorrect -- would be an exfail but gradgrad somehow passes
  14188. DecorateInfo(unittest.skip("Gradients are incorrect!"), 'TestFwdGradients'),
  14189. DecorateInfo(unittest.skip("Gradients are incorrect!"), 'TestBwdGradients'),
  14190. DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'))),
  14191. OpInfo('bernoulli',
  14192. op=lambda inp, *args, **kwargs:
  14193. wrapper_set_seed(torch.bernoulli, inp, *args, **kwargs),
  14194. # The inplace variant (Tensor.bernoulli_) is different from torch.bernoulli
  14195. inplace_variant=None,
  14196. method_variant=lambda inp, *args, **kwargs:
  14197. wrapper_set_seed(torch.Tensor.bernoulli, inp, *args, **kwargs),
  14198. dtypes=floating_types_and(torch.bfloat16),
  14199. dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.half),
  14200. supports_out=True,
  14201. supports_forward_ad=True,
  14202. supports_fwgrad_bwgrad=True,
  14203. sample_inputs_func=sample_inputs_bernoulli,
  14204. error_inputs_func=error_inputs_bernoulli,
  14205. skips=(
  14206. # vmap: We do not yet support calling random operations inside of vmap
  14207. DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'),
  14208. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  14209. # AssertionError: JIT Test does not execute any logic
  14210. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  14211. # Expected RuntimeError when doing an unsafe cast from a result of
  14212. # dtype torch.float32 into an out= with dtype torch.lon
  14213. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
  14214. # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
  14215. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
  14216. DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo', 'test_nvfuser_extremal_values'),
  14217. DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'))),
  14218. OpInfo('scatter_add',
  14219. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
  14220. sample_inputs_func=sample_inputs_scatter_add,
  14221. error_inputs_func=error_inputs_scatter_and_scatter_add,
  14222. supports_forward_ad=True,
  14223. supports_fwgrad_bwgrad=True,
  14224. ),
  14225. OpInfo('stack',
  14226. dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
  14227. sample_inputs_func=sample_inputs_stack,
  14228. assert_autodiffed=True,
  14229. supports_forward_ad=True,
  14230. supports_fwgrad_bwgrad=True,
  14231. skips=(
  14232. # https://github.com/pytorch/pytorch/issues/77046
  14233. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
  14234. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
  14235. ),
  14236. ),
  14237. OpInfo('hstack',
  14238. dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
  14239. sample_inputs_func=sample_inputs_hstack_dstack_vstack,
  14240. error_inputs_func=error_inputs_hstack_dstack_vstack,
  14241. supports_forward_ad=True,
  14242. supports_fwgrad_bwgrad=True,
  14243. ),
  14244. BinaryUfuncInfo('hypot',
  14245. dtypes=floating_types_and(torch.bfloat16),
  14246. dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
  14247. supports_forward_ad=True,
  14248. supports_fwgrad_bwgrad=True,
  14249. supports_rhs_python_scalar=False),
  14250. OpInfo('histogram',
  14251. dtypes=floating_types(),
  14252. dtypesIfCUDA=_dispatch_dtypes(), # histogram is only implemented on CPU
  14253. sample_inputs_func=sample_inputs_histogram,
  14254. supports_autograd=False,
  14255. skips=(
  14256. # JIT tests don't work with Tensor keyword arguments
  14257. # https://github.com/pytorch/pytorch/issues/58507
  14258. # RuntimeError:
  14259. # undefined value tensor:
  14260. # File "<string>", line 3
  14261. # def the_method(i0):
  14262. # return torch.histogram(i0, 1, weight=tensor(-0.5735, dtype=torch.float32), density=False)
  14263. # ~~~~~~ <--- HERE
  14264. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  14265. # Not Implemented on XLA.
  14266. DecorateInfo(unittest.skip("Skipped!"), 'TestOpInfo', device_type='xla'),
  14267. )),
  14268. OpInfo('histogramdd',
  14269. dtypes=floating_types(),
  14270. dtypesIfCUDA=_dispatch_dtypes(), # histogramdd is only implemented on CPU
  14271. sample_inputs_func=sample_inputs_histogramdd,
  14272. supports_autograd=False,
  14273. skips=(
  14274. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  14275. # JIT tests don't work with Tensor keyword arguments
  14276. # https://github.com/pytorch/pytorch/issues/58507
  14277. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  14278. )),
  14279. OpInfo('histc',
  14280. dtypes=floating_types_and(torch.bfloat16),
  14281. dtypesIfCUDA=floating_types_and(torch.int8, torch.int16, torch.int32, torch.int64),
  14282. sample_inputs_func=sample_inputs_histc,
  14283. supports_out=True,
  14284. supports_autograd=False,
  14285. skips=(
  14286. # CUDA histc returns a float tensor but does not correctly warn when passed an integral out tensor
  14287. # "AssertionError: RuntimeError not raised : Expected RuntimeError when doing an unsafe cast
  14288. # from a result of dtype torch.float32 into an out= with dtype torch.long"
  14289. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cuda'),
  14290. DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo', 'test_nvfuser_extremal_values'),
  14291. )),
  14292. OpInfo('bincount',
  14293. dtypes=integral_types_and(),
  14294. sample_inputs_func=sample_inputs_bincount,
  14295. supports_out=False,
  14296. supports_autograd=False,
  14297. skips=(
  14298. # JIT tests don't work with Tensor keyword arguments
  14299. # https://github.com/pytorch/pytorch/issues/58507
  14300. DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
  14301. )),
  14302. OpInfo('bucketize',
  14303. dtypes=all_types_and(torch.float16, torch.bfloat16),
  14304. dtypesIfCUDA=all_types_and(torch.float16),
  14305. sample_inputs_func=sample_inputs_bucketize,
  14306. reference_inputs_func=reference_inputs_bucketize,
  14307. supports_autograd=False,
  14308. skips=(
  14309. # JIT tests don't work with Tensor keyword arguments
  14310. DecorateInfo(unittest.skip("Expected failure!"), 'TestJit', 'test_variant_consistency_jit'),
  14311. )),
  14312. OpInfo('searchsorted',
  14313. dtypes=all_types_and(torch.bfloat16, torch.float16),
  14314. dtypesIfCUDA=all_types_and(torch.float16),
  14315. sample_inputs_func=sample_inputs_searchsorted,
  14316. supports_autograd=False,
  14317. ref=reference_searchsorted,
  14318. skips=(
  14319. # JIT tests don't work with Tensor keyword arguments
  14320. # https://github.com/pytorch/pytorch/issues/58507
  14321. DecorateInfo(unittest.skip("Expected failure!"), 'TestJit', 'test_variant_consistency_jit'),
  14322. DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'),
  14323. )),
  14324. OpInfo('cat',
  14325. ref=_cat_np,
  14326. aliases=('concat', 'concatenate'),
  14327. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.complex32),
  14328. sample_inputs_func=sample_inputs_cat_concat,
  14329. reference_inputs_func=reference_inputs_cat,
  14330. error_inputs_func=error_inputs_cat,
  14331. # https://github.com/pytorch/pytorch/issues/80411
  14332. gradcheck_fast_mode=True,
  14333. supports_forward_ad=True,
  14334. supports_fwgrad_bwgrad=True,
  14335. assert_autodiffed=True,
  14336. skips=(
  14337. # https://github.com/pytorch/pytorch/issues/89353
  14338. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref_mps'),
  14339. # RuntimeError: Arguments for call not valid.
  14340. # Expected a value of type 'List[Tensor]' for argument
  14341. # 'tensors' but instead found type 'Tensor (inferred)'.
  14342. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'),
  14343. # see https://github.com/pytorch/pytorch/issues/71286
  14344. DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness'),)),
  14345. OpInfo('unbind',
  14346. dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
  14347. ref=reference_unbind,
  14348. sample_inputs_func=sample_inputs_unbind,
  14349. error_inputs_func=error_inputs_unbind,
  14350. supports_forward_ad=True,
  14351. supports_fwgrad_bwgrad=True,
  14352. supports_gradgrad=True,
  14353. supports_out=False,
  14354. ),
  14355. OpInfo('vstack',
  14356. aliases=('row_stack',),
  14357. dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
  14358. sample_inputs_func=sample_inputs_hstack_dstack_vstack,
  14359. error_inputs_func=error_inputs_hstack_dstack_vstack,
  14360. supports_forward_ad=True,
  14361. supports_fwgrad_bwgrad=True,
  14362. skips=(
  14363. # RuntimeError: _fn() Expected a value of type
  14364. # 'Tensor (inferred)' for argument 't0' but instead found type 'tuple'.
  14365. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'),)),
  14366. OpInfo('dstack',
  14367. dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
  14368. sample_inputs_func=sample_inputs_hstack_dstack_vstack,
  14369. error_inputs_func=error_inputs_hstack_dstack_vstack,
  14370. supports_forward_ad=True,
  14371. supports_fwgrad_bwgrad=True,
  14372. # See https://github.com/pytorch/pytorch/pull/78358
  14373. check_batched_forward_grad=False,
  14374. ),
  14375. OpInfo('unfold',
  14376. op=lambda x, *args: x.unfold(*args),
  14377. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
  14378. backward_dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
  14379. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  14380. gradcheck_fast_mode=True,
  14381. supports_out=False,
  14382. supports_forward_ad=True,
  14383. supports_fwgrad_bwgrad=True,
  14384. check_batched_gradgrad=False,
  14385. # See https://github.com/pytorch/pytorch/issues/66357
  14386. check_batched_forward_grad=False,
  14387. skips=(
  14388. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  14389. # Skip operator schema test because this is a functional and not an operator
  14390. DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
  14391. ),
  14392. sample_inputs_func=sample_inputs_unfold),
  14393. OpInfo('unfold_copy',
  14394. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
  14395. backward_dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
  14396. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  14397. gradcheck_fast_mode=True,
  14398. supports_out=True,
  14399. supports_forward_ad=True,
  14400. supports_fwgrad_bwgrad=True,
  14401. check_batched_gradgrad=False,
  14402. # See https://github.com/pytorch/pytorch/issues/66357
  14403. check_batched_forward_grad=False,
  14404. sample_inputs_func=sample_inputs_unfold),
  14405. OpInfo('msort',
  14406. dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
  14407. dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
  14408. check_batched_gradgrad=False,
  14409. supports_forward_ad=True,
  14410. supports_fwgrad_bwgrad=True,
  14411. sample_inputs_func=sample_inputs_msort,
  14412. skips=(
  14413. )),
  14414. OpInfo('movedim',
  14415. aliases=('moveaxis',),
  14416. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
  14417. supports_out=False,
  14418. supports_forward_ad=True,
  14419. supports_fwgrad_bwgrad=True,
  14420. # See https://github.com/pytorch/pytorch/pull/78358
  14421. check_batched_forward_grad=False,
  14422. sample_inputs_func=sample_movedim_moveaxis,
  14423. reference_inputs_func=reference_movedim_moveaxis,
  14424. error_inputs_func=error_movedim_moveaxis),
  14425. OpInfo('renorm',
  14426. dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
  14427. sample_inputs_func=sample_inputs_renorm,
  14428. error_inputs_func=error_inputs_renorm,
  14429. skips=(
  14430. # RuntimeError: Difference from float64 is larger with decomposition
  14431. # linalg_vector_norm.default than original on output 0.
  14432. # Original max diff: 2.560596747969157e-07,
  14433. # Decomp max diff: 1.8187482915266173e-06
  14434. DecorateInfo(unittest.skip("Inconsistent accuracy"), 'TestDecomp', 'test_comprehensive',
  14435. device_type='cpu', dtypes=(torch.float16,)),
  14436. )),
  14437. ShapeFuncInfo('repeat',
  14438. op=lambda x, dims: x.repeat(dims),
  14439. ref=np.tile,
  14440. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
  14441. # https://github.com/pytorch/pytorch/issues/80411
  14442. gradcheck_fast_mode=True,
  14443. supports_out=False,
  14444. supports_forward_ad=True,
  14445. supports_fwgrad_bwgrad=True,
  14446. sample_inputs_func=sample_repeat_tile,
  14447. skips=(
  14448. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  14449. )),
  14450. OpInfo('squeeze',
  14451. ref=_squeeze_ref,
  14452. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
  14453. supports_out=False,
  14454. assert_autodiffed=True,
  14455. autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused
  14456. autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused
  14457. assert_jit_shape_analysis=True,
  14458. supports_forward_ad=True,
  14459. supports_fwgrad_bwgrad=True,
  14460. # vmap does not support inplace views
  14461. check_inplace_batched_forward_grad=False,
  14462. # https://github.com/pytorch/pytorch/issues/66357
  14463. check_batched_forward_grad=False,
  14464. sample_inputs_func=sample_inputs_squeeze),
  14465. OpInfo('squeeze',
  14466. ref=_squeeze_ref,
  14467. variant_test_name="multiple",
  14468. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
  14469. supports_out=False,
  14470. assert_autodiffed=True,
  14471. autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused
  14472. autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused
  14473. supports_forward_ad=True,
  14474. supports_fwgrad_bwgrad=True,
  14475. # vmap does not support inplace views
  14476. check_inplace_batched_forward_grad=False,
  14477. # https://github.com/pytorch/pytorch/issues/66357
  14478. check_batched_forward_grad=False,
  14479. sample_inputs_func=sample_inputs_squeeze_multiple),
  14480. UnaryUfuncInfo(
  14481. 'fill',
  14482. ref=_fill_np,
  14483. method_variant=None,
  14484. sample_kwargs=_fill_sample_kwargs,
  14485. sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'value': True}),
  14486. supports_forward_ad=True,
  14487. supports_fwgrad_bwgrad=True,
  14488. # https://github.com/pytorch/pytorch/issues/66357
  14489. check_batched_forward_grad=False,
  14490. dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
  14491. supports_out=False,
  14492. skips=(
  14493. # JIT has issue when op is passed as lambda
  14494. # AssertionError: JIT Test does not execute any logic
  14495. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  14496. DecorateInfo(unittest.skip("No fill_ op"), 'TestCudaFuserOpInfo'),
  14497. DecorateInfo(unittest.skip("No fill_ op"), 'TestNNCOpInfo'),
  14498. )),
  14499. OpInfo('resize_',
  14500. op=lambda x, shape: x.clone().resize_(shape),
  14501. method_variant=None,
  14502. inplace_variant=torch.Tensor.resize_,
  14503. # the test fails because resize_ doesn't work with imag views as expected by the test
  14504. # https://github.com/pytorch/pytorch/issues/65945
  14505. test_neg_view=False,
  14506. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
  14507. supports_out=False,
  14508. supports_autograd=False,
  14509. skips=(
  14510. # Cannot resize variables that require grad
  14511. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'),
  14512. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  14513. DecorateInfo(unittest.skip("Allowed exception"), 'TestCompositeCompliance', 'test_operator'),
  14514. ),
  14515. sample_inputs_func=sample_inputs_resize_ops),
  14516. OpInfo('resize_as_',
  14517. op=lambda x, other: torch.resize_as_(x.clone(), other),
  14518. method_variant=None,
  14519. inplace_variant=torch.Tensor.resize_as_,
  14520. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
  14521. supports_out=False,
  14522. supports_autograd=False,
  14523. skips=(
  14524. # Cannot resize variables that require grad
  14525. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'),
  14526. DecorateInfo(unittest.skip('Allowed exemption'), 'TestCompositeCompliance', 'test_operator'),
  14527. ),
  14528. sample_inputs_func=sample_inputs_resize_ops),
  14529. OpInfo('take_along_dim',
  14530. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
  14531. dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
  14532. supports_inplace_autograd=False,
  14533. supports_forward_ad=True,
  14534. supports_fwgrad_bwgrad=True,
  14535. # See https://github.com/pytorch/pytorch/pull/78358
  14536. check_batched_forward_grad=False,
  14537. sample_inputs_func=sample_inputs_take_along_dim,
  14538. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL),
  14539. ShapeFuncInfo('tile',
  14540. ref=np.tile,
  14541. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
  14542. # https://github.com/pytorch/pytorch/issues/80411
  14543. gradcheck_fast_mode=True,
  14544. supports_out=False,
  14545. supports_forward_ad=True,
  14546. supports_fwgrad_bwgrad=True,
  14547. sample_inputs_func=sample_repeat_tile),
  14548. OpInfo('trapz', # TODO: in the future, 'trapz' should be made a proper alias of 'trapezoid'
  14549. dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
  14550. supports_out=False,
  14551. supports_forward_ad=True,
  14552. supports_fwgrad_bwgrad=True,
  14553. # See https://github.com/pytorch/pytorch/pull/78358
  14554. check_batched_forward_grad=False,
  14555. sample_inputs_func=sample_trapezoid),
  14556. OpInfo('trapezoid',
  14557. dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
  14558. supports_out=False,
  14559. supports_forward_ad=True,
  14560. supports_fwgrad_bwgrad=True,
  14561. # See https://github.com/pytorch/pytorch/pull/78358
  14562. check_batched_forward_grad=False,
  14563. sample_inputs_func=sample_trapezoid),
  14564. OpInfo('cumulative_trapezoid',
  14565. dtypes=all_types_and_complex_and(torch.bfloat16),
  14566. dtypesIfCUDA=all_types_and_complex_and(torch.bfloat16, torch.float16),
  14567. supports_forward_ad=True,
  14568. supports_fwgrad_bwgrad=True,
  14569. # See https://github.com/pytorch/pytorch/pull/78358
  14570. check_batched_forward_grad=False,
  14571. supports_out=False,
  14572. sample_inputs_func=sample_cumulative_trapezoid,),
  14573. OpInfo('unsqueeze',
  14574. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
  14575. supports_out=False,
  14576. supports_forward_ad=True,
  14577. supports_fwgrad_bwgrad=True,
  14578. # See https://github.com/pytorch/pytorch/pull/78358
  14579. check_batched_forward_grad=False,
  14580. # vmap does not support inplace views
  14581. check_inplace_batched_forward_grad=False,
  14582. assert_jit_shape_analysis=True,
  14583. assert_autodiffed=True,
  14584. autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused
  14585. autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused
  14586. sample_inputs_func=sample_unsqueeze),
  14587. BinaryUfuncInfo('xlogy',
  14588. aliases=('special.xlogy',),
  14589. dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
  14590. promotes_int_to_float=True,
  14591. supports_forward_ad=True,
  14592. supports_fwgrad_bwgrad=True,
  14593. supports_one_python_scalar=True,
  14594. # We don't test 0 as the gradient will be NaN and it'll break
  14595. rhs_make_tensor_kwargs=dict(low=0.01)),
  14596. OpInfo('zero_',
  14597. op=lambda x: torch.zero_(x.clone()),
  14598. method_variant=None,
  14599. inplace_variant=torch.Tensor.zero_,
  14600. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
  14601. # https://github.com/pytorch/pytorch/issues/80411
  14602. gradcheck_fast_mode=True,
  14603. supports_out=False,
  14604. supports_forward_ad=True,
  14605. supports_fwgrad_bwgrad=True,
  14606. supports_gradgrad=True,
  14607. skips=(
  14608. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  14609. ),
  14610. sample_inputs_func=sample_inputs_zero_),
  14611. OpInfo('logsumexp',
  14612. aliases=('special.logsumexp',),
  14613. dtypes=all_types_and(torch.bool, torch.bfloat16),
  14614. dtypesIfCUDA=all_types_and(torch.bool, torch.bfloat16, torch.half),
  14615. assert_autodiffed=True,
  14616. supports_forward_ad=True,
  14617. supports_fwgrad_bwgrad=True,
  14618. gradcheck_fast_mode=False,
  14619. sample_inputs_func=sample_inputs_logsumexp,
  14620. reference_inputs_func=reference_inputs_logsumexp),
  14621. OpInfo('trace',
  14622. dtypes=all_types_and_complex(),
  14623. dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
  14624. error_inputs_func=error_inputs_trace,
  14625. supports_inplace_autograd=False,
  14626. supports_out=False,
  14627. supports_forward_ad=True,
  14628. supports_fwgrad_bwgrad=True,
  14629. sample_inputs_func=sample_inputs_trace),
  14630. OpInfo('transpose',
  14631. ref=_numpy_ref_transpose,
  14632. aliases=('swapdims', 'swapaxes'),
  14633. assert_jit_shape_analysis=True,
  14634. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf),
  14635. supports_out=False,
  14636. supports_forward_ad=True,
  14637. supports_fwgrad_bwgrad=True,
  14638. # vmap does not support inplace views
  14639. check_inplace_batched_forward_grad=False,
  14640. sample_inputs_func=sample_inputs_transpose_swapdims),
  14641. OpInfo('T',
  14642. op=lambda x: x.T,
  14643. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf),
  14644. supports_out=False,
  14645. supports_forward_ad=True,
  14646. supports_fwgrad_bwgrad=True,
  14647. skips=(
  14648. # lambda impl
  14649. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  14650. DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"),),
  14651. sample_inputs_func=sample_inputs_T,
  14652. error_inputs_func=error_inputs_T),
  14653. OpInfo('H',
  14654. op=lambda x: x.H,
  14655. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf),
  14656. supports_out=False,
  14657. supports_forward_ad=True,
  14658. supports_fwgrad_bwgrad=True,
  14659. # See https://github.com/pytorch/pytorch/pull/78358
  14660. check_batched_forward_grad=False,
  14661. skips=(
  14662. # lambda impl
  14663. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  14664. DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"),),
  14665. sample_inputs_func=sample_inputs_T),
  14666. OpInfo('mT',
  14667. op=lambda x: x.mT,
  14668. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf),
  14669. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  14670. gradcheck_fast_mode=True,
  14671. supports_out=False,
  14672. supports_forward_ad=True,
  14673. supports_fwgrad_bwgrad=True,
  14674. skips=(
  14675. # lambda impl
  14676. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  14677. DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"),),
  14678. sample_inputs_func=sample_inputs_adjoint),
  14679. OpInfo('mH',
  14680. op=lambda x: x.mH,
  14681. aliases=('adjoint',),
  14682. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf),
  14683. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  14684. gradcheck_fast_mode=True,
  14685. supports_out=False,
  14686. supports_forward_ad=True,
  14687. supports_fwgrad_bwgrad=True,
  14688. # See https://github.com/pytorch/pytorch/pull/78358
  14689. check_batched_forward_grad=False,
  14690. skips=(
  14691. # lambda impl
  14692. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  14693. DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"),),
  14694. sample_inputs_func=sample_inputs_adjoint),
  14695. OpInfo('tril',
  14696. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
  14697. dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half),
  14698. supports_forward_ad=True,
  14699. supports_fwgrad_bwgrad=True,
  14700. error_inputs_func=error_inputs_tril_triu,
  14701. sample_inputs_func=sample_inputs_tril_triu),
  14702. OpInfo('triu',
  14703. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
  14704. dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half),
  14705. supports_forward_ad=True,
  14706. supports_fwgrad_bwgrad=True,
  14707. error_inputs_func=error_inputs_tril_triu,
  14708. sample_inputs_func=sample_inputs_tril_triu),
  14709. OpInfo('triu_indices',
  14710. dtypes=_dispatch_dtypes((torch.int32, torch.int64)),
  14711. sample_inputs_func=sample_inputs_trilu_indices,
  14712. ref=lambda h, w, ofs=0, dtype=torch.long, device='cpu' : np.array(np.triu_indices(h, ofs, w), dtype=dtype),
  14713. supports_out=False,
  14714. supports_autograd=False,
  14715. skips=(
  14716. # skip these tests since we have non tensor input
  14717. DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_noncontiguous_samples'),
  14718. DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'),
  14719. DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'),
  14720. DecorateInfo(unittest.skip('Skipped!'), 'TestMathBits', 'test_neg_view'),
  14721. )),
  14722. OpInfo('tril_indices',
  14723. dtypes=_dispatch_dtypes((torch.int32, torch.int64)),
  14724. sample_inputs_func=sample_inputs_trilu_indices,
  14725. ref=lambda h, w, ofs=0, dtype=torch.long, device='cpu' : np.array(np.tril_indices(h, ofs, w), dtype=dtype),
  14726. supports_out=False,
  14727. supports_autograd=False,
  14728. skips=(
  14729. # skip these tests since we have non tensor input
  14730. DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_noncontiguous_samples'),
  14731. DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'),
  14732. DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'),
  14733. DecorateInfo(unittest.skip('Skipped!'), 'TestMathBits', 'test_neg_view'),
  14734. )),
  14735. OpInfo('kron',
  14736. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
  14737. dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
  14738. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  14739. gradcheck_fast_mode=True,
  14740. supports_inplace_autograd=False,
  14741. supports_forward_ad=True,
  14742. supports_fwgrad_bwgrad=True,
  14743. sample_inputs_func=sample_inputs_kron),
  14744. OpInfo('inner',
  14745. dtypes=all_types_and_complex_and(torch.bfloat16),
  14746. dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
  14747. dtypesIfROCM=floating_and_complex_types_and(torch.half, torch.bfloat16),
  14748. supports_forward_ad=True,
  14749. supports_fwgrad_bwgrad=True,
  14750. # See https://github.com/pytorch/pytorch/pull/78358
  14751. check_batched_forward_grad=False,
  14752. sample_inputs_func=sample_inputs_inner,
  14753. ),
  14754. OpInfo('tensordot',
  14755. dtypes=all_types_and_complex_and(torch.bfloat16),
  14756. dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
  14757. dtypesIfROCM=floating_and_complex_types_and(torch.half, torch.bfloat16),
  14758. supports_forward_ad=True,
  14759. supports_fwgrad_bwgrad=True,
  14760. # See https://github.com/pytorch/pytorch/pull/78358
  14761. check_batched_forward_grad=False,
  14762. sample_inputs_func=sample_inputs_tensordot,
  14763. skips=(
  14764. # Skip operator schema test because this is a functional and not an operator.
  14765. # Reference: https://github.com/pytorch/pytorch/issues/54574
  14766. DecorateInfo(unittest.skip("Skipped!"), 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
  14767. )
  14768. ),
  14769. OpInfo('to_sparse',
  14770. op=lambda x, *args: x.to_sparse(*args),
  14771. sample_inputs_func=sample_inputs_to_sparse,
  14772. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
  14773. backward_dtypes=floating_types(),
  14774. backward_dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  14775. supports_out=False,
  14776. supports_sparse_csr=True,
  14777. supports_sparse_csc=True,
  14778. check_batched_grad=False,
  14779. check_batched_gradgrad=False,
  14780. skips=(
  14781. # NotImplementedError: Could not run 'aten::normal_' with arguments from the 'SparseCPU' backend
  14782. DecorateInfo(unittest.skip(""), 'TestCommon', 'test_noncontiguous_samples'),
  14783. # TODO: FIXME: complex inputs requiring grad error in forward
  14784. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes'),
  14785. # lambda impl
  14786. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  14787. # Allowed exception: sparse tensors don't have strides
  14788. DecorateInfo(unittest.skip("Allowed exception"), 'TestCompositeCompliance', 'test_operator'),
  14789. DecorateInfo(unittest.skip("Allowed exception"), 'TestCompositeCompliance', 'test_backward'),
  14790. DecorateInfo(unittest.skip("Allowed exception"), 'TestTags', 'test_tags'),
  14791. # TODO: implement csr.to_sparse(sample_dim) where sampled_dim is 1.
  14792. DecorateInfo(unittest.skip("csr.to_sparse(1) not implemented. Skipped!"),
  14793. 'TestSparseCSR', 'test_sparse_csr_consistency'),
  14794. # Compiler issue on ROCm. Might need to skip until ROCm5.5
  14795. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values',
  14796. dtypes=[torch.bool], active_if=TEST_WITH_ROCM),
  14797. )
  14798. ),
  14799. OpInfo('logcumsumexp',
  14800. dtypes=floating_and_complex_types_and(torch.bfloat16),
  14801. dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
  14802. backward_dtypes=floating_and_complex_types_and(torch.bfloat16),
  14803. backward_dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16),
  14804. skips=(
  14805. # AssertionError: UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
  14806. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type='cuda'),
  14807. ),
  14808. sample_inputs_func=sample_inputs_logcumsumexp,
  14809. error_inputs_func=error_inputs_logcumsumexp),
  14810. UnaryUfuncInfo('sigmoid',
  14811. aliases=('special.expit', 'nn.functional.sigmoid'),
  14812. aten_backward_name='sigmoid_backward',
  14813. ref=reference_sigmoid if TEST_SCIPY else None,
  14814. decorators=(precisionOverride({torch.float16: 1e-2,
  14815. torch.complex64: 1e-1,
  14816. torch.bfloat16: 1e-2}),),
  14817. skips=(
  14818. # Reference: https://github.com/pytorch/pytorch/issues/56012
  14819. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  14820. dtypes=[torch.complex64, torch.cdouble]),
  14821. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  14822. dtypes=[torch.chalf, torch.complex64, torch.cdouble])),
  14823. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
  14824. dtypesIfCUDA=all_types_and_complex_and(torch.complex32, torch.bool, torch.half, torch.bfloat16),
  14825. supports_forward_ad=True,
  14826. supports_fwgrad_bwgrad=True,
  14827. assert_autodiffed=True,
  14828. # sigmoid(z) = 1 / (1 + exp(-z)), at z = j * pi * odd_number, the denominator is zero
  14829. reference_numerics_filter=NumericsFilter(
  14830. condition=lambda x: (close_to_int(x / (math.pi * 1j))
  14831. if x.is_complex() else x.new_tensor(False, dtype=torch.bool)),
  14832. safe_val=0)),
  14833. UnaryUfuncInfo('digamma',
  14834. ref=scipy.special.digamma if TEST_SCIPY else None,
  14835. aliases=('special.psi', 'special.digamma',),
  14836. decorators=(precisionOverride({torch.float16: 5e-1}),),
  14837. dtypes=all_types_and(torch.bool, torch.bfloat16),
  14838. dtypesIfCUDA=all_types_and(torch.bool, torch.half),
  14839. supports_forward_ad=True,
  14840. supports_fwgrad_bwgrad=True),
  14841. UnaryUfuncInfo('erf',
  14842. ref=scipy.special.erf if TEST_SCIPY else None,
  14843. aliases=('special.erf', ),
  14844. decorators=(precisionOverride({torch.float16: 1e-2,
  14845. torch.bfloat16: 1e-2}),),
  14846. skips=(
  14847. DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
  14848. 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
  14849. ),
  14850. dtypes=all_types_and(torch.bool, torch.bfloat16),
  14851. dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
  14852. assert_autodiffed=True,
  14853. assert_jit_shape_analysis=True,
  14854. supports_sparse=True,
  14855. supports_sparse_csr=True,
  14856. supports_sparse_csc=True,
  14857. supports_sparse_bsr=True,
  14858. supports_sparse_bsc=True,
  14859. supports_forward_ad=True,
  14860. supports_fwgrad_bwgrad=True),
  14861. UnaryUfuncInfo('erfc',
  14862. ref=scipy.special.erfc if TEST_SCIPY else None,
  14863. aliases=('special.erfc', ),
  14864. decorators=(precisionOverride({torch.float16: 1e-2,
  14865. torch.bfloat16: 1e-2}),),
  14866. dtypes=all_types_and(torch.bool, torch.bfloat16),
  14867. dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
  14868. assert_autodiffed=True,
  14869. supports_forward_ad=True,
  14870. supports_fwgrad_bwgrad=True),
  14871. UnaryUfuncInfo('erfinv',
  14872. ref=scipy.special.erfinv if TEST_SCIPY else None,
  14873. aliases=('special.erfinv', ),
  14874. decorators=(precisionOverride({torch.float16: 1e-2,
  14875. torch.bfloat16: 1e-2,
  14876. torch.float32: 1e-4}),),
  14877. dtypes=all_types_and(torch.bool, torch.bfloat16),
  14878. dtypesIfCUDA=all_types_and(torch.bool, torch.half),
  14879. supports_sparse_csr=True,
  14880. supports_sparse_csc=True,
  14881. supports_sparse_bsr=True,
  14882. supports_sparse_bsc=True,
  14883. supports_forward_ad=True,
  14884. supports_fwgrad_bwgrad=True,
  14885. domain=(-1, 1),
  14886. skips=(
  14887. # Reference: https://github.com/pytorch/pytorch/pull/49155#issuecomment-742664611
  14888. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  14889. active_if=TEST_SCIPY and LooseVersion(scipy.__version__) < "1.4.0"),
  14890. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  14891. active_if=TEST_SCIPY and LooseVersion(scipy.__version__) < "1.4.0"),
  14892. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
  14893. active_if=TEST_SCIPY and LooseVersion(scipy.__version__) < "1.4.0"),
  14894. )),
  14895. OpInfo("nn.functional.smooth_l1_loss",
  14896. ref=reference_smooth_l1_loss,
  14897. sample_inputs_func=sample_inputs_smooth_l1_loss,
  14898. dtypes=floating_types_and(torch.float16, torch.bfloat16),
  14899. backward_dtypes=floating_types_and(torch.bfloat16),
  14900. dtypesIfCUDA=floating_types_and(torch.float16),
  14901. backward_dtypesIfCUDA=floating_types_and(torch.float16),
  14902. supports_out=False,
  14903. supports_forward_ad=True,
  14904. supports_fwgrad_bwgrad=True,
  14905. skips=(
  14906. # RuntimeError: input->type()->kind() == TypeKind::OptionalTypeINTERNAL ASSERT FAILED
  14907. # at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270, please report a bug to PyTorch.
  14908. DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"),)),
  14909. OpInfo(
  14910. "nn.functional.l1_loss",
  14911. ref=loss_reference_reduction_wrapper(lambda input, target: np.abs(input - target)),
  14912. sample_inputs_func=sample_inputs_l1_loss,
  14913. error_inputs_func=error_inputs_l1_loss,
  14914. dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
  14915. supports_out=False,
  14916. supports_forward_ad=True,
  14917. supports_fwgrad_bwgrad=True,
  14918. skips=(
  14919. # RuntimeError: input->type()->kind() == TypeKind::OptionalTypeINTERNAL ASSERT FAILED
  14920. # at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270, please report a bug to PyTorch.
  14921. DecorateInfo(
  14922. unittest.expectedFailure,
  14923. "TestJit",
  14924. "test_variant_consistency_jit",
  14925. dtypes=(torch.float32,),
  14926. ),
  14927. ),
  14928. ),
  14929. UnaryUfuncInfo('lgamma',
  14930. ref=reference_lgamma if TEST_SCIPY else None,
  14931. aliases=('special.gammaln', ),
  14932. decorators=(precisionOverride({torch.float16: 7e-1}),),
  14933. dtypes=all_types_and(torch.bool, torch.bfloat16),
  14934. dtypesIfCUDA=all_types_and(torch.bool, torch.half),
  14935. supports_forward_ad=True,
  14936. supports_fwgrad_bwgrad=True,
  14937. skips=(
  14938. # Reference: https://github.com/pytorch/pytorch/pull/50140#discussion_r552615345
  14939. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  14940. dtypes=[torch.bfloat16]),
  14941. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  14942. device_type='cpu', dtypes=[torch.bfloat16]),
  14943. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
  14944. device_type='cpu', dtypes=[torch.bfloat16]),
  14945. # Reference: https://github.com/pytorch/pytorch/pull/50140#issuecomment-756150214
  14946. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  14947. dtypes=[torch.float32, torch.float64], active_if=IS_WINDOWS),
  14948. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  14949. dtypes=[torch.float32, torch.float64], active_if=IS_WINDOWS),
  14950. ),
  14951. # lgamma have multiple singularities at x <= 0
  14952. reference_numerics_filter=NumericsFilter(condition=lambda x: x < 0.1, safe_val=1)),
  14953. OpInfo(
  14954. 'logdet',
  14955. dtypes=floating_and_complex_types(),
  14956. supports_out=False,
  14957. supports_forward_ad=True,
  14958. supports_fwgrad_bwgrad=True,
  14959. sample_inputs_func=sample_inputs_linalg_det_logdet_slogdet,
  14960. decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack]),
  14961. # `log_softmax` supports different dtypes based on whether `dtype` argument,
  14962. # is passed or not. Hence two OpInfo entries, one with dtype and other without.
  14963. OpInfo(
  14964. 'log_softmax',
  14965. aliases=('special.log_softmax', 'nn.functional.log_softmax'),
  14966. supports_out=True,
  14967. aten_backward_name='_log_softmax_backward_data',
  14968. dtypes=floating_types_and(torch.bfloat16),
  14969. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  14970. sample_inputs_func=sample_inputs_softmax_variant,
  14971. supports_forward_ad=True,
  14972. supports_fwgrad_bwgrad=True,
  14973. assert_autodiffed=True),
  14974. OpInfo(
  14975. 'log_softmax',
  14976. variant_test_name='with_dtype',
  14977. aliases=('special.log_softmax', 'nn.functional.log_softmax'),
  14978. supports_out=True,
  14979. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
  14980. sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True),
  14981. supports_forward_ad=True,
  14982. supports_fwgrad_bwgrad=True,
  14983. assert_autodiffed=True),
  14984. UnaryUfuncInfo('logit',
  14985. aten_backward_name='logit_backward',
  14986. ref=scipy.special.logit if TEST_SCIPY else None,
  14987. domain=(0, 1),
  14988. aliases=('special.logit', ),
  14989. supports_forward_ad=True,
  14990. supports_fwgrad_bwgrad=True,
  14991. decorators=(precisionOverride({torch.bfloat16: 5e-1,
  14992. torch.float16: 5e-1}),),
  14993. dtypes=all_types_and(torch.bool, torch.bfloat16),
  14994. dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
  14995. sample_inputs_func=sample_inputs_logit),
  14996. OpInfo('where',
  14997. # Currently only the `input` is tested in gradcheck.
  14998. # If we pass `condition` first, none of the input which supports
  14999. # autograd will be tested. Hence the following lambda.
  15000. op=lambda self, condition, other: torch.where(condition, self, other),
  15001. ref=lambda self, condition, other: np.where(condition, self, other),
  15002. sample_inputs_func=sample_inputs_where,
  15003. reference_inputs_func=reference_inputs_where,
  15004. error_inputs_func=error_inputs_where,
  15005. supports_out=False,
  15006. supports_forward_ad=True,
  15007. supports_fwgrad_bwgrad=True,
  15008. decorators=(
  15009. DecorateInfo(onlyCUDA, "TestCommon", 'test_errors'),),
  15010. skips=(
  15011. # lambda impl
  15012. DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
  15013. DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
  15014. ),
  15015. dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf)),
  15016. OpInfo('nonzero',
  15017. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
  15018. sample_inputs_func=sample_inputs_nonzero,
  15019. supports_autograd=False,
  15020. skips=(
  15021. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  15022. # nonzero(): argument 'out' must be Tensor, not tuple
  15023. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
  15024. # https://github.com/pytorch/pytorch/issues/67458
  15025. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  15026. # nonzero is not raising a warning when the out is resized
  15027. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
  15028. # Can't find schemas for this operator for some reason
  15029. DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
  15030. # Compiler issue on ROCm. Might need to skip until ROCm5.5
  15031. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values',
  15032. dtypes=[torch.bool], active_if=TEST_WITH_ROCM),
  15033. )),
  15034. # Following tests are for jiterator's python interface
  15035. # Jiterator can be used to author elementwise CUDA kernel
  15036. # jiterator._create_jit_fn returns a callable that behaves like a regular pytorch op
  15037. # See create_jit_fn in jiterator.py for more information
  15038. UnaryUfuncInfo(
  15039. 'jiterator_unary',
  15040. op=torch.cuda.jiterator._create_jit_fn("template <typename T> T unary(T x) { return x * x + x; }"),
  15041. ref=lambda x: x * x + x,
  15042. dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool),
  15043. supports_out=False,
  15044. supports_autograd=False, # jiterator ops doesn't have backward defined
  15045. decorators=[
  15046. onlyCUDA,
  15047. DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
  15048. 'TestUnaryUfuncs', 'test_reference_numerics_extremal'),
  15049. DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
  15050. 'TestUnaryUfuncs', 'test_reference_numerics_hard'),
  15051. DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
  15052. 'TestUnaryUfuncs', 'test_reference_numerics_normal'),
  15053. DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
  15054. 'TestUnaryUfuncs', 'test_reference_numerics_small'),
  15055. ],
  15056. skips=(
  15057. # Jiterator ops doesn't support neg or conj view
  15058. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
  15059. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
  15060. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
  15061. # Jiterator ops doesn't suport CompositeCompliantTensor
  15062. # Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped
  15063. DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'),
  15064. # Skip reference_numerics tests for bool type, as the defined function doesn't work for bool
  15065. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
  15066. dtypes=[torch.bool]),
  15067. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard',
  15068. dtypes=[torch.bool]),
  15069. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal',
  15070. dtypes=[torch.bool]),
  15071. # ROCm generates -inf+infj instead of nan+infj for complex64 for some of the results
  15072. DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
  15073. dtypes=[torch.complex64], active_if=TEST_WITH_ROCM),
  15074. # Expected failure: torch.jiterator_unary is not a valid op
  15075. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  15076. # Skip Nvfuser
  15077. DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo'),
  15078. )
  15079. ),
  15080. BinaryUfuncInfo(
  15081. 'jiterator_binary',
  15082. op=torch.cuda.jiterator._create_jit_fn(
  15083. "template <typename T> T binary(T x, T y, T alpha) { return x + alpha * y; }", alpha=1),
  15084. ref=lambda input, other, *, alpha=1: np.add(input, other) if alpha == 1 \
  15085. else np.add(input, np.multiply(alpha, other)),
  15086. dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool),
  15087. sample_inputs_func=partial(sample_inputs_jiterator, num_inputs=2, alpha=-3.14),
  15088. supports_out=False,
  15089. supports_autograd=False, # jiterator ops doesn't have backward defined
  15090. supports_rhs_python_scalar=False,
  15091. decorators=[onlyCUDA],
  15092. skips=(
  15093. # Jiterator ops doesn't support neg or conj view
  15094. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
  15095. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
  15096. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
  15097. # Jiterator ops doesn't suport CompositeCompliantTensor
  15098. # Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped
  15099. DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'),
  15100. # Expected failure: torch.jiterator_binary is not a valid op
  15101. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  15102. # Skip Nvfuser
  15103. DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo'),
  15104. )
  15105. ),
  15106. OpInfo(
  15107. 'jiterator_4inputs_with_extra_args',
  15108. op=torch.cuda.jiterator._create_jit_fn(
  15109. "template <typename T> T binary(T i0, T i1, T i2, T i3, T alpha, T beta) { return alpha * i0 + beta * i1 + i2 + i3; }",
  15110. alpha=1, beta=1),
  15111. ref=lambda i0, i1, i2, i3, *, alpha=1, beta=1: alpha * i0 + beta * i1 + i2 + i3,
  15112. dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool),
  15113. sample_inputs_func=partial(sample_inputs_jiterator, num_inputs=4, alpha=3.14, beta=-4.20),
  15114. supports_out=False,
  15115. supports_autograd=False, # jiterator ops doesn't have backward defined
  15116. decorators=[onlyCUDA],
  15117. skips=(
  15118. # Jiterator ops doesn't support neg or conj view
  15119. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
  15120. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
  15121. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
  15122. # Jiterator ops doesn't suport CompositeCompliantTensor
  15123. # Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped
  15124. DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'),
  15125. # Expected failure: torch.jiterator_4inputs_with_extra_args is not a valid op
  15126. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  15127. # Skip Nvfuser
  15128. DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo'),
  15129. )
  15130. ),
  15131. BinaryUfuncInfo(
  15132. 'jiterator_binary_return_by_ref',
  15133. op=torch.cuda.jiterator._create_multi_output_jit_fn(
  15134. """
  15135. template <typename T>
  15136. void binary_return_by_ref(T i0, T i1, T& out0) {
  15137. out0 = i0 + i1;
  15138. }
  15139. """,
  15140. num_outputs=1),
  15141. ref=lambda i0, i1: i0 + i1,
  15142. dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool),
  15143. sample_inputs_func=partial(sample_inputs_jiterator, num_inputs=2, alpha=-0.42),
  15144. supports_out=False,
  15145. supports_autograd=False, # jiterator ops doesn't have backward defined
  15146. supports_rhs_python_scalar=False,
  15147. decorators=[onlyCUDA],
  15148. skips=(
  15149. # Jiterator ops doesn't support neg or conj view
  15150. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
  15151. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
  15152. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
  15153. # Jiterator ops doesn't suport CompositeCompliantTensor
  15154. # Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped
  15155. DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'),
  15156. # Expected failure: torch.jiterator_4inputs_with_extra_args is not a valid op
  15157. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  15158. # Skip Nvfuser
  15159. DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo'),
  15160. )
  15161. ),
  15162. OpInfo(
  15163. 'jiterator_2inputs_2outputs',
  15164. op=torch.cuda.jiterator._create_multi_output_jit_fn(
  15165. """
  15166. template <typename T>
  15167. void binary_2outputs(T i0, T i1, T& out0, T& out1) {
  15168. out0 = i0 + i1;
  15169. out1 = i0 - i1;
  15170. }
  15171. """,
  15172. num_outputs=2),
  15173. ref=lambda i0, i1, *, alpha=1: (i0 + i1, i0 - i1),
  15174. dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool),
  15175. sample_inputs_func=partial(sample_inputs_jiterator, num_inputs=2),
  15176. supports_out=False,
  15177. supports_autograd=False, # jiterator ops doesn't have backward defined
  15178. decorators=[onlyCUDA],
  15179. skips=(
  15180. # Jiterator ops doesn't support neg or conj view
  15181. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
  15182. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
  15183. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
  15184. # Jiterator ops doesn't suport CompositeCompliantTensor
  15185. # Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped
  15186. DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'),
  15187. # Expected failure: torch.jiterator_4inputs_with_extra_args is not a valid op
  15188. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  15189. # Skip Nvfuser
  15190. DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo'),
  15191. )
  15192. ),
  15193. # `torch.norm` has multiple code paths depending on the value of `p`.
  15194. # These paths have different dtype support. Also JIT supports,
  15195. # most variants but not all of them. So we split the OpInfo entries,
  15196. # for `norm` based on the code-paths and JIT support.
  15197. OpInfo(
  15198. "norm",
  15199. sample_inputs_func=sample_inputs_norm,
  15200. dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
  15201. # TODO Benchmark again with the new implementation
  15202. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  15203. gradcheck_fast_mode=True,
  15204. check_batched_forward_grad=False,
  15205. supports_forward_ad=True,
  15206. supports_fwgrad_bwgrad=True,
  15207. skips=(
  15208. # Dispatches in Python to vector_norm. Not sure how to make this test happy
  15209. # Happens to pass on complex64. Also a mystery
  15210. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',
  15211. dtypes=(torch.float32,)),)
  15212. ),
  15213. OpInfo('norm',
  15214. variant_test_name='nuc',
  15215. sample_inputs_func=sample_inputs_norm_nuc,
  15216. decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
  15217. check_batched_gradgrad=False,
  15218. # torch.autograd.gradcheck.GradcheckError: While computing batched gradients
  15219. # got: Could not allocate memory to change Tensor SizesAndStrides!
  15220. check_batched_forward_grad=False,
  15221. supports_forward_ad=True,
  15222. supports_fwgrad_bwgrad=True,
  15223. dtypes=floating_and_complex_types(),
  15224. dtypesIfCUDA=floating_and_complex_types(),
  15225. skips=(
  15226. # Dispatches in Python to matrix_norm. Not sure how to make this test happy
  15227. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',
  15228. dtypes=(torch.complex64, torch.float32,)),)
  15229. ),
  15230. OpInfo('norm',
  15231. variant_test_name='fro',
  15232. sample_inputs_func=sample_inputs_norm_fro,
  15233. dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16),
  15234. dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
  15235. supports_forward_ad=True,
  15236. # torch.autograd.gradcheck.GradcheckError: While computing batched gradients
  15237. # got: Could not allocate memory to change Tensor SizesAndStrides!
  15238. check_batched_forward_grad=False,
  15239. supports_fwgrad_bwgrad=True,
  15240. skips=(
  15241. # MPS has some mild accuracy issues for float16. We divide the tolerances by 10
  15242. DecorateInfo(
  15243. toleranceOverride({torch.float16: tol(atol=1e-4, rtol=0.01)}),
  15244. 'TestConsistency',
  15245. 'test_output_match',
  15246. ),
  15247. # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
  15248. DecorateInfo(
  15249. unittest.skip("Skipped!"),
  15250. 'TestSchemaCheckModeOpInfo',
  15251. 'test_schema_correctness',
  15252. dtypes=(torch.complex64, torch.complex128)),
  15253. # Dispatches in Python to vector_norm. Not sure how to make this test happy
  15254. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',
  15255. dtypes=(torch.complex64, torch.float32,)),)
  15256. ),
  15257. OpInfo(
  15258. "norm",
  15259. variant_test_name="inf",
  15260. sample_inputs_func=sample_inputs_norm_inf,
  15261. dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
  15262. supports_forward_ad=True,
  15263. check_batched_forward_grad=False,
  15264. supports_fwgrad_bwgrad=True,
  15265. # fast gradcheck produces NaNs
  15266. gradcheck_fast_mode=False,
  15267. skips=(
  15268. DecorateInfo(
  15269. toleranceOverride({torch.float16: tol(atol=2e-3, rtol=1e-3)}),
  15270. 'TestInductorOpInfo', 'test_comprehensive', device_type='cuda',
  15271. ),
  15272. # Dispatches in Python to vector_norm. Not sure how to make this test happy
  15273. # Happens to pass on complex64. Also a mystery
  15274. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',
  15275. dtypes=(torch.float32,))
  15276. ),
  15277. ),
  15278. OpInfo('t',
  15279. sample_inputs_func=sample_inputs_t,
  15280. supports_out=False,
  15281. supports_forward_ad=True,
  15282. supports_fwgrad_bwgrad=True,
  15283. # See https://github.com/pytorch/pytorch/pull/78358
  15284. check_batched_forward_grad=False,
  15285. # vmap does not support inplace views
  15286. check_inplace_batched_forward_grad=False,
  15287. autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused
  15288. autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused
  15289. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
  15290. assert_autodiffed=True,
  15291. error_inputs_func=error_inputs_t),
  15292. OpInfo(
  15293. "nn.functional.dropout",
  15294. op=lambda input, *args, **kwargs:
  15295. wrapper_set_seed(torch.nn.functional.dropout, input, *args, **kwargs),
  15296. dtypes=floating_types_and(torch.bfloat16),
  15297. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  15298. skips=(
  15299. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  15300. # Probably because we have used lambda for the op here
  15301. # AssertionError: JIT Test does not execute any logic
  15302. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  15303. # inplace variant dispatches to dropout kernel, while on CUDA
  15304. # the op dispatches to _fused_dropout (with a few more conditions)
  15305. # hence, different values and this skip here
  15306. DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view', device_type='cuda'),
  15307. DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')),
  15308. supports_forward_ad=True,
  15309. supports_fwgrad_bwgrad=True,
  15310. # https://github.com/pytorch/pytorch/issues/66357
  15311. check_batched_forward_grad=False,
  15312. supports_out=False,
  15313. sample_inputs_func=sample_inputs_dropout,
  15314. inplace_variant=lambda input, *args, **kwargs:
  15315. wrapper_set_seed(torch.nn.functional.dropout, input, *args, **kwargs, inplace=True)),
  15316. OpInfo(
  15317. "native_dropout_backward",
  15318. op=torch.ops.aten.native_dropout_backward.default,
  15319. aten_name="native_dropout_backward",
  15320. dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
  15321. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  15322. supports_out=False,
  15323. sample_inputs_func=sample_inputs_dropout_backward,
  15324. skips=(
  15325. DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'),
  15326. # Lazy tensor failures
  15327. DecorateInfo(unittest.skip('Skipped!'), 'TestLazyOpInfo', 'test_dispatched_to_lazy'),
  15328. DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_correctness'),
  15329. DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_correctness_with_reusing_ir'),
  15330. ),
  15331. ),
  15332. OpInfo(
  15333. "nn.functional.dropout2d",
  15334. op=lambda input, *args, **kwargs:
  15335. wrapper_set_seed(torch.nn.functional.dropout2d, input, *args, **kwargs),
  15336. dtypes=floating_types_and(torch.bfloat16),
  15337. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  15338. skips=(
  15339. # lambda impl
  15340. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  15341. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  15342. DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')),
  15343. supports_forward_ad=True,
  15344. supports_fwgrad_bwgrad=True,
  15345. supports_out=False,
  15346. check_batched_forward_grad=False,
  15347. # As per the docs, valid input dims are (3, 4)
  15348. sample_inputs_func=partial(sample_inputs_dropout, valid_input_dim=(3, 4)),
  15349. inplace_variant=lambda input, *args, **kwargs:
  15350. wrapper_set_seed(torch.nn.functional.dropout2d, input, *args, **kwargs, inplace=True)),
  15351. OpInfo(
  15352. "nn.functional.dropout3d",
  15353. op=lambda input, *args, **kwargs:
  15354. wrapper_set_seed(torch.nn.functional.dropout3d, input, *args, **kwargs),
  15355. dtypes=floating_types_and(torch.bfloat16),
  15356. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  15357. skips=(
  15358. # lambda impl
  15359. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  15360. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  15361. DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')),
  15362. supports_forward_ad=True,
  15363. supports_fwgrad_bwgrad=True,
  15364. supports_out=False,
  15365. check_batched_forward_grad=False,
  15366. # As per the docs, valid input dims are (4, 5)
  15367. sample_inputs_func=partial(sample_inputs_dropout, valid_input_dim=(4, 5)),
  15368. inplace_variant=lambda input, *args, **kwargs:
  15369. wrapper_set_seed(torch.nn.functional.dropout3d, input, *args, **kwargs, inplace=True)),
  15370. OpInfo(
  15371. "nn.functional.alpha_dropout",
  15372. op=lambda input, *args, **kwargs:
  15373. wrapper_set_seed(torch.nn.functional.alpha_dropout, input, *args, **kwargs),
  15374. dtypes=floating_types_and(torch.bfloat16),
  15375. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  15376. gradcheck_wrapper=wrapper_set_seed,
  15377. supports_forward_ad=True,
  15378. supports_fwgrad_bwgrad=True,
  15379. supports_out=False,
  15380. sample_inputs_func=sample_inputs_dropout,
  15381. check_batched_forward_grad=False,
  15382. inplace_variant=lambda input, *args, **kwargs:
  15383. wrapper_set_seed(torch.nn.functional.alpha_dropout, input, *args, **kwargs, inplace=True),
  15384. skips=(
  15385. # lambda impl
  15386. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  15387. # AssertionError: Tensor-likes are not close!
  15388. # Fails in cuda11.7
  15389. # Error Log: https://github.com/pytorch/pytorch/actions/runs/3440108478/jobs/5738475757
  15390. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu', device_type='cuda'),
  15391. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),),),
  15392. # In training mode, feature_alpha_dropout currently doesn't support inputs of complex dtype
  15393. # unlike when `train=False`, it supports complex inputs, hence 2 OpInfos to cover all cases
  15394. OpInfo(
  15395. "nn.functional.feature_alpha_dropout",
  15396. op=lambda input, *args, **kwargs:
  15397. wrapper_set_seed(torch.nn.functional.feature_alpha_dropout, input, *args, **kwargs),
  15398. variant_test_name="with_train",
  15399. dtypes=floating_types_and(torch.bfloat16),
  15400. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  15401. skips=(
  15402. # lambda impl
  15403. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  15404. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  15405. # torch.autograd.gradcheck.GradcheckError: While computing batched gradients, got:
  15406. # vmap: We do not yet support calling random operations inside of vmap.
  15407. # Please perform random operations outside of vmap as a workaround
  15408. DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', "test_forward_mode_AD"),
  15409. DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', "test_inplace_forward_mode_AD"),
  15410. DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')),
  15411. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  15412. gradcheck_fast_mode=True,
  15413. supports_forward_ad=True,
  15414. supports_fwgrad_bwgrad=True,
  15415. supports_out=False,
  15416. # As per the docs, valid input dims are (4, 5)
  15417. sample_inputs_func=partial(sample_inputs_dropout, train=True, valid_input_dim=(4, 5)),
  15418. inplace_variant=lambda input, *args, **kwargs:
  15419. wrapper_set_seed(torch.nn.functional.feature_alpha_dropout, input, *args, **kwargs, inplace=True)),
  15420. OpInfo(
  15421. "nn.functional.feature_alpha_dropout",
  15422. op=lambda input, *args, **kwargs:
  15423. wrapper_set_seed(torch.nn.functional.feature_alpha_dropout, input, *args, **kwargs),
  15424. variant_test_name="without_train",
  15425. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
  15426. skips=(
  15427. # lambda impl
  15428. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  15429. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),),
  15430. gradcheck_wrapper=wrapper_set_seed,
  15431. supports_forward_ad=True,
  15432. supports_fwgrad_bwgrad=True,
  15433. supports_out=False,
  15434. sample_inputs_func=partial(sample_inputs_dropout, train=False),
  15435. inplace_variant=lambda input, *args, **kwargs:
  15436. wrapper_set_seed(torch.nn.functional.feature_alpha_dropout, input, *args, **kwargs, inplace=True)),
  15437. OpInfo(
  15438. "nn.functional.one_hot",
  15439. ref=reference_one_hot,
  15440. supports_out=False,
  15441. dtypes=_dispatch_dtypes((torch.int64,)),
  15442. sample_inputs_func=sample_inputs_one_hot,
  15443. ),
  15444. OpInfo(
  15445. "nn.functional.embedding",
  15446. aten_backward_name="embedding_dense_backward",
  15447. # We use lambda to reshuffle the positional arguments.
  15448. # This is because currently only the `input` field of SampleInput
  15449. # is tested in gradient tests.
  15450. op=lambda weight, idx, **kwargs: torch.nn.functional.embedding(idx, weight, **kwargs),
  15451. dtypes=floating_types_and(torch.bfloat16, torch.float16),
  15452. sample_inputs_func=sample_inputs_embedding,
  15453. error_inputs_func=error_inputs_embedding,
  15454. supports_forward_ad=True,
  15455. supports_fwgrad_bwgrad=True,
  15456. skips=(
  15457. # lambda impl
  15458. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  15459. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  15460. # Fails on CI https://github.com/pytorch/pytorch/issues/85377
  15461. DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_compare_cpu'),
  15462. # Reference: https://github.com/pytorch/pytorch/issues/67084
  15463. DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view', device_type='cuda'),
  15464. # Not a problem: embedding does weird stuff to its input (it renormalizes)
  15465. DecorateInfo(unittest.skip('Allowed exemption'), 'TestCompositeCompliance', 'test_operator'),
  15466. ),
  15467. supports_expanded_weight=True,
  15468. supports_out=False,
  15469. ),
  15470. OpInfo(
  15471. "nn.functional.embedding_bag",
  15472. # We use lambda to reshuffle the positional arguments.
  15473. # This is because currently only the `input` field of SampleInput
  15474. # is tested in gradient tests.
  15475. op=lambda weight, idx, **kwargs: torch.nn.functional.embedding_bag(idx, weight, **kwargs),
  15476. dtypes=floating_types_and(torch.bfloat16, torch.float16),
  15477. dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16),
  15478. # backward is not supported for mode `max` and dtype `bfloat16`
  15479. backward_dtypesIfCUDA=floating_types_and(torch.float16),
  15480. sample_inputs_func=sample_inputs_embedding_bag,
  15481. skips=(
  15482. # lambda impl
  15483. DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
  15484. DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
  15485. # Not a problem: embedding_bag does weird stuff to its input (it renormalizes)
  15486. DecorateInfo(unittest.skip('Allowed exemption'), 'TestCompositeCompliance', 'test_operator'),
  15487. ),
  15488. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  15489. supports_out=False,
  15490. supports_gradgrad=False,
  15491. ),
  15492. UnaryUfuncInfo(
  15493. "nn.functional.softplus",
  15494. aten_backward_name='softplus_backward',
  15495. ref=reference_softplus,
  15496. sample_kwargs=lambda device, dtype, input: ({'beta': 3, 'threshold': .2}, {'beta': 3, 'threshold': .2}),
  15497. sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'beta': 3, 'threshold': .2}),
  15498. supports_forward_ad=True,
  15499. supports_fwgrad_bwgrad=True,
  15500. dtypes=floating_types_and(torch.bfloat16),
  15501. dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16),
  15502. decorators=(
  15503. DecorateInfo(
  15504. toleranceOverride
  15505. ({
  15506. torch.half: tol(atol=1e-2, rtol=1e-2),
  15507. torch.bfloat16: tol(atol=1e-2, rtol=1e-2),
  15508. }),
  15509. 'TestUnaryUfuncs'),
  15510. ),
  15511. ),
  15512. OpInfo(
  15513. "nn.functional.mse_loss",
  15514. aten_backward_name='mse_loss_backward',
  15515. ref=loss_reference_reduction_wrapper(lambda input, target: (input - target) ** 2),
  15516. sample_inputs_func=sample_inputs_loss,
  15517. supports_out=False,
  15518. supports_forward_ad=True,
  15519. supports_fwgrad_bwgrad=True,
  15520. dtypes=floating_types_and(torch.float16),
  15521. backward_dtypes=floating_types(),
  15522. dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16),
  15523. backward_dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16),
  15524. skips=(
  15525. # RuntimeError: input->type()->kind() == TypeKind::OptionalType
  15526. # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252,
  15527. # please report a bug to PyTorch.
  15528. DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit", dtypes=(torch.float32,),),
  15529. ),
  15530. ),
  15531. OpInfo(
  15532. "nn.functional.grid_sample",
  15533. dtypes=floating_types(),
  15534. dtypesIfCUDA=floating_types_and(torch.float16),
  15535. supports_out=False,
  15536. sample_inputs_func=sample_inputs_grid_sample,
  15537. supports_gradgrad=False,
  15538. gradcheck_nondet_tol=1e-15),
  15539. # TODO: delete this OpInfo once we add meta support for grid_sampler_3d
  15540. OpInfo(
  15541. "grid_sampler_2d",
  15542. dtypes=floating_types(),
  15543. dtypesIfCUDA=floating_types_and(torch.float16),
  15544. supports_out=False,
  15545. sample_inputs_func=sample_inputs_grid_sampler_2d,
  15546. supports_gradgrad=False,
  15547. gradcheck_nondet_tol=1e-15),
  15548. OpInfo(
  15549. "argwhere",
  15550. ref=np.argwhere,
  15551. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
  15552. supports_out=False,
  15553. supports_autograd=False,
  15554. sample_inputs_func=sample_inputs_argwhere,
  15555. skips=(
  15556. # Compiler issue on ROCm. Might need to skip until ROCm5.5
  15557. DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values',
  15558. dtypes=[torch.bool], active_if=TEST_WITH_ROCM),
  15559. ),
  15560. ),
  15561. ReductionOpInfo(
  15562. 'all',
  15563. identity=True,
  15564. supports_multiple_dims=False,
  15565. supports_autograd=False,
  15566. result_dtype=torch.bool,
  15567. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
  15568. ref=reference_reduction_numpy(np.all),
  15569. skips=(
  15570. # FIXME: does not support passing keepdim without dim
  15571. DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_default_keepdim'),
  15572. # FIXME: does not support dim=None
  15573. DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_none'),
  15574. DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_none_keepdim'),
  15575. # FIXME: uint8 input returns uint8 instead of bool
  15576. DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_result_dtype', dtypes=[torch.uint8]),
  15577. ),
  15578. ),
  15579. ReductionOpInfo(
  15580. 'any',
  15581. identity=False,
  15582. supports_multiple_dims=False,
  15583. supports_autograd=False,
  15584. result_dtype=torch.bool,
  15585. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
  15586. ref=reference_reduction_numpy(np.any),
  15587. skips=(
  15588. # FIXME: does not support passing keepdim without dim
  15589. DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_default_keepdim'),
  15590. # FIXME: does not support dim=None
  15591. DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_none'),
  15592. DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_none_keepdim'),
  15593. # FIXME: uint8 input returns uint8 instead of bool
  15594. DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_result_dtype', dtypes=[torch.uint8]),
  15595. ),
  15596. ),
  15597. ReductionOpInfo(
  15598. 'amax',
  15599. nan_policy='propagate',
  15600. supports_forward_ad=True,
  15601. check_batched_forward_grad=False,
  15602. supports_fwgrad_bwgrad=True,
  15603. dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
  15604. ref=reference_reduction_numpy(np.amax),
  15605. skips=(
  15606. # FIXME: reduces all dimensions when dim=[]
  15607. DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty'),
  15608. DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'),
  15609. ),
  15610. error_inputs_func=error_inputs_aminmax_amax_amin,
  15611. ),
  15612. ReductionOpInfo(
  15613. 'amin',
  15614. nan_policy='propagate',
  15615. supports_forward_ad=True,
  15616. check_batched_forward_grad=False,
  15617. supports_fwgrad_bwgrad=True,
  15618. dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
  15619. ref=reference_reduction_numpy(np.amin),
  15620. skips=(
  15621. # FIXME: reduces all dimensions when dim=[]
  15622. DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty'),
  15623. DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'),
  15624. ),
  15625. error_inputs_func=error_inputs_aminmax_amax_amin,
  15626. ),
  15627. ReductionOpInfo(
  15628. 'argmax',
  15629. supports_multiple_dims=False,
  15630. supports_autograd=False,
  15631. assert_jit_shape_analysis=True,
  15632. result_dtype=torch.int64,
  15633. dtypes=all_types_and(torch.float16, torch.bfloat16),
  15634. ref=reference_reduction_numpy(np.argmax, supports_keepdims=False),
  15635. skips=(
  15636. # FIXME: keepdim parameter is ignored when dim=None
  15637. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'),
  15638. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'),
  15639. ),
  15640. ),
  15641. ReductionOpInfo(
  15642. 'argmin',
  15643. supports_multiple_dims=False,
  15644. supports_autograd=False,
  15645. result_dtype=torch.int64,
  15646. dtypes=all_types_and(torch.float16, torch.bfloat16),
  15647. ref=reference_reduction_numpy(np.argmin, supports_keepdims=False),
  15648. skips=(
  15649. # FIXME: keepdim parameter is ignored when dim=None
  15650. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'),
  15651. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'),
  15652. ),
  15653. ),
  15654. ReductionOpInfo(
  15655. 'count_nonzero',
  15656. identity=0,
  15657. supports_out=False,
  15658. supports_autograd=False,
  15659. result_dtype=torch.int64,
  15660. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
  15661. sample_inputs_func=sample_inputs_reduction_count_nonzero,
  15662. ref=reference_reduction_numpy(np.count_nonzero),
  15663. skips=(
  15664. # FIXME: count_nonzero does not accept keepdim kwarg
  15665. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'),
  15666. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'),
  15667. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_single_keepdim'),
  15668. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
  15669. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_multi_keepdim'),
  15670. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_multi_unsorted_keepdim'),
  15671. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_offbounds_keepdim'),
  15672. # FIXME: dim=[] reduces all dimensions
  15673. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
  15674. ),
  15675. ),
  15676. ReductionOpInfo(
  15677. 'mean',
  15678. nan_policy='propagate',
  15679. supports_forward_ad=True,
  15680. supports_fwgrad_bwgrad=True,
  15681. # FIXME: mean needs 'dim' parameter when using the 'out' overload.
  15682. # Adding it with 'generate_args_kwargs' does not work, since these also get passed
  15683. # onto the reference implementations.
  15684. supports_out=False,
  15685. assert_autodiffed=True,
  15686. assert_jit_shape_analysis=True,
  15687. promotes_int_to_float=True,
  15688. dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
  15689. ref=reference_reduction_numpy(np.mean),
  15690. error_inputs_func=error_inputs_mean,
  15691. skips=(
  15692. # FIXME: mean does not support passing keepdim without passing dim
  15693. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'),
  15694. # FIXME: mean reduces all dimensions when dim=[]
  15695. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
  15696. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
  15697. # FIXME: improve precision
  15698. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input',
  15699. dtypes=[torch.float16]),
  15700. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_extremal_values',
  15701. device_type='cuda', dtypes=[torch.complex64]),
  15702. ),
  15703. ),
  15704. ReductionOpInfo(
  15705. 'nanmean',
  15706. nan_policy='omit',
  15707. assert_autodiffed=True,
  15708. promotes_int_to_float=True,
  15709. supports_forward_ad=True,
  15710. check_batched_forward_grad=False,
  15711. supports_fwgrad_bwgrad=True,
  15712. dtypes=floating_types_and(torch.float16, torch.bfloat16),
  15713. sample_inputs_func=sample_inputs_nan_reduction(supports_multiple_dims=True),
  15714. ref=reference_reduction_numpy(np.nanmean),
  15715. skips=(
  15716. # AssertionError: False is not true :
  15717. # Failure in testing nodes' autodifferentiation.
  15718. DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
  15719. # FIXME: prod reduces all dimensions when dim=[]
  15720. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
  15721. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
  15722. # FIXME: improve precision
  15723. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input',
  15724. dtypes=[torch.float16]),
  15725. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values',
  15726. device_type='cuda', dtypes=[torch.float16]),
  15727. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_extremal_values',
  15728. device_type='cuda', dtypes=[torch.complex64]),
  15729. ),
  15730. ),
  15731. ReductionOpInfo(
  15732. 'std',
  15733. nan_policy='propagate',
  15734. supports_out=True,
  15735. complex_to_real=True,
  15736. supports_forward_ad=True,
  15737. supports_fwgrad_bwgrad=True,
  15738. assert_autodiffed=True,
  15739. promotes_int_to_float=True,
  15740. check_batched_forward_grad=False,
  15741. dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
  15742. dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
  15743. sample_inputs_func=sample_inputs_std_var,
  15744. ref=reference_std_var(np.std),
  15745. generate_args_kwargs=generate_std_var_kwargs,
  15746. skips=(
  15747. # FIXME: cannot specify keepdim without dim
  15748. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'),
  15749. # FIXME: dim=[] reduces all dimensions
  15750. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
  15751. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
  15752. # FIXME: improve precision
  15753. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input',
  15754. dtypes=(torch.float16,)),
  15755. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values',
  15756. dtypes=(torch.float16,)),
  15757. ),
  15758. ),
  15759. ReductionOpInfo(
  15760. 'std',
  15761. variant_test_name='unbiased',
  15762. nan_policy='propagate',
  15763. supports_out=False,
  15764. complex_to_real=True,
  15765. supports_forward_ad=True,
  15766. supports_fwgrad_bwgrad=True,
  15767. assert_autodiffed=True,
  15768. promotes_int_to_float=True,
  15769. check_batched_forward_grad=False,
  15770. dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
  15771. dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
  15772. sample_inputs_func=sample_inputs_std_var_unbiased,
  15773. skips=(
  15774. # FIXME: dim=[] reduces all dimensions
  15775. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
  15776. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
  15777. ),
  15778. ),
  15779. ReductionOpInfo(
  15780. 'var',
  15781. nan_policy='propagate',
  15782. supports_out=True,
  15783. assert_autodiffed=True,
  15784. promotes_int_to_float=True,
  15785. complex_to_real=True,
  15786. supports_forward_ad=True,
  15787. supports_fwgrad_bwgrad=True,
  15788. check_batched_forward_grad=False,
  15789. dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
  15790. dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
  15791. sample_inputs_func=sample_inputs_std_var,
  15792. ref=reference_std_var(np.var),
  15793. generate_args_kwargs=generate_std_var_kwargs,
  15794. skips=(
  15795. # FIXME: cannot specify keepdim without dim
  15796. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'),
  15797. # FIXME: dim=[] reduces all dimensions
  15798. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
  15799. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
  15800. # FIXME: improve precision
  15801. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input'),
  15802. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values'),
  15803. # NumPy is giving NaN for this
  15804. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_large_input'),
  15805. ),
  15806. ),
  15807. ReductionOpInfo(
  15808. 'var',
  15809. variant_test_name='unbiased',
  15810. nan_policy='propagate',
  15811. supports_out=False,
  15812. complex_to_real=True,
  15813. supports_forward_ad=True,
  15814. supports_fwgrad_bwgrad=True,
  15815. assert_autodiffed=True,
  15816. promotes_int_to_float=True,
  15817. check_batched_forward_grad=False,
  15818. dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
  15819. dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
  15820. sample_inputs_func=sample_inputs_std_var_unbiased,
  15821. skips=(
  15822. # FIXME: dim=[] reduces all dimensions
  15823. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
  15824. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
  15825. ),
  15826. ),
  15827. ReductionOpInfo(
  15828. 'prod',
  15829. identity=1,
  15830. nan_policy='propagate',
  15831. supports_multiple_dims=False,
  15832. # https://github.com/pytorch/pytorch/issues/80411
  15833. gradcheck_fast_mode=True,
  15834. supports_out=False,
  15835. supports_forward_ad=True,
  15836. supports_fwgrad_bwgrad=True,
  15837. promotes_int_to_int64=True,
  15838. gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
  15839. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
  15840. dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
  15841. sample_inputs_func=sample_inputs_prod,
  15842. ref=prod_numpy,
  15843. skips=(
  15844. # FIXME: prod does not support passing keepdim without passing dim
  15845. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'),
  15846. # FIXME: prod reduces all dimensions when dim=[]
  15847. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
  15848. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
  15849. # FIXME: prod does not support passing None to dim
  15850. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none'),
  15851. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'),
  15852. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input',
  15853. dtypes=[torch.float16, torch.complex64]),
  15854. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values',
  15855. dtypes=[torch.uint8, torch.float16, torch.complex64]),
  15856. # FIXME: ValueError: The data in MaskedTensor a and Tensor b do not match
  15857. DecorateInfo(unittest.skip("Skipped!"), 'TestOperators', 'test_reduction_all',
  15858. dtypes=[torch.float16]),
  15859. ),
  15860. ),
  15861. ReductionOpInfo(
  15862. 'sum',
  15863. identity=0,
  15864. nan_policy='propagate',
  15865. supports_out=False,
  15866. supports_forward_ad=True,
  15867. supports_fwgrad_bwgrad=True,
  15868. supports_sparse=True,
  15869. promotes_int_to_int64=True,
  15870. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
  15871. dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
  15872. ref=reference_reduction_numpy(np.sum),
  15873. sample_inputs_sparse_coo_func=partial(sample_inputs_reduction_sparse, layout=torch.sparse_coo),
  15874. sample_inputs_sparse_csr_func=partial(sample_inputs_reduction_sparse, layout=torch.sparse_csr),
  15875. sample_inputs_sparse_csc_func=partial(sample_inputs_reduction_sparse, layout=torch.sparse_csc),
  15876. sample_inputs_sparse_bsr_func=partial(sample_inputs_reduction_sparse, layout=torch.sparse_bsr),
  15877. sample_inputs_sparse_bsc_func=partial(sample_inputs_reduction_sparse, layout=torch.sparse_bsc),
  15878. skips=(
  15879. # FIXME: sum does not support passing keepdim without passing dim
  15880. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'),
  15881. # FIXME: sum reduces all dimensions when dim=[]
  15882. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
  15883. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
  15884. # FIXME: improve precision
  15885. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input',
  15886. dtypes=[torch.float16]),
  15887. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values',
  15888. dtypes=[torch.float16]),
  15889. DecorateInfo(unittest.skip("Skipped!"), 'TestOperators', 'test_reduction_all',
  15890. dtypes=[torch.float32]),
  15891. ),
  15892. ),
  15893. ReductionOpInfo(
  15894. 'nansum',
  15895. identity=0,
  15896. nan_policy='omit',
  15897. supports_out=True,
  15898. promotes_int_to_int64=True,
  15899. supports_forward_ad=True,
  15900. check_batched_forward_grad=False,
  15901. supports_fwgrad_bwgrad=True,
  15902. dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
  15903. sample_inputs_func=sample_inputs_nan_reduction(supports_multiple_dims=True),
  15904. ref=reference_reduction_numpy(np.nansum),
  15905. skips=(
  15906. # please report a bug to PyTorch.
  15907. DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
  15908. # FIXME: nansum reduces all dimensions when dim=[]
  15909. DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty'),
  15910. DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'),
  15911. # FIXME: flaky test so skipped instead of xfailed
  15912. # possibly bad low precision reference in numpy
  15913. DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input',
  15914. dtypes=[torch.float16]),
  15915. ),
  15916. ),
  15917. OpInfo(
  15918. "nn.functional.ctc_loss",
  15919. dtypes=floating_types(),
  15920. supports_out=False,
  15921. sample_inputs_func=sample_inputs_ctc_loss,
  15922. skips=(
  15923. # https://github.com/pytorch/pytorch/issues/67462
  15924. # torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for output 0 with respect to input 0
  15925. DecorateInfo(
  15926. unittest.expectedFailure,
  15927. "TestBwdGradients",
  15928. "test_fn_grad",
  15929. dtypes=(torch.float64,),
  15930. ),
  15931. # RuntimeError: derivative for aten::_ctc_loss_backward is not implemented
  15932. DecorateInfo(
  15933. unittest.expectedFailure,
  15934. "TestBwdGradients",
  15935. "test_fn_gradgrad",
  15936. dtypes=(torch.float64,),
  15937. ),
  15938. # RuntimeError: derivative for aten::_ctc_loss_backward is not implemented
  15939. DecorateInfo(
  15940. unittest.skip("Skipped!"),
  15941. "TestJit",
  15942. "test_variant_consistency_jit",
  15943. dtypes=(torch.float32,),
  15944. ),
  15945. # Ref: https://github.com/pytorch/pytorch/issues/85231
  15946. DecorateInfo(unittest.skip("Fails with ASAN"),
  15947. 'TestProxyTensorOpInfo',
  15948. 'test_make_fx_fake_exhaustive', active_if=TEST_WITH_ASAN),
  15949. ),
  15950. ),
  15951. OpInfo(
  15952. "nn.functional.cosine_embedding_loss",
  15953. dtypes=all_types_and(torch.bfloat16, torch.bool),
  15954. dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16, torch.bool),
  15955. supports_out=False,
  15956. supports_forward_ad=True,
  15957. supports_fwgrad_bwgrad=True,
  15958. sample_inputs_func=sample_inputs_cosine_embedding_loss,
  15959. ),
  15960. OpInfo(
  15961. "nn.functional.nll_loss",
  15962. dtypes=floating_types_and(torch.bfloat16),
  15963. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  15964. supports_out=False,
  15965. sample_inputs_func=sample_inputs_nll_loss,
  15966. supports_forward_ad=True,
  15967. supports_fwgrad_bwgrad=True,
  15968. assert_jit_shape_analysis=True,
  15969. skips=(
  15970. # RuntimeError:
  15971. # undefined value tensor:
  15972. # File "<string>", line 3
  15973. # def the_method(i0, i1):
  15974. # return torch.nn.functional.nll_loss(i0, i1, weight=tensor([8.4784, 1.7658, 4.3228], dtype=torch.float32))
  15975. # ~~~~~~ <--- HERE
  15976. DecorateInfo(unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit", dtypes=(torch.float32,),),
  15977. ),
  15978. ),
  15979. OpInfo(
  15980. "nn.functional.gaussian_nll_loss",
  15981. dtypes=floating_types_and(torch.bfloat16),
  15982. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  15983. # Runs very slowly on slow gradcheck - alternatively reduce input sizes
  15984. gradcheck_fast_mode=True,
  15985. supports_out=False,
  15986. supports_forward_ad=True,
  15987. supports_fwgrad_bwgrad=True,
  15988. sample_inputs_func=sample_inputs_gaussian_nll_loss,
  15989. error_inputs_func=error_inputs_gaussian_nll_loss,
  15990. skips=(
  15991. # Pre-existing condition (calls .item); needs to be fixed
  15992. DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
  15993. DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'),
  15994. # Pre-existing condition (calls .item); needs to be fixed
  15995. DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'),
  15996. # JIT does not support variadic tensors.
  15997. # RuntimeError: input->type()->kind() == TypeKind::OptionalType
  15998. # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270,
  15999. # please report a bug to PyTorch.
  16000. DecorateInfo(unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit", dtypes=(torch.float32,),),
  16001. ),
  16002. decorators=(
  16003. DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02),
  16004. torch.bfloat16: tol(atol=1e-02, rtol=1e-02)}),
  16005. 'TestCudaFuserOpInfo', 'test_nvfuser_correctness'),
  16006. )
  16007. ),
  16008. OpInfo(
  16009. "nn.functional.hinge_embedding_loss",
  16010. dtypes=floating_types_and(torch.bfloat16),
  16011. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  16012. supports_out=False,
  16013. supports_forward_ad=True,
  16014. supports_fwgrad_bwgrad=True,
  16015. sample_inputs_func=sample_inputs_hinge_embedding_loss,
  16016. error_inputs_func=error_inputs_hinge_embedding_loss,
  16017. reference_inputs_func=reference_inputs_hinge_embedding_loss,
  16018. ),
  16019. OpInfo(
  16020. "nn.functional.huber_loss",
  16021. aten_backward_name='huber_loss_backward',
  16022. dtypes=floating_types_and(torch.float16, torch.bfloat16),
  16023. supports_out=False,
  16024. supports_forward_ad=True,
  16025. sample_inputs_func=sample_inputs_huber_loss,
  16026. error_inputs_func=error_inputs_huber_loss,
  16027. skips=(
  16028. # JIT does not support variadic tensors.
  16029. # RuntimeError: input->type()->kind() == TypeKind::OptionalType
  16030. # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270,
  16031. # please report a bug to PyTorch.
  16032. DecorateInfo(unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit", dtypes=(torch.float32,),),
  16033. )
  16034. ),
  16035. OpInfo(
  16036. "nn.functional.pdist",
  16037. ref=reference_pdist,
  16038. sample_inputs_func=sample_inputs_pdist,
  16039. dtypes=floating_types(),
  16040. supports_out=False,
  16041. supports_gradgrad=False,
  16042. skips=(
  16043. DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'),
  16044. )
  16045. ),
  16046. OpInfo(
  16047. "nn.functional.poisson_nll_loss",
  16048. dtypes=all_types_and(torch.bfloat16),
  16049. dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
  16050. supports_out=False,
  16051. supports_forward_ad=True,
  16052. supports_fwgrad_bwgrad=True,
  16053. sample_inputs_func=sample_inputs_poisson_nll_loss,
  16054. error_inputs_func=error_inputs_poisson_nll_loss,
  16055. ),
  16056. OpInfo(
  16057. "argsort",
  16058. dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
  16059. dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
  16060. sample_inputs_func=sample_inputs_argsort,
  16061. supports_out=False,
  16062. supports_autograd=False,
  16063. skips=(
  16064. DecorateInfo(
  16065. unittest.skip("Skipped!"),
  16066. "TestJit",
  16067. "test_variant_consistency_jit",
  16068. dtypes=(torch.float32,),
  16069. ),
  16070. ),
  16071. ),
  16072. OpInfo(
  16073. "repeat_interleave",
  16074. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
  16075. backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16, torch.chalf),
  16076. sample_inputs_func=sample_inputs_repeat_interleave,
  16077. supports_out=False,
  16078. supports_forward_ad=True,
  16079. supports_fwgrad_bwgrad=True,
  16080. # See https://github.com/pytorch/pytorch/pull/78358
  16081. check_batched_forward_grad=False,
  16082. skips=(
  16083. DecorateInfo(
  16084. unittest.skip("Skipped!"),
  16085. "TestJit",
  16086. "test_variant_consistency_jit",
  16087. dtypes=(torch.float32, torch.complex64),
  16088. ),
  16089. ),
  16090. ),
  16091. OpInfo(
  16092. "nn.functional.pairwise_distance",
  16093. ref=lambda a, b, p=2.0, eps=1e-6, keepdim=False: (
  16094. np.sum(np.abs(a - b + eps) ** p, axis=-1, keepdims=keepdim) ** (1 / p)
  16095. ),
  16096. sample_inputs_func=sample_inputs_pairwise_distance,
  16097. dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
  16098. supports_out=False,
  16099. supports_forward_ad=True,
  16100. supports_fwgrad_bwgrad=True,
  16101. skips=(
  16102. DecorateInfo(
  16103. unittest.skip("Skipped!"),
  16104. "TestJit",
  16105. "test_variant_consistency_jit",
  16106. dtypes=(torch.float32, torch.complex64),
  16107. ),
  16108. ),
  16109. ),
  16110. OpInfo(
  16111. "nn.functional.pixel_shuffle",
  16112. sample_inputs_func=sample_inputs_pixel_shuffle,
  16113. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
  16114. supports_out=False,
  16115. supports_forward_ad=True,
  16116. supports_fwgrad_bwgrad=True,
  16117. skips=(
  16118. DecorateInfo(
  16119. unittest.skip("Skipped!"),
  16120. "TestJit",
  16121. "test_variant_consistency_jit",
  16122. dtypes=(torch.float32, torch.complex64),
  16123. ),
  16124. ),
  16125. ),
  16126. OpInfo(
  16127. "nn.functional.pixel_unshuffle",
  16128. sample_inputs_func=sample_inputs_pixel_unshuffle,
  16129. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
  16130. supports_out=False,
  16131. supports_forward_ad=True,
  16132. supports_fwgrad_bwgrad=True,
  16133. skips=(
  16134. DecorateInfo(
  16135. unittest.skip("Skipped!"),
  16136. "TestJit",
  16137. "test_variant_consistency_jit",
  16138. dtypes=(torch.float32, torch.complex64),
  16139. ),
  16140. ),
  16141. ),
  16142. OpInfo(
  16143. "nn.functional.kl_div",
  16144. sample_inputs_func=sample_inputs_kl_div,
  16145. dtypes=floating_types_and(torch.bfloat16),
  16146. dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
  16147. supports_out=False,
  16148. supports_forward_ad=True,
  16149. supports_fwgrad_bwgrad=True,
  16150. ),
  16151. OpInfo(
  16152. "diagflat",
  16153. ref=lambda input, offset=0: np.diagflat(input, k=offset),
  16154. sample_inputs_func=sample_inputs_diagflat,
  16155. dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
  16156. dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
  16157. supports_out=False,
  16158. supports_forward_ad=True,
  16159. supports_fwgrad_bwgrad=True,
  16160. # See https://github.com/pytorch/pytorch/pull/78358
  16161. check_batched_forward_grad=False,
  16162. ),
  16163. OpInfo(
  16164. 'scatter_reduce',
  16165. variant_test_name='sum',
  16166. # complex not added to dtypes as complex gradients are not properly handled
  16167. # and scatter_reduce hasn't been added to the whitelist in gen_variable_type yet
  16168. dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
  16169. supports_forward_ad=True,
  16170. supports_fwgrad_bwgrad=True,
  16171. sample_inputs_func=sample_inputs_scatter_reduce,
  16172. ),
  16173. OpInfo(
  16174. 'scatter_reduce',
  16175. variant_test_name='prod',
  16176. # complex not added to dtypes as complex gradients are not properly handled
  16177. # and scatter_reduce hasn't been added to the whitelist in gen_variable_type yet
  16178. dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
  16179. dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
  16180. sample_inputs_func=sample_inputs_scatter_reduce,
  16181. skips=(
  16182. # Pre-existing condition (calls .item); needs to be fixed
  16183. DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
  16184. # Not implemented
  16185. DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'),
  16186. DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_inplace_forward_mode_AD'),
  16187. DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'),
  16188. ),
  16189. ),
  16190. OpInfo(
  16191. 'scatter_reduce',
  16192. variant_test_name='mean',
  16193. # complex not added to dtypes as complex gradients are not properly handled
  16194. # and scatter_reduce hasn't been added to the whitelist in gen_variable_type yet
  16195. dtypes=all_types_and(torch.float16, torch.bfloat16),
  16196. dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
  16197. supports_forward_ad=True,
  16198. supports_fwgrad_bwgrad=True,
  16199. sample_inputs_func=sample_inputs_scatter_reduce,
  16200. ),
  16201. OpInfo(
  16202. 'scatter_reduce',
  16203. variant_test_name='amin',
  16204. dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
  16205. dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
  16206. supports_forward_ad=True,
  16207. check_batched_forward_grad=False,
  16208. supports_fwgrad_bwgrad=True,
  16209. sample_inputs_func=sample_inputs_scatter_reduce,
  16210. ),
  16211. OpInfo(
  16212. 'scatter_reduce',
  16213. variant_test_name='amax',
  16214. dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
  16215. dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
  16216. supports_forward_ad=True,
  16217. check_batched_forward_grad=False,
  16218. supports_fwgrad_bwgrad=True,
  16219. sample_inputs_func=sample_inputs_scatter_reduce,
  16220. ),
  16221. OpInfo(
  16222. '_segment_reduce',
  16223. aten_name='segment_reduce',
  16224. variant_test_name='lengths',
  16225. dtypes=floating_types_and(torch.float16, torch.bfloat16),
  16226. supports_out=False,
  16227. # RuntimeError: derivative for aten::_segment_reduce_backward is not implemented
  16228. supports_gradgrad=False,
  16229. sample_inputs_func=sample_inputs_segment_reduce,
  16230. skips=(
  16231. # FIXME: CUDA driver API confirmed a leak in
  16232. # __main__.TestJitCUDA.test_variant_consistency_jit_segment_reduce_cuda_float32
  16233. DecorateInfo(
  16234. unittest.skip("Skipped!"),
  16235. "TestJit",
  16236. "test_variant_consistency_jit",
  16237. device_type="cuda",
  16238. ),
  16239. ),
  16240. ),
  16241. OpInfo(
  16242. '_segment_reduce',
  16243. aten_name='segment_reduce',
  16244. variant_test_name='offsets',
  16245. dtypes=floating_types_and(torch.float16, torch.bfloat16),
  16246. supports_out=False,
  16247. # RuntimeError: derivative for aten::_segment_reduce_backward is not implemented
  16248. supports_gradgrad=False,
  16249. sample_inputs_func=partial(sample_inputs_segment_reduce, mode='offsets'),
  16250. skips=(
  16251. # FIXME: CUDA driver API confirmed a leak in
  16252. # __main__.TestJitCUDA.test_variant_consistency_jit_segment_reduce_cuda_float32
  16253. DecorateInfo(
  16254. unittest.skip("Skipped!"),
  16255. "TestJit",
  16256. "test_variant_consistency_jit",
  16257. device_type="cuda",
  16258. ),
  16259. ),
  16260. ),
  16261. ]
  16262. op_db += opinfo.definitions.op_db
  16263. # Separate registry for experimental Python Reference OpInfos.
  16264. python_ref_db = [
  16265. #
  16266. # Elementwise Unary OpInfos
  16267. #
  16268. ElementwiseUnaryPythonRefInfo(
  16269. "_refs.abs",
  16270. torch_opinfo_name="abs",
  16271. skips=(
  16272. # Reference result was farther (0.0) from the precise computation
  16273. # than the torch result was (nan)!
  16274. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',
  16275. dtypes=(torch.chalf,), device_type='cpu', active_if=not (IS_MACOS or IS_WINDOWS)),
  16276. # Reference result was farther (0.0) from the precise computation
  16277. # than the torch result was (nan)!
  16278. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
  16279. dtypes=(torch.chalf,), device_type='cpu', active_if=not (IS_MACOS or IS_WINDOWS)),
  16280. )
  16281. ),
  16282. ElementwiseUnaryPythonRefInfo(
  16283. "_refs.acos",
  16284. torch_opinfo_name="acos",
  16285. ),
  16286. ElementwiseUnaryPythonRefInfo(
  16287. "_refs.acosh",
  16288. torch_opinfo_name="acosh",
  16289. supports_nvfuser=False,
  16290. ),
  16291. ElementwiseUnaryPythonRefInfo(
  16292. "_refs.asin",
  16293. torch_opinfo_name="asin",
  16294. ),
  16295. ElementwiseUnaryPythonRefInfo(
  16296. "_refs.asinh",
  16297. torch_opinfo_name="asinh",
  16298. supports_nvfuser=False,
  16299. ),
  16300. PythonRefInfo(
  16301. "_refs.lerp",
  16302. torch_opinfo_name="lerp",
  16303. supports_nvfuser=False,
  16304. ),
  16305. PythonRefInfo(
  16306. "_refs.ones",
  16307. torch_opinfo_name="ones",
  16308. skips=(
  16309. # Tests that assume input is a tensor or sequence of tensors
  16310. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
  16311. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
  16312. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
  16313. ),
  16314. supports_nvfuser=False,
  16315. ),
  16316. PythonRefInfo(
  16317. "_refs.zeros",
  16318. torch_opinfo_name="zeros",
  16319. skips=(
  16320. # Tests that assume input is a tensor or sequence of tensors
  16321. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
  16322. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
  16323. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
  16324. ),
  16325. supports_nvfuser=False,
  16326. ),
  16327. PythonRefInfo(
  16328. "_refs.cauchy",
  16329. torch_opinfo_name="cauchy",
  16330. decorators=(
  16331. # TODO: RuntimeError: no _refs support for torch.rand_like
  16332. DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"),
  16333. 'TestCommon',
  16334. 'test_python_ref'),
  16335. # AssertionError: Tensor-likes are not close!
  16336. DecorateInfo(unittest.skip("Expected: cauchy is not comparable"),
  16337. 'TestCommon',
  16338. 'test_out'),
  16339. DecorateInfo(unittest.skip("Expected: cauchy is not comparable"),
  16340. 'TestCommon',
  16341. 'test_out_warning'),
  16342. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'),
  16343. DecorateInfo(unittest.skip("Expected: cauchy is not comparable"),
  16344. 'TestCommon',
  16345. 'test_python_ref_torch_fallback'),
  16346. DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
  16347. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
  16348. )
  16349. ),
  16350. PythonRefInfo(
  16351. "_refs.exponential",
  16352. torch_opinfo_name="exponential",
  16353. supports_out=True,
  16354. decorators=(
  16355. # dtypes that do not support check_uniform_bounds of rand_like
  16356. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta',
  16357. dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64)),
  16358. DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_dtypes'),
  16359. # TODO: RuntimeError: no _refs support for torch.rand_like
  16360. DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"),
  16361. 'TestCommon',
  16362. 'test_python_ref'),
  16363. # AssertionError: Tensor-likes are not close!
  16364. DecorateInfo(unittest.skip("Expected: exponential is not comparable"),
  16365. 'TestCommon',
  16366. 'test_out'),
  16367. DecorateInfo(unittest.skip("Expected: exponential is not comparable"),
  16368. 'TestCommon',
  16369. 'test_out_warning'),
  16370. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'),
  16371. DecorateInfo(unittest.skip("Expected: exponential is not comparable"),
  16372. 'TestCommon',
  16373. 'test_python_ref_torch_fallback'),
  16374. DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
  16375. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
  16376. )
  16377. ),
  16378. PythonRefInfo(
  16379. "_refs.geometric",
  16380. torch_opinfo_name="geometric",
  16381. supports_out=True,
  16382. decorators=(
  16383. # dtypes that do not support check_uniform_bounds of rand_like
  16384. DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_dtypes'),
  16385. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta',
  16386. dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64)),
  16387. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
  16388. dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64)),
  16389. # TODO: RuntimeError: no _refs support for torch.rand_like
  16390. DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"),
  16391. 'TestCommon',
  16392. 'test_python_ref'),
  16393. DecorateInfo(unittest.skip("Expected: geometric is not comparable"),
  16394. 'TestCommon',
  16395. 'test_python_ref_executor', device_type='cuda'),
  16396. # AssertionError: Tensor-likes are not close!
  16397. DecorateInfo(unittest.skip("Expected: geometric is not comparable"),
  16398. 'TestCommon',
  16399. 'test_out'),
  16400. DecorateInfo(unittest.skip("Expected: geometric is not comparable"),
  16401. 'TestCommon',
  16402. 'test_out_warning'),
  16403. DecorateInfo(unittest.skip("Expected: geometric is not comparable"),
  16404. 'TestCommon',
  16405. 'test_python_ref_torch_fallback'),
  16406. DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
  16407. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
  16408. )
  16409. ),
  16410. PythonRefInfo(
  16411. "_refs.log_normal",
  16412. torch_opinfo_name="log_normal",
  16413. supports_out=True,
  16414. decorators=(
  16415. # TODO: RuntimeError: no _refs support for torch.rand_like
  16416. DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"),
  16417. 'TestCommon',
  16418. 'test_python_ref'),
  16419. DecorateInfo(unittest.skip("Expected: log_normal is not comparable"),
  16420. 'TestCommon',
  16421. 'test_python_ref_executor', device_type='cuda'),
  16422. # AssertionError: Tensor-likes are not close!
  16423. DecorateInfo(unittest.skip("Expected: log_normal is not comparable"),
  16424. 'TestCommon',
  16425. 'test_out'),
  16426. DecorateInfo(unittest.skip("Expected: log_normal is not comparable"),
  16427. 'TestCommon',
  16428. 'test_out_warning'),
  16429. DecorateInfo(unittest.skip("Expected: log_normal is not comparable"),
  16430. 'TestCommon',
  16431. 'test_python_ref_torch_fallback'),
  16432. DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
  16433. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
  16434. )
  16435. ),
  16436. PythonRefInfo(
  16437. "_refs.arange",
  16438. torch_opinfo_name="arange",
  16439. skips=(
  16440. # Tests that assume input is a tensor or sequence of tensors
  16441. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
  16442. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
  16443. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
  16444. ),
  16445. supports_nvfuser=False,
  16446. ),
  16447. PythonRefInfo(
  16448. "_refs.linspace",
  16449. torch_opinfo_name="linspace",
  16450. skips=(
  16451. # Tests that assume input is a tensor or sequence of tensors
  16452. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
  16453. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
  16454. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
  16455. # cpu implementation is wrong on some integral types
  16456. # https://github.com/pytorch/pytorch/issues/81996
  16457. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
  16458. dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64), device_type="cpu"),
  16459. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',
  16460. dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64), device_type="cpu"),
  16461. # cuda implementation is off-by-one on some inputs due to precision issues
  16462. # https://github.com/pytorch/pytorch/issues/82230
  16463. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
  16464. dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64),
  16465. device_type="cuda"),
  16466. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',
  16467. dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64),
  16468. device_type="cuda"),
  16469. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor',
  16470. dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64),
  16471. device_type="cuda"),
  16472. ),
  16473. supports_nvfuser=False,
  16474. ),
  16475. PythonRefInfo(
  16476. "_refs.logspace",
  16477. torch_opinfo_name="logspace",
  16478. skips=(
  16479. # Tests that assume input is a tensor or sequence of tensors
  16480. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
  16481. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
  16482. DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
  16483. # Off-by-one issue when casting floats to ints
  16484. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
  16485. dtypes=(torch.int16, torch.int32, torch.int64),
  16486. device_type="cuda"),
  16487. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',
  16488. dtypes=(torch.int16, torch.int32, torch.int64),
  16489. device_type="cuda"),
  16490. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor',
  16491. dtypes=(torch.int16, torch.int32, torch.int64),
  16492. device_type="cuda"),
  16493. ),
  16494. supports_nvfuser=False,
  16495. ),
  16496. PythonRefInfo(
  16497. "_refs.meshgrid",
  16498. torch_opinfo_name="meshgrid",
  16499. torch_opinfo_variant_name="variadic_tensors",
  16500. supports_nvfuser=False,
  16501. ),
  16502. PythonRefInfo(
  16503. "_refs.to",
  16504. torch_opinfo_name="to",
  16505. supports_nvfuser=False,
  16506. ),
  16507. PythonRefInfo(
  16508. "_refs.triu",
  16509. torch_opinfo_name="triu",
  16510. supports_nvfuser=False,
  16511. ),
  16512. PythonRefInfo(
  16513. "_refs.tril",
  16514. torch_opinfo_name="tril",
  16515. supports_nvfuser=False,
  16516. ),
  16517. PythonRefInfo(
  16518. "_refs.triu_indices",
  16519. torch_opinfo_name="triu_indices",
  16520. supports_nvfuser=False,
  16521. # the implementation uses torch.stack that violates view consistency
  16522. validate_view_consistency=False,
  16523. skips=(
  16524. # skip these tests since we have non tensor input
  16525. DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_noncontiguous_samples'),
  16526. DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'),
  16527. DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'),
  16528. DecorateInfo(unittest.skip('Skipped!'), 'TestMathBits', 'test_neg_view'),
  16529. )),
  16530. PythonRefInfo(
  16531. "_refs.tril_indices",
  16532. torch_opinfo_name="tril_indices",
  16533. supports_nvfuser=False,
  16534. # the implementation uses torch.stack that violates view consistency
  16535. validate_view_consistency=False,
  16536. skips=(
  16537. # skip these tests since we have non tensor input
  16538. DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_noncontiguous_samples'),
  16539. DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'),
  16540. DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'),
  16541. DecorateInfo(unittest.skip('Skipped!'), 'TestMathBits', 'test_neg_view'),
  16542. )),
  16543. PythonRefInfo(
  16544. "_refs.meshgrid",
  16545. torch_opinfo_name="meshgrid",
  16546. torch_opinfo_variant_name="list_of_tensors",
  16547. supports_nvfuser=False,
  16548. ),
  16549. PythonRefInfo(
  16550. "_refs.movedim",
  16551. aliases=('moveaxis',),
  16552. torch_opinfo_name="movedim",
  16553. supports_nvfuser=False,
  16554. ),
  16555. PythonRefInfo(
  16556. "_refs.bucketize",
  16557. torch_opinfo_name="bucketize",
  16558. skips=(
  16559. # RuntimeError: It appears that you're trying to get value out of a tracing tensor with
  16560. # aten._local_scalar_dense.default - erroring out! [...]
  16561. # triggered by mid_val = boundaries[mid]
  16562. DecorateInfo(unittest.expectedFailure, "TestCommon", "test_python_ref_executor"),
  16563. )
  16564. ),
  16565. ElementwiseUnaryPythonRefInfo(
  16566. "_refs.atan",
  16567. torch_opinfo_name="atan",
  16568. ),
  16569. ElementwiseUnaryPythonRefInfo(
  16570. "_refs.atanh",
  16571. torch_opinfo_name="atanh",
  16572. ),
  16573. ElementwiseUnaryPythonRefInfo(
  16574. "_refs.bitwise_not",
  16575. torch_opinfo_name="bitwise_not",
  16576. ),
  16577. ElementwiseUnaryPythonRefInfo(
  16578. "_refs.ceil",
  16579. torch_opinfo_name="ceil",
  16580. # Fails on int32
  16581. # https://github.com/pytorch/pytorch/issues/85258
  16582. supports_nvfuser=False,
  16583. ),
  16584. ElementwiseUnaryPythonRefInfo(
  16585. "_refs.conj_physical",
  16586. torch_opinfo_name="conj_physical",
  16587. supports_nvfuser=False,
  16588. ),
  16589. ElementwiseUnaryPythonRefInfo(
  16590. "_refs.cos",
  16591. torch_opinfo_name="cos",
  16592. ),
  16593. ElementwiseUnaryPythonRefInfo(
  16594. "_refs.cosh",
  16595. torch_opinfo_name="cosh",
  16596. ),
  16597. ElementwiseUnaryPythonRefInfo(
  16598. "_refs.digamma",
  16599. torch_opinfo_name="digamma",
  16600. supports_nvfuser=False,
  16601. ),
  16602. ElementwiseUnaryPythonRefInfo(
  16603. "_refs.erf",
  16604. torch_opinfo_name="erf",
  16605. ),
  16606. ElementwiseUnaryPythonRefInfo(
  16607. "_refs.erfinv",
  16608. torch_opinfo_name="erfinv",
  16609. supports_nvfuser=False,
  16610. ),
  16611. ElementwiseUnaryPythonRefInfo(
  16612. "_refs.erfc",
  16613. torch_opinfo_name="erfc",
  16614. ),
  16615. ElementwiseUnaryPythonRefInfo(
  16616. "_refs.exp",
  16617. torch_opinfo_name="exp",
  16618. ),
  16619. ElementwiseUnaryPythonRefInfo(
  16620. "_refs.expm1",
  16621. torch_opinfo_name="expm1",
  16622. ),
  16623. ElementwiseUnaryPythonRefInfo(
  16624. "_refs.exp2",
  16625. torch_opinfo_name="exp2",
  16626. supports_nvfuser=False,
  16627. ),
  16628. ElementwiseUnaryPythonRefInfo(
  16629. "_refs.fill",
  16630. torch_opinfo_name="fill",
  16631. supports_out=True,
  16632. supports_nvfuser=False,
  16633. ),
  16634. ElementwiseUnaryPythonRefInfo(
  16635. "_refs.floor",
  16636. torch_opinfo_name="floor",
  16637. # Fails on int32
  16638. # https://github.com/pytorch/pytorch/issues/85258
  16639. supports_nvfuser=False,
  16640. ),
  16641. ElementwiseUnaryPythonRefInfo(
  16642. "_refs.frac",
  16643. torch_opinfo_name="frac",
  16644. supports_nvfuser=False,
  16645. ),
  16646. ElementwiseUnaryPythonRefInfo(
  16647. "_refs.imag",
  16648. torch_opinfo_name="imag",
  16649. supports_nvfuser=False,
  16650. ),
  16651. ElementwiseUnaryPythonRefInfo(
  16652. "_refs.isfinite",
  16653. torch_opinfo_name="isfinite",
  16654. supports_out=True,
  16655. supports_nvfuser=False,
  16656. ),
  16657. ElementwiseUnaryPythonRefInfo(
  16658. "_refs.isinf",
  16659. torch_opinfo_name="isinf",
  16660. supports_out=True,
  16661. supports_nvfuser=False,
  16662. ),
  16663. ElementwiseUnaryPythonRefInfo(
  16664. "_refs.isposinf",
  16665. torch_opinfo_name="isposinf",
  16666. supports_out=True,
  16667. supports_nvfuser=False,
  16668. ),
  16669. ElementwiseUnaryPythonRefInfo(
  16670. "_refs.isneginf",
  16671. torch_opinfo_name="isneginf",
  16672. supports_out=True,
  16673. supports_nvfuser=False,
  16674. ),
  16675. ElementwiseUnaryPythonRefInfo(
  16676. "_refs.isnan",
  16677. torch_opinfo_name="isnan",
  16678. supports_out=True,
  16679. ),
  16680. ElementwiseUnaryPythonRefInfo(
  16681. "_refs.isreal",
  16682. torch_opinfo_name="isreal",
  16683. supports_out=True,
  16684. supports_nvfuser=False,
  16685. ),
  16686. ElementwiseUnaryPythonRefInfo(
  16687. "_refs.i0",
  16688. torch_opinfo_name="i0",
  16689. supports_nvfuser=False,
  16690. ),
  16691. ElementwiseUnaryPythonRefInfo(
  16692. "_refs.lgamma",
  16693. torch_opinfo_name="lgamma",
  16694. ),
  16695. ElementwiseUnaryPythonRefInfo(
  16696. "_refs.special.multigammaln",
  16697. torch_opinfo_name="mvlgamma",
  16698. torch_opinfo_variant_name="mvlgamma_p_1",
  16699. supports_nvfuser=False,
  16700. ),
  16701. ElementwiseUnaryPythonRefInfo(
  16702. "_refs.special.multigammaln",
  16703. torch_opinfo_name="mvlgamma",
  16704. torch_opinfo_variant_name="mvlgamma_p_3",
  16705. supports_nvfuser=False,
  16706. ),
  16707. ElementwiseUnaryPythonRefInfo(
  16708. "_refs.special.multigammaln",
  16709. torch_opinfo_name="mvlgamma",
  16710. torch_opinfo_variant_name="mvlgamma_p_5",
  16711. supports_nvfuser=False,
  16712. ),
  16713. ElementwiseUnaryPythonRefInfo(
  16714. "_refs.log",
  16715. torch_opinfo_name="log",
  16716. ),
  16717. ElementwiseUnaryPythonRefInfo(
  16718. "_refs.log1p",
  16719. torch_opinfo_name="log1p",
  16720. ),
  16721. ElementwiseUnaryPythonRefInfo(
  16722. "_refs.log10",
  16723. torch_opinfo_name="log10",
  16724. ),
  16725. ElementwiseUnaryPythonRefInfo(
  16726. "_refs.log2",
  16727. torch_opinfo_name="log2",
  16728. ),
  16729. PythonRefInfo(
  16730. "_refs.logsumexp",
  16731. torch_opinfo_name="logsumexp",
  16732. # When keepdim=False logsumexp function uses squeeze operation
  16733. # that is not yet exposed in nvFuser's Python API.
  16734. supports_nvfuser=False,
  16735. ),
  16736. PythonRefInfo(
  16737. "_refs.log_softmax",
  16738. torch_opinfo_name="log_softmax",
  16739. torch_opinfo_variant_name="with_dtype",
  16740. ),
  16741. ElementwiseUnaryPythonRefInfo(
  16742. "_refs.nan_to_num",
  16743. torch_opinfo_name="nan_to_num",
  16744. supports_nvfuser=False,
  16745. ),
  16746. ElementwiseUnaryPythonRefInfo(
  16747. "_refs.neg",
  16748. torch_opinfo_name="neg",
  16749. ),
  16750. ElementwiseUnaryPythonRefInfo(
  16751. "_refs.positive",
  16752. torch_opinfo_name="positive",
  16753. supports_nvfuser=False,
  16754. ),
  16755. ElementwiseUnaryPythonRefInfo(
  16756. "_refs.real",
  16757. torch_opinfo_name="real",
  16758. supports_nvfuser=False,
  16759. ),
  16760. ElementwiseUnaryPythonRefInfo(
  16761. "_refs.reciprocal",
  16762. torch_opinfo_name="reciprocal",
  16763. ),
  16764. ElementwiseUnaryPythonRefInfo(
  16765. "_refs.round",
  16766. torch_opinfo_name="round",
  16767. # Fails on int32
  16768. # https://github.com/pytorch/pytorch/issues/85258
  16769. supports_nvfuser=False,
  16770. ),
  16771. ElementwiseUnaryPythonRefInfo(
  16772. "_refs.rsqrt",
  16773. torch_opinfo_name="rsqrt",
  16774. ),
  16775. ElementwiseUnaryPythonRefInfo(
  16776. "_refs.sigmoid",
  16777. torch_opinfo_name="sigmoid",
  16778. aliases=('_refs.special.expit',),
  16779. # Reference: https://github.com/pytorch/pytorch/issues/56012
  16780. handles_complex_extremal_values=False,
  16781. handles_large_floats=False,
  16782. ),
  16783. ElementwiseUnaryPythonRefInfo(
  16784. "_refs.sign",
  16785. torch_opinfo_name="sign",
  16786. ),
  16787. ElementwiseUnaryPythonRefInfo(
  16788. "_refs.sgn",
  16789. torch_opinfo_name="sgn",
  16790. # This is an issue with the vectorised abs on CPU
  16791. handles_complex_extremal_values=False,
  16792. handles_large_floats=False,
  16793. ),
  16794. ElementwiseUnaryPythonRefInfo(
  16795. "_refs.signbit",
  16796. torch_opinfo_name="signbit",
  16797. supports_nvfuser=False,
  16798. ),
  16799. ElementwiseUnaryPythonRefInfo(
  16800. "_refs.sin",
  16801. torch_opinfo_name="sin",
  16802. ),
  16803. ElementwiseUnaryPythonRefInfo(
  16804. "_refs.sinc",
  16805. torch_opinfo_name="sinc",
  16806. ),
  16807. ElementwiseUnaryPythonRefInfo(
  16808. "_refs.sinh",
  16809. torch_opinfo_name="sinh",
  16810. ),
  16811. PythonRefInfo(
  16812. "_refs.softmax",
  16813. torch_opinfo_name="softmax",
  16814. torch_opinfo_variant_name="with_dtype",
  16815. ),
  16816. ElementwiseUnaryPythonRefInfo(
  16817. "_refs.sqrt",
  16818. torch_opinfo_name="sqrt",
  16819. ),
  16820. ElementwiseUnaryPythonRefInfo(
  16821. "_refs.square",
  16822. torch_opinfo_name="square",
  16823. skips=(
  16824. # AssertionError: Reference result was farther (2.2417024338305655e-07) from the precise computation
  16825. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_executor', dtypes=(torch.complex64,)),
  16826. ),
  16827. ),
  16828. ElementwiseUnaryPythonRefInfo(
  16829. "_refs.tan",
  16830. torch_opinfo_name="tan",
  16831. ),
  16832. ElementwiseUnaryPythonRefInfo(
  16833. "_refs.tanh",
  16834. torch_opinfo_name="tanh",
  16835. ),
  16836. ElementwiseUnaryPythonRefInfo(
  16837. "_refs.trunc",
  16838. torch_opinfo_name="trunc",
  16839. # Fails on int32
  16840. # https://github.com/pytorch/pytorch/issues/85258
  16841. supports_nvfuser=False,
  16842. ),
  16843. PythonRefInfo(
  16844. "_refs.special.log_softmax",
  16845. torch_opinfo_name="log_softmax", # alias
  16846. torch_opinfo_variant_name="with_dtype",
  16847. supports_out=False,
  16848. ),
  16849. PythonRefInfo(
  16850. "_refs.special.softmax",
  16851. torch_opinfo_name="softmax", # alias
  16852. torch_opinfo_variant_name="with_dtype",
  16853. supports_out=False,
  16854. ),
  16855. #
  16856. # Elementwise Unary Special OpInfos
  16857. #
  16858. ElementwiseUnaryPythonRefInfo(
  16859. "_refs.special.logit",
  16860. torch_opinfo_name="logit",
  16861. supports_nvfuser=False,
  16862. ),
  16863. #
  16864. # Elementwise Unary nn.functional OpInfos
  16865. #
  16866. PythonRefInfo(
  16867. "_refs.nn.functional.alpha_dropout",
  16868. torch_opinfo_name="nn.functional.alpha_dropout",
  16869. supports_nvfuser=False,
  16870. decorators=(
  16871. DecorateInfo(unittest.skip("Expected: dropout is not comparable"),
  16872. 'TestCommon',
  16873. 'test_python_ref'),
  16874. # AssertionError: Tensor-likes are not close!
  16875. DecorateInfo(unittest.skip("Expected: dropout is not comparable"),
  16876. 'TestCommon',
  16877. 'test_python_ref_torch_fallback'),
  16878. DecorateInfo(unittest.skip("Expected: dropout is not comparable"),
  16879. 'TestCommon',
  16880. 'test_python_ref_executor', device_type='cuda'),
  16881. # AssertionError: Tensor-likes are not close!
  16882. DecorateInfo(unittest.skip("Expected: dropout is not comparable"),
  16883. 'TestMathBits',
  16884. 'test_neg_view'),
  16885. # AssertionError: Tensor-likes are not close!
  16886. DecorateInfo(unittest.skip("Expected: dropout is not comparable"),
  16887. 'TestCommon',
  16888. 'test_compare_cpu'),
  16889. )
  16890. ),
  16891. ElementwiseUnaryPythonRefInfo(
  16892. "_refs.nn.functional.celu",
  16893. torch_opinfo_name="nn.functional.celu",
  16894. supports_out=True,
  16895. ),
  16896. ElementwiseUnaryPythonRefInfo(
  16897. "_refs.nn.functional.threshold",
  16898. torch_opinfo_name="nn.functional.threshold",
  16899. supports_nvfuser=False,
  16900. supports_out=True,
  16901. ),
  16902. PythonRefInfo(
  16903. "_refs.nn.functional.dropout",
  16904. torch_opinfo_name="nn.functional.dropout",
  16905. decorators=(
  16906. DecorateInfo(unittest.skip("Expected: dropout is not comparable"),
  16907. 'TestCommon',
  16908. 'test_python_ref'),
  16909. DecorateInfo(unittest.skip("Expected: dropout is not comparable"),
  16910. 'TestCommon',
  16911. 'test_python_ref_torch_fallback'),
  16912. DecorateInfo(unittest.skip("Expected: dropout is not comparable"),
  16913. 'TestCommon',
  16914. 'test_out'),
  16915. DecorateInfo(unittest.skip("Expected: dropout is not comparable"),
  16916. 'TestCommon',
  16917. 'test_out_warning'),
  16918. DecorateInfo(unittest.skip("Expected: dropout is not comparable"),
  16919. 'TestMathBits',
  16920. 'test_conj_view'),
  16921. DecorateInfo(unittest.skip("Expected: dropout is not comparable"),
  16922. 'TestMathBits',
  16923. 'test_neg_conj_view'),
  16924. DecorateInfo(unittest.skip("Expected: dropout is not comparable"),
  16925. 'TestMathBits',
  16926. 'test_neg_view'),
  16927. # dropout is not comparable
  16928. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'),
  16929. DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
  16930. )
  16931. ),
  16932. ElementwiseUnaryPythonRefInfo(
  16933. "_refs.nn.functional.elu",
  16934. torch_opinfo_name="nn.functional.elu",
  16935. supports_out=True,
  16936. ),
  16937. ElementwiseUnaryPythonRefInfo(
  16938. "_refs.nn.functional.hardtanh",
  16939. torch_opinfo_name="nn.functional.hardtanh",
  16940. supports_nvfuser=False,
  16941. supports_out=True,
  16942. ),
  16943. PythonRefInfo( # TODO: Port this to an UnaryOpInfo
  16944. "_refs.nn.functional.gelu",
  16945. torch_opinfo_name="nn.functional.gelu",
  16946. ),
  16947. PythonRefInfo(
  16948. "_refs.nn.functional.layer_norm",
  16949. torch_opinfo_name="nn.functional.layer_norm",
  16950. skips=(
  16951. # Reference result was farther (3.5762786809723224e-07) from the precise computation
  16952. # than the torch result was (2.5068410824946596e-07)!
  16953. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref',
  16954. dtypes=(torch.float32,), device_type='cpu'),
  16955. ),
  16956. ),
  16957. PythonRefInfo(
  16958. "_refs.nn.functional.glu",
  16959. torch_opinfo_name="nn.functional.glu",
  16960. supports_nvfuser=False,
  16961. supports_out=True,
  16962. ),
  16963. PythonRefInfo(
  16964. "_refs.nn.functional.pairwise_distance",
  16965. torch_opinfo_name="nn.functional.pairwise_distance",
  16966. supports_out=True,
  16967. ),
  16968. PythonRefInfo(
  16969. "_refs.nn.functional.pdist",
  16970. torch_opinfo_name="nn.functional.pdist",
  16971. supports_out=True,
  16972. supports_nvfuser=False,
  16973. skips=(
  16974. # RunTimeError: no _refs support for torch.Tensor.index_select
  16975. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),
  16976. )),
  16977. PythonRefInfo(
  16978. "_refs.nn.functional.leaky_relu",
  16979. torch_opinfo_name="nn.functional.leaky_relu",
  16980. supports_out=True,
  16981. ),
  16982. PythonRefInfo(
  16983. "_refs.nn.functional.log_softmax",
  16984. torch_opinfo_name="log_softmax", # alias
  16985. torch_opinfo_variant_name="with_dtype",
  16986. supports_out=False,
  16987. ),
  16988. PythonRefInfo(
  16989. "_refs.nn.functional.poisson_nll_loss",
  16990. torch_opinfo_name="nn.functional.poisson_nll_loss",
  16991. ),
  16992. ElementwiseUnaryPythonRefInfo(
  16993. "_refs.nn.functional.prelu",
  16994. torch_opinfo_name="nn.functional.prelu",
  16995. supports_nvfuser=False,
  16996. ),
  16997. ElementwiseUnaryPythonRefInfo(
  16998. "_refs.nn.functional.relu",
  16999. torch_opinfo_name="nn.functional.relu",
  17000. supports_nvfuser=False,
  17001. supports_out=True,
  17002. ),
  17003. ElementwiseUnaryPythonRefInfo(
  17004. "_refs.nn.functional.relu6",
  17005. torch_opinfo_name="nn.functional.relu6",
  17006. supports_out=True,
  17007. ),
  17008. ElementwiseUnaryPythonRefInfo(
  17009. "_refs.nn.functional.mish",
  17010. torch_opinfo_name="nn.functional.mish",
  17011. supports_out=True,
  17012. ),
  17013. ElementwiseUnaryPythonRefInfo(
  17014. "_refs.nn.functional.selu",
  17015. torch_opinfo_name="nn.functional.selu",
  17016. supports_out=True,
  17017. ),
  17018. PythonRefInfo(
  17019. "_refs.nn.functional.softmax",
  17020. torch_opinfo_name="softmax", # alias
  17021. torch_opinfo_variant_name="with_dtype",
  17022. supports_out=False,
  17023. ),
  17024. PythonRefInfo(
  17025. "_refs.nn.functional.softmin",
  17026. torch_opinfo_name="nn.functional.softmin",
  17027. torch_opinfo_variant_name="with_dtype",
  17028. supports_out=False,
  17029. ),
  17030. ElementwiseUnaryPythonRefInfo(
  17031. "_refs.nn.functional.softplus",
  17032. torch_opinfo_name="nn.functional.softplus",
  17033. ),
  17034. PythonRefInfo(
  17035. "_refs.nn.functional.l1_loss",
  17036. torch_opinfo_name="nn.functional.l1_loss",
  17037. # TestCommonCUDA::test_python_ref_executor__refs_nn_functional_l1_loss_executor_nvfuser_cuda_float32
  17038. # - RuntimeError: No reduction axis specified
  17039. supports_nvfuser=False,
  17040. ),
  17041. PythonRefInfo(
  17042. "_refs.nn.functional.margin_ranking_loss",
  17043. torch_opinfo_name="nn.functional.margin_ranking_loss",
  17044. supports_nvfuser=False,
  17045. ),
  17046. PythonRefInfo(
  17047. "_refs.nn.functional.mse_loss",
  17048. torch_opinfo_name="nn.functional.mse_loss",
  17049. supports_nvfuser=False,
  17050. ),
  17051. PythonRefInfo(
  17052. "_refs.nn.functional.hinge_embedding_loss",
  17053. torch_opinfo_name="nn.functional.hinge_embedding_loss",
  17054. supports_nvfuser=False,
  17055. ),
  17056. PythonRefInfo(
  17057. "_refs.nn.functional.nll_loss",
  17058. torch_opinfo_name="nn.functional.nll_loss",
  17059. # The corresponding PyTorch op doesn't support out. But the ref is
  17060. # registered as a decomp and ATen has an out variant.
  17061. supports_out=True,
  17062. supports_nvfuser=False,
  17063. # For simpler indexing, we flatten target indices, then reshape the result tensor.
  17064. # This creates inconsistent view state with reference impl.
  17065. validate_view_consistency=False,
  17066. skips=(
  17067. # RuntimeError: It appears that you're trying to get value out of a tracing tensor - erroring out!
  17068. DecorateInfo(
  17069. unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', device_type="cuda"
  17070. ),
  17071. ),
  17072. ),
  17073. PythonRefInfo(
  17074. "_refs.nn.functional.huber_loss",
  17075. torch_opinfo_name="nn.functional.huber_loss",
  17076. # The corresponding PyTorch op doesn't support out. But the ref is
  17077. # registered as a decomp and ATen has an out variant.
  17078. supports_out=True,
  17079. ),
  17080. ElementwiseUnaryPythonRefInfo(
  17081. "_refs.nn.functional.tanhshrink",
  17082. torch_opinfo_name="nn.functional.tanhshrink",
  17083. ),
  17084. ElementwiseUnaryPythonRefInfo(
  17085. "_refs.nn.functional.hardshrink",
  17086. torch_opinfo_name="nn.functional.hardshrink",
  17087. supports_nvfuser=False,
  17088. ),
  17089. ElementwiseUnaryPythonRefInfo(
  17090. "_refs.nn.functional.softshrink",
  17091. torch_opinfo_name="nn.functional.softshrink",
  17092. supports_nvfuser=False,
  17093. ),
  17094. #
  17095. # Elementwise Binary Reference OpInfos
  17096. #
  17097. ElementwiseBinaryPythonRefInfo(
  17098. "_refs.add",
  17099. torch_opinfo_name="add",
  17100. # https://github.com/pytorch/pytorch/issues/76944
  17101. supports_two_python_scalars=True,
  17102. supports_one_python_scalar=True,
  17103. ),
  17104. ElementwiseBinaryPythonRefInfo(
  17105. "_refs.atan2",
  17106. torch_opinfo_name="atan2",
  17107. ),
  17108. ElementwiseBinaryPythonRefInfo(
  17109. "_refs.bitwise_and",
  17110. torch_opinfo_name="bitwise_and",
  17111. ),
  17112. ElementwiseBinaryPythonRefInfo(
  17113. "_refs.bitwise_left_shift",
  17114. torch_opinfo_name="bitwise_left_shift",
  17115. supports_nvfuser=False,
  17116. skips=(
  17117. # https://github.com/pytorch/pytorch/issues/70904
  17118. DecorateInfo(unittest.skip("Some inputs produce undefined outputs"), 'TestCommon', 'test_compare_cpu'),
  17119. ),
  17120. ),
  17121. ElementwiseBinaryPythonRefInfo(
  17122. "_refs.bitwise_right_shift",
  17123. torch_opinfo_name="bitwise_right_shift",
  17124. supports_nvfuser=False,
  17125. skips=(
  17126. # # https://github.com/pytorch/pytorch/issues/70904
  17127. DecorateInfo(unittest.skip("Skipped some inputs produce undefined outputs"), 'TestCommon', 'test_compare_cpu'),
  17128. ),
  17129. ),
  17130. ElementwiseBinaryPythonRefInfo(
  17131. "_refs.bitwise_or",
  17132. torch_opinfo_name="bitwise_or",
  17133. ),
  17134. ElementwiseBinaryPythonRefInfo(
  17135. "_refs.bitwise_xor",
  17136. torch_opinfo_name="bitwise_xor",
  17137. ),
  17138. ElementwiseBinaryPythonRefInfo(
  17139. "_refs.copysign",
  17140. torch_opinfo_name="copysign",
  17141. supports_nvfuser=False,
  17142. skips=(
  17143. # RuntimeError: Expected divisor (b) to be on the same device (cuda:0) as dividend (a), but it is found on cpu!
  17144. DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'),
  17145. )
  17146. ),
  17147. ElementwiseBinaryPythonRefInfo(
  17148. "_refs.div",
  17149. torch_opinfo_name="div",
  17150. torch_opinfo_variant_name="no_rounding_mode",
  17151. # https://github.com/pytorch/pytorch/issues/76944
  17152. supports_two_python_scalars=True,
  17153. supports_one_python_scalar=True,
  17154. supports_nvfuser=False,
  17155. skips=(
  17156. # NotImplementedError: argument of type: <class 'complex'>
  17157. DecorateInfo(
  17158. unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_executor',
  17159. dtypes=(torch.complex32, torch.complex64, torch.complex128,)
  17160. ),
  17161. # Reference result was farther (0.7433461727239705) from the precise
  17162. # computation than the torch result was (nan)!
  17163. DecorateInfo(
  17164. unittest.expectedFailure, 'TestCommon', 'test_python_ref',
  17165. dtypes=(torch.complex32,), device_type="cuda"
  17166. ),
  17167. # Reference result was farther (0.7433461727239705) from the precise
  17168. # computation than the torch result was (nan)!
  17169. DecorateInfo(
  17170. unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
  17171. dtypes=(torch.complex32,), device_type="cuda"
  17172. ),
  17173. ),
  17174. ),
  17175. ElementwiseBinaryPythonRefInfo(
  17176. "_refs.div",
  17177. torch_opinfo_name="div",
  17178. torch_opinfo_variant_name="trunc_rounding",
  17179. # https://github.com/pytorch/pytorch/issues/76944
  17180. supports_two_python_scalars=True,
  17181. supports_one_python_scalar=True,
  17182. supports_nvfuser=False,
  17183. ),
  17184. ElementwiseBinaryPythonRefInfo(
  17185. "_refs.div",
  17186. torch_opinfo_name="div",
  17187. torch_opinfo_variant_name="floor_rounding",
  17188. # https://github.com/pytorch/pytorch/issues/76944
  17189. supports_two_python_scalars=True,
  17190. supports_one_python_scalar=True,
  17191. supports_nvfuser=False,
  17192. ),
  17193. ElementwiseBinaryPythonRefInfo(
  17194. "_refs.eq",
  17195. torch_opinfo_name="eq",
  17196. ),
  17197. ElementwiseBinaryPythonRefInfo(
  17198. "_refs.float_power",
  17199. torch_opinfo_name="float_power",
  17200. supports_nvfuser=False,
  17201. skips=(
  17202. # Test doesn't account for float -> double type promotion
  17203. DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'),
  17204. )
  17205. ),
  17206. ElementwiseBinaryPythonRefInfo(
  17207. "_refs.logaddexp",
  17208. torch_opinfo_name="logaddexp",
  17209. supports_nvfuser=False,
  17210. ),
  17211. ElementwiseBinaryPythonRefInfo(
  17212. "_refs.floor_divide",
  17213. torch_opinfo_name="floor_divide",
  17214. rhs_make_tensor_kwargs=dict(exclude_zero=True),
  17215. # https://github.com/pytorch/pytorch/issues/76944
  17216. supports_two_python_scalars=True,
  17217. supports_one_python_scalar=True,
  17218. supports_nvfuser=False,
  17219. # bfloat16 floor_divide compared with a float32 reference works inconsistently
  17220. skips=(
  17221. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref',
  17222. dtypes=(torch.bfloat16,)),
  17223. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_torch_fallback',
  17224. dtypes=(torch.bfloat16,)),
  17225. ),
  17226. ),
  17227. ElementwiseBinaryPythonRefInfo(
  17228. "_refs.fmax",
  17229. torch_opinfo_name="fmax",
  17230. supports_rhs_python_scalar=False,
  17231. supports_nvfuser=False,
  17232. ),
  17233. ElementwiseBinaryPythonRefInfo(
  17234. "_refs.fmin",
  17235. torch_opinfo_name="fmin",
  17236. supports_rhs_python_scalar=False,
  17237. supports_nvfuser=False,
  17238. ),
  17239. ElementwiseBinaryPythonRefInfo(
  17240. "_refs.fmod",
  17241. torch_opinfo_name="fmod",
  17242. rhs_make_tensor_kwargs={'exclude_zero': True},
  17243. supports_rhs_python_scalar=True,
  17244. skips=(
  17245. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref',
  17246. dtypes=(torch.bfloat16,), device_type='cpu'),
  17247. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_torch_fallback',
  17248. dtypes=(torch.bfloat16,), device_type='cpu'),
  17249. ),
  17250. ),
  17251. ElementwiseBinaryPythonRefInfo(
  17252. "_refs.gcd",
  17253. torch_opinfo_name="gcd",
  17254. supports_nvfuser=False,
  17255. ),
  17256. ElementwiseBinaryPythonRefInfo(
  17257. "_refs.ge",
  17258. torch_opinfo_name="ge",
  17259. ),
  17260. ElementwiseBinaryPythonRefInfo(
  17261. "_refs.gt",
  17262. torch_opinfo_name="gt",
  17263. ),
  17264. ElementwiseBinaryPythonRefInfo(
  17265. "_refs.heaviside",
  17266. torch_opinfo_name="heaviside",
  17267. supports_rhs_python_scalar=False,
  17268. supports_nvfuser=False,
  17269. ),
  17270. ElementwiseBinaryPythonRefInfo(
  17271. "_refs.hypot",
  17272. torch_opinfo_name="hypot",
  17273. supports_rhs_python_scalar=False,
  17274. supports_nvfuser=False,
  17275. ),
  17276. ElementwiseBinaryPythonRefInfo(
  17277. "_refs.igamma",
  17278. torch_opinfo_name="igamma",
  17279. supports_nvfuser=False,
  17280. ),
  17281. ElementwiseBinaryPythonRefInfo(
  17282. "_refs.igammac",
  17283. torch_opinfo_name="igammac",
  17284. supports_nvfuser=False,
  17285. ),
  17286. ElementwiseBinaryPythonRefInfo(
  17287. "_refs.isclose",
  17288. torch_opinfo_name="isclose",
  17289. supports_nvfuser=False,
  17290. skips=(
  17291. # Intentional xfail -- isclose does not type promote
  17292. DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'),
  17293. ),
  17294. ),
  17295. ElementwiseBinaryPythonRefInfo(
  17296. "_refs.lcm",
  17297. torch_opinfo_name="lcm",
  17298. supports_nvfuser=False,
  17299. ),
  17300. ElementwiseBinaryPythonRefInfo(
  17301. "_refs.le",
  17302. torch_opinfo_name="le",
  17303. ),
  17304. ElementwiseBinaryPythonRefInfo(
  17305. "_refs.logical_and",
  17306. torch_opinfo_name="logical_and",
  17307. ),
  17308. ElementwiseUnaryPythonRefInfo(
  17309. "_refs.logical_not",
  17310. torch_opinfo_name="logical_not",
  17311. ),
  17312. ElementwiseBinaryPythonRefInfo(
  17313. "_refs.logical_or",
  17314. torch_opinfo_name="logical_or",
  17315. ),
  17316. ElementwiseBinaryPythonRefInfo(
  17317. "_refs.logical_xor",
  17318. torch_opinfo_name="logical_xor",
  17319. ),
  17320. ElementwiseBinaryPythonRefInfo(
  17321. "_refs.lt",
  17322. torch_opinfo_name="lt",
  17323. ),
  17324. ElementwiseBinaryPythonRefInfo(
  17325. "_refs.maximum",
  17326. torch_opinfo_name="maximum",
  17327. supports_nvfuser=False,
  17328. skips=(
  17329. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
  17330. ),
  17331. ),
  17332. ElementwiseBinaryPythonRefInfo(
  17333. "_refs.minimum",
  17334. torch_opinfo_name="minimum",
  17335. supports_nvfuser=False,
  17336. skips=(
  17337. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
  17338. ),
  17339. ),
  17340. ElementwiseBinaryPythonRefInfo(
  17341. "_refs.mul",
  17342. torch_opinfo_name="mul",
  17343. # https://github.com/pytorch/pytorch/issues/76944
  17344. supports_two_python_scalars=True,
  17345. supports_one_python_scalar=True,
  17346. skips=(
  17347. # Reference result was farther (0.0) from the precise computation
  17348. # than the torch result was (nan)!
  17349. DecorateInfo(
  17350. unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor',
  17351. dtypes=(torch.complex32,),
  17352. ),
  17353. # Reference result was farther (0.0) from the precise computation
  17354. # than the torch result was (nan)!
  17355. DecorateInfo(
  17356. unittest.expectedFailure, 'TestCommon', 'test_python_ref',
  17357. dtypes=(torch.complex32,), device_type='cuda'
  17358. ),
  17359. # Reference result was farther (0.0) from the precise computation
  17360. # than the torch result was (nan)!
  17361. DecorateInfo(
  17362. unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
  17363. dtypes=(torch.complex32,), device_type='cuda'
  17364. ),
  17365. )
  17366. ),
  17367. ElementwiseBinaryPythonRefInfo(
  17368. "_refs.ne",
  17369. torch_opinfo_name="ne",
  17370. ),
  17371. ElementwiseBinaryPythonRefInfo(
  17372. "_refs.nextafter",
  17373. torch_opinfo_name="nextafter",
  17374. supports_nvfuser=False,
  17375. ),
  17376. ElementwiseBinaryPythonRefInfo(
  17377. "_refs.pow",
  17378. torch_opinfo_name="pow",
  17379. supports_nvfuser=False, # clone default
  17380. skips=(
  17381. # Reference result was farther (inf) from the precise
  17382. # computation than the torch result was (nan)!
  17383. DecorateInfo(
  17384. unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor',
  17385. dtypes=(torch.complex32,),
  17386. ),
  17387. # Reference result was farther (inf) from the precise
  17388. # computation than the torch result was (nan)!
  17389. DecorateInfo(
  17390. unittest.expectedFailure, 'TestCommon', 'test_python_ref',
  17391. dtypes=(torch.complex32,), device_type="cuda"
  17392. ),
  17393. # Reference result was farther (inf) from the precise
  17394. # computation than the torch result was (nan)!
  17395. DecorateInfo(
  17396. unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
  17397. dtypes=(torch.complex32,), device_type="cuda"
  17398. ),
  17399. ),
  17400. ),
  17401. ElementwiseBinaryPythonRefInfo(
  17402. "_refs.remainder",
  17403. torch_opinfo_name="remainder",
  17404. skips=(
  17405. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref',
  17406. dtypes=(torch.bfloat16,), device_type='cpu'),
  17407. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_torch_fallback',
  17408. dtypes=(torch.bfloat16,), device_type='cpu'),
  17409. ),
  17410. ),
  17411. ElementwiseBinaryPythonRefInfo(
  17412. "_refs.rsub",
  17413. torch_opinfo_name="rsub",
  17414. # https://github.com/pytorch/pytorch/issues/76944
  17415. skips=(
  17416. # Reference result was farther (nan) from the precise computation than
  17417. # the torch result was (nan)!
  17418. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',
  17419. dtypes=(torch.chalf,), device_type='cpu'),
  17420. # Reference result was farther (nan) from the precise computation than
  17421. # the torch result was (nan)!
  17422. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
  17423. dtypes=(torch.chalf,), device_type='cpu'),
  17424. ),
  17425. ),
  17426. ElementwiseBinaryPythonRefInfo(
  17427. "_refs.sub",
  17428. torch_opinfo_name="sub",
  17429. # https://github.com/pytorch/pytorch/issues/76944
  17430. supports_two_python_scalars=True,
  17431. supports_one_python_scalar=True,
  17432. ),
  17433. ElementwiseBinaryPythonRefInfo(
  17434. "_refs.true_divide",
  17435. torch_opinfo_name="true_divide",
  17436. # https://github.com/pytorch/pytorch/issues/76944
  17437. supports_two_python_scalars=True,
  17438. supports_one_python_scalar=True,
  17439. skips=(
  17440. # Reference result was farther (0.7433461727239705) from the precise
  17441. # computation than the torch result was (nan)!
  17442. DecorateInfo(
  17443. unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor',
  17444. dtypes=(torch.complex32,),
  17445. ),
  17446. # Reference result was farther (0.7433461727239705) from the precise
  17447. # computation than the torch result was (nan)!
  17448. DecorateInfo(
  17449. unittest.expectedFailure, 'TestCommon', 'test_python_ref',
  17450. dtypes=(torch.complex32,), device_type="cuda"
  17451. ),
  17452. # Reference result was farther (0.7433461727239705) from the precise
  17453. # computation than the torch result was (nan)!
  17454. DecorateInfo(
  17455. unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
  17456. dtypes=(torch.complex32,), device_type="cuda"
  17457. ),
  17458. ),
  17459. ),
  17460. #
  17461. # Elementwise Ternary Reference OpInfos
  17462. #
  17463. PythonRefInfo(
  17464. "_refs.addcdiv",
  17465. torch_opinfo_name="addcdiv",
  17466. ),
  17467. PythonRefInfo(
  17468. "_refs.addcmul",
  17469. torch_opinfo_name="addcmul",
  17470. ),
  17471. ElementwiseBinaryPythonRefInfo(
  17472. "_refs.clamp_min",
  17473. torch_opinfo_name="clamp_min",
  17474. supports_nvfuser=False,
  17475. skips=(
  17476. # test error disabled since rhs non-tensor python scalar is supported
  17477. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
  17478. ),
  17479. ),
  17480. ElementwiseBinaryPythonRefInfo(
  17481. "_refs.clamp_max",
  17482. torch_opinfo_name="clamp_max",
  17483. supports_nvfuser=False,
  17484. skips=(
  17485. # test error disabled since rhs non-tensor python scalar is supported
  17486. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
  17487. ),
  17488. ),
  17489. PythonRefInfo(
  17490. "_refs.clamp",
  17491. torch_opinfo_name="clamp",
  17492. supports_nvfuser=False,
  17493. ),
  17494. PythonRefInfo(
  17495. "_refs.nn.functional.triplet_margin_loss",
  17496. torch_opinfo_name="nn.functional.triplet_margin_loss",
  17497. supports_out=False,
  17498. # TODO: Uses minimum and clamp, which don't support nvfuser.
  17499. supports_nvfuser=False,
  17500. skips=(
  17501. # AssertionError: Tensor-likes are not close!
  17502. # Greatest absolute difference: 6.103515625e-05 at index (4,) (up to 1e-05 allowed)
  17503. # Greatest relative difference: 8.519846983548175e-06 at index (4,) (up to 1.3e-06 allowed)
  17504. DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref',
  17505. dtypes=(torch.uint8,), device_type="cpu"),
  17506. )
  17507. ),
  17508. ElementwiseBinaryPythonRefInfo(
  17509. "_refs.xlogy",
  17510. torch_opinfo_name="xlogy",
  17511. supports_one_python_scalar=True,
  17512. supports_nvfuser=False,
  17513. ),
  17514. #
  17515. # Elementwise Binary Special OpInfos
  17516. #
  17517. ElementwiseBinaryPythonRefInfo(
  17518. "_refs.special.xlog1py",
  17519. torch_opinfo_name="special.xlog1py",
  17520. supports_one_python_scalar=True,
  17521. supports_nvfuser=False,
  17522. ),
  17523. #
  17524. # Data Conversion & Data Movement Opinfos
  17525. #
  17526. ElementwiseUnaryPythonRefInfo(
  17527. "_refs._conversions.bfloat16",
  17528. torch_opinfo_name="bfloat16",
  17529. # TODO: If self already has the correct dtype and device, then self is
  17530. # returned ignoring memory_format.
  17531. # https://github.com/pytorch/pytorch/issues/86558
  17532. validate_view_consistency=False,
  17533. supports_nvfuser=False,
  17534. ),
  17535. ElementwiseUnaryPythonRefInfo(
  17536. "_refs._conversions.bool",
  17537. torch_opinfo_name="bool",
  17538. # TODO: If self already has the correct dtype and device, then self is
  17539. # returned ignoring memory_format.
  17540. # https://github.com/pytorch/pytorch/issues/86558
  17541. validate_view_consistency=False,
  17542. supports_nvfuser=False,
  17543. ),
  17544. ElementwiseUnaryPythonRefInfo(
  17545. "_refs._conversions.byte",
  17546. torch_opinfo_name="byte",
  17547. # TODO: If self already has the correct dtype and device, then self is
  17548. # returned ignoring memory_format.
  17549. # https://github.com/pytorch/pytorch/issues/86558
  17550. validate_view_consistency=False,
  17551. supports_nvfuser=False,
  17552. skips=(
  17553. DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'),
  17554. )
  17555. ),
  17556. ElementwiseUnaryPythonRefInfo(
  17557. "_refs._conversions.char",
  17558. torch_opinfo_name="char",
  17559. # TODO: If self already has the correct dtype and device, then self is
  17560. # returned ignoring memory_format.
  17561. # https://github.com/pytorch/pytorch/issues/86558
  17562. validate_view_consistency=False,
  17563. supports_nvfuser=False,
  17564. skips=(
  17565. DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'),
  17566. )
  17567. ),
  17568. ElementwiseBinaryPythonRefInfo(
  17569. "_refs._conversions.complex",
  17570. torch_opinfo_name="complex",
  17571. error_inputs_func=partial(error_inputs_complex, is_ref=True),
  17572. # prims.empty_strided.default does not support nvfuser
  17573. supports_nvfuser=False,
  17574. skips=(
  17575. # Test doesn't account for complex's type promotion semantics
  17576. DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'),
  17577. )
  17578. ),
  17579. ElementwiseUnaryPythonRefInfo(
  17580. "_refs._conversions.double",
  17581. torch_opinfo_name="double",
  17582. # TODO: If self already has the correct dtype and device, then self is
  17583. # returned ignoring memory_format.
  17584. # https://github.com/pytorch/pytorch/issues/86558
  17585. validate_view_consistency=False,
  17586. supports_nvfuser=False,
  17587. ),
  17588. ElementwiseUnaryPythonRefInfo(
  17589. "_refs._conversions.float",
  17590. torch_opinfo_name="float",
  17591. # TODO: If self already has the correct dtype and device, then self is
  17592. # returned ignoring memory_format.
  17593. # https://github.com/pytorch/pytorch/issues/86558
  17594. validate_view_consistency=False,
  17595. supports_nvfuser=False,
  17596. ),
  17597. ElementwiseUnaryPythonRefInfo(
  17598. "_refs._conversions.half",
  17599. torch_opinfo_name="half",
  17600. # TODO: If self already has the correct dtype and device, then self is
  17601. # returned ignoring memory_format.
  17602. # https://github.com/pytorch/pytorch/issues/86558
  17603. validate_view_consistency=False,
  17604. supports_nvfuser=False,
  17605. ),
  17606. ElementwiseUnaryPythonRefInfo(
  17607. "_refs._conversions.int",
  17608. torch_opinfo_name="int",
  17609. # TODO: If self already has the correct dtype and device, then self is
  17610. # returned ignoring memory_format.
  17611. # https://github.com/pytorch/pytorch/issues/86558
  17612. validate_view_consistency=False,
  17613. supports_nvfuser=False,
  17614. skips=(
  17615. DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'),
  17616. )
  17617. ),
  17618. ElementwiseUnaryPythonRefInfo(
  17619. "_refs._conversions.long",
  17620. torch_opinfo_name="long",
  17621. # TODO: If self already has the correct dtype and device, then self is
  17622. # returned ignoring memory_format.
  17623. # https://github.com/pytorch/pytorch/issues/86558
  17624. validate_view_consistency=False,
  17625. supports_nvfuser=False,
  17626. skips=(
  17627. DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'),
  17628. )
  17629. ),
  17630. ElementwiseUnaryPythonRefInfo(
  17631. "_refs._conversions.short",
  17632. torch_opinfo_name="short",
  17633. # TODO: If self already has the correct dtype and device, then self is
  17634. # returned ignoring memory_format.
  17635. # https://github.com/pytorch/pytorch/issues/86558
  17636. validate_view_consistency=False,
  17637. supports_nvfuser=False,
  17638. skips=(
  17639. DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'),
  17640. )
  17641. ),
  17642. ElementwiseUnaryPythonRefInfo(
  17643. "_refs._conversions.chalf",
  17644. torch_opinfo_name="chalf",
  17645. # TODO: If self already has the correct dtype and device, then self is
  17646. # returned ignoring memory_format.
  17647. # https://github.com/pytorch/pytorch/issues/86558
  17648. validate_view_consistency=False,
  17649. supports_nvfuser=False,
  17650. ),
  17651. ElementwiseUnaryPythonRefInfo(
  17652. "_refs._conversions.cfloat",
  17653. torch_opinfo_name="cfloat",
  17654. # TODO: If self already has the correct dtype and device, then self is
  17655. # returned ignoring memory_format.
  17656. # https://github.com/pytorch/pytorch/issues/86558
  17657. validate_view_consistency=False,
  17658. supports_nvfuser=False,
  17659. ),
  17660. ElementwiseUnaryPythonRefInfo(
  17661. "_refs._conversions.cdouble",
  17662. torch_opinfo_name="cdouble",
  17663. # TODO: If self already has the correct dtype and device, then self is
  17664. # returned ignoring memory_format.
  17665. # https://github.com/pytorch/pytorch/issues/86558
  17666. validate_view_consistency=False,
  17667. supports_nvfuser=False,
  17668. ),
  17669. PythonRefInfo(
  17670. "_refs.clone",
  17671. torch_opinfo_name="clone",
  17672. ),
  17673. #
  17674. # View & Shape OpInfos
  17675. #
  17676. PythonRefInfo(
  17677. "_refs.atleast_1d",
  17678. torch_opinfo_name="atleast_1d",
  17679. validate_view_consistency=False,
  17680. supports_nvfuser=False
  17681. ),
  17682. PythonRefInfo(
  17683. "_refs.atleast_2d",
  17684. torch_opinfo_name="atleast_2d",
  17685. validate_view_consistency=False,
  17686. supports_nvfuser=False
  17687. ),
  17688. PythonRefInfo(
  17689. "_refs.atleast_3d",
  17690. torch_opinfo_name="atleast_3d",
  17691. validate_view_consistency=False,
  17692. supports_nvfuser=False
  17693. ),
  17694. PythonRefInfo(
  17695. "_refs.as_strided",
  17696. torch_opinfo_name="as_strided",
  17697. # FIXME: doesn't support chalf
  17698. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
  17699. supports_nvfuser=False,
  17700. skips=(
  17701. # cloned_mutable_input.is_same(returned_output) INTERNAL ASSERT FAILED
  17702. DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'),
  17703. DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'),
  17704. DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_conj_view'),
  17705. ),
  17706. ),
  17707. PythonRefInfo(
  17708. "_refs.as_strided",
  17709. torch_opinfo_name="as_strided",
  17710. torch_opinfo_variant_name="partial_views",
  17711. # FIXME: doesn't support chalf
  17712. dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
  17713. supports_nvfuser=False,
  17714. skips=(
  17715. # cloned_mutable_input.is_same(returned_output) INTERNAL ASSERT FAILED
  17716. DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'),
  17717. DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'),
  17718. DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_conj_view'),
  17719. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'),
  17720. ),
  17721. ),
  17722. PythonRefInfo(
  17723. "_refs.as_strided_scatter",
  17724. torch_opinfo_name="as_strided_scatter",
  17725. supports_nvfuser=False,
  17726. # returns a view of an intermediate tensor (as_strided)
  17727. validate_view_consistency=False,
  17728. ),
  17729. PythonRefInfo(
  17730. "_refs.broadcast_shapes",
  17731. torch_opinfo_name="broadcast_shapes",
  17732. supports_nvfuser=False,
  17733. ),
  17734. PythonRefInfo(
  17735. "_refs.broadcast_tensors",
  17736. torch_opinfo_name="broadcast_tensors",
  17737. ),
  17738. PythonRefInfo(
  17739. "_refs.broadcast_to",
  17740. torch_opinfo_name="broadcast_to",
  17741. ),
  17742. PythonRefInfo(
  17743. "_refs.cat",
  17744. torch_opinfo_name="cat",
  17745. supports_nvfuser=False,
  17746. skips=(
  17747. # FIXME: AssertionError: RuntimeError not raised
  17748. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
  17749. ),
  17750. ),
  17751. PythonRefInfo(
  17752. "_refs.chunk",
  17753. torch_opinfo_name="chunk",
  17754. supports_nvfuser=False,
  17755. ),
  17756. PythonRefInfo(
  17757. "_refs.column_stack",
  17758. torch_opinfo_name="column_stack",
  17759. supports_nvfuser=False,
  17760. ),
  17761. ElementwiseUnaryPythonRefInfo(
  17762. "_refs.conj",
  17763. torch_opinfo_name="conj",
  17764. supports_nvfuser=False,
  17765. ),
  17766. PythonRefInfo(
  17767. "_refs.constant_pad_nd",
  17768. torch_opinfo_name="constant_pad_nd",
  17769. supports_nvfuser=False,
  17770. ),
  17771. PythonRefInfo(
  17772. "_refs.contiguous",
  17773. torch_opinfo_name="contiguous",
  17774. supports_nvfuser=False,
  17775. ),
  17776. PythonRefInfo(
  17777. "_refs.dsplit",
  17778. torch_opinfo_name="dsplit",
  17779. supports_nvfuser=False,
  17780. ),
  17781. PythonRefInfo(
  17782. "_refs.diag",
  17783. torch_opinfo_name="diag",
  17784. supports_nvfuser=False,
  17785. ),
  17786. PythonRefInfo(
  17787. "_refs.diagonal",
  17788. torch_opinfo_name="diagonal",
  17789. supports_nvfuser=False,
  17790. ),
  17791. PythonRefInfo(
  17792. "_refs.diagonal_copy",
  17793. torch_opinfo_name="diagonal_copy",
  17794. supports_nvfuser=False,
  17795. ),
  17796. PythonRefInfo(
  17797. "_refs.diagonal_scatter",
  17798. torch_opinfo_name="diagonal_scatter",
  17799. supports_out=True,
  17800. supports_nvfuser=False,
  17801. # returns a view of an intermediate tensor (as_strided)
  17802. validate_view_consistency=False,
  17803. ),
  17804. PythonRefInfo(
  17805. "_refs.diag_embed",
  17806. torch_opinfo_name="diag_embed",
  17807. supports_out=True,
  17808. supports_nvfuser=False,
  17809. ),
  17810. PythonRefInfo(
  17811. "_refs.dstack",
  17812. torch_opinfo_name="dstack",
  17813. supports_nvfuser=False,
  17814. skips=(
  17815. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
  17816. ),
  17817. ),
  17818. PythonRefInfo(
  17819. "_refs.expand",
  17820. torch_opinfo_name="expand",
  17821. supports_nvfuser=False,
  17822. ),
  17823. PythonRefInfo(
  17824. "_refs.expand_as",
  17825. torch_opinfo_name="expand_as",
  17826. supports_nvfuser=False,
  17827. ),
  17828. PythonRefInfo(
  17829. "_refs.flatten",
  17830. torch_opinfo_name="flatten",
  17831. supports_nvfuser=False,
  17832. ),
  17833. PythonRefInfo(
  17834. "_refs.flip",
  17835. torch_opinfo_name="flip",
  17836. supports_nvfuser=False,
  17837. ),
  17838. PythonRefInfo(
  17839. "_refs.fliplr",
  17840. torch_opinfo_name="fliplr",
  17841. supports_nvfuser=False,
  17842. ),
  17843. PythonRefInfo(
  17844. "_refs.flipud",
  17845. torch_opinfo_name="flipud",
  17846. supports_nvfuser=False,
  17847. ),
  17848. PythonRefInfo(
  17849. "_refs.hstack",
  17850. torch_opinfo_name="hstack",
  17851. supports_nvfuser=False,
  17852. skips=(
  17853. # https://github.com/pytorch/pytorch/issues/78613
  17854. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
  17855. ),
  17856. ),
  17857. PythonRefInfo(
  17858. "_refs.narrow",
  17859. torch_opinfo_name="narrow",
  17860. supports_nvfuser=False,
  17861. error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=True, is_ref=True),
  17862. ),
  17863. PythonRefInfo(
  17864. "_refs.narrow_copy",
  17865. torch_opinfo_name="narrow_copy",
  17866. supports_out=True,
  17867. supports_nvfuser=False,
  17868. error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=False, is_ref=True),
  17869. ),
  17870. PythonRefInfo(
  17871. "_refs.nn.functional.group_norm",
  17872. torch_opinfo_name="nn.functional.group_norm",
  17873. supports_nvfuser=False,
  17874. validate_view_consistency=False,
  17875. ),
  17876. PythonRefInfo(
  17877. "_refs.native_layer_norm",
  17878. torch_opinfo_name="native_layer_norm",
  17879. skips=(
  17880. DecorateInfo(unittest.skip("Skipped!"), "TestCommon", "test_python_ref",
  17881. device_type="cpu", dtypes=(torch.float32,)),
  17882. DecorateInfo(unittest.skip("Skipped!"), "TestCommon", "test_python_ref_torch_fallback",
  17883. device_type="cpu", dtypes=(torch.float32,)),
  17884. ),
  17885. ),
  17886. PythonRefInfo(
  17887. "_refs.permute",
  17888. torch_opinfo_name="permute",
  17889. ),
  17890. PythonRefInfo(
  17891. "_refs.ravel",
  17892. torch_opinfo_name="ravel",
  17893. supports_nvfuser=False,
  17894. ),
  17895. PythonRefInfo(
  17896. "_refs.repeat",
  17897. torch_opinfo_name="repeat",
  17898. supports_nvfuser=False,
  17899. validate_view_consistency=False,
  17900. ),
  17901. PythonRefInfo(
  17902. "_refs.reshape",
  17903. torch_opinfo_name="reshape",
  17904. supports_nvfuser=False,
  17905. ),
  17906. PythonRefInfo(
  17907. "_refs.reshape_as",
  17908. torch_opinfo_name="reshape_as",
  17909. supports_nvfuser=False,
  17910. ),
  17911. PythonRefInfo(
  17912. "_refs.roll",
  17913. torch_opinfo_name="roll",
  17914. validate_view_consistency=False,
  17915. supports_nvfuser=False,
  17916. ),
  17917. PythonRefInfo(
  17918. "_refs.rot90",
  17919. torch_opinfo_name="rot90",
  17920. validate_view_consistency=False,
  17921. supports_nvfuser=False,
  17922. ),
  17923. PythonRefInfo(
  17924. "_refs.stack",
  17925. torch_opinfo_name="stack",
  17926. supports_nvfuser=False,
  17927. validate_view_consistency=False,
  17928. ),
  17929. PythonRefInfo(
  17930. "_refs.squeeze",
  17931. torch_opinfo_name="squeeze",
  17932. ),
  17933. PythonRefInfo(
  17934. "_refs.squeeze",
  17935. torch_opinfo_name="squeeze",
  17936. torch_opinfo_variant_name="multiple",
  17937. ),
  17938. PythonRefInfo(
  17939. "_refs.tensor_split",
  17940. torch_opinfo_name="tensor_split",
  17941. skips=(
  17942. # TensorMeta doesn't support tolist
  17943. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta'),
  17944. # RuntimeError: no _refs support for torch.Tensor.tolist
  17945. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),
  17946. ),
  17947. supports_nvfuser=False,
  17948. ),
  17949. PythonRefInfo(
  17950. "_refs.hsplit",
  17951. torch_opinfo_name="hsplit",
  17952. supports_nvfuser=False,
  17953. ),
  17954. PythonRefInfo(
  17955. "_refs.vsplit",
  17956. torch_opinfo_name="vsplit",
  17957. supports_nvfuser=False,
  17958. ),
  17959. PythonRefInfo(
  17960. "_refs.transpose",
  17961. torch_opinfo_name="transpose",
  17962. ),
  17963. PythonRefInfo(
  17964. "_refs.t",
  17965. torch_opinfo_name="t",
  17966. ),
  17967. PythonRefInfo(
  17968. "_refs.T",
  17969. torch_opinfo_name="T",
  17970. error_inputs_func=partial(error_inputs_T, has_ndims_error=True),
  17971. ),
  17972. PythonRefInfo(
  17973. "_refs.unfold",
  17974. torch_opinfo_name="unfold",
  17975. supports_nvfuser=False,
  17976. ),
  17977. PythonRefInfo(
  17978. "_refs.unfold_copy",
  17979. torch_opinfo_name="unfold_copy",
  17980. supports_nvfuser=False,
  17981. supports_out=True,
  17982. ),
  17983. PythonRefInfo(
  17984. "_refs.unsqueeze",
  17985. torch_opinfo_name="unsqueeze",
  17986. ),
  17987. PythonRefInfo(
  17988. "_refs.view",
  17989. torch_opinfo_name="view",
  17990. supports_nvfuser=False,
  17991. ),
  17992. PythonRefInfo(
  17993. "_refs.view_as",
  17994. torch_opinfo_name="view_as",
  17995. supports_nvfuser=False,
  17996. ),
  17997. PythonRefInfo(
  17998. "_refs.vstack",
  17999. torch_opinfo_name="vstack",
  18000. supports_nvfuser=False,
  18001. skips=(
  18002. # https://github.com/pytorch/pytorch/issues/78613
  18003. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
  18004. ),
  18005. ),
  18006. PythonRefInfo(
  18007. "_refs.unflatten",
  18008. torch_opinfo_name="unflatten",
  18009. supports_nvfuser=False,
  18010. ),
  18011. PythonRefInfo(
  18012. "_refs.unbind",
  18013. torch_opinfo_name="unbind",
  18014. supports_nvfuser=False,
  18015. ),
  18016. #
  18017. # Reduction Reference OpInfos
  18018. #
  18019. ReductionPythonRefInfo(
  18020. "_refs.all",
  18021. torch_opinfo_name="all",
  18022. ),
  18023. ReductionPythonRefInfo(
  18024. "_refs.amax",
  18025. torch_opinfo_name="amax",
  18026. error_inputs_func=partial(error_inputs_aminmax_amax_amin, is_ref=True),
  18027. ),
  18028. ReductionPythonRefInfo(
  18029. "_refs.amin",
  18030. torch_opinfo_name="amin",
  18031. error_inputs_func=partial(error_inputs_aminmax_amax_amin, is_ref=True),
  18032. ),
  18033. ReductionPythonRefInfo(
  18034. "_refs.any",
  18035. torch_opinfo_name="any",
  18036. ),
  18037. ReductionPythonRefInfo(
  18038. "_refs.mean",
  18039. torch_opinfo_name="mean",
  18040. supports_out=True,
  18041. error_inputs_func=partial(error_inputs_mean, is_ref=True),
  18042. ),
  18043. ReductionPythonRefInfo(
  18044. "_refs.std",
  18045. torch_opinfo_name="std",
  18046. supports_out=True,
  18047. ),
  18048. # std_mean and var_mean are not ReductionInfos
  18049. PythonRefInfo(
  18050. "_refs.std_mean",
  18051. torch_opinfo_name="std_mean",
  18052. ),
  18053. ReductionPythonRefInfo(
  18054. "_refs.sum",
  18055. torch_opinfo_name="sum",
  18056. supports_out=True,
  18057. ),
  18058. PythonRefInfo(
  18059. "_refs.cumsum",
  18060. torch_opinfo_name="cumsum",
  18061. supports_out=True,
  18062. supports_nvfuser=False, # arange not supported
  18063. ),
  18064. PythonRefInfo(
  18065. "_refs.sum_to_size",
  18066. torch_opinfo_name="sum_to_size",
  18067. validate_view_consistency=False,
  18068. ),
  18069. ReductionPythonRefInfo(
  18070. "_refs.prod",
  18071. torch_opinfo_name="prod",
  18072. supports_out=True,
  18073. supports_nvfuser=False,
  18074. ),
  18075. ReductionPythonRefInfo(
  18076. "_refs.var",
  18077. torch_opinfo_name="var",
  18078. supports_out=True,
  18079. ),
  18080. PythonRefInfo(
  18081. "_refs.var_mean",
  18082. torch_opinfo_name="var_mean",
  18083. validate_view_consistency=False,
  18084. ),
  18085. PythonRefInfo(
  18086. "ops.nvprims.var_mean",
  18087. torch_opinfo_name="var_mean",
  18088. validate_view_consistency=False,
  18089. # Complex types are currently disabled
  18090. dtypes=floating_types_and(torch.float16, torch.bfloat16),
  18091. # This function is expected not to work with TorchRefsMode(strict=True)
  18092. decorators=(
  18093. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',),
  18094. ),
  18095. ),
  18096. PythonRefInfo(
  18097. "ops.nvprims.native_batch_norm",
  18098. torch_opinfo_name="native_batch_norm",
  18099. # Complex types are currently disabled
  18100. dtypes=floating_types(),
  18101. supports_out=False,
  18102. # This function is expected not to work with TorchRefsMode(strict=True)
  18103. decorators=(
  18104. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',),
  18105. # There's a discrepancy in returned shape between CPU and other devices
  18106. # AssertionError: Shapes torch.Size([0]) and torch.Size([2]) are not equal!
  18107. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta', device_type="cpu"),
  18108. ),
  18109. skips=(
  18110. # https://github.com/pytorch/pytorch/issues/85960
  18111. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'),
  18112. ),
  18113. ),
  18114. PythonRefInfo(
  18115. "ops.nvprims.view",
  18116. torch_opinfo_name="view",
  18117. validate_view_consistency=False,
  18118. # This function is expected not to work with TorchRefsMode(strict=True)
  18119. decorators=(
  18120. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',),
  18121. ),
  18122. ),
  18123. #
  18124. # Linear Algebra Operators
  18125. #
  18126. PythonRefInfo(
  18127. "_refs.addr",
  18128. torch_opinfo_name="addr",
  18129. supports_nvfuser=False,
  18130. decorators=(
  18131. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',),
  18132. ),
  18133. ),
  18134. PythonRefInfo(
  18135. "_refs.trace",
  18136. torch_opinfo_name="trace",
  18137. supports_nvfuser=False,
  18138. ),
  18139. PythonRefInfo(
  18140. "_refs.norm",
  18141. torch_opinfo_name="norm",
  18142. supports_out=True,
  18143. # Uses svdvals which does not support nvfuser
  18144. supports_nvfuser=False,
  18145. # Uses vector_norm inside and vector_norm is affected by
  18146. # https://github.com/pytorch/pytorch/issues/77216
  18147. validate_view_consistency=False,
  18148. ),
  18149. #
  18150. # Tensor Creation Reference OpInfos
  18151. #
  18152. PythonRefInfo(
  18153. "_refs.empty",
  18154. torch_opinfo_name="empty",
  18155. skips=(
  18156. DecorateInfo(unittest.skip("Expected: empty is not comparable"),
  18157. 'TestCommon',
  18158. 'test_python_ref'),
  18159. DecorateInfo(unittest.skip("Expected: empty is not comparable"),
  18160. 'TestCommon',
  18161. 'test_python_ref_torch_fallback'),
  18162. DecorateInfo(unittest.skip("Expected: empty is not comparable"),
  18163. 'TestCommon',
  18164. 'test_out'),
  18165. DecorateInfo(unittest.skip("Expected: empty is not comparable"),
  18166. 'TestCommon',
  18167. 'test_out_warning'),
  18168. DecorateInfo(unittest.skip("Expected: empty is not comparable"),
  18169. 'TestMathBits',
  18170. 'test_conj_view'),
  18171. DecorateInfo(unittest.skip("Expected: empty is not comparable"),
  18172. 'TestMathBits',
  18173. 'test_neg_conj_view'),
  18174. DecorateInfo(unittest.skip("Expected: empty is not comparable"),
  18175. 'TestMathBits',
  18176. 'test_neg_view'),
  18177. # FIXME: shouldn't check empty results
  18178. DecorateInfo(unittest.skip("Can't check result for empty"), 'TestCommon', 'test_python_ref_executor'),
  18179. DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
  18180. ),
  18181. ),
  18182. PythonRefInfo(
  18183. "_refs.empty_like",
  18184. torch_opinfo_name="empty_like",
  18185. supports_nvfuser=False,
  18186. skips=(
  18187. DecorateInfo(unittest.skip("Expected: empty is not comparable"),
  18188. 'TestCommon',
  18189. 'test_python_ref'),
  18190. DecorateInfo(unittest.skip("Expected: empty is not comparable"),
  18191. 'TestCommon',
  18192. 'test_python_ref_torch_fallback'),
  18193. DecorateInfo(unittest.skip("Expected: empty is not comparable"),
  18194. 'TestCommon',
  18195. 'test_out'),
  18196. DecorateInfo(unittest.skip("Expected: empty is not comparable"),
  18197. 'TestCommon',
  18198. 'test_out_warning'),
  18199. DecorateInfo(unittest.skip("Expected: empty is not comparable"),
  18200. 'TestMathBits',
  18201. 'test_conj_view'),
  18202. DecorateInfo(unittest.skip("Expected: empty is not comparable"),
  18203. 'TestMathBits',
  18204. 'test_neg_conj_view'),
  18205. DecorateInfo(unittest.skip("Expected: empty is not comparable"),
  18206. 'TestMathBits',
  18207. 'test_neg_view'),
  18208. # FIXME: should not compare results of empty_like
  18209. DecorateInfo(unittest.skip("Can't check result for empty_like"), 'TestCommon', 'test_python_ref_executor'),
  18210. DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
  18211. ),
  18212. ),
  18213. PythonRefInfo(
  18214. "_refs.randn",
  18215. torch_opinfo_name="randn",
  18216. op=lambda *args, **kwargs: wrapper_set_seed(refs.randn, *args, **kwargs),
  18217. supports_nvfuser=False,
  18218. skips=(
  18219. # see https://github.com/pytorch/pytorch/issues/85121
  18220. DecorateInfo(unittest.skip("make_traced() doesn't set seed properly!"),
  18221. 'TestCommon',
  18222. 'test_python_ref_executor'),
  18223. # These tests expect the input to be a tensor or a sequence of tensors
  18224. DecorateInfo(unittest.skip("Test expects tensor input"), "TestCommon", "test_noncontiguous_samples"),
  18225. DecorateInfo(unittest.skip("Test expects tensor input"), 'TestMathBits', 'test_neg_view'),
  18226. DecorateInfo(unittest.skip("Test expects tensor input"), 'TestMathBits', 'test_conj_view'),
  18227. DecorateInfo(unittest.skip("Test expects tensor input"), 'TestMathBits', 'test_neg_conj_view'),
  18228. ),
  18229. ),
  18230. PythonRefInfo(
  18231. "_refs.eye",
  18232. torch_opinfo_name="eye",
  18233. supports_nvfuser=False,
  18234. skips=(
  18235. # skip these tests since we have non tensor input
  18236. DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'),
  18237. DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
  18238. DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
  18239. ),
  18240. ),
  18241. PythonRefInfo(
  18242. "_refs.new_empty",
  18243. torch_opinfo_name="new_empty",
  18244. supports_nvfuser=False,
  18245. skips=(
  18246. DecorateInfo(unittest.skip("Expected: empty is not comparable"),
  18247. 'TestCommon',
  18248. 'test_python_ref'),
  18249. DecorateInfo(unittest.skip("Expected: empty is not comparable"),
  18250. 'TestCommon',
  18251. 'test_python_ref_torch_fallback'),
  18252. DecorateInfo(unittest.skip("Expected: empty is not comparable"),
  18253. 'TestCommon',
  18254. 'test_out'),
  18255. DecorateInfo(unittest.skip("Expected: empty is not comparable"),
  18256. 'TestCommon',
  18257. 'test_out_warning'),
  18258. DecorateInfo(unittest.skip("Expected: empty is not comparable"),
  18259. 'TestMathBits',
  18260. 'test_conj_view'),
  18261. DecorateInfo(unittest.skip("Expected: empty is not comparable"),
  18262. 'TestMathBits',
  18263. 'test_neg_conj_view'),
  18264. DecorateInfo(unittest.skip("Expected: empty is not comparable"),
  18265. 'TestMathBits',
  18266. 'test_neg_view'),
  18267. # FIXME: should not compare results of empty_like
  18268. DecorateInfo(unittest.skip("Can't check result for new_empty"), 'TestCommon', 'test_python_ref_executor'),
  18269. DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
  18270. ),
  18271. ),
  18272. PythonRefInfo(
  18273. "_refs.new_empty_strided",
  18274. torch_opinfo_name="new_empty_strided",
  18275. skips=(
  18276. DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"),
  18277. 'TestCommon',
  18278. 'test_python_ref'),
  18279. DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"),
  18280. 'TestCommon',
  18281. 'test_python_ref_torch_fallback'),
  18282. DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"),
  18283. 'TestMathBits',
  18284. 'test_conj_view'),
  18285. DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"),
  18286. 'TestMathBits',
  18287. 'test_neg_conj_view'),
  18288. DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"),
  18289. 'TestMathBits',
  18290. 'test_neg_view'),
  18291. DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"),
  18292. 'TestCommon',
  18293. 'test_python_ref_executor'),
  18294. DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
  18295. ),
  18296. ),
  18297. PythonRefInfo(
  18298. "_refs.new_full",
  18299. torch_opinfo_name="new_full",
  18300. supports_nvfuser=False,
  18301. ),
  18302. PythonRefInfo(
  18303. "_refs.new_ones",
  18304. torch_opinfo_name="new_ones",
  18305. supports_nvfuser=False,
  18306. ),
  18307. PythonRefInfo(
  18308. "_refs.new_zeros",
  18309. torch_opinfo_name="new_zeros",
  18310. supports_nvfuser=False,
  18311. ),
  18312. #
  18313. # Conditional Reference OpInfos
  18314. #
  18315. PythonRefInfo(
  18316. "_refs.masked_fill",
  18317. torch_opinfo_name="masked_fill",
  18318. supports_nvfuser=False,
  18319. skips=(
  18320. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
  18321. ),
  18322. ),
  18323. PythonRefInfo(
  18324. "_refs.where",
  18325. torch_opinfo_name="where",
  18326. op=lambda self, condition, other: refs.where(condition, self, other),
  18327. supports_nvfuser=False,
  18328. skips=(
  18329. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors', device_type='cuda'),
  18330. ),
  18331. ),
  18332. PythonRefInfo(
  18333. "_refs.index_select",
  18334. torch_opinfo_name="index_select",
  18335. # empty_strided
  18336. supports_nvfuser=False,
  18337. skips=(
  18338. # no _refs support for Tensor.__setitem__
  18339. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),
  18340. # Sample out= with a stride of zero. This _out operation checks that the input has no
  18341. # inner overlap
  18342. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),)
  18343. ),
  18344. PythonRefInfo(
  18345. "_refs.index_copy",
  18346. torch_opinfo_name="index_copy",
  18347. # empty_strided
  18348. supports_nvfuser=False,
  18349. skips=(
  18350. # no _refs support for Tensor.__setitem__
  18351. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),)
  18352. ),
  18353. PythonRefInfo(
  18354. "_refs.index_add",
  18355. torch_opinfo_name="index_add",
  18356. # empty_strided
  18357. supports_nvfuser=False,
  18358. skips=(
  18359. # no _refs support for Tensor.__setitem__
  18360. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),)
  18361. ),
  18362. PythonRefInfo(
  18363. "_refs.index_fill",
  18364. torch_opinfo_name="index_fill",
  18365. # empty_strided
  18366. supports_nvfuser=False,
  18367. skips=(
  18368. # no _refs support for Tensor.__setitem__
  18369. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),)
  18370. ),
  18371. #
  18372. # Test-related functions
  18373. #
  18374. PythonRefInfo(
  18375. "_refs.allclose",
  18376. torch_opinfo_name="allclose",
  18377. supports_nvfuser=False,
  18378. ),
  18379. ]
  18380. python_ref_db += opinfo.definitions.python_ref_db
  18381. # Common operator groupings
  18382. ops_and_refs = op_db + python_ref_db
  18383. unary_ufuncs = [op for op in op_db if isinstance(op, UnaryUfuncInfo)]
  18384. binary_ufuncs = [op for op in op_db if isinstance(op, BinaryUfuncInfo)]
  18385. binary_ufuncs_and_refs = tuple(op for op in ops_and_refs if isinstance(op, BinaryUfuncInfo))
  18386. spectral_funcs = [op for op in op_db if isinstance(op, SpectralFuncInfo)]
  18387. sparse_unary_ufuncs = [op for op in op_db if isinstance(op, UnaryUfuncInfo) and op.supports_sparse]
  18388. sparse_csr_unary_ufuncs = [op for op in op_db if isinstance(op, UnaryUfuncInfo) and op.supports_sparse_csr]
  18389. sparse_reduction_ops = [op for op in op_db if isinstance(op, ReductionOpInfo) and op.supports_sparse]
  18390. shape_funcs = [op for op in op_db if isinstance(op, ShapeFuncInfo)]
  18391. reduction_ops = [op for op in op_db if isinstance(op, ReductionOpInfo)]
  18392. reference_filtered_ops = [op for op in reduction_ops if op.ref is not None]
  18393. reference_masked_ops = [op for op in reference_filtered_ops if op.name.startswith('masked.')]
  18394. sparse_masked_reduction_ops = [op for op in sparse_reduction_ops if op.name.startswith('masked.')]
  18395. # TODO: review porting these to make_tensor
  18396. def index_variable(shape, max_indices, device=torch.device('cpu')):
  18397. if not isinstance(shape, tuple):
  18398. shape = (shape,)
  18399. index = torch.rand(*shape, dtype=torch.double, device=device).mul_(max_indices).floor_().long()
  18400. return index
  18401. def gather_variable(shape, index_dim, max_indices, duplicate=False, device=torch.device('cpu')):
  18402. assert len(shape) == 2
  18403. assert index_dim < 2
  18404. batch_dim = 1 - index_dim
  18405. index = torch.zeros(*shape, dtype=torch.long, device=device)
  18406. for i in range(shape[index_dim]):
  18407. index.select(index_dim, i).copy_(
  18408. torch.randperm(max_indices, device=device)[:shape[batch_dim]])
  18409. if duplicate:
  18410. index.select(batch_dim, 0).copy_(index.select(batch_dim, 1))
  18411. return index
  18412. def bernoulli_scalar():
  18413. return torch.tensor(0, dtype=torch.bool).bernoulli_()
  18414. def mask_not_all_zeros(shape):
  18415. assert len(shape) > 0
  18416. while True:
  18417. result = torch.randn(shape).gt(0)
  18418. if result.sum() > 0:
  18419. return result