common_nn.py 256 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839484048414842484348444845484648474848484948504851485248534854485548564857485848594860486148624863486448654866486748684869487048714872487348744875487648774878487948804881488248834884488548864887488848894890489148924893489448954896489748984899490049014902490349044905490649074908490949104911491249134914491549164917491849194920492149224923492449254926492749284929493049314932493349344935493649374938493949404941494249434944494549464947494849494950495149524953495449554956495749584959496049614962496349644965496649674968496949704971497249734974497549764977497849794980498149824983498449854986498749884989499049914992499349944995499649974998499950005001500250035004500550065007500850095010501150125013501450155016501750185019502050215022502350245025502650275028502950305031503250335034503550365037503850395040504150425043504450455046504750485049505050515052505350545055505650575058505950605061506250635064506550665067506850695070507150725073507450755076507750785079508050815082508350845085508650875088508950905091509250935094509550965097509850995100510151025103510451055106510751085109511051115112511351145115511651175118511951205121512251235124512551265127512851295130513151325133513451355136513751385139514051415142514351445145514651475148514951505151515251535154515551565157515851595160516151625163516451655166516751685169517051715172517351745175517651775178517951805181518251835184518551865187518851895190519151925193519451955196519751985199520052015202520352045205520652075208520952105211521252135214521552165217521852195220522152225223522452255226522752285229523052315232523352345235523652375238523952405241524252435244524552465247524852495250525152525253525452555256525752585259526052615262526352645265526652675268526952705271527252735274527552765277527852795280528152825283528452855286528752885289529052915292529352945295529652975298529953005301530253035304530553065307530853095310531153125313531453155316531753185319532053215322532353245325532653275328532953305331533253335334533553365337533853395340534153425343534453455346534753485349535053515352535353545355535653575358535953605361536253635364536553665367536853695370537153725373537453755376537753785379538053815382538353845385538653875388538953905391539253935394539553965397539853995400540154025403540454055406540754085409541054115412541354145415541654175418541954205421542254235424542554265427542854295430543154325433543454355436543754385439544054415442544354445445544654475448544954505451545254535454545554565457545854595460546154625463546454655466546754685469547054715472547354745475547654775478547954805481548254835484548554865487548854895490549154925493549454955496549754985499550055015502550355045505550655075508550955105511551255135514551555165517551855195520552155225523552455255526552755285529553055315532553355345535553655375538553955405541554255435544554555465547554855495550555155525553555455555556555755585559556055615562556355645565556655675568556955705571557255735574557555765577557855795580558155825583558455855586558755885589559055915592559355945595559655975598559956005601560256035604560556065607560856095610561156125613561456155616561756185619562056215622562356245625562656275628562956305631563256335634563556365637563856395640564156425643564456455646564756485649565056515652565356545655565656575658565956605661566256635664566556665667566856695670567156725673567456755676567756785679568056815682568356845685568656875688568956905691569256935694569556965697569856995700570157025703570457055706570757085709571057115712571357145715571657175718571957205721572257235724572557265727572857295730573157325733573457355736573757385739574057415742574357445745574657475748574957505751575257535754575557565757575857595760576157625763576457655766576757685769577057715772577357745775577657775778577957805781578257835784578557865787578857895790579157925793579457955796579757985799580058015802580358045805580658075808580958105811581258135814581558165817581858195820582158225823582458255826582758285829583058315832583358345835583658375838583958405841584258435844584558465847584858495850585158525853585458555856585758585859586058615862586358645865586658675868586958705871587258735874587558765877587858795880588158825883588458855886588758885889589058915892589358945895589658975898589959005901590259035904590559065907590859095910591159125913591459155916591759185919592059215922592359245925592659275928592959305931593259335934593559365937593859395940594159425943594459455946594759485949595059515952595359545955595659575958595959605961596259635964596559665967596859695970597159725973597459755976597759785979598059815982598359845985598659875988598959905991599259935994599559965997599859996000600160026003600460056006600760086009601060116012601360146015601660176018601960206021602260236024602560266027602860296030603160326033603460356036603760386039604060416042604360446045604660476048604960506051605260536054605560566057605860596060606160626063606460656066606760686069607060716072607360746075607660776078607960806081608260836084608560866087608860896090609160926093609460956096609760986099610061016102610361046105610661076108610961106111611261136114611561166117611861196120612161226123612461256126612761286129613061316132613361346135613661376138613961406141614261436144614561466147614861496150615161526153615461556156615761586159616061616162616361646165616661676168616961706171617261736174617561766177617861796180618161826183618461856186618761886189619061916192619361946195619661976198619962006201620262036204620562066207620862096210621162126213621462156216621762186219622062216222622362246225622662276228622962306231623262336234623562366237623862396240624162426243624462456246624762486249625062516252625362546255625662576258625962606261626262636264626562666267626862696270627162726273627462756276627762786279628062816282628362846285628662876288628962906291629262936294629562966297629862996300630163026303630463056306630763086309631063116312631363146315631663176318631963206321632263236324632563266327632863296330633163326333633463356336633763386339634063416342634363446345634663476348634963506351635263536354635563566357635863596360636163626363636463656366636763686369637063716372637363746375637663776378637963806381638263836384638563866387638863896390639163926393639463956396639763986399640064016402640364046405640664076408640964106411641264136414641564166417641864196420642164226423642464256426642764286429643064316432643364346435643664376438643964406441644264436444644564466447644864496450645164526453645464556456645764586459646064616462646364646465646664676468646964706471647264736474647564766477647864796480648164826483648464856486648764886489649064916492649364946495
  1. from abc import abstractmethod
  2. import math
  3. import tempfile
  4. import unittest
  5. from copy import deepcopy
  6. from functools import reduce, partial, wraps
  7. from itertools import product
  8. from operator import mul
  9. from math import pi
  10. import torch
  11. import torch.cuda
  12. import torch.nn as nn
  13. import torch.nn.functional as F
  14. from torch.nn import _reduction as _Reduction
  15. from torch.testing._internal.common_utils import TestCase, to_gpu, freeze_rng_state, is_iterable, \
  16. TEST_WITH_ROCM, gradcheck, gradgradcheck
  17. from torch.testing._internal.common_cuda import TEST_CUDA
  18. from torch.autograd.gradcheck import _get_numerical_jacobian, _iter_tensors
  19. from torch.autograd import Variable
  20. from torch.types import _TensorOrTensors
  21. import torch.backends.cudnn
  22. from typing import Dict, Callable, Tuple, List, Sequence, Union, Any
  23. TemporaryFile = tempfile.TemporaryFile
  24. PRECISION = 1e-5
  25. def get_reduction(m):
  26. result = getattr(m, 'reduction', None)
  27. if result is None:
  28. result = _Reduction.legacy_get_string(getattr(m, 'sizeAverage', None), True, emit_warning=False)
  29. assert result is not None
  30. return result
  31. def get_weight(m):
  32. result = getattr(m, 'weight', None)
  33. if result is not None:
  34. return result
  35. return getattr(m, 'weights', None)
  36. # NOTE [How to check NN module / functional API parity between Python and C++ frontends]
  37. #
  38. # The way to check API parity is to add parity tests for the NN module / functional of interest.
  39. # Here are the detailed steps:
  40. #
  41. # For NN module:
  42. # 1. Make sure you already have a test dict with the module configuration you want to test.
  43. # 2. Add `cpp_constructor_args` entry to the test dict, with its value exactly matching
  44. # the Python module constructor arguments. For example, if in the test dict we pass
  45. # `(10, 8)` to `torch.nn.Linear` constructor, then we should pass `torch::nn::LinearOptions(10, 8)`
  46. # as the corresponding C++ constructor argument to `torch::nn::Linear`.
  47. # 3. If in the process of performing the above step you referenced any variables
  48. # in the `cpp_constructor_args` entry, you must add `cpp_var_map` entry
  49. # to the test dict to make sure that those variables are populated with the right Python values.
  50. # For example, if the Python constructor call is
  51. # `torch.nn.FractionalMaxPool2d(2, output_ratio=0.5, _random_samples=random_samples)`,
  52. # the corresponding C++ constructor argument is
  53. # `torch::nn::FractionalMaxPool2dOptions(2).output_ratio(0.5)._random_samples(random_samples)`,
  54. # and the `cpp_var_map` entry must be
  55. # `{'random_samples': random_samples}` in order to populate the C++ variable `random_samples`
  56. # used in the C++ constructor argument with the Python tensor value `random_samples`.
  57. #
  58. # For NN functional:
  59. # 1. Make sure you already have a test dict with the functional configuration you want to test.
  60. # 2. If the test dict's `constructor` entry looks like `wrap_functional(F.some_functional_name, ...)`,
  61. # then you must add `cpp_options_args` entry to the test dict, with its value exactly matching the Python
  62. # functional optional arguments. For example, if the test dict's `constructor` entry is
  63. # `wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest')`,
  64. # then the `cpp_options_args` entry should be
  65. # "F::InterpolateFuncOptions().size(std::vector<int64_t>({12})).scale_factor(c10::nullopt).mode(torch::kNearest)".
  66. # 3. Otherwise, if the test dict's `constructor` entry looks like
  67. # `wrap_functional(lambda i: F.some_functional_name(...))`,
  68. # then you must add `cpp_function_call` entry to the test dict, with its value exactly matching the Python
  69. # functional function call. For example, if the test dict's `constructor` entry is
  70. # `wrap_functional(lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none'))`,
  71. # then the `cpp_function_call` entry should be
  72. # "F::poisson_nll_loss(i, t.to(i.options()), F::PoissonNLLLossFuncOptions().reduction(torch::kNone))".
  73. # 4. If in the process of performing the above two steps you referenced any variables
  74. # in the `cpp_options_args` or `cpp_function_call` entry, you must
  75. # add `cpp_var_map` entry to the test dict to make sure that those variables
  76. # are populated with the right Python values. For example, if the test dict's `constructor` entry is
  77. # `wrap_functional(lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none'))`,
  78. # then the `cpp_function_call` entry should be
  79. # "F::poisson_nll_loss(i, t.to(i.options()), F::PoissonNLLLossFuncOptions().reduction(torch::kNone))".
  80. # Notice that there are two variables `i` and `t` that need to have their values provided,
  81. # and the way to do so is to add a `cpp_var_map` entry: `cpp_var_map={'i': '_get_input()', 't': t}`.
  82. # (Note that for `i`, since we want it to take the Python input value, we pass '_get_input()' string as value
  83. # and the C++ parity test mechanism will populate `i` with the Python input value correctly.)
  84. #
  85. # There are also a few optional flags in the test dict to control the C++ parity test behavior:
  86. #
  87. # - `test_cpp_api_parity`: if `False`, skips the C++ parity test for this test dict. Default: True.
  88. # - `has_parity`: if `False`, expects this test dict to fail the C++ parity test. Default: True.
  89. module_tests = [
  90. dict(
  91. module_name='Linear',
  92. constructor_args=(10, 8),
  93. cpp_constructor_args='torch::nn::LinearOptions(10, 8)',
  94. input_size=(4, 10),
  95. reference_fn=lambda i, p, _: torch.mm(i, p[0].t()) + p[1].view(1, -1).expand(4, 8),
  96. with_tf32=True,
  97. tf32_precision=0.005,
  98. ),
  99. dict(
  100. module_name='Linear',
  101. constructor_args=(10, 8, False),
  102. cpp_constructor_args='torch::nn::LinearOptions(10, 8).bias(false)',
  103. input_size=(4, 10),
  104. desc='no_bias',
  105. reference_fn=lambda i, p, _: torch.mm(i, p[0].t()),
  106. with_tf32=True,
  107. tf32_precision=0.005,
  108. ),
  109. dict(
  110. module_name='Threshold',
  111. constructor_args=(2., 1.),
  112. cpp_constructor_args='torch::nn::ThresholdOptions(2., 1.)',
  113. input_size=(2, 3, 4, 5),
  114. check_inplace=True,
  115. desc='threshold_value'
  116. ),
  117. dict(
  118. module_name='Threshold',
  119. constructor_args=(2., 10.),
  120. cpp_constructor_args='torch::nn::ThresholdOptions(2., 10.)',
  121. input_size=(2, 3, 4, 5),
  122. desc='large_value'
  123. ),
  124. dict(
  125. module_name='ReLU',
  126. input_size=(2, 3, 4, 5),
  127. check_inplace=True,
  128. ),
  129. dict(
  130. module_name='ReLU6',
  131. input_size=(2, 3, 4, 5),
  132. check_inplace=True,
  133. ),
  134. dict(
  135. module_name='RReLU',
  136. input_size=(1, 2, 2),
  137. test_cuda=False,
  138. ),
  139. dict(
  140. module_name='RReLU',
  141. constructor_args=(0.1, 0.9),
  142. cpp_constructor_args='torch::nn::RReLUOptions().lower(0.1).upper(0.9)',
  143. input_size=(4, 4, 5),
  144. desc='with_up_down',
  145. test_cuda=False,
  146. ),
  147. dict(
  148. module_name='Hardtanh',
  149. input_size=(3, 2, 5),
  150. reference_fn=lambda i, *_: i.clamp(-1, 1),
  151. ),
  152. dict(
  153. module_name='Sigmoid',
  154. input_size=(2, 3, 4, 5),
  155. ),
  156. dict(
  157. module_name='Tanh',
  158. input_size=(2, 3, 4, 5),
  159. ),
  160. dict(
  161. module_name='Flatten',
  162. input_size=(2, 3, 4, 5),
  163. reference_fn=lambda i, *_: torch.flatten(i, 1)
  164. ),
  165. dict(
  166. module_name='Softmax',
  167. constructor_args=(1,),
  168. cpp_constructor_args='torch::nn::SoftmaxOptions(1)',
  169. input_size=(10, 20),
  170. reference_fn=lambda i, *_: torch.exp(i).div(torch.exp(i).sum(1, True).expand(10, 20)),
  171. ),
  172. dict(
  173. module_name='Softmax2d',
  174. input_size=(1, 3, 10, 20),
  175. reference_fn=lambda i, *_: torch.exp(i).div(torch.exp(i).sum(1, False)),
  176. ),
  177. dict(
  178. module_name='LogSoftmax',
  179. constructor_args=(1,),
  180. cpp_constructor_args='torch::nn::LogSoftmaxOptions(1)',
  181. input_size=(10, 20),
  182. reference_fn=lambda i, *_: torch.exp(i).div_(torch.exp(i).sum(1, True).expand(10, 20)).log_(),
  183. ),
  184. dict(
  185. module_name='LogSoftmax',
  186. constructor_args=(1,),
  187. cpp_constructor_args='torch::nn::LogSoftmaxOptions(1)',
  188. input_size=(1, 3, 10, 20),
  189. reference_fn=lambda i, *_: torch.exp(i).div_(torch.exp(i).sum(1, False)).log_(),
  190. desc='multiparam',
  191. ),
  192. dict(
  193. module_name='ELU',
  194. constructor_args=(2.,),
  195. cpp_constructor_args='torch::nn::ELUOptions().alpha(2.)',
  196. input_size=(3, 2, 5),
  197. reference_fn=lambda x, *_: torch.where(x >= 0, x, 2 * (x.exp() - 1)),
  198. ),
  199. # TODO: reference function
  200. dict(
  201. module_name='Hardshrink',
  202. constructor_args=(2.,),
  203. cpp_constructor_args='torch::nn::HardshrinkOptions(2.)',
  204. input_size=(4, 3, 2, 4),
  205. ),
  206. dict(
  207. module_name='LeakyReLU',
  208. input_size=(3, 2, 5),
  209. check_inplace=True
  210. ),
  211. dict(
  212. module_name='LeakyReLU',
  213. constructor_args=(0.5,),
  214. cpp_constructor_args='torch::nn::LeakyReLUOptions().negative_slope(0.5)',
  215. input_size=(3, 2, 5),
  216. check_inplace=True,
  217. desc='with_negval'
  218. ),
  219. dict(
  220. module_name='LeakyReLU',
  221. constructor_args=(0.0,),
  222. cpp_constructor_args='torch::nn::LeakyReLUOptions().negative_slope(0.0)',
  223. input_fn=lambda: torch.randn(10, 10),
  224. check_inplace=True,
  225. desc='with_zero_negval'
  226. ),
  227. dict(
  228. module_name='LogSigmoid',
  229. input_size=(2, 3, 4),
  230. reference_fn=lambda i, *_: i.sigmoid().log(),
  231. ),
  232. dict(
  233. module_name='Softplus',
  234. input_size=(10, 20),
  235. reference_fn=lambda i, *_: torch.log(1 + torch.exp(i)),
  236. ),
  237. dict(
  238. module_name='Softplus',
  239. constructor_args=(2,),
  240. cpp_constructor_args='torch::nn::SoftplusOptions().beta(2)',
  241. input_size=(10, 20),
  242. reference_fn=lambda i, *_: 1. / 2. * torch.log(1 + torch.exp(2 * i)),
  243. desc='beta',
  244. ),
  245. dict(
  246. module_name='Softplus',
  247. constructor_args=(2, -100),
  248. cpp_constructor_args='torch::nn::SoftplusOptions().beta(2).threshold(-100)',
  249. input_size=(10, 20),
  250. reference_fn=(
  251. lambda i, *_: ((i * 2) > -100).type_as(i) * i
  252. + ((i * 2) <= -100).type_as(i) * 1. / 2. * torch.log(1 + torch.exp(2 * i))
  253. ),
  254. desc='beta_threshold',
  255. ),
  256. dict(
  257. module_name='Softshrink',
  258. input_size=(3, 2, 5),
  259. ),
  260. dict(
  261. module_name='Softshrink',
  262. constructor_args=(1,),
  263. cpp_constructor_args='torch::nn::SoftshrinkOptions(1)',
  264. input_size=(3, 2, 5),
  265. desc='lambda',
  266. ),
  267. dict(
  268. module_name='CrossMapLRN2d',
  269. constructor_args=(5, 5e-3, 1e-3, 2),
  270. cpp_constructor_args='torch::nn::CrossMapLRN2dOptions(5).alpha(5e-3).beta(1e-3).k(2)',
  271. input_size=(2, 3, 6, 6),
  272. check_gradgrad=False,
  273. # TODO(#50743): Figure out the error. "RuntimeError: Unrecognized tensor type ID: Batched"
  274. check_batched_grad=False,
  275. ),
  276. dict(
  277. module_name='PReLU',
  278. input_size=(2, 3, 4),
  279. reference_fn=lambda i, p, _: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
  280. desc='1d',
  281. ),
  282. dict(
  283. module_name='PReLU',
  284. constructor_args=(3,),
  285. cpp_constructor_args='torch::nn::PReLUOptions().num_parameters(3)',
  286. input_size=(2, 3, 4),
  287. desc='1d_multiparam',
  288. reference_fn=lambda i, p, _: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
  289. ),
  290. dict(
  291. module_name='PReLU',
  292. input_size=(2, 3, 4, 5),
  293. desc='2d',
  294. reference_fn=lambda i, p, _: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
  295. ),
  296. dict(
  297. module_name='PReLU',
  298. constructor_args=(3,),
  299. cpp_constructor_args='torch::nn::PReLUOptions().num_parameters(3)',
  300. input_size=(2, 3, 4, 5),
  301. desc='2d_multiparam',
  302. reference_fn=lambda i, p, _: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
  303. ),
  304. dict(
  305. module_name='PReLU',
  306. input_size=(2, 3, 4, 5, 6),
  307. reference_fn=lambda i, p, _: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
  308. desc='3d',
  309. ),
  310. dict(
  311. module_name='PReLU',
  312. constructor_args=(3,),
  313. cpp_constructor_args='torch::nn::PReLUOptions().num_parameters(3)',
  314. input_size=(2, 3, 4, 5, 6),
  315. desc='3d_multiparam',
  316. reference_fn=lambda i, p, _: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
  317. ),
  318. dict(
  319. module_name='Softsign',
  320. input_size=(3, 2, 5),
  321. reference_fn=lambda i, *_: i.div(1 + torch.abs(i)),
  322. ),
  323. dict(
  324. module_name='Softmin',
  325. constructor_args=(1,),
  326. cpp_constructor_args='torch::nn::SoftminOptions(1)',
  327. input_size=(10, 20),
  328. ),
  329. dict(
  330. module_name='Softmin',
  331. constructor_args=(1,),
  332. cpp_constructor_args='torch::nn::SoftminOptions(1)',
  333. input_size=(2, 3, 5, 10),
  334. desc='multidim',
  335. ),
  336. dict(
  337. module_name='Tanhshrink',
  338. input_size=(2, 3, 4, 5),
  339. ),
  340. ]
  341. # Generates rand tensor with non-equal values. This ensures that duplicate
  342. # values won't be causing test failure for modules like MaxPooling.
  343. # size should be small, otherwise randperm fails / long overflows.
  344. def _rand_tensor_non_equal(*size):
  345. total = reduce(mul, size, 1)
  346. return torch.randperm(total).view(*size).double()
  347. def wrap_functional(fn, **kwargs):
  348. class FunctionalModule(nn.Module):
  349. def forward(self, *args):
  350. return fn(*args, **kwargs)
  351. return FunctionalModule
  352. def poissonnllloss_no_reduce_test():
  353. t = torch.randn(10, 10)
  354. return dict(
  355. fullname='PoissonNLLLoss_no_reduce',
  356. constructor=wrap_functional(
  357. lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none')),
  358. cpp_function_call='F::poisson_nll_loss('
  359. 'i, t.to(i.options()), F::PoissonNLLLossFuncOptions().reduction(torch::kNone))',
  360. input_fn=lambda: torch.rand(10, 10),
  361. cpp_var_map={'i': '_get_input()', 't': t},
  362. reference_fn=lambda i, *_: i.exp() - t.mul(i),
  363. pickle=False)
  364. def bceloss_no_reduce_test():
  365. t = Variable(torch.randn(15, 10).gt(0).double())
  366. return dict(
  367. fullname='BCELoss_no_reduce',
  368. constructor=wrap_functional(
  369. lambda i: F.binary_cross_entropy(i, t.type_as(i), reduction='none')),
  370. cpp_function_call='F::binary_cross_entropy('
  371. 'i, t.to(i.options()), F::BinaryCrossEntropyFuncOptions().reduction(torch::kNone))',
  372. input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
  373. cpp_var_map={'i': '_get_input()', 't': t},
  374. reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()),
  375. pickle=False,
  376. precision=7e-4)
  377. def bceloss_no_reduce_scalar_test():
  378. t = torch.randn(()).gt(0).double()
  379. return dict(
  380. fullname='BCELoss_no_reduce_scalar',
  381. constructor=wrap_functional(
  382. lambda i: F.binary_cross_entropy(i, t.type_as(i), reduction='none')),
  383. cpp_function_call='F::binary_cross_entropy('
  384. 'i, t.to(i.options()), F::BinaryCrossEntropyFuncOptions().reduction(torch::kNone))',
  385. input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
  386. cpp_var_map={'i': '_get_input()', 't': t},
  387. reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()),
  388. pickle=False)
  389. def bceloss_weights_no_reduce_test():
  390. t = Variable(torch.randn(15, 10).gt(0).double())
  391. weights = torch.rand(10)
  392. return dict(
  393. fullname='BCELoss_weights_no_reduce',
  394. constructor=wrap_functional(
  395. lambda i: F.binary_cross_entropy(i, t.type_as(i),
  396. weight=weights.type_as(i), reduction='none')),
  397. cpp_function_call='F::binary_cross_entropy('
  398. 'i, t.to(i.options()), '
  399. 'F::BinaryCrossEntropyFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))',
  400. input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
  401. cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights},
  402. reference_fn=lambda i, p, m: -(t * i.log() + (1 - t) * (1 - i).log()) * weights,
  403. pickle=False,
  404. precision=3e-4
  405. )
  406. def bceloss_weights_no_reduce_scalar_test():
  407. t = torch.randn(()).double()
  408. weights = torch.rand(())
  409. return dict(
  410. fullname='BCELoss_weights_no_reduce_scalar',
  411. constructor=wrap_functional(
  412. lambda i: F.binary_cross_entropy(i, t.type_as(i),
  413. weight=weights.type_as(i), reduction='none')),
  414. cpp_function_call='''F::binary_cross_entropy(
  415. i, t.to(i.options()),
  416. F::BinaryCrossEntropyFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))''',
  417. cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights},
  418. input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
  419. reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()) * weights,
  420. pickle=False
  421. )
  422. def bce_with_logistic_legacy_enum_test():
  423. t = Variable(torch.randn(15, 10).gt(0).double())
  424. sigmoid = nn.Sigmoid()
  425. return dict(
  426. fullname='BCEWithLogitsLoss_legacy_enum',
  427. constructor=wrap_functional(
  428. lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduce=False)),
  429. cpp_function_call='''F::binary_cross_entropy_with_logits(
  430. i, t.to(i.options()), F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone))''',
  431. input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
  432. cpp_var_map={'i': '_get_input()', 't': t},
  433. reference_fn=lambda i, *_: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()),
  434. check_gradgrad=False,
  435. pickle=False,
  436. )
  437. def bce_with_logistic_no_reduce_test():
  438. t = Variable(torch.randn(15, 10).gt(0).double())
  439. sigmoid = nn.Sigmoid()
  440. return dict(
  441. fullname='BCEWithLogitsLoss_no_reduce',
  442. constructor=wrap_functional(
  443. lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduction='none')),
  444. cpp_function_call='''F::binary_cross_entropy_with_logits(
  445. i, t.to(i.options()), F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone))''',
  446. input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
  447. cpp_var_map={'i': '_get_input()', 't': t},
  448. reference_fn=lambda i, *_: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()),
  449. check_gradgrad=False,
  450. pickle=False,
  451. )
  452. def bce_with_logistic_no_reduce_scalar_test():
  453. t = torch.randn(()).gt(0).double()
  454. sigmoid = nn.Sigmoid()
  455. return dict(
  456. fullname='BCEWithLogitsLoss_no_reduce_scalar',
  457. constructor=wrap_functional(
  458. lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduction='none')),
  459. cpp_function_call='''F::binary_cross_entropy_with_logits(
  460. i, t.to(i.options()), F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone))''',
  461. input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
  462. cpp_var_map={'i': '_get_input()', 't': t},
  463. reference_fn=lambda i, *_: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()),
  464. check_gradgrad=False,
  465. pickle=False
  466. )
  467. def kldivloss_with_target_no_reduce_test():
  468. t = torch.rand(10, 10)
  469. return dict(
  470. fullname='KLDivLoss_with_target_no_reduce',
  471. constructor=wrap_functional(
  472. lambda i: F.kl_div(i, t.type_as(i), reduction='none')),
  473. cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))',
  474. input_fn=lambda: torch.rand(10, 10).log(),
  475. cpp_var_map={'i': '_get_input()', 't': t},
  476. reference_fn=lambda i, *_:
  477. loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'),
  478. supports_forward_ad=True,
  479. pickle=False)
  480. def kldivloss_no_reduce_test():
  481. t = torch.rand(10, 10)
  482. return dict(
  483. fullname='KLDivLoss_no_reduce',
  484. constructor=wrap_functional(
  485. lambda i: F.kl_div(i, t.type_as(i), reduction='none')),
  486. cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))',
  487. input_fn=lambda: torch.rand(10, 10).log(),
  488. cpp_var_map={'i': '_get_input()', 't': t},
  489. reference_fn=lambda i, *_:
  490. loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'),
  491. supports_forward_ad=True,
  492. pickle=False,
  493. )
  494. def kldivloss_no_reduce_scalar_test():
  495. t = torch.rand(())
  496. return dict(
  497. fullname='KLDivLoss_no_reduce_scalar',
  498. constructor=wrap_functional(
  499. lambda i: F.kl_div(i, t.type_as(i), reduction='none')),
  500. cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))',
  501. input_fn=lambda: torch.rand(()).log(),
  502. cpp_var_map={'i': '_get_input()', 't': t},
  503. reference_fn=lambda i, *_:
  504. loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'),
  505. supports_forward_ad=True,
  506. pickle=False)
  507. def kldivloss_with_log_target_no_reduce_test():
  508. t = torch.rand(10, 10).log()
  509. return dict(
  510. fullname='KLDivLoss_with_log_target_no_reduce',
  511. constructor=wrap_functional(
  512. lambda i: F.kl_div(i, t.type_as(i), reduction='none', log_target=True)),
  513. cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone).log_target(true))',
  514. input_fn=lambda: torch.rand(10, 10).log(),
  515. cpp_var_map={'i': '_get_input()', 't': t},
  516. reference_fn=lambda i, *_:
  517. loss_reference_fns['KLDivLoss_log_target'](i, t.type_as(i), reduction='none'),
  518. supports_forward_ad=True,
  519. pickle=False)
  520. def kldivloss_no_reduce_log_target_test():
  521. t = torch.rand(10, 10).log()
  522. return dict(
  523. fullname='KLDivLoss_no_reduce_log_target',
  524. constructor=wrap_functional(
  525. lambda i: F.kl_div(i, t.type_as(i), reduction='none', log_target=True)),
  526. cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone).log_target(true))',
  527. input_fn=lambda: torch.rand(10, 10).log(),
  528. cpp_var_map={'i': '_get_input()', 't': t},
  529. reference_fn=lambda i, *_:
  530. loss_reference_fns['KLDivLoss_log_target'](i, t.type_as(i), reduction='none'),
  531. supports_forward_ad=True,
  532. pickle=False,
  533. )
  534. def kldivloss_no_reduce_scalar_log_target_test():
  535. t = torch.rand(()).log()
  536. return dict(
  537. fullname='KLDivLoss_no_reduce_scalar_log_target',
  538. constructor=wrap_functional(
  539. lambda i: F.kl_div(i, t.type_as(i), reduction='none', log_target=True)),
  540. cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone).log_target(true))',
  541. input_fn=lambda: torch.rand(()).log(),
  542. cpp_var_map={'i': '_get_input()', 't': t},
  543. reference_fn=lambda i, *_:
  544. loss_reference_fns['KLDivLoss_log_target'](i, t.type_as(i), reduction='none'),
  545. supports_forward_ad=True,
  546. pickle=False)
  547. def l1loss_no_reduce_test():
  548. t = torch.randn(2, 3, 4)
  549. return dict(
  550. fullname='L1Loss_no_reduce',
  551. constructor=wrap_functional(
  552. lambda i: F.l1_loss(i, t.type_as(i), reduction='none')),
  553. cpp_function_call='F::l1_loss(i, t.to(i.options()), F::L1LossFuncOptions().reduction(torch::kNone))',
  554. input_fn=lambda: torch.randn(2, 3, 4),
  555. cpp_var_map={'i': '_get_input()', 't': t},
  556. reference_fn=lambda i, *_: (i - t.type_as(i)).abs(),
  557. supports_forward_ad=True,
  558. pickle=False)
  559. def l1loss_no_reduce_complex_test():
  560. t = torch.randn(2, 3, 4, dtype=torch.cdouble)
  561. return dict(
  562. fullname='L1Loss_no_reduce_complex',
  563. constructor=wrap_functional(
  564. lambda i: F.l1_loss(i, t.type_as(i), reduction='none')),
  565. cpp_function_call='F::l1_loss(i, t.to(i.options()), F::L1LossFuncOptions().reduction(torch::kNone))',
  566. input_fn=lambda: torch.randn(2, 3, 4, dtype=torch.cdouble),
  567. cpp_var_map={'i': '_get_input()', 't': t},
  568. reference_fn=lambda i, *_: (i - t.type_as(i)).abs(),
  569. supports_forward_ad=True,
  570. pickle=False)
  571. def l1loss_no_reduce_scalar_test():
  572. t = torch.randn(())
  573. return dict(
  574. fullname='L1Loss_no_reduce_scalar',
  575. constructor=wrap_functional(
  576. lambda i: F.l1_loss(i, t.type_as(i), reduction='none')),
  577. cpp_function_call='F::l1_loss(i, t.to(i.options()), F::L1LossFuncOptions().reduction(torch::kNone))',
  578. input_fn=lambda: torch.randn(()),
  579. cpp_var_map={'i': '_get_input()', 't': t},
  580. reference_fn=lambda i, *_: (i - t.type_as(i)).abs(),
  581. supports_forward_ad=True,
  582. pickle=False)
  583. def mseloss_no_reduce_test():
  584. input_size = (2, 3, 4, 5)
  585. target = torch.randn(*input_size)
  586. return dict(
  587. fullname='MSELoss_no_reduce',
  588. constructor=wrap_functional(
  589. lambda i: F.mse_loss(i, target.type_as(i), reduction='none')),
  590. cpp_function_call='F::mse_loss(i, target.to(i.options()), F::MSELossFuncOptions().reduction(torch::kNone))',
  591. input_size=input_size,
  592. cpp_var_map={'i': '_get_input()', 'target': target},
  593. reference_fn=lambda i, *_: (i - target).pow(2),
  594. supports_forward_ad=True,
  595. pickle=False)
  596. def mseloss_no_reduce_scalar_test():
  597. input_size = ()
  598. target = torch.randn(input_size)
  599. return dict(
  600. fullname='MSELoss_no_reduce_scalar',
  601. constructor=wrap_functional(
  602. lambda i: F.mse_loss(i, target.type_as(i), reduction='none')),
  603. cpp_function_call='F::mse_loss(i, target.to(i.options()), F::MSELossFuncOptions().reduction(torch::kNone))',
  604. input_size=input_size,
  605. cpp_var_map={'i': '_get_input()', 'target': target},
  606. reference_fn=lambda i, *_: (i - target).pow(2),
  607. supports_forward_ad=True,
  608. pickle=False)
  609. def nllloss_no_reduce_test():
  610. t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
  611. kwargs = {'reduction': 'none'}
  612. return dict(
  613. fullname='NLLLoss_no_reduce',
  614. constructor=wrap_functional(
  615. lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])),
  616. cpp_function_call='''F::nll_loss(
  617. i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''',
  618. input_fn=lambda: torch.rand(15, 10).log(),
  619. cpp_var_map={'i': '_get_input()', 't': t},
  620. reference_fn=lambda i, *_:
  621. loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs),
  622. pickle=False)
  623. def nllloss_no_reduce_ignore_index_test():
  624. t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
  625. kwargs: Dict[str, Union[int, str]] = {'ignore_index': 2, 'reduction': 'none'}
  626. return dict(
  627. fullname='NLLLoss_no_reduce_ignore_index',
  628. constructor=wrap_functional(
  629. lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']),
  630. reduction=str(kwargs['reduction']))),
  631. cpp_function_call='''F::nll_loss(
  632. i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(2).reduction(torch::kNone))''',
  633. input_fn=lambda: torch.rand(15, 10).log(),
  634. cpp_var_map={'i': '_get_input()', 't': t},
  635. reference_fn=lambda i, *_:
  636. loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs),
  637. pickle=False)
  638. def nllloss_no_reduce_weights_test():
  639. t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
  640. weight = torch.rand(10)
  641. def kwargs(i):
  642. return {'weight': weight.type_as(i), 'reduction': 'none'}
  643. return dict(
  644. fullname='NLLLoss_no_reduce_weights',
  645. constructor=wrap_functional(
  646. lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))),
  647. cpp_function_call='''F::nll_loss(
  648. i, t.to(i.options()).to(torch::kLong),
  649. F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone))''',
  650. input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
  651. cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight},
  652. reference_fn=lambda i, *_:
  653. loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs(i)),
  654. pickle=False)
  655. def nllloss_no_reduce_weights_ignore_index_test():
  656. t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
  657. weight = torch.rand(10)
  658. def kwargs(i):
  659. return {'weight': weight.type_as(i), 'reduction': 'none',
  660. 'ignore_index': 2}
  661. return dict(
  662. fullname='NLLLoss_no_reduce_weights_ignore_index',
  663. constructor=wrap_functional(
  664. lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i.data))),
  665. cpp_function_call='''F::nll_loss(
  666. i, t.to(i.options()).to(torch::kLong),
  667. F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone).ignore_index(2))''',
  668. input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
  669. cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight},
  670. reference_fn=lambda i, *_:
  671. loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs(i)),
  672. pickle=False)
  673. def nllloss_no_reduce_weights_ignore_index_neg_test():
  674. t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
  675. weight = torch.rand(10)
  676. def kwargs(i):
  677. return {'weight': weight.type_as(i), 'reduction': 'none',
  678. 'ignore_index': -1}
  679. return dict(
  680. fullname='NLLLoss_no_reduce_weights_ignore_index_neg',
  681. constructor=wrap_functional(
  682. lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))),
  683. cpp_function_call='''F::nll_loss(
  684. i, t.to(i.options()).to(torch::kLong),
  685. F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone).ignore_index(-1))''',
  686. input=torch.rand(15, 10).add(1e-2).log(),
  687. cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight},
  688. reference_fn=lambda i, *_:
  689. loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs(i)),
  690. pickle=False)
  691. def nllloss2d_no_reduce_test():
  692. t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
  693. kwargs = {'reduction': 'none'}
  694. return dict(
  695. fullname='NLLLoss2d_no_reduce',
  696. constructor=wrap_functional(
  697. lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])),
  698. cpp_function_call='''F::nll_loss(
  699. i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''',
  700. input_fn=lambda: torch.rand(2, 3, 5, 5).log(),
  701. cpp_var_map={'i': '_get_input()', 't': t},
  702. reference_fn=lambda i, *_:
  703. loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs),
  704. pickle=False)
  705. def nllloss2d_no_reduce_ignore_index_test():
  706. t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
  707. kwargs: Dict[str, Union[int, str]] = {'ignore_index': 1, 'reduction': 'none'}
  708. return dict(
  709. fullname='NLLLoss2d_no_reduce_ignore_index',
  710. constructor=wrap_functional(
  711. lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']),
  712. reduction=str(kwargs['reduction']))),
  713. cpp_function_call='''F::nll_loss(
  714. i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(1).reduction(torch::kNone))''',
  715. input_fn=lambda: torch.rand(2, 3, 5, 5).log(),
  716. cpp_var_map={'i': '_get_input()', 't': t},
  717. reference_fn=lambda i, *_:
  718. loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs),
  719. pickle=False)
  720. def nllloss2d_no_reduce_weights_test():
  721. t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
  722. weight = torch.rand(3)
  723. def kwargs(i):
  724. return {'weight': weight.type_as(i), 'reduction': 'none'}
  725. return dict(
  726. fullname='NLLLoss2d_no_reduce_weights',
  727. constructor=wrap_functional(
  728. lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))),
  729. cpp_function_call='''F::nll_loss(
  730. i, t.to(i.options()).to(torch::kLong),
  731. F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone))''',
  732. input_fn=lambda: torch.rand(2, 3, 5, 5).log(),
  733. cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight},
  734. reference_fn=lambda i, *_:
  735. loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs(i)),
  736. pickle=False)
  737. def nlllossNd_no_reduce_test():
  738. t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
  739. kwargs = {'reduction': 'none'}
  740. return dict(
  741. fullname='NLLLossNd_no_reduce',
  742. constructor=wrap_functional(
  743. lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])),
  744. cpp_function_call='''F::nll_loss(
  745. i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''',
  746. input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
  747. cpp_var_map={'i': '_get_input()', 't': t},
  748. reference_fn=lambda i, *_:
  749. loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs),
  750. pickle=False)
  751. def nlllossNd_no_reduce_ignore_index_test():
  752. t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
  753. kwargs: Dict[str, Union[int, str]] = {'ignore_index': 1, 'reduction': 'none'}
  754. return dict(
  755. fullname='NLLLossNd_no_reduce_ignore_index',
  756. constructor=wrap_functional(
  757. lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']),
  758. reduction=str(kwargs['reduction']))),
  759. cpp_function_call='''F::nll_loss(
  760. i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(1).reduction(torch::kNone))''',
  761. input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
  762. cpp_var_map={'i': '_get_input()', 't': t},
  763. reference_fn=lambda i, *_:
  764. loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs),
  765. pickle=False)
  766. def nlllossNd_no_reduce_weights_test():
  767. t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
  768. weight = torch.rand(3)
  769. def kwargs(i):
  770. return {'weight': weight.type_as(i), 'reduction': 'none'}
  771. return dict(
  772. fullname='NLLLossNd_no_reduce_weights',
  773. constructor=wrap_functional(
  774. lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))),
  775. cpp_function_call='''F::nll_loss(
  776. i, t.to(i.options()).to(torch::kLong),
  777. F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone))''',
  778. input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
  779. cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight},
  780. reference_fn=lambda i, *_:
  781. loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs(i)),
  782. pickle=False)
  783. def smoothl1loss_no_reduce_test():
  784. t = torch.randn(2, 3, 4)
  785. return dict(
  786. fullname='SmoothL1Loss_no_reduce',
  787. constructor=wrap_functional(
  788. lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none')),
  789. cpp_function_call='''F::smooth_l1_loss(
  790. i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone))''',
  791. input_fn=lambda: torch.randn(2, 3, 4),
  792. cpp_var_map={'i': '_get_input()', 't': t},
  793. reference_fn=lambda i, *_:
  794. loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none'),
  795. supports_forward_ad=True,
  796. pickle=False)
  797. def smoothl1loss_no_reduce_scalar_test():
  798. t = torch.randn(())
  799. return dict(
  800. fullname='SmoothL1Loss_no_reduce_scalar',
  801. constructor=wrap_functional(
  802. lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none')),
  803. cpp_function_call='''F::smooth_l1_loss(
  804. i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone))''',
  805. input_fn=lambda: torch.randn(()),
  806. cpp_var_map={'i': '_get_input()', 't': t},
  807. reference_fn=lambda i, *_:
  808. loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none'),
  809. supports_forward_ad=True,
  810. pickle=False)
  811. def smoothl1loss_beta_test():
  812. t = torch.randn(2, 3, 4)
  813. return dict(
  814. fullname='SmoothL1Loss_beta',
  815. constructor=wrap_functional(
  816. lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none', beta=0.5)),
  817. cpp_function_call='''F::smooth_l1_loss(
  818. i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone), 0.5)''',
  819. input_fn=lambda: torch.randn(2, 3, 4),
  820. cpp_var_map={'i': '_get_input()', 't': t},
  821. reference_fn=lambda i, *_:
  822. loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none', beta=0.5),
  823. supports_forward_ad=True,
  824. pickle=False)
  825. def smoothl1loss_zero_beta_test():
  826. t = torch.randn(2, 3, 4)
  827. return dict(
  828. fullname='SmoothL1Loss_zero_beta',
  829. constructor=wrap_functional(
  830. lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none', beta=0)),
  831. cpp_function_call='''F::smooth_l1_loss(
  832. i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone), 0)''',
  833. input_fn=lambda: torch.randn(2, 3, 4),
  834. cpp_var_map={'i': '_get_input()', 't': t},
  835. reference_fn=lambda i, *_:
  836. loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none', beta=0),
  837. supports_forward_ad=True,
  838. pickle=False)
  839. def huberloss_delta_test():
  840. t = torch.randn(2, 3, 4)
  841. return dict(
  842. fullname='HuberLoss_delta',
  843. constructor=wrap_functional(
  844. lambda i: F.huber_loss(i, t.type_as(i), reduction='none', delta=0.5)),
  845. cpp_function_call='''F::huber_loss(
  846. i, t.to(i.options()), F::HuberLossFuncOptions().reduction(torch::kNone).delta(0.5))''',
  847. input_fn=lambda: torch.randn(2, 3, 4),
  848. cpp_var_map={'i': '_get_input()', 't': t},
  849. reference_fn=lambda i, *_:
  850. loss_reference_fns['HuberLoss'](i, t.type_as(i), reduction='none', delta=0.5),
  851. supports_forward_ad=True,
  852. pickle=False)
  853. def multilabelmarginloss_0d_no_reduce_test():
  854. t = torch.zeros(()).long()
  855. return dict(
  856. fullname='MultiLabelMarginLoss_0d_no_reduce',
  857. constructor=wrap_functional(
  858. lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')),
  859. cpp_function_call='''F::multilabel_margin_loss(
  860. i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''',
  861. input_fn=lambda: torch.randn(()),
  862. cpp_var_map={'i': '_get_input()', 't': t},
  863. reference_fn=lambda i, *_:
  864. loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
  865. check_sum_reduction=True,
  866. check_gradgrad=False,
  867. pickle=False)
  868. def multilabelmarginloss_1d_no_reduce_test():
  869. t = Variable(torch.rand(10).mul(10).floor().long())
  870. return dict(
  871. fullname='MultiLabelMarginLoss_1d_no_reduce',
  872. constructor=wrap_functional(
  873. lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')),
  874. cpp_function_call='''F::multilabel_margin_loss(
  875. i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''',
  876. input_fn=lambda: torch.randn(10),
  877. cpp_var_map={'i': '_get_input()', 't': t},
  878. reference_fn=lambda i, *_:
  879. loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
  880. check_sum_reduction=True,
  881. check_gradgrad=False,
  882. pickle=False)
  883. def multilabelmarginloss_index_neg_test():
  884. t = Variable(torch.clamp(torch.rand(5, 10).add(-.5).mul(20).floor().long(), min=-1))
  885. return dict(
  886. fullname='MultiLabelMarginLoss_index_neg',
  887. constructor=wrap_functional(
  888. lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')),
  889. cpp_function_call='''F::multilabel_margin_loss(
  890. i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''',
  891. input_fn=lambda: torch.randn(5, 10),
  892. cpp_var_map={'i': '_get_input()', 't': t},
  893. reference_fn=lambda i, *_:
  894. loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
  895. check_sum_reduction=True,
  896. check_gradgrad=False,
  897. pickle=False)
  898. def multilabelmarginloss_no_reduce_test():
  899. t = Variable(torch.rand(5, 10).mul(10).floor().long())
  900. return dict(
  901. fullname='MultiLabelMarginLoss_no_reduce',
  902. constructor=wrap_functional(
  903. lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')),
  904. cpp_function_call='''F::multilabel_margin_loss(
  905. i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''',
  906. input_fn=lambda: torch.randn(5, 10),
  907. cpp_var_map={'i': '_get_input()', 't': t},
  908. reference_fn=lambda i, *_:
  909. loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
  910. check_sum_reduction=True,
  911. check_gradgrad=False,
  912. pickle=False)
  913. def hingeembeddingloss_no_reduce_test():
  914. t = Variable(torch.randn(10).gt(0).double().mul_(2).sub(1))
  915. return dict(
  916. fullname='HingeEmbeddingLoss_no_reduce',
  917. constructor=wrap_functional(
  918. lambda i: F.hinge_embedding_loss(i, t.type_as(i), reduction='none')),
  919. cpp_function_call='''F::hinge_embedding_loss(
  920. i, t.to(i.options()), F::HingeEmbeddingLossFuncOptions().reduction(torch::kNone))''',
  921. input_fn=lambda: torch.randn(10),
  922. cpp_var_map={'i': '_get_input()', 't': t},
  923. reference_fn=lambda i, *_:
  924. loss_reference_fns['HingeEmbeddingLoss'](i, t.type_as(i), reduction='none'),
  925. check_sum_reduction=True,
  926. pickle=False)
  927. def hingeembeddingloss_margin_no_reduce_test():
  928. t = Variable(torch.randn(10).gt(0).double().mul_(2).sub(1))
  929. return dict(
  930. fullname='HingeEmbeddingLoss_margin_no_reduce',
  931. constructor=wrap_functional(
  932. lambda i: F.hinge_embedding_loss(i, t.type_as(i), margin=0.5, reduction='none')),
  933. cpp_function_call='''F::hinge_embedding_loss(
  934. i, t.to(i.options()), F::HingeEmbeddingLossFuncOptions().margin(0.5).reduction(torch::kNone))''',
  935. input_fn=lambda: torch.randn(10),
  936. cpp_var_map={'i': '_get_input()', 't': t},
  937. reference_fn=lambda i, *_:
  938. loss_reference_fns['HingeEmbeddingLoss'](i, t.type_as(i), margin=0.5, reduction='none'),
  939. check_sum_reduction=True,
  940. pickle=False)
  941. def softmarginloss_no_reduce_test():
  942. t = torch.randn(5, 5)
  943. return dict(
  944. fullname='SoftMarginLoss_no_reduce',
  945. constructor=wrap_functional(
  946. lambda i: F.soft_margin_loss(i, t.type_as(i), reduction='none')),
  947. cpp_function_call='''F::soft_margin_loss(
  948. i, t.to(i.options()), F::SoftMarginLossFuncOptions().reduction(torch::kNone))''',
  949. input_fn=lambda: torch.randn(5, 5),
  950. cpp_var_map={'i': '_get_input()', 't': t},
  951. reference_fn=lambda i, *_:
  952. loss_reference_fns['SoftMarginLoss'](i, t.type_as(i), reduction='none'),
  953. supports_forward_ad=True,
  954. pickle=False)
  955. def multilabelsoftmarginloss_no_reduce_test():
  956. t = torch.rand(5, 10).mul(2).floor()
  957. return dict(
  958. fullname='MultiLabelSoftMarginLoss_no_reduce',
  959. constructor=wrap_functional(
  960. lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i), reduction='none')),
  961. cpp_function_call='''F::multilabel_soft_margin_loss(
  962. i, t.to(i.options()), F::MultilabelSoftMarginLossFuncOptions().reduction(torch::kNone))''',
  963. input_fn=lambda: torch.randn(5, 10),
  964. cpp_var_map={'i': '_get_input()', 't': t},
  965. reference_fn=lambda i, *_:
  966. (-(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log())).sum(dim=1) / i.size(1),
  967. check_gradgrad=False,
  968. pickle=False)
  969. def multilabelsoftmarginloss_weights_no_reduce_test():
  970. t = torch.rand(5, 10).mul(2).floor()
  971. weights = torch.rand(10)
  972. return dict(
  973. fullname='MultiLabelSoftMarginLoss_weights_no_reduce',
  974. constructor=wrap_functional(
  975. lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i),
  976. weight=weights.type_as(i), reduction='none')),
  977. cpp_function_call='''F::multilabel_soft_margin_loss(
  978. i, t.to(i.options()),
  979. F::MultilabelSoftMarginLossFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))''',
  980. input_fn=lambda: torch.randn(5, 10),
  981. cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights},
  982. reference_fn=lambda i, *_:
  983. (-(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()) * weights).sum(dim=1) / i.size(1),
  984. check_sum_reduction=True,
  985. check_gradgrad=False,
  986. pickle=False)
  987. def multimarginloss_no_reduce_test():
  988. t = torch.rand(5).mul(8).floor().long()
  989. return dict(
  990. fullname='MultiMarginLoss_no_reduce',
  991. constructor=wrap_functional(
  992. lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction='none')),
  993. cpp_function_call='''F::multi_margin_loss(
  994. i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().reduction(torch::kNone))''',
  995. input_fn=lambda: torch.randn(5, 10),
  996. cpp_var_map={'i': '_get_input()', 't': t},
  997. reference_fn=lambda i, *_:
  998. loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
  999. check_sum_reduction=True,
  1000. check_gradgrad=False,
  1001. pickle=False)
  1002. def multimarginloss_1d_no_reduce_test():
  1003. t = torch.rand(1).mul(8).floor().long()
  1004. return dict(
  1005. fullname='MultiMarginLoss_1d_no_reduce',
  1006. constructor=wrap_functional(
  1007. lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction='none')),
  1008. cpp_function_call='''F::multi_margin_loss(
  1009. i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().reduction(torch::kNone))''',
  1010. input_fn=lambda: torch.randn(10),
  1011. cpp_var_map={'i': '_get_input()', 't': t},
  1012. reference_fn=lambda i, *_:
  1013. loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
  1014. check_sum_reduction=True,
  1015. check_gradgrad=False,
  1016. pickle=False)
  1017. def multimarginloss_1d_input_0d_target_no_reduce_test():
  1018. t = torch.rand(()).mul(8).floor().long()
  1019. return dict(
  1020. fullname='multimarginloss_1d_input_0d_target_no_reduce',
  1021. constructor=wrap_functional(
  1022. lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction='none')),
  1023. cpp_function_call='''F::multi_margin_loss(
  1024. i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().reduction(torch::kNone))''',
  1025. input_fn=lambda: torch.randn(10),
  1026. cpp_var_map={'i': '_get_input()', 't': t},
  1027. reference_fn=lambda i, *_:
  1028. loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
  1029. check_sum_reduction=True,
  1030. check_gradgrad=False,
  1031. pickle=False)
  1032. def multimarginloss_p_no_reduce_test():
  1033. t = torch.rand(5).mul(8).floor().long()
  1034. return dict(
  1035. fullname='MultiMarginLoss_p_no_reduce',
  1036. constructor=wrap_functional(
  1037. lambda i: F.multi_margin_loss(i, t.type_as(i).long(), p=2, reduction='none')),
  1038. cpp_function_call='''F::multi_margin_loss(
  1039. i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().p(2).reduction(torch::kNone))''',
  1040. input_fn=lambda: torch.randn(5, 10).clamp_(1e-2, 1 - 1e-2),
  1041. cpp_var_map={'i': '_get_input()', 't': t},
  1042. reference_fn=lambda i, *_:
  1043. loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), p=2, reduction='none'),
  1044. check_sum_reduction=True,
  1045. check_gradgrad=False,
  1046. pickle=False)
  1047. def multimarginloss_margin_no_reduce_test():
  1048. t = torch.rand(5).mul(8).floor().long()
  1049. return dict(
  1050. fullname='MultiMarginLoss_margin_no_reduce',
  1051. constructor=wrap_functional(
  1052. lambda i: F.multi_margin_loss(i, t.type_as(i).long(), margin=0.5, reduction='none')),
  1053. cpp_function_call='''F::multi_margin_loss(
  1054. i, t.to(i.options()).to(torch::kLong),
  1055. F::MultiMarginLossFuncOptions().margin(0.5).reduction(torch::kNone))''',
  1056. input_fn=lambda: torch.randn(5, 10),
  1057. cpp_var_map={'i': '_get_input()', 't': t},
  1058. reference_fn=lambda i, *_:
  1059. loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(),
  1060. margin=0.5, reduction='none'),
  1061. check_sum_reduction=True,
  1062. check_gradgrad=False,
  1063. pickle=False)
  1064. def multimarginloss_weights_no_reduce_test():
  1065. t = torch.rand(5).mul(8).floor().long()
  1066. weights = torch.rand(10)
  1067. return dict(
  1068. fullname='MultiMarginLoss_weights_no_reduce',
  1069. constructor=wrap_functional(
  1070. lambda i: F.multi_margin_loss(i, t.type_as(i).long(), weight=weights.type_as(i),
  1071. reduction='none')),
  1072. cpp_function_call='''F::multi_margin_loss(
  1073. i, t.to(i.options()).to(torch::kLong),
  1074. F::MultiMarginLossFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))''',
  1075. input_fn=lambda: torch.randn(5, 10),
  1076. cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights},
  1077. reference_fn=lambda i, *_:
  1078. loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(),
  1079. weight=weights, reduction='none'),
  1080. check_sum_reduction=True,
  1081. check_gradgrad=False,
  1082. pickle=False)
  1083. def fractional_max_pool2d_test(test_case, return_indices=False):
  1084. random_samples = torch.empty((1, 3, 2), dtype=torch.double).uniform_()
  1085. if test_case == 'ratio':
  1086. out = dict(
  1087. constructor=lambda: nn.FractionalMaxPool2d(
  1088. 2, output_ratio=0.5, _random_samples=random_samples, return_indices=return_indices),
  1089. cpp_constructor_args='''torch::nn::FractionalMaxPool2dOptions(2)
  1090. .output_ratio(0.5)
  1091. ._random_samples(random_samples)''',
  1092. input_size=(1, 3, 5, 7),
  1093. cpp_var_map={'random_samples': random_samples},
  1094. fullname='FractionalMaxPool2d_ratio')
  1095. elif test_case == 'size':
  1096. out = dict(
  1097. constructor=lambda: nn.FractionalMaxPool2d((2, 3), output_size=(
  1098. 4, 3), _random_samples=random_samples, return_indices=return_indices),
  1099. cpp_constructor_args='''torch::nn::FractionalMaxPool2dOptions({2, 3})
  1100. .output_size(std::vector<int64_t>({4, 3}))
  1101. ._random_samples(random_samples)''',
  1102. input_size=(1, 3, 7, 6),
  1103. cpp_var_map={'random_samples': random_samples},
  1104. fullname='FractionalMaxPool2d_size')
  1105. if return_indices:
  1106. # to get the return_indices behavior we have to call
  1107. # `forward_with_indices` in C++ and the return type switches from
  1108. # Tensor to tuple<Tensor, Tensor> which complicates testing considerably.
  1109. out['test_cpp_api_parity'] = False
  1110. out['fullname'] = '%s_return_indices' % out['fullname']
  1111. return out
  1112. def fractional_max_pool2d_no_batch_dim_test(test_case, use_random_samples):
  1113. if use_random_samples:
  1114. # random_samples enables CPU and GPU checks to be consistent
  1115. random_samples = torch.empty((1, 3, 2), dtype=torch.double).uniform_()
  1116. if test_case == 'ratio':
  1117. return dict(
  1118. constructor=lambda: nn.FractionalMaxPool2d(
  1119. 2, output_ratio=0.5, _random_samples=random_samples),
  1120. cpp_constructor_args='''torch::nn::FractionalMaxPool2dOptions(2)
  1121. .output_ratio(0.5)
  1122. ._random_samples(random_samples)''',
  1123. input_size=(3, 5, 7),
  1124. cpp_var_map={'random_samples': random_samples},
  1125. reference_fn=single_batch_reference_fn,
  1126. fullname='FractionalMaxPool2d_ratio_no_batch_dim')
  1127. elif test_case == 'size':
  1128. return dict(
  1129. constructor=lambda: nn.FractionalMaxPool2d((2, 3), output_size=(
  1130. 4, 3), _random_samples=random_samples),
  1131. cpp_constructor_args='''torch::nn::FractionalMaxPool2dOptions({2, 3})
  1132. .output_size(std::vector<int64_t>({4, 3}))
  1133. ._random_samples(random_samples)''',
  1134. input_size=(3, 7, 6),
  1135. cpp_var_map={'random_samples': random_samples},
  1136. reference_fn=single_batch_reference_fn,
  1137. fullname='FractionalMaxPool2d_size_no_batch_dim')
  1138. else:
  1139. # can not check cuda because there RNG is different between cpu and cuda
  1140. if test_case == 'ratio':
  1141. return dict(
  1142. constructor=lambda: nn.FractionalMaxPool2d(
  1143. 2, output_ratio=0.5),
  1144. cpp_constructor_args='''torch::nn::FractionalMaxPool2dOptions(2)
  1145. .output_ratio(0.5)''',
  1146. input_size=(3, 5, 7),
  1147. reference_fn=single_batch_reference_fn,
  1148. test_cuda=False,
  1149. fullname='FractionalMaxPool2d_ratio_no_batch_dim_no_random_samples')
  1150. elif test_case == 'size':
  1151. return dict(
  1152. constructor=lambda: nn.FractionalMaxPool2d((2, 3), output_size=(
  1153. 4, 3)),
  1154. cpp_constructor_args='''torch::nn::FractionalMaxPool2dOptions({2, 3})
  1155. .output_size(std::vector<int64_t>({4, 3}))''',
  1156. input_size=(3, 7, 6),
  1157. reference_fn=single_batch_reference_fn,
  1158. test_cuda=False,
  1159. fullname='FractionalMaxPool2d_size_no_batch_dim_no_random_samples')
  1160. def fractional_max_pool3d_test(test_case, return_indices=False):
  1161. random_samples = torch.empty((2, 4, 3), dtype=torch.double).uniform_()
  1162. if test_case == 'ratio':
  1163. out = dict(
  1164. constructor=lambda: nn.FractionalMaxPool3d(
  1165. 2, output_ratio=0.5, _random_samples=random_samples, return_indices=return_indices),
  1166. cpp_constructor_args='''torch::nn::FractionalMaxPool3dOptions(2)
  1167. .output_ratio(0.5)
  1168. ._random_samples(random_samples)''',
  1169. input_size=(2, 4, 5, 5, 5),
  1170. cpp_var_map={'random_samples': random_samples},
  1171. fullname='FractionalMaxPool3d_ratio')
  1172. elif test_case == 'size':
  1173. out = dict(
  1174. constructor=lambda: nn.FractionalMaxPool3d((2, 2, 2), output_size=(
  1175. 4, 4, 4), _random_samples=random_samples, return_indices=return_indices),
  1176. cpp_constructor_args='''torch::nn::FractionalMaxPool3dOptions({2, 2, 2})
  1177. .output_size(std::vector<int64_t>({4, 4, 4}))
  1178. ._random_samples(random_samples)''',
  1179. input_size=(2, 4, 7, 7, 7),
  1180. cpp_var_map={'random_samples': random_samples},
  1181. fullname='FractionalMaxPool3d_size')
  1182. elif test_case == 'asymsize':
  1183. out = dict(
  1184. constructor=lambda: nn.FractionalMaxPool3d((4, 2, 3), output_size=(
  1185. 10, 3, 2), _random_samples=random_samples, return_indices=return_indices),
  1186. cpp_constructor_args='''torch::nn::FractionalMaxPool3dOptions({4, 2, 3})
  1187. .output_size(std::vector<int64_t>({10, 3, 2}))
  1188. ._random_samples(random_samples)''',
  1189. input_size=(2, 4, 16, 7, 5),
  1190. cpp_var_map={'random_samples': random_samples},
  1191. fullname='FractionalMaxPool3d_asymsize')
  1192. if return_indices:
  1193. # to get the return_indices behavior we have to call
  1194. # `forward_with_indices` in C++ and the return type switches from
  1195. # Tensor to tuple<Tensor, Tensor> which complicates testing considerably.
  1196. out['test_cpp_api_parity'] = False
  1197. out['fullname'] = '%s_return_indices' % out['fullname']
  1198. return out
  1199. def fractional_max_pool3d_no_batch_dim_test(test_case, use_random_samples):
  1200. if use_random_samples:
  1201. # random_samples enables CPU and GPU checks to be consistent
  1202. random_samples = torch.empty((2, 4, 3), dtype=torch.double).uniform_()
  1203. if test_case == 'ratio':
  1204. return dict(
  1205. constructor=lambda: nn.FractionalMaxPool3d(
  1206. 2, output_ratio=0.5, _random_samples=random_samples),
  1207. cpp_constructor_args='''torch::nn::FractionalMaxPool3dOptions(2)
  1208. .output_ratio(0.5)
  1209. ._random_samples(random_samples)''',
  1210. input_size=(4, 5, 5, 5),
  1211. cpp_var_map={'random_samples': random_samples},
  1212. reference_fn=single_batch_reference_fn,
  1213. fullname='FractionalMaxPool3d_ratio_no_batch_dim')
  1214. elif test_case == 'size':
  1215. return dict(
  1216. constructor=lambda: nn.FractionalMaxPool3d((2, 2, 2), output_size=(
  1217. 4, 4, 4), _random_samples=random_samples),
  1218. cpp_constructor_args='''torch::nn::FractionalMaxPool3dOptions({2, 2, 2})
  1219. .output_size(std::vector<int64_t>({4, 4, 4}))
  1220. ._random_samples(random_samples)''',
  1221. input_size=(4, 7, 7, 7),
  1222. cpp_var_map={'random_samples': random_samples},
  1223. reference_fn=single_batch_reference_fn,
  1224. fullname='FractionalMaxPool3d_size_no_batch_dim')
  1225. else:
  1226. # can not check cuda because there RNG is different between cpu and cuda
  1227. if test_case == 'ratio':
  1228. return dict(
  1229. constructor=lambda: nn.FractionalMaxPool3d(
  1230. 2, output_ratio=0.5),
  1231. cpp_constructor_args='''torch::nn::FractionalMaxPool3dOptions(2)
  1232. .output_ratio(0.5)''',
  1233. input_size=(4, 5, 5, 5),
  1234. reference_fn=single_batch_reference_fn,
  1235. test_cuda=False,
  1236. fullname='FractionalMaxPool3d_ratio_no_batch_dim_no_random_samples')
  1237. elif test_case == 'size':
  1238. return dict(
  1239. constructor=lambda: nn.FractionalMaxPool3d((2, 2, 2), output_size=(
  1240. 4, 4, 4)),
  1241. cpp_constructor_args='''torch::nn::FractionalMaxPool3dOptions({2, 2, 2})
  1242. .output_size(std::vector<int64_t>({4, 4, 4}))''',
  1243. input_size=(4, 7, 7, 7),
  1244. reference_fn=single_batch_reference_fn,
  1245. test_cuda=False,
  1246. fullname='FractionalMaxPool3d_size_no_batch_dim_no_random_samples')
  1247. def single_batch_reference_fn(input, parameters, module):
  1248. """Reference function for modules supporting no batch dimensions.
  1249. The module is passed the input and target in batched form with a single item.
  1250. The output is squeezed to compare with the no-batch input.
  1251. """
  1252. def unsqueeze_inp(inp):
  1253. if isinstance(inp, (list, tuple)):
  1254. return [t.unsqueeze(0) for t in inp]
  1255. return inp.unsqueeze(0)
  1256. single_batch_input = unsqueeze_inp(input)
  1257. single_batch_input = [single_batch_input] if isinstance(single_batch_input, torch.Tensor) else single_batch_input
  1258. with freeze_rng_state():
  1259. return module(*single_batch_input).squeeze(0)
  1260. new_module_tests = [
  1261. poissonnllloss_no_reduce_test(),
  1262. bceloss_no_reduce_test(),
  1263. bceloss_weights_no_reduce_test(),
  1264. bce_with_logistic_legacy_enum_test(),
  1265. bce_with_logistic_no_reduce_test(),
  1266. bceloss_no_reduce_scalar_test(),
  1267. bceloss_weights_no_reduce_scalar_test(),
  1268. bce_with_logistic_no_reduce_scalar_test(),
  1269. kldivloss_with_target_no_reduce_test(),
  1270. kldivloss_no_reduce_test(),
  1271. kldivloss_no_reduce_scalar_test(),
  1272. kldivloss_with_log_target_no_reduce_test(),
  1273. kldivloss_no_reduce_log_target_test(),
  1274. kldivloss_no_reduce_scalar_log_target_test(),
  1275. l1loss_no_reduce_test(),
  1276. l1loss_no_reduce_complex_test(),
  1277. l1loss_no_reduce_scalar_test(),
  1278. mseloss_no_reduce_test(),
  1279. mseloss_no_reduce_scalar_test(),
  1280. nllloss_no_reduce_test(),
  1281. nllloss_no_reduce_ignore_index_test(),
  1282. nllloss_no_reduce_weights_test(),
  1283. nllloss_no_reduce_weights_ignore_index_test(),
  1284. nllloss_no_reduce_weights_ignore_index_neg_test(),
  1285. nllloss2d_no_reduce_test(),
  1286. nllloss2d_no_reduce_weights_test(),
  1287. nllloss2d_no_reduce_ignore_index_test(),
  1288. nlllossNd_no_reduce_test(),
  1289. nlllossNd_no_reduce_weights_test(),
  1290. nlllossNd_no_reduce_ignore_index_test(),
  1291. smoothl1loss_no_reduce_test(),
  1292. smoothl1loss_no_reduce_scalar_test(),
  1293. smoothl1loss_beta_test(),
  1294. smoothl1loss_zero_beta_test(),
  1295. huberloss_delta_test(),
  1296. multilabelmarginloss_0d_no_reduce_test(),
  1297. multilabelmarginloss_1d_no_reduce_test(),
  1298. multilabelmarginloss_index_neg_test(),
  1299. multilabelmarginloss_no_reduce_test(),
  1300. hingeembeddingloss_no_reduce_test(),
  1301. hingeembeddingloss_margin_no_reduce_test(),
  1302. softmarginloss_no_reduce_test(),
  1303. multilabelsoftmarginloss_no_reduce_test(),
  1304. multilabelsoftmarginloss_weights_no_reduce_test(),
  1305. multimarginloss_no_reduce_test(),
  1306. multimarginloss_1d_no_reduce_test(),
  1307. multimarginloss_1d_input_0d_target_no_reduce_test(),
  1308. multimarginloss_p_no_reduce_test(),
  1309. multimarginloss_margin_no_reduce_test(),
  1310. multimarginloss_weights_no_reduce_test(),
  1311. fractional_max_pool2d_test('ratio'),
  1312. fractional_max_pool2d_test('size'),
  1313. fractional_max_pool2d_no_batch_dim_test('ratio', True),
  1314. fractional_max_pool2d_no_batch_dim_test('ratio', False),
  1315. fractional_max_pool2d_no_batch_dim_test('size', True),
  1316. fractional_max_pool2d_no_batch_dim_test('size', False),
  1317. fractional_max_pool2d_test('ratio', return_indices=True),
  1318. fractional_max_pool3d_test('ratio'),
  1319. fractional_max_pool3d_test('size'),
  1320. fractional_max_pool3d_test('asymsize'),
  1321. fractional_max_pool3d_test('ratio', return_indices=True),
  1322. fractional_max_pool3d_no_batch_dim_test('ratio', True),
  1323. fractional_max_pool3d_no_batch_dim_test('ratio', False),
  1324. fractional_max_pool3d_no_batch_dim_test('size', True),
  1325. fractional_max_pool3d_no_batch_dim_test('size', False),
  1326. dict(
  1327. module_name='BatchNorm1d',
  1328. constructor_args=(10,),
  1329. cpp_constructor_args='torch::nn::BatchNorm1dOptions(10)',
  1330. input_size=(4, 10),
  1331. cudnn=True,
  1332. check_eval=True,
  1333. desc='affine',
  1334. ),
  1335. dict(
  1336. module_name='BatchNorm1d',
  1337. constructor_args=(5,),
  1338. cpp_constructor_args='torch::nn::BatchNorm1dOptions(5)',
  1339. input_size=(4, 5, 3),
  1340. cudnn=True,
  1341. check_eval=True,
  1342. desc='3d_input',
  1343. ),
  1344. dict(
  1345. module_name='BatchNorm1d',
  1346. constructor_args=(10, 1e-3, None),
  1347. cpp_constructor_args='torch::nn::BatchNorm1dOptions(10).eps(1e-3).momentum(c10::nullopt)',
  1348. input_size=(4, 10),
  1349. cudnn=True,
  1350. check_eval=True,
  1351. desc='affine_simple_average',
  1352. ),
  1353. dict(
  1354. module_name='BatchNorm1d',
  1355. constructor_args=(10, 1e-3, 0.3, False),
  1356. cpp_constructor_args='torch::nn::BatchNorm1dOptions(10).eps(1e-3).momentum(0.3).affine(false)',
  1357. input_size=(4, 10),
  1358. cudnn=True,
  1359. check_eval=True,
  1360. desc='not_affine',
  1361. ),
  1362. dict(
  1363. module_name='BatchNorm1d',
  1364. constructor_args=(10, 1e-3, 0.3, True, False),
  1365. cpp_constructor_args='''torch::nn::BatchNorm1dOptions(10)
  1366. .eps(1e-3).momentum(0.3).affine(true).track_running_stats(false)''',
  1367. input_size=(4, 10),
  1368. cudnn=True,
  1369. check_eval=True,
  1370. desc='not_tracking_stats',
  1371. ),
  1372. dict(
  1373. module_name='BatchNorm1d',
  1374. constructor_args=(5, 1e-3, 0.3, False),
  1375. cpp_constructor_args='torch::nn::BatchNorm1dOptions(5).eps(1e-3).momentum(0.3).affine(false)',
  1376. input_size=(4, 5, 3),
  1377. cudnn=True,
  1378. check_eval=True,
  1379. desc='3d_input_not_affine',
  1380. ),
  1381. dict(
  1382. module_name='BatchNorm1d',
  1383. constructor_args=(5, 1e-3, 0.3, False),
  1384. cpp_constructor_args='torch::nn::BatchNorm1dOptions(5).eps(1e-3).momentum(0.3).affine(false)',
  1385. input_size=(0, 5, 9),
  1386. cudnn=True,
  1387. check_eval=True,
  1388. desc='zero_batch',
  1389. ),
  1390. dict(
  1391. module_name='BatchNorm2d',
  1392. constructor_args=(3,),
  1393. cpp_constructor_args='torch::nn::BatchNorm2dOptions(3)',
  1394. input_size=(2, 3, 6, 6),
  1395. cudnn=True,
  1396. check_eval=True,
  1397. ),
  1398. dict(
  1399. module_name='BatchNorm2d',
  1400. constructor_args=(3, 1e-3, None),
  1401. cpp_constructor_args='torch::nn::BatchNorm2dOptions(3).eps(1e-3).momentum(c10::nullopt)',
  1402. input_size=(2, 3, 6, 6),
  1403. cudnn=True,
  1404. check_eval=True,
  1405. desc='2d_simple_average',
  1406. ),
  1407. dict(
  1408. module_name='BatchNorm2d',
  1409. constructor_args=(3, 1e-3, 0.8),
  1410. cpp_constructor_args='torch::nn::BatchNorm2dOptions(3).eps(1e-3).momentum(0.8)',
  1411. input_size=(2, 3, 6, 6),
  1412. cudnn=True,
  1413. check_eval=True,
  1414. desc='momentum',
  1415. ),
  1416. dict(
  1417. module_name='BatchNorm2d',
  1418. constructor_args=(3, 1e-3, 0.8, False),
  1419. cpp_constructor_args='torch::nn::BatchNorm2dOptions(3).eps(1e-3).momentum(0.8).affine(false)',
  1420. input_size=(2, 3, 6, 6),
  1421. cudnn=True,
  1422. check_eval=True,
  1423. desc='not_affine',
  1424. ),
  1425. dict(
  1426. module_name='BatchNorm2d',
  1427. constructor_args=(3, 1e-3, 0.8, True, False),
  1428. cpp_constructor_args='''torch::nn::BatchNorm2dOptions(3)
  1429. .eps(1e-3).momentum(0.8).affine(true).track_running_stats(false)''',
  1430. input_size=(2, 3, 6, 6),
  1431. cudnn=True,
  1432. check_eval=True,
  1433. desc='not_tracking_stats',
  1434. ),
  1435. dict(
  1436. module_name='BatchNorm2d',
  1437. constructor_args=(5, 1e-3, 0.3, False),
  1438. cpp_constructor_args='torch::nn::BatchNorm2dOptions(5).eps(1e-3).momentum(0.3).affine(false)',
  1439. input_size=(0, 5, 2, 2),
  1440. cudnn=True,
  1441. check_eval=True,
  1442. desc='zero_batch',
  1443. ),
  1444. dict(
  1445. module_name='BatchNorm3d',
  1446. constructor_args=(3,),
  1447. cpp_constructor_args='torch::nn::BatchNorm3dOptions(3)',
  1448. input_size=(2, 3, 4, 4, 4),
  1449. cudnn=True,
  1450. check_eval=True,
  1451. ),
  1452. dict(
  1453. module_name='BatchNorm3d',
  1454. constructor_args=(3, 1e-3, None),
  1455. cpp_constructor_args='torch::nn::BatchNorm3dOptions(3).eps(1e-3).momentum(c10::nullopt)',
  1456. input_size=(2, 3, 4, 4, 4),
  1457. cudnn=True,
  1458. check_eval=True,
  1459. desc='3d_simple_average',
  1460. ),
  1461. dict(
  1462. module_name='BatchNorm3d',
  1463. constructor_args=(3, 1e-3, 0.7),
  1464. cpp_constructor_args='torch::nn::BatchNorm3dOptions(3).eps(1e-3).momentum(0.7)',
  1465. input_size=(2, 3, 4, 4, 4),
  1466. cudnn=True,
  1467. check_eval=True,
  1468. desc='momentum',
  1469. ),
  1470. dict(
  1471. module_name='BatchNorm3d',
  1472. constructor_args=(3, 1e-3, 0.7, False),
  1473. cpp_constructor_args='torch::nn::BatchNorm3dOptions(3).eps(1e-3).momentum(0.7).affine(false)',
  1474. input_size=(2, 3, 4, 4, 4),
  1475. cudnn=True,
  1476. check_eval=True,
  1477. desc='not_affine',
  1478. ),
  1479. dict(
  1480. module_name='BatchNorm3d',
  1481. constructor_args=(3, 1e-3, 0.7, True, False),
  1482. cpp_constructor_args='''torch::nn::BatchNorm3dOptions(3)
  1483. .eps(1e-3).momentum(0.7).affine(true).track_running_stats(false)''',
  1484. input_size=(2, 3, 4, 4, 4),
  1485. cudnn=True,
  1486. check_eval=True,
  1487. desc='not_tracking_stats',
  1488. ),
  1489. dict(
  1490. module_name='BatchNorm3d',
  1491. constructor_args=(5, 1e-3, 0.3, False),
  1492. cpp_constructor_args='torch::nn::BatchNorm3dOptions(5).eps(1e-3).momentum(0.3).affine(false)',
  1493. input_size=(0, 5, 2, 2, 2),
  1494. cudnn=True,
  1495. check_eval=True,
  1496. desc='zero_batch',
  1497. ),
  1498. dict(
  1499. module_name='InstanceNorm1d',
  1500. constructor_args=(3, 1e-3, 0.3),
  1501. cpp_constructor_args='torch::nn::InstanceNorm1dOptions(3).eps(1e-3).momentum(0.3)',
  1502. input_size=(4, 3, 15),
  1503. cudnn=True,
  1504. check_eval=True,
  1505. ),
  1506. dict(
  1507. module_name='InstanceNorm1d',
  1508. constructor_args=(3, 1e-3, 0.3, False, True),
  1509. cpp_constructor_args='''torch::nn::InstanceNorm1dOptions(3)
  1510. .eps(1e-3).momentum(0.3).affine(false).track_running_stats(true)''',
  1511. input_size=(4, 3, 15),
  1512. cudnn=True,
  1513. check_eval=True,
  1514. desc='tracking_stats',
  1515. ),
  1516. dict(
  1517. module_name='InstanceNorm1d',
  1518. constructor_args=(3, 1e-3, 0.3, False, True),
  1519. cpp_constructor_args='''torch::nn::InstanceNorm1dOptions(3)
  1520. .eps(1e-3).momentum(0.3).affine(false).track_running_stats(true)''',
  1521. input_size=(3, 15),
  1522. cudnn=True,
  1523. check_eval=True,
  1524. ref=single_batch_reference_fn,
  1525. desc='tracking_stats_no_batch_dim',
  1526. ),
  1527. dict(
  1528. module_name='InstanceNorm1d',
  1529. constructor_args=(3, 1e-3, 0.3),
  1530. cpp_constructor_args='torch::nn::InstanceNorm1dOptions(3).eps(1e-3).momentum(0.3)',
  1531. input_size=(3, 15),
  1532. cudnn=True,
  1533. check_eval=True,
  1534. ref=single_batch_reference_fn,
  1535. desc='no_batch_dim',
  1536. ),
  1537. dict(
  1538. module_name='InstanceNorm2d',
  1539. constructor_args=(3, 1e-3, 0.3),
  1540. cpp_constructor_args='torch::nn::InstanceNorm2dOptions(3).eps(1e-3).momentum(0.3)',
  1541. input_size=(2, 3, 6, 6),
  1542. cudnn=True,
  1543. check_eval=True,
  1544. ),
  1545. dict(
  1546. module_name='InstanceNorm2d',
  1547. constructor_args=(3, 1e-3, 0.3, False, True),
  1548. cpp_constructor_args='''torch::nn::InstanceNorm2dOptions(3)
  1549. .eps(1e-3).momentum(0.3).affine(false).track_running_stats(true)''',
  1550. input_size=(2, 3, 6, 6),
  1551. cudnn=True,
  1552. check_eval=True,
  1553. desc='tracking_stats',
  1554. ),
  1555. dict(
  1556. module_name='InstanceNorm2d',
  1557. constructor_args=(3, 1e-3, 0.3),
  1558. cpp_constructor_args='torch::nn::InstanceNorm2dOptions(3).eps(1e-3).momentum(0.3)',
  1559. input_size=(3, 6, 6),
  1560. cudnn=True,
  1561. check_eval=True,
  1562. ref=single_batch_reference_fn,
  1563. desc='no_batch_dim'
  1564. ),
  1565. dict(
  1566. module_name='InstanceNorm2d',
  1567. constructor_args=(3, 1e-3, 0.3, False, True),
  1568. cpp_constructor_args='''torch::nn::InstanceNorm2dOptions(3)
  1569. .eps(1e-3).momentum(0.3).affine(false).track_running_stats(true)''',
  1570. input_size=(3, 6, 6),
  1571. cudnn=True,
  1572. check_eval=True,
  1573. ref=single_batch_reference_fn,
  1574. desc='tracking_stats_no_batch_dim',
  1575. ),
  1576. dict(
  1577. module_name='InstanceNorm3d',
  1578. constructor_args=(3, 1e-3, 0.3),
  1579. cpp_constructor_args='torch::nn::InstanceNorm3dOptions(3).eps(1e-3).momentum(0.3)',
  1580. input_size=(2, 3, 4, 4, 4),
  1581. cudnn=True,
  1582. check_eval=True,
  1583. ),
  1584. dict(
  1585. module_name='InstanceNorm3d',
  1586. constructor_args=(3, 1e-3, 0.3, False, True),
  1587. cpp_constructor_args='''torch::nn::InstanceNorm3dOptions(3)
  1588. .eps(1e-3).momentum(0.3).affine(false).track_running_stats(true)''',
  1589. input_size=(2, 3, 4, 4, 4),
  1590. cudnn=True,
  1591. check_eval=True,
  1592. desc='tracking_stats',
  1593. ),
  1594. dict(
  1595. module_name='InstanceNorm3d',
  1596. constructor_args=(3, 1e-3, 0.3),
  1597. cpp_constructor_args='torch::nn::InstanceNorm3dOptions(3).eps(1e-3).momentum(0.3)',
  1598. input_size=(3, 4, 4, 4),
  1599. cudnn=True,
  1600. check_eval=True,
  1601. ref=single_batch_reference_fn,
  1602. desc='no_batch_dim'
  1603. ),
  1604. dict(
  1605. module_name='InstanceNorm3d',
  1606. constructor_args=(3, 1e-3, 0.3, False, True),
  1607. cpp_constructor_args='''torch::nn::InstanceNorm3dOptions(3)
  1608. .eps(1e-3).momentum(0.3).affine(false).track_running_stats(true)''',
  1609. input_size=(2, 3, 4, 4, 4),
  1610. cudnn=True,
  1611. check_eval=True,
  1612. ref=single_batch_reference_fn,
  1613. desc='tracking_stats_no_batch_dim',
  1614. ),
  1615. dict(
  1616. module_name='LayerNorm',
  1617. constructor_args=([5], 1e-3),
  1618. cpp_constructor_args='torch::nn::LayerNormOptions({5}).eps(1e-3)',
  1619. input_size=(4, 5, 5),
  1620. cudnn=True,
  1621. check_eval=True,
  1622. check_half=True,
  1623. desc='1d_elementwise_affine',
  1624. ),
  1625. dict(
  1626. module_name='LayerNorm',
  1627. constructor_args=([5], 1e-3, False),
  1628. cpp_constructor_args='torch::nn::LayerNormOptions({5}).eps(1e-3).elementwise_affine(false)',
  1629. input_size=(4, 5, 5),
  1630. cudnn=True,
  1631. check_eval=True,
  1632. check_half=True,
  1633. desc='1d_no_elementwise_affine',
  1634. ),
  1635. dict(
  1636. module_name='LayerNorm',
  1637. constructor_args=([2, 2, 5], 1e-3),
  1638. cpp_constructor_args='torch::nn::LayerNormOptions({2, 2, 5}).eps(1e-3)',
  1639. input_size=(4, 2, 2, 5),
  1640. cudnn=True,
  1641. check_eval=True,
  1642. check_half=True,
  1643. desc='3d_elementwise_affine',
  1644. ),
  1645. dict(
  1646. module_name='LayerNorm',
  1647. constructor_args=([2, 2, 5], 1e-3, False),
  1648. cpp_constructor_args='torch::nn::LayerNormOptions({2, 2, 5}).eps(1e-3).elementwise_affine(false)',
  1649. input_size=(4, 2, 2, 5),
  1650. cudnn=True,
  1651. check_eval=True,
  1652. check_half=True,
  1653. desc='3d_no_elementwise_affine',
  1654. ),
  1655. dict(
  1656. module_name='LayerNorm',
  1657. constructor_args=([56, 56, 56], 1e-5, False),
  1658. cpp_constructor_args='torch::nn::LayerNormOptions({56, 56, 56}).eps(1e-5).elementwise_affine(false)',
  1659. input_size=(4, 56, 56, 56),
  1660. cudnn=True,
  1661. check_eval=True,
  1662. gradcheck_fast_mode=True,
  1663. check_half=True,
  1664. desc='3d_no_affine_large_feature',
  1665. ),
  1666. dict(
  1667. module_name='LayerNorm',
  1668. constructor_args=([5], 1e-3),
  1669. cpp_constructor_args='torch::nn::LayerNormOptions({5}).eps(1e-3)',
  1670. input_size=(0, 5),
  1671. cudnn=True,
  1672. check_eval=True,
  1673. check_half=True,
  1674. desc='1d_empty_elementwise_affine',
  1675. ),
  1676. dict(
  1677. module_name='GroupNorm',
  1678. constructor_args=(3, 6, 1e-3),
  1679. cpp_constructor_args='torch::nn::GroupNormOptions(3, 6).eps(1e-3)',
  1680. input_size=(4, 6, 5),
  1681. cudnn=True,
  1682. check_eval=True,
  1683. check_bfloat16=True,
  1684. desc='1d_affine',
  1685. ),
  1686. dict(
  1687. module_name='GroupNorm',
  1688. constructor_args=(3, 12, 1e-3),
  1689. cpp_constructor_args='torch::nn::GroupNormOptions(3, 12).eps(1e-3)',
  1690. input_size=(4, 12),
  1691. cudnn=True,
  1692. check_eval=True,
  1693. check_bfloat16=True,
  1694. desc='1d_affine_GN',
  1695. ),
  1696. dict(
  1697. module_name='GroupNorm',
  1698. constructor_args=(1, 6, 1e-3),
  1699. cpp_constructor_args='torch::nn::GroupNormOptions(1, 6).eps(1e-3)',
  1700. input_size=(150, 6),
  1701. cudnn=True,
  1702. check_eval=True,
  1703. desc='1d_affine_large_batch', # For large batch_size
  1704. check_bfloat16=True,
  1705. test_cpu=False,
  1706. ),
  1707. dict(
  1708. module_name='GroupNorm',
  1709. constructor_args=(5, 5, 1e-3, False),
  1710. cpp_constructor_args='torch::nn::GroupNormOptions(5, 5).eps(1e-3).affine(false)',
  1711. input_size=(4, 5, 5),
  1712. cudnn=True,
  1713. check_eval=True,
  1714. check_bfloat16=True,
  1715. desc='1d_no_affine_IN', # this setting is equivalent with InstanceNormi
  1716. ),
  1717. dict(
  1718. module_name='GroupNorm',
  1719. constructor_args=(1, 10, 1e-3, False),
  1720. cpp_constructor_args='torch::nn::GroupNormOptions(1, 10).eps(1e-3).affine(false)',
  1721. input_size=(4, 10),
  1722. cudnn=True,
  1723. check_eval=True,
  1724. check_bfloat16=True,
  1725. desc='1d_no_affine_LN', # this setting is equivalent with LayerNorm
  1726. ),
  1727. dict(
  1728. module_name='GroupNorm',
  1729. constructor_args=(3, 6, 1e-3),
  1730. cpp_constructor_args='torch::nn::GroupNormOptions(3, 6).eps(1e-3)',
  1731. input_size=(4, 6, 2, 3),
  1732. cudnn=True,
  1733. check_eval=True,
  1734. check_bfloat16=True,
  1735. desc='2d_affine',
  1736. ),
  1737. dict(
  1738. module_name='GroupNorm',
  1739. constructor_args=(3, 6, 1e-3),
  1740. cpp_constructor_args='torch::nn::GroupNormOptions(3, 6).eps(1e-3)',
  1741. input_size=(4, 6, 28, 28),
  1742. cudnn=True,
  1743. check_eval=True,
  1744. check_bfloat16=True,
  1745. desc='2d_affine_large_feature',
  1746. test_cpu=False,
  1747. ),
  1748. dict(
  1749. module_name='GroupNorm',
  1750. constructor_args=(3, 51, 1e-5, False),
  1751. cpp_constructor_args='torch::nn::GroupNormOptions(3, 51).eps(1e-5).affine(false)',
  1752. input_size=(2, 51, 28, 28),
  1753. cudnn=True,
  1754. check_eval=True,
  1755. check_bfloat16=True,
  1756. desc='2d_no_affine_large_feature',
  1757. test_cpu=False,
  1758. ),
  1759. dict(
  1760. module_name='GroupNorm',
  1761. constructor_args=(3, 3, 1e-3, False),
  1762. cpp_constructor_args='torch::nn::GroupNormOptions(3, 3).eps(1e-3).affine(false)',
  1763. input_size=(4, 3, 2, 3),
  1764. cudnn=True,
  1765. check_eval=True,
  1766. check_bfloat16=True,
  1767. desc='2d_no_affine_IN', # this setting is equivalent with InstanceNorm
  1768. ),
  1769. dict(
  1770. module_name='GroupNorm',
  1771. constructor_args=(1, 3, 1e-3, False),
  1772. cpp_constructor_args='torch::nn::GroupNormOptions(1, 3).eps(1e-3).affine(false)',
  1773. input_size=(4, 3, 2, 3),
  1774. cudnn=True,
  1775. check_eval=True,
  1776. check_bfloat16=True,
  1777. desc='2d_no_affine_LN', # this setting is equivalent with LayerNorm
  1778. ),
  1779. dict(
  1780. module_name='Conv1d',
  1781. constructor_args=(4, 5, 3),
  1782. cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3)',
  1783. input_size=(2, 4, 10),
  1784. cudnn=True,
  1785. with_tf32=True,
  1786. tf32_precision=0.005,
  1787. ),
  1788. dict(
  1789. module_name='Conv1d',
  1790. constructor_args=(4, 5, 3, 2),
  1791. cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).stride(2)',
  1792. input_size=(2, 4, 10),
  1793. cudnn=True,
  1794. desc='stride',
  1795. with_tf32=True,
  1796. tf32_precision=0.005,
  1797. ),
  1798. dict(
  1799. module_name='Conv1d',
  1800. constructor_args=(4, 5, 3, 1, 1),
  1801. cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).stride(1).padding(1)',
  1802. input_size=(2, 4, 10),
  1803. cudnn=True,
  1804. desc='pad1',
  1805. with_tf32=True,
  1806. tf32_precision=0.01,
  1807. ),
  1808. dict(
  1809. module_name='Conv1d',
  1810. constructor_args=(4, 5, 5, 1, 2),
  1811. cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 5).stride(1).padding(2)',
  1812. input_size=(2, 4, 10),
  1813. cudnn=True,
  1814. desc='pad2',
  1815. with_tf32=True,
  1816. tf32_precision=0.005,
  1817. ),
  1818. dict(
  1819. module_name='Conv1d',
  1820. constructor_args=(4, 4, 3, 1, 1),
  1821. cpp_constructor_args='torch::nn::Conv1dOptions(4, 4, 3).stride(1).padding(1)',
  1822. input_size=(1, 4, 1),
  1823. cudnn=True,
  1824. desc='pad1size1',
  1825. with_tf32=True,
  1826. tf32_precision=0.005,
  1827. ),
  1828. dict(
  1829. module_name='Conv1d',
  1830. constructor_args=(4, 4, 5, 1, 2),
  1831. cpp_constructor_args='torch::nn::Conv1dOptions(4, 4, 5).stride(1).padding(2)',
  1832. input_size=(1, 4, 1),
  1833. cudnn=True,
  1834. desc='pad2size1',
  1835. with_tf32=True,
  1836. tf32_precision=0.005,
  1837. ),
  1838. dict(
  1839. module_name='Conv1d',
  1840. constructor_args=(4, 5, 3),
  1841. cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3)',
  1842. input_size=(0, 4, 10),
  1843. cudnn=True,
  1844. desc='zero_batch',
  1845. with_tf32=True,
  1846. tf32_precision=0.005,
  1847. ),
  1848. dict(
  1849. fullname='Conv1d_dilated',
  1850. constructor=lambda: nn.Conv1d(4, 5, kernel_size=3, dilation=2),
  1851. cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).dilation(2)',
  1852. input_size=(2, 4, 10),
  1853. with_tf32=True,
  1854. tf32_precision=0.005,
  1855. ),
  1856. dict(
  1857. fullname='Conv1d_groups',
  1858. constructor=lambda: nn.Conv1d(4, 6, kernel_size=3, groups=2),
  1859. cpp_constructor_args='torch::nn::Conv1dOptions(4, 6, 3).groups(2)',
  1860. input_size=(2, 4, 6),
  1861. cudnn=True,
  1862. with_tf32=True,
  1863. tf32_precision=0.005,
  1864. ),
  1865. dict(
  1866. fullname='Conv1d_pad_valid',
  1867. constructor=lambda: nn.Conv1d(4, 5, 3, padding="valid"),
  1868. cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kValid)',
  1869. input_size=(2, 4, 10),
  1870. cudnn=True,
  1871. with_tf32=True,
  1872. tf32_precision=0.005,
  1873. ),
  1874. dict(
  1875. fullname='Conv1d_pad_same',
  1876. constructor=lambda: nn.Conv1d(4, 5, 3, padding="same"),
  1877. cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kSame)',
  1878. input_size=(2, 4, 10),
  1879. cudnn=True,
  1880. with_tf32=True,
  1881. tf32_precision=0.005,
  1882. ),
  1883. dict(
  1884. fullname='Conv1d_pad_same2',
  1885. constructor=lambda: nn.Conv1d(4, 5, 4, padding="same"),
  1886. cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 4).padding(torch::kSame)',
  1887. input_size=(2, 4, 10),
  1888. cudnn=True,
  1889. with_tf32=True,
  1890. tf32_precision=0.005,
  1891. ),
  1892. dict(
  1893. fullname='Conv1d_pad_same_dilated',
  1894. constructor=lambda: nn.Conv1d(4, 5, 4, padding="same", dilation=2),
  1895. cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kSame).dilation(2)',
  1896. input_size=(2, 4, 10),
  1897. cudnn=True,
  1898. with_tf32=True,
  1899. tf32_precision=0.005,
  1900. ),
  1901. dict(
  1902. fullname='ConvTranspose1d',
  1903. constructor=lambda: nn.ConvTranspose1d(3, 4, kernel_size=3, stride=(3,), padding=1, output_padding=(1,)),
  1904. cpp_constructor_args='torch::nn::ConvTranspose1dOptions(3, 4, 3).stride(3).padding(1).output_padding(1)',
  1905. cudnn=True,
  1906. input_size=(1, 3, 7),
  1907. with_tf32=True,
  1908. tf32_precision=0.005,
  1909. ),
  1910. dict(
  1911. module_name='ConvTranspose1d',
  1912. constructor_args=(3, 4, 3, 2, 1, 1, 1, False),
  1913. cpp_constructor_args='''torch::nn::ConvTranspose1dOptions(3, 4, 3)
  1914. .stride(2).padding(1).output_padding(1).groups(1).bias(false)''',
  1915. input_size=(1, 3, 6),
  1916. cudnn=True,
  1917. desc='no_bias',
  1918. with_tf32=True,
  1919. tf32_precision=0.005,
  1920. ),
  1921. dict(
  1922. module_name='ConvTranspose1d',
  1923. constructor_args=(3, 4, 3, 2, 1, 1, 1, True, 2),
  1924. cpp_constructor_args='''torch::nn::ConvTranspose1dOptions(3, 4, 3)
  1925. .stride(2).padding(1).output_padding(1).groups(1).bias(true).dilation(2)''',
  1926. input_size=(1, 3, 6),
  1927. cudnn=True,
  1928. desc='dilated',
  1929. with_tf32=True,
  1930. tf32_precision=0.005,
  1931. ),
  1932. dict(
  1933. fullname='ConvTranspose1d_groups',
  1934. constructor=lambda: nn.ConvTranspose1d(4, 6, 3, stride=(3,), padding=1, output_padding=(1,), groups=2),
  1935. cpp_constructor_args='''torch::nn::ConvTranspose1dOptions(4, 6, 3)
  1936. .stride(3).padding(1).output_padding(1).groups(2)''',
  1937. cudnn=True,
  1938. input_size=(2, 4, 7),
  1939. with_tf32=True,
  1940. tf32_precision=0.005,
  1941. ),
  1942. dict(
  1943. module_name='MaxPool1d',
  1944. constructor_args=(4,),
  1945. cpp_constructor_args='torch::nn::MaxPool1dOptions(4)',
  1946. input_size=(2, 10, 4),
  1947. ),
  1948. dict(
  1949. module_name='MaxPool1d',
  1950. constructor_args=(4, 4),
  1951. cpp_constructor_args='torch::nn::MaxPool1dOptions(4).stride(4)',
  1952. input_size=(2, 10, 4),
  1953. desc='stride',
  1954. ),
  1955. dict(
  1956. module_name='MaxPool1d',
  1957. fullname='MaxPool1d_return_indices',
  1958. constructor=lambda: nn.MaxPool1d(4, return_indices=True),
  1959. input_size=(2, 10, 4),
  1960. test_cpp_api_parity=False,
  1961. ),
  1962. dict(
  1963. module_name='Conv2d',
  1964. constructor_args=(3, 4, (3, 2)),
  1965. cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 2})',
  1966. input_size=(2, 3, 7, 5),
  1967. cudnn=True,
  1968. check_with_long_tensor=True,
  1969. with_tf32=True,
  1970. tf32_precision=0.005,
  1971. ),
  1972. dict(
  1973. module_name='Conv2d',
  1974. constructor_args=(3, 4, (3, 3), (2, 2)),
  1975. cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 3}).stride({2, 2})',
  1976. input_size=(2, 3, 6, 6),
  1977. cudnn=True,
  1978. desc='strided',
  1979. check_with_long_tensor=True,
  1980. with_tf32=True,
  1981. tf32_precision=0.005,
  1982. ),
  1983. dict(
  1984. module_name='Conv2d',
  1985. constructor_args=(3, 4, (3, 3), (2, 2), (1, 1)),
  1986. cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 3}).stride({2, 2}).padding({1, 1})',
  1987. input_size=(2, 3, 6, 6),
  1988. cudnn=True,
  1989. desc='padding',
  1990. check_with_long_tensor=True,
  1991. with_tf32=True,
  1992. tf32_precision=0.005,
  1993. ),
  1994. dict(
  1995. module_name='Conv2d',
  1996. constructor_args=(3, 2, (3, 3), (2, 2), (1, 1), (2, 2)),
  1997. cpp_constructor_args='torch::nn::Conv2dOptions(3, 2, {3, 3}).stride({2, 2}).padding({1, 1}).dilation({2, 2})',
  1998. input_size=(2, 3, 8, 8),
  1999. cudnn=True,
  2000. desc='dilated',
  2001. check_with_long_tensor=True,
  2002. with_tf32=True,
  2003. tf32_precision=0.005,
  2004. ),
  2005. dict(
  2006. module_name='Conv2d',
  2007. constructor_args=(3, 4, (3, 2), 1, 0, 1, 1, False),
  2008. cpp_constructor_args='''torch::nn::Conv2dOptions(3, 4, {3, 2})
  2009. .stride(1).padding(0).dilation(1).groups(1).bias(false)''',
  2010. input_size=(2, 3, 6, 5),
  2011. cudnn=True,
  2012. desc='no_bias',
  2013. check_with_long_tensor=True,
  2014. with_tf32=True,
  2015. ),
  2016. dict(
  2017. module_name='Conv2d',
  2018. constructor_args=(3, 4, (3, 2)),
  2019. cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 2})',
  2020. input_size=(0, 3, 7, 5),
  2021. cudnn=True,
  2022. desc='zero_batch',
  2023. check_with_long_tensor=True,
  2024. with_tf32=True,
  2025. ),
  2026. dict(
  2027. fullname='Conv2d_groups',
  2028. constructor=lambda: nn.Conv2d(4, 6, (3, 2), groups=2),
  2029. cpp_constructor_args='torch::nn::Conv2dOptions(4, 6, {3, 2}).groups(2)',
  2030. input_size=(2, 4, 6, 5),
  2031. cudnn=True,
  2032. check_with_long_tensor=True,
  2033. with_tf32=True,
  2034. tf32_precision=0.005,
  2035. ),
  2036. dict(
  2037. fullname='Conv2d_groups_thnn',
  2038. constructor=lambda: nn.Conv2d(4, 6, (3, 2), groups=2),
  2039. cpp_constructor_args='torch::nn::Conv2dOptions(4, 6, {3, 2}).groups(2)',
  2040. input_size=(2, 4, 6, 5),
  2041. check_with_long_tensor=True,
  2042. with_tf32=True,
  2043. tf32_precision=0.005,
  2044. ),
  2045. dict(
  2046. fullname='Conv2d_pad_valid',
  2047. constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="valid"),
  2048. cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kValid)',
  2049. input_size=(2, 2, 6, 5),
  2050. cudnn=True,
  2051. with_tf32=True,
  2052. tf32_precision=0.005,
  2053. ),
  2054. dict(
  2055. fullname='Conv2d_pad_same',
  2056. constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="same"),
  2057. cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kSame)',
  2058. input_size=(2, 2, 6, 5),
  2059. cudnn=True,
  2060. with_tf32=True,
  2061. tf32_precision=0.01,
  2062. ),
  2063. dict(
  2064. fullname='Conv2d_pad_same_dilated',
  2065. constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="same", dilation=2),
  2066. cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kSame).dilation(2)',
  2067. input_size=(2, 2, 6, 5),
  2068. cudnn=True,
  2069. with_tf32=True,
  2070. tf32_precision=0.005,
  2071. ),
  2072. dict(
  2073. module_name='ConvTranspose2d',
  2074. constructor_args=(3, 4, 3, (3, 2), 1, (1, 1)),
  2075. cpp_constructor_args='''torch::nn::ConvTranspose2dOptions(3, 4, 3)
  2076. .stride({3, 2}).padding(1).output_padding({1, 1})''',
  2077. cudnn=True,
  2078. input_size=(1, 3, 7, 6),
  2079. check_with_long_tensor=True,
  2080. with_tf32=True,
  2081. tf32_precision=0.01,
  2082. ),
  2083. dict(
  2084. module_name='ConvTranspose2d',
  2085. constructor_args=(3, 4, 3, (2, 3), 1, (1, 1), 1, False, (2, 2)),
  2086. cpp_constructor_args='''torch::nn::ConvTranspose2dOptions(3, 4, 3)
  2087. .stride({2, 3})
  2088. .padding(1)
  2089. .output_padding({1, 1})
  2090. .groups(1)
  2091. .bias(false)
  2092. .dilation({2, 2})''',
  2093. input_size=(1, 3, 6, 7),
  2094. cudnn=True,
  2095. desc='dilated',
  2096. check_with_long_tensor=True,
  2097. with_tf32=True,
  2098. tf32_precision=0.005,
  2099. ),
  2100. dict(
  2101. module_name='ConvTranspose2d',
  2102. constructor_args=(3, 4, 3, (2, 3), 1, (1, 1), 1, False),
  2103. cpp_constructor_args='''torch::nn::ConvTranspose2dOptions(3, 4, 3)
  2104. .stride({2, 3}).padding(1).output_padding({1, 1}).groups(1).bias(false)''',
  2105. input_size=(1, 3, 6, 7),
  2106. cudnn=True,
  2107. desc='no_bias',
  2108. check_with_long_tensor=True,
  2109. with_tf32=True,
  2110. tf32_precision=0.005,
  2111. ),
  2112. dict(
  2113. fullname='ConvTranspose2d_groups',
  2114. constructor=lambda: nn.ConvTranspose2d(2, 4, (2, 3), groups=2),
  2115. cpp_constructor_args='torch::nn::ConvTranspose2dOptions(2, 4, {2, 3}).groups(2)',
  2116. input_size=(1, 2, 4, 5),
  2117. cudnn=True,
  2118. check_with_long_tensor=True,
  2119. with_tf32=True,
  2120. tf32_precision=0.01,
  2121. ),
  2122. dict(
  2123. fullname='Conv2d_depthwise',
  2124. constructor=lambda: nn.Conv2d(4, 4, (3, 3), groups=4),
  2125. cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).groups(4)',
  2126. input_size=(2, 4, 6, 6),
  2127. with_tf32=True,
  2128. tf32_precision=0.005,
  2129. ),
  2130. dict(
  2131. fullname='Conv2d_depthwise_with_multiplier',
  2132. constructor=lambda: nn.Conv2d(4, 8, (3, 3), groups=4),
  2133. cpp_constructor_args='torch::nn::Conv2dOptions(4, 8, {3, 3}).groups(4)',
  2134. input_size=(2, 4, 6, 6),
  2135. with_tf32=True,
  2136. tf32_precision=0.005,
  2137. ),
  2138. dict(
  2139. fullname='Conv2d_depthwise_strided',
  2140. constructor=lambda: nn.Conv2d(4, 4, (3, 3), stride=(2, 2), groups=4),
  2141. cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).stride({2, 2}).groups(4)',
  2142. input_size=(2, 4, 6, 6),
  2143. with_tf32=True,
  2144. tf32_precision=0.005,
  2145. ),
  2146. dict(
  2147. fullname='Conv2d_depthwise_padded',
  2148. constructor=lambda: nn.Conv2d(4, 4, (3, 3), padding=(1, 1), groups=4),
  2149. cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).padding({1, 1}).groups(4)',
  2150. input_size=(2, 4, 6, 6),
  2151. with_tf32=True,
  2152. tf32_precision=0.005,
  2153. ),
  2154. dict(
  2155. fullname='Conv2d_depthwise_dilated',
  2156. constructor=lambda: nn.Conv2d(4, 4, (2, 2), dilation=(2, 2), groups=4),
  2157. cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {2, 2}).dilation({2, 2}).groups(4)',
  2158. input_size=(2, 4, 5, 5),
  2159. with_tf32=True,
  2160. tf32_precision=0.005,
  2161. ),
  2162. dict(
  2163. module_name='MaxPool2d',
  2164. constructor_args=((3, 3), (2, 2), (1, 1)),
  2165. cpp_constructor_args='torch::nn::MaxPool2dOptions({3, 3}).stride({2, 2}).padding({1, 1})',
  2166. input_size=(3, 7, 7),
  2167. desc='3d_input'
  2168. ),
  2169. dict(
  2170. module_name='MaxPool2d',
  2171. constructor_args=((3, 3), (2, 2), (1, 1)),
  2172. cpp_constructor_args='torch::nn::MaxPool2dOptions({3, 3}).stride({2, 2}).padding({1, 1})',
  2173. input_size=(1, 3, 7, 7),
  2174. check_with_channels_last=True,
  2175. desc='4d_input'
  2176. ),
  2177. dict(
  2178. module_name='MaxPool2d',
  2179. fullname='MaxPool2d_return_indices',
  2180. constructor=lambda: nn.MaxPool2d((3, 3), (2, 2), (1, 1), return_indices=True),
  2181. input_size=(1, 3, 7, 7),
  2182. check_with_channels_last=True,
  2183. test_cpp_api_parity=False,
  2184. ),
  2185. dict(
  2186. module_name='AvgPool1d',
  2187. constructor_args=(2,),
  2188. cpp_constructor_args='torch::nn::AvgPool1dOptions(2)',
  2189. input_size=(2, 3, 6),
  2190. ),
  2191. dict(
  2192. module_name='AvgPool1d',
  2193. constructor_args=((2,), (2,)),
  2194. cpp_constructor_args='torch::nn::AvgPool1dOptions(2).stride(2)',
  2195. input_size=(2, 3, 6),
  2196. desc='stride',
  2197. ),
  2198. dict(
  2199. module_name='AvgPool1d',
  2200. constructor_args=(2, 2, 1),
  2201. cpp_constructor_args='torch::nn::AvgPool1dOptions(2).stride(2).padding(1)',
  2202. input_size=(2, 3, 6),
  2203. desc='stride_pad',
  2204. ),
  2205. dict(
  2206. module_name='AvgPool1d',
  2207. constructor_args=(2,),
  2208. cpp_constructor_args='torch::nn::AvgPool1dOptions(2)',
  2209. input_size=(3, 6),
  2210. reference_fn=single_batch_reference_fn,
  2211. desc='no_batch_dim',
  2212. ),
  2213. dict(
  2214. module_name='AvgPool2d',
  2215. constructor_args=((2, 2),),
  2216. cpp_constructor_args='torch::nn::AvgPool2dOptions({2, 2})',
  2217. input_size=(2, 3, 6, 6),
  2218. ),
  2219. dict(
  2220. module_name='AvgPool2d',
  2221. constructor_args=((2, 2),),
  2222. cpp_constructor_args='torch::nn::AvgPool2dOptions({2, 2})',
  2223. input_size=(3, 6, 6),
  2224. reference_fn=single_batch_reference_fn,
  2225. desc='no_batch_dim'
  2226. ),
  2227. dict(
  2228. module_name='AvgPool2d',
  2229. constructor_args=((2, 2), (2, 2)),
  2230. cpp_constructor_args='torch::nn::AvgPool2dOptions({2, 2}).stride({2, 2})',
  2231. input_size=(2, 3, 6, 6),
  2232. desc='stride',
  2233. ),
  2234. dict(
  2235. module_name='AvgPool2d',
  2236. constructor_args=((2, 2), (2, 2), (1, 1)),
  2237. cpp_constructor_args='torch::nn::AvgPool2dOptions({2, 2}).stride({2, 2}).padding({1, 1})',
  2238. input_size=(2, 3, 6, 6),
  2239. desc='stride_pad',
  2240. ),
  2241. dict(
  2242. fullname='AvgPool2d_divisor',
  2243. constructor=lambda: nn.AvgPool2d((2, 2), divisor_override=1),
  2244. cpp_constructor_args='torch::nn::AvgPool2dOptions({2, 2}).divisor_override(1)',
  2245. input_size=(2, 3, 6, 6),
  2246. check_with_long_tensor=True,
  2247. ),
  2248. dict(
  2249. fullname='AvgPool2d_divisor_stride',
  2250. constructor=lambda: nn.AvgPool2d((2, 2), (2, 2), divisor_override=1),
  2251. cpp_constructor_args='torch::nn::AvgPool2dOptions({2, 2}).stride({2, 2}).divisor_override(1)',
  2252. input_size=(2, 3, 6, 6),
  2253. check_with_long_tensor=True,
  2254. ),
  2255. dict(
  2256. fullname='AvgPool2d_divisor_stride_pad',
  2257. constructor=lambda: nn.AvgPool2d((2, 2), (2, 2), (1, 1), divisor_override=1),
  2258. cpp_constructor_args='torch::nn::AvgPool2dOptions({2, 2}).stride({2, 2}).padding({1, 1}).divisor_override(1)',
  2259. input_size=(2, 3, 6, 6),
  2260. check_with_long_tensor=True,
  2261. ),
  2262. dict(
  2263. module_name='LPPool2d',
  2264. constructor_args=(2, 2, 2),
  2265. cpp_constructor_args='torch::nn::LPPool2dOptions(2, 2).stride(2)',
  2266. input_size=(1, 3, 7, 7),
  2267. ),
  2268. dict(
  2269. module_name='LPPool2d',
  2270. constructor_args=(1.5, 2),
  2271. cpp_constructor_args='torch::nn::LPPool2dOptions(1.5, 2)',
  2272. input_fn=lambda: torch.rand(1, 3, 7, 7),
  2273. desc='norm',
  2274. ),
  2275. dict(
  2276. module_name='LPPool1d',
  2277. constructor_args=(1.5, 2),
  2278. cpp_constructor_args='torch::nn::LPPool1dOptions(1.5, 2)',
  2279. input_fn=lambda: torch.rand(1, 3, 7),
  2280. desc='norm',
  2281. ),
  2282. dict(
  2283. module_name='LPPool1d',
  2284. constructor_args=(2, 2, 3),
  2285. cpp_constructor_args='torch::nn::LPPool1dOptions(2, 2).stride(3)',
  2286. input_size=(1, 3, 7),
  2287. ),
  2288. dict(
  2289. module_name='LPPool1d',
  2290. constructor_args=(2, 2, 3),
  2291. cpp_constructor_args='torch::nn::LPPool1dOptions(2, 2).stride(3)',
  2292. input_size=(3, 7),
  2293. reference_fn=single_batch_reference_fn,
  2294. desc='no_batch_dim',
  2295. ),
  2296. dict(
  2297. module_name='LocalResponseNorm',
  2298. constructor_args=(3, ),
  2299. cpp_constructor_args='torch::nn::LocalResponseNormOptions(3)',
  2300. input_size=(1, 5, 7),
  2301. desc='1d',
  2302. ),
  2303. dict(
  2304. module_name='LocalResponseNorm',
  2305. constructor_args=(2, ),
  2306. cpp_constructor_args='torch::nn::LocalResponseNormOptions(2)',
  2307. input_size=(1, 5, 7, 7),
  2308. desc='2d_uneven_pad',
  2309. ),
  2310. dict(
  2311. module_name='LocalResponseNorm',
  2312. constructor_args=(1, 1., 0.5, 2.),
  2313. cpp_constructor_args='torch::nn::LocalResponseNormOptions(1).alpha(1.).beta(0.5).k(2.)',
  2314. input_size=(1, 5, 7, 7, 7),
  2315. desc='3d_custom_params',
  2316. ),
  2317. dict(
  2318. module_name='ReflectionPad1d',
  2319. constructor_args=((1, 2),),
  2320. cpp_constructor_args='torch::nn::ReflectionPad1dOptions({1, 2})',
  2321. input_size=(2, 3, 8),
  2322. ),
  2323. dict(
  2324. module_name='ReflectionPad1d',
  2325. constructor_args=((1, 2),),
  2326. cpp_constructor_args='torch::nn::ReflectionPad1dOptions({1, 2})',
  2327. input_size=(3, 8),
  2328. reference_fn=single_batch_reference_fn,
  2329. desc='batch',
  2330. ),
  2331. dict(
  2332. module_name='ReflectionPad1d',
  2333. constructor_args=((1, 2),),
  2334. cpp_constructor_args='torch::nn::ReflectionPad1dOptions({1, 2})',
  2335. input_fn=lambda: torch.rand(2, 3, 8, dtype=torch.complex128, requires_grad=True),
  2336. skip_half=True,
  2337. desc='complex'
  2338. ),
  2339. dict(
  2340. module_name='ReflectionPad2d',
  2341. constructor_args=((1, 2, 3, 4),),
  2342. cpp_constructor_args='torch::nn::ReflectionPad2dOptions({1, 2, 3, 4})',
  2343. input_size=(2, 3, 8, 8),
  2344. ),
  2345. dict(
  2346. module_name='ReflectionPad2d',
  2347. constructor_args=((1, 2, 3, 4),),
  2348. cpp_constructor_args='torch::nn::ReflectionPad2dOptions({1, 2, 3, 4})',
  2349. input_size=(3, 8, 8),
  2350. reference_fn=single_batch_reference_fn,
  2351. desc='no_batch_dim',
  2352. ),
  2353. dict(
  2354. module_name='ReflectionPad2d',
  2355. constructor_args=((1, 2, 3, 4),),
  2356. cpp_constructor_args='torch::nn::ReflectionPad2dOptions({1, 2, 3, 4})',
  2357. input_fn=lambda: torch.rand(2, 3, 8, 8, dtype=torch.complex128, requires_grad=True),
  2358. skip_half=True,
  2359. desc='complex'
  2360. ),
  2361. dict(
  2362. module_name='ReflectionPad3d',
  2363. constructor_args=((1, 2, 0, 2, 1, 2),),
  2364. cpp_constructor_args='torch::nn::ReflectionPad3dOptions({1, 2, 0, 2, 1, 2})',
  2365. input_size=(2, 3, 8, 8, 8),
  2366. ),
  2367. dict(
  2368. module_name='ReflectionPad3d',
  2369. constructor_args=((1, 2, 0, 2, 1, 2),),
  2370. cpp_constructor_args='torch::nn::ReflectionPad3dOptions({1, 2, 0, 2, 1, 2})',
  2371. input_size=(3, 8, 8, 8),
  2372. reference_fn=single_batch_reference_fn,
  2373. desc='no_batch_dim',
  2374. ),
  2375. dict(
  2376. module_name='ReflectionPad3d',
  2377. constructor_args=((1, 2, 0, 2, 1, 2),),
  2378. cpp_constructor_args='torch::nn::ReflectionPad3dOptions({1, 2, 0, 2, 1, 2})',
  2379. input_fn=lambda: torch.rand(2, 3, 8, 8, 8, dtype=torch.complex128, requires_grad=True),
  2380. skip_half=True,
  2381. desc='complex'
  2382. ),
  2383. dict(
  2384. module_name='ReplicationPad1d',
  2385. constructor_args=((1, 2),),
  2386. cpp_constructor_args='torch::nn::ReplicationPad1dOptions({1, 2})',
  2387. input_size=(2, 3, 4),
  2388. ),
  2389. dict(
  2390. module_name='ReplicationPad1d',
  2391. constructor_args=((1, 2),),
  2392. cpp_constructor_args='torch::nn::ReplicationPad1dOptions({1, 2})',
  2393. input_size=(3, 4),
  2394. reference_fn=single_batch_reference_fn,
  2395. desc='batch',
  2396. ),
  2397. dict(
  2398. module_name='ReplicationPad1d',
  2399. constructor_args=((1, 2),),
  2400. cpp_constructor_args='torch::nn::ReplicationPad1dOptions({1, 2})',
  2401. input_fn=lambda: torch.rand(2, 3, 4, dtype=torch.complex128, requires_grad=True),
  2402. skip_half=True,
  2403. desc='complex'
  2404. ),
  2405. dict(
  2406. module_name='ReplicationPad2d',
  2407. constructor_args=((1, 2, 3, 4),),
  2408. cpp_constructor_args='torch::nn::ReplicationPad2dOptions({1, 2, 3, 4})',
  2409. input_size=(2, 3, 4, 4),
  2410. ),
  2411. dict(
  2412. module_name='ReplicationPad2d',
  2413. constructor_args=((1, 2, 3, 4),),
  2414. cpp_constructor_args='torch::nn::ReplicationPad2dOptions({1, 2, 3, 4})',
  2415. input_size=(3, 4, 4),
  2416. reference_fn=single_batch_reference_fn,
  2417. desc='no_batch_dim',
  2418. ),
  2419. dict(
  2420. module_name='ReplicationPad2d',
  2421. constructor_args=((1, 2, 3, 4),),
  2422. cpp_constructor_args='torch::nn::ReplicationPad2dOptions({1, 2, 3, 4})',
  2423. input_fn=lambda: torch.rand(2, 3, 4, 4, dtype=torch.complex128, requires_grad=True),
  2424. skip_half=True,
  2425. desc='complex'
  2426. ),
  2427. dict(
  2428. module_name='ZeroPad2d',
  2429. constructor_args=((1, 2, 3, 4),),
  2430. cpp_constructor_args='torch::nn::ZeroPad2dOptions({1, 2, 3, 4})',
  2431. input_size=(2, 3, 4, 4),
  2432. ),
  2433. dict(
  2434. module_name='ZeroPad2d',
  2435. constructor_args=((1, 2, 3, 4),),
  2436. cpp_constructor_args='torch::nn::ZeroPad2dOptions({1, 2, 3, 4})',
  2437. input_size=(3, 4, 4),
  2438. reference_fn=single_batch_reference_fn,
  2439. desc='no_batch_dim',
  2440. ),
  2441. dict(
  2442. module_name='ZeroPad2d',
  2443. constructor_args=((1, 2, 3, 4),),
  2444. cpp_constructor_args='torch::nn::ZeroPad2dOptions({1, 2, 3, 4})',
  2445. input_fn=lambda: torch.rand(2, 3, 4, 4, dtype=torch.complex128, requires_grad=True),
  2446. skip_half=True,
  2447. desc='complex'
  2448. ),
  2449. dict(
  2450. module_name='ZeroPad2d',
  2451. constructor_args=((-1, -1, -1, -2),),
  2452. cpp_constructor_args='torch::nn::ZeroPad2dOptions({-1, -1, -1, -2})',
  2453. input_size=(2, 3, 4, 4),
  2454. desc='negative_dims'
  2455. ),
  2456. dict(
  2457. module_name='ConstantPad1d',
  2458. constructor_args=((1, 2), 2.),
  2459. cpp_constructor_args='torch::nn::ConstantPad1dOptions({1, 2}, 2.)',
  2460. input_size=(2, 3, 4),
  2461. ),
  2462. dict(
  2463. module_name='ConstantPad1d',
  2464. constructor_args=((1, 2), 2.),
  2465. cpp_constructor_args='torch::nn::ConstantPad1dOptions({1, 2}, 2.)',
  2466. input_size=(3, 4),
  2467. reference_fn=single_batch_reference_fn,
  2468. desc='batch',
  2469. ),
  2470. dict(
  2471. module_name='ConstantPad1d',
  2472. constructor_args=((1, 2), 2.),
  2473. cpp_constructor_args='torch::nn::ConstantPad1dOptions({1, 2}, 2.)',
  2474. input_fn=lambda: torch.rand(2, 3, 4, dtype=torch.complex128, requires_grad=True),
  2475. skip_half=True,
  2476. desc='complex'
  2477. ),
  2478. dict(
  2479. module_name='ConstantPad2d',
  2480. constructor_args=((1, 2, 3, 4), 2.),
  2481. cpp_constructor_args='torch::nn::ConstantPad2dOptions({1, 2, 3, 4}, 2.)',
  2482. input_size=(2, 3, 4, 4),
  2483. ),
  2484. dict(
  2485. module_name='ConstantPad2d',
  2486. constructor_args=((1, 2, 3, 4), 2.),
  2487. cpp_constructor_args='torch::nn::ConstantPad2dOptions({1, 2, 3, 4}, 2.)',
  2488. input_size=(3, 4, 4),
  2489. reference_fn=single_batch_reference_fn,
  2490. desc='no_batch_dim'
  2491. ),
  2492. dict(
  2493. module_name='ConstantPad2d',
  2494. constructor_args=((1, 2, 3, 4), 2.),
  2495. cpp_constructor_args='torch::nn::ConstantPad2dOptions({1, 2, 3, 4}, 2.)',
  2496. input_fn=lambda: torch.rand(2, 3, 4, 4, dtype=torch.complex128, requires_grad=True),
  2497. skip_half=True,
  2498. desc='complex'
  2499. ),
  2500. dict(
  2501. module_name='ConstantPad3d',
  2502. constructor_args=((1, 2, 3, 4, 1, 0), 2.),
  2503. cpp_constructor_args='torch::nn::ConstantPad3dOptions({1, 2, 3, 4, 1, 0}, 2.)',
  2504. input_size=(2, 3, 4, 4, 5),
  2505. ),
  2506. dict(
  2507. module_name='ConstantPad3d',
  2508. constructor_args=((1, 2, 3, 4, 1, 0), 2.),
  2509. cpp_constructor_args='torch::nn::ConstantPad3dOptions({1, 2, 3, 4, 1, 0}, 2.)',
  2510. input_size=(3, 4, 4, 5),
  2511. reference_fn=single_batch_reference_fn,
  2512. desc='no_batch_dim'
  2513. ),
  2514. dict(
  2515. module_name='ConstantPad3d',
  2516. constructor_args=((1, 2, 3, 4, 1, 0), 2.),
  2517. cpp_constructor_args='torch::nn::ConstantPad3dOptions({1, 2, 3, 4, 1, 0}, 2.)',
  2518. input_fn=lambda: torch.rand(2, 3, 4, 4, 5, dtype=torch.complex128, requires_grad=True),
  2519. skip_half=True,
  2520. desc='complex'
  2521. ),
  2522. dict(
  2523. module_name='Conv3d',
  2524. constructor_args=(2, 3, (2, 3, 2)),
  2525. cpp_constructor_args='torch::nn::Conv3dOptions(2, 3, {2, 3, 2})',
  2526. input_size=(1, 2, 4, 5, 4),
  2527. cudnn=True,
  2528. check_with_long_tensor=True,
  2529. with_tf32=True,
  2530. tf32_precision=0.05,
  2531. ),
  2532. dict(
  2533. module_name='Conv3d',
  2534. constructor_args=(2, 3, (2, 3, 4), 1, 0, 1, 1, False),
  2535. cpp_constructor_args='''torch::nn::Conv3dOptions(2, 3, {2, 3, 4})
  2536. .stride(1).padding(0).dilation(1).groups(1).bias(false)''',
  2537. input_size=(1, 2, 3, 4, 5),
  2538. cudnn=True,
  2539. desc='no_bias',
  2540. check_with_long_tensor=True,
  2541. with_tf32=True,
  2542. tf32_precision=0.05,
  2543. ),
  2544. dict(
  2545. module_name='Conv3d',
  2546. constructor_args=(2, 3, (1, 1, 1), 1, 0, 1, 1, False),
  2547. cpp_constructor_args='''torch::nn::Conv3dOptions(2, 3, {2, 3, 4})
  2548. .stride(1).padding(0).dilation(1).groups(1).bias(false)''',
  2549. input_size=(1, 2, 3, 4, 5),
  2550. cudnn=True,
  2551. desc='1x1x1_no_bias',
  2552. check_with_long_tensor=False,
  2553. with_tf32=True,
  2554. tf32_precision=0.05,
  2555. ),
  2556. dict(
  2557. module_name='Conv3d',
  2558. constructor_args=(3, 4, 2, 2),
  2559. cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).stride(2)',
  2560. input_size=(2, 3, 5, 5, 5),
  2561. cudnn=True,
  2562. desc='stride',
  2563. check_with_long_tensor=True,
  2564. with_tf32=True,
  2565. tf32_precision=0.05,
  2566. ),
  2567. dict(
  2568. module_name='Conv3d',
  2569. constructor_args=(3, 4, 2, 2, 1),
  2570. cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).stride(2).padding(1)',
  2571. input_size=(2, 3, 5, 5, 5),
  2572. cudnn=True,
  2573. desc='stride_padding',
  2574. check_with_long_tensor=True,
  2575. with_tf32=True,
  2576. tf32_precision=0.05,
  2577. ),
  2578. dict(
  2579. module_name='Conv3d',
  2580. constructor_args=(3, 4, (2, 3, 4)),
  2581. cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4})',
  2582. input_size=(0, 3, 3, 4, 5),
  2583. cudnn=True,
  2584. check_with_long_tensor=True,
  2585. desc='zero_batch',
  2586. with_tf32=True,
  2587. ),
  2588. dict(
  2589. fullname='Conv3d_groups',
  2590. constructor=lambda: nn.Conv3d(2, 4, kernel_size=3, groups=2),
  2591. cpp_constructor_args='torch::nn::Conv3dOptions(2, 4, 3).groups(2)',
  2592. input_size=(1, 2, 4, 5, 4),
  2593. cudnn=True,
  2594. check_with_long_tensor=True,
  2595. with_tf32=True,
  2596. tf32_precision=0.005,
  2597. ),
  2598. dict(
  2599. fullname='Conv3d_dilated',
  2600. constructor=lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2),
  2601. cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).dilation(2)',
  2602. input_size=(2, 3, 5, 5, 5),
  2603. with_tf32=True,
  2604. tf32_precision=0.05,
  2605. ),
  2606. dict(
  2607. fullname='Conv3d_dilated_strided',
  2608. constructor=lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2, stride=2),
  2609. cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).dilation(2).stride(2)',
  2610. input_size=(2, 3, 5, 5, 5),
  2611. with_tf32=True,
  2612. tf32_precision=0.05
  2613. ),
  2614. dict(
  2615. fullname='Conv3d_pad_valid',
  2616. constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="valid"),
  2617. cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kValid)',
  2618. input_size=(2, 3, 6, 5, 4),
  2619. cudnn=True,
  2620. with_tf32=True,
  2621. tf32_precision=0.05,
  2622. ),
  2623. dict(
  2624. fullname='Conv3d_pad_same',
  2625. constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="same"),
  2626. cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kSame)',
  2627. input_size=(2, 3, 6, 5, 4),
  2628. cudnn=True,
  2629. with_tf32=True,
  2630. tf32_precision=0.05,
  2631. ),
  2632. dict(
  2633. fullname='Conv3d_pad_same_dilated',
  2634. constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="same", dilation=2),
  2635. cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kSame).dilation(2)',
  2636. input_size=(2, 3, 6, 5, 4),
  2637. cudnn=True,
  2638. with_tf32=True,
  2639. tf32_precision=0.05,
  2640. ),
  2641. dict(
  2642. module_name='ConvTranspose3d',
  2643. constructor_args=(2, 3, (2, 3, 2)),
  2644. cpp_constructor_args='torch::nn::ConvTranspose3dOptions(2, 3, {2, 3, 2})',
  2645. cudnn=True,
  2646. input_size=(1, 2, 4, 5, 4),
  2647. with_tf32=True,
  2648. tf32_precision=0.05
  2649. ),
  2650. dict(
  2651. module_name='ConvTranspose3d',
  2652. constructor_args=(2, 3, (2, 3, 2), 1, 0, 0, 1, True, (2, 2, 2)),
  2653. cpp_constructor_args='''torch::nn::ConvTranspose3dOptions(2, 3, {2, 3, 2})
  2654. .stride(1).padding(0).output_padding(0).groups(1).bias(true).dilation({2, 2, 2})''',
  2655. cudnn=True,
  2656. input_size=(1, 2, 4, 5, 4),
  2657. desc='dilated',
  2658. with_tf32=True,
  2659. tf32_precision=0.05
  2660. ),
  2661. dict(
  2662. module_name='MaxPool3d',
  2663. constructor_args=((2, 2, 2),),
  2664. cpp_constructor_args='torch::nn::MaxPool3dOptions({2, 2, 2})',
  2665. input_size=(2, 3, 5, 5, 5),
  2666. ),
  2667. dict(
  2668. module_name='MaxPool3d',
  2669. constructor_args=(2, (2, 2, 2)),
  2670. cpp_constructor_args='torch::nn::MaxPool3dOptions(2).stride({2, 2, 2})',
  2671. input_size=(2, 3, 5, 5, 5),
  2672. desc='stride',
  2673. ),
  2674. dict(
  2675. module_name='MaxPool3d',
  2676. constructor_args=(2, 2, (1, 1, 1)),
  2677. cpp_constructor_args='torch::nn::MaxPool3dOptions(2).stride(2).padding({1, 1, 1})',
  2678. input_size=(2, 3, 5, 5, 5),
  2679. desc='stride_padding',
  2680. ),
  2681. dict(
  2682. module_name='MaxPool3d',
  2683. fullname='MaxPool3d_return_indices',
  2684. constructor=lambda: nn.MaxPool3d(2, 2, (1, 1, 1), return_indices=True),
  2685. input_size=(2, 3, 5, 5, 5),
  2686. test_cpp_api_parity=False,
  2687. ),
  2688. dict(
  2689. module_name='AvgPool3d',
  2690. constructor_args=((2, 2, 2),),
  2691. cpp_constructor_args='torch::nn::AvgPool3dOptions({2, 2, 2})',
  2692. input_size=(2, 3, 4, 4, 4),
  2693. ),
  2694. dict(
  2695. module_name='AvgPool3d',
  2696. constructor_args=((2, 2, 2),),
  2697. cpp_constructor_args='torch::nn::AvgPool3dOptions({2, 2, 2})',
  2698. input_size=(3, 4, 4, 4),
  2699. desc='no_batch_dim',
  2700. ),
  2701. dict(
  2702. module_name='AvgPool3d',
  2703. constructor_args=(2, (2, 2, 2)),
  2704. cpp_constructor_args='torch::nn::AvgPool3dOptions(2).stride({2, 2, 2})',
  2705. input_size=(2, 3, 5, 5, 5),
  2706. desc='stride',
  2707. ),
  2708. dict(
  2709. module_name='AvgPool3d',
  2710. constructor_args=(2, 2, (1, 1, 1)),
  2711. cpp_constructor_args='torch::nn::AvgPool3dOptions(2).stride(2).padding({1, 1, 1})',
  2712. input_size=(2, 3, 5, 5, 5),
  2713. desc='stride_pad',
  2714. ),
  2715. dict(
  2716. module_name='AvgPool3d',
  2717. constructor_args=(4, 2, (1, 2, 1)),
  2718. cpp_constructor_args='torch::nn::AvgPool3dOptions(4).stride(2).padding({1, 2, 1})',
  2719. input_size=(2, 3, 5, 5, 5),
  2720. desc='stride_pad_gpu_fixedkw_output',
  2721. ),
  2722. dict(
  2723. module_name='AvgPool3d',
  2724. constructor_args=((2, 4, 8), 1, (1, 1, 2)),
  2725. cpp_constructor_args='torch::nn::AvgPool3dOptions({2, 4, 8}).stride(1).padding({1, 1, 2})',
  2726. input_size=(2, 3, 2, 4, 8),
  2727. desc='stride_pad_gpu_general_output',
  2728. ),
  2729. dict(
  2730. module_name='AvgPool3d',
  2731. constructor_args=(3, 1, 0),
  2732. cpp_constructor_args='torch::nn::AvgPool3dOptions(3).stride(1).padding(0)',
  2733. input_size=(2, 3, 4, 4, 4),
  2734. desc='stride1_pad0_gpu_input',
  2735. ),
  2736. dict(
  2737. module_name='AvgPool3d',
  2738. constructor_args=(2, 2, (1, 1, 1)),
  2739. cpp_constructor_args='torch::nn::AvgPool3dOptions(2).stride(2).padding({1, 1, 1})',
  2740. input_size=(2, 3, 4, 4, 4),
  2741. desc='stride_pad_gpu_input_nooverlap',
  2742. ),
  2743. dict(
  2744. fullname='AvgPool3d_divisor',
  2745. constructor=lambda: nn.AvgPool3d((2, 2, 2), divisor_override=1),
  2746. cpp_constructor_args='torch::nn::AvgPool3dOptions({2, 2, 2}).divisor_override(1)',
  2747. input_size=(2, 3, 4, 4, 4),
  2748. check_with_long_tensor=True,
  2749. ),
  2750. dict(
  2751. fullname='AvgPool3d_divisor_stride',
  2752. constructor=lambda: nn.AvgPool3d(2, (2, 2, 2), divisor_override=1),
  2753. cpp_constructor_args='torch::nn::AvgPool3dOptions(2).stride({2, 2, 2}).divisor_override(1)',
  2754. input_size=(2, 3, 5, 5, 5),
  2755. check_with_long_tensor=True,
  2756. ),
  2757. dict(
  2758. fullname='AvgPool3d_divisor_stride_pad',
  2759. constructor=lambda: nn.AvgPool3d(2, 2, (1, 1, 1), divisor_override=1),
  2760. cpp_constructor_args='torch::nn::AvgPool3dOptions(2).stride(2).padding({1, 1, 1}).divisor_override(1)',
  2761. input_size=(2, 3, 5, 5, 5),
  2762. check_with_long_tensor=True,
  2763. ),
  2764. dict(
  2765. fullname='AvgPool3d_divisor_stride_pad_gpu_fixedkw_output',
  2766. constructor=lambda: nn.AvgPool3d(4, 2, (1, 2, 1), divisor_override=1),
  2767. cpp_constructor_args='torch::nn::AvgPool3dOptions(4).stride(2).padding({1, 2, 1}).divisor_override(1)',
  2768. input_size=(2, 3, 5, 5, 5),
  2769. check_with_long_tensor=True,
  2770. ),
  2771. dict(
  2772. fullname='AvgPool3d_divisor_stride_pad_gpu_general_output',
  2773. constructor=lambda: nn.AvgPool3d((2, 4, 8), 1, (1, 1, 2), divisor_override=1),
  2774. cpp_constructor_args='torch::nn::AvgPool3dOptions({2, 4, 8}).stride(1).padding({1, 1, 2}).divisor_override(1)',
  2775. input_size=(2, 3, 2, 4, 8),
  2776. check_with_long_tensor=True,
  2777. ),
  2778. dict(
  2779. fullname='AvgPool3d_divisor_stride1_pad0_gpu_input',
  2780. constructor=lambda: nn.AvgPool3d(3, 1, 0, divisor_override=1),
  2781. cpp_constructor_args='torch::nn::AvgPool3dOptions(3).stride(1).padding(0).divisor_override(1)',
  2782. input_size=(2, 3, 4, 4, 4),
  2783. check_with_long_tensor=True,
  2784. ),
  2785. dict(
  2786. fullname='AvgPool3d_divisor_stride_pad_gpu_input_nooverlap',
  2787. constructor=lambda: nn.AvgPool3d(2, 2, (1, 1, 1), divisor_override=1),
  2788. cpp_constructor_args='torch::nn::AvgPool3dOptions(2).stride(2).padding({1, 1, 1}).divisor_override(1)',
  2789. input_size=(2, 3, 4, 4, 4),
  2790. check_with_long_tensor=True,
  2791. ),
  2792. dict(
  2793. module_name='ReplicationPad3d',
  2794. constructor_args=((1, 2, 3, 3, 2, 1),),
  2795. cpp_constructor_args='torch::nn::ReplicationPad3dOptions({1, 2, 3, 3, 2, 1})',
  2796. input_size=(2, 3, 2, 2, 2),
  2797. ),
  2798. dict(
  2799. module_name='ReplicationPad3d',
  2800. constructor_args=((1, 2, 3, 3, 2, 1),),
  2801. cpp_constructor_args='torch::nn::ReplicationPad3dOptions({1, 2, 3, 3, 2, 1})',
  2802. input_size=(3, 2, 2, 2),
  2803. reference_fn=single_batch_reference_fn,
  2804. desc='no_batch_dim',
  2805. ),
  2806. dict(
  2807. module_name='ReplicationPad3d',
  2808. constructor_args=((1, 2, 3, 3, 2, 1),),
  2809. cpp_constructor_args='torch::nn::ReplicationPad3dOptions({1, 2, 3, 3, 2, 1})',
  2810. input_fn=lambda: torch.rand(2, 3, 2, 2, 2, dtype=torch.complex128, requires_grad=True),
  2811. skip_half=True,
  2812. desc='complex'
  2813. ),
  2814. dict(
  2815. module_name='Embedding',
  2816. constructor_args=(4, 3),
  2817. cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3)',
  2818. input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
  2819. check_gradgrad=False,
  2820. ),
  2821. dict(
  2822. module_name='Embedding',
  2823. constructor_args=(4, 3),
  2824. cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3)',
  2825. input_fn=lambda: torch.empty(1, 512, dtype=torch.long).random_(4).expand(7, 512),
  2826. check_gradgrad=False,
  2827. desc='discontiguous'
  2828. ),
  2829. dict(
  2830. module_name='EmbeddingBag',
  2831. constructor_args=(4, 3),
  2832. cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3)',
  2833. input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
  2834. check_gradgrad=False,
  2835. desc='mean',
  2836. ),
  2837. dict(
  2838. module_name='EmbeddingBag',
  2839. constructor_args=(4, 3),
  2840. cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3)',
  2841. input_fn=lambda: torch.empty(1, 512, dtype=torch.long).random_(4).expand(7, 512),
  2842. check_gradgrad=False,
  2843. desc='discontiguous',
  2844. ),
  2845. dict(
  2846. module_name='EmbeddingBag',
  2847. constructor_args=(4, 3, None, 2., False, 'sum'),
  2848. cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3)
  2849. .max_norm(c10::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kSum)''',
  2850. input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
  2851. check_gradgrad=False,
  2852. desc='sum',
  2853. ),
  2854. dict(
  2855. module_name='EmbeddingBag',
  2856. constructor_args=(4, 3, None, 2., False, 'max'),
  2857. cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3)
  2858. .max_norm(c10::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kMax)''',
  2859. input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
  2860. check_gradgrad=False,
  2861. desc='max',
  2862. ),
  2863. dict(
  2864. fullname='EmbeddingBag_mean_padding_idx',
  2865. constructor=lambda: nn.EmbeddingBag(4, 3, padding_idx=1),
  2866. cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3).padding_idx(1)',
  2867. input_fn=lambda: torch.stack([torch.randperm(3), torch.randperm(3)]),
  2868. check_gradgrad=False,
  2869. ),
  2870. dict(
  2871. fullname='EmbeddingBag_sum_padding_idx',
  2872. constructor=lambda: nn.EmbeddingBag(4, 3, None, 2., False, 'sum', padding_idx=1),
  2873. cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3)
  2874. .max_norm(c10::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kSum).padding_idx(1)''',
  2875. input_fn=lambda: torch.stack([torch.randperm(3), torch.randperm(3)]),
  2876. check_gradgrad=False,
  2877. ),
  2878. dict(
  2879. fullname='EmbeddingBag_max_padding_idx',
  2880. constructor=lambda: nn.EmbeddingBag(4, 3, None, 2., False, 'max', padding_idx=1),
  2881. cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3)
  2882. .max_norm(c10::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kMax).padding_idx(1)''',
  2883. input_fn=lambda: torch.stack([torch.randperm(3), torch.randperm(3)]),
  2884. check_gradgrad=False,
  2885. ),
  2886. dict(
  2887. fullname='EmbeddingBag_sparse',
  2888. constructor=lambda: nn.EmbeddingBag(4, 3, sparse=True),
  2889. cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3).sparse(true)',
  2890. input_fn=lambda: torch.randperm(2).repeat(1, 2),
  2891. check_gradgrad=False,
  2892. has_sparse_gradients=True,
  2893. ),
  2894. dict(
  2895. constructor=lambda: nn.Embedding(4, 3, sparse=True),
  2896. cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3).sparse(true)',
  2897. input_fn=lambda: torch.randperm(2).repeat(1, 2),
  2898. fullname='Embedding_sparse',
  2899. check_gradgrad=False,
  2900. has_sparse_gradients=True,
  2901. ),
  2902. dict(
  2903. module_name='PixelShuffle',
  2904. constructor_args=(3,),
  2905. cpp_constructor_args='torch::nn::PixelShuffleOptions(3)',
  2906. input_size=(1, 9, 4, 4),
  2907. ),
  2908. dict(
  2909. module_name='PixelUnshuffle',
  2910. constructor_args=(3,),
  2911. cpp_constructor_args='torch::nn::PixelUnshuffleOptions(3)',
  2912. input_size=(1, 1, 12, 12),
  2913. ),
  2914. dict(
  2915. constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
  2916. cpp_options_args='''F::InterpolateFuncOptions()
  2917. .size(std::vector<int64_t>({12})).scale_factor(c10::nullopt).mode(torch::kNearest)''',
  2918. input_size=(1, 2, 4),
  2919. fullname='interpolate_nearest_1d',
  2920. pickle=False,
  2921. ),
  2922. dict(
  2923. constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
  2924. cpp_options_args='''F::InterpolateFuncOptions()
  2925. .size(std::vector<int64_t>({12})).scale_factor(c10::nullopt).mode(torch::kNearest)''',
  2926. input_size=(0, 2, 4),
  2927. fullname='interpolate_nearest_1d_zero_dim',
  2928. pickle=False,
  2929. ),
  2930. dict(
  2931. constructor=wrap_functional(F.interpolate, size=(12, ), scale_factor=None, mode='nearest'),
  2932. cpp_options_args='''F::InterpolateFuncOptions()
  2933. .size(std::vector<int64_t>({12})).scale_factor(c10::nullopt).mode(torch::kNearest)''',
  2934. input_size=(1, 2, 3),
  2935. fullname='interpolate_nearest_tuple_1d',
  2936. pickle=False,
  2937. ),
  2938. dict(
  2939. constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'),
  2940. cpp_options_args='''F::InterpolateFuncOptions()
  2941. .size(c10::nullopt).scale_factor(std::vector<double>({4.})).mode(torch::kNearest)''',
  2942. input_size=(1, 2, 4),
  2943. fullname='interpolate_nearest_scale_1d',
  2944. pickle=False,
  2945. ),
  2946. dict(
  2947. constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='linear', align_corners=False),
  2948. cpp_options_args='''F::InterpolateFuncOptions()
  2949. .size(std::vector<int64_t>({12}))
  2950. .scale_factor(c10::nullopt)
  2951. .mode(torch::kLinear)
  2952. .align_corners(false)''',
  2953. input_size=(1, 2, 4),
  2954. fullname='interpolate_linear_1d',
  2955. pickle=False,
  2956. ),
  2957. dict(
  2958. constructor=wrap_functional(F.interpolate, size=(4, ), scale_factor=None, mode='linear', align_corners=False),
  2959. cpp_options_args='''F::InterpolateFuncOptions()
  2960. .size(std::vector<int64_t>({4}))
  2961. .scale_factor(c10::nullopt)
  2962. .mode(torch::kLinear)
  2963. .align_corners(false)''',
  2964. input_size=(1, 2, 3),
  2965. fullname='interpolate_linear_tuple_1d',
  2966. pickle=False,
  2967. ),
  2968. dict(
  2969. constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='linear', align_corners=False),
  2970. cpp_options_args='''F::InterpolateFuncOptions()
  2971. .size(c10::nullopt)
  2972. .scale_factor(std::vector<double>({4.}))
  2973. .mode(torch::kLinear)
  2974. .align_corners(false)''',
  2975. input_size=(1, 2, 4),
  2976. fullname='interpolate_linear_scale_1d',
  2977. pickle=False,
  2978. ),
  2979. dict(
  2980. constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='linear', align_corners=False),
  2981. cpp_options_args='''F::InterpolateFuncOptions()
  2982. .size(std::vector<int64_t>({12}))
  2983. .scale_factor(c10::nullopt)
  2984. .mode(torch::kLinear)
  2985. .align_corners(false)''',
  2986. input_size=(0, 2, 4),
  2987. fullname='interpolate_linear_1d_zero_dim',
  2988. pickle=False,
  2989. ),
  2990. dict(
  2991. constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='linear', align_corners=True),
  2992. cpp_options_args='''F::InterpolateFuncOptions()
  2993. .size(std::vector<int64_t>({12}))
  2994. .scale_factor(c10::nullopt)
  2995. .mode(torch::kLinear)
  2996. .align_corners(true)''',
  2997. input_size=(1, 2, 4),
  2998. fullname='interpolate_linear_1d_align_corners',
  2999. pickle=False,
  3000. ),
  3001. dict(
  3002. constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='linear', align_corners=True),
  3003. cpp_options_args='''F::InterpolateFuncOptions()
  3004. .size(c10::nullopt)
  3005. .scale_factor(std::vector<double>({4.}))
  3006. .mode(torch::kLinear)
  3007. .align_corners(true)''',
  3008. input_size=(1, 2, 4),
  3009. fullname='interpolate_linear_scale_1d_align_corners',
  3010. pickle=False,
  3011. ),
  3012. dict(
  3013. constructor=wrap_functional(F.interpolate, size=2, scale_factor=None, mode='nearest'),
  3014. cpp_options_args='''F::InterpolateFuncOptions()
  3015. .size(std::vector<int64_t>({2, 2}))
  3016. .scale_factor(c10::nullopt)
  3017. .mode(torch::kNearest)''',
  3018. input_size=(1, 128, 1, 1),
  3019. fullname='interpolate_nearest_2d_launch_configs',
  3020. pickle=False,
  3021. ),
  3022. dict(
  3023. constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
  3024. cpp_options_args='''F::InterpolateFuncOptions()
  3025. .size(std::vector<int64_t>({12, 12}))
  3026. .scale_factor(c10::nullopt)
  3027. .mode(torch::kNearest)''',
  3028. input_size=(1, 2, 4, 4),
  3029. fullname='interpolate_nearest_2d',
  3030. pickle=False,
  3031. ),
  3032. dict(
  3033. constructor=wrap_functional(F.interpolate, size=(12, 16), scale_factor=None, mode='nearest'),
  3034. cpp_options_args='''F::InterpolateFuncOptions()
  3035. .size(std::vector<int64_t>({12, 16}))
  3036. .scale_factor(c10::nullopt)
  3037. .mode(torch::kNearest)''',
  3038. input_size=(1, 2, 3, 4),
  3039. fullname='interpolate_nearest_tuple_2d',
  3040. pickle=False,
  3041. ),
  3042. dict(
  3043. constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'),
  3044. cpp_options_args='''F::InterpolateFuncOptions()
  3045. .size(c10::nullopt)
  3046. .scale_factor(std::vector<double>({4., 4.}))
  3047. .mode(torch::kNearest)''',
  3048. input_size=(1, 2, 4, 4),
  3049. fullname='interpolate_nearest_scale_2d',
  3050. pickle=False,
  3051. ),
  3052. dict(
  3053. constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
  3054. cpp_options_args='''F::InterpolateFuncOptions()
  3055. .size(std::vector<int64_t>({12, 12}))
  3056. .scale_factor(c10::nullopt)
  3057. .mode(torch::kNearest)''',
  3058. input_size=(0, 2, 4, 4),
  3059. fullname='interpolate_nearest_2d_zero_dim',
  3060. pickle=False,
  3061. ),
  3062. dict(
  3063. constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bilinear', align_corners=False),
  3064. cpp_options_args='''F::InterpolateFuncOptions()
  3065. .size(std::vector<int64_t>({12, 12}))
  3066. .scale_factor(c10::nullopt)
  3067. .mode(torch::kBilinear)
  3068. .align_corners(false)''',
  3069. input_size=(1, 2, 4, 4),
  3070. fullname='interpolate_bilinear_2d',
  3071. pickle=False,
  3072. ),
  3073. dict(
  3074. constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bilinear', align_corners=False),
  3075. cpp_options_args='''F::InterpolateFuncOptions()
  3076. .size(std::vector<int64_t>({12, 12}))
  3077. .scale_factor(c10::nullopt)
  3078. .mode(torch::kBilinear)
  3079. .align_corners(false)''',
  3080. input_size=(0, 2, 4, 4),
  3081. fullname='interpolate_bilinear_2d_zero_dim',
  3082. pickle=False,
  3083. ),
  3084. dict(
  3085. constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None,
  3086. mode='bilinear', align_corners=False),
  3087. cpp_options_args='''F::InterpolateFuncOptions()
  3088. .size(std::vector<int64_t>({4, 6}))
  3089. .scale_factor(c10::nullopt)
  3090. .mode(torch::kBilinear)
  3091. .align_corners(false)''',
  3092. input_size=(1, 2, 2, 3),
  3093. fullname='interpolate_bilinear_tuple_2d',
  3094. pickle=False,
  3095. ),
  3096. dict(
  3097. constructor=wrap_functional(F.interpolate, size=None, scale_factor=4.,
  3098. mode='bilinear', align_corners=False),
  3099. cpp_options_args='''F::InterpolateFuncOptions()
  3100. .size(c10::nullopt)
  3101. .scale_factor(std::vector<double>({4., 4.}))
  3102. .mode(torch::kBilinear)
  3103. .align_corners(false)''',
  3104. input_size=(1, 2, 4, 4),
  3105. fullname='interpolate_bilinear_scale_2d',
  3106. pickle=False,
  3107. ),
  3108. dict(
  3109. constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 2.),
  3110. mode='bilinear', align_corners=False),
  3111. cpp_options_args='''F::InterpolateFuncOptions()
  3112. .size(c10::nullopt)
  3113. .scale_factor(std::vector<double>({2., 2.}))
  3114. .mode(torch::kBilinear)
  3115. .align_corners(false)''',
  3116. input_size=(1, 2, 4, 4),
  3117. fullname='interpolate_bilinear_scale_tuple_shared_2d',
  3118. pickle=False,
  3119. ),
  3120. dict(
  3121. constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.),
  3122. mode='bilinear', align_corners=False),
  3123. cpp_options_args='''F::InterpolateFuncOptions()
  3124. .size(c10::nullopt)
  3125. .scale_factor(std::vector<double>({2., 1.}))
  3126. .mode(torch::kBilinear)
  3127. .align_corners(false)''',
  3128. input_size=(1, 2, 4, 4),
  3129. fullname='interpolate_bilinear_scale_tuple_skewed_2d',
  3130. pickle=False,
  3131. ),
  3132. dict(
  3133. constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None, mode='bilinear', align_corners=True),
  3134. cpp_options_args='''F::InterpolateFuncOptions()
  3135. .size(std::vector<int64_t>({4, 6}))
  3136. .scale_factor(c10::nullopt)
  3137. .mode(torch::kBilinear)
  3138. .align_corners(true)''',
  3139. input_size=(1, 2, 4, 4),
  3140. fullname='interpolate_bilinear_tuple_2d_align_corners',
  3141. pickle=False,
  3142. ),
  3143. dict(
  3144. constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.),
  3145. mode='bilinear', align_corners=True),
  3146. cpp_options_args='''F::InterpolateFuncOptions()
  3147. .size(c10::nullopt)
  3148. .scale_factor(std::vector<double>({2., 1.}))
  3149. .mode(torch::kBilinear)
  3150. .align_corners(true)''',
  3151. input_size=(1, 2, 4, 4),
  3152. fullname='interpolate_bilinear_scale_tuple_skewed_2d_align_corners',
  3153. pickle=False,
  3154. ),
  3155. dict(
  3156. constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bicubic', align_corners=False),
  3157. cpp_options_args='''F::InterpolateFuncOptions()
  3158. .size(std::vector<int64_t>({12, 12}))
  3159. .scale_factor(c10::nullopt)
  3160. .mode(torch::kBicubic)
  3161. .align_corners(false)''',
  3162. input_size=(1, 2, 4, 4),
  3163. fullname='interpolate_bicubic_2d',
  3164. pickle=False,
  3165. ),
  3166. dict(
  3167. constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bicubic', align_corners=False),
  3168. cpp_options_args='''F::InterpolateFuncOptions()
  3169. .size(std::vector<int64_t>({12, 12}))
  3170. .scale_factor(c10::nullopt)
  3171. .mode(torch::kBicubic)
  3172. .align_corners(false)''',
  3173. input_size=(0, 2, 4, 4),
  3174. fullname='interpolate_bicubic_2d_zero_dim',
  3175. pickle=False,
  3176. ),
  3177. dict(
  3178. constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None,
  3179. mode='bicubic', align_corners=False),
  3180. cpp_options_args='''F::InterpolateFuncOptions()
  3181. .size(std::vector<int64_t>({4, 6}))
  3182. .scale_factor(c10::nullopt)
  3183. .mode(torch::kBicubic)
  3184. .align_corners(false)''',
  3185. input_size=(1, 2, 2, 3),
  3186. fullname='interpolate_bicubic_tuple_2d',
  3187. pickle=False,
  3188. ),
  3189. dict(
  3190. constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='bicubic', align_corners=False),
  3191. cpp_options_args='''F::InterpolateFuncOptions()
  3192. .size(c10::nullopt)
  3193. .scale_factor(std::vector<double>({4., 4.}))
  3194. .mode(torch::kBicubic)
  3195. .align_corners(false)''',
  3196. input_size=(1, 2, 4, 4),
  3197. fullname='interpolate_bicubic_scale_2d',
  3198. pickle=False,
  3199. ),
  3200. dict(
  3201. constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 2.),
  3202. mode='bicubic', align_corners=False),
  3203. cpp_options_args='''F::InterpolateFuncOptions()
  3204. .size(c10::nullopt)
  3205. .scale_factor(std::vector<double>({2., 2.}))
  3206. .mode(torch::kBicubic)
  3207. .align_corners(false)''',
  3208. input_size=(1, 2, 4, 4),
  3209. fullname='interpolate_bicubic_scale_tuple_shared_2d',
  3210. pickle=False,
  3211. ),
  3212. dict(
  3213. constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.),
  3214. mode='bicubic', align_corners=False),
  3215. cpp_options_args='''F::InterpolateFuncOptions()
  3216. .size(c10::nullopt)
  3217. .scale_factor(std::vector<double>({2., 1.}))
  3218. .mode(torch::kBicubic)
  3219. .align_corners(false)''',
  3220. input_size=(1, 2, 4, 4),
  3221. fullname='interpolate_bicubic_scale_tuple_skewed_2d',
  3222. pickle=False,
  3223. ),
  3224. dict(
  3225. constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None, mode='bicubic', align_corners=True),
  3226. cpp_options_args='''F::InterpolateFuncOptions()
  3227. .size(std::vector<int64_t>({4, 6}))
  3228. .scale_factor(c10::nullopt)
  3229. .mode(torch::kBicubic)
  3230. .align_corners(true)''',
  3231. input_size=(1, 2, 4, 4),
  3232. fullname='interpolate_bicubic_tuple_2d_align_corners',
  3233. pickle=False,
  3234. ),
  3235. dict(
  3236. constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.),
  3237. mode='bicubic', align_corners=True),
  3238. cpp_options_args='''F::InterpolateFuncOptions()
  3239. .size(c10::nullopt)
  3240. .scale_factor(std::vector<double>({2., 1.}))
  3241. .mode(torch::kBicubic)
  3242. .align_corners(true)''',
  3243. input_size=(1, 2, 4, 4),
  3244. fullname='interpolate_bicubic_scale_tuple_skewed_2d_align_corners',
  3245. pickle=False,
  3246. ),
  3247. dict(
  3248. constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
  3249. cpp_options_args='''F::InterpolateFuncOptions()
  3250. .size(std::vector<int64_t>({12, 12, 12}))
  3251. .scale_factor(c10::nullopt)
  3252. .mode(torch::kNearest)''',
  3253. input_size=(1, 2, 4, 4, 4),
  3254. fullname='interpolate_nearest_3d',
  3255. pickle=False,
  3256. ),
  3257. dict(
  3258. constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
  3259. cpp_options_args='''F::InterpolateFuncOptions()
  3260. .size(std::vector<int64_t>({12, 12, 12}))
  3261. .scale_factor(c10::nullopt)
  3262. .mode(torch::kNearest)''',
  3263. input_size=(0, 2, 4, 4, 4),
  3264. fullname='interpolate_nearest_3d_zero_dim',
  3265. pickle=False,
  3266. ),
  3267. dict(
  3268. constructor=wrap_functional(F.interpolate, size=(12, 16, 16), scale_factor=None, mode='nearest'),
  3269. cpp_options_args='''F::InterpolateFuncOptions()
  3270. .size(std::vector<int64_t>({12, 16, 16}))
  3271. .scale_factor(c10::nullopt)
  3272. .mode(torch::kNearest)''',
  3273. input_size=(1, 2, 3, 4, 4),
  3274. fullname='interpolate_nearest_tuple_3d',
  3275. pickle=False,
  3276. ),
  3277. dict(
  3278. constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'),
  3279. cpp_options_args='''F::InterpolateFuncOptions()
  3280. .size(c10::nullopt)
  3281. .scale_factor(std::vector<double>({4., 4., 4.}))
  3282. .mode(torch::kNearest)''',
  3283. input_size=(1, 2, 4, 4, 4),
  3284. fullname='interpolate_nearest_scale_3d',
  3285. pickle=False,
  3286. ),
  3287. dict(
  3288. constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='trilinear', align_corners=False),
  3289. cpp_options_args='''F::InterpolateFuncOptions()
  3290. .size(std::vector<int64_t>({12, 12, 12}))
  3291. .scale_factor(c10::nullopt)
  3292. .mode(torch::kTrilinear)
  3293. .align_corners(false)''',
  3294. input_size=(1, 2, 4, 4, 4),
  3295. fullname='interpolate_trilinear_3d',
  3296. pickle=False,
  3297. ),
  3298. dict(
  3299. constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='trilinear', align_corners=False),
  3300. cpp_options_args='''F::InterpolateFuncOptions()
  3301. .size(std::vector<int64_t>({12, 12, 12}))
  3302. .scale_factor(c10::nullopt)
  3303. .mode(torch::kTrilinear)
  3304. .align_corners(false)''',
  3305. input_size=(0, 2, 4, 4, 4),
  3306. fullname='interpolate_trilinear_3d_zero_dim',
  3307. pickle=False,
  3308. ),
  3309. dict(
  3310. constructor=wrap_functional(F.interpolate, size=(4, 6, 6),
  3311. scale_factor=None, mode='trilinear', align_corners=False),
  3312. cpp_options_args='''F::InterpolateFuncOptions()
  3313. .size(std::vector<int64_t>({4, 6, 6}))
  3314. .scale_factor(c10::nullopt)
  3315. .mode(torch::kTrilinear)
  3316. .align_corners(false)''',
  3317. input_size=(1, 2, 2, 3, 3),
  3318. fullname='interpolate_trilinear_tuple_3d',
  3319. pickle=False,
  3320. ),
  3321. dict(
  3322. constructor=wrap_functional(F.interpolate, size=None, scale_factor=3., mode='trilinear', align_corners=False),
  3323. cpp_options_args='''F::InterpolateFuncOptions()
  3324. .size(c10::nullopt)
  3325. .scale_factor(std::vector<double>({3., 3., 3.}))
  3326. .mode(torch::kTrilinear)
  3327. .align_corners(false)''',
  3328. input_size=(1, 2, 3, 4, 5),
  3329. fullname='interpolate_trilinear_scale_3d',
  3330. # See https://github.com/pytorch/pytorch/issues/5006
  3331. precision=3e-4,
  3332. pickle=False,
  3333. ),
  3334. dict(
  3335. constructor=wrap_functional(F.interpolate, size=(4, 6, 6), scale_factor=None,
  3336. mode='trilinear', align_corners=True),
  3337. cpp_options_args='''F::InterpolateFuncOptions()
  3338. .size(std::vector<int64_t>({4, 6, 6}))
  3339. .scale_factor(c10::nullopt)
  3340. .mode(torch::kTrilinear)
  3341. .align_corners(true)''',
  3342. input_size=(1, 2, 2, 3, 3),
  3343. fullname='interpolate_trilinear_tuple_3d_align_corners',
  3344. pickle=False,
  3345. ),
  3346. dict(
  3347. constructor=wrap_functional(F.interpolate, size=None, scale_factor=3., mode='trilinear', align_corners=True),
  3348. cpp_options_args='''F::InterpolateFuncOptions()
  3349. .size(c10::nullopt)
  3350. .scale_factor(std::vector<double>({3., 3., 3.}))
  3351. .mode(torch::kTrilinear)
  3352. .align_corners(true)''',
  3353. input_size=(1, 2, 3, 4, 4),
  3354. fullname='interpolate_trilinear_scale_3d_align_corners',
  3355. # See https://github.com/pytorch/pytorch/issues/5006
  3356. precision=3e-4,
  3357. pickle=False,
  3358. ),
  3359. dict(
  3360. module_name='AdaptiveMaxPool1d',
  3361. constructor_args=(3,),
  3362. cpp_constructor_args='torch::nn::AdaptiveMaxPool1dOptions(3)',
  3363. input_fn=lambda: _rand_tensor_non_equal(1, 3, 5),
  3364. ),
  3365. dict(
  3366. module_name='AdaptiveMaxPool1d',
  3367. constructor_args=(3,),
  3368. cpp_constructor_args='torch::nn::AdaptiveMaxPool1dOptions(3)',
  3369. input_fn=lambda: _rand_tensor_non_equal(3, 5),
  3370. desc='no_batch_dim',
  3371. ),
  3372. dict(
  3373. module_name='AdaptiveMaxPool2d',
  3374. constructor_args=(3,),
  3375. cpp_constructor_args='torch::nn::AdaptiveMaxPool2dOptions(3)',
  3376. input_fn=lambda: _rand_tensor_non_equal(1, 3, 5, 6),
  3377. desc='single',
  3378. ),
  3379. dict(
  3380. module_name='AdaptiveMaxPool2d',
  3381. constructor_args=((3, 4),),
  3382. cpp_constructor_args='torch::nn::AdaptiveMaxPool2dOptions({3, 4})',
  3383. input_fn=lambda: _rand_tensor_non_equal(1, 3, 5, 6),
  3384. desc='tuple',
  3385. ),
  3386. dict(
  3387. module_name='AdaptiveMaxPool2d',
  3388. constructor_args=(3,),
  3389. cpp_constructor_args='torch::nn::AdaptiveMaxPool2dOptions(3)',
  3390. input_fn=lambda: _rand_tensor_non_equal(3, 5, 6),
  3391. reference_fn=single_batch_reference_fn,
  3392. desc='no_batch_dim',
  3393. ),
  3394. dict(
  3395. module_name='AdaptiveMaxPool2d',
  3396. constructor_args=((3, None),),
  3397. cpp_constructor_args='torch::nn::AdaptiveMaxPool2dOptions({3, c10::nullopt})',
  3398. input_fn=lambda: _rand_tensor_non_equal(1, 3, 5, 6),
  3399. desc='tuple_none',
  3400. ),
  3401. dict(
  3402. module_name='AdaptiveMaxPool3d',
  3403. constructor_args=(3,),
  3404. cpp_constructor_args='torch::nn::AdaptiveMaxPool3dOptions(3)',
  3405. input_fn=lambda: _rand_tensor_non_equal(2, 3, 5, 6, 7),
  3406. desc='single',
  3407. ),
  3408. dict(
  3409. module_name='AdaptiveMaxPool3d',
  3410. constructor_args=(3,),
  3411. cpp_constructor_args='torch::nn::AdaptiveMaxPool3dOptions(3)',
  3412. input_fn=lambda: _rand_tensor_non_equal(3, 5, 6, 7),
  3413. reference_fn=single_batch_reference_fn,
  3414. desc='no_batch_dim',
  3415. ),
  3416. dict(
  3417. module_name='AdaptiveMaxPool3d',
  3418. constructor_args=((3, 4, 5),),
  3419. cpp_constructor_args='torch::nn::AdaptiveMaxPool3dOptions({3, 4, 5})',
  3420. input_fn=lambda: _rand_tensor_non_equal(2, 3, 5, 6, 7),
  3421. desc='tuple',
  3422. ),
  3423. dict(
  3424. module_name='AdaptiveMaxPool3d',
  3425. constructor_args=((3, None, 5),),
  3426. cpp_constructor_args='torch::nn::AdaptiveMaxPool3dOptions({3, c10::nullopt, 5})',
  3427. input_fn=lambda: _rand_tensor_non_equal(2, 3, 5, 6, 7),
  3428. desc='tuple_none',
  3429. ),
  3430. dict(
  3431. module_name='AdaptiveMaxPool3d',
  3432. constructor_args=(3,),
  3433. cpp_constructor_args='torch::nn::AdaptiveMaxPool3dOptions(3)',
  3434. input_fn=lambda: _rand_tensor_non_equal(2, 3, 12, 9, 3),
  3435. desc='single_nonatomic',
  3436. ),
  3437. dict(
  3438. module_name='AdaptiveMaxPool3d',
  3439. constructor_args=((3, 4, 5),),
  3440. cpp_constructor_args='torch::nn::AdaptiveMaxPool3dOptions({3, 4, 5})',
  3441. input_fn=lambda: _rand_tensor_non_equal(2, 3, 6, 4, 10),
  3442. desc='tuple_nonatomic',
  3443. ),
  3444. dict(
  3445. module_name='AdaptiveAvgPool1d',
  3446. constructor_args=(3,),
  3447. cpp_constructor_args='torch::nn::AdaptiveAvgPool1dOptions(3)',
  3448. input_fn=lambda: torch.rand(1, 3, 5),
  3449. ),
  3450. dict(
  3451. module_name='AdaptiveAvgPool1d',
  3452. constructor_args=(3,),
  3453. cpp_constructor_args='torch::nn::AdaptiveAvgPool1dOptions(3)',
  3454. input_fn=lambda: torch.rand(3, 5),
  3455. reference_fn=single_batch_reference_fn,
  3456. desc='no_batch_dim',
  3457. ),
  3458. dict(
  3459. module_name='AdaptiveAvgPool1d',
  3460. constructor_args=(1,),
  3461. cpp_constructor_args='torch::nn::AdaptiveAvgPool1dOptions(1)',
  3462. input_fn=lambda: torch.rand(1, 3, 5),
  3463. desc='one_output',
  3464. ),
  3465. dict(
  3466. module_name='AdaptiveAvgPool2d',
  3467. constructor_args=(3,),
  3468. cpp_constructor_args='torch::nn::AdaptiveAvgPool2dOptions(3)',
  3469. input_fn=lambda: torch.rand(1, 3, 5, 6),
  3470. desc='single',
  3471. ),
  3472. dict(
  3473. module_name='AdaptiveAvgPool2d',
  3474. constructor_args=(3,),
  3475. cpp_constructor_args='torch::nn::AdaptiveAvgPool2dOptions(3)',
  3476. input_fn=lambda: torch.rand(3, 5, 6),
  3477. reference_fn=single_batch_reference_fn,
  3478. desc='no_batch_dim',
  3479. ),
  3480. dict(
  3481. module_name='AdaptiveAvgPool2d',
  3482. constructor_args=(1,),
  3483. cpp_constructor_args='torch::nn::AdaptiveAvgPool2dOptions(1)',
  3484. input_fn=lambda: torch.rand(1, 3, 5, 6),
  3485. desc='single_1x1output',
  3486. ),
  3487. dict(
  3488. module_name='AdaptiveAvgPool2d',
  3489. constructor_args=((3, 4),),
  3490. cpp_constructor_args='torch::nn::AdaptiveAvgPool2dOptions({3, 4})',
  3491. input_fn=lambda: torch.rand(1, 3, 5, 6),
  3492. desc='tuple',
  3493. ),
  3494. dict(
  3495. module_name='AdaptiveAvgPool2d',
  3496. constructor_args=((3, None),),
  3497. cpp_constructor_args='torch::nn::AdaptiveAvgPool2dOptions({3, c10::nullopt})',
  3498. input_fn=lambda: torch.rand(1, 3, 5, 6),
  3499. desc='tuple_none',
  3500. ),
  3501. dict(
  3502. module_name='AdaptiveAvgPool3d',
  3503. constructor_args=(3,),
  3504. cpp_constructor_args='torch::nn::AdaptiveAvgPool3dOptions(3)',
  3505. input_fn=lambda: torch.rand(2, 3, 5, 2, 7),
  3506. desc='single',
  3507. ),
  3508. dict(
  3509. module_name='AdaptiveAvgPool3d',
  3510. constructor_args=(3,),
  3511. cpp_constructor_args='torch::nn::AdaptiveAvgPool3dOptions(3)',
  3512. input_fn=lambda: torch.rand(3, 5, 2, 7),
  3513. reference_fn=single_batch_reference_fn,
  3514. desc='no_batch_dim',
  3515. ),
  3516. dict(
  3517. module_name='AdaptiveAvgPool3d',
  3518. constructor_args=((3, 4, 5),),
  3519. cpp_constructor_args='torch::nn::AdaptiveAvgPool3dOptions({3, 4, 5})',
  3520. input_fn=lambda: torch.rand(2, 3, 5, 3, 7),
  3521. desc='tuple',
  3522. ),
  3523. dict(
  3524. module_name='AdaptiveAvgPool3d',
  3525. constructor_args=((None, 4, 5),),
  3526. cpp_constructor_args='torch::nn::AdaptiveAvgPool3dOptions({c10::nullopt, 4, 5})',
  3527. input_fn=lambda: torch.rand(2, 3, 5, 3, 7),
  3528. desc='tuple_none',
  3529. ),
  3530. dict(
  3531. module_name='AdaptiveAvgPool3d',
  3532. constructor_args=((3, 2, 2),),
  3533. cpp_constructor_args='torch::nn::AdaptiveAvgPool3dOptions({3, 2, 2})',
  3534. input_fn=lambda: torch.rand(1, 1, 3, 2, 6),
  3535. desc='last_dim',
  3536. ),
  3537. dict(
  3538. module_name='SELU',
  3539. input_size=(3, 2, 5),
  3540. check_inplace=True
  3541. ),
  3542. dict(
  3543. module_name='SELU',
  3544. input_size=(),
  3545. check_inplace=True,
  3546. desc='scalar'
  3547. ),
  3548. dict(
  3549. module_name='CELU',
  3550. input_size=(3, 2, 5),
  3551. constructor_args=(2.,),
  3552. cpp_constructor_args='torch::nn::CELUOptions().alpha(2.)',
  3553. check_inplace=True,
  3554. reference_fn=lambda x, *_: torch.where(x >= 0, x, 2. * ((.5 * x).exp() - 1)),
  3555. ),
  3556. dict(
  3557. module_name='CELU',
  3558. input_size=(),
  3559. constructor_args=(2.,),
  3560. cpp_constructor_args='torch::nn::CELUOptions().alpha(2.)',
  3561. check_inplace=True,
  3562. reference_fn=lambda x, *_: torch.where(x >= 0, x, 2. * ((.5 * x).exp() - 1)),
  3563. desc='scalar'
  3564. ),
  3565. dict(
  3566. module_name='GLU',
  3567. input_size=(5, 6),
  3568. ),
  3569. dict(
  3570. module_name='GLU',
  3571. constructor_args=(1,),
  3572. cpp_constructor_args='torch::nn::GLUOptions(1)',
  3573. input_size=(5, 6, 7),
  3574. desc='dim',
  3575. ),
  3576. dict(
  3577. module_name='GELU',
  3578. constructor_args=('none',),
  3579. cpp_constructor_args='torch::nn::GELUOptions().approximate(\"none\")',
  3580. input_size=(),
  3581. desc='scalar',
  3582. reference_fn=lambda x, *_: x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))),
  3583. ),
  3584. dict(
  3585. module_name='GELU',
  3586. constructor_args=('none',),
  3587. cpp_constructor_args='torch::nn::GELUOptions().approximate(\"none\")',
  3588. input_size=(3, 2, 5),
  3589. reference_fn=lambda x, *_: x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))),
  3590. ),
  3591. dict(
  3592. module_name='SiLU',
  3593. input_size=(),
  3594. desc='scalar',
  3595. reference_fn=lambda x, *_: x * torch.sigmoid(x),
  3596. ),
  3597. dict(
  3598. module_name='SiLU',
  3599. input_size=(5, 6, 7),
  3600. reference_fn=lambda x, *_: x * torch.sigmoid(x),
  3601. ),
  3602. dict(
  3603. module_name='Mish',
  3604. input_size=(),
  3605. desc='scalar',
  3606. reference_fn=lambda x, *_: x * torch.tanh(F.softplus(x)),
  3607. ),
  3608. dict(
  3609. module_name='Mish',
  3610. input_size=(5, 6, 7),
  3611. reference_fn=lambda x, *_: x * torch.tanh(F.softplus(x)),
  3612. ),
  3613. dict(
  3614. constructor=wrap_functional(F.softmax, dim=-1),
  3615. cpp_options_args='F::SoftmaxFuncOptions(-1)',
  3616. input_size=(2, 128), # trigger the last-dim algo in CUDA
  3617. fullname='softmax_lastdim',
  3618. pickle=False,
  3619. ),
  3620. dict(
  3621. constructor=wrap_functional(F.softmax, dim=1, dtype=torch.float64),
  3622. cpp_options_args='F::SoftmaxFuncOptions(1).dtype(torch::kFloat64)',
  3623. input_size=(2, 128),
  3624. fullname='softmax_lastdim_dtype',
  3625. pickle=False,
  3626. test_cuda=False
  3627. ),
  3628. dict(
  3629. constructor=wrap_functional(F.softmax, dim=1),
  3630. cpp_options_args='F::SoftmaxFuncOptions(1)',
  3631. input_size=(2, 128, 2, 2), # trigger special case of spatial CUDA algo
  3632. fullname='softmax_spatial_special',
  3633. pickle=False,
  3634. ),
  3635. dict(
  3636. constructor=wrap_functional(F.softmax, dim=1),
  3637. cpp_options_args='F::SoftmaxFuncOptions(1)',
  3638. input_size=(2, 2, 4, 4), # regular spatial algorithm
  3639. fullname='softmax_spatial',
  3640. pickle=False,
  3641. ),
  3642. dict(
  3643. constructor=wrap_functional(F.softmax, dim=1, dtype=torch.float64),
  3644. cpp_options_args='F::SoftmaxFuncOptions(1).dtype(torch::kFloat64)',
  3645. input_size=(2, 2, 4, 4), # regular spatial algorithm
  3646. fullname='softmax_spatial_dtype',
  3647. pickle=False,
  3648. test_cuda=False
  3649. ),
  3650. dict(
  3651. constructor=wrap_functional(F.softmax, dim=0),
  3652. cpp_options_args='F::SoftmaxFuncOptions(0)',
  3653. input_size=(2, 3, 4, 5),
  3654. fullname='softmax_functional_dim0',
  3655. test_cuda=False,
  3656. pickle=False,
  3657. ),
  3658. dict(
  3659. constructor=wrap_functional(F.softmax, dim=3),
  3660. cpp_options_args='F::SoftmaxFuncOptions(3)',
  3661. input_size=(2, 3, 4, 5),
  3662. fullname='softmax_functional_dim3',
  3663. test_cuda=False,
  3664. pickle=False,
  3665. ),
  3666. dict(
  3667. constructor=wrap_functional(F.softmax, dim=-1),
  3668. cpp_options_args='F::SoftmaxFuncOptions(-1)',
  3669. input_size=(),
  3670. fullname='softmax_functional_scalar',
  3671. test_cuda=False,
  3672. pickle=False,
  3673. ),
  3674. dict(
  3675. constructor=wrap_functional(F.log_softmax, dim=-1),
  3676. cpp_options_args='F::LogSoftmaxFuncOptions(-1)',
  3677. input_size=(2, 128), # trigger the last-dim algo in CUDA
  3678. fullname='log_softmax_lastdim',
  3679. pickle=False,
  3680. ),
  3681. dict(
  3682. constructor=wrap_functional(F.log_softmax, dim=1),
  3683. cpp_options_args='F::LogSoftmaxFuncOptions(1)',
  3684. input_size=(2, 128, 2, 2), # trigger special case of spatial CUDA algo
  3685. fullname='log_softmax_spatial_special',
  3686. pickle=False,
  3687. ),
  3688. dict(
  3689. constructor=wrap_functional(F.log_softmax, dim=1),
  3690. cpp_options_args='F::LogSoftmaxFuncOptions(1)',
  3691. input_size=(2, 2, 4, 4), # regular spatial algorithm
  3692. fullname='log_softmax_spatial',
  3693. pickle=False,
  3694. ),
  3695. dict(
  3696. constructor=wrap_functional(F.log_softmax, dim=0),
  3697. cpp_options_args='F::LogSoftmaxFuncOptions(0)',
  3698. input_size=(2, 3, 4, 5),
  3699. fullname='log_softmax_dim0',
  3700. pickle=False,
  3701. ),
  3702. dict(
  3703. constructor=wrap_functional(F.log_softmax, dim=3),
  3704. cpp_options_args='F::LogSoftmaxFuncOptions(3)',
  3705. input_size=(2, 3, 4, 5),
  3706. fullname='log_softmax_dim3',
  3707. pickle=False,
  3708. ),
  3709. dict(
  3710. constructor=wrap_functional(F.log_softmax, dim=0),
  3711. cpp_options_args='F::LogSoftmaxFuncOptions(0)',
  3712. input_size=(),
  3713. fullname='log_softmax_scalar',
  3714. pickle=False,
  3715. ),
  3716. dict(
  3717. module_name='Softmax2d',
  3718. input_size=(3, 4, 5),
  3719. reference_fn=single_batch_reference_fn,
  3720. desc='no_batch_dim',
  3721. ),
  3722. dict(
  3723. module_name='Softmax',
  3724. constructor_args=(-1,),
  3725. cpp_constructor_args='torch::nn::SoftmaxOptions(-1)',
  3726. input_size=(4, 5),
  3727. reference_fn=single_batch_reference_fn,
  3728. desc='no_batch_dim',
  3729. ),
  3730. dict(
  3731. module_name='LogSoftmax',
  3732. constructor_args=(-1,),
  3733. cpp_constructor_args='torch::nn::LogSoftmaxOptions(1)',
  3734. input_size=(4, 5),
  3735. reference_fn=single_batch_reference_fn,
  3736. desc='no_batch_dim',
  3737. ),
  3738. dict(
  3739. fullname='Unfold',
  3740. constructor=lambda: nn.Unfold((2, 2), (1, 1), (0, 0), (1, 1)),
  3741. cpp_constructor_args='torch::nn::UnfoldOptions({2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})',
  3742. input_size=(2, 4, 3, 3),
  3743. check_gradgrad=False,
  3744. test_cuda=True,
  3745. ),
  3746. dict(
  3747. fullname='Fold',
  3748. constructor=lambda: nn.Fold((3, 3), (2, 2), (1, 1), (0, 0), (1, 1)),
  3749. cpp_constructor_args='torch::nn::FoldOptions({3, 3}, {2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})',
  3750. input_size=(2, 16, 4),
  3751. check_gradgrad=False,
  3752. test_cuda=True,
  3753. ),
  3754. dict(
  3755. fullname='Fold_no_batch_dim_input',
  3756. constructor=lambda: nn.Fold((3, 3), (2, 2), (1, 1), (0, 0), (1, 1)),
  3757. cpp_constructor_args='torch::nn::FoldOptions({3, 3}, {2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})',
  3758. input_size=(16, 4),
  3759. check_gradgrad=False,
  3760. ref=single_batch_reference_fn,
  3761. test_cuda=True,
  3762. ),
  3763. dict(
  3764. fullname='Unfold_int_input',
  3765. constructor=lambda: nn.Unfold(2, 1, 0, 1),
  3766. cpp_constructor_args='torch::nn::UnfoldOptions(2).dilation(1).padding(0).stride(1)',
  3767. input_size=(2, 4, 3, 3),
  3768. check_gradgrad=False,
  3769. test_cuda=True,
  3770. ),
  3771. dict(
  3772. fullname='Fold_int_input',
  3773. constructor=lambda: nn.Fold(3, 2, 1, 0, 1),
  3774. cpp_constructor_args='torch::nn::FoldOptions(3, 2).dilation(1).padding(0).stride(1)',
  3775. input_size=(2, 16, 4),
  3776. check_gradgrad=False,
  3777. test_cuda=True,
  3778. ),
  3779. dict(
  3780. fullname='Fold_no_batch_dim_int_input',
  3781. constructor=lambda: nn.Fold(3, 2, 1, 0, 1),
  3782. cpp_constructor_args='torch::nn::FoldOptions(3, 2).dilation(1).padding(0).stride(1)',
  3783. input_size=(16, 4),
  3784. ref=single_batch_reference_fn,
  3785. check_gradgrad=False,
  3786. test_cuda=True,
  3787. ),
  3788. dict(
  3789. module_name='Threshold',
  3790. constructor_args=(2., 1.),
  3791. cpp_constructor_args='torch::nn::ThresholdOptions(2., 1.)',
  3792. input_size=(),
  3793. check_inplace=True,
  3794. desc='threshold_value_scalar'
  3795. ),
  3796. dict(
  3797. module_name='ReLU',
  3798. input_size=(),
  3799. check_inplace=True,
  3800. desc='scalar'
  3801. ),
  3802. dict(
  3803. module_name='ReLU6',
  3804. input_size=(),
  3805. check_inplace=True,
  3806. desc='scalar'
  3807. ),
  3808. dict(
  3809. module_name='RReLU',
  3810. constructor_args=(0.1, 0.9),
  3811. cpp_constructor_args='torch::nn::RReLUOptions().lower(0.1).upper(0.9)',
  3812. input_size=(),
  3813. desc='with_up_down_scalar',
  3814. test_cuda=False,
  3815. ),
  3816. dict(
  3817. module_name='Hardtanh',
  3818. input_size=(),
  3819. reference_fn=lambda i, *_: i.clamp(-1, 1),
  3820. desc='scalar'
  3821. ),
  3822. dict(
  3823. module_name='Sigmoid',
  3824. input_size=(),
  3825. desc='scalar',
  3826. ),
  3827. dict(
  3828. module_name='Tanh',
  3829. input_size=(),
  3830. desc='scalar',
  3831. ),
  3832. dict(
  3833. module_name='Softmax',
  3834. constructor_args=(0,),
  3835. cpp_constructor_args='torch::nn::SoftmaxOptions(0)',
  3836. input_size=(),
  3837. reference_fn=lambda i, *_: torch.exp(i).div(torch.exp(i).sum(0, True)),
  3838. desc='scalar',
  3839. ),
  3840. dict(
  3841. module_name='LogSoftmax',
  3842. constructor_args=(0,),
  3843. cpp_constructor_args='torch::nn::LogSoftmaxOptions(0)',
  3844. input_size=(),
  3845. reference_fn=lambda i, *_: torch.exp(i).div_(torch.exp(i).sum(0, False)).log_(),
  3846. desc='multiparam_scalar',
  3847. ),
  3848. dict(
  3849. module_name='ELU',
  3850. constructor_args=(2.,),
  3851. cpp_constructor_args='torch::nn::ELUOptions().alpha(2.)',
  3852. input_size=(),
  3853. desc='scalar',
  3854. ),
  3855. dict(
  3856. module_name='Hardshrink',
  3857. constructor_args=(2.,),
  3858. cpp_constructor_args='torch::nn::HardshrinkOptions(2.)',
  3859. input_size=(),
  3860. desc='scalar',
  3861. ),
  3862. dict(
  3863. module_name='LeakyReLU',
  3864. constructor_args=(0.5,),
  3865. cpp_constructor_args='torch::nn::LeakyReLUOptions().negative_slope(0.5)',
  3866. input_size=(),
  3867. check_inplace=True,
  3868. desc='with_negval_scalar'
  3869. ),
  3870. dict(
  3871. module_name='LogSigmoid',
  3872. input_size=(),
  3873. reference_fn=lambda i, *_: i.sigmoid().log(),
  3874. desc='scalar'
  3875. ),
  3876. dict(
  3877. module_name='Softplus',
  3878. constructor_args=(2, -100),
  3879. cpp_constructor_args='torch::nn::SoftplusOptions().beta(2).threshold(-100)',
  3880. input_size=(),
  3881. reference_fn=(
  3882. lambda i, *_: ((i * 2) > -100).type_as(i) * i
  3883. + ((i * 2) <= -100).type_as(i) * 1.0 / 2.0 * torch.log(1 + torch.exp(2 * i))
  3884. ),
  3885. desc='beta_threshold_scalar',
  3886. ),
  3887. dict(
  3888. module_name='Softshrink',
  3889. constructor_args=(1,),
  3890. cpp_constructor_args='torch::nn::SoftshrinkOptions(1)',
  3891. input_size=(),
  3892. desc='lambda_scalar',
  3893. ),
  3894. dict(
  3895. module_name='PReLU',
  3896. input_size=(),
  3897. reference_fn=lambda i, p, _: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
  3898. desc='scalar',
  3899. ),
  3900. dict(
  3901. module_name='Softsign',
  3902. input_size=(),
  3903. reference_fn=lambda i, *_: i.div(1 + torch.abs(i)),
  3904. desc='scalar',
  3905. ),
  3906. dict(
  3907. module_name='Softmin',
  3908. constructor_args=(0,),
  3909. cpp_constructor_args='torch::nn::SoftminOptions(0)',
  3910. input_size=(),
  3911. desc='scalar',
  3912. ),
  3913. dict(
  3914. module_name='Softmin',
  3915. constructor_args=(-1,),
  3916. cpp_constructor_args='torch::nn::SoftminOptions(-1)',
  3917. input_size=(3, 4, 10),
  3918. reference_fn=single_batch_reference_fn,
  3919. desc='no_batch_dim',
  3920. ),
  3921. dict(
  3922. module_name='Tanhshrink',
  3923. input_size=(),
  3924. desc='scalar',
  3925. ),
  3926. dict(
  3927. fullname='Padding12_1dcircular',
  3928. constructor=wrap_functional(F.pad, pad=(1, 2), mode='circular'),
  3929. cpp_options_args='F::PadFuncOptions({1, 2}).mode(torch::kCircular)',
  3930. input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 2, 3]),
  3931. reference_fn=lambda i, *_: padding1d_circular(i, (1, 2)),
  3932. skip_double=TEST_WITH_ROCM,
  3933. pickle=False,
  3934. ),
  3935. dict(
  3936. fullname='Padding31_1dcircular',
  3937. constructor=wrap_functional(F.pad, pad=(3, 1), mode='circular'),
  3938. cpp_options_args='F::PadFuncOptions({3, 1}).mode(torch::kCircular)',
  3939. input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 2, 3]),
  3940. reference_fn=lambda i, *_: padding1d_circular(i, (3, 1)),
  3941. skip_double=TEST_WITH_ROCM,
  3942. pickle=False,
  3943. ),
  3944. dict(
  3945. fullname='Padding33_1dcircular',
  3946. constructor=wrap_functional(F.pad, pad=(3, 3), mode='circular'),
  3947. cpp_options_args='F::PadFuncOptions({3, 3}).mode(torch::kCircular)',
  3948. input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 2, 3]),
  3949. reference_fn=lambda i, *_: padding1d_circular(i, (3, 3)),
  3950. skip_double=TEST_WITH_ROCM,
  3951. pickle=False,
  3952. ),
  3953. dict(
  3954. fullname='Padding1221_2dcircular',
  3955. constructor=wrap_functional(F.pad, pad=(1, 2, 2, 1), mode='circular'),
  3956. cpp_options_args='F::PadFuncOptions({1, 2, 2, 1}).mode(torch::kCircular)',
  3957. input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 1, 2, 3]),
  3958. reference_fn=lambda i, *_: padding2d_circular(i, (1, 2, 2, 1)),
  3959. skip_double=TEST_WITH_ROCM,
  3960. pickle=False,
  3961. ),
  3962. dict(
  3963. fullname='Padding2322_2dcircular',
  3964. constructor=wrap_functional(F.pad, pad=(2, 3, 2, 2), mode='circular'),
  3965. cpp_options_args='F::PadFuncOptions({2, 3, 2, 2}).mode(torch::kCircular)',
  3966. input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 1, 2, 3]),
  3967. reference_fn=lambda i, *_: padding2d_circular(i, (2, 3, 2, 2)),
  3968. skip_double=TEST_WITH_ROCM,
  3969. pickle=False,
  3970. ),
  3971. dict(
  3972. fullname='Padding3331_2dcircular',
  3973. constructor=wrap_functional(F.pad, pad=(3, 3, 3, 1), mode='circular'),
  3974. cpp_options_args='F::PadFuncOptions({3, 3, 3, 1}).mode(torch::kCircular)',
  3975. input_fn=lambda: torch.arange(9, out=torch.DoubleTensor()).reshape([1, 1, 3, 3]),
  3976. reference_fn=lambda i, *_: padding2d_circular(i, (3, 3, 3, 1)),
  3977. skip_double=TEST_WITH_ROCM,
  3978. pickle=False,
  3979. ),
  3980. dict(
  3981. fullname='Padding122112_3dcircular',
  3982. constructor=wrap_functional(F.pad, pad=(1, 2, 2, 1, 1, 2), mode='circular'),
  3983. cpp_options_args='F::PadFuncOptions({1, 2, 2, 1, 1, 2}).mode(torch::kCircular)',
  3984. input_fn=lambda: torch.arange(12, out=torch.DoubleTensor()).reshape([1, 1, 2, 2, 3]),
  3985. reference_fn=lambda i, *_: padding3d_circular(i, (1, 2, 2, 1, 1, 2)),
  3986. skip_double=TEST_WITH_ROCM,
  3987. pickle=False,
  3988. ),
  3989. dict(
  3990. fullname='Padding322112_3dcircular',
  3991. constructor=wrap_functional(F.pad, pad=(3, 2, 2, 1, 1, 2), mode='circular'),
  3992. cpp_options_args='F::PadFuncOptions({3, 2, 2, 1, 1, 2}).mode(torch::kCircular)',
  3993. input_fn=lambda: torch.arange(12, out=torch.DoubleTensor()).reshape([1, 1, 2, 2, 3]),
  3994. reference_fn=lambda i, *_: padding3d_circular(i, (3, 2, 2, 1, 1, 2)),
  3995. skip_double=TEST_WITH_ROCM,
  3996. pickle=False,
  3997. ),
  3998. dict(
  3999. fullname='Padding332122_3dcircular',
  4000. constructor=wrap_functional(F.pad, pad=(3, 3, 2, 1, 2, 2), mode='circular'),
  4001. cpp_options_args='F::PadFuncOptions({3, 3, 2, 1, 2, 2}).mode(torch::kCircular)',
  4002. input_fn=lambda: torch.arange(12, out=torch.DoubleTensor()).reshape([1, 1, 2, 2, 3]),
  4003. reference_fn=lambda i, *_: padding3d_circular(i, (3, 3, 2, 1, 2, 2)),
  4004. skip_double=TEST_WITH_ROCM,
  4005. pickle=False,
  4006. ),
  4007. dict(
  4008. module_name='PairwiseDistance',
  4009. input_fn=lambda: (torch.randn(10, 8), torch.randn(10, 8)),
  4010. ),
  4011. dict(
  4012. module_name='PairwiseDistance',
  4013. input_fn=lambda: (torch.randn(10, 1), torch.randn(10, 8)),
  4014. desc='broadcast_lhs'
  4015. ),
  4016. dict(
  4017. module_name='PairwiseDistance',
  4018. input_fn=lambda: (torch.randn(10, 8), torch.randn(1, 8)),
  4019. desc='broadcast_rhs'
  4020. ),
  4021. dict(
  4022. module_name='PairwiseDistance',
  4023. constructor_args=(1.5, 1e-05, True),
  4024. cpp_constructor_args='torch::nn::PairwiseDistanceOptions().p(1.5).eps(1e-05).keepdim(true)',
  4025. input_fn=lambda: (torch.randn(10, 8), torch.randn(10, 8)),
  4026. desc='with_non_default_args',
  4027. ),
  4028. dict(
  4029. module_name='PairwiseDistance',
  4030. input_fn=lambda: (torch.randn(8), torch.randn(8)),
  4031. reference_fn=single_batch_reference_fn,
  4032. desc='no_batch_dim',
  4033. ),
  4034. dict(
  4035. module_name='TransformerEncoderLayer',
  4036. constructor_args=(4, 2, 16, 0.0),
  4037. cpp_constructor_args='''torch::nn::TransformerEncoderLayerOptions(4, 2)
  4038. .dim_feedforward(16)
  4039. .dropout(0.0)''',
  4040. input_size=(2, 3, 4),
  4041. desc='relu_activation',
  4042. with_tf32=True,
  4043. tf32_precision=0.1,
  4044. # TODO(#50743): figure out the error
  4045. # RuntimeError: The size of tensor a (6) must match the size of tensor b (4)
  4046. # at non-singleton dimension 2
  4047. check_batched_grad=False,
  4048. check_gradgrad=False,
  4049. ),
  4050. dict(
  4051. module_name='TransformerEncoderLayer',
  4052. constructor_args=(4, 2, 8, 0.0, F.gelu),
  4053. cpp_constructor_args='''torch::nn::TransformerEncoderLayerOptions(4, 2)
  4054. .dim_feedforward(8)
  4055. .dropout(0.0)
  4056. .activation(torch::kGELU)''',
  4057. input_size=(2, 3, 4),
  4058. check_gradgrad=False,
  4059. desc='gelu_activation',
  4060. with_tf32=True,
  4061. tf32_precision=0.05,
  4062. ),
  4063. dict(
  4064. module_name='TransformerDecoderLayer',
  4065. constructor_args=(4, 2, 8, 0.0),
  4066. cpp_constructor_args='''torch::nn::TransformerDecoderLayerOptions(4, 2)
  4067. .dim_feedforward(8)
  4068. .dropout(0.0)''',
  4069. input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4)),
  4070. check_gradgrad=False,
  4071. desc='relu_activation',
  4072. with_tf32=True,
  4073. tf32_precision=0.05,
  4074. ),
  4075. dict(
  4076. module_name='TransformerDecoderLayer',
  4077. constructor_args=(4, 2, 8, 0.0, F.gelu),
  4078. cpp_constructor_args='''torch::nn::TransformerDecoderLayerOptions(4, 2)
  4079. .dim_feedforward(8)
  4080. .dropout(0.0)
  4081. .activation(torch::kGELU)''',
  4082. input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4)),
  4083. check_gradgrad=False,
  4084. desc='gelu_activation',
  4085. with_tf32=True,
  4086. tf32_precision=0.05,
  4087. ),
  4088. dict(
  4089. module_name='Transformer',
  4090. constructor_args=(4, 2, 2, 2, 8, 0.0, F.relu),
  4091. cpp_constructor_args='''torch::nn::TransformerOptions()
  4092. .d_model(4)
  4093. .nhead(2)
  4094. .num_encoder_layers(2)
  4095. .num_decoder_layers(2)
  4096. .dim_feedforward(8)
  4097. .dropout(0.0)
  4098. .activation(torch::kReLU)''',
  4099. input_fn=lambda:(torch.rand(3, 3, 4), torch.rand(2, 3, 4), torch.rand(3, 3)),
  4100. check_gradgrad=False,
  4101. desc='multilayer_coder',
  4102. with_tf32=True,
  4103. tf32_precision=0.02,
  4104. ),
  4105. dict(
  4106. module_name='Linear',
  4107. constructor_args=(3, 5),
  4108. cpp_constructor_args='torch::nn::LinearOptions(3, 5)',
  4109. input_fn=lambda: torch.rand(3),
  4110. reference_fn=lambda i, p, _: torch.mm(i.view(1, -1), p[0].t()).view(-1) + p[1],
  4111. desc="no_batch_dim",
  4112. with_tf32=True,
  4113. tf32_precision=0.005,
  4114. ),
  4115. dict(
  4116. module_name='Flatten',
  4117. cpp_constructor_args='torch::nn::FlattenOptions().start_dim(-3).end_dim(-1)',
  4118. constructor_args=(-3, -1),
  4119. input_size=(3, 4, 5),
  4120. reference_fn=single_batch_reference_fn,
  4121. desc="no_batch_dim",
  4122. ),
  4123. dict(
  4124. module_name='Unflatten',
  4125. cpp_constructor_args='torch::nn::UnflattenOptions(-2, {2, 2})',
  4126. constructor_args=(-2, torch.Size([2, 2])),
  4127. input_size=(3, 4, 5),
  4128. reference_fn=single_batch_reference_fn,
  4129. desc="no_batch_dim",
  4130. ),
  4131. ]
  4132. # add conv padding mode tests:
  4133. for padding_mode, cpp_padding_mode in zip(
  4134. ['reflect', 'circular', 'replicate', 'zeros'],
  4135. ['torch::kReflect', 'torch::kCircular', 'torch::kReplicate', 'torch::kZeros']):
  4136. # conv signature:
  4137. # in_channels, out_channels, kernel_size, stride=1,
  4138. # padding=0, dilation=1, groups=1,
  4139. # bias=True, padding_mode='zeros'
  4140. for d in (1, 2, 3):
  4141. if d == 3 and padding_mode == 'reflect':
  4142. # FIXME: remove after implementing reflection pad 3d
  4143. # https://github.com/pytorch/pytorch/issues/27655
  4144. continue
  4145. padding = tuple(range(1, d + 1))
  4146. cpp_padding = '{' + ', '.join(map(str, padding)) + '}'
  4147. input_size = (2, 2) + (4,) * d
  4148. output_size = (2, 3) + tuple(p + 1 for p in padding) # simplified from `(4 + 2 * p - 3) // 2 + 1`
  4149. new_module_tests.append(
  4150. dict(
  4151. module_name='Conv{}d'.format(d),
  4152. constructor_args=(2, 3, 3, 2, padding, 1, 1, True, padding_mode),
  4153. cpp_constructor_args='''torch::nn::Conv{}dOptions(2, 3, 3)
  4154. .stride(2)
  4155. .padding({})
  4156. .dilation(1)
  4157. .groups(1)
  4158. .bias(true)
  4159. .padding_mode({})'''.format(d, cpp_padding, cpp_padding_mode),
  4160. input_size=input_size,
  4161. output_size=output_size,
  4162. cudnn=True,
  4163. desc='{}_stride2_pad2'.format(padding_mode),
  4164. with_tf32=True,
  4165. tf32_precision=0.05
  4166. ),
  4167. )
  4168. # Check that non linear activations work with no batch dimensions
  4169. non_linear_activations_no_batch = [
  4170. 'ELU', 'Hardshrink', 'Hardsigmoid', 'Hardtanh', 'Hardswish', 'LeakyReLU',
  4171. 'LogSigmoid', 'PReLU', 'ReLU', 'ReLU6', 'RReLU', 'SELU', 'CELU', 'GELU', 'GLU',
  4172. 'Sigmoid', 'SiLU', 'Mish', 'Softplus', 'Softshrink', 'Softsign', 'Tanh',
  4173. 'Tanhshrink', 'Threshold'
  4174. ]
  4175. non_linear_activations_extra_info: Dict[str, dict] = {
  4176. 'CELU': {'constructor_args': (2.,)},
  4177. 'Threshold': {'constructor_args': (2., 1.)},
  4178. 'Hardsigmoid': {'check_gradgrad': False, 'check_jit': False},
  4179. 'Hardswish': {'check_gradgrad': False, 'check_jit': False},
  4180. # For RRelu, test that compare CPU and GPU results fail because RNG
  4181. # is different between CPU and GPU
  4182. 'RReLU': {'test_cuda': False},
  4183. }
  4184. for non_linear_activation in non_linear_activations_no_batch:
  4185. activation_test_info = dict(
  4186. module_name=non_linear_activation,
  4187. input_size=(4,),
  4188. reference_fn=single_batch_reference_fn,
  4189. desc='no_batch_dim',
  4190. test_cpp_api_parity=False,
  4191. )
  4192. extra_info = non_linear_activations_extra_info.get(non_linear_activation, {})
  4193. activation_test_info.update(extra_info)
  4194. new_module_tests.append(activation_test_info)
  4195. def kldivloss_reference(input, target, reduction='mean'):
  4196. result = target * (target.log() - input)
  4197. if reduction == 'mean':
  4198. return result.mean()
  4199. elif reduction == 'sum':
  4200. return result.sum()
  4201. elif reduction == 'batchmean' and result.dim() != 0:
  4202. return result.sum() / result.size(0)
  4203. return result
  4204. def kldivloss_log_target_reference(input, target, reduction='mean'):
  4205. result = torch.exp(target) * (target - input)
  4206. if reduction == 'mean':
  4207. return result.mean()
  4208. elif reduction == 'sum':
  4209. return result.sum()
  4210. elif reduction == 'batchmean' and result.dim() != 0:
  4211. return result.sum() / result.size(0)
  4212. return result
  4213. def nlllossNd_reference(input, target, weight=None, ignore_index=-100,
  4214. reduction='mean'):
  4215. assert input.dim() >= 3
  4216. N = input.size(0)
  4217. C = input.size(1)
  4218. out_size = (N,) + input.size()[2:]
  4219. output = torch.zeros(out_size).type_as(input)
  4220. if weight is None:
  4221. weight = torch.ones(C).type_as(input)
  4222. total_weight = 0
  4223. for tup in product(*[range(size) for size in out_size]):
  4224. t_nx = target[tup]
  4225. norm = 0. if ignore_index == t_nx else weight[t_nx].item()
  4226. input_index = list(tup)
  4227. input_index.insert(1, t_nx)
  4228. output[tup] = -input[tuple(input_index)] * norm
  4229. total_weight += norm
  4230. if reduction == 'mean':
  4231. return output.sum() / total_weight
  4232. elif reduction == 'sum':
  4233. return output.sum()
  4234. return output
  4235. def cross_entropy_loss_prob_target_reference(input, target, weight=None, reduction='mean',
  4236. label_smoothing=0.0):
  4237. assert input.dim() >= 2
  4238. input = torch.log_softmax(input, 1)
  4239. C = input.size(1)
  4240. if weight is None:
  4241. weight = torch.ones(C).type_as(input)
  4242. weight = weight.view(1, C, *(1 for _ in input.shape[2:]))
  4243. if label_smoothing > 0.0:
  4244. assert label_smoothing <= 1.0
  4245. target = (target * (1 - label_smoothing) + label_smoothing / C)
  4246. output = -(input * target * weight).sum(dim=1)
  4247. if reduction == 'mean':
  4248. return output.mean()
  4249. elif reduction == 'sum':
  4250. return output.sum()
  4251. return output
  4252. def cross_entropy_loss_indices_target_reference(input, target, weight=None, ignore_index=-100,
  4253. reduction='mean', label_smoothing=0.0):
  4254. log_softmax_input = torch.log_softmax(input, 1)
  4255. nllloss = F.nll_loss(
  4256. log_softmax_input,
  4257. target,
  4258. weight,
  4259. ignore_index=ignore_index,
  4260. reduction=reduction)
  4261. if label_smoothing == 0.0:
  4262. return nllloss
  4263. assert 0.0 < label_smoothing <= 1.0
  4264. input = torch.log_softmax(input, 1)
  4265. C = input.size(1)
  4266. if weight is not None:
  4267. input = input * weight.view(1, C, *(1 for _ in input.shape[2:]))
  4268. smooth_loss = -torch.sum(input, 1)
  4269. ignore_mask = target == ignore_index
  4270. smooth_loss.masked_fill_(ignore_mask, 0.0)
  4271. if reduction == 'mean':
  4272. if weight is not None:
  4273. # TODO: This code can path can be removed if #61309 is resolved
  4274. # loss is normalized by the weights to be consistent with nll_loss_nd
  4275. ret = torch.sum(smooth_loss) / weight.gather(0, target.masked_select(ignore_mask.logical_not()).flatten()).sum()
  4276. else:
  4277. ret = torch.mean(smooth_loss.masked_select(ignore_mask.logical_not()))
  4278. elif reduction == 'sum':
  4279. ret = torch.sum(smooth_loss)
  4280. else:
  4281. ret = smooth_loss
  4282. return (1 - label_smoothing) * nllloss + ret * (label_smoothing / C)
  4283. def cross_entropy_loss_reference(input, target, weight=None, ignore_index=-100, reduction='mean',
  4284. label_smoothing=0.0):
  4285. if input.shape == target.shape:
  4286. return cross_entropy_loss_prob_target_reference(
  4287. input,
  4288. target,
  4289. weight=weight,
  4290. reduction=reduction,
  4291. label_smoothing=label_smoothing)
  4292. else:
  4293. return cross_entropy_loss_indices_target_reference(
  4294. input, target, weight=weight, reduction=reduction,
  4295. ignore_index=ignore_index, label_smoothing=label_smoothing
  4296. )
  4297. def nllloss_reference(input, target, weight=None, ignore_index=-100,
  4298. reduction='mean'):
  4299. def nll_loss_helper(input, target, weight, ignore_index):
  4300. if target == ignore_index:
  4301. return (0, 0)
  4302. norm = 1 if weight is None else weight[target]
  4303. result = -input[target] * norm
  4304. return (result, norm)
  4305. losses_and_weights = [nll_loss_helper(i, t, weight, ignore_index)
  4306. for i, t in zip(input, target)]
  4307. losses, weights = zip(*losses_and_weights)
  4308. losses_tensor = input.new_tensor(losses)
  4309. if reduction == 'mean':
  4310. return sum(losses_tensor) / sum(weights)
  4311. elif reduction == 'sum':
  4312. return sum(losses_tensor)
  4313. else:
  4314. return losses_tensor
  4315. def smoothl1loss_reference(input, target, reduction='mean', beta=1.0):
  4316. abs_diff = (input - target).abs()
  4317. ge_beta_mask = (abs_diff >= beta).type_as(abs_diff)
  4318. lt_beta_mask = (abs_diff < beta).type_as(abs_diff)
  4319. # when beta <= 0 we should just use l1_loss
  4320. if beta == 0:
  4321. output = abs_diff
  4322. else:
  4323. output = ge_beta_mask * (abs_diff - 0.5 * beta) + lt_beta_mask * 0.5 * (abs_diff ** 2) / beta
  4324. if reduction == 'mean':
  4325. return output.mean()
  4326. elif reduction == 'sum':
  4327. return output.sum()
  4328. return output
  4329. def huberloss_reference(input, target, reduction='mean', delta=1.0):
  4330. abs_diff = (input - target).abs()
  4331. ge_delta_mask = (abs_diff >= delta)
  4332. lt_delta_mask = (abs_diff < delta)
  4333. output = ge_delta_mask * delta * (abs_diff - 0.5 * delta) + lt_delta_mask * 0.5 * (abs_diff ** 2)
  4334. if reduction == 'mean':
  4335. return output.mean()
  4336. elif reduction == 'sum':
  4337. return output.sum()
  4338. return output
  4339. def _multilabelmarginloss_reference(input, target):
  4340. targets = []
  4341. for target_index in target:
  4342. if target_index < 0:
  4343. break
  4344. targets.append(target_index)
  4345. sum = 0
  4346. for target_index in targets:
  4347. for i in range(0, len(input)):
  4348. if i not in targets:
  4349. sum += max(0, 1 - input[target_index] + input[i])
  4350. return sum
  4351. def multilabelmarginloss_reference(input, target, reduction='mean'):
  4352. # make everything 2-dimensional
  4353. input_dim = input.dim()
  4354. if input.dim() < 2:
  4355. assert target.dim() < 2
  4356. input = input.unsqueeze(0) if input.dim() == 1 else input.unsqueeze(0).unsqueeze(0)
  4357. target = target.unsqueeze(0) if target.dim() == 1 else target.unsqueeze(0).unsqueeze(0)
  4358. n = input.size(0)
  4359. dim = input.size(1)
  4360. output = input.new(n).zero_()
  4361. for i in range(0, n):
  4362. output[i] = _multilabelmarginloss_reference(input[i], target[i])
  4363. if reduction == 'mean':
  4364. return output.mean() / dim
  4365. elif reduction == 'sum':
  4366. return output.sum() / dim
  4367. elif input_dim < 2:
  4368. # we know we have (1, C) X (1, C) -> (1,), so squeeze will get us
  4369. # back to correct dimensionality
  4370. return output.squeeze() / dim
  4371. else:
  4372. return output / dim
  4373. def hingeembeddingloss_reference(input, target, margin=1.0, reduction='mean'):
  4374. margin_clamp = (margin - input).clamp(min=0).type_as(input)
  4375. output = torch.where(target == 1, input, margin_clamp)
  4376. if reduction == 'mean':
  4377. return output.mean()
  4378. elif reduction == 'sum':
  4379. return output.sum()
  4380. return output
  4381. def softmarginloss_reference(input, target, reduction='mean'):
  4382. output = (1 + (-input * target).exp()).log()
  4383. if reduction == 'mean':
  4384. return output.mean()
  4385. elif reduction == 'sum':
  4386. return output.sum()
  4387. return output
  4388. def _multimarginloss_reference(input, target_idx, p, margin, weight):
  4389. if weight is None:
  4390. weight = input.new(len(input)).fill_(1)
  4391. output = 0
  4392. for i in range(0, len(input)):
  4393. if i != target_idx:
  4394. output += max(0, weight[target_idx] * (margin - input[target_idx] + input[i]) ** p)
  4395. return output
  4396. def multimarginloss_reference(input, target, p=1, margin=1, weight=None, reduction='mean'):
  4397. if input.dim() < 2:
  4398. input = input.unsqueeze(0) if input.dim() == 1 else input.unsqueeze(0).unsqueeze(0)
  4399. target_dim = target.dim()
  4400. if target.dim() == 0:
  4401. target = target.unsqueeze(0)
  4402. n = input.size(0)
  4403. dim = input.size(1)
  4404. output = input.new(n)
  4405. for x in range(0, n):
  4406. output[x] = _multimarginloss_reference(input[x], target[x], p, margin, weight)
  4407. if reduction == 'mean':
  4408. return output.mean() / dim
  4409. elif reduction == 'sum':
  4410. return output.sum() / dim
  4411. elif target_dim == 0:
  4412. return output.squeeze(0) / dim
  4413. return output / dim
  4414. def cosineembeddingloss_reference(input1, input2, target, margin=0, reduction='mean'):
  4415. def _cos(a, b):
  4416. cos = a.new(a.size(0))
  4417. for i in range(0, a.size(0)):
  4418. cos[i] = (a[i] * b[i]).sum() / ((((a[i] * a[i]).sum() + 1e-12) * ((b[i] * b[i]).sum() + 1e-12)) ** 0.5)
  4419. return cos
  4420. output = torch.where(target == 1, 1 - _cos(input1, input2), (_cos(input1, input2) - margin).clamp(min=0))
  4421. if reduction == 'mean':
  4422. return output.mean()
  4423. elif reduction == 'sum':
  4424. return output.sum()
  4425. return output
  4426. def tripletmarginloss_reference(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False,
  4427. reduction='mean'):
  4428. d_p = torch.pairwise_distance(anchor, positive, p, eps)
  4429. d_n = torch.pairwise_distance(anchor, negative, p, eps)
  4430. if swap:
  4431. d_s = torch.pairwise_distance(positive, negative, p, eps)
  4432. d_n = torch.min(d_n, d_s)
  4433. output = torch.clamp(margin + d_p - d_n, min=0.0)
  4434. if reduction == 'mean':
  4435. return output.mean()
  4436. elif reduction == 'sum':
  4437. return output.sum()
  4438. return output
  4439. def marginrankingloss_reference(input1, input2, target, margin=0, reduction='mean'):
  4440. output = (-target * (input1 - input2) + margin).clamp(min=0)
  4441. if reduction == 'mean':
  4442. return output.mean()
  4443. elif reduction == 'sum':
  4444. return output.sum()
  4445. return output
  4446. # this directly follows Graves et al's paper, in contrast to the production implementation, it does not use log-space
  4447. def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean'):
  4448. input_lengths = torch.as_tensor(input_lengths, dtype=torch.long)
  4449. target_lengths = torch.as_tensor(target_lengths, dtype=torch.long)
  4450. dt = log_probs.dtype
  4451. log_probs = log_probs.double() # we need the accuracy as we are not in logspace
  4452. targets = targets.long()
  4453. cum_target_lengths = target_lengths.cumsum(0)
  4454. losses = []
  4455. for i in range(log_probs.size(1)):
  4456. input_length = input_lengths[i].item()
  4457. target_length = target_lengths[i].item()
  4458. cum_target_length = cum_target_lengths[i].item()
  4459. targets_prime = targets.new_full((2 * target_length + 1,), blank)
  4460. if targets.dim() == 2:
  4461. targets_prime[1::2] = targets[i, :target_length]
  4462. else:
  4463. targets_prime[1::2] = targets[cum_target_length - target_length:cum_target_length]
  4464. probs = log_probs[:input_length, i].exp()
  4465. alpha = log_probs.new_zeros((target_length * 2 + 1,))
  4466. alpha[0] = probs[0, blank]
  4467. alpha[1] = probs[0, targets_prime[1]]
  4468. mask_third = (targets_prime[:-2] != targets_prime[2:])
  4469. for t in range(1, input_length):
  4470. alpha_next = alpha.clone()
  4471. alpha_next[1:] += alpha[:-1]
  4472. alpha_next[2:] += torch.where(mask_third, alpha[:-2], alpha.new_zeros(1))
  4473. alpha = probs[t, targets_prime] * alpha_next
  4474. losses.append(-alpha[-2:].sum().log()[None])
  4475. output = torch.cat(losses, 0)
  4476. if reduction == 'mean':
  4477. return (output / target_lengths.to(dtype=output.dtype, device=output.device)).mean()
  4478. elif reduction == 'sum':
  4479. return output.sum()
  4480. output = output.to(dt)
  4481. return output
  4482. def padding1d_circular(input, pad):
  4483. r""" input:
  4484. [[[0., 1., 2.],
  4485. [3., 4., 5.]]]
  4486. pad: (1, 2)
  4487. output:
  4488. [[[2., 0., 1., 2., 0., 1.],
  4489. [5., 3., 4., 5., 3., 4.]]]
  4490. """
  4491. return torch.cat([input[:, :, -pad[0]:], input,
  4492. input[:, :, 0:pad[1]]], dim=2)
  4493. def padding2d_circular(input, pad):
  4494. r"""input:
  4495. [[[[0., 1., 2],
  4496. [3., 4., 5.]]]]
  4497. pad: (1, 2, 2, 1)
  4498. output:
  4499. [[[[2., 0., 1., 2., 0., 1.],
  4500. [5., 3., 4., 5., 3., 4.],
  4501. [2., 0., 1., 2., 0., 1.],
  4502. [5., 3., 4., 5., 3., 4.],
  4503. [2., 0., 1., 2., 0., 1.]]]]
  4504. """
  4505. input = torch.cat([input[:, :, -pad[2]:], input, input[:, :, 0:pad[3]]], dim=2)
  4506. return torch.cat([input[:, :, :, -pad[0]:], input, input[:, :, :, 0:pad[1]]], dim=3)
  4507. def padding3d_circular(input, pad):
  4508. r"""input:
  4509. [[[[[ 0., 1., 2.],
  4510. [ 3., 4., 5.]],
  4511. [[ 6., 7., 8.],
  4512. [ 9., 10., 11.]]]]]
  4513. pad: (1, 2, 2, 1, 1, 2)
  4514. output: [[[[[ 8., 6., 7., 8., 6., 7.],
  4515. [11., 9., 10., 11., 9., 10.],
  4516. [ 8., 6., 7., 8., 6., 7.],
  4517. [11., 9., 10., 11., 9., 10.],
  4518. [ 8., 6., 7., 8., 6., 7.]],
  4519. [[ 2., 0., 1., 2., 0., 1.],
  4520. [ 5., 3., 4., 5., 3., 4.],
  4521. [ 2., 0., 1., 2., 0., 1.],
  4522. [ 5., 3., 4., 5., 3., 4.],
  4523. [ 2., 0., 1., 2., 0., 1.]],
  4524. [[ 8., 6., 7., 8., 6., 7.],
  4525. [11., 9., 10., 11., 9., 10.],
  4526. [ 8., 6., 7., 8., 6., 7.],
  4527. [11., 9., 10., 11., 9., 10.],
  4528. [ 8., 6., 7., 8., 6., 7.]],
  4529. [[ 2., 0., 1., 2., 0., 1.],
  4530. [ 5., 3., 4., 5., 3., 4.],
  4531. [ 2., 0., 1., 2., 0., 1.],
  4532. [ 5., 3., 4., 5., 3., 4.],
  4533. [ 2., 0., 1., 2., 0., 1.]],
  4534. [[ 8., 6., 7., 8., 6., 7.],
  4535. [11., 9., 10., 11., 9., 10.],
  4536. [ 8., 6., 7., 8., 6., 7.],
  4537. [11., 9., 10., 11., 9., 10.],
  4538. [ 8., 6., 7., 8., 6., 7.]]]]]
  4539. """
  4540. input = torch.cat([input[:, :, -pad[4]:], input, input[:, :, 0:pad[5]]], dim=2)
  4541. input = torch.cat([input[:, :, :, -pad[2]:], input, input[:, :, :, 0:pad[3]]], dim=3)
  4542. return torch.cat([input[:, :, :, :, -pad[0]:], input, input[:, :, :, :, 0:pad[1]]], dim=4)
  4543. loss_reference_fns: Dict['str', Callable] = {
  4544. 'KLDivLoss': kldivloss_reference,
  4545. 'KLDivLoss_log_target': kldivloss_log_target_reference,
  4546. 'NLLLoss': nllloss_reference,
  4547. 'NLLLossNd': nlllossNd_reference,
  4548. 'SmoothL1Loss': smoothl1loss_reference,
  4549. 'HuberLoss': huberloss_reference,
  4550. 'MultiLabelMarginLoss': multilabelmarginloss_reference,
  4551. 'HingeEmbeddingLoss': hingeembeddingloss_reference,
  4552. 'SoftMarginLoss': softmarginloss_reference,
  4553. 'MultiMarginLoss': multimarginloss_reference,
  4554. 'CosineEmbeddingLoss': cosineembeddingloss_reference,
  4555. 'TripletMarginLoss': tripletmarginloss_reference,
  4556. 'MarginRankingLoss': marginrankingloss_reference,
  4557. 'CTCLoss': ctcloss_reference,
  4558. 'CrossEntropyLoss': cross_entropy_loss_reference
  4559. }
  4560. criterion_tests = [
  4561. dict(
  4562. module_name='L1Loss',
  4563. input_size=(2, 3, 4),
  4564. target_fn=lambda: torch.randn((2, 3, 4), requires_grad=True),
  4565. reference_fn=lambda i, t, _: 1. / i.numel() *
  4566. sum((a - b).abs().sum() for a, b in zip(i, t)),
  4567. check_complex=True,
  4568. ),
  4569. dict(
  4570. module_name='NLLLoss',
  4571. input_fn=lambda: torch.rand(15, 10).log(),
  4572. target_fn=lambda: torch.empty(15).uniform_().mul(10).floor().long(),
  4573. reference_fn=lambda i, t, m:
  4574. nllloss_reference(i, t, reduction=get_reduction(m)),
  4575. check_sum_reduction=True,
  4576. check_bfloat16=True,
  4577. ),
  4578. dict(
  4579. module_name='NLLLoss',
  4580. constructor_args=(None, None, 2),
  4581. cpp_constructor_args='torch::nn::NLLLossOptions().weight({}).ignore_index(2)',
  4582. input_fn=lambda: torch.rand(15, 10).log(),
  4583. target_fn=lambda: torch.empty(15).uniform_().mul(10).floor().long(),
  4584. reference_fn=lambda i, t, _: nllloss_reference(i, t, ignore_index=2),
  4585. desc='ignore_index',
  4586. check_bfloat16=True,
  4587. ),
  4588. dict(
  4589. module_name='NLLLoss',
  4590. constructor_args_fn=lambda: (torch.rand(10),),
  4591. cpp_constructor_args='torch::nn::NLLLossOptions().weight(torch::rand(10))',
  4592. input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
  4593. target_fn=lambda: torch.empty(15).uniform_().mul(10).floor().long(),
  4594. reference_fn=lambda i, t, m:
  4595. nllloss_reference(i, t, weight=get_weight(m)),
  4596. desc='weights',
  4597. check_bfloat16=True,
  4598. ),
  4599. dict(
  4600. module_name='NLLLoss',
  4601. constructor_args_fn=lambda: (torch.rand(10), None, 2),
  4602. cpp_constructor_args='torch::nn::NLLLossOptions().weight(torch::rand(10)).ignore_index(2)',
  4603. input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
  4604. target_fn=lambda: torch.empty(15).uniform_().mul(10).floor().long(),
  4605. reference_fn=lambda i, t, m:
  4606. nllloss_reference(i, t, weight=get_weight(m), ignore_index=2),
  4607. desc='weights_ignore_index',
  4608. check_bfloat16=True,
  4609. ),
  4610. dict(
  4611. module_name='NLLLoss',
  4612. constructor_args_fn=lambda: (torch.rand(10), None, -1),
  4613. cpp_constructor_args='torch::nn::NLLLossOptions().weight(torch::rand(10)).ignore_index(-1)',
  4614. input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
  4615. target_fn=lambda: torch.empty(15).uniform_().mul(10 + 1).floor().long() - 1,
  4616. reference_fn=lambda i, t, m:
  4617. nllloss_reference(i, t, weight=get_weight(m), ignore_index=-1),
  4618. desc='weights_ignore_index_neg',
  4619. check_bfloat16=True,
  4620. ),
  4621. dict(
  4622. module_name='KLDivLoss',
  4623. input_fn=lambda: torch.rand(10, 10).log(),
  4624. target_fn=lambda: torch.rand(10, 10),
  4625. reference_fn=lambda i, t, m:
  4626. kldivloss_reference(i, t, get_reduction(m)),
  4627. check_sum_reduction=True,
  4628. ),
  4629. dict(
  4630. module_name='KLDivLoss',
  4631. constructor=wraps(nn.KLDivLoss)(partial(nn.KLDivLoss, log_target=True)),
  4632. cpp_constructor_args='torch::nn::KLDivLossOptions().log_target(true)',
  4633. input_fn=lambda: torch.rand(10, 10).log(),
  4634. target_fn=lambda: torch.rand(10, 10).log(),
  4635. reference_fn=lambda i, t, m:
  4636. kldivloss_log_target_reference(i, t, get_reduction(m)),
  4637. check_sum_reduction=True,
  4638. desc='log_target',
  4639. ),
  4640. dict(
  4641. module_name='MSELoss',
  4642. input_size=(2, 3, 4, 5),
  4643. target_fn=lambda: torch.randn((2, 3, 4, 5), requires_grad=True),
  4644. reference_fn=lambda i, t, m: ((i - t).abs().pow(2).sum() / (i.numel()
  4645. if get_reduction(m) == 'mean' else 1)),
  4646. check_sum_reduction=True,
  4647. ),
  4648. dict(
  4649. module_name='BCELoss',
  4650. input_fn=lambda: torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2),
  4651. target_fn=lambda: torch.randn(15, 10).gt(0).double(),
  4652. reference_fn=lambda i, t, m: -(t * i.log() + (1 - t) * (1 - i).log()).sum() /
  4653. (i.numel() if get_reduction(m) else 1),
  4654. check_bfloat16=True,
  4655. ),
  4656. dict(
  4657. module_name='BCELoss',
  4658. constructor_args_fn=lambda: (torch.rand(10),),
  4659. cpp_constructor_args='torch::nn::BCELossOptions().weight(torch::rand(10))',
  4660. input_fn=lambda: torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2),
  4661. target_fn=lambda: torch.randn(15, 10).gt(0).double(),
  4662. reference_fn=lambda i, t, m: -((t * i.log() + (1 - t) * (1 - i).log()) * get_weight(m)).sum() /
  4663. (i.numel() if get_reduction(m) else 1),
  4664. desc='weights',
  4665. check_bfloat16=True,
  4666. ),
  4667. dict(
  4668. module_name='CrossEntropyLoss',
  4669. input_size=(15, 10),
  4670. target_fn=lambda: torch.empty(15).uniform_().mul(10).floor().long(),
  4671. ),
  4672. dict(
  4673. module_name='CrossEntropyLoss',
  4674. constructor_args_fn=lambda: (torch.rand(10),),
  4675. cpp_constructor_args='torch::nn::CrossEntropyLossOptions().weight(torch::rand(10))',
  4676. input_size=(15, 10),
  4677. target_fn=lambda: torch.empty(15).uniform_().mul(10).floor().long(),
  4678. desc='weights',
  4679. ),
  4680. dict(
  4681. module_name='HingeEmbeddingLoss',
  4682. input_size=(10,),
  4683. target_fn=lambda: torch.randn(10).gt(0).double().mul_(2).sub(1),
  4684. reference_fn=lambda i, t, m:
  4685. hingeembeddingloss_reference(i, t, reduction=get_reduction(m)),
  4686. check_sum_reduction=True,
  4687. ),
  4688. dict(
  4689. module_name='HingeEmbeddingLoss',
  4690. constructor_args=(0.5,),
  4691. cpp_constructor_args='torch::nn::HingeEmbeddingLossOptions().margin(0.5)',
  4692. input_size=(10,),
  4693. target_fn=lambda: torch.randn(10).gt(0).double().mul_(2).sub(1),
  4694. reference_fn=lambda i, t, m:
  4695. hingeembeddingloss_reference(i, t, margin=0.5, reduction=get_reduction(m)),
  4696. desc='margin',
  4697. check_sum_reduction=True,
  4698. ),
  4699. dict(
  4700. module_name='MultiLabelMarginLoss',
  4701. input_size=(10,),
  4702. target_fn=lambda: torch.rand(10).mul(10).floor().long(),
  4703. reference_fn=lambda i, t, m:
  4704. multilabelmarginloss_reference(i, t, reduction=get_reduction(m)),
  4705. desc="1d",
  4706. check_sum_reduction=True,
  4707. check_gradgrad=False,
  4708. check_bfloat16=True,
  4709. ),
  4710. dict(
  4711. module_name='MultiLabelMarginLoss',
  4712. input_size=(5, 10),
  4713. target_fn=lambda: torch.rand(5, 10).mul(10).floor().long(),
  4714. reference_fn=lambda i, t, m:
  4715. multilabelmarginloss_reference(i, t, reduction=get_reduction(m)),
  4716. check_sum_reduction=True,
  4717. check_gradgrad=False,
  4718. check_bfloat16=True,
  4719. ),
  4720. dict(
  4721. module_name='MultiLabelSoftMarginLoss',
  4722. input_size=(5, 10),
  4723. target_fn=lambda: torch.rand(5, 10).mul(2).floor(),
  4724. reference_fn=lambda i, t, m: -(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()).sum() / i.numel(),
  4725. check_gradgrad=False,
  4726. ),
  4727. dict(
  4728. module_name='MultiMarginLoss',
  4729. input_size=(5, 10),
  4730. target_fn=lambda: torch.rand(5).mul(8).floor().long(),
  4731. reference_fn=lambda i, t, m:
  4732. multimarginloss_reference(i, t, reduction=get_reduction(m)),
  4733. check_sum_reduction=True,
  4734. check_gradgrad=False,
  4735. ),
  4736. dict(
  4737. module_name='MultiMarginLoss',
  4738. input_size=(10,),
  4739. target_fn=lambda: torch.rand(1).mul(8).floor().long(),
  4740. reference_fn=lambda i, t, m:
  4741. multimarginloss_reference(i, t, reduction=get_reduction(m)),
  4742. desc='1d',
  4743. check_sum_reduction=True,
  4744. check_gradgrad=False,
  4745. ),
  4746. dict(
  4747. module_name='MultiMarginLoss',
  4748. constructor_args=(2,),
  4749. cpp_constructor_args='torch::nn::MultiMarginLossOptions().p(2)',
  4750. input_fn=lambda: torch.rand(5, 10).clamp_(1e-2, 1 - 1e-2),
  4751. target_fn=lambda: torch.rand(5).mul(8).floor().long(),
  4752. reference_fn=lambda i, t, m:
  4753. multimarginloss_reference(i, t, p=2, reduction=get_reduction(m)),
  4754. desc='p',
  4755. check_sum_reduction=True,
  4756. check_gradgrad=False,
  4757. ),
  4758. dict(
  4759. module_name='MultiMarginLoss',
  4760. constructor_args=(1, 0.5),
  4761. cpp_constructor_args='torch::nn::MultiMarginLossOptions().p(1).margin(0.5)',
  4762. legacy_constructor_args=(1, None, 0.5),
  4763. input_size=(5, 10),
  4764. target_fn=lambda: torch.rand(5).mul(8).floor().long(),
  4765. reference_fn=lambda i, t, m:
  4766. multimarginloss_reference(i, t, margin=0.5, reduction=get_reduction(m)),
  4767. desc='margin',
  4768. check_sum_reduction=True,
  4769. check_gradgrad=False,
  4770. ),
  4771. dict(
  4772. module_name='MultiMarginLoss',
  4773. constructor_args=(1, 1., torch.rand(10).double()),
  4774. cpp_constructor_args='torch::nn::MultiMarginLossOptions().p(1).margin(1.).weight(torch::rand(10))',
  4775. legacy_constructor_args=(1, torch.rand(10).double()),
  4776. input_size=(5, 10),
  4777. target_fn=lambda: torch.rand(5).mul(8).floor().long(),
  4778. reference_fn=lambda i, t, m:
  4779. multimarginloss_reference(i, t, weight=get_weight(m), reduction=get_reduction(m)),
  4780. desc='weights',
  4781. check_sum_reduction=True,
  4782. check_gradgrad=False,
  4783. ),
  4784. dict(
  4785. module_name='SmoothL1Loss',
  4786. input_size=(5, 10),
  4787. target_fn=lambda: torch.randn((5, 10), requires_grad=True),
  4788. check_sum_reduction=True,
  4789. reference_fn=lambda i, t, m, b=1.0:
  4790. smoothl1loss_reference(i, t, reduction=get_reduction(m), beta=b),
  4791. ),
  4792. dict(
  4793. module_name='HuberLoss',
  4794. input_size=(5, 10),
  4795. target_fn=lambda: torch.randn((5, 10), requires_grad=True),
  4796. check_sum_reduction=True,
  4797. check_half=True,
  4798. check_bfloat16=True,
  4799. reference_fn=lambda i, t, m:
  4800. huberloss_reference(i, t, reduction=get_reduction(m)),
  4801. ),
  4802. dict(
  4803. module_name='SoftMarginLoss',
  4804. input_size=(5, 5),
  4805. target_fn=lambda: torch.randn(5, 5).sign(),
  4806. reference_fn=lambda i, t, m:
  4807. softmarginloss_reference(i, t, reduction=get_reduction(m)),
  4808. check_sum_reduction=True,
  4809. ),
  4810. dict(
  4811. module_name='CosineEmbeddingLoss',
  4812. input_fn=lambda: (torch.rand(15, 10), torch.rand(15, 10)),
  4813. target_fn=lambda: torch.randn(15).sign(),
  4814. reference_fn=lambda i, t, m:
  4815. cosineembeddingloss_reference(i[0], i[1], t, reduction=get_reduction(m)),
  4816. check_sum_reduction=True,
  4817. ),
  4818. dict(
  4819. module_name='CosineEmbeddingLoss',
  4820. constructor_args=(0.7,),
  4821. cpp_constructor_args='torch::nn::CosineEmbeddingLossOptions().margin(0.7)',
  4822. input_fn=lambda: (torch.rand(15, 10), torch.rand(15, 10)),
  4823. target_fn=lambda: torch.randn(15).sign(),
  4824. reference_fn=lambda i, t, m:
  4825. cosineembeddingloss_reference(i[0], i[1], t, margin=0.7, reduction=get_reduction(m)),
  4826. desc='margin',
  4827. check_sum_reduction=True,
  4828. ),
  4829. dict(
  4830. module_name='MarginRankingLoss',
  4831. input_fn=lambda: (torch.randn(50).mul(10), torch.randn(50).mul(10)),
  4832. target_fn=lambda: torch.randn(50).sign(),
  4833. reference_fn=lambda i, t, m:
  4834. marginrankingloss_reference(i[0], i[1], t, reduction=get_reduction(m)),
  4835. check_sum_reduction=True,
  4836. ),
  4837. dict(
  4838. module_name='MarginRankingLoss',
  4839. constructor_args=(0.5,),
  4840. cpp_constructor_args='torch::nn::MarginRankingLossOptions().margin(0.5)',
  4841. input_fn=lambda: (torch.randn(50).mul(10), torch.randn(50).mul(10)),
  4842. target_fn=lambda: torch.randn(50).sign(),
  4843. reference_fn=lambda i, t, m:
  4844. marginrankingloss_reference(i[0], i[1], t, margin=0.5, reduction=get_reduction(m)),
  4845. desc='margin',
  4846. check_sum_reduction=True,
  4847. ),
  4848. dict(
  4849. module_name='BCEWithLogitsLoss',
  4850. input_fn=lambda: torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2),
  4851. target_fn=lambda: torch.randn(15, 10).gt(0).double(),
  4852. ),
  4853. dict(
  4854. module_name='BCEWithLogitsLoss',
  4855. constructor_args=(torch.rand(10),),
  4856. cpp_constructor_args='torch::nn::BCEWithLogitsLossOptions().weight(torch::rand(10))',
  4857. input_fn=lambda: torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2),
  4858. target_fn=lambda: torch.randn(15, 10).gt(0).double(),
  4859. desc='weights',
  4860. ),
  4861. dict(
  4862. module_name='BCEWithLogitsLoss',
  4863. constructor_args=(torch.rand(()),),
  4864. cpp_constructor_args='torch::nn::BCEWithLogitsLossOptions().weight(torch::rand({}))',
  4865. input_fn=lambda: torch.rand(()).clamp_(1e-2, 1 - 1e-2),
  4866. target_fn=lambda: torch.randn(()).gt(0).double(),
  4867. desc='scalar_weights'
  4868. ),
  4869. dict(
  4870. module_name='NLLLoss',
  4871. input_size=(2, 3, 5, 5),
  4872. target_fn=lambda: torch.rand(2, 5, 5).mul(3).floor().long(),
  4873. reference_fn=lambda i, t, m:
  4874. loss_reference_fns['NLLLossNd'](i, t, reduction=get_reduction(m)),
  4875. check_sum_reduction=True,
  4876. desc='2d',
  4877. check_bfloat16=True,
  4878. ),
  4879. dict(
  4880. module_name='NLLLoss',
  4881. constructor_args_fn=lambda: (torch.rand(3),),
  4882. cpp_constructor_args='torch::nn::NLLLossOptions().weight(torch::rand(3))',
  4883. input_size=(2, 3, 5, 5),
  4884. target=torch.rand(2, 5, 5).mul(3).floor().long(),
  4885. reference_fn=lambda i, t, m:
  4886. loss_reference_fns['NLLLossNd'](i, t, weight=get_weight(m)),
  4887. desc='2d_weights',
  4888. check_bfloat16=True,
  4889. ),
  4890. dict(
  4891. module_name='NLLLoss',
  4892. constructor_args=(None, None, 1),
  4893. cpp_constructor_args='torch::nn::NLLLossOptions().weight({}).ignore_index(1)',
  4894. input_size=(2, 3, 5, 5),
  4895. target_fn=lambda: torch.rand(2, 5, 5).mul(3).floor().long(),
  4896. reference_fn=lambda i, t, m:
  4897. loss_reference_fns['NLLLossNd'](i, t, ignore_index=1),
  4898. desc='2d_ignore_index',
  4899. check_bfloat16=True,
  4900. ),
  4901. dict(
  4902. module_name='NLLLoss',
  4903. input_size=(2, 3, 5, 5, 2, 2),
  4904. target_fn=lambda: torch.rand(2, 5, 5, 2, 2).mul(3).floor().long(),
  4905. reference_fn=lambda i, t, m:
  4906. loss_reference_fns['NLLLossNd'](i, t, reduction=get_reduction(m)),
  4907. check_sum_reduction=True,
  4908. desc='higher_dim',
  4909. check_bfloat16=True,
  4910. ),
  4911. dict(
  4912. module_name='NLLLoss',
  4913. input_size=(2, 3, 5),
  4914. target_fn=lambda: torch.rand(2, 5).mul(3).floor().long(),
  4915. reference_fn=lambda i, t, m:
  4916. loss_reference_fns['NLLLossNd'](i, t, reduction=get_reduction(m)),
  4917. check_sum_reduction=True,
  4918. desc='dim_is_3',
  4919. check_bfloat16=True,
  4920. ),
  4921. dict(
  4922. module_name='CrossEntropyLoss',
  4923. input_size=(2, 3, 5, 5),
  4924. target_fn=lambda: torch.rand(2, 5, 5).mul(3).floor().long(),
  4925. reference_fn=lambda i, t, m:
  4926. loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m)),
  4927. check_sum_reduction=True,
  4928. desc='2d',
  4929. check_bfloat16=False,
  4930. ),
  4931. dict(
  4932. module_name='CrossEntropyLoss',
  4933. constructor_args_fn=lambda: (torch.rand(3),),
  4934. cpp_constructor_args='torch::nn::CrossEntropyLossOptions().weight(torch::rand(3))',
  4935. input_size=(2, 3, 5, 5),
  4936. target=torch.rand(2, 5, 5).mul(3).floor().long(),
  4937. reference_fn=lambda i, t, m:
  4938. loss_reference_fns['CrossEntropyLoss'](i, t, weight=get_weight(m)),
  4939. desc='2d_weights',
  4940. check_bfloat16=False,
  4941. ),
  4942. dict(
  4943. module_name='CrossEntropyLoss',
  4944. constructor_args=(None, None, 1),
  4945. cpp_constructor_args='torch::nn::CrossEntropyLossOptions().weight({}).ignore_index(1)',
  4946. input_size=(2, 3, 5, 5),
  4947. target_fn=lambda: torch.rand(2, 5, 5).mul(3).floor().long(),
  4948. reference_fn=lambda i, t, m:
  4949. loss_reference_fns['CrossEntropyLoss'](i, t, ignore_index=1),
  4950. desc='2d_ignore_index',
  4951. check_bfloat16=False,
  4952. ),
  4953. dict(
  4954. module_name='CrossEntropyLoss',
  4955. input_size=(2, 3, 5, 5, 2, 2),
  4956. target_fn=lambda: torch.rand(2, 5, 5, 2, 2).mul(3).floor().long(),
  4957. reference_fn=lambda i, t, m:
  4958. loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m)),
  4959. check_sum_reduction=True,
  4960. desc='higher_dim',
  4961. check_bfloat16=False,
  4962. ),
  4963. dict(
  4964. module_name='CrossEntropyLoss',
  4965. input_size=(2, 3, 5),
  4966. target_fn=lambda: torch.rand(2, 5).mul(3).floor().long(),
  4967. reference_fn=lambda i, t, m:
  4968. loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m)),
  4969. check_sum_reduction=True,
  4970. desc='dim_is_3',
  4971. check_bfloat16=False,
  4972. ),
  4973. dict(
  4974. module_name='CrossEntropyLoss',
  4975. input_size=(5, 3),
  4976. target_fn=lambda: torch.rand(5, 3).softmax(dim=1),
  4977. reference_fn=lambda i, t, m:
  4978. loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m)),
  4979. check_sum_reduction=True,
  4980. desc='2d_prob_target',
  4981. check_bfloat16=False,
  4982. ),
  4983. dict(
  4984. module_name='CrossEntropyLoss',
  4985. input_size=(5, 3, 4),
  4986. target_fn=lambda: torch.rand(5, 3, 4).softmax(dim=1),
  4987. reference_fn=lambda i, t, m:
  4988. loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m)),
  4989. check_sum_reduction=True,
  4990. desc='3d_prob_target',
  4991. check_bfloat16=False,
  4992. ),
  4993. dict(
  4994. module_name='CrossEntropyLoss',
  4995. input_size=(5, 3, 4, 2),
  4996. target_fn=lambda: torch.rand(5, 3, 4, 2).softmax(dim=1),
  4997. reference_fn=lambda i, t, m:
  4998. loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m)),
  4999. check_sum_reduction=True,
  5000. desc='4d_prob_target',
  5001. check_bfloat16=False,
  5002. ),
  5003. dict(
  5004. fullname='CrossEntropyLoss_2d_prob_target_smoothing_sum_reduction',
  5005. constructor=lambda *args, **kwargs: nn.CrossEntropyLoss(reduction='sum',
  5006. label_smoothing=0.15),
  5007. cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15).reduction(torch::kSum)',
  5008. input_size=(5, 3),
  5009. target_fn=lambda: torch.rand(5, 3).softmax(dim=1),
  5010. reference_fn=lambda i, t, m:
  5011. loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), label_smoothing=0.15),
  5012. check_bfloat16=False,
  5013. ),
  5014. dict(
  5015. fullname='CrossEntropyLoss_2d_prob_target_smoothing',
  5016. constructor=lambda *args: nn.CrossEntropyLoss(label_smoothing=0.15),
  5017. cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15)',
  5018. input_size=(5, 3),
  5019. target_fn=lambda: torch.rand(5, 3).softmax(dim=1),
  5020. reference_fn=lambda i, t, m:
  5021. loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), label_smoothing=0.15),
  5022. check_bfloat16=False,
  5023. ),
  5024. dict(
  5025. fullname='CrossEntropyLoss_2d_prob_target_smoothing_weight',
  5026. constructor_args_fn=lambda: (torch.rand(3).abs(),),
  5027. constructor=lambda weight: nn.CrossEntropyLoss(weight, label_smoothing=0.15),
  5028. cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15).weight(torch::rand(3).abs())',
  5029. input_size=(5, 3),
  5030. target_fn=lambda: torch.rand(5, 3).softmax(dim=1),
  5031. reference_fn=lambda i, t, m:
  5032. loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), weight=get_weight(m), label_smoothing=0.15),
  5033. check_bfloat16=False,
  5034. ),
  5035. dict(
  5036. fullname='CrossEntropyLoss_3d_prob_target_smoothing_sum_reduction',
  5037. constructor=lambda *args: nn.CrossEntropyLoss(reduction='sum',
  5038. label_smoothing=0.15),
  5039. cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15).reduction(torch::kSum)',
  5040. input_size=(5, 3, 4),
  5041. target_fn=lambda: torch.rand(5, 3, 4).softmax(dim=1),
  5042. reference_fn=lambda i, t, m:
  5043. loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), label_smoothing=0.15),
  5044. check_bfloat16=False,
  5045. ),
  5046. dict(
  5047. fullname='CrossEntropyLoss_3d_prob_target_smoothing',
  5048. constructor=lambda *args: nn.CrossEntropyLoss(label_smoothing=0.15),
  5049. cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15)',
  5050. input_size=(5, 3, 4),
  5051. target_fn=lambda: torch.rand(5, 3, 4).softmax(dim=1),
  5052. reference_fn=lambda i, t, m:
  5053. loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), label_smoothing=0.15),
  5054. check_bfloat16=False,
  5055. ),
  5056. dict(
  5057. fullname='CrossEntropyLoss_3d_indices_target_smoothing',
  5058. constructor=lambda *args: nn.CrossEntropyLoss(label_smoothing=0.15),
  5059. cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15)',
  5060. input_size=(2, 3, 5),
  5061. target_fn=lambda: torch.rand(2, 5).mul(3).floor().long(),
  5062. reference_fn=lambda i, t, m:
  5063. loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), label_smoothing=0.15),
  5064. check_bfloat16=False,
  5065. ),
  5066. dict(
  5067. fullname='CrossEntropyLoss_3d_indices_target_smoothing_ignore_index',
  5068. constructor=lambda *args: nn.CrossEntropyLoss(label_smoothing=0.15, ignore_index=1),
  5069. cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15).ignore_index(1)',
  5070. input_size=(2, 3, 5),
  5071. target_fn=lambda: torch.rand(2, 5).mul(3).floor().long(),
  5072. reference_fn=lambda i, t, m:
  5073. loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), label_smoothing=0.15, ignore_index=1),
  5074. check_bfloat16=False,
  5075. ),
  5076. dict(
  5077. fullname='CrossEntropyLoss_3d_indices_target_smoothing_sum_reduction',
  5078. constructor=lambda *args: nn.CrossEntropyLoss(reduction='sum', label_smoothing=0.15),
  5079. cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15).reduction(torch::kSum)',
  5080. input_size=(2, 3, 5),
  5081. target_fn=lambda: torch.rand(2, 5).mul(3).floor().long(),
  5082. reference_fn=lambda i, t, m:
  5083. loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), label_smoothing=0.15),
  5084. check_bfloat16=False,
  5085. ),
  5086. dict(
  5087. fullname='CrossEntropyLoss_3d_indices_target_smoothing_sum_reduction_ignore_index',
  5088. constructor=lambda *args: nn.CrossEntropyLoss(reduction='sum', label_smoothing=0.15,
  5089. ignore_index=1),
  5090. cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15).reduction(torch::kSum).ignore_index(1)',
  5091. input_size=(2, 3, 5),
  5092. target_fn=lambda: torch.rand(2, 5).mul(3).floor().long(),
  5093. reference_fn=lambda i, t, m:
  5094. loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), label_smoothing=0.15, ignore_index=1),
  5095. check_bfloat16=False,
  5096. ),
  5097. dict(
  5098. fullname='CrossEntropyLoss_2d_indices_target_smoothing',
  5099. constructor=lambda *args: nn.CrossEntropyLoss(label_smoothing=0.15),
  5100. cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15)',
  5101. input_size=(15, 10),
  5102. target_fn=lambda: torch.empty(15).uniform_().mul(10).floor().long(),
  5103. reference_fn=lambda i, t, m:
  5104. loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), label_smoothing=0.15),
  5105. check_bfloat16=False,
  5106. ),
  5107. dict(
  5108. fullname='CrossEntropyLoss_2d_indices_target_smoothing_sum_reduction',
  5109. constructor=lambda *args: nn.CrossEntropyLoss(reduction='sum', label_smoothing=0.15),
  5110. cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15).reduction(torch::kSum)',
  5111. input_size=(15, 10),
  5112. target_fn=lambda: torch.empty(15).uniform_().mul(10).floor().long(),
  5113. reference_fn=lambda i, t, m:
  5114. loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), label_smoothing=0.15),
  5115. check_bfloat16=False,
  5116. ),
  5117. dict(
  5118. fullname='CrossEntropyLoss_2d_indices_target_smoothing_ignore_index',
  5119. constructor=lambda *args: nn.CrossEntropyLoss(label_smoothing=0.15, ignore_index=3),
  5120. cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15).ignore_index(3)',
  5121. input_size=(15, 10),
  5122. target_fn=lambda: torch.empty(15).uniform_().mul(10).floor().long(),
  5123. reference_fn=lambda i, t, m:
  5124. loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), label_smoothing=0.15, ignore_index=3),
  5125. check_bfloat16=False,
  5126. ),
  5127. dict(
  5128. fullname='CrossEntropyLoss_2d_indices_target_smoothing_weight',
  5129. constructor_args_fn=lambda: (torch.rand(10).abs(),),
  5130. constructor=lambda weight: nn.CrossEntropyLoss(weight, label_smoothing=0.15),
  5131. cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15).weight(torch::rand(10).abs())',
  5132. input_size=(15, 10),
  5133. target_fn=lambda: torch.empty(15).uniform_().mul(10).floor().long(),
  5134. reference_fn=lambda i, t, m:
  5135. loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), weight=get_weight(m), label_smoothing=0.15),
  5136. check_bfloat16=False,
  5137. ),
  5138. dict(
  5139. module_name='CrossEntropyLoss',
  5140. constructor_args_fn=lambda: (torch.rand(3),),
  5141. cpp_constructor_args='torch::nn::CrossEntropyLossOptions().weight(torch::rand(3))',
  5142. input_size=(5, 3),
  5143. target_fn=lambda: torch.rand(5, 3).softmax(dim=1),
  5144. reference_fn=lambda i, t, m:
  5145. loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), weight=get_weight(m)),
  5146. check_sum_reduction=True,
  5147. desc='2d_prob_target_weights',
  5148. check_bfloat16=False,
  5149. ),
  5150. dict(
  5151. module_name='CrossEntropyLoss',
  5152. constructor_args_fn=lambda: (torch.rand(3),),
  5153. cpp_constructor_args='torch::nn::CrossEntropyLossOptions().weight(torch::rand(3))',
  5154. input_size=(5, 3, 4),
  5155. target_fn=lambda: torch.rand(5, 3, 4).softmax(dim=1),
  5156. reference_fn=lambda i, t, m:
  5157. loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), weight=get_weight(m)),
  5158. check_sum_reduction=True,
  5159. desc='3d_prob_target_weights',
  5160. check_bfloat16=False,
  5161. ),
  5162. dict(
  5163. module_name='CrossEntropyLoss',
  5164. constructor_args_fn=lambda: (torch.rand(3),),
  5165. cpp_constructor_args='torch::nn::CrossEntropyLossOptions().weight(torch::rand(3))',
  5166. input_size=(5, 3, 4, 2),
  5167. target_fn=lambda: torch.rand(5, 3, 4, 2).softmax(dim=1),
  5168. reference_fn=lambda i, t, m:
  5169. loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), weight=get_weight(m)),
  5170. check_sum_reduction=True,
  5171. desc='4d_prob_target_weights',
  5172. check_bfloat16=False,
  5173. ),
  5174. dict(
  5175. module_name='PoissonNLLLoss', # Default is log_input=True, full=False
  5176. input_size=(2, 3, 4, 5),
  5177. target_fn=lambda: torch.randn(2, 3, 4, 5).floor_().abs_(),
  5178. reference_fn=lambda i, t, _: (i.exp() - t.mul(i)).mean(),
  5179. desc='no_full_loss',
  5180. ),
  5181. dict(
  5182. module_name='PoissonNLLLoss',
  5183. constructor_args=(False, False), # log_input=False, full=False
  5184. cpp_constructor_args='torch::nn::PoissonNLLLossOptions().log_input(false).full(false)',
  5185. input_fn=lambda: torch.randn(2, 3, 4, 5).abs_().add_(0.001),
  5186. target_fn=lambda: torch.randn(2, 3, 4, 5).floor_().abs_(),
  5187. reference_fn=lambda i, t, _: (i - t.mul((i + 1e-8).log())).mean(),
  5188. desc='no_full_loss_no_log_input',
  5189. ),
  5190. dict(
  5191. module_name='PoissonNLLLoss',
  5192. constructor_args=(True, True), # log_input=True, full=True
  5193. cpp_constructor_args='torch::nn::PoissonNLLLossOptions().log_input(true).full(true)',
  5194. input_size=(2, 3, 4, 5),
  5195. target_fn=lambda: torch.randn(2, 3, 4, 5).floor_().abs_(),
  5196. reference_fn=lambda i, t, _:
  5197. (i.exp() - t.mul(i) + (t.mul(t.log()) - t + 0.5 * (2. * pi * t).log()).masked_fill(t <= 1, 0)).mean(),
  5198. desc='full_loss',
  5199. ),
  5200. dict(
  5201. module_name='PoissonNLLLoss',
  5202. constructor_args=(False, True), # log_input=False, full=True
  5203. cpp_constructor_args='torch::nn::PoissonNLLLossOptions().log_input(false).full(true)',
  5204. input_fn=lambda: torch.randn(2, 3, 4, 5).abs_().add_(0.001),
  5205. target_fn=lambda: torch.randn(2, 3, 4, 5).floor_().abs_(),
  5206. reference_fn=lambda i, t, _: (
  5207. i - t.mul((i + 1e-8).log()) + (t.mul(t.log()) - t + 0.5 * (2. * pi * t).log()).masked_fill(t <= 1, 0)
  5208. ).mean(),
  5209. desc='full_loss_no_log_input',
  5210. ),
  5211. dict(
  5212. module_name='L1Loss',
  5213. input_size=(),
  5214. target_fn=lambda: torch.randn((), requires_grad=True),
  5215. reference_fn=lambda i, t, _: 1. / i.numel() * (i - t).abs().sum(),
  5216. desc='scalar',
  5217. check_complex=True,
  5218. ),
  5219. dict(
  5220. module_name='KLDivLoss',
  5221. input_fn=lambda: torch.rand(()).log(),
  5222. target_fn=lambda: torch.rand(()),
  5223. reference_fn=lambda i, t, m:
  5224. kldivloss_reference(i, t, get_reduction(m)),
  5225. check_sum_reduction=True,
  5226. desc='scalar',
  5227. ),
  5228. dict(
  5229. module_name='KLDivLoss',
  5230. constructor=wraps(nn.KLDivLoss)(partial(nn.KLDivLoss, log_target=True)),
  5231. cpp_constructor_args='torch::nn::KLDivLossOptions().log_target(true)',
  5232. input_fn=lambda: torch.rand(()).log(),
  5233. target_fn=lambda: torch.rand(()).log(),
  5234. reference_fn=lambda i, t, m:
  5235. kldivloss_log_target_reference(i, t, get_reduction(m)),
  5236. check_sum_reduction=True,
  5237. desc='scalar_log_target',
  5238. ),
  5239. dict(
  5240. module_name='MSELoss',
  5241. input_size=(),
  5242. target_fn=lambda: torch.randn((), requires_grad=True),
  5243. reference_fn=lambda i, t, m: ((i - t).abs().pow(2).sum() /
  5244. (i.numel() if get_reduction(m) == 'mean' else 1)),
  5245. check_sum_reduction=True,
  5246. desc='scalar',
  5247. check_bfloat16=True,
  5248. ),
  5249. dict(
  5250. module_name='MSELoss',
  5251. input_fn=lambda: torch.ones(5, 68, 64, 64, dtype=torch.float) / 10,
  5252. target_fn=lambda: torch.zeros(5, 68, 64, 64, dtype=torch.float),
  5253. reference_fn=lambda i, t, m: ((i - t).abs().pow(2).sum() /
  5254. (i.numel() if get_reduction(m) == 'mean' else 1)),
  5255. check_forward_only=True,
  5256. desc='prec',
  5257. check_bfloat16=True,
  5258. ),
  5259. dict(
  5260. module_name='BCELoss',
  5261. constructor_args_fn=lambda: (torch.rand(()),),
  5262. cpp_constructor_args='torch::nn::BCELossOptions().weight(torch::rand({}))',
  5263. input_fn=lambda: torch.rand(()).clamp_(1e-2, 1 - 1e-2),
  5264. target_fn=lambda: torch.rand(()).gt(0).double(),
  5265. reference_fn=lambda i, t, m: -((t * i.log() + (1 - t) * (1 - i).log()) * get_weight(m)).sum() /
  5266. (i.numel() if get_reduction(m) == 'mean' else 1),
  5267. desc='scalar_weights',
  5268. check_bfloat16=True,
  5269. ),
  5270. dict(
  5271. module_name='HingeEmbeddingLoss',
  5272. constructor_args=(0.5,),
  5273. cpp_constructor_args='torch::nn::HingeEmbeddingLossOptions().margin(0.5)',
  5274. input_size=(),
  5275. target_fn=lambda: torch.randn(()).gt(0).double().mul_(2).sub(1),
  5276. desc='scalar_margin',
  5277. check_sum_reduction=True,
  5278. ),
  5279. dict(
  5280. module_name='SmoothL1Loss',
  5281. input_size=(),
  5282. target_fn=lambda: torch.randn((), requires_grad=True),
  5283. check_sum_reduction=True,
  5284. reference_fn=lambda i, t, m, b=1.0:
  5285. smoothl1loss_reference(i, t, reduction=get_reduction(m), beta=b),
  5286. desc='scalar',
  5287. ),
  5288. dict(
  5289. module_name='MultiLabelSoftMarginLoss',
  5290. constructor_args=(torch.rand(10),),
  5291. cpp_constructor_args='torch::nn::MultiLabelSoftMarginLossOptions().weight(torch::rand(10))',
  5292. input_fn=lambda: torch.randn(5, 10),
  5293. target_fn=lambda: torch.rand(5, 10).mul(2).floor(),
  5294. reference_fn=lambda i, t, m: -((t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()) * get_weight(m)).sum() /
  5295. (i.numel() if get_reduction(m) == 'mean' else i.size(1) if get_reduction(m) == 'sum' else 1),
  5296. desc='weights',
  5297. check_sum_reduction=True,
  5298. check_gradgrad=False,
  5299. ),
  5300. dict(
  5301. module_name='CTCLoss',
  5302. constructor_args=(14,), # blank=14
  5303. extra_args=([50, 50, 50], [30, 25, 20]), # input_lengths, target_lengths
  5304. input_fn=lambda: torch.randn(50, 3, 15).log_softmax(2),
  5305. target_fn=lambda: torch.randint(0, 14, (3, 30), dtype=torch.long),
  5306. reference_fn=lambda i, t, il, tl, m:
  5307. ctcloss_reference(i, t, il, tl, blank=14, reduction=get_reduction(m)),
  5308. desc='lengths_intlists',
  5309. check_forward_only=True,
  5310. check_sum_reduction=True,
  5311. check_gradgrad=False,
  5312. check_half=False,
  5313. # `CTCLoss` in C++ frontend doesn't accept integer list for `input_lengths` or `target_lengths`
  5314. test_cpp_api_parity=False,
  5315. check_jit=False,
  5316. ),
  5317. dict(
  5318. module_name='CTCLoss',
  5319. constructor_args=(14,), # blank=14
  5320. cpp_constructor_args='torch::nn::CTCLossOptions().blank(14)',
  5321. extra_args=(torch.tensor([50, 50, 50]), torch.tensor([30, 25, 20])), # input_lengths, target_lengths
  5322. input_fn=lambda: torch.randn(50, 3, 15).log_softmax(2),
  5323. target_fn=lambda: torch.randint(0, 14, (3, 30), dtype=torch.long),
  5324. reference_fn=lambda i, t, il, tl, m:
  5325. ctcloss_reference(i, t, il, tl, blank=14, reduction=get_reduction(m)),
  5326. desc='lengths_tensors',
  5327. check_forward_only=True,
  5328. check_sum_reduction=True,
  5329. check_gradgrad=False,
  5330. check_half=False,
  5331. ),
  5332. # Test is flaky
  5333. # See https://github.com/pytorch/pytorch/issues/29380.
  5334. # dict(
  5335. # module_name='CTCLoss',
  5336. # desc='1d_target',
  5337. # constructor_args=(14,), # blank=14
  5338. # extra_args=([50, 50, 50], [30, 25, 20]), # input_lengths, target_lengths
  5339. # input_fn=lambda: torch.randn(50, 3, 15).log_softmax(2),
  5340. # target_fn=lambda: torch.randint(0, 14, (3, 30), dtype=torch.long),
  5341. # reference_fn=lambda i, t, il, tl, m:
  5342. # ctcloss_reference(i, t, il, tl, blank=14, reduction=get_reduction(m)),
  5343. # check_sum_reduction=True,
  5344. # check_gradgrad=False,
  5345. # check_half=False,
  5346. # ),
  5347. dict(
  5348. module_name='CTCLoss',
  5349. desc='2d_int_target_lengths_intlists',
  5350. constructor_args=(0,), # blank=0
  5351. extra_args=([50, 50, 50], [30, 25, 20]), # input_lengths, target_lengths
  5352. input_fn=lambda: torch.randn(50, 3, 15).log_softmax(2),
  5353. target_fn=lambda: torch.randint(1, 15, (3, 30), dtype=torch.int),
  5354. reference_fn=lambda i, t, il, tl, m:
  5355. ctcloss_reference(i, t, il, tl, blank=0, reduction=get_reduction(m)),
  5356. check_forward_only=True,
  5357. check_sum_reduction=True,
  5358. check_gradgrad=False,
  5359. check_half=False,
  5360. # `CTCLoss` in C++ frontend doesn't accept integer list for `input_lengths` or `target_lengths`
  5361. test_cpp_api_parity=False,
  5362. check_jit=False,
  5363. ),
  5364. dict(
  5365. module_name='CTCLoss',
  5366. desc='2d_int_target_lengths_tensors',
  5367. constructor_args=(0,), # blank=0
  5368. cpp_constructor_args='torch::nn::CTCLossOptions().blank(0)',
  5369. extra_args=(torch.tensor([50, 50, 50]), torch.tensor([30, 25, 20])), # input_lengths, target_lengths
  5370. input_fn=lambda: torch.randn(50, 3, 15).log_softmax(2),
  5371. target_fn=lambda: torch.randint(1, 15, (3, 30), dtype=torch.int),
  5372. reference_fn=lambda i, t, il, tl, m:
  5373. ctcloss_reference(i, t, il, tl, blank=0, reduction=get_reduction(m)),
  5374. check_forward_only=True,
  5375. check_sum_reduction=True,
  5376. check_gradgrad=False,
  5377. check_half=False,
  5378. ),
  5379. dict(
  5380. module_name='CTCLoss',
  5381. desc='2d_lengths_tensors',
  5382. constructor_args=(0,), # blank=0
  5383. cpp_constructor_args='torch::nn::CTCLossOptions().blank(0)',
  5384. extra_args=(torch.tensor([50, 50, 50]), torch.tensor([30, 25, 20])), # input_lengths, target_lengths
  5385. input_fn=lambda: torch.randn(50, 3, 15).log_softmax(2),
  5386. target_fn=lambda: torch.randint(1, 15, (3, 30), dtype=torch.int),
  5387. reference_fn=lambda i, t, il, tl, m:
  5388. ctcloss_reference(i, t, il, tl, blank=0, reduction=get_reduction(m)),
  5389. check_forward_only=True,
  5390. check_sum_reduction=True,
  5391. check_gradgrad=False,
  5392. check_half=False,
  5393. ),
  5394. ]
  5395. def single_batch_reference_criterion_fn(*args):
  5396. """Reference function for criterion supporting no batch dimensions.
  5397. The criterion is passed the input and target in batched form with a single item.
  5398. The output is squeezed to compare with the no-batch input.
  5399. """
  5400. criterion = args[-1]
  5401. def unsqueeze_inp(inp):
  5402. if isinstance(inp, (list, tuple)):
  5403. return [t.unsqueeze(0) for t in inp]
  5404. return inp.unsqueeze(0)
  5405. def flatten(xs):
  5406. result = []
  5407. if isinstance(xs, (list, tuple)):
  5408. for x in xs:
  5409. result.extend(flatten(x))
  5410. else:
  5411. result.append(xs)
  5412. return result
  5413. single_batch_input_args = flatten([unsqueeze_inp(input) for input in args[:-1]])
  5414. output = criterion(*single_batch_input_args)
  5415. reduction = get_reduction(criterion)
  5416. if reduction == 'none':
  5417. return output.squeeze(0)
  5418. # reduction is 'sum' or 'mean' which results in a scalar
  5419. return output
  5420. # Check that regression criterion work with no batch dimensions
  5421. regression_criterion_no_batch = [
  5422. 'L1Loss', 'MSELoss', 'PoissonNLLLoss', 'HuberLoss', 'SmoothL1Loss'
  5423. ]
  5424. reductions = ['none', 'mean', 'sum']
  5425. for name, reduction in product(regression_criterion_no_batch, reductions):
  5426. regression_test_info = dict(
  5427. fullname="{}_no_batch_dim_{}".format(name, reduction),
  5428. constructor=lambda *args, name=name: getattr(nn, name)(reduction=reduction),
  5429. input_size=(3, ),
  5430. target_size=(3, ),
  5431. reference_fn=single_batch_reference_criterion_fn,
  5432. test_cpp_api_parity=False,
  5433. )
  5434. criterion_tests.append(regression_test_info)
  5435. for reduction in reductions:
  5436. regression_test_info = dict(
  5437. fullname=f"KLDivLoss_no_batch_dim_{reduction}",
  5438. constructor=lambda: nn.KLDivLoss(reduction=reduction),
  5439. input_fn=lambda: torch.rand((3,)).log(),
  5440. target_fn=lambda: torch.rand((3,)),
  5441. reference_fn=single_batch_reference_criterion_fn,
  5442. test_cpp_api_parity=False,
  5443. )
  5444. criterion_tests.append(regression_test_info)
  5445. # Check that classification criterion work with no batch dimensions
  5446. # List of tuples of (name, input_fn, target_fn)
  5447. classification_criterion_no_batch = [
  5448. ('BCELoss', lambda: torch.sigmoid(torch.randn(9)), lambda: torch.randn(9)),
  5449. ('BCEWithLogitsLoss', lambda: torch.randn(9), lambda: torch.randn(9)),
  5450. ('HingeEmbeddingLoss', lambda: torch.randn(9), lambda: torch.tensor([-1, 1, 1] * 3)),
  5451. ('MultiLabelMarginLoss', lambda: torch.randn(4), lambda: torch.tensor([3, 0, -1, 1])),
  5452. ('SoftMarginLoss', lambda: torch.randn(9), lambda: torch.tensor([-1, 1, 1] * 3)),
  5453. ('NLLLoss', lambda: F.log_softmax(torch.randn(3), dim=0), lambda: torch.tensor(1)),
  5454. ('CosineEmbeddingLoss', lambda: (torch.randn(9), torch.randn(9)), lambda: torch.tensor(1)),
  5455. # For MarginRankingLoss, input_fn : (x1, x2) and target_fn : target
  5456. ('MarginRankingLoss', lambda: (torch.randn(()), torch.randn(())), lambda: torch.randn(()).sign()),
  5457. # For TripletMarginLoss, input_fn : (anchor, positive) and target_fn : negative
  5458. ('TripletMarginLoss', lambda: (torch.randn(9), torch.randn(9)), lambda: torch.randn(9)),
  5459. ('MultiLabelSoftMarginLoss', lambda: torch.randn(9), lambda: torch.randn(9)),
  5460. ]
  5461. classification_criterion_no_batch_extra_info: Dict[str, dict] = {
  5462. 'MultiLabelMarginLoss': {'check_gradgrad': False},
  5463. }
  5464. # TODO : Fix these discrepancies
  5465. classification_cpp_parity = {
  5466. 'BCELoss': False,
  5467. 'BCEWithLogitsLoss': False,
  5468. 'HingeEmbeddingLoss': False,
  5469. 'NLLLoss': False,
  5470. 'SoftMarginLoss': False,
  5471. }
  5472. reductions = ['none', 'mean', 'sum']
  5473. for (name, input_fn, target_fn), reduction in product(classification_criterion_no_batch,
  5474. reductions):
  5475. classification_test_info = dict(
  5476. fullname="{}_no_batch_dim_{}".format(name, reduction),
  5477. constructor=lambda *args, name=name: getattr(nn, name)(reduction=reduction),
  5478. input_fn=lambda f=input_fn: f(),
  5479. target_fn=lambda f=target_fn: f(),
  5480. reference_fn=single_batch_reference_criterion_fn,
  5481. test_cpp_api_parity=True,
  5482. has_parity=classification_cpp_parity.get(name, True)
  5483. )
  5484. extra_info = classification_criterion_no_batch_extra_info.get(name, {})
  5485. classification_test_info.update(extra_info)
  5486. criterion_tests.append(classification_test_info)
  5487. class NNTestCase(TestCase):
  5488. # _forward is defined in classes inheriting from NNTestCase
  5489. @abstractmethod
  5490. def _forward(self, *args, **kwargs):
  5491. raise NotImplementedError
  5492. @abstractmethod
  5493. def _get_parameters(self, module: nn.Module) -> Tuple[List[nn.Parameter], List[nn.Parameter]]:
  5494. raise NotImplementedError
  5495. @abstractmethod
  5496. def _zero_grad_parameters(self, module: nn.Module) -> None:
  5497. raise NotImplementedError
  5498. @abstractmethod
  5499. def _backward(self, module: nn.Module,
  5500. input: _TensorOrTensors, output: torch.Tensor,
  5501. grad_output: Union[torch.Tensor, Sequence[torch.Tensor]],
  5502. create_graph: bool = False):
  5503. raise NotImplementedError
  5504. def _jacobian(self, input, num_out):
  5505. if isinstance(input, tuple):
  5506. return tuple(self._jacobian(elem, num_out) for elem in input)
  5507. elif isinstance(input, list):
  5508. return [self._jacobian(elem, num_out) for elem in input]
  5509. else:
  5510. return torch.zeros(input.nelement(), num_out)
  5511. def _flatten_tensors(self, x):
  5512. if isinstance(x, torch.Tensor):
  5513. if x.is_sparse:
  5514. return x.to_dense().view(-1)
  5515. else:
  5516. return x.view(-1)
  5517. else:
  5518. return tuple(self._flatten_tensors(a) for a in x)
  5519. def _zero_grad_input(self, input):
  5520. if isinstance(input, torch.Tensor):
  5521. if input.requires_grad and input.grad is not None:
  5522. input.grad.zero_()
  5523. input.grad.detach_()
  5524. else:
  5525. for i in input:
  5526. self._zero_grad_input(i)
  5527. def _analytical_jacobian(self, module, input: _TensorOrTensors, jacobian_input=True, jacobian_parameters=True):
  5528. output = self._forward(module, input)
  5529. output_size = output.nelement()
  5530. if jacobian_input:
  5531. jacobian_inp = self._jacobian(input, output_size)
  5532. flat_jacobian_input = list(_iter_tensors(jacobian_inp))
  5533. if jacobian_parameters:
  5534. num_param = sum(p.numel() for p in self._get_parameters(module)[0])
  5535. jacobian_param = torch.zeros(num_param, output_size)
  5536. for i in range(output_size):
  5537. param, d_param = self._get_parameters(module)
  5538. # make non grad zeros
  5539. d_param = [torch.zeros_like(p) if d is None else d for (p, d) in zip(param, d_param)]
  5540. d_out = torch.zeros_like(output)
  5541. flat_d_out = d_out.view(-1)
  5542. flat_d_out[i] = 1
  5543. if jacobian_parameters:
  5544. self._zero_grad_parameters(module)
  5545. # Tensors will accumulate gradient from multiple steps
  5546. if jacobian_input:
  5547. self._zero_grad_input(input)
  5548. d_input = self._backward(module, input, output, d_out)
  5549. if jacobian_input:
  5550. for jacobian_x, d_x in zip(flat_jacobian_input, _iter_tensors(d_input)):
  5551. jacobian_x[:, i] = d_x.contiguous().view(-1)
  5552. if jacobian_parameters:
  5553. jacobian_param[:, i] = torch.cat(self._flatten_tensors(d_param), 0)
  5554. res: Tuple[torch.Tensor, ...] = tuple()
  5555. if jacobian_input:
  5556. res += jacobian_inp,
  5557. if jacobian_parameters:
  5558. res += jacobian_param,
  5559. return res
  5560. def _numerical_jacobian(self, module, input: _TensorOrTensors, jacobian_input=True, jacobian_parameters=True):
  5561. def fw(*input):
  5562. return self._forward(module, input).detach()
  5563. res: Tuple[torch.Tensor, ...] = tuple()
  5564. if jacobian_input:
  5565. res += _get_numerical_jacobian(fw, input, eps=1e-6),
  5566. if jacobian_parameters:
  5567. param, _ = self._get_parameters(module)
  5568. to_cat = []
  5569. for p in param:
  5570. jacobian = _get_numerical_jacobian(fw, input, target=p, eps=1e-6)
  5571. # get_numerical_jacobian returns a list of tuples but we require a tensor
  5572. to_cat.append(jacobian[0][0])
  5573. res += (torch.cat(to_cat, 0),)
  5574. return res
  5575. def check_jacobian(self, module, input: _TensorOrTensors, jacobian_input=True):
  5576. jacobian_parameters = bool(self._get_parameters(module)[0])
  5577. analytical = self._analytical_jacobian(module, input, jacobian_input, jacobian_parameters)
  5578. numerical = self._numerical_jacobian(module, input, jacobian_input, jacobian_parameters)
  5579. analytical_t = list(_iter_tensors(analytical))
  5580. numerical_t = list(_iter_tensors(numerical))
  5581. differences = []
  5582. for a, n in zip(analytical_t, numerical_t):
  5583. if a.numel() != 0:
  5584. differences.append(a.add(n, alpha=-1).abs().max())
  5585. # TODO: compare structure (ensure analytic jacobian has correct shape)
  5586. if len(differences) > 0:
  5587. self.assertLessEqual(max(differences), PRECISION) # type: ignore[type-var]
  5588. class TestBase:
  5589. _required_arg_names = {'constructor_args', 'input', 'extra_args'}
  5590. def __init__(self, constructor, desc='', reference_fn=None, fullname=None, **kwargs):
  5591. self.desc = desc
  5592. self.fullname = fullname
  5593. self.constructor = constructor
  5594. self.reference_fn = reference_fn
  5595. for name in self._required_arg_names:
  5596. if name not in kwargs and name + '_fn' not in kwargs and name + '_size' not in kwargs:
  5597. if name in {'constructor_args', 'extra_args'}:
  5598. kwargs[name] = tuple()
  5599. else:
  5600. raise ValueError("{}: Specify {} by a value, a function to generate it, or it's size!"
  5601. .format(self.get_name(), name))
  5602. self._extra_kwargs = kwargs
  5603. self._arg_cache = {}
  5604. def get_name(self):
  5605. if self.fullname is not None:
  5606. return 'test_' + self.fullname
  5607. test_name = 'test_' + self.constructor.__name__
  5608. if self.desc:
  5609. test_name += '_' + self.desc
  5610. return test_name
  5611. def _unpack(self, value):
  5612. if isinstance(value, torch.Tensor):
  5613. return value
  5614. elif is_iterable(value):
  5615. return type(value)(self._unpack(v) for v in value)
  5616. else:
  5617. return value
  5618. @property
  5619. def constructor_args(self):
  5620. return self._get_arg('constructor_args', True)
  5621. @property
  5622. def extra_args(self):
  5623. return self._get_arg('extra_args', True)
  5624. def _get_arg(self, name, unpack):
  5625. assert name in self._required_arg_names
  5626. if name not in self._arg_cache:
  5627. fn_name = name + '_fn'
  5628. size_name = name + '_size'
  5629. if name in self._extra_kwargs:
  5630. self._arg_cache[name] = self._extra_kwargs[name]
  5631. elif fn_name in self._extra_kwargs:
  5632. self._arg_cache[name] = self._extra_kwargs[fn_name]()
  5633. else:
  5634. assert size_name in self._extra_kwargs, \
  5635. "Missing `{}`, `{}` or `{}` for {}".format(name, size_name, fn_name, self.get_name())
  5636. def map_tensor_sizes(sizes):
  5637. if isinstance(sizes, list):
  5638. return [map_tensor_sizes(s) for s in sizes]
  5639. elif isinstance(sizes, torch.Tensor):
  5640. return sizes.double()
  5641. else:
  5642. return torch.randn(sizes)
  5643. self._arg_cache[name] = map_tensor_sizes(self._extra_kwargs[size_name])
  5644. return self._unpack(self._arg_cache[name]) if unpack else self._arg_cache[name]
  5645. def _get_input(self, unpack=True):
  5646. return self._get_arg('input', unpack)
  5647. def __call__(self, test_case):
  5648. raise NotImplementedError
  5649. class ModuleTest(TestBase):
  5650. @abstractmethod
  5651. def _do_test(self, test_case: Any, module: nn.Module, input: Any) -> Any:
  5652. raise NotImplementedError
  5653. def __init__(self, *args, **kwargs):
  5654. super().__init__(*args, **kwargs)
  5655. self.jacobian_input = kwargs.get('jacobian_input', True)
  5656. self.should_test_cuda = kwargs.get('test_cuda', True)
  5657. self.should_test_pickle = kwargs.get('pickle', True)
  5658. self.check_gradgrad = kwargs.get('check_gradgrad', True)
  5659. self.FIXME_no_cuda_gradgrad_comparison = \
  5660. kwargs.get('FIXME_no_cuda_gradgrad_comparison', False)
  5661. self.precision = kwargs.get('precision', 2e-4)
  5662. self.check_forward_only = kwargs.get('check_forward_only', False)
  5663. def __call__(self, test_case):
  5664. module = self.constructor(*self.constructor_args)
  5665. input = self._get_input()
  5666. if self.reference_fn is not None:
  5667. out = test_case._forward(module, input)
  5668. ref_input = deepcopy(input)
  5669. ref_module = deepcopy(module)
  5670. expected_out = self.reference_fn(ref_input, test_case._get_parameters(module)[0], ref_module)
  5671. test_case.assertEqual(out, expected_out, exact_dtype=False)
  5672. if self.check_forward_only:
  5673. return
  5674. self.test_noncontig(test_case, module, input)
  5675. if self.should_test_pickle:
  5676. # TODO: do this with in-memory files as soon as torch.save will support it
  5677. with tempfile.TemporaryFile() as f:
  5678. test_case._forward(module, input)
  5679. torch.save(module, f)
  5680. f.seek(0)
  5681. module_copy = torch.load(f)
  5682. test_case.assertEqual(test_case._forward(module, input), test_case._forward(module_copy, input))
  5683. self._do_test(test_case, module, input)
  5684. def noncontiguize(self, obj):
  5685. if isinstance(obj, list):
  5686. return [self.noncontiguize(o) for o in obj]
  5687. elif isinstance(obj, tuple):
  5688. return tuple(self.noncontiguize(o) for o in obj)
  5689. tensor = obj
  5690. ndim = tensor.dim()
  5691. # Always making only the last dimension noncontiguous is easy to hide
  5692. # bugs because .view(-1) will still work. So try to find a dim with size
  5693. # > 1 and make that non-contiguous, i.e., stack + select on the
  5694. # dimension directly after that.
  5695. dim = ndim
  5696. for d in range(ndim):
  5697. if tensor.size(d) > 1:
  5698. dim = d + 1
  5699. break
  5700. noncontig = torch.stack([torch.empty_like(tensor), tensor], dim).select(dim, 1).detach()
  5701. assert noncontig.numel() == 1 or noncontig.numel() == 0 or not noncontig.is_contiguous()
  5702. noncontig.requires_grad = tensor.requires_grad
  5703. return noncontig
  5704. def test_noncontig(self, test_case, module, input):
  5705. # check no scalars, can't make non-contig
  5706. if isinstance(input, torch.Tensor) and input.dim() == 0:
  5707. return
  5708. if any(i.dim() == 0 for i in input if isinstance(i, torch.Tensor)):
  5709. return
  5710. test_case._zero_grad_parameters(module)
  5711. test_case._zero_grad_input(input)
  5712. with freeze_rng_state():
  5713. output = test_case._forward(module, input)
  5714. if getattr(module, "return_indices", False):
  5715. output = output[0]
  5716. grad_output = output.new(output.shape).normal_()
  5717. output = output.clone()
  5718. d_input = deepcopy(test_case._backward(module, input, output, grad_output))
  5719. d_param = deepcopy(test_case._get_parameters(module)[1])
  5720. nc_input = self.noncontiguize(input)
  5721. nc_grad_output = self.noncontiguize(grad_output)
  5722. for contig_i, contig_g in product((True, False), repeat=2):
  5723. i = input if contig_i else nc_input
  5724. # Some ops, e.g., nn.Flatten, return gradient that shares
  5725. # storage with the grad_output. Hence we copy here.
  5726. go = deepcopy(grad_output if contig_g else nc_grad_output)
  5727. test_case._zero_grad_parameters(module)
  5728. test_case._zero_grad_input(i)
  5729. with freeze_rng_state():
  5730. out = test_case._forward(module, i)
  5731. if getattr(module, "return_indices", False):
  5732. out = out[0]
  5733. grad = test_case._backward(module, i, out, go)
  5734. test_case.assertEqual(out, output)
  5735. test_case.assertEqual(grad, d_input, atol=1e-4, rtol=0)
  5736. test_case.assertEqual(test_case._get_parameters(module)[1], d_param)
  5737. def test_cuda(self, test_case):
  5738. if not TEST_CUDA or not self.should_test_cuda:
  5739. raise unittest.SkipTest('Excluded from CUDA tests')
  5740. cpu_input = self._get_input()
  5741. type_map = {torch.double: torch.float}
  5742. cpu_input_tuple = cpu_input if isinstance(cpu_input, tuple) else (cpu_input,)
  5743. gpu_input_tuple = to_gpu(cpu_input_tuple, type_map=type_map)
  5744. cpu_module = self.constructor(*self.constructor_args)
  5745. gpu_module = self.constructor(*self.constructor_args).float().cuda()
  5746. cpu_param = test_case._get_parameters(cpu_module)
  5747. gpu_param = test_case._get_parameters(gpu_module)
  5748. for cpu_p, gpu_p in zip(cpu_param[0], gpu_param[0]):
  5749. gpu_p.data.copy_(cpu_p)
  5750. test_case._zero_grad_input(cpu_input_tuple)
  5751. test_case._zero_grad_input(gpu_input_tuple)
  5752. test_case._zero_grad_parameters(cpu_module)
  5753. test_case._zero_grad_parameters(gpu_module)
  5754. cpu_output = test_case._forward(cpu_module, cpu_input_tuple)
  5755. gpu_output = test_case._forward(gpu_module, gpu_input_tuple)
  5756. if getattr(cpu_module, "return_indices", False):
  5757. cpu_output = cpu_output[0]
  5758. gpu_output = gpu_output[0]
  5759. test_case.assertEqual(cpu_output, gpu_output, atol=self.precision, rtol=0, exact_dtype=False)
  5760. # Run backwards on CPU and GPU and compare results
  5761. for _ in range(5):
  5762. cpu_gradOutput = cpu_output.clone().normal_()
  5763. gpu_gradOutput = cpu_gradOutput.type_as(gpu_output)
  5764. cpu_gradInput = test_case._backward(cpu_module, cpu_input_tuple, cpu_output, cpu_gradOutput)
  5765. gpu_gradInput = test_case._backward(gpu_module, gpu_input_tuple, gpu_output, gpu_gradOutput)
  5766. test_case.assertEqual(cpu_gradInput, gpu_gradInput, atol=self.precision, rtol=0, exact_dtype=False)
  5767. for cpu_d_p, gpu_d_p in zip(cpu_param[1], gpu_param[1]):
  5768. test_case.assertEqual(cpu_d_p, gpu_d_p, atol=self.precision, rtol=0)
  5769. # Run double-backwards on CPU and GPU and compare results
  5770. if self.check_gradgrad and not self.FIXME_no_cuda_gradgrad_comparison:
  5771. cpu_output = cpu_module(*cpu_input_tuple)
  5772. gpu_output = gpu_module(*gpu_input_tuple)
  5773. if getattr(cpu_module, "return_indices", False):
  5774. cpu_output = cpu_output[0]
  5775. gpu_output = gpu_output[0]
  5776. cpu_gradOutput = torch.randn_like(cpu_output, requires_grad=True)
  5777. gpu_gradOutput = cpu_gradOutput.type_as(gpu_output).detach()
  5778. gpu_gradOutput.requires_grad = True
  5779. cpu_gradInputs = torch.autograd.grad(
  5780. cpu_output,
  5781. cpu_input_tuple + tuple(cpu_module.parameters()),
  5782. cpu_gradOutput,
  5783. create_graph=True)
  5784. gpu_gradInputs = torch.autograd.grad(
  5785. gpu_output,
  5786. gpu_input_tuple + tuple(gpu_module.parameters()),
  5787. gpu_gradOutput,
  5788. create_graph=True)
  5789. for cpu_d_i, gpu_d_i in zip(cpu_gradInputs, gpu_gradInputs):
  5790. test_case.assertEqual(cpu_d_i, gpu_d_i, atol=self.precision, rtol=0, exact_dtype=False)
  5791. # We mix output into the second backwards computation so that
  5792. # torch.autograd.grad doesn't complain that some inputs
  5793. # are unreachable (which can happen if you differentiate
  5794. # only on the gradient.
  5795. cpu_gg = torch.autograd.grad(
  5796. cpu_output.sum() + sum(x.sum() for x in cpu_gradInputs),
  5797. cpu_input_tuple + (cpu_gradOutput,) + tuple(cpu_module.parameters()),
  5798. retain_graph=True)
  5799. gpu_gg = torch.autograd.grad(
  5800. gpu_output.sum() + sum(x.sum() for x in gpu_gradInputs),
  5801. gpu_input_tuple + (gpu_gradOutput,) + tuple(gpu_module.parameters()),
  5802. retain_graph=True)
  5803. test_case.assertEqual(cpu_gradInput, gpu_gradInput, atol=self.precision, rtol=0, exact_dtype=False)
  5804. for cpu_d_p, gpu_d_p in zip(cpu_gg, gpu_gg):
  5805. test_case.assertEqual(cpu_d_p, gpu_d_p, atol=self.precision, rtol=0, exact_dtype=False)
  5806. self.test_noncontig(test_case, gpu_module, gpu_input_tuple)
  5807. class InputVariableMixin:
  5808. def _get_input(self):
  5809. input = TestBase._get_input(self, False) # type: ignore[arg-type]
  5810. def map_variables(i):
  5811. if isinstance(i, torch.Tensor):
  5812. if i.is_floating_point() or i.is_complex():
  5813. i.requires_grad = True
  5814. return i
  5815. else:
  5816. return type(i)(map_variables(elem) for elem in i)
  5817. return map_variables(input)
  5818. class NewModuleTest(InputVariableMixin, ModuleTest): # type: ignore[misc]
  5819. def __init__(self, *args, **kwargs):
  5820. super().__init__(*args, **kwargs)
  5821. self.cudnn = kwargs.get('cudnn', False)
  5822. self.check_inplace = kwargs.get('check_inplace', False)
  5823. self.check_gradgrad = kwargs.get('check_gradgrad', True)
  5824. self.skip_double = kwargs.get('skip_double', False)
  5825. self.skip_half = kwargs.get('skip_half', False)
  5826. self.with_tf32 = kwargs.get('with_tf32', False)
  5827. self.tf32_precision = kwargs.get('tf32_precision', 0.001)
  5828. self.test_cpu = kwargs.get('test_cpu', True)
  5829. self.has_sparse_gradients = kwargs.get('has_sparse_gradients', False)
  5830. self.check_batched_grad = kwargs.get('check_batched_grad', True)
  5831. self.gradcheck_fast_mode = kwargs.get('gradcheck_fast_mode', None)
  5832. self.supports_forward_ad = kwargs.get('supports_forward_ad', False)
  5833. self.supports_fwgrad_bwgrad = kwargs.get('supports_fwgrad_bwgrad', False)
  5834. def _check_gradients(self, test_case, module, input_tuple):
  5835. params = tuple(x for x in module.parameters())
  5836. num_inputs = len(input_tuple)
  5837. def fn_to_gradcheck(*inputs_and_params, **kwargs):
  5838. assert not kwargs
  5839. return test_case._forward(module, inputs_and_params[:num_inputs])
  5840. # gradcheck doesn't support operators that take in dense inputs but
  5841. # return sparse parameters. This only happens in the case of nn.Embedding
  5842. # and nn.EmbeddingBag. Instead, we call `self.check_jacobian`, which
  5843. # is a slightly different version of gradcheck that can handle this.
  5844. if self.has_sparse_gradients:
  5845. assert num_inputs == 1
  5846. test_input_jacobian = torch.is_floating_point(input_tuple[0])
  5847. test_case.check_jacobian(module, input_tuple[0], test_input_jacobian)
  5848. else:
  5849. test_case.assertTrue(gradcheck(fn_to_gradcheck, input_tuple + params,
  5850. check_batched_grad=self.check_batched_grad,
  5851. fast_mode=self.gradcheck_fast_mode,
  5852. check_forward_ad=self.supports_forward_ad))
  5853. if self.check_gradgrad:
  5854. test_case.assertTrue(gradgradcheck(fn_to_gradcheck, input_tuple + params,
  5855. check_batched_grad=self.check_batched_grad,
  5856. fast_mode=self.gradcheck_fast_mode,
  5857. check_fwd_over_rev=self.supports_fwgrad_bwgrad))
  5858. def _do_test(self, test_case, module, input):
  5859. num_threads = torch.get_num_threads()
  5860. torch.set_num_threads(1)
  5861. input_tuple = input if isinstance(input, tuple) else (input,)
  5862. self._check_gradients(test_case, module, input_tuple)
  5863. # check if module can be printed
  5864. module.__repr__()
  5865. if self.check_inplace:
  5866. # check if the inplace variant of the module gives the same result
  5867. # as the out-of-place
  5868. # check_inplace doesn't support multiple input tensors, since we don't have any modules
  5869. # that modify the inputs in-place and that accept more than one input
  5870. assert len(input_tuple) == 1
  5871. input = input_tuple[0]
  5872. module_ip = self.constructor(*self.constructor_args, inplace=True)
  5873. input_version = input._version
  5874. with freeze_rng_state():
  5875. output = module(input)
  5876. test_case.assertEqual(input._version, input_version)
  5877. input_ip = deepcopy(input)
  5878. input_ip_clone = input_ip.clone()
  5879. with freeze_rng_state():
  5880. output_ip = module_ip(input_ip_clone)
  5881. test_case.assertNotEqual(input_ip_clone._version, input_version)
  5882. test_case.assertEqual(output, output_ip)
  5883. grad = output.data.clone().normal_()
  5884. if input.grad is not None:
  5885. with torch.no_grad():
  5886. input.grad.zero_()
  5887. if input_ip.grad is not None:
  5888. with torch.no_grad():
  5889. input_ip.grad.zero_()
  5890. output.backward(grad)
  5891. output_ip.backward(grad)
  5892. test_case.assertEqual(input.grad, input_ip.grad)
  5893. def assert_module_parameters_are(tensor_type, device_id=None):
  5894. for p in module.parameters():
  5895. test_case.assertIsInstance(p, tensor_type)
  5896. if device_id is not None:
  5897. test_case.assertEqual(p.get_device(), device_id)
  5898. if all(isinstance(t, torch.LongTensor) for t in input_tuple) and TEST_CUDA:
  5899. # check that cuda() moves module parameters to correct GPU device,
  5900. # and that float() casts parameters correctly
  5901. input_tuple = tuple(t.cuda() for t in input_tuple)
  5902. module.float().cuda()
  5903. module(*input_tuple)
  5904. assert_module_parameters_are(torch.cuda.FloatTensor, 0) # type: ignore[attr-defined]
  5905. if torch.cuda.device_count() > 1:
  5906. input_tuple = tuple(t.cuda(1) for t in input_tuple)
  5907. module.cuda(1)
  5908. with torch.cuda.device(1):
  5909. module(*input_tuple)
  5910. assert_module_parameters_are(torch.cuda.FloatTensor, 1) # type: ignore[attr-defined]
  5911. else:
  5912. # check that float()/double() casters work correctly
  5913. def to_type(tensor, real, complex):
  5914. if tensor.is_complex():
  5915. return tensor.to(complex)
  5916. elif tensor.is_floating_point():
  5917. return tensor.to(real)
  5918. else:
  5919. return tensor
  5920. def to_half(x):
  5921. # TODO: torch.complex32 when properly supported
  5922. return to_type(x, torch.float16, None)
  5923. def to_single(x):
  5924. return to_type(x, torch.float32, torch.complex64)
  5925. def to_double(x):
  5926. return to_type(x, torch.float64, torch.complex128)
  5927. # to float
  5928. input_tuple = tuple(to_single(t) for t in input_tuple)
  5929. module.float()
  5930. module(*input_tuple)
  5931. assert_module_parameters_are(torch.FloatTensor)
  5932. # and back to double
  5933. input_tuple = tuple(to_double(t) for t in input_tuple)
  5934. module.double()
  5935. module(*input_tuple)
  5936. assert_module_parameters_are(torch.DoubleTensor)
  5937. if TEST_CUDA and self.should_test_cuda:
  5938. # check that cuda() moves module parameters to correct GPU device,
  5939. # and that float() casts parameters correctly
  5940. # to GPU0
  5941. input_tuple = tuple(to_single(t).cuda() for t in input_tuple)
  5942. module.float().cuda()
  5943. module(*input_tuple)
  5944. assert_module_parameters_are(torch.cuda.FloatTensor, 0) # type: ignore[attr-defined]
  5945. # to CPU
  5946. input_tuple = tuple(t.cpu() for t in input_tuple)
  5947. module.cpu()
  5948. module(*input_tuple)
  5949. assert_module_parameters_are(torch.FloatTensor)
  5950. # back to GPU0
  5951. input_tuple = tuple(t.cuda() for t in input_tuple)
  5952. module.cuda()
  5953. module(*input_tuple)
  5954. assert_module_parameters_are(torch.cuda.FloatTensor, 0) # type: ignore[attr-defined]
  5955. # test that forwards of module runs correctly without cuDNN
  5956. if self.cudnn:
  5957. with torch.backends.cudnn.flags(enabled=False):
  5958. module(*input_tuple)
  5959. assert_module_parameters_are(torch.cuda.FloatTensor, 0) # type: ignore[attr-defined]
  5960. if torch.cuda.device_count() >= 2:
  5961. # test cross-GPU transfer works
  5962. # to GPU1
  5963. input_tuple = tuple(t.cuda(1) for t in input_tuple)
  5964. module.cuda(1)
  5965. with torch.cuda.device(1):
  5966. module(*input_tuple)
  5967. assert_module_parameters_are(torch.cuda.FloatTensor, 1) # type: ignore[attr-defined]
  5968. if not self.skip_double:
  5969. # test double()
  5970. input_tuple = tuple(to_double(t).cuda() for t in input_tuple)
  5971. module.double().cuda()
  5972. module(*input_tuple)
  5973. assert_module_parameters_are(torch.cuda.DoubleTensor, 0) # type: ignore[attr-defined]
  5974. # test half()
  5975. if not self.skip_half:
  5976. input_tuple = tuple(to_half(t).cuda() for t in input_tuple)
  5977. module.half().cuda()
  5978. module(*input_tuple)
  5979. assert_module_parameters_are(torch.cuda.HalfTensor, 0) # type: ignore[attr-defined]
  5980. torch.set_num_threads(num_threads)
  5981. def _get_target(self):
  5982. return self._get_arg('target', False)
  5983. @property
  5984. def constructor_args(self):
  5985. return self._get_arg('constructor_args', False)
  5986. class CriterionTest(InputVariableMixin, TestBase): # type: ignore[misc]
  5987. # TODO: check that criterions don't ignore grad_output
  5988. _required_arg_names = TestBase._required_arg_names.union({'target'})
  5989. def __init__(self, *args, **kwargs):
  5990. super().__init__(*args, **kwargs)
  5991. self.should_test_cuda = kwargs.get('test_cuda', True)
  5992. self.check_forward_only = kwargs.get('check_forward_only', False)
  5993. self.check_gradgrad = kwargs.get('check_gradgrad', True)
  5994. self.check_half = kwargs.get('check_half', True)
  5995. self.check_bfloat16 = kwargs.get('check_bfloat16', False)
  5996. self.check_complex = kwargs.get('check_complex', False)
  5997. self.test_cpu = kwargs.get('test_cpu', True)
  5998. self.with_tf32 = kwargs.get('with_tf32', True)
  5999. self.tf32_precision = kwargs.get('tf32_precision', 0.001)
  6000. self.check_batched_grad = kwargs.get('check_batched_grad', True)
  6001. def __call__(self, test_case):
  6002. module = self.constructor(*self.constructor_args)
  6003. input = self._get_input()
  6004. # Check that these methods don't raise errors
  6005. module.__repr__()
  6006. str(module)
  6007. target = self._get_target()
  6008. if self.reference_fn is not None:
  6009. out = test_case._forward_criterion(module, input, target, extra_args=self.extra_args)
  6010. ref_args = (deepcopy(input), deepcopy(target)) + self.extra_args + (module,)
  6011. expected_out = self.reference_fn(*ref_args)
  6012. test_case.assertEqual(out, expected_out)
  6013. if self.check_forward_only:
  6014. return
  6015. params = tuple(x for x in module.parameters())
  6016. if not isinstance(input, tuple):
  6017. inputs = (input,) + params + (target,)
  6018. def apply_fn(input, target, *params):
  6019. return module(input, target)
  6020. else:
  6021. inputs = input + params + (target,)
  6022. def apply_fn(input1, input2, target, *params): # type: ignore[misc]
  6023. return module(input1, input2, target)
  6024. gradcheck(apply_fn, inputs, check_batched_grad=self.check_batched_grad)
  6025. if self.check_gradgrad:
  6026. gradgradcheck(apply_fn, inputs, check_batched_grad=self.check_batched_grad)
  6027. def test_cuda(self, test_case, dtype, extra_args=None):
  6028. def convert_dtype(obj, dtype, requires_grad=False):
  6029. if isinstance(obj, torch.Tensor):
  6030. return obj.detach().to(dtype=dtype).requires_grad_(requires_grad)
  6031. elif isinstance(obj, tuple):
  6032. return tuple(convert_dtype(o, dtype, requires_grad) for o in obj)
  6033. else:
  6034. return obj
  6035. if not TEST_CUDA or not self.should_test_cuda:
  6036. raise unittest.SkipTest('Excluded from CUDA tests')
  6037. cpu_input = self._get_input()
  6038. cpu_target = self._get_target()
  6039. cpu_module = self.constructor(*self.constructor_args)
  6040. gpu_module = self.constructor(*self.constructor_args)
  6041. # Convert input, target and module parameters to dtype
  6042. cpu_input = convert_dtype(cpu_input, dtype, True)
  6043. if cpu_target.is_floating_point() or cpu_target.is_complex():
  6044. cpu_target = convert_dtype(cpu_target, dtype)
  6045. cpu_module.type(dtype)
  6046. gpu_module.type(dtype)
  6047. # GPU setup
  6048. gpu_input = to_gpu(cpu_input)
  6049. gpu_target = to_gpu(cpu_target)
  6050. gpu_module.cuda()
  6051. # torch.HalfTensor doesn't support most operations, converting back to default
  6052. if dtype in {torch.half, torch.bfloat16}:
  6053. cpu_input = self._get_input()
  6054. cpu_target = self._get_target()
  6055. # Loss modules with weights require consistent input/module weight types
  6056. cpu_module = self.constructor(*self.constructor_args)
  6057. cpu_output = test_case._forward_criterion(cpu_module, cpu_input, cpu_target, extra_args=extra_args)
  6058. gpu_output = test_case._forward_criterion(gpu_module, gpu_input, gpu_target, extra_args=extra_args)
  6059. # dtype used to be able to be None, so set precision in this way instead of a precision map
  6060. test_case.assertEqual(cpu_output, gpu_output,
  6061. atol=1e-1 if dtype in {torch.half, torch.bfloat16} else 4e-4, rtol=0, exact_dtype=False)
  6062. cpu_gradInput = test_case._backward_criterion(
  6063. cpu_module, cpu_input, cpu_output, cpu_target, extra_args=extra_args)
  6064. gpu_gradInput = test_case._backward_criterion(
  6065. gpu_module, gpu_input, gpu_output, gpu_target, extra_args=extra_args)
  6066. # dtype used to be able to be None, so set precision in this way instead of a precision map
  6067. test_case.assertEqual(cpu_gradInput, gpu_gradInput,
  6068. atol=1e-1 if dtype in {torch.half, torch.bfloat16} else 4e-4, rtol=0, exact_dtype=False)
  6069. def _get_target(self):
  6070. return self._get_arg('target', False)
  6071. @property
  6072. def constructor_args(self):
  6073. return self._get_arg('constructor_args', False)
  6074. @property
  6075. def extra_args(self):
  6076. return self._get_arg('extra_args', False)
  6077. def _test_bfloat16_ops(test_case, op, device, inp_dims=(), prec=1e-2, scale_factor=None):
  6078. # fp32 compute
  6079. input1 = torch.randn(inp_dims, dtype=torch.float32, device=device, requires_grad=True)
  6080. if scale_factor is not None:
  6081. input1 = (torch.rand(inp_dims, dtype=torch.bfloat16, device=device) * scale_factor).float().requires_grad_()
  6082. out1 = op(input1)
  6083. grad_input1 = torch.randn_like(out1, device=device)
  6084. out1.backward(grad_input1)
  6085. # bfloat16 compute
  6086. op_bfp16 = op.bfloat16()
  6087. input2 = input1.detach().bfloat16().requires_grad_()
  6088. grad_input2 = grad_input1.bfloat16()
  6089. out2 = op_bfp16(input2)
  6090. out2.backward(grad_input2)
  6091. test_case.assertEqual(out1, out2, atol=prec, rtol=prec, exact_dtype=False)
  6092. test_case.assertEqual(input1.grad.data, input2.grad.data, atol=prec, rtol=prec, exact_dtype=False)
  6093. def _test_module_empty_input(test_case, module, inp, check_size=True, inference=False):
  6094. if not inference:
  6095. inp.requires_grad_(True)
  6096. out = module(inp)
  6097. if not inference:
  6098. gO = torch.rand_like(out)
  6099. out.backward(gO)
  6100. if check_size:
  6101. test_case.assertEqual(out.size(), inp.size())
  6102. if not inference:
  6103. for p in module.parameters():
  6104. if p.requires_grad:
  6105. test_case.assertEqual(p.grad, torch.zeros_like(p.grad))
  6106. test_case.assertEqual(inp.grad, torch.zeros_like(inp))
  6107. def _create_basic_net():
  6108. class Layer(nn.Module):
  6109. def __init__(self):
  6110. super().__init__()
  6111. self.layer_dummy_param = nn.Parameter(torch.empty(3, 5))
  6112. self.register_buffer('layer_dummy_buf', torch.zeros(1, 3, 3, 7))
  6113. class Net(nn.Module):
  6114. def __init__(self):
  6115. super().__init__()
  6116. self.l1 = Layer()
  6117. self.dummy_param = nn.Parameter(torch.empty(3, 5))
  6118. self.register_buffer('dummy_buf', torch.zeros(7, 3, 3, 1))
  6119. l = Layer()
  6120. n = Net()
  6121. s = nn.Sequential(n, n)
  6122. return l, n, s