distributed_test.py 391 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839484048414842484348444845484648474848484948504851485248534854485548564857485848594860486148624863486448654866486748684869487048714872487348744875487648774878487948804881488248834884488548864887488848894890489148924893489448954896489748984899490049014902490349044905490649074908490949104911491249134914491549164917491849194920492149224923492449254926492749284929493049314932493349344935493649374938493949404941494249434944494549464947494849494950495149524953495449554956495749584959496049614962496349644965496649674968496949704971497249734974497549764977497849794980498149824983498449854986498749884989499049914992499349944995499649974998499950005001500250035004500550065007500850095010501150125013501450155016501750185019502050215022502350245025502650275028502950305031503250335034503550365037503850395040504150425043504450455046504750485049505050515052505350545055505650575058505950605061506250635064506550665067506850695070507150725073507450755076507750785079508050815082508350845085508650875088508950905091509250935094509550965097509850995100510151025103510451055106510751085109511051115112511351145115511651175118511951205121512251235124512551265127512851295130513151325133513451355136513751385139514051415142514351445145514651475148514951505151515251535154515551565157515851595160516151625163516451655166516751685169517051715172517351745175517651775178517951805181518251835184518551865187518851895190519151925193519451955196519751985199520052015202520352045205520652075208520952105211521252135214521552165217521852195220522152225223522452255226522752285229523052315232523352345235523652375238523952405241524252435244524552465247524852495250525152525253525452555256525752585259526052615262526352645265526652675268526952705271527252735274527552765277527852795280528152825283528452855286528752885289529052915292529352945295529652975298529953005301530253035304530553065307530853095310531153125313531453155316531753185319532053215322532353245325532653275328532953305331533253335334533553365337533853395340534153425343534453455346534753485349535053515352535353545355535653575358535953605361536253635364536553665367536853695370537153725373537453755376537753785379538053815382538353845385538653875388538953905391539253935394539553965397539853995400540154025403540454055406540754085409541054115412541354145415541654175418541954205421542254235424542554265427542854295430543154325433543454355436543754385439544054415442544354445445544654475448544954505451545254535454545554565457545854595460546154625463546454655466546754685469547054715472547354745475547654775478547954805481548254835484548554865487548854895490549154925493549454955496549754985499550055015502550355045505550655075508550955105511551255135514551555165517551855195520552155225523552455255526552755285529553055315532553355345535553655375538553955405541554255435544554555465547554855495550555155525553555455555556555755585559556055615562556355645565556655675568556955705571557255735574557555765577557855795580558155825583558455855586558755885589559055915592559355945595559655975598559956005601560256035604560556065607560856095610561156125613561456155616561756185619562056215622562356245625562656275628562956305631563256335634563556365637563856395640564156425643564456455646564756485649565056515652565356545655565656575658565956605661566256635664566556665667566856695670567156725673567456755676567756785679568056815682568356845685568656875688568956905691569256935694569556965697569856995700570157025703570457055706570757085709571057115712571357145715571657175718571957205721572257235724572557265727572857295730573157325733573457355736573757385739574057415742574357445745574657475748574957505751575257535754575557565757575857595760576157625763576457655766576757685769577057715772577357745775577657775778577957805781578257835784578557865787578857895790579157925793579457955796579757985799580058015802580358045805580658075808580958105811581258135814581558165817581858195820582158225823582458255826582758285829583058315832583358345835583658375838583958405841584258435844584558465847584858495850585158525853585458555856585758585859586058615862586358645865586658675868586958705871587258735874587558765877587858795880588158825883588458855886588758885889589058915892589358945895589658975898589959005901590259035904590559065907590859095910591159125913591459155916591759185919592059215922592359245925592659275928592959305931593259335934593559365937593859395940594159425943594459455946594759485949595059515952595359545955595659575958595959605961596259635964596559665967596859695970597159725973597459755976597759785979598059815982598359845985598659875988598959905991599259935994599559965997599859996000600160026003600460056006600760086009601060116012601360146015601660176018601960206021602260236024602560266027602860296030603160326033603460356036603760386039604060416042604360446045604660476048604960506051605260536054605560566057605860596060606160626063606460656066606760686069607060716072607360746075607660776078607960806081608260836084608560866087608860896090609160926093609460956096609760986099610061016102610361046105610661076108610961106111611261136114611561166117611861196120612161226123612461256126612761286129613061316132613361346135613661376138613961406141614261436144614561466147614861496150615161526153615461556156615761586159616061616162616361646165616661676168616961706171617261736174617561766177617861796180618161826183618461856186618761886189619061916192619361946195619661976198619962006201620262036204620562066207620862096210621162126213621462156216621762186219622062216222622362246225622662276228622962306231623262336234623562366237623862396240624162426243624462456246624762486249625062516252625362546255625662576258625962606261626262636264626562666267626862696270627162726273627462756276627762786279628062816282628362846285628662876288628962906291629262936294629562966297629862996300630163026303630463056306630763086309631063116312631363146315631663176318631963206321632263236324632563266327632863296330633163326333633463356336633763386339634063416342634363446345634663476348634963506351635263536354635563566357635863596360636163626363636463656366636763686369637063716372637363746375637663776378637963806381638263836384638563866387638863896390639163926393639463956396639763986399640064016402640364046405640664076408640964106411641264136414641564166417641864196420642164226423642464256426642764286429643064316432643364346435643664376438643964406441644264436444644564466447644864496450645164526453645464556456645764586459646064616462646364646465646664676468646964706471647264736474647564766477647864796480648164826483648464856486648764886489649064916492649364946495649664976498649965006501650265036504650565066507650865096510651165126513651465156516651765186519652065216522652365246525652665276528652965306531653265336534653565366537653865396540654165426543654465456546654765486549655065516552655365546555655665576558655965606561656265636564656565666567656865696570657165726573657465756576657765786579658065816582658365846585658665876588658965906591659265936594659565966597659865996600660166026603660466056606660766086609661066116612661366146615661666176618661966206621662266236624662566266627662866296630663166326633663466356636663766386639664066416642664366446645664666476648664966506651665266536654665566566657665866596660666166626663666466656666666766686669667066716672667366746675667666776678667966806681668266836684668566866687668866896690669166926693669466956696669766986699670067016702670367046705670667076708670967106711671267136714671567166717671867196720672167226723672467256726672767286729673067316732673367346735673667376738673967406741674267436744674567466747674867496750675167526753675467556756675767586759676067616762676367646765676667676768676967706771677267736774677567766777677867796780678167826783678467856786678767886789679067916792679367946795679667976798679968006801680268036804680568066807680868096810681168126813681468156816681768186819682068216822682368246825682668276828682968306831683268336834683568366837683868396840684168426843684468456846684768486849685068516852685368546855685668576858685968606861686268636864686568666867686868696870687168726873687468756876687768786879688068816882688368846885688668876888688968906891689268936894689568966897689868996900690169026903690469056906690769086909691069116912691369146915691669176918691969206921692269236924692569266927692869296930693169326933693469356936693769386939694069416942694369446945694669476948694969506951695269536954695569566957695869596960696169626963696469656966696769686969697069716972697369746975697669776978697969806981698269836984698569866987698869896990699169926993699469956996699769986999700070017002700370047005700670077008700970107011701270137014701570167017701870197020702170227023702470257026702770287029703070317032703370347035703670377038703970407041704270437044704570467047704870497050705170527053705470557056705770587059706070617062706370647065706670677068706970707071707270737074707570767077707870797080708170827083708470857086708770887089709070917092709370947095709670977098709971007101710271037104710571067107710871097110711171127113711471157116711771187119712071217122712371247125712671277128712971307131713271337134713571367137713871397140714171427143714471457146714771487149715071517152715371547155715671577158715971607161716271637164716571667167716871697170717171727173717471757176717771787179718071817182718371847185718671877188718971907191719271937194719571967197719871997200720172027203720472057206720772087209721072117212721372147215721672177218721972207221722272237224722572267227722872297230723172327233723472357236723772387239724072417242724372447245724672477248724972507251725272537254725572567257725872597260726172627263726472657266726772687269727072717272727372747275727672777278727972807281728272837284728572867287728872897290729172927293729472957296729772987299730073017302730373047305730673077308730973107311731273137314731573167317731873197320732173227323732473257326732773287329733073317332733373347335733673377338733973407341734273437344734573467347734873497350735173527353735473557356735773587359736073617362736373647365736673677368736973707371737273737374737573767377737873797380738173827383738473857386738773887389739073917392739373947395739673977398739974007401740274037404740574067407740874097410741174127413741474157416741774187419742074217422742374247425742674277428742974307431743274337434743574367437743874397440744174427443744474457446744774487449745074517452745374547455745674577458745974607461746274637464746574667467746874697470747174727473747474757476747774787479748074817482748374847485748674877488748974907491749274937494749574967497749874997500750175027503750475057506750775087509751075117512751375147515751675177518751975207521752275237524752575267527752875297530753175327533753475357536753775387539754075417542754375447545754675477548754975507551755275537554755575567557755875597560756175627563756475657566756775687569757075717572757375747575757675777578757975807581758275837584758575867587758875897590759175927593759475957596759775987599760076017602760376047605760676077608760976107611761276137614761576167617761876197620762176227623762476257626762776287629763076317632763376347635763676377638763976407641764276437644764576467647764876497650765176527653765476557656765776587659766076617662766376647665766676677668766976707671767276737674767576767677767876797680768176827683768476857686768776887689769076917692769376947695769676977698769977007701770277037704770577067707770877097710771177127713771477157716771777187719772077217722772377247725772677277728772977307731773277337734773577367737773877397740774177427743774477457746774777487749775077517752775377547755775677577758775977607761776277637764776577667767776877697770777177727773777477757776777777787779778077817782778377847785778677877788778977907791779277937794779577967797779877997800780178027803780478057806780778087809781078117812781378147815781678177818781978207821782278237824782578267827782878297830783178327833783478357836783778387839784078417842784378447845784678477848784978507851785278537854785578567857785878597860786178627863786478657866786778687869787078717872787378747875787678777878787978807881788278837884788578867887788878897890789178927893789478957896789778987899790079017902790379047905790679077908790979107911791279137914791579167917791879197920792179227923792479257926792779287929793079317932793379347935793679377938793979407941794279437944794579467947794879497950795179527953795479557956795779587959796079617962796379647965796679677968796979707971797279737974797579767977797879797980798179827983798479857986798779887989799079917992799379947995799679977998799980008001800280038004800580068007800880098010801180128013801480158016801780188019802080218022802380248025802680278028802980308031803280338034803580368037803880398040804180428043804480458046804780488049805080518052805380548055805680578058805980608061806280638064806580668067806880698070807180728073807480758076807780788079808080818082808380848085808680878088808980908091809280938094809580968097809880998100810181028103810481058106810781088109811081118112811381148115811681178118811981208121812281238124812581268127812881298130813181328133813481358136813781388139814081418142814381448145814681478148814981508151815281538154815581568157815881598160816181628163816481658166816781688169817081718172817381748175817681778178817981808181818281838184818581868187818881898190819181928193819481958196819781988199820082018202820382048205820682078208820982108211821282138214821582168217821882198220822182228223822482258226822782288229823082318232823382348235823682378238823982408241824282438244824582468247824882498250825182528253825482558256825782588259826082618262826382648265826682678268826982708271827282738274827582768277827882798280828182828283828482858286828782888289829082918292829382948295829682978298829983008301830283038304830583068307830883098310831183128313831483158316831783188319832083218322832383248325832683278328832983308331833283338334833583368337833883398340834183428343834483458346834783488349835083518352835383548355835683578358835983608361836283638364836583668367836883698370837183728373837483758376837783788379838083818382838383848385838683878388838983908391839283938394839583968397839883998400840184028403840484058406840784088409841084118412841384148415841684178418841984208421842284238424842584268427842884298430843184328433843484358436843784388439844084418442844384448445844684478448844984508451845284538454845584568457845884598460846184628463846484658466846784688469847084718472847384748475847684778478847984808481848284838484848584868487848884898490849184928493849484958496849784988499850085018502850385048505850685078508850985108511851285138514851585168517851885198520852185228523852485258526852785288529853085318532853385348535853685378538853985408541854285438544854585468547854885498550855185528553855485558556855785588559856085618562856385648565856685678568856985708571857285738574857585768577857885798580858185828583858485858586858785888589859085918592859385948595859685978598859986008601860286038604860586068607860886098610861186128613861486158616861786188619862086218622862386248625862686278628862986308631863286338634863586368637863886398640864186428643864486458646864786488649865086518652865386548655865686578658865986608661866286638664866586668667866886698670867186728673867486758676867786788679868086818682868386848685868686878688868986908691869286938694869586968697869886998700870187028703870487058706870787088709871087118712871387148715871687178718871987208721872287238724872587268727872887298730873187328733873487358736873787388739874087418742874387448745874687478748874987508751875287538754875587568757875887598760876187628763876487658766876787688769877087718772877387748775877687778778877987808781878287838784878587868787878887898790879187928793879487958796879787988799880088018802880388048805880688078808880988108811881288138814881588168817881888198820882188228823882488258826882788288829883088318832883388348835883688378838883988408841884288438844884588468847884888498850885188528853885488558856885788588859886088618862886388648865886688678868886988708871887288738874887588768877887888798880888188828883888488858886888788888889889088918892889388948895889688978898889989008901890289038904890589068907890889098910891189128913891489158916891789188919892089218922892389248925892689278928892989308931893289338934893589368937893889398940894189428943894489458946894789488949895089518952895389548955895689578958895989608961896289638964896589668967896889698970897189728973897489758976897789788979898089818982898389848985898689878988898989908991899289938994899589968997899889999000900190029003900490059006900790089009901090119012901390149015901690179018901990209021902290239024902590269027902890299030903190329033903490359036903790389039904090419042904390449045904690479048904990509051905290539054905590569057905890599060906190629063906490659066906790689069907090719072907390749075907690779078907990809081908290839084908590869087908890899090909190929093909490959096909790989099910091019102910391049105910691079108910991109111911291139114911591169117911891199120912191229123912491259126912791289129913091319132913391349135913691379138913991409141914291439144914591469147914891499150915191529153915491559156915791589159916091619162916391649165916691679168916991709171917291739174917591769177917891799180918191829183918491859186918791889189919091919192919391949195919691979198919992009201920292039204920592069207920892099210921192129213921492159216921792189219922092219222922392249225922692279228922992309231923292339234923592369237923892399240924192429243924492459246924792489249925092519252925392549255925692579258925992609261926292639264926592669267926892699270927192729273927492759276927792789279928092819282928392849285928692879288928992909291929292939294929592969297929892999300930193029303930493059306930793089309931093119312931393149315931693179318
  1. import copy
  2. import itertools
  3. import math
  4. import os
  5. import random
  6. import sys
  7. import tempfile
  8. import time
  9. from collections import namedtuple, OrderedDict
  10. from contextlib import contextmanager, suppress
  11. from datetime import timedelta
  12. from functools import reduce
  13. from typing import Union, NamedTuple, Callable, Any
  14. import numpy as np
  15. import torch
  16. import torch.cuda
  17. import torch.distributed as dist
  18. import torch.distributed.algorithms.model_averaging.averagers as averagers
  19. import torch.distributed.algorithms.model_averaging.hierarchical_model_averager as hierarchicalSGD
  20. import torch.distributed.algorithms.model_averaging.utils as model_averaging_utils
  21. import torch.nn as nn
  22. import torch.nn.functional as F
  23. from torch._utils_internal import TEST_MASTER_ADDR as MASTER_ADDR
  24. from torch._utils_internal import TEST_MASTER_PORT as MASTER_PORT
  25. from torch.cuda.amp import GradScaler, autocast
  26. from torch.distributed.algorithms.ddp_comm_hooks import (
  27. post_localSGD_hook as post_localSGD,
  28. powerSGD_hook as powerSGD,
  29. default_hooks as default,
  30. quantization as quantization_hooks,
  31. )
  32. from torch.distributed.optim import _apply_optimizer_in_backward
  33. from torch.distributed.distributed_c10d import (
  34. get_world_size,
  35. _get_default_group,
  36. AllreduceOptions,
  37. GroupMember,
  38. )
  39. from torch.distributed.utils import (
  40. _verify_param_shape_across_processes,
  41. _sync_module_states,
  42. )
  43. from torch.nn.parallel import DistributedDataParallel
  44. from torch.nn.parallel.distributed import _dump_DDP_relevant_env_vars
  45. from torch.testing._internal.common_distributed import (
  46. MultiProcessTestCase,
  47. TEST_SKIPS,
  48. init_multigpu_helper,
  49. initialize_temp_directories,
  50. cleanup_temp_dir,
  51. simple_sparse_reduce_tests,
  52. skip_if_rocm,
  53. skip_if_small_worldsize,
  54. skip_if_odd_worldsize,
  55. skip_if_lt_x_gpu,
  56. nccl_skip_if_lt_x_gpu,
  57. skip_if_no_gpu,
  58. require_n_gpus_for_nccl_backend,
  59. requires_nccl_version,
  60. captured_output,
  61. with_nccl_blocking_wait,
  62. with_dist_debug_levels,
  63. verify_ddp_error_logged,
  64. DistTestCases
  65. )
  66. from torch.testing._internal.common_utils import (
  67. instantiate_parametrized_tests,
  68. IS_MACOS,
  69. IS_WINDOWS,
  70. FILE_SCHEMA,
  71. IS_FBCODE,
  72. NO_MULTIPROCESSING_SPAWN,
  73. IS_SANDCASTLE,
  74. parametrize,
  75. sandcastle_skip,
  76. sandcastle_skip_if,
  77. )
  78. import torch.distributed.optim.post_localSGD_optimizer as post_localSGD_optimizer
  79. from torch.utils.data.distributed import DistributedSampler
  80. try:
  81. import torchvision
  82. HAS_TORCHVISION = True
  83. except ImportError:
  84. HAS_TORCHVISION = False
  85. if sys.platform == "win32":
  86. import msvcrt
  87. else:
  88. import fcntl
  89. class NetWithBuffers(nn.Module):
  90. def __init__(self):
  91. super().__init__()
  92. self.a = nn.Linear(10, 10, bias=False)
  93. self.b = nn.Linear(10, 1, bias=False)
  94. self.register_buffer('buffer', torch.randn(1, 2))
  95. def forward(self, x):
  96. self.buffer.add_(1)
  97. return self.b(self.a(x))
  98. class Foo:
  99. def __init__(self, x):
  100. # Can be tensor or int
  101. self.x = x
  102. def __eq__(self, other):
  103. def eq(value, other):
  104. if isinstance(value, torch.Tensor):
  105. return torch.equal(value, other)
  106. return value == other
  107. for attr, value in self.__dict__.items():
  108. other_value = other.__dict__[attr]
  109. if not eq(value, other_value):
  110. return False
  111. return True
  112. f = Foo(10)
  113. f.bar = 1
  114. foo_cpu_tensor = Foo(torch.randn(3, 3))
  115. COLLECTIVES_OBJECT_TEST_LIST = [
  116. {"key1": 3, "key2": 4, "key3": {"nested": True}},
  117. f,
  118. foo_cpu_tensor,
  119. "foo",
  120. [1, 2, True, "string", [4, 5, "nested"]],
  121. ]
  122. # Allowlist of distributed backends where profiling collectives is supported.
  123. PROFILING_SUPPORTED_BACKENDS = [
  124. dist.Backend.NCCL,
  125. dist.Backend.GLOO,
  126. dist.Backend.MPI,
  127. dist.Backend.UCC,
  128. ]
  129. # Allowlist of distributed backends where profiling is supported with use_cuda=True
  130. CUDA_PROFILING_SUPPORTED_BACKENDS = [
  131. dist.Backend.GLOO,
  132. dist.Backend.MPI,
  133. dist.Backend.NCCL,
  134. dist.Backend.UCC,
  135. ]
  136. # Allowlist of distributed backends where profiling is supported for p2p ops
  137. SEND_RECV_PROFILING_SUPPORTED_BACKENDS = [
  138. dist.Backend.MPI,
  139. dist.Backend.GLOO,
  140. dist.Backend.NCCL,
  141. dist.Backend.UCC,
  142. ]
  143. # Dummy NamedTuple data structures to test DDP support for NamedTuple types.
  144. EXPECTED_FIELDS = ("a", "b")
  145. TestNamedTupleInput_0 = namedtuple("NamedTuple", EXPECTED_FIELDS)
  146. class TestNamedTupleInput_1(NamedTuple):
  147. a: torch.tensor
  148. b: torch.tensor
  149. skipIfNoTorchVision = sandcastle_skip_if(not HAS_TORCHVISION, "no torchvision")
  150. BACKEND = os.environ["BACKEND"]
  151. INIT_METHOD = os.getenv("INIT_METHOD", "env://")
  152. DEFAULT_TIMEOUT = 300
  153. CUSTOMIZED_TIMEOUT = {"test_DistributedDataParallel": 500}
  154. def get_profiling_event(postfix, profiler):
  155. event_list = (
  156. profiler.events()
  157. if isinstance(profiler, torch.profiler.profile)
  158. else profiler.function_events
  159. )
  160. return [event for event in event_list if event.name.endswith(postfix)]
  161. # Base error message substring on unfinished reductions.
  162. ddp_prev_reduction_unfinished_str = (
  163. "Expected to have finished reduction in the prior iteration"
  164. )
  165. # Error message substring when find_unused_parameters=True has not been passed
  166. ddp_recommend_find_unused_params_str = (
  167. "passing the keyword argument `find_unused_parameters=True`"
  168. )
  169. # Error message substring when find_unused_parameters=True is enabled
  170. ddp_find_unused_params_enabled_str = "Since `find_unused_parameters=True` is enabled"
  171. # Error message substring for possibility of not all model outputs being used
  172. # in loss computation
  173. ddp_outputs_not_used_in_loss_str = (
  174. "`forward` function outputs participate in calculating loss"
  175. )
  176. # Error message substring suggesting to use TORCH_DISTRIBUTED_DEBUG
  177. ddp_suggest_debug_mode_str = (
  178. "set the environment variable TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL"
  179. )
  180. class DDPUnevenTestInput(NamedTuple):
  181. name: str
  182. model: nn.Module
  183. inp: Union[torch.tensor, tuple]
  184. sync_interval: int
  185. throw_on_early_termination: bool = False
  186. hook: Callable = None
  187. state: Any = None
  188. class _FC2(nn.Module):
  189. def __init__(self):
  190. super().__init__()
  191. self.fc = nn.Linear(10, 50, bias=True)
  192. self.fc.bias.requires_grad = False
  193. def forward(self, x):
  194. x = self.fc(x)
  195. return x
  196. class Net(nn.Module):
  197. def __init__(self):
  198. super().__init__()
  199. self.fc1 = nn.Linear(2, 10, bias=False)
  200. self.fc2 = _FC2()
  201. self.fc3 = nn.Linear(50, 4, bias=False)
  202. self.relu = nn.ReLU()
  203. self.no_grad_param = nn.Parameter(
  204. torch.tensor([2, 2]).long(), requires_grad=False
  205. )
  206. def forward(self, x):
  207. x = self.relu(self.fc1(x))
  208. x = self.relu(self.fc2(x))
  209. x = self.fc3(x)
  210. return F.softmax(x, dim=1)
  211. class LargeNet(nn.Module):
  212. def __init__(self):
  213. super().__init__()
  214. self.fc1 = nn.Linear(1000, 2000, bias=False)
  215. self.fc2 = nn.Linear(2000, 500, bias=False)
  216. def forward(self, x):
  217. x = self.fc1(x)
  218. x = self.fc2(x)
  219. return x
  220. class Task(nn.Module):
  221. def __init__(self):
  222. super().__init__()
  223. self.p = nn.Parameter(torch.ones(2, 2))
  224. def forward(self, x):
  225. return self.p + x
  226. class BatchNormNet(nn.Module):
  227. def __init__(self, affine=True):
  228. super().__init__()
  229. self.fc1 = nn.Linear(2, 40, bias=False)
  230. self.bn = nn.BatchNorm1d(4, affine=affine)
  231. self.fc2 = nn.Linear(40, 4, bias=False)
  232. def forward(self, x):
  233. x = torch.reshape(self.fc1(x), (-1, 4, 10))
  234. x = self.bn(x)
  235. x = torch.reshape(x, (-1, 40))
  236. x = self.fc2(x)
  237. return F.softmax(x, dim=1)
  238. class UnusedParamTwoLinLayerNet(nn.Module):
  239. def __init__(self):
  240. super().__init__()
  241. self.a = nn.Linear(10, 10, bias=False)
  242. self.b = nn.Linear(10, 10, bias=False)
  243. self.c = nn.Linear(5, 5, bias=False)
  244. def forward(self, x):
  245. a = self.a(x)
  246. b = self.b(x)
  247. return (a, b)
  248. class DictOutputModule(nn.Module):
  249. def __init__(self):
  250. super().__init__()
  251. self.module = UnusedParamTwoLinLayerNet()
  252. def forward(self, x):
  253. predictions = self.module(x)
  254. loss = (predictions[0] + predictions[1]).sum()
  255. return {
  256. "predictions": predictions,
  257. "loss": loss,
  258. }
  259. class TwoLinLayerNet(nn.Module):
  260. def __init__(self):
  261. super().__init__()
  262. self.a = nn.Linear(10, 10, bias=False)
  263. self.b = nn.Linear(10, 1, bias=False)
  264. def forward(self, x):
  265. a = self.a(x)
  266. b = self.b(x)
  267. return (a, b)
  268. class EmbeddingNetDifferentParams(nn.Module):
  269. """
  270. A module containing an embedding with different dimension or different # of
  271. parameters depending on the rank.
  272. """
  273. def __init__(self, rank, diff_num_params=False):
  274. super().__init__()
  275. embedding_dim = 500 if diff_num_params or rank == 0 else 50
  276. self.embedding = nn.Embedding(num_embeddings=10, embedding_dim=embedding_dim)
  277. self.lin = nn.Linear(embedding_dim, 1)
  278. if diff_num_params:
  279. self.lin2 = nn.Linear(1, 1, bias=False)
  280. def forward(self, x):
  281. x = self.embedding(x)
  282. return self.lin(x)
  283. class ControlFlowToyModel(nn.Module):
  284. def __init__(self):
  285. super().__init__()
  286. self.lin1 = nn.Linear(10, 10, bias=False)
  287. self.lin2 = nn.Linear(10, 10, bias=False)
  288. def forward(self, x):
  289. # Second layer is used dependent on input x.
  290. use_second_layer = torch.equal(x, torch.ones(20, 10, device=x.device))
  291. if use_second_layer:
  292. return self.lin2(F.relu(self.lin1(x)))
  293. else:
  294. return F.relu(self.lin1(x))
  295. DDP_NET = Net()
  296. BN_NET = BatchNormNet()
  297. BN_NET_NO_AFFINE = BatchNormNet(affine=False)
  298. ONLY_SBN_NET = nn.SyncBatchNorm(2, momentum=0.99)
  299. def get_timeout(test_id):
  300. test_name = test_id.split(".")[-1]
  301. if test_name in CUSTOMIZED_TIMEOUT:
  302. return CUSTOMIZED_TIMEOUT[test_name]
  303. else:
  304. return DEFAULT_TIMEOUT
  305. default_pg_timeout = 60
  306. CUSTOM_PG_TIMEOUT = {
  307. # This test runs slowly and needs additional time to complete, otherwise can
  308. # be taken down by NCCL_ASYNC_ERROR_HANDLING
  309. "test_ddp_uneven_inputs": 300,
  310. # This test has a short timeout since it tests being taken down by
  311. # NCCL_ASYNC_ERROR_HANDLING which we want to happen quickly.
  312. "test_ddp_model_diff_across_ranks": 5,
  313. }
  314. def require_backend(backends):
  315. if BACKEND not in backends:
  316. return sandcastle_skip("Test requires backend to be one of %s" % backends)
  317. return lambda func: func
  318. def require_backends_available(backends):
  319. def check(backend):
  320. if backend == dist.Backend.GLOO:
  321. return dist.is_gloo_available()
  322. if backend == dist.Backend.NCCL:
  323. return dist.is_nccl_available()
  324. if backend == dist.Backend.MPI:
  325. return dist.is_mpi_available()
  326. if backend == dist.Backend.UCC:
  327. return dist.is_ucc_available()
  328. if backend in DistTestCases.backend_feature["plugin"]:
  329. return True
  330. return False
  331. if not all(check(dist.Backend(backend)) for backend in backends):
  332. return sandcastle_skip("Test requires backends to be available %s" % backends)
  333. return lambda func: func
  334. def require_world_size(world_size):
  335. if int(os.environ["WORLD_SIZE"]) < world_size:
  336. return sandcastle_skip("Test requires world size of %d" % world_size)
  337. return lambda func: func
  338. @contextmanager
  339. def _lock():
  340. TEMP_DIR = os.environ["TEMP_DIR"]
  341. lockfile = os.path.join(TEMP_DIR, "lockfile")
  342. with open(lockfile, "w") as lf:
  343. try:
  344. if sys.platform == "win32":
  345. msvcrt.locking(lf.fileno(), msvcrt.LK_RLCK, 1)
  346. yield
  347. else:
  348. fcntl.flock(lf.fileno(), fcntl.LOCK_EX)
  349. yield
  350. finally:
  351. if sys.platform == "win32":
  352. msvcrt.locking(lf.fileno(), msvcrt.LK_UNLCK, 1)
  353. else:
  354. fcntl.flock(lf.fileno(), fcntl.LOCK_UN)
  355. lf.close()
  356. @contextmanager
  357. def _rank_temp_file():
  358. if dist.get_rank() == 0:
  359. fd, name = tempfile.mkstemp()
  360. os.close(fd)
  361. else:
  362. name = None
  363. object_list = [name]
  364. dist.broadcast_object_list(object_list)
  365. name = object_list[0]
  366. try:
  367. yield name
  368. finally:
  369. if dist.get_rank() == 0:
  370. os.remove(name)
  371. def _build_tensor(size, value=None, dtype=torch.float, device_id=None):
  372. if value is None:
  373. value = size
  374. if device_id is None:
  375. return torch.empty(size, size, size, dtype=dtype).fill_(value)
  376. else:
  377. return torch.empty(size, size, size, dtype=dtype).fill_(value).cuda(device_id)
  378. def _build_multidim_tensor(dim, dim_size, value=None, dtype=torch.float):
  379. if value is None:
  380. value = dim
  381. return torch.empty(size=[dim_size for _ in range(dim)], dtype=dtype).fill_(value)
  382. def _create_autograd_profiler():
  383. return torch.autograd.profiler.profile(record_shapes=True)
  384. def _create_torch_profiler():
  385. return torch.profiler.profile(
  386. activities=[
  387. torch.profiler.ProfilerActivity.CPU,
  388. ],
  389. record_shapes=True,
  390. )
  391. class Barrier:
  392. barrier_id = 0
  393. @classmethod
  394. def init(cls):
  395. cls.barrier_id = 0
  396. barrier_dir = os.path.join(os.environ["TEMP_DIR"], "barrier")
  397. for f_name in os.listdir(barrier_dir):
  398. os.unlink(os.path.join(barrier_dir, f_name))
  399. @classmethod
  400. def sync(cls, wait_for=None, timeout=10):
  401. if wait_for is None:
  402. wait_for = dist.get_world_size()
  403. cls.barrier_id += 1
  404. barrier_dir = os.path.join(os.environ["TEMP_DIR"], "barrier")
  405. pid = str(os.getpid())
  406. barrier_file = os.path.join(barrier_dir, pid)
  407. with _lock():
  408. with open(barrier_file, "w") as f:
  409. f.write(str(cls.barrier_id))
  410. start_time = time.time()
  411. while True:
  412. arrived = 0
  413. with _lock():
  414. for f_name in os.listdir(barrier_dir):
  415. with open(os.path.join(barrier_dir, f_name), "r") as f:
  416. data = f.read()
  417. if int(data) >= cls.barrier_id:
  418. arrived += 1
  419. if arrived == wait_for:
  420. break
  421. if time.time() - start_time > timeout:
  422. raise RuntimeError("barrier timeout")
  423. time.sleep(0.1)
  424. class TestDistBackend(MultiProcessTestCase):
  425. @classmethod
  426. def setUpClass(cls):
  427. os.environ["MASTER_ADDR"] = str(MASTER_ADDR)
  428. # Not setting MASTER_PORT and get a random free port
  429. super().setUpClass()
  430. def setUp(self):
  431. super().setUp()
  432. # initialize temp directories
  433. initialize_temp_directories()
  434. # initialize Barrier
  435. Barrier.init()
  436. # Skip return code checking for following tests as they are expected to
  437. # crash a process due to NCCL_ASYNC_ERROR_HANDLING.
  438. self.skip_return_code_checks = []
  439. def tearDown(self):
  440. cleanup_temp_dir()
  441. super().tearDown()
  442. @property
  443. def init_method(self):
  444. return "{}{file_name}".format(FILE_SCHEMA, file_name=self.file_name)
  445. @classmethod
  446. def _run(cls, rank, test_name, file_name, pipe):
  447. # Enable DDP + ReplicatedTensor
  448. from torch.nn.parallel._replicated_tensor_ddp_utils import _set_ddp_with_replicated_tensor
  449. _set_ddp_with_replicated_tensor(True)
  450. if BACKEND == "nccl" and not torch.cuda.is_available():
  451. sys.exit(TEST_SKIPS["no_cuda"].exit_code)
  452. self = cls(test_name)
  453. self.rank = rank
  454. self.file_name = file_name
  455. if torch.cuda.is_available() and torch.cuda.device_count() < int(
  456. self.world_size
  457. ):
  458. sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
  459. try:
  460. pg_timeout_seconds = CUSTOM_PG_TIMEOUT.get(test_name, default_pg_timeout)
  461. timeout = timedelta(seconds=pg_timeout_seconds)
  462. dist.init_process_group(
  463. init_method=self.init_method,
  464. backend=BACKEND,
  465. world_size=int(self.world_size),
  466. rank=self.rank,
  467. timeout=timeout,
  468. )
  469. except RuntimeError as e:
  470. if "recompile" in e.args[0]:
  471. sys.exit(TEST_SKIPS["backend_unavailable"].exit_code)
  472. raise
  473. # Execute barrier prior to running test to ensure that every process
  474. # has finished initialization and that the following test
  475. # immediately exiting due to a skip doesn't cause flakiness.
  476. self._barrier()
  477. self.run_test(test_name, pipe)
  478. self._barrier()
  479. dist.destroy_process_group()
  480. sys.exit(0)
  481. # Needed since MultiProcessTestCase assumes a world_size of 4, but we
  482. # run these tests under other various world_sizes.
  483. @property
  484. def world_size(self):
  485. return os.environ["WORLD_SIZE"]
  486. class DistributedTest:
  487. class _DistTestBase:
  488. def _barrier(self, *args, **kwargs):
  489. Barrier.sync(*args, **kwargs)
  490. def _init_group_test(self, **kwargs):
  491. group = [1, 2]
  492. group_id = dist.new_group(group, **kwargs)
  493. rank = dist.get_rank()
  494. if rank not in group:
  495. return ([], None, rank)
  496. return (group, group_id, rank)
  497. def _init_full_group_test(self, **kwargs):
  498. group = list(range(0, dist.get_world_size()))
  499. group_id = dist.new_group(**kwargs)
  500. rank = dist.get_rank()
  501. return (group, group_id, rank)
  502. def _init_global_test(self):
  503. group = list(range(0, dist.get_world_size()))
  504. group_id = dist.group.WORLD
  505. rank = dist.get_rank()
  506. return (group, group_id, rank)
  507. def _verify_buffers_equal(self, m1, m2):
  508. # verify buffers across models
  509. m1_buf_dict = {k: v for k, v in m1.module.named_buffers()}
  510. for name, buf in m2.module.named_buffers():
  511. self.assertEqual(buf, m1_buf_dict[name])
  512. # Verify buffers across ranks.
  513. m1_buffers = list(m1.buffers())
  514. m2_buffers = list(m2.buffers())
  515. for (buf1, buf2) in zip(m1_buffers, m2_buffers):
  516. gathered_bufs = [
  517. torch.empty_like(buf1) for _ in range(dist.get_world_size())
  518. ]
  519. dist.all_gather(gathered_bufs, buf1)
  520. gathered_bufs_m2 = [
  521. torch.empty_like(buf2) for _ in range(dist.get_world_size())
  522. ]
  523. for b in gathered_bufs:
  524. self.assertEqual(b, buf1)
  525. dist.all_gather(gathered_bufs_m2, buf2)
  526. for b in gathered_bufs_m2:
  527. self.assertEqual(b, buf2)
  528. def test_dump_DDP_relevant_env_vars(self):
  529. with captured_output() as (out, _):
  530. _dump_DDP_relevant_env_vars()
  531. lines = out.getvalue().splitlines()
  532. def format_line(var):
  533. return "env:%s=%s" % (
  534. var,
  535. os.environ[var] if var in os.environ else "N/A",
  536. )
  537. # Check relevant env vars
  538. vars = [
  539. "MASTER_ADDR",
  540. "MASTER_PORT",
  541. "WORLD_SIZE",
  542. "NCCL_TOPO_DUMP_FILE", # N/A
  543. "NCCL_ASYNC_ERROR_HANDLING",
  544. ]
  545. for var in vars:
  546. line = format_line(var)
  547. self.assertIn(line, lines)
  548. # Check irrelevant env vars
  549. vars = [
  550. "xxx",
  551. "yyy",
  552. "zzz",
  553. ]
  554. for var in vars:
  555. line = format_line(var)
  556. self.assertNotIn(line, lines)
  557. # GET RANK
  558. def test_get_rank(self):
  559. test_dir = os.path.join(os.environ["TEMP_DIR"], "test_dir")
  560. pid = str(os.getpid())
  561. num_processes = dist.get_world_size()
  562. with open(os.path.join(test_dir, pid), "w") as f:
  563. f.write(str(dist.get_rank()))
  564. self._barrier()
  565. all_ranks = set()
  566. for f_name in os.listdir(test_dir):
  567. with open(os.path.join(test_dir, f_name), "r") as f:
  568. all_ranks.add(int(f.read()))
  569. self.assertEqual(len(all_ranks), num_processes)
  570. self._barrier()
  571. if dist.get_rank() == 0:
  572. for f_name in os.listdir(test_dir):
  573. os.unlink(os.path.join(test_dir, f_name))
  574. self._barrier()
  575. def test_get_backend(self):
  576. if dist.get_world_size() > 2:
  577. group = [1, 2]
  578. else:
  579. group = [0, 1]
  580. group_id = dist.new_group(group)
  581. backend_str = BACKEND.lower()
  582. self.assertEqual(dist.get_backend(), backend_str)
  583. if dist.get_rank() in group:
  584. self.assertEqual(dist.get_backend(group_id), backend_str)
  585. else:
  586. with self.assertRaisesRegex(
  587. RuntimeError, "Invalid process group specified"
  588. ):
  589. dist.get_backend(group_id)
  590. def test_Backend_enum_class(self):
  591. # test parsing
  592. backend = BACKEND.lower()
  593. self.assertEqual(dist.Backend(BACKEND.upper()), backend)
  594. self.assertEqual(dist.Backend(BACKEND), backend)
  595. with self.assertRaises(ValueError):
  596. dist.Backend(None)
  597. with self.assertRaises(ValueError):
  598. dist.Backend(3)
  599. with self.assertRaises(ValueError):
  600. dist.Backend(["gloo"])
  601. # Test destroy
  602. def test_destroy_group(self):
  603. if dist.get_world_size() > 2:
  604. group = [1, 2]
  605. else:
  606. group = [0, 1]
  607. group_id = dist.new_group(group)
  608. self._barrier()
  609. dist.destroy_process_group(group_id)
  610. # Test get rank and size of group
  611. def test_get_rank_size_group(self):
  612. if dist.get_world_size() > 2:
  613. group = [1, 2]
  614. else:
  615. group = [0, 1]
  616. group_id = dist.new_group(group)
  617. if dist.get_rank() in group:
  618. self.assertEqual(dist.get_world_size(group_id), 2)
  619. self.assertTrue(dist.get_rank(group_id) in list(range(2)))
  620. else:
  621. self.assertEqual(dist.get_world_size(group_id), -1)
  622. self.assertEqual(dist.get_rank(group_id), -1)
  623. # Test destroy full groups
  624. def test_destroy_full_group(self):
  625. _, group_id, _ = self._init_full_group_test()
  626. self._barrier()
  627. dist.destroy_process_group(group_id)
  628. # Test get rank and size of full group
  629. def test_get_rank_size_full_group(self):
  630. _, group_id, _ = self._init_full_group_test()
  631. self.assertEqual(dist.get_world_size(group_id), dist.get_world_size())
  632. self.assertEqual(dist.get_rank(group_id), dist.get_rank())
  633. def _test_barrier_timeout(self, group_id, timeout):
  634. local_rank = dist.get_rank(group_id)
  635. # Only execute barrier on rank == 0, causing it to timeout
  636. if local_rank == 0:
  637. expected_time = time.time() + timeout.total_seconds()
  638. # In debug mode, we execute a monitored_barrier before the
  639. # collective, so assert on that.
  640. if dist.get_debug_level() == dist.DebugLevel.DETAIL:
  641. exception_ctx = self.assertRaisesRegex(
  642. Exception, "failed to pass monitoredBarrier"
  643. )
  644. else:
  645. exception_ctx = self.assertRaisesRegex(
  646. Exception, " (Timed out|closed|timeout) "
  647. )
  648. with exception_ctx:
  649. dist.barrier(group_id)
  650. self.assertGreaterAlmostEqual(time.time(), expected_time, delta=0.1)
  651. else:
  652. pass
  653. @sandcastle_skip_if(BACKEND != "gloo", "Only gloo backend supports timeouts")
  654. @sandcastle_skip_if(
  655. not INIT_METHOD.startswith("file://"),
  656. "Requires file:// initialization method. "
  657. + "Both tcp:// and env:// rely on the TCP store for which "
  658. "reinitialization has proven racy.",
  659. )
  660. def test_barrier_timeout_global(self):
  661. dist.destroy_process_group()
  662. # Explicitly pass world size to the barrier because we've
  663. # just destroyed any state in torch.distributed.
  664. self._barrier(wait_for=int(os.environ["WORLD_SIZE"]))
  665. # Reinitialize global process group
  666. timeout = timedelta(seconds=1)
  667. dist.init_process_group(
  668. init_method=INIT_METHOD,
  669. backend=BACKEND,
  670. world_size=int(os.environ["WORLD_SIZE"]),
  671. rank=self.rank,
  672. timeout=timeout,
  673. )
  674. self._test_barrier_timeout(dist.group.WORLD, timeout)
  675. @skip_if_small_worldsize
  676. @sandcastle_skip_if(BACKEND != "gloo", "Only gloo backend supports timeouts")
  677. def test_barrier_timeout_group(self):
  678. timeout = timedelta(seconds=5)
  679. _, group_id, _ = self._init_group_test(timeout=timeout)
  680. if group_id is not None:
  681. self._test_barrier_timeout(group_id, timeout)
  682. @sandcastle_skip_if(BACKEND != "gloo", "Only gloo backend supports timeouts")
  683. def test_barrier_timeout_full_group(self):
  684. timeout = timedelta(seconds=1)
  685. _, group_id, _ = self._init_full_group_test(timeout=timeout)
  686. if group_id is not None:
  687. self._test_barrier_timeout(group_id, timeout)
  688. # This test helper can only be used when using the Gloo or NCCL backend
  689. # **and** both the Gloo and NCCL backends are available.
  690. # See the @skip annotations below.
  691. def _test_group_override_backend(self, initializer):
  692. if BACKEND == "gloo":
  693. new_backend = "nccl"
  694. elif BACKEND == "nccl":
  695. new_backend = "gloo"
  696. elif BACKEND in DistTestCases.backend_feature["plugin"]:
  697. new_backend = "gloo"
  698. group, group_id, rank = initializer(backend=new_backend)
  699. if group_id is None:
  700. return
  701. if new_backend == "gloo":
  702. self.assertTrue(isinstance(group_id, dist.ProcessGroupGloo))
  703. if new_backend == "nccl":
  704. self.assertTrue(isinstance(group_id, dist.ProcessGroupNCCL))
  705. self.assertEqual(rank, group[dist.get_rank(group_id)])
  706. self.assertEqual(len(group), dist.get_world_size(group_id))
  707. # Pin device (so we avoid NCCL race conditions/deadlocks).
  708. group_rank = dist.get_rank(group_id)
  709. torch.cuda.set_device(group_rank)
  710. # Run broadcast of CUDA tensor (so it works for both Gloo and NCCL).
  711. tensor = _build_tensor(2, value=group_rank).cuda()
  712. dist.broadcast(tensor, src=group[0], group=group_id)
  713. self.assertEqual(_build_tensor(2, value=0), tensor.to("cpu"))
  714. @require_backend(DistTestCases.backend_feature["gpu"])
  715. @require_backends_available(DistTestCases.backend_feature["gpu"])
  716. @require_world_size(3)
  717. @skip_if_lt_x_gpu(2)
  718. def test_backend_group(self):
  719. self._test_group_override_backend(self._init_group_test)
  720. @require_backend(DistTestCases.backend_feature["gpu"])
  721. @require_backends_available(DistTestCases.backend_feature["gpu"])
  722. @skip_if_lt_x_gpu(3)
  723. def test_backend_full_group(self):
  724. self._test_group_override_backend(self._init_full_group_test)
  725. @sandcastle_skip_if(
  726. BACKEND not in DistTestCases.backend_feature["subgroup"],
  727. f"The {BACKEND} backend does not support creating subgroups on CUDA devices"
  728. )
  729. @require_world_size(4)
  730. @skip_if_lt_x_gpu(2)
  731. def test_new_subgroups(self):
  732. subgroup_size = 2
  733. cur_subgroup, subgroups = dist.new_subgroups(subgroup_size)
  734. world_size = dist.get_world_size()
  735. self.assertEqual(cur_subgroup.size(), subgroup_size)
  736. self.assertEqual(len(subgroups), world_size / subgroup_size)
  737. self.assertFalse(dist._rank_not_in_group(cur_subgroup))
  738. for subgroup in subgroups:
  739. dist.destroy_process_group(subgroup)
  740. @sandcastle_skip_if(
  741. BACKEND not in DistTestCases.backend_feature["subgroup"],
  742. f"The {BACKEND} backend does not support creating subgroups on CUDA devices"
  743. )
  744. @skip_if_no_gpu
  745. def test_new_subgroups_group_size_exceeds_world_size(self):
  746. with self.assertRaisesRegex(
  747. ValueError, "must not exceed"
  748. ):
  749. dist.new_subgroups(100)
  750. @sandcastle_skip_if(
  751. BACKEND not in DistTestCases.backend_feature["subgroup"],
  752. f"The {BACKEND} backend does not support creating subgroups on CUDA devices"
  753. )
  754. @require_world_size(4)
  755. @skip_if_lt_x_gpu(4)
  756. def test_new_subgroups_world_size_not_divisible_by_group_size(self):
  757. with self.assertRaisesRegex(
  758. ValueError, "The world size must be divisible by 'group_size'"
  759. ):
  760. dist.new_subgroups(3)
  761. @sandcastle_skip_if(
  762. BACKEND not in DistTestCases.backend_feature["subgroup"],
  763. f"The {BACKEND} backend does not support creating subgroups on CUDA devices"
  764. )
  765. @require_world_size(4)
  766. @skip_if_lt_x_gpu(4)
  767. def test_new_subgroups_by_enumeration(self):
  768. group, group_id, rank = self._init_global_test()
  769. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  770. device_id = rank_to_GPU[rank][0]
  771. cur_subgroup, subgroups = dist.new_subgroups_by_enumeration(
  772. ranks_per_subgroup_list=[[0, 2], [1, 3]]
  773. )
  774. if device_id >= 4:
  775. self.assertIsNone(cur_subgroup)
  776. else:
  777. self.assertEqual(cur_subgroup.size(), 2)
  778. self.assertEqual(len(subgroups), 2)
  779. if device_id == 0 or device_id == 2:
  780. self.assertEqual(cur_subgroup, subgroups[0])
  781. else:
  782. self.assertEqual(cur_subgroup, subgroups[1])
  783. for subgroup in subgroups:
  784. dist.destroy_process_group(subgroup)
  785. @sandcastle_skip_if(
  786. BACKEND not in DistTestCases.backend_feature["subgroup"],
  787. f"The {BACKEND} backend does not support creating subgroups on CUDA devices"
  788. )
  789. @require_world_size(4)
  790. @skip_if_lt_x_gpu(4)
  791. def test_new_subgroups_by_enumeration_input_rank_exceeds_world_size(self):
  792. group, group_id, rank = self._init_global_test()
  793. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  794. device_id = rank_to_GPU[rank][0]
  795. world_size = get_world_size(group_id)
  796. with self.assertRaisesRegex(
  797. RuntimeError,
  798. "The new group's rank should be within the the world_size set by init_process_group",
  799. ):
  800. dist.new_subgroups_by_enumeration(
  801. ranks_per_subgroup_list=[[0, 1], [world_size, 2]]
  802. )
  803. @sandcastle_skip_if(
  804. BACKEND not in DistTestCases.backend_feature["subgroup"],
  805. f"The {BACKEND} backend does not support creating subgroups on CUDA devices"
  806. )
  807. @skip_if_no_gpu
  808. def test_new_subgroups_by_enumeration_negative_input_rank(self):
  809. group, group_id, rank = self._init_global_test()
  810. with self.assertRaisesRegex(
  811. RuntimeError,
  812. "The new group's rank should be within the the world_size set by init_process_group",
  813. ):
  814. dist.new_subgroups_by_enumeration(
  815. ranks_per_subgroup_list=[[-1, -2], [-3, -4]]
  816. )
  817. @sandcastle_skip_if(
  818. BACKEND not in DistTestCases.backend_feature["subgroup"],
  819. f"The {BACKEND} backend does not support creating subgroups on CUDA devices"
  820. )
  821. @require_world_size(4)
  822. @skip_if_lt_x_gpu(4)
  823. def test_new_subgroups_overlap_not_allowed(self):
  824. with self.assertRaisesRegex(
  825. ValueError, "Rank 1 has appeared in both subgroup"
  826. ):
  827. dist.new_subgroups_by_enumeration(
  828. ranks_per_subgroup_list=[[0], [1, 2], [1, 3]]
  829. )
  830. @sandcastle_skip_if(
  831. BACKEND not in DistTestCases.backend_feature["subgroup"],
  832. f"The {BACKEND} backend does not support creating subgroups on CUDA devices"
  833. )
  834. @skip_if_lt_x_gpu(2)
  835. def test_average_parameters(self):
  836. rank = dist.get_rank()
  837. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  838. device_id = rank_to_GPU[rank][0]
  839. model = nn.Sequential(
  840. nn.Conv2d(3, 3, kernel_size=3, padding=1),
  841. nn.ReLU(),
  842. nn.Linear(1, 5, bias=False),
  843. ).cuda(device_id)
  844. # Test global model averaging
  845. for p in model.parameters():
  846. p.data = torch.ones_like(p.data)
  847. model_averaging_utils.average_parameters(
  848. params=model.parameters(), process_group=None
  849. )
  850. # Every element will be the same as the input.
  851. for p in model.parameters():
  852. self.assertEqual(p.data, torch.ones_like(p.data))
  853. # Test partial model averaging
  854. for p in model.parameters():
  855. p.data = torch.ones_like(p.data) * rank
  856. group_nccl = dist.new_group(ranks=[0, 1], backend="nccl")
  857. model_averaging_utils.average_parameters(
  858. params=model.parameters(), process_group=group_nccl
  859. )
  860. if not dist._rank_not_in_group(group_nccl):
  861. # Every element on device 0 or 1 should be the average of 0 and 1, i.e., 0.5.
  862. for p in model.parameters():
  863. self.assertEqual(p.data, torch.ones_like(p.data) * 0.5)
  864. else:
  865. # Every element on device not in the subgroup should remain the same.
  866. for p in model.parameters():
  867. self.assertEqual(p.data, torch.ones_like(p.data) * rank)
  868. @sandcastle_skip_if(
  869. BACKEND not in DistTestCases.backend_feature["subgroup"],
  870. f"The {BACKEND} backend does not support creating subgroups on CUDA devices"
  871. )
  872. @skip_if_lt_x_gpu(2)
  873. def test_periodic_model_averager(self):
  874. rank = dist.get_rank()
  875. world_size = dist.get_world_size()
  876. rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
  877. device_id = rank_to_GPU[rank][0]
  878. model = nn.Linear(1, 5, bias=False).cuda(device_id)
  879. param = next(model.parameters())
  880. tensor = torch.ones_like(param.data) * rank
  881. expected_avg_tensor = (
  882. torch.ones_like(param.data) * sum(range(world_size)) / world_size
  883. )
  884. period = 4
  885. for warmup_steps in [12, 13, 14, 15]:
  886. averager = averagers.PeriodicModelAverager(period=period, warmup_steps=warmup_steps)
  887. for step in range(0, 20):
  888. # Reset the parameters at every step.
  889. param.data = copy.deepcopy(tensor)
  890. for params in model.parameters():
  891. # mock grad
  892. params.grad = torch.ones_like(param.data)
  893. averager.average_parameters(model.parameters())
  894. if step >= warmup_steps and (step - warmup_steps) % period == 0:
  895. self.assertEqual(param.data, expected_avg_tensor)
  896. else:
  897. # No model averaging, so the parameters are not updated.
  898. self.assertEqual(param.data, tensor)
  899. @skip_if_lt_x_gpu(2)
  900. def test_periodic_model_averager_param_group(self):
  901. rank = dist.get_rank()
  902. world_size = dist.get_world_size()
  903. rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
  904. device_id = rank_to_GPU[rank][0]
  905. model = nn.Linear(1, 5, bias=False).cuda(device_id)
  906. param = next(model.parameters())
  907. opt = torch.optim.SGD(model.parameters(), lr=0.1)
  908. period = 4
  909. for warmup_steps in [12, 13, 14, 15]:
  910. averager = averagers.PeriodicModelAverager(period=period, warmup_steps=warmup_steps)
  911. for step in range(0, 20):
  912. # Reset the parameters at every step.
  913. for param_group in opt.param_groups:
  914. for params in param_group["params"]:
  915. # mock grad
  916. params.grad = torch.ones_like(param.data) * rank
  917. params.data = torch.ones_like(param.data) * rank
  918. averager.average_parameters(opt.param_groups)
  919. if step >= warmup_steps and (step - warmup_steps) % period == 0:
  920. for param_group in opt.param_groups:
  921. for params in param_group["params"]:
  922. if params.grad is None:
  923. continue
  924. self.assertEqual(param.data, torch.ones_like(param.data) * sum(range(world_size)) / world_size)
  925. else:
  926. # No model averaging, so the parameters are not updated.
  927. for param_group in opt.param_groups:
  928. for params in param_group["params"]:
  929. if params.grad is None:
  930. continue
  931. self.assertEqual(param.data, torch.ones_like(param.data) * rank)
  932. @sandcastle_skip_if(
  933. BACKEND not in DistTestCases.backend_feature["subgroup"],
  934. f"The {BACKEND} backend does not support creating subgroups on CUDA devices"
  935. )
  936. @skip_if_lt_x_gpu(2)
  937. def test_1_level_hierarchical_model_averager_equivalent_to_periodic_model_averager(self):
  938. rank = dist.get_rank()
  939. world_size = dist.get_world_size()
  940. rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
  941. device_id = rank_to_GPU[rank][0]
  942. model = nn.Linear(1, 5, bias=False).cuda(device_id)
  943. param = next(model.parameters())
  944. tensor = torch.ones_like(param.data) * rank
  945. expected_avg_tensor = (
  946. torch.ones_like(param.data) * sum(range(world_size)) / world_size
  947. )
  948. period = 4
  949. for warmup_steps in [12, 13, 14, 15]:
  950. averager = hierarchicalSGD.HierarchicalModelAverager(
  951. # Run the global averaging at a period of 4,
  952. # which is equivalent to the above periodic model averaging test case.
  953. period_group_size_dict=OrderedDict([(period, world_size)]), warmup_steps=warmup_steps
  954. )
  955. averager = averagers.PeriodicModelAverager(period=period, warmup_steps=warmup_steps)
  956. for step in range(0, 20):
  957. # Reset the parameters at every step.
  958. param.data = copy.deepcopy(tensor)
  959. for params in model.parameters():
  960. # mock grad
  961. params.grad = torch.ones_like(param.data)
  962. averager.average_parameters(model.parameters())
  963. if step >= warmup_steps and (step - warmup_steps) % period == 0:
  964. self.assertEqual(param.data, expected_avg_tensor)
  965. else:
  966. # No model averaging, so the parameters are not updated.
  967. self.assertEqual(param.data, tensor)
  968. @sandcastle_skip_if(
  969. BACKEND not in DistTestCases.backend_feature["subgroup"],
  970. f"The {BACKEND} backend does not support creating subgroups on CUDA devices"
  971. )
  972. @require_world_size(4)
  973. @skip_if_lt_x_gpu(4)
  974. def test_3_level_hierarchical_model_averager(self):
  975. rank = dist.get_rank()
  976. world_size = dist.get_world_size()
  977. rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
  978. device_id = rank_to_GPU[rank][0]
  979. model = nn.Linear(1, 5, bias=False).cuda(device_id)
  980. param = next(model.parameters())
  981. tensor = torch.ones_like(param.data) * rank
  982. # Set up such a hierarchical model averaging as follows:
  983. # after the first 10 warmup steps,
  984. # run model averaging every 2 steps within each subgroup of size 2,
  985. # run model averaging every 4 steps within each subgroup of size 3,
  986. # and run the global model averaging every 8 steps.
  987. # If there is a conflict in model averaging at a step, only run the highest-level model averaging.
  988. warmup_steps = 10
  989. subgroup_size1 = 2
  990. subgroup_avg_period1 = 2
  991. subgroup_size2 = 4
  992. subgroup_avg_period2 = 4
  993. global_avg_period = 8
  994. period_group_size_dict = OrderedDict(
  995. [(subgroup_avg_period1, subgroup_size1),
  996. (subgroup_avg_period2, subgroup_size2),
  997. (global_avg_period, world_size)])
  998. averager = hierarchicalSGD.HierarchicalModelAverager(
  999. period_group_size_dict=period_group_size_dict, warmup_steps=warmup_steps
  1000. )
  1001. subgroup1 = averager.period_process_group_dict[subgroup_avg_period1]
  1002. subgroup2 = averager.period_process_group_dict[subgroup_avg_period2]
  1003. real_group_ranks_res1 = dist.get_process_group_ranks(subgroup1)
  1004. real_group_ranks_res2 = dist.get_process_group_ranks(subgroup2)
  1005. expect_group_ranks_res1 = (rank // subgroup_size1 * subgroup_size1 + np.array(list(range(subgroup_size1)))).tolist()
  1006. expect_group_ranks_res2 = (rank // subgroup_size2 * subgroup_size2 + np.array(list(range(subgroup_size2)))).tolist()
  1007. self.assertEqual(real_group_ranks_res1, expect_group_ranks_res1)
  1008. self.assertEqual(real_group_ranks_res2, expect_group_ranks_res2)
  1009. expected_avg_tensor_within_subgroup1 = (
  1010. torch.ones_like(param.data) * sum(real_group_ranks_res1) / subgroup_size1
  1011. )
  1012. expected_avg_tensor_within_subgroup2 = (
  1013. torch.ones_like(param.data) * sum(real_group_ranks_res2) / subgroup_size2
  1014. )
  1015. expected_global_avg_tensor = (
  1016. torch.ones_like(param.data) * sum(range(world_size)) / world_size
  1017. )
  1018. for step in range(0, 25):
  1019. # Reset the parameters at every step.
  1020. param.data = copy.deepcopy(tensor)
  1021. for params in model.parameters():
  1022. # mock grad
  1023. params.grad = torch.ones_like(param.data)
  1024. averager.average_parameters(model.parameters())
  1025. if step == 16 or step == 24:
  1026. # Run global model averaging when `step` can be divided by 8.
  1027. self.assertEqual(param.data, expected_global_avg_tensor)
  1028. elif step == 12 or step == 20:
  1029. # Run model averaging within subgroup when `step` can be divided by 4 but not by 8.
  1030. self.assertEqual(param.data, expected_avg_tensor_within_subgroup2)
  1031. elif step == 10 or step == 14 or step == 18 or step == 22:
  1032. # Run model averaging within subgroup when `step` can be divided by 2 but not by 4 or 8.
  1033. self.assertEqual(param.data, expected_avg_tensor_within_subgroup1)
  1034. else:
  1035. # No model averaging, so the parameters are not updated.
  1036. self.assertEqual(param.data, tensor)
  1037. # NCCL Batch SEND RECV
  1038. @skip_if_no_gpu
  1039. @sandcastle_skip_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
  1040. @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
  1041. def test_batch_isend_irecv_nccl(self):
  1042. self._barrier()
  1043. rank = dist.get_rank()
  1044. world_size = dist.get_world_size()
  1045. rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
  1046. device_id = rank_to_GPU[rank][0]
  1047. torch.cuda.set_device(device_id)
  1048. p2p_op_list = []
  1049. recv_tensors = [None for _ in range(world_size)]
  1050. expected_tensors = [None for _ in range(world_size)]
  1051. for val in ["1", "0"]:
  1052. os.environ["NCCL_BLOCKING_WAIT"] = val
  1053. for src in range(0, world_size):
  1054. send_tensor = _build_tensor(rank + 1, device_id=device_id).fill_(src)
  1055. recv_tensors[src] = _build_tensor(src + 1, value=-1, device_id=device_id).fill_(-1)
  1056. expected_tensors[src] = _build_tensor(src + 1, value=-1, device_id=device_id).fill_(rank)
  1057. recv_op = dist.P2POp(dist.irecv, recv_tensors[src], src)
  1058. p2p_op_list.append(recv_op)
  1059. send_op = dist.P2POp(dist.isend, send_tensor, src)
  1060. p2p_op_list.append(send_op)
  1061. reqs = dist.batch_isend_irecv(p2p_op_list)
  1062. for req in reqs:
  1063. req.wait()
  1064. for src in range(0, world_size):
  1065. self.assertEqual(recv_tensors[src], expected_tensors[src])
  1066. self._barrier()
  1067. @skip_if_no_gpu
  1068. @sandcastle_skip_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
  1069. @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
  1070. def test_batch_isend_irecv_ring_exchange_nccl(self):
  1071. self._barrier()
  1072. rank = dist.get_rank()
  1073. world_size = dist.get_world_size()
  1074. rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
  1075. device_id = rank_to_GPU[rank][0]
  1076. torch.cuda.set_device(device_id)
  1077. p2p_op_list = []
  1078. send_tensor = _build_tensor(world_size, device_id=device_id)
  1079. recv_tensor = _build_tensor(world_size, value=-1, device_id=device_id)
  1080. send_op = dist.P2POp(dist.isend, send_tensor, (rank + 1) % world_size)
  1081. recv_op = dist.P2POp(dist.irecv, recv_tensor, (rank - 1 + world_size) % world_size)
  1082. reqs = dist.batch_isend_irecv([send_op, recv_op])
  1083. for req in reqs:
  1084. req.wait()
  1085. self._barrier()
  1086. @skip_if_no_gpu
  1087. @sandcastle_skip_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
  1088. @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
  1089. def test_batch_isend_irecv_self_nccl(self):
  1090. self._barrier()
  1091. # Ensure the process group has been fully initialized (needed by
  1092. # the first sub-group batch_isend_irecv call)
  1093. dist.barrier()
  1094. rank = dist.get_rank()
  1095. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  1096. device_id = rank_to_GPU[rank][0]
  1097. p2p_op_list = []
  1098. if rank == 0:
  1099. send_tensor = _build_tensor(rank + 1, device_id=device_id)
  1100. recv_tensor = _build_tensor(rank + 1, value=-1, device_id=device_id)
  1101. recv_op = dist.P2POp(dist.irecv, recv_tensor, 0)
  1102. p2p_op_list.append(recv_op)
  1103. send_op = dist.P2POp(dist.isend, send_tensor, 0)
  1104. p2p_op_list.append(send_op)
  1105. reqs = dist.batch_isend_irecv(p2p_op_list)
  1106. for req in reqs:
  1107. req.wait()
  1108. self._barrier()
  1109. @skip_if_no_gpu
  1110. @skip_if_small_worldsize
  1111. @sandcastle_skip_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
  1112. @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
  1113. def test_batch_isend_irecv_no_rank_zero_nccl(self):
  1114. self._barrier()
  1115. # Ensure the process group has been fully initialized (needed by
  1116. # the first sub-group batch_isend_irecv call)
  1117. dist.barrier()
  1118. rank = dist.get_rank()
  1119. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  1120. device_id = rank_to_GPU[rank][0]
  1121. torch.cuda.set_device(device_id)
  1122. p2p_op_list = []
  1123. if rank == 1:
  1124. peer = 2
  1125. elif rank == 2:
  1126. peer = 1
  1127. if rank in [1, 2]:
  1128. send_tensor = _build_tensor(rank + 1, device_id=device_id)
  1129. recv_tensor = _build_tensor(peer + 1, value=-1, device_id=device_id)
  1130. recv_op = dist.P2POp(dist.irecv, recv_tensor, peer)
  1131. p2p_op_list.append(recv_op)
  1132. send_op = dist.P2POp(dist.isend, send_tensor, peer)
  1133. p2p_op_list.append(send_op)
  1134. reqs = dist.batch_isend_irecv(p2p_op_list)
  1135. for req in reqs:
  1136. req.wait()
  1137. self._barrier()
  1138. # GLOO Batch SEND RECV CPU
  1139. @sandcastle_skip_if(BACKEND != "gloo", "GLOO Batch Send Recv CPU")
  1140. def test_batch_isend_irecv_gloo(self):
  1141. self._barrier()
  1142. rank = dist.get_rank()
  1143. p2p_op_list = []
  1144. for src in range(0, dist.get_world_size()):
  1145. if src == rank:
  1146. continue
  1147. send_tensor = _build_tensor(rank + 1)
  1148. recv_tensor = _build_tensor(src + 1, value=-1)
  1149. recv_op = dist.P2POp(dist.irecv, recv_tensor, src)
  1150. p2p_op_list.append(recv_op)
  1151. send_op = dist.P2POp(dist.isend, send_tensor, src)
  1152. p2p_op_list.append(send_op)
  1153. reqs = dist.batch_isend_irecv(p2p_op_list)
  1154. for req in reqs:
  1155. req.wait()
  1156. self._barrier()
  1157. # GLOO Batch SEND RECV CPU with provided tags
  1158. @sandcastle_skip_if(BACKEND != "gloo", "GLOO Batch Send Recv CPU")
  1159. def test_batch_isend_irecv_gloo_tags(self):
  1160. self._barrier()
  1161. rank = dist.get_rank()
  1162. p2p_op_list = []
  1163. for src in range(0, dist.get_world_size()):
  1164. if src == rank:
  1165. continue
  1166. send_tensor = _build_tensor(rank + 1)
  1167. recv_tensor = _build_tensor(src + 1, value=-1)
  1168. recv_op = dist.P2POp(dist.irecv, recv_tensor, src, tag=src)
  1169. p2p_op_list.append(recv_op)
  1170. send_op = dist.P2POp(dist.isend, send_tensor, src, tag=rank)
  1171. p2p_op_list.append(send_op)
  1172. reqs = dist.batch_isend_irecv(p2p_op_list)
  1173. for req in reqs:
  1174. req.wait()
  1175. self._barrier()
  1176. # NCCL Batch SEND RECV Tensor Error
  1177. @sandcastle_skip_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
  1178. @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
  1179. def test_batch_isend_irecv_tensor_err(self):
  1180. self._barrier()
  1181. rank = dist.get_rank()
  1182. if rank == 0:
  1183. with self.assertRaisesRegex(
  1184. RuntimeError, "Tensors must be CUDA and dense"
  1185. ):
  1186. send_tensor = _build_tensor(rank + 1)
  1187. send_op = dist.P2POp(dist.isend, send_tensor, 1)
  1188. dist.batch_isend_irecv([send_op])
  1189. # NCCL Batch SEND RECV Op Error
  1190. @sandcastle_skip_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
  1191. @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
  1192. def test_batch_isend_irecv_op_err(self):
  1193. self._barrier()
  1194. rank = dist.get_rank()
  1195. if rank == 0:
  1196. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  1197. device_id = rank_to_GPU[rank][0]
  1198. with self.assertRaisesRegex(RuntimeError, "^Invalid ``op``"):
  1199. send_tensor = _build_tensor(rank + 1, device_id=device_id)
  1200. send_op = dist.P2POp(dist.broadcast, send_tensor, 1)
  1201. dist.batch_isend_irecv([send_op])
  1202. # NCCL Batch SEND RECV p2p_op_list Error
  1203. @sandcastle_skip_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
  1204. @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
  1205. def test_batch_isend_irecv_op_list_err(self):
  1206. self._barrier()
  1207. rank = dist.get_rank()
  1208. if rank == 0:
  1209. with self.assertRaisesRegex(RuntimeError, "^Invalid ``p2p_op_list``"):
  1210. dist.batch_isend_irecv([1, 2])
  1211. # NCCL Batch SEND RECV Mixed Backend Error
  1212. @sandcastle_skip_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
  1213. @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
  1214. def test_batch_isend_irecv_mixed_backend_err(self):
  1215. self._barrier()
  1216. rank = dist.get_rank()
  1217. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  1218. device_id = rank_to_GPU[rank][0]
  1219. group_gloo = dist.new_group(ranks=[0, 1], backend="gloo")
  1220. group_nccl = dist.new_group(ranks=[0, 1], backend="nccl")
  1221. if rank == 0:
  1222. with self.assertRaisesRegex(
  1223. RuntimeError, "All ops need to use the same group"
  1224. ):
  1225. send_tensor = _build_tensor(rank + 1)
  1226. send_op_gloo = dist.P2POp(dist.isend, send_tensor, 1, group_gloo)
  1227. send_op_nccl = dist.P2POp(dist.isend, send_tensor, 1, group_nccl)
  1228. dist.batch_isend_irecv([send_op_gloo, send_op_nccl])
  1229. # NCCL SEND RECV
  1230. @skip_if_no_gpu
  1231. @sandcastle_skip_if(BACKEND != "nccl", "NCCL Send Recv Only")
  1232. @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
  1233. def _test_send_recv_nccl(self, profiler_ctx=None):
  1234. # TODO: now that nccl send/recv is supported, there does not seem to
  1235. # be a need to have nccl send/recv be tested separately.
  1236. rank = dist.get_rank()
  1237. world_size = dist.get_world_size()
  1238. rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
  1239. device_id = rank_to_GPU[rank][0]
  1240. torch.cuda.set_device(device_id)
  1241. tensor = _build_tensor(rank + 1, device_id=device_id)
  1242. profiler_cls = profiler_ctx if profiler_ctx is not None else suppress()
  1243. with profiler_cls as prof:
  1244. for src in range(0, world_size):
  1245. if src == rank:
  1246. # Send mode
  1247. for dst in range(0, world_size):
  1248. if dst == rank:
  1249. continue
  1250. dist.send(tensor, dst)
  1251. else:
  1252. # Recv mode
  1253. expected_tensor = _build_tensor(src + 1)
  1254. output_tensor = _build_tensor(
  1255. src + 1, value=-1, device_id=device_id
  1256. )
  1257. dist.recv(output_tensor, src)
  1258. self.assertEqual(output_tensor, expected_tensor)
  1259. self._barrier()
  1260. if profiler_ctx is not None:
  1261. backend = dist.get_backend()
  1262. if backend in SEND_RECV_PROFILING_SUPPORTED_BACKENDS:
  1263. for event_name in [f"{backend}:send", f"{backend}:recv"]:
  1264. events = get_profiling_event(event_name, prof)
  1265. self.assertTrue(events)
  1266. # Event order is not deterministic, so simply assert their shape
  1267. # is found in the following list.
  1268. expected_shapes = [
  1269. [[rank + 1] * 3] for rank in range(dist.get_world_size())
  1270. ]
  1271. for event in events:
  1272. self.assertTrue(event.input_shapes in expected_shapes)
  1273. @skip_if_no_gpu
  1274. @sandcastle_skip_if(BACKEND != "nccl", "NCCL Send Recv Only")
  1275. @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
  1276. def test_send_recv_nccl(self):
  1277. self._test_send_recv_nccl()
  1278. @skip_if_no_gpu
  1279. @sandcastle_skip_if(BACKEND != "nccl", "NCCL Send Recv Only")
  1280. @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
  1281. def test_send_recv_nccl_autograd_profiler(self):
  1282. profiler_ctx = torch.autograd.profiler.profile(record_shapes=True)
  1283. self._test_send_recv_nccl(profiler_ctx)
  1284. @skip_if_no_gpu
  1285. @sandcastle_skip_if(BACKEND != "nccl", "NCCL Send Recv Only")
  1286. @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
  1287. @sandcastle_skip_if(IS_FBCODE, "Kineto in fbcode causes hang")
  1288. @sandcastle_skip_if(
  1289. IS_MACOS or IS_WINDOWS,
  1290. "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
  1291. )
  1292. def test_send_recv_nccl_torch_profiler(self):
  1293. profiler_ctx = torch.profiler.profile(
  1294. activities=[
  1295. torch.profiler.ProfilerActivity.CPU,
  1296. torch.profiler.ProfilerActivity.CUDA,
  1297. ],
  1298. record_shapes=True,
  1299. )
  1300. self._test_send_recv_nccl(profiler_ctx)
  1301. # SEND RECV
  1302. def _test_send_recv(self, profiler_ctx):
  1303. rank = dist.get_rank()
  1304. send_size = rank + 1
  1305. tensor = _build_tensor(send_size)
  1306. ctx = profiler_ctx if profiler_ctx is not None else suppress()
  1307. with ctx as prof:
  1308. for src in range(0, dist.get_world_size()):
  1309. if src == rank:
  1310. # Send mode
  1311. for dst in range(0, dist.get_world_size()):
  1312. if dst == rank:
  1313. continue
  1314. dist.send(tensor, dst)
  1315. else:
  1316. # Recv mode
  1317. recv_size = src + 1
  1318. expected_tensor = _build_tensor(recv_size)
  1319. output_tensor = _build_tensor(recv_size, value=-1)
  1320. dist.recv(output_tensor, src)
  1321. self.assertEqual(output_tensor, expected_tensor)
  1322. if profiler_ctx is not None:
  1323. backend = dist.get_backend()
  1324. if backend in SEND_RECV_PROFILING_SUPPORTED_BACKENDS:
  1325. for event_name in [f"{backend}:send", f"{backend}:recv"]:
  1326. events = get_profiling_event(event_name, prof)
  1327. # Each rank sends/recvs from all other ranks.
  1328. event_count = sum(e.count for e in events)
  1329. expected_event_count = dist.get_world_size() - 1
  1330. self.assertEqual(event_count, expected_event_count)
  1331. # Event order is not deterministic, so simply assert their shape
  1332. # is found in the following list.
  1333. expected_shapes = [
  1334. [[rank + 1] * 3] for rank in range(dist.get_world_size())
  1335. ]
  1336. for event in events:
  1337. self.assertTrue(event.is_async)
  1338. self.assertTrue(event.input_shapes in expected_shapes)
  1339. @sandcastle_skip_if(
  1340. BACKEND == "nccl", "Nccl send/recv tested by test_send_recv_nccl"
  1341. )
  1342. def test_send_recv(self):
  1343. self._test_send_recv(profiler_ctx=None)
  1344. @sandcastle_skip_if(
  1345. BACKEND == "nccl", "NCCL send/recv tested by test_send_recv_nccl"
  1346. )
  1347. def test_send_recv_autograd_profiler(self):
  1348. autograd_profiler_ctx = _create_autograd_profiler()
  1349. self._test_send_recv(profiler_ctx=autograd_profiler_ctx)
  1350. @sandcastle_skip_if(
  1351. BACKEND == "nccl", "NCCL send/recv tested by test_send_recv_nccl"
  1352. )
  1353. @sandcastle_skip_if(IS_FBCODE, "Kineto in fbcode causes hang")
  1354. @sandcastle_skip_if(
  1355. IS_MACOS or IS_WINDOWS,
  1356. "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
  1357. )
  1358. def test_send_recv_torch_profiler(self):
  1359. torch_profiler_ctx = _create_torch_profiler()
  1360. return self._test_send_recv(profiler_ctx=torch_profiler_ctx)
  1361. # SEND RECV ANY SOURCE
  1362. def _test_send_recv_any_source(self, profiler_ctx):
  1363. rank = dist.get_rank()
  1364. send_recv_size = 10
  1365. tensor = _build_tensor(send_recv_size, value=rank)
  1366. recv_ranks = list()
  1367. irecv_ranks = list()
  1368. ctx = profiler_ctx if profiler_ctx is not None else suppress()
  1369. with ctx as prof:
  1370. for dst in range(0, dist.get_world_size()):
  1371. if dst == rank:
  1372. # Recv mode
  1373. for dst in range(0, dist.get_world_size()):
  1374. if dst == rank:
  1375. continue
  1376. for recv in ["recv", "irecv"]:
  1377. output_tensor = _build_tensor(send_recv_size, value=-1)
  1378. if recv == "recv":
  1379. sender = dist.recv(output_tensor)
  1380. recv_ranks.append(sender)
  1381. elif recv == "irecv":
  1382. work = dist.irecv(output_tensor)
  1383. work.wait()
  1384. sender = work._source_rank()
  1385. irecv_ranks.append(sender)
  1386. # Assert the scalar value "sender" that should be
  1387. # equal to the rank of the sender is equal to all
  1388. # values in the received tensor.
  1389. self.assertTrue(output_tensor.eq(sender).all())
  1390. else:
  1391. # Send mode
  1392. dist.send(tensor, dst) # recv
  1393. dist.send(tensor, dst) # irecv
  1394. if profiler_ctx is not None:
  1395. backend = dist.get_backend()
  1396. if backend in SEND_RECV_PROFILING_SUPPORTED_BACKENDS:
  1397. for event_name in [f"{backend}:send", f"{backend}:recvAnySource"]:
  1398. events = get_profiling_event(event_name, prof)
  1399. # Each rank sends/recvs from other rank twice.
  1400. self.assertEqual(
  1401. sum(event.count for event in events),
  1402. 2 * (dist.get_world_size() - 1),
  1403. )
  1404. for event in events:
  1405. self.assertTrue(event.is_async)
  1406. self.assertEqual(event.input_shapes, [[send_recv_size] * 3])
  1407. # Each rank would have 2 * (world_size - 1) sends, verify that
  1408. # globally we receive the same amount on the other end.
  1409. recv_ranks_tensor = torch.cat(
  1410. (torch.tensor(recv_ranks), torch.tensor(irecv_ranks)), 0
  1411. )
  1412. global_recv_ranks = [
  1413. torch.empty_like(recv_ranks_tensor)
  1414. for _ in range(dist.get_world_size())
  1415. ]
  1416. dist.all_gather(global_recv_ranks, recv_ranks_tensor)
  1417. global_recv_ranks_list = []
  1418. for tensor in global_recv_ranks:
  1419. global_recv_ranks_list += tensor.tolist()
  1420. from itertools import groupby
  1421. global_recv_ranks_list.sort()
  1422. frequency = [
  1423. len(list(group)) for key, group in groupby(global_recv_ranks_list)
  1424. ]
  1425. self.assertEqual(dist.get_world_size(), len(frequency))
  1426. self.assertEqual(
  1427. [2 * (dist.get_world_size() - 1)] * dist.get_world_size(), frequency
  1428. )
  1429. self._barrier()
  1430. @sandcastle_skip_if(
  1431. BACKEND in DistTestCases.skip_collective["sendrecv anysource"], f"{BACKEND} does not support send/recv from any source"
  1432. )
  1433. def test_send_recv_any_source(self):
  1434. self._test_send_recv_any_source(profiler_ctx=None)
  1435. @sandcastle_skip_if(
  1436. BACKEND in DistTestCases.skip_collective["sendrecv anysource"], f"{BACKEND} does not support send/recv from any source"
  1437. )
  1438. def test_send_recv_any_source_autograd_profiler(self):
  1439. autograd_profiler_ctx = _create_autograd_profiler()
  1440. self._test_send_recv_any_source(profiler_ctx=autograd_profiler_ctx)
  1441. @sandcastle_skip_if(
  1442. BACKEND in DistTestCases.skip_collective["sendrecv anysource"], f"{BACKEND} does not support send/recv from any source"
  1443. )
  1444. @sandcastle_skip_if(IS_FBCODE, "Kineto in fbcode code causes hang")
  1445. @sandcastle_skip_if(
  1446. IS_MACOS or IS_WINDOWS,
  1447. "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
  1448. )
  1449. def test_send_recv_any_source_torch_profiler(self):
  1450. torch_profiler_ctx = _create_torch_profiler()
  1451. return self._test_send_recv_any_source(profiler_ctx=torch_profiler_ctx)
  1452. # SEND RECV WITH TAG
  1453. def _test_send_recv_with_tag(self, profiler_ctx):
  1454. rank = dist.get_rank()
  1455. world_size = dist.get_world_size()
  1456. send_recv_size = 10
  1457. tensor = _build_tensor(send_recv_size, value=rank)
  1458. ctx = profiler_ctx if profiler_ctx is not None else suppress()
  1459. with ctx as prof:
  1460. for dst in range(0, world_size):
  1461. if dst == rank:
  1462. # Recv mode
  1463. for src in range(0, world_size):
  1464. if src == rank:
  1465. continue
  1466. output_tensor = _build_tensor(send_recv_size, value=-1)
  1467. dist.recv(output_tensor, src, tag=src)
  1468. self.assertTrue(output_tensor.eq(src).all())
  1469. else:
  1470. # Send mode
  1471. dist.send(tensor, dst, tag=rank)
  1472. if profiler_ctx is not None:
  1473. backend = dist.get_backend()
  1474. if backend in SEND_RECV_PROFILING_SUPPORTED_BACKENDS:
  1475. for event_name in [f"{backend}:send", f"{backend}:recv"]:
  1476. events = get_profiling_event(event_name, prof)
  1477. # Each rank sends/recvs from all other ranks
  1478. event_count = sum(e.count for e in events)
  1479. expected_event_count = dist.get_world_size() - 1
  1480. self.assertEqual(event_count, expected_event_count)
  1481. for event in events:
  1482. self.assertTrue(event.is_async)
  1483. self.assertEqual(event.name, event_name)
  1484. self.assertEqual(event.input_shapes, [[send_recv_size] * 3])
  1485. @sandcastle_skip_if(
  1486. BACKEND == "nccl", "NCCL send/recv tested by test_send_recv_nccl"
  1487. )
  1488. def test_send_recv_with_tag(self):
  1489. self._test_send_recv_with_tag(profiler_ctx=None)
  1490. @sandcastle_skip_if(
  1491. BACKEND == "nccl", "NCCL send/recv tested by test_send_recv_nccl"
  1492. )
  1493. def test_send_recv_with_tag_autograd_profiler(self):
  1494. autograd_profiler_ctx = _create_autograd_profiler()
  1495. return self._test_send_recv_with_tag(profiler_ctx=autograd_profiler_ctx)
  1496. @sandcastle_skip_if(
  1497. BACKEND == "nccl", "NCCL send/recv tested by test_send_recv_nccl"
  1498. )
  1499. @sandcastle_skip_if(IS_FBCODE, "Kineto in fbcode code causes hang")
  1500. @sandcastle_skip_if(
  1501. IS_MACOS or IS_WINDOWS,
  1502. "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
  1503. )
  1504. def test_send_recv_with_tag_torch_profiler(self):
  1505. torch_profiler_ctx = _create_torch_profiler()
  1506. return self._test_send_recv_with_tag(profiler_ctx=torch_profiler_ctx)
  1507. # ISEND
  1508. def _test_isend(self, profiler_ctx):
  1509. rank = dist.get_rank()
  1510. world_size = dist.get_world_size()
  1511. ctx = profiler_ctx if profiler_ctx is not None else suppress()
  1512. with ctx as prof:
  1513. if rank == 0:
  1514. requests = [
  1515. dist.isend(_build_tensor(dest, 10), dest)
  1516. for dest in range(1, world_size)
  1517. ]
  1518. for request in requests:
  1519. request.wait()
  1520. self.assertTrue(request.is_completed())
  1521. else:
  1522. tensor = _build_tensor(rank, -1)
  1523. dist.recv(tensor, 0)
  1524. self.assertEqual(tensor, _build_tensor(rank, 10))
  1525. self._barrier()
  1526. if profiler_ctx is not None:
  1527. backend = dist.get_backend()
  1528. if backend in SEND_RECV_PROFILING_SUPPORTED_BACKENDS:
  1529. expected_event_name = (
  1530. f"{backend}:send" if rank == 0 else f"{backend}:recv"
  1531. )
  1532. events = get_profiling_event(expected_event_name, prof)
  1533. event_count = sum(e.count for e in events)
  1534. expected_count = dist.get_world_size() - 1 if rank == 0 else 1
  1535. self.assertEqual(expected_count, event_count)
  1536. # Event ordering is not guaranteed, so simply ensure the shapes are
  1537. # found in the following map.
  1538. expected_shapes = {
  1539. r: [[r] * 3] for r in range(1, dist.get_world_size())
  1540. }
  1541. for event in events:
  1542. self.assertTrue(event.is_async)
  1543. self.assertEqual(event.name, expected_event_name)
  1544. if rank == 0:
  1545. self.assertTrue(
  1546. event.input_shapes in expected_shapes.values()
  1547. )
  1548. else:
  1549. self.assertEqual(event.input_shapes, expected_shapes[rank])
  1550. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support isend")
  1551. def test_isend(self):
  1552. self._test_isend(profiler_ctx=None)
  1553. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support isend")
  1554. def test_isend_autograd_profiler(self):
  1555. autograd_profiler_ctx = _create_autograd_profiler()
  1556. self._test_isend(profiler_ctx=autograd_profiler_ctx)
  1557. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support isend")
  1558. @sandcastle_skip_if(IS_FBCODE, "Kineto in fbcode code causes hang")
  1559. @sandcastle_skip_if(
  1560. IS_MACOS or IS_WINDOWS,
  1561. "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
  1562. )
  1563. def test_isend_torch_profiler(self):
  1564. torch_profiler_ctx = _create_torch_profiler()
  1565. self._test_isend(profiler_ctx=torch_profiler_ctx)
  1566. # IRECV
  1567. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support irecv")
  1568. def test_irecv(self):
  1569. rank = dist.get_rank()
  1570. world_size = dist.get_world_size()
  1571. if rank == 0:
  1572. expected_tensors = [
  1573. _build_tensor(src, -1) for src in range(1, world_size)
  1574. ]
  1575. requests = [
  1576. dist.irecv(expected_tensors[src - 1], src)
  1577. for src in range(1, world_size)
  1578. ]
  1579. for src in range(1, world_size):
  1580. requests[src - 1].wait()
  1581. self.assertTrue(requests[src - 1].is_completed())
  1582. self.assertEqual(expected_tensors[src - 1], _build_tensor(src, 10))
  1583. else:
  1584. tensor = _build_tensor(rank, 10)
  1585. dist.send(tensor, 0)
  1586. self._barrier()
  1587. # BROADCAST
  1588. def _test_broadcast_helper(
  1589. self,
  1590. group,
  1591. group_id,
  1592. rank,
  1593. cuda=False,
  1594. rank_to_GPU=None,
  1595. with_options=False,
  1596. ):
  1597. for dtype, value, requires_cuda in [
  1598. (torch.float, -1e-10, False),
  1599. (torch.double, -1e-100, False),
  1600. (torch.half, -0.1, True),
  1601. (torch.int8, -2, False),
  1602. (torch.uint8, 129, False),
  1603. (torch.int, -1e5, False),
  1604. (torch.long, -1e15, False),
  1605. ]:
  1606. if requires_cuda and not cuda:
  1607. continue
  1608. for src in group:
  1609. expected_tensor = _build_tensor(src + 1, value, dtype)
  1610. if cuda:
  1611. expected_tensor = expected_tensor.cuda(rank_to_GPU[rank][0])
  1612. if rank == src:
  1613. if with_options:
  1614. opts = dist.BroadcastOptions()
  1615. opts.rootTensor = 0
  1616. opts.rootRank = src
  1617. self.call_dist_op(
  1618. ":broadcast",
  1619. True,
  1620. group_id.broadcast,
  1621. [expected_tensor],
  1622. opts,
  1623. )
  1624. else:
  1625. self.call_dist_op(
  1626. ":broadcast",
  1627. False,
  1628. dist.broadcast,
  1629. expected_tensor,
  1630. src,
  1631. group_id,
  1632. )
  1633. else:
  1634. tensor = _build_tensor(src + 1, -1, dtype)
  1635. if cuda:
  1636. tensor = tensor.cuda(rank_to_GPU[rank][0])
  1637. if with_options:
  1638. opts = dist.BroadcastOptions()
  1639. opts.rootTensor = 0
  1640. opts.rootRank = src
  1641. self.call_dist_op(
  1642. ":broadcast", True, group_id.broadcast, [tensor], opts
  1643. )
  1644. else:
  1645. self.call_dist_op(
  1646. ":broadcast",
  1647. False,
  1648. dist.broadcast,
  1649. tensor,
  1650. src,
  1651. group_id,
  1652. )
  1653. self.assertEqual(tensor.size(), expected_tensor.size())
  1654. self.assertEqual(
  1655. tensor.ne(expected_tensor).max(), torch.tensor(False)
  1656. )
  1657. self._barrier()
  1658. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  1659. def test_broadcast(self):
  1660. group, group_id, rank = self._init_global_test()
  1661. self._test_broadcast_helper(group, group_id, rank)
  1662. @sandcastle_skip_if(
  1663. BACKEND != "gloo" and BACKEND != "nccl",
  1664. "Only Gloo and Nccl backend supports CUDA allReduce",
  1665. )
  1666. @skip_if_no_gpu
  1667. def test_broadcast_cuda(self):
  1668. group, group_id, rank = self._init_global_test()
  1669. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  1670. device_id = rank_to_GPU[rank][0]
  1671. torch.cuda.set_device(device_id)
  1672. self._test_broadcast_helper(group, group_id, rank, True, rank_to_GPU)
  1673. @skip_if_small_worldsize
  1674. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  1675. def test_broadcast_group(self):
  1676. group, group_id, rank = self._init_group_test()
  1677. self._test_broadcast_helper(group, group_id, rank)
  1678. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  1679. def test_broadcast_full_group(self):
  1680. group, group_id, rank = self._init_full_group_test()
  1681. self._test_broadcast_helper(group, group_id, rank)
  1682. @sandcastle_skip_if(
  1683. BACKEND != "nccl",
  1684. "Only NCCL backend supports high priority stream",
  1685. )
  1686. @skip_if_no_gpu
  1687. def test_nccl_high_priority_stream(self):
  1688. group, _, rank = self._init_global_test()
  1689. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  1690. device_id = rank_to_GPU[rank][0]
  1691. torch.cuda.set_device(device_id)
  1692. new_port = str(MASTER_PORT + 1)
  1693. os.environ["MASTER_PORT"] = new_port
  1694. gen_iterator = dist.rendezvous("env://", rank, dist.get_world_size())
  1695. store, rank, size = next(gen_iterator)
  1696. store = dist.PrefixStore(new_port, store)
  1697. opts = dist.ProcessGroupNCCL.Options()
  1698. opts.is_high_priority_stream = False
  1699. group_id = dist.ProcessGroupNCCL(store, rank, size, opts)
  1700. self._test_broadcast_helper(group, group_id, rank, True, rank_to_GPU, True)
  1701. # REDUCE
  1702. def _test_reduce_helper(
  1703. self,
  1704. group,
  1705. group_id,
  1706. rank,
  1707. op,
  1708. master_value,
  1709. worker_value,
  1710. expected_value,
  1711. cuda=False,
  1712. rank_to_GPU=None,
  1713. ):
  1714. for src in group:
  1715. tensor = _build_tensor(src + 1).fill_(
  1716. master_value if rank == src else worker_value
  1717. )
  1718. if cuda:
  1719. tensor = tensor.cuda(rank_to_GPU[rank][0])
  1720. self.call_dist_op(
  1721. ":reduce",
  1722. False,
  1723. dist.reduce,
  1724. tensor,
  1725. src,
  1726. op,
  1727. group_id,
  1728. tensor_shapes=[tensor.shape],
  1729. )
  1730. if rank == src:
  1731. self.assertEqual(tensor, _build_tensor(src + 1, expected_value))
  1732. self._barrier()
  1733. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  1734. @sandcastle_skip_if(BACKEND in DistTestCases.skip_collective["reduce"], f"{BACKEND} does not support reduce")
  1735. def test_reduce_sum(self):
  1736. group, group_id, rank = self._init_global_test()
  1737. self._test_reduce_helper(
  1738. group,
  1739. group_id,
  1740. rank,
  1741. dist.ReduceOp.SUM,
  1742. 2,
  1743. 10,
  1744. 2 + (10 * (len(group) - 1)),
  1745. )
  1746. @sandcastle_skip_if(BACKEND != "nccl", "Only Nccl supports CUDA reduce")
  1747. @sandcastle_skip_if(BACKEND in DistTestCases.skip_collective["reduce"], f"{BACKEND} does not support reduce")
  1748. @skip_if_no_gpu
  1749. def test_reduce_sum_cuda(self):
  1750. group, group_id, rank = self._init_global_test()
  1751. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  1752. device_id = rank_to_GPU[rank][0]
  1753. torch.cuda.set_device(device_id)
  1754. self._test_reduce_helper(
  1755. group,
  1756. group_id,
  1757. rank,
  1758. dist.ReduceOp.SUM,
  1759. 2,
  1760. 10,
  1761. 2 + 10 * (len(group) - 1),
  1762. True,
  1763. rank_to_GPU,
  1764. )
  1765. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  1766. @sandcastle_skip_if(BACKEND in DistTestCases.skip_collective["reduce"], f"{BACKEND} does not support reduce")
  1767. def test_reduce_product(self):
  1768. group, group_id, rank = self._init_global_test()
  1769. self._test_reduce_helper(
  1770. group,
  1771. group_id,
  1772. rank,
  1773. dist.ReduceOp.PRODUCT,
  1774. 2,
  1775. 10,
  1776. reduce((lambda x, y: x * y), [10] * (len(group) - 1), 2),
  1777. )
  1778. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  1779. @sandcastle_skip_if(BACKEND in DistTestCases.skip_collective["reduce"], f"{BACKEND} does not support reduce")
  1780. def test_reduce_min(self):
  1781. group, group_id, rank = self._init_global_test()
  1782. self._test_reduce_helper(
  1783. group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1
  1784. )
  1785. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  1786. @sandcastle_skip_if(BACKEND in DistTestCases.skip_collective["reduce"], f"{BACKEND} does not support reduce")
  1787. def test_reduce_max(self):
  1788. group, group_id, rank = self._init_global_test()
  1789. self._test_reduce_helper(
  1790. group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10
  1791. )
  1792. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  1793. @sandcastle_skip_if(BACKEND in DistTestCases.skip_collective["reduce"], f"{BACKEND} does not support reduce")
  1794. @skip_if_small_worldsize
  1795. def test_reduce_group_sum(self):
  1796. group, group_id, rank = self._init_group_test()
  1797. self._test_reduce_helper(
  1798. group,
  1799. group_id,
  1800. rank,
  1801. dist.ReduceOp.SUM,
  1802. 2,
  1803. 10,
  1804. 2 + (10 * (len(group) - 1)),
  1805. )
  1806. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  1807. @sandcastle_skip_if(BACKEND in DistTestCases.skip_collective["reduce"], f"{BACKEND} does not support reduce")
  1808. @skip_if_small_worldsize
  1809. def test_reduce_group_product(self):
  1810. group, group_id, rank = self._init_group_test()
  1811. self._test_reduce_helper(
  1812. group,
  1813. group_id,
  1814. rank,
  1815. dist.ReduceOp.PRODUCT,
  1816. 2,
  1817. 10,
  1818. reduce((lambda x, y: x * y), [10] * (len(group) - 1), 2),
  1819. )
  1820. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  1821. @sandcastle_skip_if(BACKEND in DistTestCases.skip_collective["reduce"], f"{BACKEND} does not support reduce")
  1822. @skip_if_small_worldsize
  1823. def test_reduce_group_min(self):
  1824. group, group_id, rank = self._init_group_test()
  1825. self._test_reduce_helper(
  1826. group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1
  1827. )
  1828. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  1829. @sandcastle_skip_if(BACKEND in DistTestCases.skip_collective["reduce"], f"{BACKEND} does not support reduce")
  1830. @skip_if_small_worldsize
  1831. def test_reduce_group_max(self):
  1832. group, group_id, rank = self._init_group_test()
  1833. self._test_reduce_helper(
  1834. group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10
  1835. )
  1836. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  1837. @sandcastle_skip_if(BACKEND in DistTestCases.skip_collective["reduce"], f"{BACKEND} does not support reduce")
  1838. def test_reduce_full_group_sum(self):
  1839. group, group_id, rank = self._init_full_group_test()
  1840. self._test_reduce_helper(
  1841. group,
  1842. group_id,
  1843. rank,
  1844. dist.ReduceOp.SUM,
  1845. 2,
  1846. 10,
  1847. 2 + (10 * (len(group) - 1)),
  1848. )
  1849. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  1850. @sandcastle_skip_if(BACKEND in DistTestCases.skip_collective["reduce"], f"{BACKEND} does not support reduce")
  1851. def test_reduce_full_group_product(self):
  1852. group, group_id, rank = self._init_full_group_test()
  1853. self._test_reduce_helper(
  1854. group,
  1855. group_id,
  1856. rank,
  1857. dist.ReduceOp.PRODUCT,
  1858. 2,
  1859. 10,
  1860. reduce((lambda x, y: x * y), [10] * (len(group) - 1), 2),
  1861. )
  1862. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  1863. @sandcastle_skip_if(BACKEND in DistTestCases.skip_collective["reduce"], f"{BACKEND} does not support reduce")
  1864. def test_reduce_full_group_min(self):
  1865. group, group_id, rank = self._init_full_group_test()
  1866. self._test_reduce_helper(
  1867. group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1
  1868. )
  1869. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  1870. @sandcastle_skip_if(BACKEND in DistTestCases.skip_collective["reduce"], f"{BACKEND} does not support reduce")
  1871. def test_reduce_full_group_max(self):
  1872. group, group_id, rank = self._init_full_group_test()
  1873. self._test_reduce_helper(
  1874. group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10
  1875. )
  1876. # REDUCE TWICE
  1877. def _test_reduce_twice_helper(
  1878. self,
  1879. group,
  1880. group_id,
  1881. rank,
  1882. op,
  1883. master_value,
  1884. worker_value,
  1885. expected_value,
  1886. cuda=False,
  1887. rank_to_GPU=None,
  1888. ):
  1889. for src in group:
  1890. tensors = [
  1891. _build_tensor(src + 1).fill_(
  1892. master_value if rank == src else worker_value
  1893. )
  1894. for i in range(2)
  1895. ]
  1896. if cuda:
  1897. for i in range(2):
  1898. tensors[i] = tensors[i].cuda(rank_to_GPU[rank][0])
  1899. self.call_dist_op(
  1900. ":reduce",
  1901. False,
  1902. dist.reduce,
  1903. tensors[0],
  1904. src,
  1905. op,
  1906. group_id,
  1907. secondary_op_call=lambda: dist.reduce(
  1908. tensors[1], src, op, group_id
  1909. ),
  1910. tensor_shapes=[tensors[0].shape],
  1911. )
  1912. if rank == src:
  1913. for tensor in tensors:
  1914. self.assertEqual(tensor, _build_tensor(src + 1, expected_value))
  1915. self._barrier()
  1916. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  1917. @sandcastle_skip_if(BACKEND in DistTestCases.skip_collective["reduce"], f"{BACKEND} does not support reduce")
  1918. def test_reduce_sum_twice(self):
  1919. group, group_id, rank = self._init_global_test()
  1920. self._test_reduce_twice_helper(
  1921. group,
  1922. group_id,
  1923. rank,
  1924. dist.ReduceOp.SUM,
  1925. 2,
  1926. 10,
  1927. 2 + (10 * (len(group) - 1)),
  1928. )
  1929. @sandcastle_skip_if(BACKEND != "nccl", "Only Nccl supports CUDA reduce")
  1930. @sandcastle_skip_if(BACKEND in DistTestCases.skip_collective["reduce"], f"{BACKEND} does not support reduce")
  1931. @skip_if_no_gpu
  1932. def test_reduce_sum_cuda_twice(self):
  1933. group, group_id, rank = self._init_global_test()
  1934. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  1935. device_id = rank_to_GPU[rank][0]
  1936. torch.cuda.set_device(device_id)
  1937. self._test_reduce_twice_helper(
  1938. group,
  1939. group_id,
  1940. rank,
  1941. dist.ReduceOp.SUM,
  1942. 2,
  1943. 10,
  1944. 2 + 10 * (len(group) - 1),
  1945. True,
  1946. rank_to_GPU,
  1947. )
  1948. @sandcastle_skip_if(BACKEND != "nccl", "Only Nccl supports reduce_scatter_v")
  1949. @sandcastle_skip_if(BACKEND in DistTestCases.skip_collective["reduce"], f"{BACKEND} does not support reduce")
  1950. @skip_if_no_gpu
  1951. def test_reduce_scatter_v_cuda(self):
  1952. self._barrier()
  1953. group, group_id, rank = self._init_global_test()
  1954. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  1955. device_id = rank_to_GPU[rank][0]
  1956. input_split_sizes = []
  1957. for src in group:
  1958. input_split_sizes.append(src + 1)
  1959. start_len = sum(input_split_sizes[:rank])
  1960. end_len = start_len + input_split_sizes[rank]
  1961. sum_len = sum(input_split_sizes)
  1962. master_value = 2
  1963. worker_value = 10
  1964. for async_val in [True, False]:
  1965. tensor = _build_tensor(sum_len, worker_value, device_id=device_id)
  1966. tensor[start_len:end_len].fill_(master_value)
  1967. out_tensor = torch.empty(input_split_sizes[rank], sum_len, sum_len, dtype=torch.float).fill_(-1).cuda(device_id)
  1968. req = dist.reduce_scatter(
  1969. out_tensor,
  1970. list(torch.split(tensor, input_split_sizes)),
  1971. dist.ReduceOp.SUM,
  1972. group_id,
  1973. async_val,
  1974. )
  1975. if async_val:
  1976. req.wait()
  1977. expected_value = 2 + (10 * (len(group) - 1))
  1978. expected_tensor = torch.empty(input_split_sizes[rank], sum_len, sum_len, dtype=torch.float)
  1979. expected_tensor = expected_tensor.fill_(expected_value).cuda(device_id)
  1980. self.assertEqual(out_tensor, expected_tensor)
  1981. self._barrier()
  1982. # Test reduce_scatter_tensor accepting single tensor as input
  1983. def _reduce_scatter_tensor_helper(
  1984. self, tensor_out, tensor_in,
  1985. group_id, rank, cuda=True, rank_to_GPU=None
  1986. ):
  1987. if cuda:
  1988. tensor_in = tensor_in.cuda(rank_to_GPU[rank][0])
  1989. tensor_out = tensor_out.cuda(rank_to_GPU[rank][0])
  1990. tensor_shapes = [tensor_out.shape]
  1991. self.call_dist_op(
  1992. ":reduce_scatter_tensor",
  1993. False,
  1994. dist.reduce_scatter_tensor,
  1995. tensor_out,
  1996. tensor_in,
  1997. dist.ReduceOp.SUM,
  1998. group_id,
  1999. False,
  2000. expect_event=False,
  2001. tensor_shapes=tensor_shapes,
  2002. )
  2003. return tensor_out
  2004. @sandcastle_skip_if(BACKEND != "nccl", "Only Nccl supports CUDA reduce_scatter_tensor")
  2005. @skip_if_no_gpu
  2006. def test_reduce_scatter_tensor_cuda(self):
  2007. group, group_id, rank = self._init_global_test()
  2008. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  2009. size = 2
  2010. tensor_out = torch.zeros(size, dtype=torch.int64)
  2011. # Concatenated input
  2012. tensor_in = torch.arange(len(group) * size)
  2013. tensor_out = self._reduce_scatter_tensor_helper(tensor_out, tensor_in, group_id, rank, True, rank_to_GPU)
  2014. # Check result
  2015. expected_tensor = torch.arange(rank * size, (rank + 1) * size) * len(group)
  2016. self.assertEqual(tensor_out, expected_tensor)
  2017. self._barrier()
  2018. # Stacked input
  2019. tensor_in = torch.reshape(tensor_in, (len(group), size))
  2020. tensor_out = self._reduce_scatter_tensor_helper(tensor_out, tensor_in, group_id, rank, True, rank_to_GPU)
  2021. # Check result
  2022. # Should be the same as the result in concatenated case
  2023. self.assertEqual(tensor_out, expected_tensor)
  2024. self._barrier()
  2025. @skip_if_no_gpu
  2026. @require_backend(DistTestCases.backend_feature["gpu"])
  2027. def test_all_reduce_result_cuda(self):
  2028. group, group_id, rank = self._init_global_test()
  2029. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  2030. for src in group:
  2031. if rank == src:
  2032. tensor = _build_tensor(src + 1, 2)
  2033. else:
  2034. tensor = _build_tensor(src + 1, 10)
  2035. tensor = tensor.cuda(rank_to_GPU[rank][0])
  2036. opts = AllreduceOptions()
  2037. opts.reduceOp = dist.ReduceOp.SUM
  2038. if group_id == GroupMember.WORLD:
  2039. work = _get_default_group().allreduce([tensor], opts)
  2040. else:
  2041. work = group_id.allreduce([tensor], opts)
  2042. if BACKEND == "gloo":
  2043. # Calling result right the work is finished should throw exception.
  2044. # Here we have a race condition, we may not assume the work is not
  2045. # finished by the time we run next lines.
  2046. try:
  2047. with self.assertRaisesRegex(
  2048. RuntimeError,
  2049. "Work needs to be completed before calling result",
  2050. ):
  2051. work.result()
  2052. except AssertionError:
  2053. # Exception was not raised, ensure is_completed()
  2054. self.assertTrue(work.is_completed())
  2055. work.wait()
  2056. result = work.result()
  2057. else:
  2058. # In case of NCCL we should be able to retrieve pointer to the result
  2059. # even before work is finished.
  2060. result = work.result()
  2061. work.wait()
  2062. expected_value = 2 + (10 * (len(group) - 1))
  2063. self.assertEqual(result, [_build_tensor(src + 1, expected_value)])
  2064. self._barrier()
  2065. def call_dist_op(
  2066. self,
  2067. profiling_title_postfix,
  2068. is_async,
  2069. op,
  2070. *args,
  2071. expect_event=True,
  2072. secondary_op_call=None,
  2073. profile_cuda=False,
  2074. tensor_shapes=None,
  2075. **kwargs,
  2076. ):
  2077. op_calls = [lambda: op(*args, **kwargs)]
  2078. if secondary_op_call is not None:
  2079. op_calls.append(secondary_op_call)
  2080. autograd_profiler_ctx = torch.autograd.profiler.profile(
  2081. use_cuda=profile_cuda, record_shapes=True
  2082. )
  2083. # TODO: move this test to use torch.profiler once kineto issues are
  2084. # fixed internally.
  2085. with autograd_profiler_ctx as prof:
  2086. works = [op_call() for op_call in op_calls]
  2087. if is_async:
  2088. for work in works:
  2089. work.wait()
  2090. if expect_event and dist.get_backend() in PROFILING_SUPPORTED_BACKENDS:
  2091. # We are only interested in the backend's implementation not the dispatcher wrapper.
  2092. events = get_profiling_event(
  2093. dist.get_backend() + profiling_title_postfix, autograd_profiler_ctx
  2094. )
  2095. # DETAIL debug mode can use a pg wrapper that issues more collectives
  2096. # under the hood
  2097. if dist.get_debug_level() != dist.DebugLevel.DETAIL:
  2098. self.assertEqual(len(events), len(op_calls))
  2099. for e in events:
  2100. self.assertTrue(e.is_async)
  2101. self.assertEqual(e.count, 1)
  2102. self.assertGreaterEqual(e.cpu_time, 0)
  2103. # Verify tensor shapes if given
  2104. # DETAIL debug mode can use a pg wrapper that issues more collectives
  2105. # under the hood
  2106. if (
  2107. tensor_shapes is not None
  2108. and dist.get_debug_level() != dist.DebugLevel.DETAIL
  2109. ):
  2110. self.assertEqual(
  2111. e.input_shapes,
  2112. tensor_shapes,
  2113. f"event shape: {e.input_shapes} vs tensor {tensor_shapes}",
  2114. )
  2115. # ALL REDUCE
  2116. def _test_all_reduce_helper(
  2117. self,
  2118. group,
  2119. group_id,
  2120. rank,
  2121. op,
  2122. master_value,
  2123. worker_value,
  2124. expected_value,
  2125. cuda=False,
  2126. rank_to_GPU=None,
  2127. dtype=torch.float,
  2128. async_op=False,
  2129. ):
  2130. for src in group:
  2131. curr_value = master_value if rank == src else worker_value
  2132. tensor = _build_tensor(src + 1, dtype=dtype).fill_(curr_value)
  2133. if cuda:
  2134. tensor = tensor.cuda(rank_to_GPU[rank][0])
  2135. if tensor.dtype == torch.complex64:
  2136. tensor_shapes = [torch.view_as_real(tensor).shape]
  2137. else:
  2138. tensor_shapes = [tensor.shape]
  2139. self.call_dist_op(
  2140. ":all_reduce",
  2141. async_op,
  2142. dist.all_reduce,
  2143. tensor,
  2144. op,
  2145. group_id,
  2146. async_op=async_op,
  2147. tensor_shapes=tensor_shapes,
  2148. )
  2149. # Currently, only Gloo backend has profiling tested with CUDA enabled.
  2150. # Only run cuda profiling test for one rank to speed up since
  2151. # running with different src_rank does not affect the correctness.
  2152. if (
  2153. src == 0
  2154. and cuda
  2155. and dist.get_backend() in CUDA_PROFILING_SUPPORTED_BACKENDS
  2156. ):
  2157. self.call_dist_op(
  2158. ":all_reduce",
  2159. async_op,
  2160. dist.all_reduce,
  2161. tensor,
  2162. op,
  2163. group_id,
  2164. async_op=async_op,
  2165. profile_cuda=True,
  2166. tensor_shapes=tensor_shapes,
  2167. )
  2168. self._barrier()
  2169. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  2170. def test_all_reduce_sum(self):
  2171. group, group_id, rank = self._init_global_test()
  2172. self._test_all_reduce_helper(
  2173. group,
  2174. group_id,
  2175. rank,
  2176. dist.ReduceOp.SUM,
  2177. 2,
  2178. 10,
  2179. 2 + (10 * (len(group) - 1)),
  2180. )
  2181. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  2182. def test_all_reduce_sum_async(self):
  2183. group, group_id, rank = self._init_global_test()
  2184. self._test_all_reduce_helper(
  2185. group,
  2186. group_id,
  2187. rank,
  2188. dist.ReduceOp.SUM,
  2189. 2,
  2190. 10,
  2191. 2 + (10 * (len(group) - 1)),
  2192. async_op=True,
  2193. )
  2194. @sandcastle_skip_if(
  2195. BACKEND != "gloo" and BACKEND != "nccl",
  2196. "Only Gloo and NCCL backends will have CUDA allReduce tested",
  2197. )
  2198. @skip_if_no_gpu
  2199. def test_all_reduce_sum_cuda(self):
  2200. torch.cuda.set_device(self.rank)
  2201. group, group_id, rank = self._init_global_test()
  2202. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  2203. self._test_all_reduce_helper(
  2204. group,
  2205. group_id,
  2206. rank,
  2207. dist.ReduceOp.SUM,
  2208. 2,
  2209. 10,
  2210. 2 + (10 * (len(group) - 1)),
  2211. True,
  2212. rank_to_GPU,
  2213. )
  2214. @sandcastle_skip_if(
  2215. BACKEND != "gloo" and BACKEND != "nccl",
  2216. "Only Gloo and NCCL backends will have CUDA allReduce tested",
  2217. )
  2218. @skip_if_no_gpu
  2219. def test_all_reduce_sum_cuda_async(self):
  2220. torch.cuda.set_device(self.rank)
  2221. group, group_id, rank = self._init_global_test()
  2222. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  2223. self._test_all_reduce_helper(
  2224. group,
  2225. group_id,
  2226. rank,
  2227. dist.ReduceOp.SUM,
  2228. 2,
  2229. 10,
  2230. 2 + (10 * (len(group) - 1)),
  2231. True,
  2232. rank_to_GPU,
  2233. async_op=True,
  2234. )
  2235. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  2236. def test_all_reduce_sum_complex(self):
  2237. group, group_id, rank = self._init_global_test()
  2238. self._test_all_reduce_helper(
  2239. group,
  2240. group_id,
  2241. rank,
  2242. dist.ReduceOp.SUM,
  2243. complex(2, 3),
  2244. complex(10, 11),
  2245. complex(2, 3) + (complex(10, 11) * (len(group) - 1)),
  2246. dtype=torch.cfloat,
  2247. )
  2248. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  2249. def test_all_reduce_complex_unsupported_ops(self):
  2250. unsupported_ops = [
  2251. dist.ReduceOp.MAX,
  2252. dist.ReduceOp.MIN,
  2253. dist.ReduceOp.PRODUCT,
  2254. dist.ReduceOp.BAND,
  2255. dist.ReduceOp.BOR,
  2256. dist.ReduceOp.BXOR,
  2257. ]
  2258. group, group_id, rank = self._init_global_test()
  2259. for unsupported_op in unsupported_ops:
  2260. with self.assertRaisesRegex(
  2261. RuntimeError, "all_reduce does not support"
  2262. ):
  2263. dist.all_reduce(
  2264. _build_tensor(1, dtype=torch.cfloat), unsupported_op, group_id
  2265. )
  2266. @sandcastle_skip_if(
  2267. BACKEND != "gloo" and BACKEND != "nccl",
  2268. "Only Gloo and NCCL backends will have CUDA allReduce tested",
  2269. )
  2270. @skip_if_no_gpu
  2271. def test_all_reduce_sum_cuda_complex(self):
  2272. torch.cuda.set_device(self.rank)
  2273. group, group_id, rank = self._init_global_test()
  2274. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  2275. self._test_all_reduce_helper(
  2276. group,
  2277. group_id,
  2278. rank,
  2279. dist.ReduceOp.SUM,
  2280. complex(2, 3),
  2281. complex(10, 11),
  2282. complex(2, 3) + (complex(10, 11) * (len(group) - 1)),
  2283. True,
  2284. rank_to_GPU,
  2285. dtype=torch.cfloat,
  2286. )
  2287. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  2288. def test_all_reduce_product(self):
  2289. group, group_id, rank = self._init_global_test()
  2290. self._test_all_reduce_helper(
  2291. group,
  2292. group_id,
  2293. rank,
  2294. dist.ReduceOp.PRODUCT,
  2295. 2,
  2296. 10,
  2297. reduce((lambda x, y: x * y), [10] * (len(group) - 1), 2),
  2298. )
  2299. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  2300. def test_all_reduce_min(self):
  2301. group, group_id, rank = self._init_global_test()
  2302. self._test_all_reduce_helper(
  2303. group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1
  2304. )
  2305. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  2306. def test_all_reduce_max(self):
  2307. group, group_id, rank = self._init_global_test()
  2308. self._test_all_reduce_helper(
  2309. group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10
  2310. )
  2311. @skip_if_small_worldsize
  2312. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  2313. def test_all_reduce_group_sum(self):
  2314. group, group_id, rank = self._init_group_test()
  2315. self._test_all_reduce_helper(
  2316. group,
  2317. group_id,
  2318. rank,
  2319. dist.ReduceOp.SUM,
  2320. 2,
  2321. 10,
  2322. 2 + (10 * (len(group) - 1)),
  2323. )
  2324. @skip_if_small_worldsize
  2325. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  2326. def test_all_reduce_group_product(self):
  2327. group, group_id, rank = self._init_group_test()
  2328. self._test_all_reduce_helper(
  2329. group,
  2330. group_id,
  2331. rank,
  2332. dist.ReduceOp.PRODUCT,
  2333. 2,
  2334. 10,
  2335. reduce((lambda x, y: x * y), [10] * (len(group) - 1), 2),
  2336. )
  2337. @skip_if_small_worldsize
  2338. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  2339. def test_all_reduce_group_min(self):
  2340. group, group_id, rank = self._init_group_test()
  2341. self._test_all_reduce_helper(
  2342. group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1
  2343. )
  2344. @skip_if_small_worldsize
  2345. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  2346. def test_all_reduce_group_max(self):
  2347. group, group_id, rank = self._init_group_test()
  2348. self._test_all_reduce_helper(
  2349. group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10
  2350. )
  2351. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  2352. def test_all_reduce_full_group_sum(self):
  2353. group, group_id, rank = self._init_full_group_test()
  2354. self._test_all_reduce_helper(
  2355. group,
  2356. group_id,
  2357. rank,
  2358. dist.ReduceOp.SUM,
  2359. 2,
  2360. 10,
  2361. 2 + (10 * (len(group) - 1)),
  2362. )
  2363. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  2364. def test_all_reduce_full_group_product(self):
  2365. group, group_id, rank = self._init_full_group_test()
  2366. self._test_all_reduce_helper(
  2367. group,
  2368. group_id,
  2369. rank,
  2370. dist.ReduceOp.PRODUCT,
  2371. 2,
  2372. 10,
  2373. reduce((lambda x, y: x * y), [10] * (len(group) - 1), 2),
  2374. )
  2375. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  2376. def test_all_reduce_full_group_min(self):
  2377. group, group_id, rank = self._init_full_group_test()
  2378. self._test_all_reduce_helper(
  2379. group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1
  2380. )
  2381. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  2382. def test_all_reduce_full_group_max(self):
  2383. group, group_id, rank = self._init_full_group_test()
  2384. self._test_all_reduce_helper(
  2385. group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10
  2386. )
  2387. # SPARSE ALL REDUCE
  2388. def _test_sparse_all_reduce_sum(self, fn):
  2389. group, group_id, rank = self._init_global_test()
  2390. tests = simple_sparse_reduce_tests(
  2391. rank, dist.get_world_size(), num_inputs=1
  2392. )
  2393. for (inputs, outputs) in tests:
  2394. tensors = [fn(input) for input in inputs]
  2395. dist.all_reduce(tensors[0], dist.ReduceOp.SUM, group_id)
  2396. self.assertEqual(tensors[0], outputs[0])
  2397. @sandcastle_skip_if(
  2398. BACKEND != "gloo", "Only Gloo backend support sparse all reduce"
  2399. )
  2400. def test_sparse_all_reduce_sum(self):
  2401. self._test_sparse_all_reduce_sum(lambda t: t)
  2402. @sandcastle_skip_if(
  2403. BACKEND != "gloo", "Only Gloo backend support sparse all reduce"
  2404. )
  2405. @skip_if_no_gpu
  2406. def test_sparse_all_reduce_sum_cuda(self):
  2407. self._test_sparse_all_reduce_sum(lambda t: t.clone().cuda())
  2408. # ALL REDUCE - COALESCED
  2409. @staticmethod
  2410. def _all_reduce_coalesced_sum_test_cases(group_size):
  2411. return (
  2412. [2, 3, complex(2, 3)],
  2413. [10, 11, complex(10, 11)],
  2414. [
  2415. 2 + 10 * (group_size - 1),
  2416. 3 + 11 * (group_size - 1),
  2417. complex(2, 3) + complex(10, 11) * (group_size - 1),
  2418. ],
  2419. [torch.float, torch.float, torch.cfloat],
  2420. )
  2421. @staticmethod
  2422. def _all_reduce_coalesced_product_test_cases(group_size):
  2423. return (
  2424. [1, 2],
  2425. [3, 4],
  2426. [1 * 3 ** (group_size - 1), 2 * 4 ** (group_size - 1)],
  2427. [torch.float, torch.float],
  2428. )
  2429. @staticmethod
  2430. def _all_reduce_coalesced_min_test_cases(group_size):
  2431. return (
  2432. [1, 4],
  2433. [2, 3],
  2434. [1, 3],
  2435. [torch.float, torch.float],
  2436. )
  2437. @staticmethod
  2438. def _all_reduce_coalesced_max_test_cases(group_size):
  2439. return (
  2440. [1, 4],
  2441. [2, 3],
  2442. [2, 4],
  2443. [torch.float, torch.float],
  2444. )
  2445. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  2446. def test_all_reduce_coalesced_max_complex_unsupported(self):
  2447. group, group_id, rank = self._init_global_test()
  2448. with self.assertRaisesRegex(RuntimeError, "all_reduce does not support"):
  2449. dist.all_reduce_coalesced(
  2450. [_build_tensor(1, dtype=torch.cfloat)], dist.ReduceOp.MAX, group_id
  2451. )
  2452. def _test_all_reduce_coalesced_helper(
  2453. self,
  2454. group,
  2455. group_id,
  2456. rank,
  2457. op,
  2458. cuda=False,
  2459. rank_to_GPU=None,
  2460. ):
  2461. test_case_func = {
  2462. dist.ReduceOp.SUM: self._all_reduce_coalesced_sum_test_cases,
  2463. dist.ReduceOp.PRODUCT: self._all_reduce_coalesced_product_test_cases,
  2464. dist.ReduceOp.MIN: self._all_reduce_coalesced_min_test_cases,
  2465. dist.ReduceOp.MAX: self._all_reduce_coalesced_max_test_cases,
  2466. }[op]
  2467. master_values, worker_values, expected_values, dtypes = test_case_func(
  2468. len(group)
  2469. )
  2470. for src in group:
  2471. curr_values = master_values if rank == src else worker_values
  2472. tensors = [
  2473. _build_tensor(src + 1, val, dtype=dtype)
  2474. for dtype, val in zip(dtypes, curr_values)
  2475. ]
  2476. if cuda:
  2477. tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors]
  2478. tensor_shapes = []
  2479. for tensor in tensors:
  2480. if tensor.dtype == torch.complex64:
  2481. tensor_shapes.append(torch.view_as_real(tensor).shape)
  2482. else:
  2483. tensor_shapes.append(tensor.shape)
  2484. self.call_dist_op(
  2485. ":all_reduce",
  2486. False,
  2487. dist.all_reduce_coalesced,
  2488. tensors,
  2489. op,
  2490. group_id,
  2491. tensor_shapes=tensor_shapes,
  2492. )
  2493. expected_tensors = [
  2494. _build_tensor(src + 1, expected_value, dtype=dtype)
  2495. for dtype, expected_value in zip(dtypes, expected_values)
  2496. ]
  2497. self.assertEqual(tensors, expected_tensors)
  2498. self._barrier()
  2499. @require_backend({"gloo"})
  2500. def test_all_reduce_coalesced_sum(self):
  2501. group, group_id, rank = self._init_global_test()
  2502. self._test_all_reduce_coalesced_helper(
  2503. group,
  2504. group_id,
  2505. rank,
  2506. dist.ReduceOp.SUM,
  2507. cuda=False,
  2508. rank_to_GPU=None,
  2509. )
  2510. @require_backend({"gloo"})
  2511. def test_all_reduce_coalesced_product(self):
  2512. group, group_id, rank = self._init_global_test()
  2513. self._test_all_reduce_coalesced_helper(
  2514. group,
  2515. group_id,
  2516. rank,
  2517. dist.ReduceOp.PRODUCT,
  2518. cuda=False,
  2519. rank_to_GPU=None,
  2520. )
  2521. @require_backend({"gloo"})
  2522. def test_all_reduce_coalesced_min(self):
  2523. group, group_id, rank = self._init_global_test()
  2524. self._test_all_reduce_coalesced_helper(
  2525. group,
  2526. group_id,
  2527. rank,
  2528. dist.ReduceOp.MIN,
  2529. cuda=False,
  2530. rank_to_GPU=None,
  2531. )
  2532. @require_backend({"gloo"})
  2533. def test_all_reduce_coalesced_max(self):
  2534. group, group_id, rank = self._init_global_test()
  2535. self._test_all_reduce_coalesced_helper(
  2536. group, group_id, rank, dist.ReduceOp.MAX, cuda=False, rank_to_GPU=None
  2537. )
  2538. @skip_if_small_worldsize
  2539. @require_backend({"gloo"})
  2540. def test_all_reduce_coalesced_group_sum(self):
  2541. group, group_id, rank = self._init_group_test()
  2542. self._test_all_reduce_coalesced_helper(
  2543. group, group_id, rank, dist.ReduceOp.SUM, cuda=False, rank_to_GPU=None
  2544. )
  2545. @skip_if_small_worldsize
  2546. @require_backend({"gloo"})
  2547. def test_all_reduce_coalesced_group_product(self):
  2548. group, group_id, rank = self._init_group_test()
  2549. self._test_all_reduce_coalesced_helper(
  2550. group,
  2551. group_id,
  2552. rank,
  2553. dist.ReduceOp.PRODUCT,
  2554. cuda=False,
  2555. rank_to_GPU=None,
  2556. )
  2557. @skip_if_small_worldsize
  2558. @require_backend({"gloo"})
  2559. def test_all_reduce_coalesced_group_min(self):
  2560. group, group_id, rank = self._init_group_test()
  2561. self._test_all_reduce_coalesced_helper(
  2562. group, group_id, rank, dist.ReduceOp.MIN, cuda=False, rank_to_GPU=None
  2563. )
  2564. @skip_if_small_worldsize
  2565. @require_backend({"gloo"})
  2566. def test_all_reduce_coalesced_group_max(self):
  2567. group, group_id, rank = self._init_group_test()
  2568. self._test_all_reduce_coalesced_helper(
  2569. group, group_id, rank, dist.ReduceOp.MAX, cuda=False, rank_to_GPU=None
  2570. )
  2571. @require_backend({"gloo"})
  2572. def test_all_reduce_coalesced_full_group_sum(self):
  2573. group, group_id, rank = self._init_full_group_test()
  2574. self._test_all_reduce_coalesced_helper(
  2575. group, group_id, rank, dist.ReduceOp.SUM, cuda=False, rank_to_GPU=None
  2576. )
  2577. @require_backend({"gloo"})
  2578. def test_all_reduce_coalesced_full_group_product(self):
  2579. group, group_id, rank = self._init_full_group_test()
  2580. self._test_all_reduce_coalesced_helper(
  2581. group,
  2582. group_id,
  2583. rank,
  2584. dist.ReduceOp.PRODUCT,
  2585. cuda=False,
  2586. rank_to_GPU=None,
  2587. )
  2588. @require_backend({"gloo"})
  2589. def test_all_reduce_coalesced_full_group_min(self):
  2590. group, group_id, rank = self._init_full_group_test()
  2591. self._test_all_reduce_coalesced_helper(
  2592. group,
  2593. group_id,
  2594. rank,
  2595. dist.ReduceOp.MIN,
  2596. cuda=False,
  2597. rank_to_GPU=None,
  2598. )
  2599. @require_backend({"gloo"})
  2600. def test_all_reduce_coalesced_full_group_max(self):
  2601. group, group_id, rank = self._init_full_group_test()
  2602. self._test_all_reduce_coalesced_helper(
  2603. group, group_id, rank, dist.ReduceOp.MAX, cuda=False, rank_to_GPU=None
  2604. )
  2605. # SCATTER
  2606. def _test_scatter_helper(
  2607. self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float
  2608. ):
  2609. for dest in group:
  2610. tensor = _build_tensor(dest + 1, -1, dtype=dtype)
  2611. expected_tensor = _build_tensor(dest + 1, rank, dtype=dtype)
  2612. tensors = (
  2613. [_build_tensor(dest + 1, i, dtype=dtype) for i in group]
  2614. if rank == dest
  2615. else []
  2616. )
  2617. if cuda:
  2618. tensor = tensor.cuda(rank_to_GPU[rank][0])
  2619. tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors]
  2620. if dtype == torch.complex64:
  2621. tensor_shapes = [torch.view_as_real(t).shape for t in tensors]
  2622. else:
  2623. tensor_shapes = [t.shape for t in tensors]
  2624. self.call_dist_op(
  2625. ":scatter",
  2626. False,
  2627. dist.scatter,
  2628. tensor,
  2629. src=dest,
  2630. scatter_list=tensors,
  2631. group=group_id,
  2632. expect_event=False,
  2633. tensor_shapes=tensor_shapes,
  2634. )
  2635. self.assertEqual(tensor, expected_tensor)
  2636. self._barrier()
  2637. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  2638. @sandcastle_skip_if(BACKEND == "ucc", "CPU tensor ops not supported by UCP TL")
  2639. def test_scatter_checks(self):
  2640. group, group_id, rank = self._init_global_test()
  2641. one = torch.ones([1])
  2642. # Specify scatter_list argument only on source rank.
  2643. output = one.clone() * -1
  2644. if rank == 0:
  2645. scatter_list = [one.clone() * i for i in group]
  2646. dist.scatter(output, src=0, scatter_list=scatter_list)
  2647. else:
  2648. dist.scatter(output, src=0)
  2649. self.assertEqual(output, one * rank)
  2650. # Don't specify src argument.
  2651. output = one.clone() * -1
  2652. if rank == 0:
  2653. scatter_list = [one.clone() * i for i in group]
  2654. dist.scatter(output, scatter_list=scatter_list)
  2655. else:
  2656. dist.scatter(output)
  2657. self.assertEqual(output, one * rank)
  2658. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  2659. @sandcastle_skip_if(BACKEND == "ucc", "CPU tensor ops not supported by UCP TL")
  2660. def test_scatter(self):
  2661. group, group_id, rank = self._init_global_test()
  2662. self._test_scatter_helper(group, group_id, rank)
  2663. @sandcastle_skip_if(BACKEND != "nccl", "Only Nccl supports CUDA gather")
  2664. @skip_if_no_gpu
  2665. def test_scatter_cuda(self):
  2666. group, group_id, rank = self._init_global_test()
  2667. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  2668. self._test_scatter_helper(group, group_id, rank, True, rank_to_GPU)
  2669. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  2670. @sandcastle_skip_if(BACKEND == "ucc", "CPU tensor ops not supported by UCP TL")
  2671. def test_scatter_complex(self):
  2672. group, group_id, rank = self._init_global_test()
  2673. self._test_scatter_helper(group, group_id, rank, dtype=torch.cfloat)
  2674. @sandcastle_skip_if(BACKEND != "nccl", "Only Nccl supports CUDA gather")
  2675. @skip_if_no_gpu
  2676. def test_scatter_cuda_complex(self):
  2677. group, group_id, rank = self._init_global_test()
  2678. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  2679. self._test_scatter_helper(group, group_id, rank, True, rank_to_GPU, dtype=torch.cfloat)
  2680. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  2681. @sandcastle_skip_if(BACKEND == "ucc", "CPU tensor ops not supported by UCP TL")
  2682. @skip_if_small_worldsize
  2683. def test_scatter_group(self):
  2684. group, group_id, rank = self._init_group_test()
  2685. self._test_scatter_helper(group, group_id, rank)
  2686. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  2687. @sandcastle_skip_if(BACKEND == "ucc", "CPU tensor ops not supported by UCP TL")
  2688. def test_scatter_full_group(self):
  2689. group, group_id, rank = self._init_full_group_test()
  2690. self._test_scatter_helper(group, group_id, rank)
  2691. # GATHER
  2692. def _test_gather_helper(self, group, group_id, rank, cuda=False, rank_to_GPU=None):
  2693. for dest in group:
  2694. tensor = _build_tensor(dest + 1, rank)
  2695. tensors = (
  2696. [_build_tensor(dest + 1, -1) for i in group] if rank == dest else []
  2697. )
  2698. if cuda:
  2699. tensor = tensor.cuda(rank_to_GPU[rank][0])
  2700. tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors]
  2701. self.call_dist_op(
  2702. ":gather",
  2703. False,
  2704. dist.gather,
  2705. tensor,
  2706. dst=dest,
  2707. gather_list=tensors,
  2708. group=group_id,
  2709. expect_event=False,
  2710. tensor_shapes=[tensors[0].shape] if len(tensors) > 0 else None,
  2711. )
  2712. if rank == dest:
  2713. expected_tensors = [_build_tensor(dest + 1, i) for i in group]
  2714. for t1, t2 in zip(tensors, expected_tensors):
  2715. self.assertEqual(t1, t2)
  2716. self._barrier()
  2717. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  2718. @sandcastle_skip_if(BACKEND == "ucc", "CPU tensor ops not supported by UCP TL")
  2719. def test_gather_checks(self):
  2720. group, group_id, rank = self._init_global_test()
  2721. one = torch.ones([1])
  2722. # Specify gather_list argument only on destination rank.
  2723. if rank == 0:
  2724. gather_list = [one.clone() for _ in group]
  2725. dist.gather(one * rank, dst=0, gather_list=gather_list)
  2726. for i in group:
  2727. self.assertEqual(gather_list[i], one * i)
  2728. else:
  2729. dist.gather(one * rank, dst=0)
  2730. # Don't specify dst argument.
  2731. if rank == 0:
  2732. gather_list = [one.clone() for _ in group]
  2733. dist.gather(one * rank, gather_list=gather_list)
  2734. for i in group:
  2735. self.assertEqual(gather_list[i], one * i)
  2736. else:
  2737. dist.gather(one * rank)
  2738. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  2739. @sandcastle_skip_if(BACKEND == "ucc", "CPU tensor ops not supported by UCP TL")
  2740. def test_gather(self):
  2741. group, group_id, rank = self._init_global_test()
  2742. self._test_gather_helper(group, group_id, rank)
  2743. @sandcastle_skip_if(BACKEND != "nccl", "Only Nccl supports CUDA gather")
  2744. @skip_if_no_gpu
  2745. def test_gather_cuda(self):
  2746. group, group_id, rank = self._init_global_test()
  2747. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  2748. self._test_gather_helper(group, group_id, rank, True, rank_to_GPU)
  2749. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  2750. @sandcastle_skip_if(BACKEND == "ucc", "CPU tensor ops not supported by UCP TL")
  2751. @skip_if_small_worldsize
  2752. def test_gather_group(self):
  2753. group, group_id, rank = self._init_group_test()
  2754. self._test_gather_helper(group, group_id, rank)
  2755. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  2756. @sandcastle_skip_if(BACKEND == "ucc", "CPU tensor ops not supported by UCP TL")
  2757. def test_gather_full_group(self):
  2758. group, group_id, rank = self._init_full_group_test()
  2759. self._test_gather_helper(group, group_id, rank)
  2760. # ALL GATHER
  2761. def _test_all_gather_helper(
  2762. self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float
  2763. ):
  2764. for dest in group:
  2765. tensor = _build_tensor(dest + 1, rank, dtype=dtype)
  2766. tensors = [_build_tensor(dest + 1, -1, dtype=dtype) for i in group]
  2767. allgather = dist.all_gather
  2768. if cuda:
  2769. tensor = tensor.cuda(rank_to_GPU[rank][0])
  2770. tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors]
  2771. if tensors[0].dtype == torch.complex64:
  2772. tensor_shapes = [torch.view_as_real(tensors[0]).shape]
  2773. else:
  2774. tensor_shapes = [tensors[0].shape]
  2775. self.call_dist_op(
  2776. ":all_gather",
  2777. False,
  2778. allgather,
  2779. tensors,
  2780. tensor,
  2781. group_id,
  2782. False,
  2783. tensor_shapes=tensor_shapes,
  2784. )
  2785. expected_tensors = [
  2786. _build_tensor(dest + 1, i, dtype=dtype) for i in group
  2787. ]
  2788. for t1, t2 in zip(tensors, expected_tensors):
  2789. self.assertEqual(t1, t2)
  2790. self._barrier()
  2791. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  2792. def test_all_gather(self):
  2793. group, group_id, rank = self._init_global_test()
  2794. self._test_all_gather_helper(group, group_id, rank)
  2795. @sandcastle_skip_if(BACKEND != "nccl", "Only Nccl supports CUDA all gather")
  2796. @skip_if_no_gpu
  2797. def test_all_gather_cuda(self):
  2798. group, group_id, rank = self._init_global_test()
  2799. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  2800. self._test_all_gather_helper(group, group_id, rank, True, rank_to_GPU)
  2801. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  2802. def test_all_gather_complex(self):
  2803. group, group_id, rank = self._init_global_test()
  2804. self._test_all_gather_helper(group, group_id, rank, dtype=torch.cfloat)
  2805. @sandcastle_skip_if(BACKEND != "nccl", "Only Nccl supports CUDA all gather")
  2806. @skip_if_no_gpu
  2807. def test_all_gather_cuda_complex(self):
  2808. group, group_id, rank = self._init_global_test()
  2809. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  2810. self._test_all_gather_helper(
  2811. group, group_id, rank, True, rank_to_GPU, dtype=torch.cfloat
  2812. )
  2813. @skip_if_small_worldsize
  2814. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  2815. def test_all_gather_group(self):
  2816. group, group_id, rank = self._init_group_test()
  2817. self._test_all_gather_helper(group, group_id, rank)
  2818. @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
  2819. def test_all_gather_full_group(self):
  2820. group, group_id, rank = self._init_full_group_test()
  2821. self._test_all_gather_helper(group, group_id, rank)
  2822. @sandcastle_skip_if(BACKEND != "nccl", "Only Nccl supports all_gather_v")
  2823. @skip_if_no_gpu
  2824. def test_all_gather_v_cuda(self):
  2825. self._barrier()
  2826. group, group_id, rank = self._init_global_test()
  2827. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  2828. device_id = rank_to_GPU[rank][0]
  2829. output_split_sizes = []
  2830. for dst in group:
  2831. output_split_sizes.append(dst + 1)
  2832. sum_len = sum(output_split_sizes)
  2833. value = 2
  2834. for async_val in [True, False]:
  2835. tensor = torch.empty(output_split_sizes[rank], sum_len, sum_len, dtype=torch.float).fill_(value).cuda(device_id)
  2836. out_tensor = _build_tensor(sum_len, -1, device_id=device_id)
  2837. req = dist.all_gather(
  2838. list(torch.split(out_tensor, output_split_sizes)),
  2839. tensor,
  2840. group_id,
  2841. async_val,
  2842. )
  2843. if async_val:
  2844. req.wait()
  2845. expected_value = value
  2846. expected_tensor = _build_tensor(sum_len, expected_value, device_id=device_id)
  2847. self.assertEqual(out_tensor, expected_tensor)
  2848. self._barrier()
  2849. # Test all_gather accepting single tensor as output
  2850. def _all_gather_into_tensor_helper(
  2851. self, tensor_out, tensor_in,
  2852. group_id, rank, cuda=True, rank_to_GPU=None
  2853. ):
  2854. if cuda:
  2855. tensor_in = tensor_in.cuda(rank_to_GPU[rank][0])
  2856. tensor_out = tensor_out.cuda(rank_to_GPU[rank][0])
  2857. if tensor_out.dtype == torch.complex64:
  2858. tensor_shapes = [torch.view_as_real(tensor_in).shape]
  2859. else:
  2860. tensor_shapes = [tensor_in.shape]
  2861. self.call_dist_op(
  2862. ":all_gather_into_tensor",
  2863. False,
  2864. dist.all_gather_into_tensor,
  2865. tensor_out,
  2866. tensor_in,
  2867. group_id,
  2868. False,
  2869. expect_event=False,
  2870. tensor_shapes=tensor_shapes,
  2871. )
  2872. return tensor_out
  2873. @sandcastle_skip_if(BACKEND != "nccl", "Only Nccl supports CUDA all_gather_into_tensor")
  2874. @skip_if_no_gpu
  2875. def test_all_gather_into_cat_tensor_cuda(self):
  2876. group, group_id, rank = self._init_global_test()
  2877. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  2878. size = 2
  2879. tensor_in = torch.ones([size, size]) * rank
  2880. # Concatenated output
  2881. tensor_out = torch.ones([len(group) * size, size]) * (-1)
  2882. tensor_out = self._all_gather_into_tensor_helper(tensor_out, tensor_in, group_id, rank, True, rank_to_GPU)
  2883. # Check result
  2884. # Concatenate all blocks into a bigger tensor
  2885. expected_tensor = torch.cat([
  2886. torch.ones([size, size]) * i for i in group
  2887. ])
  2888. self.assertEqual(tensor_out, expected_tensor)
  2889. self._barrier()
  2890. @sandcastle_skip_if(BACKEND != "nccl", "Only Nccl supports CUDA all_gather_into_tensor")
  2891. @skip_if_no_gpu
  2892. def test_all_gather_into_stack_tensor_cuda(self):
  2893. group, group_id, rank = self._init_global_test()
  2894. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  2895. size = 2
  2896. tensor_in = torch.ones([size, size]) * rank
  2897. # Stacked output
  2898. tensor_out = torch.ones([len(group), size, size]) * (-1)
  2899. tensor_out = self._all_gather_into_tensor_helper(tensor_out, tensor_in, group_id, rank, True, rank_to_GPU)
  2900. # Check result
  2901. # Stack all blocks into a bigger tensor
  2902. expected_tensor = torch.stack([
  2903. torch.ones([size, size]) * i for i in group
  2904. ])
  2905. self.assertEqual(tensor_out, expected_tensor)
  2906. self._barrier()
  2907. def _run_all_gather_coalesced_and_verify(
  2908. self, output_tensor_lists, input_tensors, expected_tensors, group_id
  2909. ):
  2910. """
  2911. Helper that runs all_gather_coalesced and returns true if output
  2912. matches expectations.
  2913. """
  2914. tensor_shapes = []
  2915. for input_tensor in input_tensors:
  2916. if input_tensor.dtype == torch.complex64:
  2917. tensor_shapes.append(torch.view_as_real(input_tensor).shape)
  2918. else:
  2919. tensor_shapes.append(input_tensor.shape)
  2920. self.call_dist_op(
  2921. ":all_gather",
  2922. False,
  2923. dist.all_gather_coalesced,
  2924. output_tensor_lists,
  2925. input_tensors,
  2926. group_id,
  2927. tensor_shapes=tensor_shapes,
  2928. )
  2929. for l1, l2 in zip(output_tensor_lists, expected_tensors):
  2930. for t1, t2 in zip(l1, l2):
  2931. if not torch.equal(t1, t2):
  2932. return False
  2933. return True
  2934. def _test_all_gather_coalesced_helper(
  2935. self, group, group_id, rank, dtype=torch.float
  2936. ):
  2937. # TODO: Instead we should probably go through _rank_not_in_group
  2938. # mechanism to disable sending tensors
  2939. if group_id is not None:
  2940. for test_case_id in range(2, 5):
  2941. # Make sure we create tensors of incompatible sizes, e.g.
  2942. # [1], [2x2], [3x3x3] ... to be sent in one batch
  2943. input_tensors = [
  2944. _build_multidim_tensor(
  2945. tensor_id, tensor_id, rank + tensor_id, dtype=dtype
  2946. )
  2947. for tensor_id in range(1, test_case_id)
  2948. ]
  2949. output_tensor_lists = [
  2950. [
  2951. _build_multidim_tensor(
  2952. tensor_id, tensor_id, -1, dtype=dtype
  2953. )
  2954. for tensor_id in range(1, test_case_id)
  2955. ]
  2956. for _ in group
  2957. ]
  2958. expected_tensors = [
  2959. [
  2960. _build_multidim_tensor(
  2961. tensor_id, tensor_id, rank_iter + tensor_id, dtype=dtype
  2962. )
  2963. for tensor_id in range(1, test_case_id)
  2964. ]
  2965. for rank_iter in group
  2966. ]
  2967. assert self._run_all_gather_coalesced_and_verify(
  2968. output_tensor_lists, input_tensors, expected_tensors, group_id
  2969. ), "output tensors do not match expected ouputs"
  2970. self._barrier()
  2971. @sandcastle_skip_if(
  2972. BACKEND in DistTestCases.skip_collective["allgather_coalesced"],
  2973. f"{BACKEND} does not support all_gather_coalesced"
  2974. )
  2975. def test_all_gather_coalesced_simple(self):
  2976. group, group_id, rank = self._init_global_test()
  2977. self._test_all_gather_coalesced_helper(group, group_id, rank)
  2978. @sandcastle_skip_if(
  2979. BACKEND in DistTestCases.skip_collective["allgather_coalesced"],
  2980. f"{BACKEND} does not support all_gather_coalesced"
  2981. )
  2982. def test_all_gather_coalesced_complex(self):
  2983. group, group_id, rank = self._init_global_test()
  2984. self._test_all_gather_coalesced_helper(
  2985. group, group_id, rank, dtype=torch.cfloat
  2986. )
  2987. @skip_if_small_worldsize
  2988. @sandcastle_skip_if(
  2989. BACKEND in DistTestCases.skip_collective["allgather_coalesced"],
  2990. f"{BACKEND} does not support all_gather_coalesced"
  2991. )
  2992. def test_all_gather_coalesced_group(self):
  2993. group, group_id, rank = self._init_group_test()
  2994. self._test_all_gather_coalesced_helper(group, group_id, rank)
  2995. @sandcastle_skip_if(
  2996. BACKEND in DistTestCases.skip_collective["allgather_coalesced"],
  2997. f"{BACKEND} does not support all_gather_coalesced"
  2998. )
  2999. def test_all_gather_coalesced_full_group(self):
  3000. group, group_id, rank = self._init_full_group_test()
  3001. self._test_all_gather_coalesced_helper(group, group_id, rank)
  3002. @sandcastle_skip_if(
  3003. BACKEND in DistTestCases.skip_collective["allgather_coalesced"],
  3004. f"{BACKEND} does not support all_gather_coalesced"
  3005. )
  3006. def test_all_gather_coalesced_with_empty(self):
  3007. group, group_id, rank = self._init_global_test()
  3008. input_tensors = [
  3009. rank * torch.ones([2, 2]),
  3010. torch.ones([0]),
  3011. (rank + 1) * torch.ones([3, 3]),
  3012. torch.ones([0]),
  3013. torch.ones([0]),
  3014. ]
  3015. output_tensors_lists = [
  3016. [
  3017. -1 * torch.ones([2, 2]),
  3018. -1 * torch.ones([0]),
  3019. -1 * torch.ones([3, 3]),
  3020. -1 * torch.ones([0]),
  3021. -1 * torch.ones([0]),
  3022. ]
  3023. for _ in group
  3024. ]
  3025. expected_tensors = [
  3026. [
  3027. r * torch.ones([2, 2]),
  3028. torch.ones([0]),
  3029. (r + 1) * torch.ones([3, 3]),
  3030. torch.ones([0]),
  3031. torch.ones([0]),
  3032. ]
  3033. for r in group
  3034. ]
  3035. assert self._run_all_gather_coalesced_and_verify(
  3036. output_tensors_lists, input_tensors, expected_tensors, group_id
  3037. )
  3038. self._barrier()
  3039. # AllToAll
  3040. def _test_all_to_all_single_equal_split_helper(
  3041. self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float
  3042. ):
  3043. if group_id is not None:
  3044. size = len(group)
  3045. in_tensor = torch.ones([size, size], dtype=dtype) * rank
  3046. expected_tensor = torch.cat(
  3047. [torch.ones([1, size], dtype=dtype) * i for i in group]
  3048. )
  3049. out_tensor = torch.ones([size, size], dtype=dtype) * -1
  3050. if cuda:
  3051. in_tensor = in_tensor.cuda(rank_to_GPU[rank][0])
  3052. expected_tensor = expected_tensor.cuda(rank_to_GPU[rank][0])
  3053. out_tensor = out_tensor.cuda(rank_to_GPU[rank][0])
  3054. if dtype == torch.complex64:
  3055. tensor_shapes = [torch.view_as_real(in_tensor).shape]
  3056. else:
  3057. tensor_shapes = [in_tensor.shape]
  3058. self.call_dist_op(
  3059. ":all_to_all",
  3060. False,
  3061. dist.all_to_all_single,
  3062. out_tensor,
  3063. in_tensor,
  3064. group=group_id,
  3065. tensor_shapes=tensor_shapes,
  3066. )
  3067. self.assertEqual(out_tensor, expected_tensor)
  3068. self._barrier()
  3069. def _test_all_to_all_single_unequal_split_helper(
  3070. self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float
  3071. ):
  3072. if group_id is not None:
  3073. size = len(group)
  3074. in_splits = [i + 1 for i in group]
  3075. out_splits = [rank + 1 for _ in group]
  3076. in_tensor = torch.ones([sum(in_splits), size], dtype=dtype) * rank
  3077. out_tensor = torch.ones([(rank + 1) * size, size], dtype=dtype)
  3078. expected_tensor = torch.cat(
  3079. [torch.ones([rank + 1, size], dtype=dtype) * i for i in group]
  3080. )
  3081. if cuda:
  3082. in_tensor = in_tensor.cuda(rank_to_GPU[rank][0])
  3083. expected_tensor = expected_tensor.cuda(rank_to_GPU[rank][0])
  3084. out_tensor = out_tensor.cuda(rank_to_GPU[rank][0])
  3085. dist.all_to_all_single(
  3086. out_tensor, in_tensor, out_splits, in_splits, group=group_id
  3087. )
  3088. self.assertEqual(out_tensor, expected_tensor)
  3089. self._barrier()
  3090. def _test_all_to_all_helper(
  3091. self,
  3092. group,
  3093. group_id,
  3094. rank,
  3095. cuda=False,
  3096. rank_to_GPU=None,
  3097. dtype=torch.float,
  3098. ):
  3099. if group_id is not None:
  3100. size = len(group)
  3101. in_splits = [i + 1 for i in group]
  3102. in_tensors = [
  3103. torch.ones([in_splits[i], size], dtype=dtype) * rank
  3104. for i, _ in enumerate(group)
  3105. ]
  3106. out_tensors = [
  3107. torch.ones([(rank + 1), size], dtype=dtype) for _ in group
  3108. ]
  3109. expected_tensors = [
  3110. torch.ones([rank + 1, size], dtype=dtype) * i for i in group
  3111. ]
  3112. if cuda:
  3113. in_tensors = [t.cuda(rank_to_GPU[rank][0]) for t in in_tensors]
  3114. expected_tensors = [
  3115. t.cuda(rank_to_GPU[rank][0]) for t in expected_tensors
  3116. ]
  3117. out_tensors = [t.cuda(rank_to_GPU[rank][0]) for t in out_tensors]
  3118. dist.all_to_all(out_tensors, in_tensors, group=group_id)
  3119. for t1, t2 in zip(out_tensors, expected_tensors):
  3120. self.assertEqual(t1, t2)
  3121. self._barrier()
  3122. @sandcastle_skip_if(BACKEND != "mpi", "Only MPI supports CPU all_to_all_single")
  3123. def test_all_to_all_single_equal_split(self):
  3124. group, group_id, rank = self._init_global_test()
  3125. self._test_all_to_all_single_equal_split_helper(group, group_id, rank)
  3126. @sandcastle_skip_if(BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single")
  3127. @skip_if_no_gpu
  3128. def test_all_to_all_single_equal_split_cuda(self):
  3129. group, group_id, rank = self._init_global_test()
  3130. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  3131. self._test_all_to_all_single_equal_split_helper(
  3132. group,
  3133. group_id,
  3134. rank,
  3135. True,
  3136. rank_to_GPU,
  3137. )
  3138. @sandcastle_skip_if(BACKEND != "mpi", "Only MPI supports CPU all_to_all_single")
  3139. def test_all_to_all_single_equal_split_complex(self):
  3140. group, group_id, rank = self._init_global_test()
  3141. self._test_all_to_all_single_equal_split_helper(
  3142. group, group_id, rank, dtype=torch.cfloat
  3143. )
  3144. @sandcastle_skip_if(BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single")
  3145. @skip_if_no_gpu
  3146. def test_all_to_all_single_equal_split_cuda_complex(self):
  3147. group, group_id, rank = self._init_global_test()
  3148. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  3149. self._test_all_to_all_single_equal_split_helper(
  3150. group, group_id, rank, True, rank_to_GPU, dtype=torch.cfloat
  3151. )
  3152. @sandcastle_skip_if(BACKEND != "mpi", "Only MPI supports CPU all_to_all_single")
  3153. def test_all_to_all_single_unequal_split(self):
  3154. group, group_id, rank = self._init_global_test()
  3155. self._test_all_to_all_single_unequal_split_helper(group, group_id, rank)
  3156. @sandcastle_skip_if(BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single")
  3157. @skip_if_no_gpu
  3158. def test_all_to_all_single_unequal_split_cuda(self):
  3159. group, group_id, rank = self._init_global_test()
  3160. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  3161. self._test_all_to_all_single_unequal_split_helper(
  3162. group,
  3163. group_id,
  3164. rank,
  3165. True,
  3166. rank_to_GPU,
  3167. )
  3168. @sandcastle_skip_if(BACKEND != "mpi", "Only MPI supports CPU all_to_all_single")
  3169. def test_all_to_all_single_unequal_split_complex(self):
  3170. group, group_id, rank = self._init_global_test()
  3171. self._test_all_to_all_single_unequal_split_helper(
  3172. group, group_id, rank, dtype=torch.cfloat
  3173. )
  3174. @sandcastle_skip_if(BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single")
  3175. @skip_if_no_gpu
  3176. def test_all_to_all_single_unequal_split_cuda_complex(self):
  3177. group, group_id, rank = self._init_global_test()
  3178. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  3179. self._test_all_to_all_single_unequal_split_helper(
  3180. group,
  3181. group_id,
  3182. rank,
  3183. True,
  3184. rank_to_GPU,
  3185. dtype=torch.cfloat,
  3186. )
  3187. @sandcastle_skip_if(BACKEND != "mpi", "Only MPI supports all_to_all")
  3188. def test_all_to_all(self):
  3189. group, group_id, rank = self._init_global_test()
  3190. self._test_all_to_all_helper(group, group_id, rank)
  3191. @sandcastle_skip_if(BACKEND != "nccl", "Only NCCL supports CUDA all_to_all")
  3192. @skip_if_rocm
  3193. def test_all_to_all_cuda(self):
  3194. group, group_id, rank = self._init_global_test()
  3195. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  3196. self._test_all_to_all_helper(group, group_id, rank, True, rank_to_GPU)
  3197. @sandcastle_skip_if(BACKEND != "mpi", "Only MPI supports all_to_all")
  3198. def test_all_to_all_complex(self):
  3199. group, group_id, rank = self._init_global_test()
  3200. self._test_all_to_all_helper(group, group_id, rank, dtype=torch.cfloat)
  3201. @sandcastle_skip_if(BACKEND != "nccl", "Only NCCL supports CUDA all_to_all")
  3202. @skip_if_rocm
  3203. def test_all_to_all_cuda_complex(self):
  3204. group, group_id, rank = self._init_global_test()
  3205. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  3206. self._test_all_to_all_helper(
  3207. group, group_id, rank, True, rank_to_GPU, dtype=torch.cfloat
  3208. )
  3209. @sandcastle_skip_if(BACKEND != "mpi", "Only MPI supports CPU all_to_all_single")
  3210. @skip_if_small_worldsize
  3211. def test_all_to_all_single_equal_split_group(self):
  3212. group, group_id, rank = self._init_group_test()
  3213. self._test_all_to_all_single_equal_split_helper(group, group_id, rank)
  3214. @sandcastle_skip_if(BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single")
  3215. @skip_if_no_gpu
  3216. @skip_if_small_worldsize
  3217. def test_all_to_all_single_equal_split_group_cuda(self):
  3218. group, group_id, rank = self._init_group_test()
  3219. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  3220. self._test_all_to_all_single_equal_split_helper(
  3221. group,
  3222. group_id,
  3223. rank,
  3224. True,
  3225. rank_to_GPU,
  3226. )
  3227. @sandcastle_skip_if(BACKEND != "mpi", "Only MPI supports CPU all_to_all_single")
  3228. @skip_if_small_worldsize
  3229. def test_all_to_all_single_unequal_split_group(self):
  3230. group, group_id, rank = self._init_group_test()
  3231. self._test_all_to_all_single_unequal_split_helper(group, group_id, rank)
  3232. @sandcastle_skip_if(BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single")
  3233. @skip_if_no_gpu
  3234. @skip_if_small_worldsize
  3235. def test_all_to_all_single_unequal_split_group_cuda(self):
  3236. group, group_id, rank = self._init_global_test()
  3237. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  3238. self._test_all_to_all_single_unequal_split_helper(
  3239. group,
  3240. group_id,
  3241. rank,
  3242. True,
  3243. rank_to_GPU,
  3244. )
  3245. @sandcastle_skip_if(BACKEND != "mpi", "Only MPI supports all_to_all")
  3246. @skip_if_small_worldsize
  3247. def test_all_to_all_group(self):
  3248. group, group_id, rank = self._init_group_test()
  3249. self._test_all_to_all_helper(group, group_id, rank)
  3250. @sandcastle_skip_if(BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single")
  3251. @skip_if_small_worldsize
  3252. @skip_if_rocm
  3253. def test_all_to_all_group_cuda(self):
  3254. group, group_id, rank = self._init_group_test()
  3255. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  3256. self._test_all_to_all_helper(group, group_id, rank, True, rank_to_GPU)
  3257. @sandcastle_skip_if(BACKEND != "mpi", "Only MPI supports CPU all_to_all_single")
  3258. def test_all_to_all_single_equal_split_full_group(self):
  3259. group, group_id, rank = self._init_full_group_test()
  3260. self._test_all_to_all_single_equal_split_helper(group, group_id, rank)
  3261. @sandcastle_skip_if(BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single")
  3262. @skip_if_no_gpu
  3263. def test_all_to_all_single_equal_split_full_group_cuda(self):
  3264. group, group_id, rank = self._init_full_group_test()
  3265. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  3266. self._test_all_to_all_single_equal_split_helper(
  3267. group,
  3268. group_id,
  3269. rank,
  3270. True,
  3271. rank_to_GPU,
  3272. )
  3273. @sandcastle_skip_if(BACKEND != "mpi", "Only MPI supports CPU all_to_all_single")
  3274. def test_all_to_all_single_unequal_split_full_group(self):
  3275. group, group_id, rank = self._init_full_group_test()
  3276. self._test_all_to_all_single_unequal_split_helper(group, group_id, rank)
  3277. @sandcastle_skip_if(BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single")
  3278. @skip_if_no_gpu
  3279. def test_all_to_all_single_unequal_split_full_group_cuda(self):
  3280. group, group_id, rank = self._init_full_group_test()
  3281. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  3282. self._test_all_to_all_single_unequal_split_helper(
  3283. group,
  3284. group_id,
  3285. rank,
  3286. True,
  3287. rank_to_GPU,
  3288. )
  3289. @sandcastle_skip_if(BACKEND != "mpi", "Only MPI supports all_to_all")
  3290. def test_all_to_all_full_group(self):
  3291. group, group_id, rank = self._init_full_group_test()
  3292. self._test_all_to_all_helper(group, group_id, rank)
  3293. @sandcastle_skip_if(BACKEND != "nccl", "Only NCCL supports CUDA all_to_all")
  3294. @skip_if_rocm
  3295. def test_all_to_all_full_group_cuda(self):
  3296. group, group_id, rank = self._init_full_group_test()
  3297. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  3298. self._test_all_to_all_helper(group, group_id, rank, True, rank_to_GPU)
  3299. # BARRIER
  3300. def _test_barrier_helper(
  3301. self, group, group_id, rank, cuda=False, rank_to_GPU=None
  3302. ):
  3303. WAIT_TIME = 0.3 # seconds
  3304. for dest in group:
  3305. expected_time = torch.DoubleTensor(1).fill_(0.0)
  3306. if cuda:
  3307. expected_time = expected_time.cuda(rank_to_GPU[rank][0])
  3308. if dest == rank:
  3309. expected_time.fill_(time.time() + WAIT_TIME)
  3310. dist.broadcast(expected_time, dest, group_id)
  3311. time.sleep(WAIT_TIME + 0.1) # sleep a little bit longer
  3312. dist.barrier(group_id)
  3313. else:
  3314. dist.broadcast(expected_time, dest, group_id)
  3315. dist.barrier(group_id)
  3316. self.assertGreaterAlmostEqual(
  3317. float(time.time()),
  3318. float(expected_time[0]),
  3319. "destination rank: %d, my rank: %d" % (dest, rank)
  3320. + " (if you see this failure, please report in #14554)",
  3321. )
  3322. # Use higher timeout for the instance where the test runs
  3323. # against a subgroup and uses a CUDA tensor for expected time.
  3324. # The CUDA initialization for the participating processes can
  3325. # take long enough for the barrier timeout to trigger on the
  3326. # process that doesn't participate in the group.
  3327. self._barrier(timeout=20)
  3328. @skip_if_no_gpu
  3329. @sandcastle_skip_if(BACKEND == "mpi", "MPI doesn't supports GPU barrier")
  3330. @sandcastle_skip_if(BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally")
  3331. def test_barrier_cuda(self):
  3332. group, group_id, rank = self._init_global_test()
  3333. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  3334. self._test_barrier_helper(group, group_id, rank, True, rank_to_GPU)
  3335. @skip_if_small_worldsize
  3336. @skip_if_no_gpu
  3337. @sandcastle_skip_if(BACKEND == "mpi", "MPI doesn't supports GPU barrier")
  3338. def test_barrier_group_cuda(self):
  3339. group, group_id, rank = self._init_group_test()
  3340. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  3341. self._test_barrier_helper(group, group_id, rank, True, rank_to_GPU)
  3342. @skip_if_small_worldsize
  3343. @skip_if_no_gpu
  3344. @sandcastle_skip_if(BACKEND == "mpi", "MPI doesn't supports GPU barrier")
  3345. def test_barrier_full_group_cuda(self):
  3346. group, group_id, rank = self._init_full_group_test()
  3347. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  3348. self._test_barrier_helper(group, group_id, rank, True, rank_to_GPU)
  3349. @sandcastle_skip_if(BACKEND in DistTestCases.skip_collective["cpu barrier"], f"{BACKEND} does not support CPU barrier")
  3350. def test_barrier(self):
  3351. group, group_id, rank = self._init_global_test()
  3352. self._test_barrier_helper(group, group_id, rank)
  3353. @skip_if_small_worldsize
  3354. @sandcastle_skip_if(BACKEND in DistTestCases.skip_collective["cpu barrier"], f"{BACKEND} does not support CPU barrier")
  3355. def test_barrier_group(self):
  3356. group, group_id, rank = self._init_group_test()
  3357. self._test_barrier_helper(group, group_id, rank)
  3358. @sandcastle_skip_if(BACKEND in DistTestCases.skip_collective["cpu barrier"], f"{BACKEND} does not support CPU barrier")
  3359. def test_barrier_full_group(self):
  3360. group, group_id, rank = self._init_full_group_test()
  3361. self._test_barrier_helper(group, group_id, rank)
  3362. def _test_broadcast_multigpu_helper(self, group, group_id, rank, rank_to_GPU):
  3363. for src in group:
  3364. expected_tensor = _build_tensor(src + 1)
  3365. tensors = [
  3366. _build_tensor(src + 1, -1).cuda(device=i) for i in rank_to_GPU[rank]
  3367. ]
  3368. if rank == src:
  3369. tensors[0] = expected_tensor.cuda(device=rank_to_GPU[rank][0])
  3370. dist.broadcast_multigpu(tensors, src, group_id)
  3371. for tensor in tensors:
  3372. self.assertEqual(tensor, expected_tensor)
  3373. self._barrier()
  3374. @sandcastle_skip_if(BACKEND == "mpi", "MPI doesn't support broadcast multigpu")
  3375. @sandcastle_skip_if(BACKEND == "nccl", "NCCL broadcast multigpu skipped")
  3376. @skip_if_no_gpu
  3377. def test_broadcast_multigpu(self):
  3378. group, group_id, rank = self._init_global_test()
  3379. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  3380. self._test_broadcast_multigpu_helper(group, group_id, rank, rank_to_GPU)
  3381. def _test_all_reduce_multigpu_helper(
  3382. self,
  3383. group,
  3384. group_id,
  3385. rank,
  3386. rank_to_GPU,
  3387. op,
  3388. master_value,
  3389. worker_value,
  3390. expected_value,
  3391. dtype=torch.float,
  3392. ):
  3393. for src in group:
  3394. curr_value = master_value if rank == src else worker_value
  3395. tensors = [
  3396. _build_tensor(src + 1, curr_value, dtype=dtype).cuda(device=i)
  3397. for i in rank_to_GPU[rank]
  3398. ]
  3399. self.call_dist_op(
  3400. ":all_reduce",
  3401. False,
  3402. dist.all_reduce_multigpu,
  3403. tensors,
  3404. op,
  3405. group_id,
  3406. )
  3407. expected_tensor = _build_tensor(src + 1, expected_value, dtype=dtype)
  3408. for tensor in tensors:
  3409. self.assertEqual(tensor, expected_tensor)
  3410. self._barrier()
  3411. @sandcastle_skip_if(BACKEND == "mpi", "MPI doesn't support broadcast multigpu")
  3412. @sandcastle_skip_if(BACKEND == "nccl", "CUDA all_reduce multigpu skipped for NCCL")
  3413. @sandcastle_skip_if(BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally")
  3414. @skip_if_no_gpu
  3415. def test_all_reduce_multigpu(self):
  3416. group, group_id, rank = self._init_global_test()
  3417. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  3418. self._test_all_reduce_multigpu_helper(
  3419. group,
  3420. group_id,
  3421. rank,
  3422. rank_to_GPU,
  3423. dist.ReduceOp.SUM,
  3424. 2,
  3425. 10,
  3426. (2 + 10 * (len(group) - 1)) * len(rank_to_GPU[0]),
  3427. )
  3428. @sandcastle_skip_if(BACKEND == "mpi", "MPI doesn't support broadcast multigpu")
  3429. @sandcastle_skip_if(BACKEND == "nccl", "CUDA all_reduce multigpu skipped for NCCL")
  3430. @sandcastle_skip_if(BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally")
  3431. @skip_if_no_gpu
  3432. def test_all_reduce_multigpu_complex(self):
  3433. group, group_id, rank = self._init_global_test()
  3434. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  3435. self._test_all_reduce_multigpu_helper(
  3436. group,
  3437. group_id,
  3438. rank,
  3439. rank_to_GPU,
  3440. dist.ReduceOp.SUM,
  3441. complex(2, 3),
  3442. complex(10, 11),
  3443. (complex(2, 3) + complex(10, 11) * (len(group) - 1))
  3444. * len(rank_to_GPU[0]),
  3445. dtype=torch.cfloat,
  3446. )
  3447. def _test_reduce_multigpu_helper(
  3448. self,
  3449. group,
  3450. group_id,
  3451. rank,
  3452. rank_to_GPU,
  3453. op,
  3454. master_value,
  3455. worker_value,
  3456. expected_value,
  3457. ):
  3458. for src in group:
  3459. tensor_value = master_value if rank == src else worker_value
  3460. tensors = [
  3461. _build_tensor(src + 1, tensor_value).cuda(device=i)
  3462. for i in rank_to_GPU[rank]
  3463. ]
  3464. self.call_dist_op(
  3465. ":reduce",
  3466. False,
  3467. dist.reduce_multigpu,
  3468. tensors,
  3469. src,
  3470. op,
  3471. group_id,
  3472. expect_event=len(tensors) == 1,
  3473. tensor_shapes=[tensors[0].shape],
  3474. )
  3475. if rank == src:
  3476. expected_tensor = _build_tensor(src + 1, expected_value)
  3477. self.assertEqual(tensors[0], expected_tensor)
  3478. self._barrier()
  3479. @sandcastle_skip_if(
  3480. BACKEND != "nccl", "Only Nccl backend supports reduce multigpu"
  3481. )
  3482. @skip_if_no_gpu
  3483. def test_reduce_multigpu(self):
  3484. group, group_id, rank = self._init_global_test()
  3485. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  3486. device_id = rank_to_GPU[rank][0]
  3487. torch.cuda.set_device(device_id)
  3488. self._test_reduce_multigpu_helper(
  3489. group,
  3490. group_id,
  3491. rank,
  3492. rank_to_GPU,
  3493. dist.ReduceOp.SUM,
  3494. 2,
  3495. 10,
  3496. (2 + 10 * (len(group) - 1)) * len(rank_to_GPU[0]),
  3497. )
  3498. def _test_all_gather_multigpu_helper(
  3499. self, group, group_id, rank, rank_to_GPU, dtype=torch.float
  3500. ):
  3501. for dest in group:
  3502. tensors = [
  3503. _build_tensor(dest + 1, dtype=dtype).cuda(device=i)
  3504. for i in rank_to_GPU[rank]
  3505. ]
  3506. # construct expected output along with
  3507. # a place holder to receive all gather results
  3508. output_tensors = []
  3509. expected_output = []
  3510. output_per_gpu = (
  3511. [_build_tensor(dest + 1, -1, dtype=dtype)]
  3512. * len(rank_to_GPU[0])
  3513. * len(group)
  3514. )
  3515. expected_per_gpu = (
  3516. [_build_tensor(dest + 1, dtype=dtype)]
  3517. * len(rank_to_GPU[0])
  3518. * len(group)
  3519. )
  3520. for gpu in rank_to_GPU[rank]:
  3521. output_tensors.append([t.cuda(device=gpu) for t in output_per_gpu])
  3522. expected_output.append(
  3523. [t.cuda(device=gpu) for t in expected_per_gpu]
  3524. )
  3525. self.call_dist_op(
  3526. ":all_gather",
  3527. False,
  3528. dist.all_gather_multigpu,
  3529. output_tensors,
  3530. tensors,
  3531. group_id,
  3532. expect_event=len(expected_output) == 1,
  3533. )
  3534. self.assertEqual(output_tensors, expected_output)
  3535. self._barrier()
  3536. @sandcastle_skip_if(
  3537. BACKEND != "nccl", "Only Nccl backend supports allgather multigpu"
  3538. )
  3539. @skip_if_no_gpu
  3540. def test_all_gather_multigpu(self):
  3541. group, group_id, rank = self._init_global_test()
  3542. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  3543. device_id = rank_to_GPU[rank][0]
  3544. torch.cuda.set_device(device_id)
  3545. self._test_all_gather_multigpu_helper(group, group_id, rank, rank_to_GPU)
  3546. @sandcastle_skip_if(
  3547. BACKEND != "nccl", "Only Nccl backend supports allgather multigpu"
  3548. )
  3549. @skip_if_no_gpu
  3550. def test_all_gather_multigpu_complex(self):
  3551. group, group_id, rank = self._init_global_test()
  3552. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  3553. device_id = rank_to_GPU[rank][0]
  3554. torch.cuda.set_device(device_id)
  3555. self._test_all_gather_multigpu_helper(
  3556. group, group_id, rank, rank_to_GPU, dtype=torch.cfloat
  3557. )
  3558. def _model_step(self, model):
  3559. for param in model.parameters():
  3560. if param.grad is not None:
  3561. with torch.no_grad():
  3562. param += param.grad
  3563. param.grad = None
  3564. def _model_step_with_zero_grad(self, model):
  3565. for param in model.parameters():
  3566. if param.grad is not None:
  3567. with torch.no_grad():
  3568. param += param.grad
  3569. param.grad.requires_grad_(False)
  3570. param.grad.zero_()
  3571. def _prepare_dummy_data(self, local_bs):
  3572. # global_bs for DDP should be divisible by WORLD_SIZE
  3573. world_size = int(os.environ["WORLD_SIZE"])
  3574. global_bs = world_size * local_bs
  3575. input_cpu = torch.randn(global_bs, 2)
  3576. target = torch.randn(global_bs, 4)
  3577. loss = nn.MSELoss()
  3578. return global_bs, input_cpu, target, loss
  3579. # END TO END TEST FOR DISTRIBUTEDDATAPARALLEL
  3580. def _test_DDP_helper(
  3581. self, model, input_var, target, loss, scale_factor=1.0, memory_format=None
  3582. ):
  3583. model.train()
  3584. output = model(input_var)
  3585. l = loss(output, target) * scale_factor
  3586. l.backward()
  3587. if memory_format is not None:
  3588. self.assertTrue(output.is_contiguous(memory_format=memory_format))
  3589. def _assert_equal_param(self, param_gpu, param_DDP):
  3590. self.assertEqual(len(param_gpu), len(param_DDP))
  3591. for p_gpu, p_DDP in zip(param_gpu, param_DDP):
  3592. self.assertEqual(p_gpu, p_DDP)
  3593. def _test_DDP_niter(
  3594. self,
  3595. model_base,
  3596. model_DDP,
  3597. input,
  3598. target,
  3599. loss,
  3600. local_bs,
  3601. rank,
  3602. batch_size,
  3603. test_save,
  3604. offset=None,
  3605. world_size=0,
  3606. zero_grad=False,
  3607. memory_format=None,
  3608. n_iter=5,
  3609. ):
  3610. for idx in range(n_iter):
  3611. # single cpu/gpu training
  3612. self._test_DDP_helper(
  3613. model_base, input, target, loss, memory_format=memory_format
  3614. )
  3615. if offset is None:
  3616. offset = rank * local_bs
  3617. # DDP training, DDP scatters subsets of input_cpu to nodes/GPUs
  3618. self._test_DDP_helper(
  3619. model_DDP,
  3620. input[offset : offset + local_bs],
  3621. target[offset : offset + local_bs],
  3622. loss,
  3623. world_size * local_bs / batch_size if world_size != 0 else 1,
  3624. memory_format=memory_format,
  3625. )
  3626. # Update weights and run a second iteration to shake out errors
  3627. if zero_grad:
  3628. self._model_step_with_zero_grad(model_base)
  3629. self._model_step_with_zero_grad(model_DDP)
  3630. else:
  3631. self._model_step(model_base)
  3632. self._model_step(model_DDP)
  3633. self._assert_equal_param(
  3634. list(model_base.parameters()), list(model_DDP.module.parameters())
  3635. )
  3636. # Shuffle the input so that DDP input is different
  3637. input = input[torch.randperm(batch_size)]
  3638. # save the model in the middle and reload
  3639. if test_save and idx == 2 and INIT_METHOD.startswith("file://"):
  3640. with tempfile.NamedTemporaryFile() as tmp:
  3641. if sys.platform == "win32":
  3642. torch.save(model_DDP, tmp)
  3643. tmp.seek(0)
  3644. model_DDP = torch.load(tmp)
  3645. else:
  3646. torch.save(model_DDP, tmp.name)
  3647. model_DDP = torch.load(tmp.name)
  3648. with tempfile.TemporaryFile() as tmp_file:
  3649. torch.save(model_DDP, tmp_file)
  3650. tmp_file.seek(0)
  3651. saved_model = torch.load(tmp_file)
  3652. for k in model_DDP.state_dict():
  3653. self.assertEqual(model_DDP.state_dict()[k], saved_model.state_dict()[k])
  3654. def _test_DistributedDataParallel(
  3655. self,
  3656. gpu_subset,
  3657. rank,
  3658. output_device=None,
  3659. gradient_as_bucket_view=False,
  3660. static_graph=False,
  3661. set_static_graph_twice=False,
  3662. ):
  3663. # Run a simple end to end DDP model, use result of single node model
  3664. # as baseline
  3665. # cpu training setup
  3666. model = DDP_NET
  3667. # single gpu training setup
  3668. model_gpu = copy.deepcopy(model)
  3669. model_gpu.cuda(gpu_subset[0])
  3670. # DDP training setup
  3671. model_DDP = copy.deepcopy(model)
  3672. model_DDP.cuda(gpu_subset[0])
  3673. model_DDP = nn.parallel.DistributedDataParallel(
  3674. model_DDP,
  3675. device_ids=gpu_subset,
  3676. gradient_as_bucket_view=gradient_as_bucket_view,
  3677. static_graph=static_graph,
  3678. )
  3679. if set_static_graph_twice:
  3680. model_DDP._set_static_graph()
  3681. # test serializable/unserializable
  3682. with tempfile.NamedTemporaryFile() as tmp:
  3683. if sys.platform == "win32":
  3684. torch.save(model_DDP, tmp)
  3685. tmp.seek(0)
  3686. model_DDP = torch.load(tmp)
  3687. else:
  3688. torch.save(model_DDP, tmp.name)
  3689. model_DDP = torch.load(tmp.name)
  3690. # dummy data initialization
  3691. local_bs = len(gpu_subset)
  3692. global_bs, input_cpu, target, loss = self._prepare_dummy_data(local_bs)
  3693. # check two model parameters over 5 iterations
  3694. self._test_DDP_niter(
  3695. model_gpu,
  3696. model_DDP,
  3697. input_cpu.cuda(gpu_subset[0]),
  3698. target.cuda(gpu_subset[0]),
  3699. loss,
  3700. local_bs,
  3701. rank,
  3702. global_bs,
  3703. True,
  3704. )
  3705. self._barrier()
  3706. def _test_DistributedDataParallelCPU(self, gradient_as_bucket_view=False):
  3707. # Run a simple end to end DDP-CPU model, use result of single node
  3708. # model as baseline
  3709. group, group_id, rank = self._init_global_test()
  3710. # cpu training setup
  3711. model_base = DDP_NET
  3712. # DDP-CPU training setup
  3713. model_DDP = copy.deepcopy(model_base)
  3714. model_DDP = nn.parallel.DistributedDataParallel(
  3715. model_DDP, gradient_as_bucket_view=gradient_as_bucket_view
  3716. )
  3717. # dummy data initialization
  3718. local_bs = 2
  3719. global_bs, input_cpu, target, loss = self._prepare_dummy_data(local_bs)
  3720. # check two model parameters over 5 iterations
  3721. self._test_DDP_niter(
  3722. model_base,
  3723. model_DDP,
  3724. input_cpu,
  3725. target,
  3726. loss,
  3727. local_bs,
  3728. rank,
  3729. global_bs,
  3730. False,
  3731. zero_grad=True,
  3732. )
  3733. self._barrier()
  3734. return model_DDP
  3735. @sandcastle_skip_if(BACKEND == "nccl", "nccl does not support DDP on CPU models")
  3736. def test_DistributedDataParallelCPU(self):
  3737. self._test_DistributedDataParallelCPU()
  3738. @sandcastle_skip_if(BACKEND == "nccl", "nccl does not support DDP on CPU models")
  3739. def test_DistributedDataParallelCPU_grad_is_view(self):
  3740. self._test_DistributedDataParallelCPU(gradient_as_bucket_view=True)
  3741. @sandcastle_skip_if(
  3742. BACKEND not in DistTestCases.backend_feature["ddp"],
  3743. f"The {BACKEND} backend does not support DistributedDataParallel"
  3744. )
  3745. def test_DistributedDataParallel_requires_grad(self):
  3746. # a module without gradients shouldn't be accepted
  3747. self.assertRaises(
  3748. RuntimeError, lambda: nn.parallel.DistributedDataParallel(nn.Module())
  3749. )
  3750. self._barrier()
  3751. @sandcastle_skip_if(
  3752. BACKEND not in DistTestCases.backend_feature["ddp"],
  3753. f"The {BACKEND} backend does not support DistributedDataParallel"
  3754. )
  3755. @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
  3756. def test_ddp_zero_output_features(self):
  3757. class ToyModel(nn.Module):
  3758. def __init__(self):
  3759. super().__init__()
  3760. self.net1 = nn.Linear(10, 10)
  3761. self.relu = nn.ReLU()
  3762. self.net2 = nn.Linear(10, 0)
  3763. model = ToyModel().to(self.rank)
  3764. ddp_model = nn.parallel.DistributedDataParallel(model, device_ids=[self.rank])
  3765. @sandcastle_skip_if(
  3766. BACKEND == "nccl",
  3767. "Gloo-only test"
  3768. )
  3769. def test_ddp_create_graph(self):
  3770. class Model(nn.Module):
  3771. def __init__(self):
  3772. super().__init__()
  3773. self.p = nn.Parameter(torch.tensor(1.))
  3774. def forward(self):
  3775. return self.p.pow(2)
  3776. model = Model()
  3777. ddp_model = torch.nn.parallel.DistributedDataParallel(model)
  3778. for _ in range(6):
  3779. # Verify DDP doesn't throw when ran with create_graph=True.
  3780. # Although we do warn about potential issues, please see
  3781. # https://github.com/pytorch/pytorch/issues/63929 for details.
  3782. ddp_model().backward(create_graph=True)
  3783. # grad tensors should require grad.
  3784. self.assertTrue(
  3785. all([param.requires_grad for param in ddp_model.parameters()])
  3786. )
  3787. @sandcastle_skip_if(
  3788. BACKEND not in DistTestCases.backend_feature["ddp"],
  3789. f"The {BACKEND} backend does not support DistributedDataParallel"
  3790. )
  3791. @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
  3792. def test_DistributedDataParallel_non_default_stream(self):
  3793. stream = torch.cuda.Stream(self.rank)
  3794. rank = self.rank
  3795. with torch.cuda.stream(stream):
  3796. net = torch.nn.parallel.DistributedDataParallel(
  3797. torch.nn.Linear(1, 1, bias=False).cuda(rank), device_ids=[rank]
  3798. )
  3799. for i in range(1000):
  3800. # Clear gradients manually
  3801. grad = net.module.weight.grad
  3802. if grad is not None:
  3803. grad.requires_grad_(False)
  3804. grad.zero_()
  3805. # Forward + BW
  3806. batch = torch.tensor([rank]).float().cuda(rank)
  3807. loss = net(batch).sum()
  3808. loss.backward()
  3809. # For each worker, the gradient on the weight should be worker_rank.
  3810. grad = net.module.weight.grad
  3811. avg = grad.clone()
  3812. # All-reducing the gradient averages should give us the gradient
  3813. # average. If not, then one of the workers has not correctly
  3814. # written back the averaged gradient before this all-reduce call.
  3815. dist.all_reduce(avg)
  3816. world_size = int(os.environ["WORLD_SIZE"])
  3817. avg.div_(world_size)
  3818. expected_grad = sum(i for i in range(world_size)) / world_size
  3819. self.assertEqual(
  3820. avg[0, 0],
  3821. expected_grad,
  3822. msg=f"Expected gradient of {expected_grad} but got {avg} on rank {self.rank}",
  3823. )
  3824. @sandcastle_skip_if(
  3825. BACKEND not in DistTestCases.backend_feature["cuda"],
  3826. f"The {BACKEND} backend does not support DDP communication hook on CUDA devices"
  3827. )
  3828. @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
  3829. def test_ddp_comm_hook_logging(self):
  3830. hooks = [
  3831. default.allreduce_hook,
  3832. default.fp16_compress_hook,
  3833. powerSGD.powerSGD_hook,
  3834. powerSGD.batched_powerSGD_hook,
  3835. quantization_hooks.quantization_pertensor_hook,
  3836. quantization_hooks.quantization_perchannel_hook,
  3837. ]
  3838. cpp_builtin_hooks = [
  3839. dist.BuiltinCommHookType.ALLREDUCE,
  3840. dist.BuiltinCommHookType.FP16_COMPRESS,
  3841. ]
  3842. for hook in hooks:
  3843. ddp_model = torch.nn.parallel.DistributedDataParallel(
  3844. torch.nn.Linear(1, 1, bias=False).cuda(self.rank),
  3845. device_ids=[self.rank],
  3846. )
  3847. ddp_logging_data = ddp_model._get_ddp_logging_data()
  3848. # Hook not registered yet, so should be empty
  3849. self.assertEqual(ddp_logging_data.get("comm_hook"), None)
  3850. ddp_model.register_comm_hook(None, hook)
  3851. ddp_logging_data = ddp_model._get_ddp_logging_data()
  3852. self.assertEqual(ddp_logging_data.get("comm_hook"), hook.__qualname__)
  3853. for hook in cpp_builtin_hooks:
  3854. ddp_model = torch.nn.parallel.DistributedDataParallel(
  3855. torch.nn.Linear(1, 1, bias=False).cuda(self.rank),
  3856. device_ids=[self.rank],
  3857. )
  3858. ddp_logging_data = ddp_model._get_ddp_logging_data()
  3859. # Hook not registered yet, so should be empty
  3860. self.assertEqual(ddp_logging_data.get("comm_hook"), None)
  3861. ddp_model._register_builtin_comm_hook(hook)
  3862. ddp_logging_data = ddp_model._get_ddp_logging_data()
  3863. self.assertEqual(ddp_logging_data.get("comm_hook"), str(hook))
  3864. # No hook registered
  3865. ddp_model = torch.nn.parallel.DistributedDataParallel(
  3866. torch.nn.Linear(1, 1, bias=False).cuda(self.rank),
  3867. device_ids=[self.rank],
  3868. )
  3869. ddp_logging_data = ddp_model._get_ddp_logging_data()
  3870. # Hook not registered yet, so should be empty
  3871. self.assertEqual(ddp_logging_data.get("comm_hook"), None)
  3872. # After second forward pass, hook should still be empty string
  3873. for i in range(2):
  3874. inp = torch.ones(1, 1, device=self.rank)
  3875. loss = ddp_model(inp).sum()
  3876. loss.backward()
  3877. ddp_logging_data = ddp_model._get_ddp_logging_data()
  3878. # Note: DETAIL debug mode logs DDP logging data to stdout and
  3879. # thus accesses std::map, which fills in a default value for the
  3880. # type if it didn't exist.
  3881. self.assertEqual(ddp_logging_data.get("comm_hook", ""), "")
  3882. def _test_ddp_hook_with_optimizer_parity(
  3883. self, grad_as_bucket_view, static_graph, optim_cls,
  3884. optimize_subset, *functional_optim_args, **functional_optim_kwargs
  3885. ):
  3886. rank = self.rank
  3887. torch.cuda.set_device(rank)
  3888. torch.manual_seed(rank)
  3889. torch.cuda.manual_seed(rank)
  3890. models_to_test = [
  3891. (LargeNet(), torch.randn(1, 1000).cuda()),
  3892. ]
  3893. if HAS_TORCHVISION:
  3894. models_to_test.append(
  3895. (torchvision.models.resnet50(), torch.randn(1, 3, 3, 1000).cuda())
  3896. )
  3897. for (model, inp) in models_to_test:
  3898. # Enable determinism in cudnn operators
  3899. with torch.backends.cudnn.flags(
  3900. enabled=True, deterministic=True, benchmark=False
  3901. ):
  3902. # Create DDP model that runs optimizer in fused fashion.
  3903. ddp_model_with_optimizer_hook = (
  3904. torch.nn.parallel.DistributedDataParallel(
  3905. copy.deepcopy(model).cuda(),
  3906. device_ids=[self.rank],
  3907. gradient_as_bucket_view=grad_as_bucket_view,
  3908. static_graph=static_graph,
  3909. )
  3910. )
  3911. # Create DDP model with no hook that does optimizer after
  3912. # backward.
  3913. ddp_model_with_no_hook = torch.nn.parallel.DistributedDataParallel(
  3914. copy.deepcopy(model).cuda(),
  3915. device_ids=[self.rank],
  3916. gradient_as_bucket_view=grad_as_bucket_view,
  3917. static_graph=static_graph,
  3918. )
  3919. hook_params = ddp_model_with_optimizer_hook.parameters()
  3920. no_hook_params = ddp_model_with_no_hook.parameters()
  3921. if optimize_subset:
  3922. hook_params = list(hook_params)
  3923. no_hook_params = list(no_hook_params)
  3924. self.assertGreater(len(hook_params), 0)
  3925. hook_params = [hook_params[0]]
  3926. no_hook_params = [no_hook_params[0]]
  3927. # Register a fused optimizer that will run optimizer in step
  3928. # with allreduce.
  3929. if optimize_subset:
  3930. # API where optim_params is specified.
  3931. ddp_model_with_optimizer_hook._register_fused_optim(
  3932. optim_cls,
  3933. *functional_optim_args,
  3934. optim_params=hook_params,
  3935. **functional_optim_kwargs,
  3936. )
  3937. else:
  3938. # API where optim_params is omitted
  3939. ddp_model_with_optimizer_hook._register_fused_optim(
  3940. optim_cls,
  3941. *functional_optim_args,
  3942. **functional_optim_kwargs,
  3943. )
  3944. optimizer_no_hook = optim_cls(
  3945. no_hook_params,
  3946. *functional_optim_args,
  3947. **functional_optim_kwargs,
  3948. )
  3949. # Verify parameters are equal initially.
  3950. for hook_param, allreduce_param in zip(
  3951. ddp_model_with_optimizer_hook.parameters(),
  3952. ddp_model_with_no_hook.parameters(),
  3953. ):
  3954. self.assertEqual(hook_param, allreduce_param)
  3955. # Save old parameters to later verify optimizer modified them.
  3956. opt_hook_init_params = copy.deepcopy(
  3957. list(ddp_model_with_optimizer_hook.parameters())
  3958. )
  3959. # Run optimizer with hook model.
  3960. for i in range(6):
  3961. ddp_model_with_optimizer_hook.zero_grad()
  3962. out = ddp_model_with_optimizer_hook(inp)
  3963. loss = out.sum()
  3964. loss.backward()
  3965. dist.barrier()
  3966. # Run regular model.
  3967. for i in range(6):
  3968. ddp_model_with_no_hook.zero_grad()
  3969. out = ddp_model_with_no_hook(inp)
  3970. loss = out.sum()
  3971. loss.backward()
  3972. optimizer_no_hook.step()
  3973. dist.barrier()
  3974. # Now verify parameters are equal.
  3975. for hook_param, allreduce_param in zip(
  3976. ddp_model_with_optimizer_hook.parameters(),
  3977. ddp_model_with_no_hook.parameters(),
  3978. ):
  3979. self.assertEqual(hook_param, allreduce_param)
  3980. # Verify optimizer modified appropriate parameter set,
  3981. # otherwise they'd be trivially equal above.
  3982. if optimize_subset:
  3983. self.assertNotEqual(
  3984. opt_hook_init_params[0],
  3985. list(ddp_model_with_optimizer_hook.parameters())[0]
  3986. )
  3987. # Untouched params should be equal
  3988. self.assertEqual(
  3989. opt_hook_init_params[1:],
  3990. list(ddp_model_with_optimizer_hook.parameters())[1:]
  3991. )
  3992. else:
  3993. self.assertNotEqual(
  3994. opt_hook_init_params,
  3995. list(ddp_model_with_optimizer_hook.parameters()),
  3996. )
  3997. dist.barrier()
  3998. @sandcastle_skip_if(
  3999. BACKEND == "nccl" or BACKEND == "ucc",
  4000. "Issues with async error handling, see https://github.com/pytorch/pytorch/issues/73259"
  4001. )
  4002. @skip_if_lt_x_gpu(2)
  4003. @parametrize("grad_as_bucket_view", [True, False])
  4004. @parametrize("static_graph", [True, False])
  4005. @parametrize("optimize_subset", [True, False])
  4006. def test_ddp_hook_with_optimizer_parity_adamw(
  4007. self,
  4008. grad_as_bucket_view,
  4009. static_graph,
  4010. optimize_subset,
  4011. ):
  4012. adamw_lr = 1e-2
  4013. adamw_betas = (0.9, 0.99)
  4014. adamw_eps = 1e-6
  4015. self._test_ddp_hook_with_optimizer_parity(
  4016. grad_as_bucket_view,
  4017. static_graph,
  4018. torch.optim.AdamW,
  4019. optimize_subset,
  4020. adamw_lr,
  4021. betas=adamw_betas,
  4022. eps=adamw_eps,
  4023. )
  4024. @sandcastle_skip_if(
  4025. BACKEND == "nccl" or BACKEND == "ucc",
  4026. "Issues with async error handling, see https://github.com/pytorch/pytorch/issues/73259"
  4027. )
  4028. @skip_if_lt_x_gpu(2)
  4029. @parametrize("optimize_subset", [True, False])
  4030. def test_ddp_hook_with_optimizer_parity_adam(self, optimize_subset):
  4031. adam_lr = 1e-2
  4032. adam_betas = (0.9, 0.99)
  4033. adam_eps = 1e-6
  4034. self._test_ddp_hook_with_optimizer_parity(
  4035. True, # grad as bucket view
  4036. False, # static graph
  4037. torch.optim.Adam,
  4038. optimize_subset,
  4039. adam_lr,
  4040. betas=adam_betas,
  4041. eps=adam_eps,
  4042. )
  4043. @sandcastle_skip_if(
  4044. BACKEND == "nccl" or BACKEND == "ucc",
  4045. "Issues with async error handling, see https://github.com/pytorch/pytorch/issues/73259"
  4046. )
  4047. @skip_if_lt_x_gpu(2)
  4048. @parametrize("optimize_subset", [True, False])
  4049. def test_ddp_hook_with_optimizer_parity_sgd(self, optimize_subset):
  4050. sgd_lr = 1e-2
  4051. sgd_momentum = 0.9
  4052. sgd_weight_decay = 0.01
  4053. # Not testing grad_as_bucket_view and static_graph as they are
  4054. # tested in AdamW test above.
  4055. self._test_ddp_hook_with_optimizer_parity(
  4056. True, # grad as bucket view
  4057. False, # static_graph
  4058. torch.optim.SGD,
  4059. optimize_subset,
  4060. sgd_lr,
  4061. momentum=sgd_momentum,
  4062. weight_decay=sgd_weight_decay,
  4063. )
  4064. def _test_ddp_apply_optim_in_backward(
  4065. self,
  4066. optim_cls,
  4067. optim_kwargs,
  4068. gradient_as_bucket_view=True,
  4069. ):
  4070. # Need to seed to ensure inputs are unique across rank. Otherwise,
  4071. # allreduce won't have any effect.
  4072. torch.manual_seed(self.rank)
  4073. torch.cuda.manual_seed(self.rank)
  4074. torch.cuda.set_device(self.rank)
  4075. # Test a simple linear as well as a ResNet model.
  4076. models_to_test = [
  4077. nn.Sequential(
  4078. nn.Linear(3, 3), nn.Linear(3, 3), nn.Linear(3, 3)
  4079. ).cuda()
  4080. ]
  4081. if HAS_TORCHVISION:
  4082. models_to_test.append(
  4083. torchvision.models.resnet50().cuda()
  4084. )
  4085. for j, model in enumerate(models_to_test):
  4086. model_optim_in_bwd = copy.deepcopy(model)
  4087. model = nn.parallel.DistributedDataParallel(
  4088. model,
  4089. device_ids=[self.rank],
  4090. gradient_as_bucket_view=gradient_as_bucket_view,
  4091. )
  4092. optim = optim_cls(model.parameters(), **optim_kwargs)
  4093. # Note: have to apply_optimizer_in_backward before wrapping with DDP.
  4094. _apply_optimizer_in_backward(
  4095. optimizer_class=optim_cls,
  4096. params=model_optim_in_bwd.parameters(),
  4097. optimizer_kwargs=optim_kwargs,
  4098. )
  4099. model_optim_in_bwd = nn.parallel.DistributedDataParallel(
  4100. model_optim_in_bwd,
  4101. device_ids=[self.rank],
  4102. gradient_as_bucket_view=gradient_as_bucket_view,
  4103. )
  4104. for p1, p2 in zip(
  4105. model.parameters(), model_optim_in_bwd.parameters()
  4106. ):
  4107. self.assertEqual(p1, p2, "Parameters not initially equal!")
  4108. # Enable determinism in cudnn operators
  4109. with torch.backends.cudnn.flags(
  4110. enabled=True, deterministic=True, benchmark=False
  4111. ):
  4112. for i in range(100):
  4113. inp = (
  4114. torch.randn(1, 3, 1000, 1000, device='cuda')
  4115. if j == 1 else torch.randn(10, 3, device='cuda')
  4116. )
  4117. model(inp).sum().backward()
  4118. optim.step()
  4119. model_optim_in_bwd(inp).sum().backward() # runs optimizer as well
  4120. for p1, p2 in zip(
  4121. model.parameters(), model_optim_in_bwd.parameters()
  4122. ):
  4123. self.assertEqual(p1, p2, f"Params not equal at iteration {i}")
  4124. self.assertTrue(
  4125. p2.grad is None, f"Optim in backward grad is not None at {i}"
  4126. )
  4127. # set_to_none for regular optimizer to match in backward
  4128. # case.
  4129. optim.zero_grad(set_to_none=True)
  4130. @skip_if_lt_x_gpu(2)
  4131. def test_ddp_apply_optim_in_backward(self):
  4132. for optim_cls in [torch.optim.SGD, torch.optim.Adam]:
  4133. with self.subTest(optim_cls=optim_cls):
  4134. self._test_ddp_apply_optim_in_backward(
  4135. optim_cls=optim_cls,
  4136. optim_kwargs={"lr": 0.03}
  4137. )
  4138. @skip_if_lt_x_gpu(2)
  4139. def test_ddp_apply_optim_in_backward_grad_as_bucket_view_false(self):
  4140. self._test_ddp_apply_optim_in_backward(
  4141. optim_cls=torch.optim.SGD,
  4142. optim_kwargs={"lr": 0.03},
  4143. gradient_as_bucket_view=False,
  4144. )
  4145. @skip_if_lt_x_gpu(2)
  4146. def test_ddp_apply_optim_in_backward_ignored_params(self):
  4147. torch.cuda.set_device(self.rank)
  4148. torch.manual_seed(self.rank)
  4149. torch.cuda.manual_seed(self.rank)
  4150. model = TwoLinLayerNet()
  4151. model_clone = copy.deepcopy(model)
  4152. # Parameters to ignore are in the format {module_name}.{param_name}
  4153. params_to_ignore = ["a.weight"]
  4154. torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
  4155. model, params_to_ignore
  4156. )
  4157. _apply_optimizer_in_backward(
  4158. optimizer_class=torch.optim.SGD,
  4159. params=model.parameters(),
  4160. optimizer_kwargs={"lr": 0.03}
  4161. )
  4162. net = torch.nn.parallel.DistributedDataParallel(
  4163. model.cuda(self.rank),
  4164. device_ids=[self.rank],
  4165. )
  4166. inp = torch.randn(1, 10)
  4167. a, b = net(inp)
  4168. (a.transpose(0, 1) @ b).sum().backward()
  4169. # a.weight did not go through allreduce, so optimizer acted on local
  4170. # gradient, which should be different across ranks. Remaining params
  4171. # should be equal.
  4172. models = [None for _ in range(dist.get_world_size())]
  4173. dist.all_gather_object(models, model)
  4174. rank0_model, remainder = models[0], models[1:]
  4175. for m in remainder:
  4176. self.assertNotEqual(rank0_model.a.weight, m.a.weight)
  4177. self.assertEqual(
  4178. list(rank0_model.b.parameters()), list(m.b.parameters())
  4179. )
  4180. self.assertEqual(rank0_model.a.bias, m.a.bias)
  4181. def _test_ddp_hook_parity(self, state, hook, num_validated_iters=100):
  4182. rank = self.rank
  4183. m = torch.nn.Linear(1, 5)
  4184. try:
  4185. process_group = state.process_group
  4186. except AttributeError:
  4187. process_group = state
  4188. net_with_hook = torch.nn.parallel.DistributedDataParallel(
  4189. copy.deepcopy(m).to(rank),
  4190. device_ids=[rank],
  4191. process_group=process_group,
  4192. )
  4193. net_with_hook.register_comm_hook(state=state, hook=hook)
  4194. net_without_hook = torch.nn.parallel.DistributedDataParallel(
  4195. copy.deepcopy(m).to(rank),
  4196. device_ids=[rank],
  4197. process_group=process_group,
  4198. )
  4199. for i in range(100):
  4200. # Clear gradients manually.
  4201. for g in [
  4202. net_without_hook.module.weight.grad,
  4203. net_with_hook.module.weight.grad,
  4204. ]:
  4205. if g is not None:
  4206. g.requires_grad_(False)
  4207. g.zero_()
  4208. # Forward + BW
  4209. batch = torch.tensor([rank]).float().cuda(rank)
  4210. loss = net_without_hook(batch).sum()
  4211. loss.backward()
  4212. # For each worker, the gradient on the weight should be worker_rank.
  4213. grad = net_without_hook.module.weight.grad
  4214. avg = grad.clone()
  4215. expected_grad = (
  4216. sum(i for i in range(dist.get_world_size())) / dist.get_world_size()
  4217. )
  4218. loss_hook = net_with_hook(batch).sum()
  4219. loss_hook.backward()
  4220. grad_hook = net_with_hook.module.weight.grad
  4221. avg_hook = grad_hook.clone()
  4222. if i < num_validated_iters:
  4223. # Verify hook grad with expected.
  4224. self.assertEqual(
  4225. avg_hook[0, 0].item(),
  4226. expected_grad,
  4227. msg=f"Expected hook grad of {expected_grad} but got {avg_hook[0, 0]}",
  4228. )
  4229. # Verify hook grad with vanilla allreduce
  4230. self.assertEqual(
  4231. avg_hook[0, 0],
  4232. avg[0, 0],
  4233. msg=f"Expected hook grad to be close to allreduce {avg[0, 0]}, but got {avg_hook[0, 0]}",
  4234. )
  4235. @sandcastle_skip_if(
  4236. BACKEND not in DistTestCases.backend_feature["cuda"],
  4237. f"The {BACKEND} backend does not support DDP communication hook on CUDA devices"
  4238. )
  4239. @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
  4240. def test_ddp_hook_parity_allreduce(self):
  4241. self._test_ddp_hook_parity(state=None, hook=default.allreduce_hook)
  4242. @sandcastle_skip_if(
  4243. BACKEND not in DistTestCases.backend_feature["cuda"],
  4244. f"The {BACKEND} backend does not support DDP communication hook on CUDA devices"
  4245. )
  4246. @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
  4247. def test_ddp_hook_parity_allreduce_process_group(self):
  4248. # process_group is passed in to both DDP and comm. hook
  4249. world_size = dist.get_world_size()
  4250. rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
  4251. gpus = [rank_to_GPU[int(r)][0] for r in range(world_size)]
  4252. process_group = torch.distributed.new_group(gpus)
  4253. self._test_ddp_hook_parity(state=process_group, hook=default.allreduce_hook)
  4254. @sandcastle_skip_if(
  4255. BACKEND not in DistTestCases.backend_feature["cuda"],
  4256. f"The {BACKEND} backend does not support DDP communication hook on CUDA devices"
  4257. )
  4258. @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
  4259. def test_ddp_hook_parity_powerSGD(self):
  4260. for warm_start in [True, False]:
  4261. powersgd_state = powerSGD.PowerSGDState(
  4262. process_group=None,
  4263. matrix_approximation_rank=1,
  4264. start_powerSGD_iter=2,
  4265. warm_start=warm_start,
  4266. )
  4267. self._test_ddp_hook_parity(
  4268. state=powersgd_state, hook=powerSGD.powerSGD_hook
  4269. )
  4270. @sandcastle_skip_if(
  4271. BACKEND not in DistTestCases.backend_feature["cuda"],
  4272. f"The {BACKEND} backend does not support DDP communication hook on CUDA devices"
  4273. )
  4274. @sandcastle_skip_if(
  4275. NO_MULTIPROCESSING_SPAWN,
  4276. "Disabled for environments that \
  4277. don't support multiprocessing with spawn start method",
  4278. )
  4279. @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
  4280. def test_ddp_hook_parity_post_localSGD(self):
  4281. # Although we start run local SGD at iteration 10, since we still use the global process group to run it,
  4282. # the post-LocalSGD actually still allreduces gradients globally for the remaining iterations.
  4283. state = post_localSGD.PostLocalSGDState(
  4284. process_group=None, subgroup=dist.group.WORLD, start_localSGD_iter=10
  4285. )
  4286. self._test_ddp_hook_parity(
  4287. state=state, hook=post_localSGD.post_localSGD_hook
  4288. )
  4289. # Only validate the warmup iterations before local SGD is applied,
  4290. # because when `post_local_gradient_allreduce` is disabled, the gradients will not be synchronized at all.
  4291. # Note that in practice a model averager has to be applied to run model averaging,
  4292. # so local gradient averaging is not necessary.
  4293. start_localSGD_iter = 10
  4294. state = post_localSGD.PostLocalSGDState(
  4295. process_group=None,
  4296. subgroup=dist.group.WORLD,
  4297. start_localSGD_iter=start_localSGD_iter,
  4298. post_local_gradient_allreduce=False,
  4299. )
  4300. self._test_ddp_hook_parity(
  4301. state=state, hook=post_localSGD.post_localSGD_hook, num_validated_iters=start_localSGD_iter
  4302. )
  4303. # When `subgroup` is None, it is equivalent to the subgroup on the each node.
  4304. # For this single-node test environment, the intra-node process group is equivalent to
  4305. # the global process group.
  4306. if self.world_size == dist.get_world_size():
  4307. state = post_localSGD.PostLocalSGDState(
  4308. process_group=None, subgroup=None, start_localSGD_iter=10
  4309. )
  4310. self._test_ddp_hook_parity(
  4311. state=state, hook=post_localSGD.post_localSGD_hook
  4312. )
  4313. # Since we start local SGD later than the total number of 100 iterations,
  4314. # no local SGD actually is executed, and we don't even need to provide a subgroup for this case.
  4315. state = post_localSGD.PostLocalSGDState(
  4316. process_group=None, subgroup=None, start_localSGD_iter=1000
  4317. )
  4318. self._test_ddp_hook_parity(
  4319. state=state, hook=post_localSGD.post_localSGD_hook
  4320. )
  4321. def _prepare_single_device_module(
  4322. self,
  4323. rank,
  4324. process_group,
  4325. devices,
  4326. device_ids,
  4327. global_batch_size,
  4328. gradient_as_bucket_view=False,
  4329. ):
  4330. model = Net()
  4331. device = devices[0] if devices else torch.device("cuda:%d" % rank)
  4332. ddp_model = DistributedDataParallel(
  4333. copy.deepcopy(model).to(device),
  4334. device_ids=device_ids,
  4335. process_group=process_group,
  4336. bucket_cap_mb=0.001,
  4337. gradient_as_bucket_view=gradient_as_bucket_view,
  4338. )
  4339. model.to(device)
  4340. input = torch.randn(global_batch_size, 2).to(device)
  4341. target = torch.randn(global_batch_size, 4).to(device)
  4342. return model, ddp_model, input, target
  4343. def _prepare_cpu_module(
  4344. self,
  4345. process_group,
  4346. global_batch_size,
  4347. gradient_as_bucket_view=False,
  4348. ):
  4349. model = Net()
  4350. ddp_model = DistributedDataParallel(
  4351. copy.deepcopy(model),
  4352. process_group=process_group,
  4353. bucket_cap_mb=0.001,
  4354. gradient_as_bucket_view=gradient_as_bucket_view,
  4355. )
  4356. input = torch.randn(global_batch_size, 2)
  4357. target = torch.randn(global_batch_size, 4)
  4358. return model, ddp_model, input, target
  4359. def _test_accumulate_gradients_no_sync(
  4360. self, num_iters=2, ddp_comm_hook=None, gradient_as_bucket_view=False
  4361. ):
  4362. """
  4363. This is the recommended way to implement accumulate grads.
  4364. If ``ddp_comm_hook`` input was specified, it will also register that hook
  4365. to the ``ddp_model``. The hook fed into this function should not change
  4366. the resulting gradients.
  4367. """
  4368. group, group_id, rank = self._init_global_test()
  4369. world_size = get_world_size()
  4370. # FIXME: Add testing for gloo/CUDA
  4371. if BACKEND == "mpi" or BACKEND == "gloo":
  4372. global_batch_size = world_size
  4373. local_batch_size = 1
  4374. model, ddp_model, input, target = self._prepare_cpu_module(
  4375. group_id, global_batch_size, gradient_as_bucket_view
  4376. )
  4377. if BACKEND == "nccl":
  4378. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  4379. int_devices = rank_to_GPU[rank][:1]
  4380. devices = [torch.device("cuda:" + str(i)) for i in int_devices]
  4381. global_batch_size = world_size
  4382. local_batch_size = len(devices)
  4383. model, ddp_model, input, target = self._prepare_single_device_module(
  4384. rank,
  4385. group_id,
  4386. devices,
  4387. devices,
  4388. global_batch_size,
  4389. gradient_as_bucket_view,
  4390. )
  4391. if ddp_comm_hook is not None:
  4392. ddp_model.register_comm_hook(group_id, ddp_comm_hook)
  4393. def step_model(model, input, target):
  4394. model.train()
  4395. output = model(input)
  4396. loss = F.mse_loss(output, target.to(output.device))
  4397. loss.backward()
  4398. # ensure accumulate grads works with no_grad => no grads are accumulated.
  4399. with torch.no_grad():
  4400. with ddp_model.no_sync():
  4401. ddp_model.train()
  4402. ddp_model(input)
  4403. # check two model parameters over num_iters iterations
  4404. for iteration in range(num_iters):
  4405. step_model(model, input, target)
  4406. ddp_input = input[
  4407. rank * local_batch_size : (rank + 1) * local_batch_size
  4408. ]
  4409. ddp_target = target[
  4410. rank * local_batch_size : (rank + 1) * local_batch_size
  4411. ]
  4412. if iteration % 2 == 0:
  4413. # accumulate grads locally
  4414. with ddp_model.no_sync():
  4415. step_model(ddp_model, ddp_input, ddp_target)
  4416. else:
  4417. # sync grads
  4418. step_model(ddp_model, ddp_input, ddp_target)
  4419. for i, j in zip(model.parameters(), ddp_model.parameters()):
  4420. if not i.requires_grad:
  4421. continue
  4422. if iteration % 2 == 0:
  4423. self.assertNotEqual(i.grad, j.grad)
  4424. else:
  4425. self.assertEqual(i.grad, j.grad)
  4426. # Shuffle the input so that DDP input is different
  4427. torch.manual_seed(1337 + iteration)
  4428. input = input[torch.randperm(global_batch_size)]
  4429. @sandcastle_skip_if(
  4430. BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
  4431. "get_future is only supported on mpi, nccl and gloo",
  4432. )
  4433. @nccl_skip_if_lt_x_gpu(BACKEND, 2)
  4434. def test_accumulate_gradients_no_sync(self):
  4435. """
  4436. Runs _test_accumulate_gradients_no_sync using default inputs
  4437. """
  4438. self._test_accumulate_gradients_no_sync()
  4439. @sandcastle_skip_if(
  4440. BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
  4441. "get_future is only supported on mpi, nccl and gloo",
  4442. )
  4443. @nccl_skip_if_lt_x_gpu(BACKEND, 2)
  4444. def test_accumulate_gradients_no_sync_grad_is_view(self):
  4445. """
  4446. Runs _test_accumulate_gradients_no_sync using default inputs
  4447. """
  4448. self._test_accumulate_gradients_no_sync(gradient_as_bucket_view=True)
  4449. @sandcastle_skip_if(
  4450. BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
  4451. "get_future is only supported on mpi, nccl and gloo",
  4452. )
  4453. @nccl_skip_if_lt_x_gpu(BACKEND, 2)
  4454. def test_accumulate_gradients_no_sync_allreduce_hook(self):
  4455. """
  4456. Runs multiple iterations on _test_accumulate_gradients_no_sync
  4457. using allreduce hook and validates whether future result was properly
  4458. passed as gradients in reducer.
  4459. """
  4460. world_size = get_world_size()
  4461. def allreduce_hook(
  4462. group_id: object, bucket: dist.GradBucket
  4463. ) -> torch.futures.Future[torch.Tensor]:
  4464. tensors = [bucket.buffer() / world_size]
  4465. return (
  4466. group_id.allreduce(tensors)
  4467. .get_future()
  4468. .then(lambda fut: fut.value()[0])
  4469. )
  4470. self._test_accumulate_gradients_no_sync(
  4471. num_iters=4, ddp_comm_hook=allreduce_hook
  4472. )
  4473. @sandcastle_skip_if(
  4474. BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
  4475. "get_future is only supported on mpi, nccl and gloo",
  4476. )
  4477. @nccl_skip_if_lt_x_gpu(BACKEND, 2)
  4478. def test_accumulate_gradients_no_sync_allreduce_with_then_hook(self):
  4479. """
  4480. Runs multiple iterations on _test_accumulate_gradients_no_sync using allreduce
  4481. hook that also uses then callbacks. In first then callback result is multiplied
  4482. by 2, and the second callback divides the result by 2 * world_size. It validates
  4483. whether final result was properly passed as gradients in reducer.
  4484. """
  4485. world_size = get_world_size()
  4486. def allreduce_with_then_hook(
  4487. group_id: object, bucket: dist.GradBucket
  4488. ) -> torch.futures.Future[torch.Tensor]:
  4489. fut = group_id.allreduce([bucket.buffer()]).get_future()
  4490. def mult(fut):
  4491. # Multiply the result by 2.
  4492. return 2 * fut.wait()[0]
  4493. def div(fut):
  4494. # Divide the result by 2 * world_size.
  4495. return fut.wait() / (2 * world_size)
  4496. return fut.then(mult).then(div)
  4497. self._test_accumulate_gradients_no_sync(
  4498. num_iters=4, ddp_comm_hook=allreduce_with_then_hook
  4499. )
  4500. @sandcastle_skip_if(
  4501. BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
  4502. "get_future is only supported on mpi, nccl and gloo",
  4503. )
  4504. @nccl_skip_if_lt_x_gpu(BACKEND, 2)
  4505. def test_get_future(self):
  4506. def mult(fut):
  4507. return [t * 3 for t in fut.wait()]
  4508. def add(fut):
  4509. return [t + 1 for t in fut.wait()]
  4510. group, group_id, rank = self._init_global_test()
  4511. input = _build_tensor(3, 2)
  4512. if BACKEND == "nccl":
  4513. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  4514. device_id = rank_to_GPU[rank][0]
  4515. input = input.to(device_id)
  4516. fut = group_id.allreduce([input]).get_future()
  4517. res = fut.then(mult).then(add).wait()
  4518. expected = _build_tensor(3, 2 * len(group) * 3 + 1)
  4519. self.assertEqual(res[0], expected)
  4520. @sandcastle_skip_if(
  4521. BACKEND not in DistTestCases.backend_feature["ddp"],
  4522. f"The {BACKEND} backend does not support DistributedDataParallel"
  4523. )
  4524. @skip_if_no_gpu
  4525. def test_DistributedDataParallel(self):
  4526. group, group_id, rank = self._init_global_test()
  4527. rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
  4528. gpus = list(rank_to_GPU[rank])
  4529. for use_bucket_view, static_graph in itertools.product(
  4530. (False, True), (False, True)
  4531. ):
  4532. self._test_DistributedDataParallel(
  4533. gpu_subset=gpus,
  4534. rank=rank,
  4535. gradient_as_bucket_view=use_bucket_view,
  4536. static_graph=static_graph,
  4537. )
  4538. # test set static graph twice
  4539. self._test_DistributedDataParallel(
  4540. gpu_subset=gpus,
  4541. rank=rank,
  4542. gradient_as_bucket_view=use_bucket_view,
  4543. static_graph=static_graph,
  4544. set_static_graph_twice=True,
  4545. )
  4546. # test output_device
  4547. self._test_DistributedDataParallel(
  4548. gpu_subset=gpus,
  4549. rank=rank,
  4550. output_device=torch.device("cuda"),
  4551. gradient_as_bucket_view=use_bucket_view,
  4552. static_graph=static_graph,
  4553. )
  4554. # test device_ids
  4555. gpus_list = [torch.device("cuda:" + str(i)) for i in gpus]
  4556. self._test_DistributedDataParallel(
  4557. gpu_subset=gpus_list,
  4558. rank=rank,
  4559. output_device=torch.device("cuda"),
  4560. gradient_as_bucket_view=use_bucket_view,
  4561. static_graph=static_graph,
  4562. )
  4563. def _test_DistributedDataParallel_with_amp(self, grad_is_view=False):
  4564. torch.manual_seed(31415)
  4565. # Creates model and optimizer in default precision
  4566. model = copy.deepcopy(DDP_NET).cuda()
  4567. optimizer = torch.optim.SGD(model.parameters(), lr=0.03)
  4568. # Creates a GradScaler once at the beginning of training.
  4569. scaler = GradScaler()
  4570. ddp_model = nn.parallel.DistributedDataParallel(
  4571. model, device_ids=[self.rank], gradient_as_bucket_view=grad_is_view
  4572. )
  4573. input = torch.randn(dist.get_world_size() * 2, 2).cuda()
  4574. target = torch.randn(dist.get_world_size() * 2, 4).cuda()
  4575. loss_fn = nn.MSELoss()
  4576. # verify grads are none before training
  4577. for p in ddp_model.parameters():
  4578. self.assertTrue(p is not None)
  4579. self.assertTrue(p.grad is None)
  4580. for idx in range(20):
  4581. optimizer.zero_grad()
  4582. # Runs the forward pass with autocasting.
  4583. with autocast():
  4584. output = ddp_model(input)
  4585. loss = loss_fn(output, target)
  4586. # Scales loss. Calls backward() on scaled loss to create scaled gradients.
  4587. # Backward passes under autocast are not recommended.
  4588. # Backward ops run in the same dtype autocast chose for corresponding forward ops.
  4589. scaler.scale(loss).backward()
  4590. # verify grads are not none and are valid during training
  4591. for p in ddp_model.parameters():
  4592. if p.requires_grad:
  4593. self.assertTrue(p.grad is not None)
  4594. self.assertFalse(p.grad.isnan().any())
  4595. self.assertFalse(p.grad.isinf().any())
  4596. # scaler.step() first unscales the gradients of the optimizer's assigned params.
  4597. # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
  4598. # otherwise, optimizer.step() is skipped.
  4599. scaler.step(optimizer)
  4600. # Updates the scale for next iteration.
  4601. scaler.update()
  4602. # Shuffle the input so that DDP input is different
  4603. torch.manual_seed(1337 + idx)
  4604. input = input[torch.randperm(dist.get_world_size() * 2)]
  4605. return ddp_model
  4606. @sandcastle_skip_if(
  4607. BACKEND not in DistTestCases.backend_feature["ddp"],
  4608. f"The {BACKEND} backend does not support DistributedDataParallel"
  4609. )
  4610. @skip_if_no_gpu
  4611. def test_DistributedDataParallel_with_amp_and_grad_is_view(self):
  4612. torch.cuda.set_device(self.rank)
  4613. ddp_model_grad_not_view = self._test_DistributedDataParallel_with_amp(
  4614. grad_is_view=False
  4615. )
  4616. ddp_model_grad_is_view = self._test_DistributedDataParallel_with_amp(
  4617. grad_is_view=True
  4618. )
  4619. for i, j in zip(
  4620. ddp_model_grad_not_view.parameters(),
  4621. ddp_model_grad_is_view.parameters(),
  4622. ):
  4623. self.assertEqual(i, j)
  4624. def _test_DistributedDataParallel_SyncBatchNorm(
  4625. self,
  4626. gpu_subset,
  4627. rank,
  4628. local_bs,
  4629. global_bs,
  4630. offset,
  4631. output_device=None,
  4632. affine=True,
  4633. ):
  4634. # Run a simple end to end DDP model, use result of single node model
  4635. # as baseline
  4636. # cpu training setup
  4637. model = BN_NET if affine else BN_NET_NO_AFFINE
  4638. # single gpu training setup
  4639. model_gpu = copy.deepcopy(model)
  4640. model_gpu.cuda(gpu_subset[0])
  4641. # DDP training setup
  4642. model_DDP = nn.SyncBatchNorm.convert_sync_batchnorm(copy.deepcopy(model))
  4643. model_DDP.cuda(gpu_subset[0])
  4644. model_DDP = nn.parallel.DistributedDataParallel(
  4645. model_DDP, device_ids=gpu_subset
  4646. )
  4647. # test serializable/unserializable
  4648. with tempfile.NamedTemporaryFile() as tmp:
  4649. if sys.platform == "win32":
  4650. torch.save(model_DDP, tmp)
  4651. tmp.seek(0)
  4652. model_DDP = torch.load(tmp)
  4653. else:
  4654. torch.save(model_DDP, tmp.name)
  4655. model_DDP = torch.load(tmp.name)
  4656. # data initialization
  4657. input_cpu = torch.randn(global_bs, 2)
  4658. target = torch.randn(global_bs, 4)
  4659. loss = nn.MSELoss()
  4660. # check two model parameters over 5 iterations
  4661. self._test_DDP_niter(
  4662. model_gpu,
  4663. model_DDP,
  4664. input_cpu.cuda(gpu_subset[0]),
  4665. target.cuda(gpu_subset[0]),
  4666. loss,
  4667. local_bs,
  4668. rank,
  4669. global_bs,
  4670. True,
  4671. offset,
  4672. dist.get_world_size(),
  4673. 5 if affine else 2,
  4674. )
  4675. self._barrier()
  4676. def _test_post_localSGD_optimizer_parity(self, create_averager, grad_is_view):
  4677. learning_rate = 0.03
  4678. net = torch.nn.parallel.DistributedDataParallel(
  4679. copy.deepcopy(DDP_NET).cuda(),
  4680. device_ids=[self.rank],
  4681. gradient_as_bucket_view=grad_is_view,
  4682. )
  4683. averager = create_averager()
  4684. opt = torch.optim.SGD(net.parameters(), lr=learning_rate)
  4685. net_using_post_localSGD_opt = torch.nn.parallel.DistributedDataParallel(
  4686. copy.deepcopy(DDP_NET).cuda(),
  4687. device_ids=[self.rank],
  4688. gradient_as_bucket_view=grad_is_view,
  4689. )
  4690. # Process group cannot be pickled in some environments,
  4691. # so cannot deep copy an averager. See:
  4692. # https://github.com/pytorch/pytorch/pull/74737#pullrequestreview-922487496
  4693. averager2 = create_averager()
  4694. post_localSGD_opt = self._create_post_localSGD_optimizer(
  4695. net_using_post_localSGD_opt,
  4696. learning_rate,
  4697. averager2
  4698. )
  4699. input = torch.randn(dist.get_world_size() * 2, 2).cuda()
  4700. target = torch.randn(dist.get_world_size() * 2, 4).cuda()
  4701. loss_fn = nn.MSELoss()
  4702. for _ in range(20):
  4703. self._perform_a_train_step(opt, net, loss_fn, input, target)
  4704. averager.average_parameters(net.parameters())
  4705. self._perform_a_train_step(
  4706. post_localSGD_opt,
  4707. net_using_post_localSGD_opt,
  4708. loss_fn,
  4709. input,
  4710. target
  4711. )
  4712. for p1, p2 in zip(net.parameters(), net_using_post_localSGD_opt.parameters()):
  4713. self.assertEqual(p1.data, p2.data)
  4714. # Also check if the built-in step counters are the same to prevent a bug like #74737.
  4715. self.assertEqual(averager.step, averager2.step)
  4716. def _create_periodic_model_averager(self):
  4717. return averagers.PeriodicModelAverager(period=4, warmup_steps=10)
  4718. def _create_post_localSGD_optimizer(self, net, learning_rate, averager):
  4719. return post_localSGD_optimizer.PostLocalSGDOptimizer(
  4720. optim=torch.optim.SGD(net.parameters(), lr=learning_rate),
  4721. averager=averager,
  4722. )
  4723. def _perform_a_train_step(self, optimizer, net, loss_fn, input, target):
  4724. optimizer.zero_grad()
  4725. output = net(input)
  4726. loss = loss_fn(output, target)
  4727. loss.backward()
  4728. optimizer.step()
  4729. def _test_post_localSGD_optimizer_step_reload(self, create_averager, chkpt_file):
  4730. learning_rate = 0.03
  4731. net_using_post_localSGD_opt = torch.nn.parallel.DistributedDataParallel(
  4732. copy.deepcopy(DDP_NET).cuda(),
  4733. device_ids=[self.rank]
  4734. )
  4735. averager = create_averager()
  4736. post_localSGD_opt = self._create_post_localSGD_optimizer(
  4737. net_using_post_localSGD_opt,
  4738. learning_rate,
  4739. averager
  4740. )
  4741. averager2 = create_averager()
  4742. dummy_post_localSGD_opt = self._create_post_localSGD_optimizer(
  4743. net_using_post_localSGD_opt,
  4744. learning_rate,
  4745. averager2
  4746. )
  4747. input = torch.randn(dist.get_world_size() * 2, 2).cuda()
  4748. target = torch.randn(dist.get_world_size() * 2, 4).cuda()
  4749. loss_fn = nn.MSELoss()
  4750. for _ in range(20):
  4751. self._perform_a_train_step(
  4752. post_localSGD_opt,
  4753. net_using_post_localSGD_opt,
  4754. loss_fn,
  4755. input,
  4756. target
  4757. )
  4758. if self.rank == 0:
  4759. torch.save({'optimizer_state_dict': post_localSGD_opt.state_dict()}, chkpt_file)
  4760. dist.barrier()
  4761. map_location = {'cuda:%d' % 0: 'cuda:%d' % self.rank}
  4762. checkpoint = torch.load(chkpt_file, map_location=map_location)
  4763. dummy_post_localSGD_opt.load_state_dict(checkpoint['optimizer_state_dict'])
  4764. # Check that we didn't hit the trivial case
  4765. self.assertNotEqual(averager2.step, 0)
  4766. # Check if dummy averager was initialized to a correct value
  4767. self.assertEqual(averager.step, averager2.step)
  4768. # Remove 'step' entry from a checkpoint.
  4769. # And make sure it is not in the state dictionary
  4770. del checkpoint['optimizer_state_dict']['step']
  4771. self.assertNotIn('step', checkpoint['optimizer_state_dict'])
  4772. # Check if checkpoint without a 'step' entry invokes a warning
  4773. with self.assertWarnsRegex(
  4774. expected_warning=UserWarning,
  4775. expected_regex="Loaded state dict does not contain a step counter for an averager. "
  4776. "Setting step counter to 0."
  4777. ):
  4778. dummy_post_localSGD_opt.load_state_dict(checkpoint['optimizer_state_dict'])
  4779. self.assertEqual(averager2.step, 0)
  4780. @skip_if_lt_x_gpu(2)
  4781. @sandcastle_skip_if(
  4782. BACKEND not in DistTestCases.backend_feature["ddp"],
  4783. f"The {BACKEND} backend does not support DistributedDataParallel"
  4784. )
  4785. def test_post_localSGD_optimizer_parity(self):
  4786. torch.cuda.set_device(self.rank)
  4787. self._test_post_localSGD_optimizer_parity(
  4788. self._create_periodic_model_averager,
  4789. grad_is_view=False,
  4790. )
  4791. @skip_if_lt_x_gpu(2)
  4792. @sandcastle_skip_if(
  4793. BACKEND not in DistTestCases.backend_feature["ddp"],
  4794. f"The {BACKEND} backend does not support DistributedDataParallel"
  4795. )
  4796. def test_post_localSGD_optimizer_parity_grad_is_view(self):
  4797. torch.cuda.set_device(self.rank)
  4798. self._test_post_localSGD_optimizer_parity(
  4799. self._create_periodic_model_averager,
  4800. grad_is_view=True,
  4801. )
  4802. def _create_hierarchical_model_averager(self):
  4803. period_group_size_dict = OrderedDict([(2, 2), (4, dist.get_world_size())])
  4804. return hierarchicalSGD.HierarchicalModelAverager(
  4805. period_group_size_dict=period_group_size_dict, warmup_steps=4
  4806. )
  4807. @skip_if_lt_x_gpu(4)
  4808. @skip_if_odd_worldsize
  4809. @sandcastle_skip_if(
  4810. BACKEND not in DistTestCases.backend_feature["ddp"],
  4811. f"The {BACKEND} backend does not support DistributedDataParallel"
  4812. )
  4813. def test_post_localSGD_optimizer_parity_with_hierarchical_sgd(self):
  4814. torch.cuda.set_device(self.rank)
  4815. self._test_post_localSGD_optimizer_parity(
  4816. self._create_hierarchical_model_averager,
  4817. grad_is_view=False,
  4818. )
  4819. @skip_if_lt_x_gpu(4)
  4820. @skip_if_odd_worldsize
  4821. @sandcastle_skip_if(
  4822. BACKEND not in DistTestCases.backend_feature["ddp"],
  4823. f"The {BACKEND} backend does not support DistributedDataParallel"
  4824. )
  4825. def test_post_localSGD_optimizer_parity_with_hierarchical_sgd_grad_is_view(self):
  4826. torch.cuda.set_device(self.rank)
  4827. self._test_post_localSGD_optimizer_parity(
  4828. self._create_hierarchical_model_averager,
  4829. grad_is_view=True,
  4830. )
  4831. @skip_if_lt_x_gpu(2)
  4832. @sandcastle_skip_if(
  4833. BACKEND not in DistTestCases.backend_feature["ddp"],
  4834. f"The {BACKEND} backend does not support DistributedDataParallel"
  4835. )
  4836. def test_post_localSGD_optimizer_step_reload(self):
  4837. torch.cuda.set_device(self.rank)
  4838. with _rank_temp_file() as tmp_file:
  4839. self._test_post_localSGD_optimizer_step_reload(
  4840. self._create_periodic_model_averager,
  4841. tmp_file
  4842. )
  4843. @sandcastle_skip_if(
  4844. BACKEND not in DistTestCases.backend_feature["ddp"],
  4845. f"The {BACKEND} backend does not support DistributedDataParallel"
  4846. )
  4847. @skip_if_no_gpu
  4848. def test_DistributedDataParallel_SyncBatchNorm_Channels_Last(self):
  4849. self._test_DistributedDataParallel_SyncBatchNorm_with_memory_format(torch.channels_last)
  4850. self._test_DistributedDataParallel_SyncBatchNorm_with_memory_format(torch.channels_last_3d)
  4851. def _test_DistributedDataParallel_SyncBatchNorm_with_memory_format(self, memory_format):
  4852. group, group_id, rank = self._init_global_test()
  4853. num_processes = dist.get_world_size()
  4854. local_bs = 2
  4855. bs_offset = int(rank * 2)
  4856. global_bs = int(num_processes * 2)
  4857. model = ONLY_SBN_NET
  4858. model_gpu = copy.deepcopy(model).cuda(rank)
  4859. model_DDP = nn.parallel.DistributedDataParallel(
  4860. model_gpu, device_ids=[rank]
  4861. )
  4862. shapes = [global_bs, 2, 4, 4] + ([] if memory_format is torch.channels_last else [4])
  4863. input_gpu = (
  4864. torch.randn(*shapes, dtype=torch.float)
  4865. .cuda(rank)
  4866. .to(memory_format=memory_format)
  4867. )
  4868. target_gpu = (
  4869. torch.randn(*shapes, dtype=torch.float)
  4870. .cuda(rank)
  4871. .to(memory_format=memory_format)
  4872. )
  4873. loss = nn.MSELoss()
  4874. # check two model parameters over 5 iterations
  4875. self._test_DDP_niter(
  4876. model_gpu,
  4877. model_DDP,
  4878. input_gpu,
  4879. target_gpu,
  4880. loss,
  4881. local_bs,
  4882. rank,
  4883. global_bs,
  4884. True,
  4885. bs_offset,
  4886. dist.get_world_size(),
  4887. memory_format=memory_format,
  4888. )
  4889. self._barrier()
  4890. @sandcastle_skip_if(
  4891. BACKEND not in DistTestCases.backend_feature["ddp"],
  4892. f"The {BACKEND} backend does not support DistributedDataParallel"
  4893. )
  4894. @skip_if_no_gpu
  4895. def test_DistributedDataParallel_SyncBatchNorm(self):
  4896. group, group_id, rank = self._init_global_test()
  4897. world_size = dist.get_world_size()
  4898. # DDP does not support replicating BN layers within a process, hence
  4899. # testing with one module replica per process
  4900. gpus = [rank]
  4901. local_bs = 2
  4902. bs_offset = int(rank * 2)
  4903. global_bs = int(world_size * 2)
  4904. self._test_DistributedDataParallel_SyncBatchNorm(
  4905. gpu_subset=gpus,
  4906. rank=rank,
  4907. local_bs=local_bs,
  4908. global_bs=global_bs,
  4909. offset=bs_offset,
  4910. )
  4911. # test output_device
  4912. self._test_DistributedDataParallel_SyncBatchNorm(
  4913. gpu_subset=gpus,
  4914. rank=rank,
  4915. local_bs=local_bs,
  4916. global_bs=global_bs,
  4917. offset=bs_offset,
  4918. output_device=torch.device("cuda"),
  4919. )
  4920. # test device_ids
  4921. gpus = [torch.device("cuda:" + str(i)) for i in gpus]
  4922. self._test_DistributedDataParallel_SyncBatchNorm(
  4923. gpu_subset=gpus,
  4924. rank=rank,
  4925. local_bs=local_bs,
  4926. global_bs=global_bs,
  4927. offset=bs_offset,
  4928. output_device=torch.device("cuda"),
  4929. )
  4930. @sandcastle_skip_if(
  4931. BACKEND not in DistTestCases.backend_feature["ddp"],
  4932. f"The {BACKEND} backend does not support DistributedDataParallel"
  4933. )
  4934. @skip_if_no_gpu
  4935. def test_DistributedDataParallel_SyncBatchNorm_No_Affine(self):
  4936. group, group_id, rank = self._init_global_test()
  4937. world_size = dist.get_world_size()
  4938. # DDP does not support replicating BN layers within a process, hence
  4939. # testing with one module replica per process
  4940. gpus = [rank]
  4941. local_bs = 2
  4942. bs_offset = int(rank * 2)
  4943. global_bs = int(world_size * 2)
  4944. self._test_DistributedDataParallel_SyncBatchNorm(
  4945. gpu_subset=gpus,
  4946. rank=rank,
  4947. local_bs=local_bs,
  4948. global_bs=global_bs,
  4949. offset=bs_offset,
  4950. affine=False,
  4951. )
  4952. @sandcastle_skip_if(
  4953. BACKEND not in DistTestCases.backend_feature["ddp"],
  4954. f"The {BACKEND} backend does not support DistributedDataParallel"
  4955. )
  4956. @skip_if_no_gpu
  4957. def test_DistributedDataParallel_SyncBatchNorm_2D_Input(self):
  4958. group, group_id, rank = self._init_global_test()
  4959. # DDP does not support replicating BN layers within a process, hence
  4960. # testing with one module replica per process
  4961. gpus = [rank]
  4962. model = nn.BatchNorm1d(2)
  4963. # single gpu training setup
  4964. model_gpu = copy.deepcopy(model)
  4965. model_gpu.cuda(gpus[0])
  4966. # DDP training setup
  4967. model_DDP = nn.SyncBatchNorm.convert_sync_batchnorm(copy.deepcopy(model))
  4968. model_DDP.cuda(gpus[0])
  4969. model_DDP = nn.parallel.DistributedDataParallel(model_DDP, device_ids=gpus)
  4970. local_bs = len(gpus) * 2
  4971. global_bs = dist.get_world_size() * local_bs
  4972. input_cpu = torch.randn(global_bs, 2)
  4973. target = torch.randn(global_bs, 2)
  4974. loss = nn.MSELoss()
  4975. # disabling cudnn.
  4976. # SyncBatchNorm goes through native_batch_norm kernel, this avoids the
  4977. # numerical issue created by the divergent code path.
  4978. with torch.backends.cudnn.flags(False):
  4979. # check two model parameters over 5 iterations
  4980. self._test_DDP_niter(
  4981. model_gpu,
  4982. model_DDP,
  4983. input_cpu.cuda(gpus[0]),
  4984. target.cuda(gpus[0]),
  4985. loss,
  4986. local_bs,
  4987. rank,
  4988. global_bs,
  4989. True,
  4990. )
  4991. self._barrier()
  4992. @sandcastle_skip_if(
  4993. BACKEND not in DistTestCases.backend_feature["ddp"],
  4994. f"The {BACKEND} backend does not support DistributedDataParallel"
  4995. )
  4996. @skip_if_no_gpu
  4997. @require_world_size(2)
  4998. def test_DistributedDataParallel_SyncBatchNorm_Single_Input_Per_Process(self):
  4999. group, group_id, rank = self._init_global_test()
  5000. # DDP does not support replicating BN layers within a process, hence
  5001. # testing with one module replica per process
  5002. gpus = [rank]
  5003. model = nn.BatchNorm1d(2)
  5004. # single gpu training setup
  5005. model_gpu = copy.deepcopy(model)
  5006. model_gpu.cuda(gpus[0])
  5007. # DDP training setup
  5008. model_DDP = nn.SyncBatchNorm.convert_sync_batchnorm(copy.deepcopy(model))
  5009. model_DDP.cuda(gpus[0])
  5010. model_DDP = nn.parallel.DistributedDataParallel(model_DDP, device_ids=gpus)
  5011. local_bs = 1
  5012. global_bs = dist.get_world_size()
  5013. input_cpu = torch.randn(global_bs, 2)
  5014. target = torch.randn(global_bs, 2)
  5015. loss = nn.MSELoss()
  5016. # disabling cudnn.
  5017. # SyncBatchNorm goes through native_batch_norm kernel, this avoids the
  5018. # numerical issue created by the divergent code path.
  5019. with torch.backends.cudnn.flags(False):
  5020. # check two model parameters over 5 iterations
  5021. self._test_DDP_niter(
  5022. model_gpu,
  5023. model_DDP,
  5024. input_cpu.cuda(gpus[0]),
  5025. target.cuda(gpus[0]),
  5026. loss,
  5027. local_bs,
  5028. rank,
  5029. global_bs,
  5030. True,
  5031. )
  5032. self._barrier()
  5033. @sandcastle_skip_if(
  5034. BACKEND not in DistTestCases.backend_feature["ddp"],
  5035. f"The {BACKEND} backend does not support DistributedDataParallel"
  5036. )
  5037. @skip_if_no_gpu
  5038. def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_Running_Value(
  5039. self,
  5040. ):
  5041. group, group_id, rank = self._init_global_test()
  5042. model = nn.parallel.DistributedDataParallel(
  5043. ONLY_SBN_NET.cuda(rank), device_ids=[rank]
  5044. )
  5045. input_var = []
  5046. for i in range(dist.get_world_size()):
  5047. input_var_rank = torch.cat(
  5048. [
  5049. torch.ones(2, 1, 10 ** (i + 1)) * (0.1 ** (i - 1)),
  5050. torch.ones(2, 1, 10 ** (i + 1)) * (0.3 ** (i - 1)),
  5051. ],
  5052. dim=1,
  5053. )
  5054. input_var.append(input_var_rank)
  5055. all_input_var = torch.cat(
  5056. [
  5057. x.permute(1, 0, 2).contiguous().view(ONLY_SBN_NET.num_features, -1)
  5058. for x in input_var
  5059. ],
  5060. dim=1,
  5061. ).cuda(rank)
  5062. for i in range(100):
  5063. y = model(input_var[rank].cuda(rank))
  5064. y.mean().backward()
  5065. running_mean, running_var = (
  5066. model.module.running_mean,
  5067. model.module.running_var,
  5068. )
  5069. torch.testing.assert_close(running_mean, all_input_var.mean(1))
  5070. torch.testing.assert_close(running_var, all_input_var.var(1))
  5071. @sandcastle_skip_if(
  5072. BACKEND not in DistTestCases.backend_feature["ddp"],
  5073. f"The {BACKEND} backend does not support DistributedDataParallel"
  5074. )
  5075. @skip_if_no_gpu
  5076. def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_gradient(self):
  5077. group, group_id, rank = self._init_global_test()
  5078. # only do single GPU per process
  5079. gpus = [rank]
  5080. # cpu training setup
  5081. model = BN_NET
  5082. num_processes = dist.get_world_size()
  5083. local_bs = rank + 2
  5084. bs_offset = int((rank + 3) * rank / 2)
  5085. global_bs = int((num_processes + 3) * num_processes / 2)
  5086. self._test_DistributedDataParallel_SyncBatchNorm(
  5087. gpu_subset=gpus,
  5088. rank=rank,
  5089. local_bs=local_bs,
  5090. global_bs=global_bs,
  5091. offset=bs_offset,
  5092. )
  5093. def _test_ddp_logging_data(self, is_gpu):
  5094. rank = dist.get_rank()
  5095. model_DDP = copy.deepcopy(DDP_NET)
  5096. if is_gpu:
  5097. model_DDP = nn.parallel.DistributedDataParallel(
  5098. model_DDP.cuda(rank), device_ids=[rank]
  5099. )
  5100. else:
  5101. model_DDP = nn.parallel.DistributedDataParallel(model_DDP)
  5102. # dummy data initialization
  5103. local_bs = 2
  5104. batch_size, input, target, loss = self._prepare_dummy_data(local_bs)
  5105. if is_gpu:
  5106. input = input.cuda(rank)
  5107. target = target.cuda(rank)
  5108. model_DDP._set_ddp_runtime_logging_sample_rate(2)
  5109. for idx in range(20):
  5110. offset = rank * local_bs
  5111. # DDP training, DDP scatters subsets of input to nodes/GPUs
  5112. self._test_DDP_helper(
  5113. model_DDP,
  5114. input[offset : offset + local_bs],
  5115. target[offset : offset + local_bs],
  5116. loss,
  5117. 1,
  5118. )
  5119. self._model_step_with_zero_grad(model_DDP)
  5120. # Verify DDP logging data is sampled as expected
  5121. # If it has ran more than 10 iteratons and this is
  5122. # the sampled iteration for measuring run time stats,
  5123. # the run time stats for this idx-th iteration will not
  5124. # be zeros.
  5125. ddp_logging_data = model_DDP._get_ddp_logging_data()
  5126. if idx > 0 and (idx < 10 or idx % 2 == 0):
  5127. self.assertGreaterEqual(
  5128. ddp_logging_data.get("forward_compute_time"), 1
  5129. )
  5130. self.assertGreaterEqual(
  5131. ddp_logging_data.get("backward_compute_time"), 1
  5132. )
  5133. self.assertGreaterEqual(
  5134. ddp_logging_data.get("backward_comm_time"), 1
  5135. )
  5136. self.assertGreaterEqual(
  5137. ddp_logging_data.get("backward_compute_time"),
  5138. ddp_logging_data.get("backward_compute_comm_overlap_time"),
  5139. )
  5140. self.assertGreaterEqual(
  5141. ddp_logging_data.get("backward_comm_time"),
  5142. ddp_logging_data.get("backward_compute_comm_overlap_time"),
  5143. )
  5144. self.assertEqual(ddp_logging_data.get("iteration"), idx)
  5145. elif idx > 0:
  5146. # if the idx-th iteration is not sampled to set runtime stats,
  5147. # ddp_logging_data.iteration will not be updated to current
  5148. # iteration.
  5149. self.assertNotEqual(ddp_logging_data.get("iteration"), idx)
  5150. # Shuffle the input so that DDP input is different
  5151. input = input[torch.randperm(batch_size)]
  5152. return model_DDP
  5153. @sandcastle_skip_if(BACKEND == "nccl", "nccl does not support DDP on CPU models")
  5154. def test_ddp_logging_data_cpu(self):
  5155. def parse_env(var):
  5156. return os.environ[var] if var in os.environ else "N/A"
  5157. dist.set_debug_level(dist.DebugLevel.INFO)
  5158. group, group_id, rank = self._init_global_test()
  5159. model_DDP = self._test_ddp_logging_data(is_gpu=False)
  5160. ddp_logging_data = model_DDP._get_ddp_logging_data()
  5161. self.assertEqual(ddp_logging_data.get("world_size"), dist.get_world_size())
  5162. self.assertEqual(ddp_logging_data.get("rank"), dist.get_rank())
  5163. self.assertEqual(ddp_logging_data.get("module_name"), "Net")
  5164. self.assertEqual(ddp_logging_data.get("device_ids"), "")
  5165. # output_device is -1 in default if it is not set, e.g.
  5166. # output_device of CPU training is -1.
  5167. self.assertEqual(ddp_logging_data.get("output_device"), -1)
  5168. self.assertEqual(ddp_logging_data.get("broadcast_buffers"), 1)
  5169. self.assertEqual(ddp_logging_data.get("bucket_cap_bytes"), 25 * 1024 * 1024)
  5170. self.assertEqual(ddp_logging_data.get("find_unused_parameters"), 0)
  5171. self.assertEqual(ddp_logging_data.get("gradient_as_bucket_view"), 0)
  5172. self.assertEqual(
  5173. ddp_logging_data.get("backend_name"), dist.get_backend(group_id)
  5174. )
  5175. self.assertEqual(ddp_logging_data.get("iteration"), 18)
  5176. params = list(model_DDP.parameters())
  5177. num_params = 0
  5178. param_size = 0
  5179. params = list(filter(lambda parameter: parameter.requires_grad, params))
  5180. for p in params:
  5181. num_params += 1
  5182. param_size += p.numel() * p.element_size()
  5183. self.assertEqual(ddp_logging_data.get("dtypes"), "float")
  5184. self.assertEqual(
  5185. ddp_logging_data.get("total_parameter_size_bytes"), param_size
  5186. )
  5187. self.assertEqual(ddp_logging_data.get("num_parameter_tensors"), num_params)
  5188. self.assertEqual(ddp_logging_data.get("bucket_sizes"), str(param_size))
  5189. self.assertEqual(
  5190. ddp_logging_data.get("master_port"), parse_env("MASTER_PORT")
  5191. )
  5192. self.assertEqual(
  5193. ddp_logging_data.get("master_addr"), parse_env("MASTER_ADDR")
  5194. )
  5195. self.assertEqual(
  5196. ddp_logging_data.get("torch_distributed_debug"),
  5197. parse_env("TORCH_DISTRIBUTED_DEBUG"),
  5198. )
  5199. self.assertEqual(
  5200. ddp_logging_data.get("cuda_visible_devices"),
  5201. parse_env("CUDA_VISIBLE_DEVICES"),
  5202. )
  5203. if ddp_logging_data.get("backend_name") == "gloo":
  5204. self.assertEqual(
  5205. ddp_logging_data.get("gloo_socket_ifname"),
  5206. parse_env("GLOO_SOCKET_IFNAME"),
  5207. )
  5208. self.assertEqual(
  5209. ddp_logging_data.get("gloo_device_transport"),
  5210. parse_env("GLOO_DEVICE_TRANSPORT"),
  5211. )
  5212. default_gloo_threads = 2
  5213. self.assertEqual(
  5214. ddp_logging_data.get("gloo_num_threads"),
  5215. default_gloo_threads,
  5216. )
  5217. self.assertEqual(ddp_logging_data.get("nccl_socket_ifname"), None)
  5218. self.assertEqual(ddp_logging_data.get("nccl_blocking_wait"), None)
  5219. self.assertEqual(ddp_logging_data.get("nccl_async_error_handling"), None)
  5220. self.assertEqual(ddp_logging_data.get("nccl_debug"), None)
  5221. self.assertEqual(ddp_logging_data.get("nccl_nthreads"), None)
  5222. self.assertEqual(ddp_logging_data.get("nccl_ib_timeout"), None)
  5223. # test runtime logging fields
  5224. # Note: DETAIL debug mode logs DDP logging data to stdout and
  5225. # thus accesses std::map, which fills in a default value for the
  5226. # type if it didn't exist.
  5227. self.assertEqual(ddp_logging_data.get("unused_parameter_size", 0), 0)
  5228. self.assertEqual(ddp_logging_data.get("has_rebuilt_buckets"), 1)
  5229. self.assertEqual(
  5230. ddp_logging_data.get("rebuilt_bucket_sizes"), str(param_size)
  5231. )
  5232. grad_ready_order = ddp_logging_data.get("prev_iteration_grad_ready_order_indices")
  5233. expected_order = list(reversed([str(x) for x in range(3)]))
  5234. self.assertEqual(grad_ready_order, ", ".join(expected_order))
  5235. bucket_indices = ddp_logging_data.get("rebuilt_per_bucket_param_indices")
  5236. self.assertEqual(bucket_indices, " ".join(expected_order))
  5237. # It is hard to test accurate latency, but it can test whether the latency is
  5238. # a valid value and in the expected range.
  5239. self.assertGreaterEqual(ddp_logging_data.get("avg_forward_compute_time"), 1)
  5240. self.assertGreaterEqual(
  5241. ddp_logging_data.get("avg_backward_compute_time"), 1
  5242. )
  5243. self.assertGreaterEqual(ddp_logging_data.get("avg_backward_comm_time"), 1)
  5244. self.assertGreaterEqual(
  5245. ddp_logging_data.get("avg_backward_compute_time"),
  5246. ddp_logging_data.get("avg_backward_compute_comm_overlap_time"),
  5247. )
  5248. self.assertGreaterEqual(
  5249. ddp_logging_data.get("avg_backward_comm_time"),
  5250. ddp_logging_data.get("avg_backward_compute_comm_overlap_time"),
  5251. )
  5252. # Test host-side times are roughly in the order that we expect
  5253. fwd_host_side_time = ddp_logging_data.get("forward_compute_time_start")
  5254. bwd_comp_start_host_side_time = ddp_logging_data.get("backward_compute_time_start")
  5255. bwd_comp_end_host_side_time = ddp_logging_data.get("backward_compute_time_end")
  5256. bwd_comm_start_host_side_time = ddp_logging_data.get("backward_comm_time_start")
  5257. bwd_comm_end_host_side_time = ddp_logging_data.get("backward_comm_time_end")
  5258. self.assertGreaterEqual(bwd_comm_end_host_side_time, bwd_comm_start_host_side_time)
  5259. self.assertGreaterEqual(bwd_comm_start_host_side_time, bwd_comp_start_host_side_time)
  5260. self.assertGreaterEqual(bwd_comp_end_host_side_time, bwd_comp_start_host_side_time)
  5261. self.assertGreaterEqual(bwd_comp_start_host_side_time, fwd_host_side_time)
  5262. # test larger net with mixed data types, verify multiple bucket sizes
  5263. model = LargeNet()
  5264. model.float()
  5265. model.fc1.double()
  5266. model_DDP = nn.parallel.DistributedDataParallel(model, bucket_cap_mb=1.5)
  5267. ddp_logging_data = model_DDP._get_ddp_logging_data()
  5268. params = list(model_DDP.parameters())
  5269. self.assertEqual(
  5270. ddp_logging_data.get("bucket_cap_bytes"), int(1.5 * 1024 * 1024)
  5271. )
  5272. bucket_sizes = [
  5273. params[1].numel() * params[1].element_size(),
  5274. params[0].numel() * params[0].element_size(),
  5275. ]
  5276. self.assertEqual(
  5277. ddp_logging_data.get("bucket_sizes"),
  5278. ", ".join(str(x) for x in bucket_sizes),
  5279. )
  5280. self.assertEqual(ddp_logging_data.get("dtypes"), "double, float")
  5281. @sandcastle_skip_if(
  5282. BACKEND not in DistTestCases.backend_feature["ddp"],
  5283. f"The {BACKEND} backend does not support DistributedDataParallel"
  5284. )
  5285. @skip_if_no_gpu
  5286. def test_ddp_logging_data_gpu(self):
  5287. group, group_id, rank = self._init_global_test()
  5288. model_DDP = self._test_ddp_logging_data(is_gpu=True)
  5289. ddp_logging_data = model_DDP._get_ddp_logging_data()
  5290. self.assertEqual(ddp_logging_data.get("device_ids"), str(rank))
  5291. self.assertEqual(ddp_logging_data.get("output_device"), rank)
  5292. grad_ready_order = ddp_logging_data.get("prev_iteration_grad_ready_order_indices")
  5293. expected_order = list(reversed([str(x) for x in range(3)]))
  5294. self.assertEqual(grad_ready_order, ", ".join(expected_order))
  5295. bucket_indices = ddp_logging_data.get("rebuilt_per_bucket_param_indices")
  5296. self.assertEqual(bucket_indices, " ".join(expected_order))
  5297. # test runtime logging fields
  5298. # It is hard to test accurate latency, but it can test whether the latency is
  5299. # a valid value and in the expected range.
  5300. self.assertGreaterEqual(ddp_logging_data.get("avg_forward_compute_time"), 1)
  5301. self.assertGreaterEqual(
  5302. ddp_logging_data.get("avg_backward_compute_comm_overlap_time"), 1
  5303. )
  5304. self.assertGreaterEqual(
  5305. ddp_logging_data.get("avg_backward_compute_time"),
  5306. ddp_logging_data.get("avg_backward_compute_comm_overlap_time"),
  5307. )
  5308. self.assertGreaterEqual(
  5309. ddp_logging_data.get("avg_backward_comm_time"),
  5310. ddp_logging_data.get("avg_backward_compute_comm_overlap_time"),
  5311. )
  5312. # Test host-side times are roughly in the order that we expect
  5313. fwd_host_side_time = ddp_logging_data.get("forward_compute_time_start")
  5314. bwd_comp_start_host_side_time = ddp_logging_data.get("backward_compute_time_start")
  5315. bwd_comp_end_host_side_time = ddp_logging_data.get("backward_compute_time_end")
  5316. bwd_comm_start_host_side_time = ddp_logging_data.get("backward_comm_time_start")
  5317. bwd_comm_end_host_side_time = ddp_logging_data.get("backward_comm_time_end")
  5318. self.assertGreaterEqual(bwd_comm_end_host_side_time, bwd_comm_start_host_side_time)
  5319. self.assertGreaterEqual(bwd_comm_start_host_side_time, bwd_comp_start_host_side_time)
  5320. self.assertGreaterEqual(bwd_comp_end_host_side_time, bwd_comp_start_host_side_time)
  5321. self.assertGreaterEqual(bwd_comp_start_host_side_time, fwd_host_side_time)
  5322. @sandcastle_skip_if(BACKEND == "nccl", "nccl does not support DDP on CPU models")
  5323. def test_static_graph_api_cpu(self):
  5324. model_DDP = nn.parallel.DistributedDataParallel(DDP_NET)
  5325. expected_err = "should be called before training loop starts"
  5326. with self.assertRaisesRegex(RuntimeError, expected_err):
  5327. local_bs = 2
  5328. batch_size, input, target, loss = self._prepare_dummy_data(local_bs)
  5329. offset = dist.get_rank() * local_bs
  5330. # DDP training, DDP scatters subsets of input to nodes/GPUs
  5331. self._test_DDP_helper(
  5332. model_DDP,
  5333. input[offset : offset + local_bs],
  5334. target[offset : offset + local_bs],
  5335. loss,
  5336. 1,
  5337. )
  5338. model_DDP._set_static_graph()
  5339. # Verify error was logged in ddp_logging_data.
  5340. verify_ddp_error_logged(model_DDP, expected_err)
  5341. @skipIfNoTorchVision
  5342. def test_SyncBatchNorm_process_group(self):
  5343. # When adopting `convert_sync_batchnorm` to convert a `nn.modules`,
  5344. # it need to recursively pass the `process_group` in the module when the `SyncBatchNorm`
  5345. # is nested in a sub-module or sub-sub-module (e.g. resnet50 in torchvision.models).
  5346. process_ids = 0
  5347. process_group = torch.distributed.new_group([process_ids])
  5348. res50_model = torchvision.models.resnet50()
  5349. res50_model_sync = nn.SyncBatchNorm.convert_sync_batchnorm(
  5350. copy.deepcopy(res50_model), process_group
  5351. )
  5352. process_group_sync = res50_model_sync.layer1[0].bn1.process_group
  5353. self.assertEqual(process_group_sync, process_group)
  5354. def _run_reduction_test(
  5355. self, tensor, expected_tensor, op, reduction_fn=dist.all_reduce, dst=None
  5356. ):
  5357. if reduction_fn != dist.all_reduce and dst is None:
  5358. raise ValueError(f"Reduction fn {reduction_fn} must specify dst!")
  5359. if dst is not None:
  5360. reduction_fn(tensor, dst, op)
  5361. # Only destination rank tensor is expected to have final result.
  5362. if dist.get_rank() == dst:
  5363. self.assertEqual(tensor, expected_tensor)
  5364. else:
  5365. reduction_fn(tensor, op)
  5366. self.assertEqual(tensor, expected_tensor)
  5367. @require_backend({"nccl"})
  5368. @require_backends_available({"nccl"})
  5369. @skip_if_lt_x_gpu(2)
  5370. def test_nccl_backend_bool_allreduce(self):
  5371. torch.cuda.set_device(self.rank)
  5372. # Run all_reduce with PRODUCT
  5373. element = self.rank % 2 == 0
  5374. for op in [dist.ReduceOp.PRODUCT, dist.ReduceOp.MIN]:
  5375. input_tensor = torch.tensor([element, element]).to(self.rank)
  5376. self._run_reduction_test(
  5377. input_tensor, torch.tensor([False, False]).to(self.rank), op
  5378. )
  5379. # Ensure that all ranks contributing True (cast to 1) results in the
  5380. # correct reduction.
  5381. input_tensor = torch.tensor([True, True]).to(self.rank)
  5382. expected_tensor = input_tensor.clone()
  5383. self._run_reduction_test(input_tensor, expected_tensor, op)
  5384. # Run all_reduce with SUM
  5385. for op in [dist.ReduceOp.SUM, dist.ReduceOp.MAX]:
  5386. input_tensor = torch.tensor([element, element]).to(self.rank)
  5387. self._run_reduction_test(
  5388. input_tensor, torch.tensor([True, True]).to(self.rank), op
  5389. )
  5390. # TODO: NCCL backend does not work correctly for bitwise reduction ops
  5391. # (see https://github.com/pytorch/pytorch/issues/41362). Add tests for
  5392. # these once it is supported.
  5393. @require_backend({"nccl"})
  5394. @require_backends_available({"nccl"})
  5395. @skip_if_lt_x_gpu(2)
  5396. def test_nccl_backend_bool_allgather(self):
  5397. torch.cuda.set_device(self.rank)
  5398. inp = {0: [True, True], 1: [False, True]}
  5399. input_tensor = torch.tensor(inp[self.rank % 2]).to(self.rank)
  5400. # Preserve a copy of the tensor to compare against after allgather.
  5401. input_tensor_copy = input_tensor.clone()
  5402. tensor_list = [
  5403. torch.tensor([False, False]).to(self.rank)
  5404. for _ in range(dist.get_world_size())
  5405. ]
  5406. dist.all_gather(tensor_list, input_tensor)
  5407. self.assertEqual(len(tensor_list), dist.get_world_size())
  5408. for i, t in enumerate(tensor_list):
  5409. expected = torch.tensor(inp[i % 2]).to(self.rank)
  5410. self.assertEqual(t, expected)
  5411. # Ensure that the input tensor is not modified, since this collective
  5412. # does not modify its input.
  5413. self.assertEqual(input_tensor_copy, input_tensor)
  5414. @require_backend({"nccl"})
  5415. @require_backends_available({"nccl"})
  5416. @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
  5417. def test_nccl_backend_bool_reduce(self):
  5418. torch.cuda.set_device(self.rank)
  5419. inp = {0: [True, True], 1: [False, False]}
  5420. # Run reduce() with product op
  5421. for op in [dist.ReduceOp.PRODUCT, dist.ReduceOp.MIN]:
  5422. input_tensor = torch.tensor(inp[self.rank % 2]).to(self.rank)
  5423. expected = torch.tensor([False, False]).to(self.rank)
  5424. self._run_reduction_test(input_tensor, expected, op, dist.reduce, dst=0)
  5425. # Ensure that all ranks contributing True (cast to 1) results in the
  5426. # correct reduction.
  5427. input_tensor = torch.tensor([True, True]).to(self.rank)
  5428. expected_tensor = input_tensor.clone()
  5429. self._run_reduction_test(
  5430. input_tensor, expected_tensor, op, dist.reduce, dst=0
  5431. )
  5432. for op in [dist.ReduceOp.SUM, dist.ReduceOp.MAX]:
  5433. input_tensor = torch.tensor(inp[self.rank % 2]).to(self.rank)
  5434. expected = (
  5435. torch.tensor([True, True]).to(self.rank)
  5436. if self.rank == 0
  5437. else input_tensor.clone()
  5438. )
  5439. self._run_reduction_test(input_tensor, expected, op, dist.reduce, dst=0)
  5440. @require_backend({"nccl"})
  5441. @require_backends_available({"nccl"})
  5442. @skip_if_lt_x_gpu(2)
  5443. def test_nccl_backend_bool_broadcast(self):
  5444. tensor_size = 10
  5445. bcast_tensor = torch.tensor(
  5446. [
  5447. (random.random() < 0.5 if self.rank == 0 else False)
  5448. for _ in range(tensor_size)
  5449. ]
  5450. ).to(self.rank)
  5451. dist.broadcast(bcast_tensor, src=0)
  5452. # Now allgather and ensure the tensors are equal.
  5453. tensor_list = [
  5454. torch.tensor([False for _ in range(tensor_size)]).to(self.rank)
  5455. for _ in range(dist.get_world_size())
  5456. ]
  5457. dist.all_gather(tensor_list, bcast_tensor)
  5458. expected = tensor_list[0]
  5459. for tensor in tensor_list[1:]:
  5460. self.assertEqual(tensor, expected)
  5461. @sandcastle_skip_if(
  5462. BACKEND not in DistTestCases.backend_feature["ddp"],
  5463. f"The {BACKEND} backend does not support DistributedDataParallel"
  5464. )
  5465. @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
  5466. def test_DistributedSampler_padding(self):
  5467. # Tests padding of distributed sampler.
  5468. world_size = dist.get_world_size()
  5469. # Simulates the 'casual' dataset size
  5470. dataset_size = 100 + world_size + 1
  5471. dataset = [torch.ones(1).to(self.rank) * i for i in range(dataset_size)]
  5472. # Simulates the 'tiny' dataset size
  5473. dataset_tiny_size = max(world_size // 2 - 1, 1)
  5474. dataset_tiny = [
  5475. torch.ones(1).to(self.rank) * i for i in range(dataset_tiny_size)
  5476. ]
  5477. # Specifying drop_last=True will cause the tail of the data to be dropped.
  5478. dist_sampler = DistributedSampler(dataset=dataset, drop_last=True)
  5479. local_num_samples, local_dataset_size = (
  5480. dist_sampler.num_samples,
  5481. dist_sampler.total_size,
  5482. )
  5483. # The effective dataset size should be the greatest integer that is <=
  5484. # dataset_size that is divisible by the world_size. This is to ensure each
  5485. # rank processes the same number of samples.
  5486. effective_dataset_size = (
  5487. math.ceil((dataset_size - world_size) / world_size)
  5488. if dataset_size % world_size != 0
  5489. else dataset_size / world_size
  5490. )
  5491. self.assertEqual(local_num_samples, effective_dataset_size)
  5492. self.assertEqual(local_dataset_size, local_num_samples * world_size)
  5493. indices_list = list(iter(dist_sampler))
  5494. self.assertEqual(len(indices_list), local_num_samples)
  5495. def validate_global_samples(local_num_samples):
  5496. # Ensure that each rank processes the same number of samples.
  5497. world_samples = [
  5498. torch.LongTensor([0]).to(self.rank) for _ in range(world_size)
  5499. ]
  5500. dist.all_gather(
  5501. world_samples, torch.tensor([local_num_samples]).to(self.rank)
  5502. )
  5503. world_samples = [sample.item() for sample in world_samples]
  5504. self.assertEqual(len(set(world_samples)), 1)
  5505. validate_global_samples(local_num_samples)
  5506. # drop_last=False is the default and will add additional indices to be sampled,
  5507. # increasing the effective dataset size.
  5508. dist_sampler_added_samples = DistributedSampler(dataset=dataset)
  5509. local_num_samples, local_dataset_size = (
  5510. dist_sampler_added_samples.num_samples,
  5511. dist_sampler_added_samples.total_size,
  5512. )
  5513. # The effective dataset size is the smallest integer that is >= dataset_size
  5514. # and divisible by the world size.
  5515. self.assertEqual(local_num_samples, math.ceil(dataset_size / world_size))
  5516. self.assertEqual(local_dataset_size, local_num_samples * world_size)
  5517. indices_list = list(iter(dist_sampler_added_samples))
  5518. self.assertEqual(len(indices_list), local_num_samples)
  5519. # Ensure that each rank processes the same number of samples.
  5520. validate_global_samples(local_num_samples)
  5521. # Ensure additional samples are padded even when
  5522. # the extremely small dataset is given.
  5523. dist_sampler_added_samples_tiny = DistributedSampler(dataset=dataset_tiny)
  5524. local_num_samples, local_dataset_size = (
  5525. dist_sampler_added_samples_tiny.num_samples,
  5526. dist_sampler_added_samples_tiny.total_size,
  5527. )
  5528. self.assertEqual(
  5529. local_num_samples, math.ceil(dataset_tiny_size / world_size)
  5530. )
  5531. self.assertEqual(local_dataset_size, local_num_samples * world_size)
  5532. indices_list = list(iter(dist_sampler_added_samples_tiny))
  5533. self.assertEqual(len(indices_list), local_num_samples)
  5534. validate_global_samples(local_num_samples)
  5535. def _test_allgather_object(self, subgroup=None):
  5536. # Only set device for NCCL backend since it must use GPUs.
  5537. gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy()
  5538. backend = os.environ["BACKEND"]
  5539. if backend == "nccl":
  5540. # Case where rank != GPU device.
  5541. next_rank = (self.rank + 1) % int(self.world_size)
  5542. torch.cuda.set_device(next_rank)
  5543. # If GPU test, add object with GPU tensor
  5544. if backend == "nccl":
  5545. gather_objects.append(Foo(torch.randn(3, 3, device=0)))
  5546. output_gathered = [None for _ in range(dist.get_world_size())]
  5547. dist.all_gather_object(
  5548. output_gathered,
  5549. gather_objects[self.rank % len(gather_objects)],
  5550. group=subgroup,
  5551. )
  5552. for i, val in enumerate(output_gathered):
  5553. expected = gather_objects[i % len(gather_objects)]
  5554. self.assertEqual(val, expected)
  5555. @require_backend(DistTestCases.backend_feature["gpu"])
  5556. @require_n_gpus_for_nccl_backend(
  5557. int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"]
  5558. )
  5559. @with_dist_debug_levels(levels=["OFF", "INFO", "DETAIL"])
  5560. def test_all_gather_object_default_pg(self):
  5561. return self._test_allgather_object()
  5562. @require_backend(DistTestCases.backend_feature["gpu"])
  5563. @require_n_gpus_for_nccl_backend(
  5564. int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"]
  5565. )
  5566. @with_dist_debug_levels(levels=["DETAIL", "OFF", "INFO"])
  5567. def test_all_gather_object_subgroup(self):
  5568. default = _get_default_group()
  5569. backend = dist.get_backend(default)
  5570. subgroup = dist.new_group(backend=backend)
  5571. return self._test_allgather_object(subgroup=subgroup)
  5572. def _test_gather_object(self, pg=None):
  5573. # Ensure stateful objects can be gathered
  5574. gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy()
  5575. my_rank = dist.get_rank(pg)
  5576. backend = os.environ["BACKEND"]
  5577. if backend == "nccl":
  5578. # Case where rank != GPU device.
  5579. next_rank = (self.rank + 1) % int(self.world_size)
  5580. torch.cuda.set_device(next_rank)
  5581. # If GPU test, add object with GPU tensor
  5582. if backend == "nccl":
  5583. gather_objects.append(Foo(torch.randn(3, 3, device=my_rank)))
  5584. output_gathered = [None for _ in range(dist.get_world_size(pg))]
  5585. gather_on_rank = 0
  5586. dist.gather_object(
  5587. gather_objects[self.rank % len(gather_objects)],
  5588. object_gather_list=output_gathered
  5589. if my_rank == gather_on_rank
  5590. else None,
  5591. dst=gather_on_rank,
  5592. group=pg
  5593. )
  5594. if my_rank != gather_on_rank:
  5595. self.assertEqual(
  5596. output_gathered, [None for _ in range(dist.get_world_size())]
  5597. )
  5598. else:
  5599. for i, val in enumerate(output_gathered):
  5600. expected = gather_objects[i % len(gather_objects)]
  5601. self.assertEqual(val, expected)
  5602. # Validate errors when objects can't be pickled.
  5603. class Bar:
  5604. pass
  5605. b = Bar()
  5606. gather_objects = [b for _ in range(dist.get_world_size())]
  5607. with self.assertRaisesRegex(AttributeError, "Can't pickle local object"):
  5608. dist.all_gather_object(
  5609. [None for _ in range(dist.get_world_size())],
  5610. gather_objects[self.rank],
  5611. group=pg
  5612. )
  5613. @sandcastle_skip_if(BACKEND == "ucc", "CPU tensor ops not supported by UCP TL")
  5614. @require_backend(DistTestCases.backend_feature["gpu"])
  5615. @with_dist_debug_levels(levels=["DETAIL", "OFF", "INFO"])
  5616. def test_gather_object(self):
  5617. return self._test_gather_object()
  5618. @sandcastle_skip_if(BACKEND == "ucc", "CPU tensor ops not supported by UCP TL")
  5619. @require_backend(DistTestCases.backend_feature["gpu"])
  5620. @with_dist_debug_levels(levels=["DETAIL", "OFF", "INFO"])
  5621. def test_gather_object_subgroup(self):
  5622. default = _get_default_group()
  5623. backend = dist.get_backend(default)
  5624. subgroup = dist.new_group(backend=backend)
  5625. return self._test_gather_object(subgroup)
  5626. def validate_net_equivalence(self, net):
  5627. # Helper to validate synchronization of nets across ranks.
  5628. net_module_states = list(net.module.state_dict().values())
  5629. # Check that all tensors in module's state_dict() are equal.
  5630. for t in net_module_states:
  5631. tensor_list = [
  5632. torch.zeros_like(t) for _ in range(dist.get_world_size())
  5633. ]
  5634. dist.all_gather(tensor_list, t)
  5635. for tensor in tensor_list:
  5636. self.assertEqual(tensor, t)
  5637. @skip_if_lt_x_gpu(2)
  5638. @sandcastle_skip_if(
  5639. BACKEND not in DistTestCases.backend_feature["ddp"],
  5640. f"The {BACKEND} backend does not support DistributedDataParallel"
  5641. )
  5642. def test_ddp_sync_module_states(self):
  5643. # Test that after calling _sync_module_states, models across ranks
  5644. # are the same and are equal to the model on the input rank.
  5645. dim = 2
  5646. rank = self.rank
  5647. rank_to_broadcast = 1
  5648. # Seed to ensure that ranks are initialized with different initial models.
  5649. torch.manual_seed(rank)
  5650. model = nn.Linear(dim, dim, bias=False)
  5651. net = torch.nn.parallel.DistributedDataParallel(
  5652. model.cuda(rank), device_ids=[self.rank], bucket_cap_mb=1
  5653. )
  5654. new_model = nn.Linear(dim, dim, bias=False).cuda(rank)
  5655. net.module = copy.deepcopy(new_model)
  5656. # Assert params are different
  5657. net_module_states = list(net.module.state_dict().values())
  5658. for t in net_module_states:
  5659. tensor_list = [
  5660. torch.zeros_like(t) for _ in range(dist.get_world_size())
  5661. ]
  5662. dist.all_gather(tensor_list, t)
  5663. for i, tensor in enumerate(tensor_list):
  5664. if i == rank:
  5665. self.assertEqual(t, tensor)
  5666. else:
  5667. # tensor from another rank should be different.
  5668. self.assertNotEqual(t, tensor)
  5669. _sync_module_states(
  5670. module=net.module,
  5671. process_group=net.process_group,
  5672. broadcast_bucket_size=net.broadcast_bucket_size,
  5673. src=rank_to_broadcast,
  5674. params_and_buffers_to_ignore=net.parameters_to_ignore
  5675. )
  5676. # Now all model params should be the same.
  5677. self.validate_net_equivalence(net)
  5678. # Since the network params were broadcast from rank_to_broadcast, validate that
  5679. # they are the same as new_model on rank_to_broadcast.
  5680. if rank == rank_to_broadcast:
  5681. expected_states = new_model.state_dict().values()
  5682. for t, expected in zip(net_module_states, expected_states):
  5683. self.assertEqual(t, expected)
  5684. @skip_if_lt_x_gpu(2)
  5685. @sandcastle_skip_if(
  5686. BACKEND not in DistTestCases.backend_feature["ddp"],
  5687. f"The {BACKEND} backend does not support DistributedDataParallel"
  5688. )
  5689. def test_ddp_grad_div_uneven_inputs(self):
  5690. # Test gradient division during training with join() API. If
  5691. # divide_by_initial_world_size=False, we scale by the effective world
  5692. # size when allreducing grads.
  5693. dim = 5
  5694. batch = 1
  5695. grad_scale = 50
  5696. rank = self.rank
  5697. model = nn.Linear(dim, dim, bias=False)
  5698. inp = torch.ones(batch, dim, device=self.rank) * grad_scale
  5699. net = torch.nn.parallel.DistributedDataParallel(
  5700. model.cuda(rank), device_ids=[self.rank], bucket_cap_mb=1
  5701. )
  5702. n_iters = 3
  5703. if self.rank > 0:
  5704. n_iters += 2
  5705. with net.join(divide_by_initial_world_size=False):
  5706. for _ in range(n_iters):
  5707. loss = net(inp).sum()
  5708. loss.backward()
  5709. # The grad is always expected_grad, since we divide by the number
  5710. # of currently active processes and inactive processes contribute
  5711. # zero gradient. If we kept dividing by static initial world
  5712. # size as processes leave, the grad would be smaller.
  5713. expected_grad = torch.ones(dim, dim, device=self.rank) * grad_scale
  5714. param = list(net.parameters())[0]
  5715. self.assertEqual(expected_grad, param.grad)
  5716. # Avoid accumulating grads so that it's the same every iteration
  5717. net.zero_grad()
  5718. torch.cuda.synchronize(device=self.rank)
  5719. # If divide_by_initial_world_size=True (default), we always scale grads
  5720. # by the initial world_size.
  5721. with net.join(divide_by_initial_world_size=True):
  5722. for i in range(n_iters):
  5723. loss = net(inp).sum()
  5724. loss.backward()
  5725. effective_ws = dist.get_world_size()
  5726. if i >= 3:
  5727. effective_ws -= 1
  5728. expected_grad = (
  5729. torch.ones(dim, dim, device=self.rank)
  5730. * grad_scale
  5731. * effective_ws
  5732. ) / dist.get_world_size()
  5733. param = list(net.parameters())[0]
  5734. self.assertEqual(expected_grad, param.grad)
  5735. # Avoid accumulating grad so that it's the same every iteration.
  5736. net.zero_grad()
  5737. torch.cuda.synchronize(device=self.rank)
  5738. def _test_ddp_profiling(self, profiler_ctx):
  5739. batch = 3
  5740. dim = 10
  5741. num_iters = 6
  5742. torch.cuda.set_device(self.rank)
  5743. model = nn.Linear(dim, dim, bias=False)
  5744. inp = torch.rand(batch, dim, device=self.rank)
  5745. net = torch.nn.parallel.DistributedDataParallel(
  5746. model.cuda(self.rank),
  5747. device_ids=[self.rank],
  5748. )
  5749. profiler_ctx_copy = copy.deepcopy(profiler_ctx)
  5750. with profiler_ctx as prof:
  5751. for i in range(num_iters):
  5752. loss = net(inp).sum()
  5753. loss.backward()
  5754. all_reduce_event_name = f"{dist.get_backend()}:all_reduce"
  5755. events = get_profiling_event(all_reduce_event_name, prof)
  5756. event_count = sum(e.count for e in events)
  5757. self.assertEqual(event_count, num_iters)
  5758. for event in events:
  5759. self.assertTrue(event.is_async)
  5760. self.assertEqual(event.name, all_reduce_event_name)
  5761. broadcast_event_name = f"{dist.get_backend()}:broadcast"
  5762. broadcast_events = get_profiling_event(broadcast_event_name, prof)
  5763. event_count = sum(e.count for e in broadcast_events)
  5764. # Broadcast is called during rebuild_buckets
  5765. self.assertGreaterEqual(event_count, 1)
  5766. for event in broadcast_events:
  5767. self.assertEqual(event.name, broadcast_event_name)
  5768. # Run DDP with profiling for a few iterations, then enable profiling
  5769. # for a single pass, and ensure it is recorded. This tests that the
  5770. # thread local state is correctly updated.
  5771. net = torch.nn.parallel.DistributedDataParallel(
  5772. model.cuda(self.rank),
  5773. device_ids=[self.rank],
  5774. find_unused_parameters=True,
  5775. )
  5776. for i in range(3):
  5777. loss = net(inp).sum()
  5778. loss.backward()
  5779. # Now enable the profiler.
  5780. with profiler_ctx_copy as prof:
  5781. loss = net(inp).sum()
  5782. loss.backward()
  5783. events = get_profiling_event(all_reduce_event_name, prof)
  5784. self.assertGreaterEqual(len(events), 1)
  5785. self.assertGreaterEqual(events[0].count, 1)
  5786. self.assertEqual(events[0].name, all_reduce_event_name)
  5787. for event in events:
  5788. self.assertTrue(event.is_async)
  5789. # Ensure searching unused parameters was profiled
  5790. events = get_profiling_event("search_unused_parameters", prof)
  5791. self.assertEqual(len(events), 1)
  5792. @require_backend(DistTestCases.backend_feature["gpu"])
  5793. @require_backends_available(DistTestCases.backend_feature["gpu"])
  5794. @skip_if_lt_x_gpu(2)
  5795. def test_ddp_profiling_autograd_profiler(self):
  5796. autograd_profiler_ctx = torch.autograd.profiler.profile()
  5797. return self._test_ddp_profiling(profiler_ctx=autograd_profiler_ctx)
  5798. @require_backend(DistTestCases.backend_feature["gpu"])
  5799. @require_backends_available(DistTestCases.backend_feature["gpu"])
  5800. @skip_if_lt_x_gpu(2)
  5801. @sandcastle_skip_if(IS_FBCODE, "Kineto in fbcode code causes hang")
  5802. @sandcastle_skip_if(
  5803. IS_MACOS or IS_WINDOWS,
  5804. "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
  5805. )
  5806. def test_ddp_profiling_torch_profiler(self):
  5807. cpu_act = torch.profiler.ProfilerActivity.CPU
  5808. cuda_act = torch.profiler.ProfilerActivity.CUDA
  5809. torch_profiler_ctx = torch.profiler.profile(activities=[cpu_act, cuda_act])
  5810. self._test_ddp_profiling(profiler_ctx=torch_profiler_ctx)
  5811. @skip_if_lt_x_gpu(2)
  5812. @sandcastle_skip_if(
  5813. BACKEND not in DistTestCases.backend_feature["ddp"],
  5814. f"The {BACKEND} backend does not support DistributedDataParallel"
  5815. )
  5816. def test_ddp_join_model_equivalence(self):
  5817. # Verifies equivalence with model training locally and with DDP under
  5818. # the join context manager.
  5819. batch = 3
  5820. dim = 10
  5821. learning_rate = 0.03
  5822. model = nn.Linear(dim, dim, bias=False)
  5823. inp = torch.rand(batch, dim, device=self.rank)
  5824. local_model = copy.deepcopy(model)
  5825. local_model = local_model.cuda(self.rank)
  5826. rank_to_iter_mapping = {
  5827. rank: 2 * (rank + 1) for rank in range(dist.get_world_size())
  5828. }
  5829. # run local model
  5830. local_iters = sum(rank_to_iter_mapping.values())
  5831. local_optim = torch.optim.SGD(local_model.parameters(), lr=learning_rate)
  5832. for _ in range(local_iters):
  5833. local_optim.zero_grad()
  5834. out = local_model(inp)
  5835. loss = out.sum()
  5836. loss.backward()
  5837. local_optim.step()
  5838. # run DDP model with join API
  5839. num_iters = rank_to_iter_mapping[self.rank]
  5840. net = torch.nn.parallel.DistributedDataParallel(
  5841. model.cuda(self.rank), device_ids=[self.rank]
  5842. )
  5843. ddp_optim = torch.optim.SGD(
  5844. model.parameters(), lr=learning_rate * dist.get_world_size()
  5845. )
  5846. with net.join():
  5847. for i in range(num_iters):
  5848. ddp_optim.zero_grad()
  5849. out = net(inp)
  5850. loss = out.sum()
  5851. loss.backward()
  5852. torch.cuda.synchronize(device=self.rank)
  5853. ddp_optim.step()
  5854. # Validate model state dicts are equal
  5855. for (_, local_tensor), (_, dist_tensor) in zip(
  5856. local_model.state_dict().items(), net.module.state_dict().items()
  5857. ):
  5858. self.assertEqual(local_tensor, dist_tensor)
  5859. def _run_uneven_inputs_test(
  5860. self,
  5861. test_case,
  5862. iteration_mapping,
  5863. find_unused_params,
  5864. ):
  5865. model = test_case.model
  5866. inp = test_case.inp
  5867. rank = self.rank
  5868. sync_interval = test_case.sync_interval
  5869. torch.cuda.set_device(rank)
  5870. # Ensure all outsanding GPU work is comlete so this test runs independently.
  5871. dist.barrier()
  5872. # Bucket_cap_mb is intentionally low to test allreduce scheduling when
  5873. # there are many buckets.
  5874. net = torch.nn.parallel.DistributedDataParallel(
  5875. model.cuda(rank),
  5876. device_ids=[rank],
  5877. bucket_cap_mb=1,
  5878. find_unused_parameters=find_unused_params,
  5879. )
  5880. # Register hook if specified
  5881. if test_case.hook is not None:
  5882. net.register_comm_hook(test_case.state, test_case.hook)
  5883. print(f"registered hook {test_case.hook}")
  5884. # Determine num iters for this rank via the passed in mapping.
  5885. num_iters = iteration_mapping[rank]
  5886. # If we throw when earliest rank terminates, we should ensure
  5887. # that we iterate for that minimum number of times.
  5888. num_iters_tensor = torch.tensor(
  5889. [num_iters], device=torch.cuda.current_device()
  5890. )
  5891. dist.all_reduce(num_iters_tensor, op=dist.ReduceOp.MIN)
  5892. min_num_iters = num_iters_tensor.item()
  5893. total_iters = 0
  5894. if test_case.throw_on_early_termination:
  5895. if min_num_iters == num_iters:
  5896. # Early termination rank(s)
  5897. exception_ctx = self.assertRaisesRegex(
  5898. RuntimeError, f"Rank {self.rank} exhausted all inputs"
  5899. )
  5900. else:
  5901. # Non early termination rank
  5902. exception_ctx = self.assertRaisesRegex(
  5903. RuntimeError,
  5904. "Detected at least one rank that exhausted inputs.",
  5905. )
  5906. else:
  5907. exception_ctx = suppress()
  5908. with exception_ctx:
  5909. with net.join(
  5910. throw_on_early_termination=test_case.throw_on_early_termination
  5911. ):
  5912. for i in range(num_iters):
  5913. # Use model.no_sync() to disable grad synchronization every
  5914. # sync_interval.
  5915. if i % sync_interval != 0:
  5916. context = net.no_sync()
  5917. else:
  5918. context = suppress()
  5919. with context:
  5920. if isinstance(inp, tuple):
  5921. loss = net(*inp).sum()
  5922. else:
  5923. loss = net(inp).sum()
  5924. loss.backward()
  5925. self._model_step(net)
  5926. # Ensure completion of GPU kernels (including allreduce). If the
  5927. # join API is not properly implemented, then this should hang
  5928. # since the allreduce will hang.
  5929. torch.cuda.synchronize(device=rank)
  5930. total_iters += 1
  5931. if test_case.throw_on_early_termination:
  5932. # Ensure we iterated min_num_iters times.
  5933. self.assertEqual(total_iters, min_num_iters)
  5934. else:
  5935. # Ensure we iterated at least min_num_iters times.
  5936. self.assertGreaterEqual(total_iters, min_num_iters)
  5937. # Ensure completion of all GPU kernels.
  5938. torch.cuda.synchronize(device=rank)
  5939. # When throwing on early rank termination, we do not
  5940. # broadcast model state from an authoritative rank. All models
  5941. # should already be in sync.
  5942. if not test_case.throw_on_early_termination:
  5943. self.assertTrue(net._authoritative_rank)
  5944. # All ranks should have agreed on the same authoritative_rank!
  5945. final_rank_tensor = torch.tensor(
  5946. [net._authoritative_rank], device=self.rank
  5947. )
  5948. tensor_list = [
  5949. torch.zeros_like(final_rank_tensor)
  5950. for _ in range(dist.get_world_size())
  5951. ]
  5952. dist.all_gather(tensor_list, final_rank_tensor)
  5953. max_rank = dist.get_world_size() - 1
  5954. self.assertSetEqual(
  5955. {max_rank}, {tensor.item() for tensor in tensor_list}
  5956. )
  5957. # Ensure that all models are the same across ranks after all have joined.
  5958. self.validate_net_equivalence(net)
  5959. # Ensure that running with DDP uneven inputs was logged.
  5960. ddp_logging_data = net._get_ddp_logging_data()
  5961. self.assertTrue(ddp_logging_data.get("join_uneven_inputs"))
  5962. dist.barrier()
  5963. @skip_if_lt_x_gpu(2)
  5964. @sandcastle_skip_if(
  5965. BACKEND not in DistTestCases.backend_feature["ddp"],
  5966. f"The {BACKEND} backend does not support DistributedDataParallel"
  5967. )
  5968. def test_ddp_uneven_inputs_stop_iteration_sync_bn(self):
  5969. # Tests that uneven inputs join handler correctly throws StopIteration
  5970. # for models with SyncBN or general collective comm when
  5971. # throw_on_early_termination=True.
  5972. class ModelWithComm(torch.nn.Module):
  5973. def __init__(self):
  5974. super().__init__()
  5975. self.lin = nn.Linear(2, 40, bias=False)
  5976. def forward(self, x):
  5977. x = self.lin(x)
  5978. dist.all_reduce(x)
  5979. return x
  5980. torch.cuda.set_device(self.rank)
  5981. model_bn = BN_NET
  5982. model_bn = nn.SyncBatchNorm.convert_sync_batchnorm(
  5983. copy.deepcopy(model_bn)
  5984. ).cuda(self.rank)
  5985. comm_model = ModelWithComm().cuda(self.rank)
  5986. model_input = torch.randn(10, 2).cuda(torch.cuda.current_device())
  5987. for model in [model_bn, comm_model]:
  5988. model = torch.nn.parallel.DistributedDataParallel(
  5989. model,
  5990. device_ids=[self.rank],
  5991. )
  5992. min_num_iters = 5
  5993. if self.rank != 0:
  5994. # Early termination rank(s)
  5995. num_iters = min_num_iters
  5996. exception_ctx = self.assertRaisesRegex(
  5997. RuntimeError, f"Rank {self.rank} exhausted all inputs"
  5998. )
  5999. else:
  6000. # Non early termination rank
  6001. num_iters = min_num_iters * 2
  6002. exception_ctx = self.assertRaisesRegex(
  6003. RuntimeError,
  6004. "Detected at least one rank that exhausted inputs.",
  6005. )
  6006. n = 0
  6007. with exception_ctx:
  6008. with model.join(throw_on_early_termination=True):
  6009. for i in range(num_iters):
  6010. loss = model(model_input).sum()
  6011. loss.backward()
  6012. self._model_step(model)
  6013. n += 1
  6014. self.assertEqual(n, min_num_iters)
  6015. # Verify model equivalence
  6016. self.validate_net_equivalence(model)
  6017. @skip_if_lt_x_gpu(2)
  6018. @sandcastle_skip_if(
  6019. BACKEND not in DistTestCases.backend_feature["ddp"],
  6020. f"The {BACKEND} backend does not support DistributedDataParallel"
  6021. )
  6022. def test_ddp_uneven_inputs(self):
  6023. dim = 1000
  6024. batch = 1
  6025. # Create a variety of models to run uneven input tests on.
  6026. large_model = nn.Sequential(
  6027. nn.Conv2d(1, 20, 5),
  6028. nn.ReLU(),
  6029. nn.Conv2d(20, 32, 5),
  6030. nn.ReLU(),
  6031. nn.Conv2d(32, 256, 5),
  6032. nn.ReLU(),
  6033. )
  6034. small_model = nn.Linear(dim, dim, bias=False)
  6035. bn_net = BatchNormNet()
  6036. class UnusedParamModule(nn.Module):
  6037. def __init__(self, unused_params_rank):
  6038. super().__init__()
  6039. self.t0 = Task()
  6040. self.t1 = Task()
  6041. self.unused_params_rank = unused_params_rank
  6042. def task_parameters(self):
  6043. return (self.t0.p, self.t1.p)
  6044. def forward(self, x, rank):
  6045. return (
  6046. self.t1(self.t0(x))
  6047. if rank != self.unused_params_rank
  6048. else self.t1(x)
  6049. )
  6050. unjoined_rank_with_unused_params_model = UnusedParamModule(1)
  6051. joined_rank_with_unused_params_model = UnusedParamModule(0)
  6052. rank = self.rank
  6053. models_to_test = [
  6054. # Network with batchnorm
  6055. DDPUnevenTestInput(
  6056. name="batch_norm_net",
  6057. model=bn_net,
  6058. inp=torch.ones(batch, 2, device=rank),
  6059. sync_interval=1,
  6060. ),
  6061. DDPUnevenTestInput(
  6062. name="large_conv_model",
  6063. model=large_model,
  6064. inp=torch.ones(batch, batch, dim, dim, device=rank),
  6065. sync_interval=1,
  6066. ),
  6067. DDPUnevenTestInput(
  6068. name="small_model",
  6069. model=small_model,
  6070. inp=torch.ones(batch, dim, device=rank),
  6071. sync_interval=1,
  6072. ),
  6073. # Unused parameter test where rank that does not join early has unused params
  6074. DDPUnevenTestInput(
  6075. name="unjoined_rank_with_unused_params_model",
  6076. model=unjoined_rank_with_unused_params_model,
  6077. inp=(torch.ones(batch, 2, device=rank), rank),
  6078. sync_interval=1,
  6079. ),
  6080. # Unused parameter test where rank that does join early has unused params
  6081. DDPUnevenTestInput(
  6082. name="joined_rank_with_unused_params_model",
  6083. model=joined_rank_with_unused_params_model,
  6084. inp=(torch.ones(batch, 2, device=rank), rank),
  6085. sync_interval=1,
  6086. ),
  6087. ]
  6088. # Test models that have hook installed.
  6089. models_with_hook = [
  6090. DDPUnevenTestInput(
  6091. name="small_model_allreduce_hook",
  6092. model=small_model,
  6093. hook=default.allreduce_hook,
  6094. state=None,
  6095. inp=torch.ones(batch, dim, device=rank),
  6096. sync_interval=1,
  6097. ),
  6098. DDPUnevenTestInput(
  6099. name="small_model_power_sgd_hook",
  6100. model=small_model,
  6101. hook=powerSGD.powerSGD_hook,
  6102. state=powerSGD.PowerSGDState(
  6103. process_group=None,
  6104. matrix_approximation_rank=1,
  6105. # Config so that powerSGD runs immediately instead of
  6106. # allreduce.
  6107. start_powerSGD_iter=1,
  6108. warm_start=False,
  6109. use_error_feedback=False,
  6110. ),
  6111. inp=torch.ones(batch, dim, device=rank),
  6112. sync_interval=1,
  6113. ),
  6114. ]
  6115. models_to_test.extend(models_with_hook)
  6116. # Add resnet model if we have torchvision installed.
  6117. if HAS_TORCHVISION:
  6118. resnet_model = torchvision.models.resnet50()
  6119. models_to_test.append(
  6120. DDPUnevenTestInput(
  6121. name="resnet_model",
  6122. model=resnet_model,
  6123. inp=torch.ones(1, 3, 1000, 1000),
  6124. sync_interval=1,
  6125. )
  6126. )
  6127. # Test with no_sync every 2, 3, 4, ... iterations.
  6128. models_with_sync = []
  6129. for i, test_input in enumerate(models_to_test):
  6130. models_with_sync.append(
  6131. DDPUnevenTestInput(
  6132. name=test_input.name,
  6133. model=test_input.model,
  6134. inp=test_input.inp,
  6135. sync_interval=i + 2,
  6136. )
  6137. )
  6138. throw_on_early_term_tests = []
  6139. for test_input in models_to_test:
  6140. throw_on_early_term_tests.append(
  6141. DDPUnevenTestInput(
  6142. name=test_input.name,
  6143. model=test_input.model,
  6144. inp=test_input.inp,
  6145. sync_interval=test_input.sync_interval,
  6146. throw_on_early_termination=True,
  6147. )
  6148. )
  6149. models_to_test.extend(models_with_sync)
  6150. models_to_test.extend(throw_on_early_term_tests)
  6151. # 0 iteration tests for when one process does not train model at all, so
  6152. # we must shadow the broadcast calls made when rebuilding buckets.
  6153. baseline_num_iters = [0, 5]
  6154. iteration_offsets = [2, 3, 10]
  6155. num_uneven_ranks = [1]
  6156. if dist.get_world_size() > 2:
  6157. num_uneven_ranks.append(2)
  6158. iteration_mappings = []
  6159. # Generate rank : num_iters mappings for various uneven input scenarios.
  6160. # This includes cases where rank 0 joins early and all other ranks join
  6161. # later, and scenarios where multiple ranks join early, but at different
  6162. # iterations, and later ranks join later.
  6163. for num_early_join_ranks in num_uneven_ranks:
  6164. for baseline_iter in baseline_num_iters:
  6165. for offset in iteration_offsets:
  6166. mapping = {
  6167. rank: baseline_iter
  6168. for rank in range(0, num_early_join_ranks)
  6169. }
  6170. # if num_early_join_ranks > 1, ranks > 0 that will join early
  6171. # iterate offset//2 more times than rank 0, to test nodes
  6172. # depleting inputs at different times.
  6173. if num_early_join_ranks > 1:
  6174. for rank in mapping.keys():
  6175. if rank > 0:
  6176. mapping[rank] += offset // 2
  6177. mapping.update(
  6178. {
  6179. rank: baseline_iter + offset
  6180. for rank in range(
  6181. num_early_join_ranks, dist.get_world_size()
  6182. )
  6183. }
  6184. )
  6185. iteration_mappings.append(mapping)
  6186. for (test_case, iteration_mapping) in itertools.product(
  6187. models_to_test, iteration_mappings
  6188. ):
  6189. if self.rank == 0:
  6190. print(
  6191. f"""Running test: {test_case.name} sync interval
  6192. {test_case.sync_interval} with iteration mapping
  6193. {iteration_mapping}"""
  6194. )
  6195. self._run_uneven_inputs_test(
  6196. test_case,
  6197. iteration_mapping,
  6198. find_unused_params=("unused_params_model" in test_case.name),
  6199. )
  6200. @skip_if_lt_x_gpu(2)
  6201. @sandcastle_skip_if(
  6202. BACKEND not in DistTestCases.backend_feature["ddp"],
  6203. f"The {BACKEND} backend does not support DistributedDataParallel"
  6204. )
  6205. def test_ddp_uneven_input_join_disable(self):
  6206. # tests that if net.join() with enable=False is specified, DDP works as
  6207. # expected with even inputs.
  6208. torch.manual_seed(self.rank)
  6209. net = torch.nn.parallel.DistributedDataParallel(
  6210. torch.nn.Linear(1, 1).cuda(self.rank), device_ids=[self.rank]
  6211. )
  6212. inp = torch.ones(1) * self.rank
  6213. n_iters = 5
  6214. world_size = dist.get_world_size()
  6215. with net.join(enable=False):
  6216. for _ in range(n_iters):
  6217. # Clear grads
  6218. grad = net.module.weight.grad
  6219. if grad is not None:
  6220. grad.requires_grad_(False)
  6221. grad.zero_()
  6222. out = net(inp)
  6223. loss = out.sum()
  6224. loss.backward()
  6225. # Validate gradients to ensure that we divide by the correct
  6226. # world_size when join mode is disabled.
  6227. expected_grad = sum(i for i in range(world_size)) / world_size
  6228. self.assertEqual(net.module.weight.grad.item(), expected_grad)
  6229. join_config = net._join_config
  6230. self.assertFalse(join_config.enable)
  6231. self.validate_net_equivalence(net)
  6232. @skip_if_lt_x_gpu(2)
  6233. @sandcastle_skip_if(
  6234. BACKEND not in DistTestCases.backend_feature["ddp"],
  6235. f"The {BACKEND} backend does not support DistributedDataParallel"
  6236. )
  6237. def test_ddp_uneven_input_exception(self):
  6238. # Tests that exceptions during training are correctly propagated by the
  6239. # context manager.
  6240. error_str = "Intentional error"
  6241. class ExceptionModule(nn.Module):
  6242. def __init__(self):
  6243. super().__init__()
  6244. self.param = nn.Parameter(torch.ones(1, requires_grad=True))
  6245. def forward(self, _):
  6246. raise ValueError(error_str)
  6247. exception_module = ExceptionModule()
  6248. net = torch.nn.parallel.DistributedDataParallel(
  6249. exception_module.cuda(self.rank), device_ids=[self.rank]
  6250. )
  6251. inp = torch.ones(1)
  6252. with self.assertRaisesRegex(ValueError, error_str):
  6253. with net.join():
  6254. out = net(inp)
  6255. loss = out.sum()
  6256. loss.backward()
  6257. def _test_broadcast_object_list(self, group=None):
  6258. gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy()
  6259. # Only set device for NCCL backend since it must use GPUs.
  6260. # Case where rank != GPU device.
  6261. next_rank = (self.rank + 1) % int(self.world_size)
  6262. backend = os.environ["BACKEND"]
  6263. if backend == "nccl":
  6264. torch.cuda.set_device(next_rank)
  6265. src_rank = 0
  6266. # If GPU test, add object with GPU tensor
  6267. if backend == "nccl":
  6268. gather_objects.append(Foo(torch.randn(3, 3, device=0)))
  6269. if IS_FBCODE:
  6270. # Create Tensor with > 2^31 Bytes storage requirements
  6271. # Only on FBCODE as testing OOMs in OSS
  6272. gather_objects.append(Foo(torch.randn(3, 178956971)))
  6273. objects = (
  6274. gather_objects
  6275. if self.rank == src_rank
  6276. else [None for _ in gather_objects]
  6277. )
  6278. # Single object test with device specified. Backend="gloo", device=cpu
  6279. if backend != "nccl":
  6280. single_obj_list = [objects[0]]
  6281. if self.rank != src_rank:
  6282. self.assertNotEqual(
  6283. single_obj_list[0], gather_objects[0]
  6284. )
  6285. dist.broadcast_object_list(
  6286. single_obj_list, src=0, group=group, device=torch.device("cpu")
  6287. )
  6288. self.assertEqual(single_obj_list[0], gather_objects[0])
  6289. # Single object test with device specified. Backend="gloo", device=current_device+1
  6290. # The test is gated by the fact GPU count is the same as world size to avoid the case
  6291. # when backend is gloo but there is no multiple GPU devices.
  6292. if backend != "nccl" and torch.cuda.device_count() == int(self.world_size):
  6293. single_obj_list = [objects[0]]
  6294. if self.rank != src_rank:
  6295. self.assertNotEqual(
  6296. single_obj_list[0], gather_objects[0]
  6297. )
  6298. dist.broadcast_object_list(
  6299. single_obj_list, src=0, group=group, device=torch.device(next_rank)
  6300. )
  6301. self.assertEqual(single_obj_list[0], gather_objects[0])
  6302. # Single object test with device specified. Backend="nccl", device=current_device+1
  6303. if backend == "nccl" and torch.cuda.device_count() == int(self.world_size):
  6304. single_obj_list = [objects[0]]
  6305. if self.rank != src_rank:
  6306. self.assertNotEqual(
  6307. single_obj_list[0], gather_objects[0]
  6308. )
  6309. dist.broadcast_object_list(
  6310. single_obj_list, src=0, group=group, device=torch.device(next_rank)
  6311. )
  6312. self.assertEqual(single_obj_list[0], gather_objects[0])
  6313. # Single object test: backward compatibility with device unspecified
  6314. single_obj_list = [objects[0]]
  6315. if self.rank != src_rank:
  6316. self.assertNotEqual(single_obj_list[0], gather_objects[0])
  6317. dist.broadcast_object_list(single_obj_list, src=0, group=group)
  6318. self.assertEqual(single_obj_list[0], gather_objects[0])
  6319. # Multiple input objects test
  6320. if self.rank != src_rank:
  6321. self.assertNotEqual(objects, gather_objects)
  6322. dist.broadcast_object_list(objects, src=0, group=group)
  6323. self.assertEqual(objects, gather_objects)
  6324. @require_backend(DistTestCases.backend_feature["gpu"])
  6325. @require_n_gpus_for_nccl_backend(
  6326. int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"]
  6327. )
  6328. @with_dist_debug_levels(levels=["DETAIL"])
  6329. def test_broadcast_object_list(self):
  6330. return self._test_broadcast_object_list()
  6331. @require_backend(DistTestCases.backend_feature["gpu"])
  6332. @require_n_gpus_for_nccl_backend(
  6333. int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"]
  6334. )
  6335. @with_dist_debug_levels(levels=["DETAIL"])
  6336. def _test_broadcast_object_list_subgroup(self):
  6337. default = _get_default_group()
  6338. backend = dist.get_backend(default)
  6339. subgroup = dist.new_group(backend=backend)
  6340. return self._test_broadcast_object_list(subgroup)
  6341. def _test_ddp_ignore_params_arg(self, static_graph=False):
  6342. class TestModel(nn.Module):
  6343. def __init__(self, rank):
  6344. self.rank = rank
  6345. super().__init__()
  6346. self.fc1 = nn.Linear(1, 1, bias=False)
  6347. # Proxy that will be materialized to another architecture later.
  6348. # (after wrapping model with DDP)
  6349. if self.rank == 0:
  6350. self.fc2 = nn.Linear(1, 10, bias=False)
  6351. else:
  6352. self.fc2 = nn.Linear(10, 10, bias=False)
  6353. def forward(self, x):
  6354. x = self.fc1(x)
  6355. x = self.fc2(x)
  6356. return x
  6357. device_id = self.rank
  6358. # Ensure the test works for both find_unused_parameter and broadcast_buffer settings.
  6359. for (find_unused, broadcast_buffers) in itertools.product(
  6360. [False, True], [False, True]
  6361. ):
  6362. model = TestModel(self.rank).float().to(device_id)
  6363. # Note that the model can have different shape buffers if we pass
  6364. # them in to be ignored as well.
  6365. model.fc2.register_buffer(
  6366. "ignore_buffer", torch.zeros(5 + self.rank, device=self.rank)
  6367. )
  6368. proxy_params = list(model.fc2.parameters())
  6369. proxy_buffers = list(model.fc2.buffers())
  6370. model_fc2_name = [
  6371. module_name
  6372. for module_name, module in model.named_modules()
  6373. if module is model.fc2
  6374. ][0]
  6375. proxy_param_names = [
  6376. f"{model_fc2_name}.{param_name}"
  6377. for param_name, _ in model.fc2.named_parameters()
  6378. ]
  6379. proxy_buffer_names = [
  6380. f"{model_fc2_name}.{buf_name}"
  6381. for buf_name, _ in model.fc2.named_buffers()
  6382. ]
  6383. # Specify that we should ignore proxy_params since it will be
  6384. # materialized later.
  6385. torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
  6386. model, proxy_param_names + proxy_buffer_names
  6387. )
  6388. ddp = torch.nn.parallel.DistributedDataParallel(
  6389. model,
  6390. device_ids=[device_id],
  6391. find_unused_parameters=find_unused,
  6392. broadcast_buffers=broadcast_buffers,
  6393. static_graph=static_graph,
  6394. )
  6395. # Materialize new params. These are not registered in DDP and thus
  6396. # don't have autograd hooks installed on them.
  6397. ddp.module.fc2 = nn.Linear(1, 1, bias=False).to(device_id)
  6398. # Rebuild replicated_module to pick up the changes.
  6399. ddp._build_replicated_tensor_module()
  6400. # local model with the new materialized parameters.
  6401. local_model = copy.deepcopy(ddp.module).cuda(self.rank)
  6402. inp = torch.ones(1, dtype=torch.float).to(device_id) * (self.rank + 1)
  6403. for i in range(6):
  6404. ddp(inp).sum().backward()
  6405. local_model(inp).sum().backward()
  6406. # materialized param grad is not touched by DDP, so its grad should
  6407. # be the same as if running locally.
  6408. for materialized_param, local_param in zip(
  6409. ddp.module.fc2.parameters(), local_model.fc2.parameters()
  6410. ):
  6411. self.assertEqual(materialized_param.grad, local_param.grad)
  6412. # fc1 parameter grad should still be different, due to allreduce.
  6413. for synced_param, local_param in zip(
  6414. ddp.module.fc1.parameters(), local_model.fc1.parameters()
  6415. ):
  6416. self.assertFalse(synced_param.grad == local_param.grad)
  6417. # Proxy module grad should not be touched
  6418. for proxy_param in proxy_params:
  6419. self.assertTrue(proxy_param.grad is None)
  6420. # Synchronize since we run multiple iterations of this test, to
  6421. # isolate failure hangs.
  6422. torch.cuda.synchronize(device=self.rank)
  6423. @require_backend(DistTestCases.backend_feature["gpu"])
  6424. @require_backends_available(DistTestCases.backend_feature["gpu"])
  6425. @skip_if_lt_x_gpu(2)
  6426. def test_ddp_ignore_params_arg(self):
  6427. self._test_ddp_ignore_params_arg(static_graph=False)
  6428. self._test_ddp_ignore_params_arg(static_graph=True)
  6429. @with_dist_debug_levels(levels=["OFF", "INFO", "DETAIL"])
  6430. @require_backend(DistTestCases.backend_feature["gpu"])
  6431. @require_backends_available(DistTestCases.backend_feature["gpu"])
  6432. @skip_if_lt_x_gpu(2)
  6433. def test_ddp_unused_params_rebuild_buckets_exception(self):
  6434. class ToyModel(nn.Module):
  6435. def __init__(self):
  6436. super().__init__()
  6437. self.net1 = nn.Linear(10, 10, bias=False)
  6438. self.net2 = nn.Linear(10, 10, bias=False)
  6439. def forward(self, x):
  6440. return self.net1(x)
  6441. ddp = torch.nn.parallel.DistributedDataParallel(
  6442. ToyModel().cuda(self.rank), device_ids=[self.rank]
  6443. )
  6444. for i in range(2):
  6445. inp = torch.rand(1, 10)
  6446. if i > 0:
  6447. # On 2nd iteration, this will fail during rebuild_buckets,
  6448. # but we should report an error regarding unused parameters
  6449. # since that is the underlying root cause.
  6450. try:
  6451. ddp(inp).sum().backward()
  6452. except RuntimeError as e:
  6453. msg = str(e)
  6454. verify_ddp_error_logged(ddp, msg)
  6455. expected_strs = [
  6456. ddp_prev_reduction_unfinished_str,
  6457. ddp_recommend_find_unused_params_str,
  6458. ddp_outputs_not_used_in_loss_str,
  6459. ]
  6460. # In debug mode, should show parameters that weren't reduced.
  6461. # Without debug mode, should show suggestion to use debug mode.
  6462. if dist.get_debug_level() == dist.DebugLevel.OFF:
  6463. expected_strs.append(ddp_suggest_debug_mode_str)
  6464. else:
  6465. unreduced_params = ", ".join(["net2.weight"])
  6466. expected_strs.append(
  6467. f"did not receive grad for rank {self.rank}: {unreduced_params}"
  6468. )
  6469. for s in expected_strs:
  6470. self.assertTrue(s in msg, f"Expected {s} to be in {msg}")
  6471. self.assertFalse(ddp_find_unused_params_enabled_str in msg)
  6472. else:
  6473. self.assertFalse(
  6474. True, "DDP unused parameters error not raised."
  6475. )
  6476. else:
  6477. ddp(inp).sum().backward()
  6478. dist.barrier()
  6479. @require_backend(DistTestCases.backend_feature["gpu"])
  6480. @require_backends_available(DistTestCases.backend_feature["gpu"])
  6481. @skip_if_lt_x_gpu(2)
  6482. def test_ddp_shared_grad_acc_unused_params(self):
  6483. # When find_unused_parameters=True, ensure we mark unused parameters
  6484. # even if they share gradient accumulators.
  6485. class ToyModel(nn.Module):
  6486. def __init__(self):
  6487. super().__init__()
  6488. # net1, bias, and net1.bias are all unused params.
  6489. self.net1 = nn.Linear(10, 5, bias=False)
  6490. self.bias = nn.Parameter(torch.zeros(5))
  6491. # net1.bias and self.bias are names for the same underlying
  6492. # parameter, so they share the same grad acc. This caused
  6493. # the bug reported in https://github.com/pytorch/pytorch/issues/41324.
  6494. self.net1.bias = self.bias
  6495. self.net2 = nn.Linear(10, 5)
  6496. def forward(self, x):
  6497. return self.net2(x).sum()
  6498. torch.cuda.set_device(self.rank)
  6499. model = ToyModel().to(torch.cuda.current_device())
  6500. for static in [True, False]:
  6501. ddp_model = torch.nn.parallel.DistributedDataParallel(
  6502. copy.deepcopy(model),
  6503. device_ids=[self.rank],
  6504. find_unused_parameters=True,
  6505. static_graph=static,
  6506. )
  6507. inp = torch.randn(20, 10, device=self.rank)
  6508. for i in range(6):
  6509. loss = ddp_model(inp)
  6510. # To test https://github.com/pytorch/pytorch/issues/61982
  6511. loss /= 10
  6512. loss.backward()
  6513. @require_backend(DistTestCases.backend_feature["gpu"])
  6514. @require_backends_available(DistTestCases.backend_feature["gpu"])
  6515. @skip_if_lt_x_gpu(2)
  6516. def test_ddp_device(self):
  6517. m = nn.Linear(10, 10).to(self.rank)
  6518. expected_len = 2
  6519. class TensorWrapper:
  6520. __slots__ = ["t", "moved_to_gpu"]
  6521. def __init__(self, t):
  6522. self.t = t
  6523. self.moved_to_gpu = False
  6524. # Handlers for specific types of validation we want to do based on
  6525. # the input type.
  6526. def tuple_and_list_validator(x):
  6527. self.assertTrue(len(x), expected_len)
  6528. self.assertEqual(1, len({t.device for t in x}))
  6529. self.assertEqual(x[0].device.index, self.rank)
  6530. return x[0] + x[1]
  6531. def namedtuple_validator(x):
  6532. self.assertEqual(x._fields, EXPECTED_FIELDS)
  6533. self.assertEqual(x.a.device.index, x.b.device.index)
  6534. self.assertEqual(x.a.device.index, self.rank)
  6535. return x.a + x.b
  6536. def custom_type_validator(x):
  6537. self.assertTrue(x.moved_to_gpu or (str(x.t.device) == "cpu"))
  6538. x.t = x.t.to(self.rank)
  6539. x.moved_to_gpu = True
  6540. return x.t
  6541. def dict_validator(x):
  6542. self.assertTrue(EXPECTED_FIELDS[0] in x.keys())
  6543. self.assertTrue(EXPECTED_FIELDS[1] in x.keys())
  6544. self.assertEqual(1, len({t.device for t in x.values()}))
  6545. self.assertEqual(x[EXPECTED_FIELDS[0]].device.index, self.rank)
  6546. return x[EXPECTED_FIELDS[0]] + x[EXPECTED_FIELDS[1]]
  6547. validators = {
  6548. TensorWrapper: custom_type_validator,
  6549. tuple: tuple_and_list_validator,
  6550. list: tuple_and_list_validator,
  6551. TestNamedTupleInput_0: namedtuple_validator,
  6552. TestNamedTupleInput_1: namedtuple_validator,
  6553. dict: dict_validator,
  6554. }
  6555. class ToyModel(torch.nn.Module):
  6556. def __init__(_self): # noqa: B902
  6557. super().__init__()
  6558. _self.lin = nn.Linear(10, 10, bias=False)
  6559. def forward(_self, x, expected_type): # noqa: B902
  6560. # Similar to scatter, the recursive to in the single-device
  6561. # case does not move tensors if they are in a custom type.
  6562. self.assertTrue(isinstance(x, expected_type))
  6563. fwd_tensor = validators[expected_type](x)
  6564. return _self.lin(fwd_tensor)
  6565. model = torch.nn.parallel.DistributedDataParallel(
  6566. ToyModel().to(self.rank), device_ids=[self.rank]
  6567. )
  6568. def train_iter(inp, input_type):
  6569. for _ in range(4):
  6570. out = model(inp, input_type)
  6571. out.sum().backward()
  6572. # CPU tuple input, should be moved to the proper device before call
  6573. # to forward.
  6574. inp = tuple(torch.randn(10, 10) for _ in range(expected_len))
  6575. train_iter(inp, tuple)
  6576. # List CPU input, should be moved to proper device before call to
  6577. # forward.
  6578. inp = [torch.randn(10, 10) for _ in range(expected_len)]
  6579. train_iter(inp, list)
  6580. # Custom type containing tensor. The type is maintained, but the
  6581. # device is not propagated (which is what happens with scatter too)
  6582. inp = TensorWrapper(torch.randn(10, 10))
  6583. train_iter(inp, TensorWrapper)
  6584. # NamedTuple input. The type should be maintained and tensor inputs
  6585. # should be moved to the correct device as in scatter.
  6586. batch = 5
  6587. dim = 10
  6588. a = torch.rand(batch, dim)
  6589. b = torch.rand(batch, dim)
  6590. inp = TestNamedTupleInput_0(a, b)
  6591. train_iter(inp, type(inp))
  6592. inp = TestNamedTupleInput_1(a, b)
  6593. train_iter(inp, type(inp))
  6594. # dictionary input.
  6595. inp = {
  6596. EXPECTED_FIELDS[0]: a,
  6597. EXPECTED_FIELDS[1]: b,
  6598. }
  6599. train_iter(inp, type(inp))
  6600. @require_backend(DistTestCases.backend_feature["gpu"])
  6601. @require_backends_available(DistTestCases.backend_feature["gpu"])
  6602. @skip_if_lt_x_gpu(2)
  6603. def test_ddp_namedtuple(self):
  6604. batch = 5
  6605. dim = 10
  6606. a = torch.rand(batch, dim, device=self.rank)
  6607. b = torch.rand(batch, dim, device=self.rank)
  6608. class NamedTupleModule(torch.nn.Module):
  6609. def __init__(_self): # noqa: B902
  6610. super().__init__()
  6611. _self.lin = nn.Linear(10, 1)
  6612. def forward(_self, input, expected_type): # noqa: B902
  6613. # Without NamedTuple support, this would be of type tuple.
  6614. self.assertTrue(
  6615. isinstance(input, expected_type),
  6616. f"Expected type {expected_type} but got {type(input)}",
  6617. )
  6618. self.assertEqual(input._fields, EXPECTED_FIELDS)
  6619. self.assertEqual(a, input.a)
  6620. self.assertEqual(b, input.b)
  6621. return _self.lin(torch.mul(input.a, input.b))
  6622. model = torch.nn.parallel.DistributedDataParallel(
  6623. NamedTupleModule().cuda(self.rank), device_ids=[self.rank]
  6624. )
  6625. inp = TestNamedTupleInput_0(a, b)
  6626. # The following would fail if DDP does not propagate NamedTuples correctly.
  6627. model(inp, type(inp))
  6628. inp = TestNamedTupleInput_1(a, b)
  6629. model(inp, type(inp))
  6630. @with_dist_debug_levels(levels=["OFF", "INFO", "DETAIL"])
  6631. @require_backend(DistTestCases.backend_feature["gpu"])
  6632. @require_backends_available(DistTestCases.backend_feature["gpu"])
  6633. @skip_if_lt_x_gpu(2)
  6634. def test_ddp_control_flow_same_across_ranks(self):
  6635. # Control flow that is the same across ranks.
  6636. batch = 20
  6637. dim = 10
  6638. world_size = dist.get_world_size()
  6639. torch.cuda.set_device(self.rank)
  6640. model = torch.nn.parallel.DistributedDataParallel(
  6641. ControlFlowToyModel().cuda(self.rank),
  6642. device_ids=[self.rank],
  6643. find_unused_parameters=True,
  6644. )
  6645. random_input = torch.randn(batch, dim, device=self.rank)
  6646. ones_input = torch.ones(batch, dim, device=self.rank)
  6647. for i in range(6):
  6648. if i % 2 == 0:
  6649. out = model(random_input)
  6650. else:
  6651. out = model(ones_input)
  6652. loss = out.sum()
  6653. loss.backward()
  6654. # On even iterations, 2nd param goes unused, on odd iterations,
  6655. # it is used.
  6656. local_used_map = model.reducer._get_local_used_map()
  6657. if i % 2 == 0:
  6658. expected = torch.tensor(
  6659. [world_size, 0], device=self.rank, dtype=torch.int32
  6660. )
  6661. else:
  6662. expected = torch.tensor(
  6663. [world_size, world_size], device=self.rank, dtype=torch.int32
  6664. )
  6665. # Validate parameter usage.
  6666. variable_usage_tensor = local_used_map
  6667. self.assertEqual(variable_usage_tensor, expected)
  6668. # Validate appropriate error message when DDP is used with
  6669. # find_unused_parameters=False.
  6670. model = torch.nn.parallel.DistributedDataParallel(
  6671. ControlFlowToyModel().cuda(self.rank),
  6672. device_ids=[self.rank],
  6673. find_unused_parameters=False,
  6674. )
  6675. for i in range(2):
  6676. if i == 0:
  6677. loss = model(random_input).sum()
  6678. loss.backward()
  6679. else:
  6680. try:
  6681. loss = model(random_input).sum()
  6682. loss.backward()
  6683. except RuntimeError as e:
  6684. msg = str(e)
  6685. verify_ddp_error_logged(model, msg)
  6686. # 2nd linear layer is unused
  6687. unused_param_index = 1
  6688. expected_strs = [
  6689. ddp_prev_reduction_unfinished_str,
  6690. ddp_recommend_find_unused_params_str,
  6691. ddp_outputs_not_used_in_loss_str,
  6692. f"Parameter indices which did not receive grad for rank {self.rank}: {unused_param_index}",
  6693. ]
  6694. # In debug mode, should show parameters that weren't reduced.
  6695. # Without debug mode, should show suggestion to use debug mode.
  6696. if dist.get_debug_level() == dist.DebugLevel.OFF:
  6697. expected_strs.append(ddp_suggest_debug_mode_str)
  6698. else:
  6699. unreduced_params = ", ".join(["lin2.weight"])
  6700. expected_strs.append(
  6701. f"did not receive grad for rank {self.rank}: {unreduced_params}"
  6702. )
  6703. for s in expected_strs:
  6704. self.assertTrue(s in msg, f"Expected {s} to be in {msg}")
  6705. self.assertFalse(ddp_find_unused_params_enabled_str in msg)
  6706. else:
  6707. self.assertFalse(True, "DDP error not raised")
  6708. dist.barrier()
  6709. @require_backend(DistTestCases.backend_feature["gpu"])
  6710. @require_backends_available(DistTestCases.backend_feature["gpu"])
  6711. @skip_if_lt_x_gpu(2)
  6712. def test_invalid_static_graph(self):
  6713. world_size = dist.get_world_size()
  6714. torch.cuda.set_device(self.rank)
  6715. model = torch.nn.parallel.DistributedDataParallel(
  6716. ControlFlowToyModel().cuda(self.rank),
  6717. device_ids=[self.rank],
  6718. static_graph=True,
  6719. )
  6720. random_input = torch.randn(20, 10, device=self.rank)
  6721. ones_input = torch.ones(20, 10, device=self.rank)
  6722. # unused parameter in the first iteration got used
  6723. # in second iteration.
  6724. expected_err = "Your training graph has changed in this iteration"
  6725. with self.assertRaisesRegex(RuntimeError, expected_err):
  6726. for i in range(2):
  6727. if i % 2 == 0:
  6728. out = model(random_input)
  6729. else:
  6730. out = model(ones_input)
  6731. loss = out.sum()
  6732. loss.backward()
  6733. verify_ddp_error_logged(model, expected_err)
  6734. # used parameter in the first iteration got unused
  6735. # in second iteration.
  6736. with self.assertRaisesRegex(
  6737. RuntimeError,
  6738. "Expected to have finished reduction in the prior iteration "
  6739. "before starting a new one. This error indicates that your "
  6740. "training graph has changed in this iteration, "
  6741. "e.g., one parameter is used in first iteration, "
  6742. "but then got unused in the second iteration. "
  6743. "this is not compatible with static_graph set to True.\n"
  6744. "Parameter indices which did not receive grad for"
  6745. ):
  6746. for i in range(2):
  6747. if i % 2 != 0:
  6748. out = model(random_input)
  6749. else:
  6750. out = model(ones_input)
  6751. loss = out.sum()
  6752. loss.backward()
  6753. verify_ddp_error_logged(model, "Expected to have finished reduction")
  6754. @with_dist_debug_levels(levels=["OFF", "INFO", "DETAIL"])
  6755. @require_backend(DistTestCases.backend_feature["gpu"])
  6756. @require_backends_available(DistTestCases.backend_feature["gpu"])
  6757. @skip_if_lt_x_gpu(2)
  6758. def test_ddp_control_flow_different_across_ranks(self):
  6759. # Control flow that is different across ranks.
  6760. batch = 20
  6761. dim = 10
  6762. class ToyModel(nn.Module):
  6763. def __init__(self, rank):
  6764. super().__init__()
  6765. self.lin1 = nn.Linear(10, 10, bias=False)
  6766. self.lin2 = nn.Linear(10, 10, bias=False)
  6767. self.rank = rank
  6768. def forward(self, x):
  6769. # Control-flow that is rank and input dependent for the
  6770. # model.
  6771. use_second_layer = (
  6772. torch.equal(x, torch.ones(batch, dim, device=x.device))
  6773. and self.rank == 1
  6774. )
  6775. if use_second_layer:
  6776. return self.lin2(F.relu(self.lin1(x)))
  6777. else:
  6778. return F.relu(self.lin1(x))
  6779. world_size = dist.get_world_size()
  6780. torch.cuda.set_device(self.rank)
  6781. model = torch.nn.parallel.DistributedDataParallel(
  6782. ToyModel(self.rank).cuda(self.rank),
  6783. device_ids=[self.rank],
  6784. find_unused_parameters=True,
  6785. )
  6786. random_input = torch.randn(batch, dim, device=self.rank)
  6787. ones_input = torch.ones(batch, dim, device=self.rank)
  6788. for i in range(6):
  6789. if i % 2 == 0:
  6790. out = model(random_input)
  6791. else:
  6792. out = model(ones_input)
  6793. loss = out.sum()
  6794. loss.backward()
  6795. # On even iterations, 2nd param goes unused, on odd iterations,
  6796. # it is used only on rank 1.
  6797. local_used_map = model.reducer._get_local_used_map()
  6798. if i % 2 == 0:
  6799. expected = torch.tensor(
  6800. [world_size, 0], device=self.rank, dtype=torch.int32
  6801. )
  6802. else:
  6803. expected = torch.tensor(
  6804. [world_size, 1], device=self.rank, dtype=torch.int32
  6805. )
  6806. variable_usage_tensor = local_used_map
  6807. # Validate parameter usage. On odd iterations, 2nd param is only
  6808. # used on rank 1.
  6809. self.assertEqual(variable_usage_tensor, expected)
  6810. # Validate appropriate error message when DDP is used with
  6811. # find_unused_parameters=False.
  6812. model = torch.nn.parallel.DistributedDataParallel(
  6813. ToyModel(self.rank).cuda(self.rank),
  6814. device_ids=[self.rank],
  6815. find_unused_parameters=False,
  6816. )
  6817. for i in range(2):
  6818. if i == 0:
  6819. loss = model(random_input).sum()
  6820. loss.backward()
  6821. else:
  6822. try:
  6823. loss = model(random_input).sum()
  6824. loss.backward()
  6825. except RuntimeError as e:
  6826. msg = str(e)
  6827. verify_ddp_error_logged(model, msg)
  6828. unused_param_index = 1
  6829. expected_strs = [
  6830. ddp_prev_reduction_unfinished_str,
  6831. ddp_recommend_find_unused_params_str,
  6832. ddp_outputs_not_used_in_loss_str,
  6833. f"Parameter indices which did not receive grad for rank {self.rank}: {unused_param_index}",
  6834. ]
  6835. # In debug mode, should show parameters that weren't reduced.
  6836. # Without debug mode, should show suggestion to use debug mode.
  6837. if dist.get_debug_level() == dist.DebugLevel.OFF:
  6838. expected_strs.append(ddp_suggest_debug_mode_str)
  6839. else:
  6840. unreduced_params = ", ".join(["lin2.weight"])
  6841. expected_strs.append(
  6842. f"did not receive grad for rank {self.rank}: {unreduced_params}"
  6843. )
  6844. for s in expected_strs:
  6845. self.assertTrue(s in msg, f"Expected {s} to be in {msg}")
  6846. self.assertFalse(ddp_find_unused_params_enabled_str in msg)
  6847. else:
  6848. self.assertFalse(True, "DDP error not raised")
  6849. dist.barrier()
  6850. @require_backend({"gloo"})
  6851. def test_scatter_object_list(self):
  6852. src_rank = 0
  6853. scatter_list = (
  6854. COLLECTIVES_OBJECT_TEST_LIST
  6855. if self.rank == src_rank
  6856. else [None for _ in COLLECTIVES_OBJECT_TEST_LIST]
  6857. )
  6858. world_size = dist.get_world_size()
  6859. scatter_list = scatter_list[:world_size]
  6860. i = 0
  6861. while len(scatter_list) < world_size:
  6862. scatter_list.append(scatter_list[i])
  6863. i += 1
  6864. output_obj_list = [None]
  6865. dist.scatter_object_list(output_obj_list, scatter_list, src=src_rank)
  6866. self.assertEqual(
  6867. output_obj_list[0],
  6868. COLLECTIVES_OBJECT_TEST_LIST[
  6869. self.rank % len(COLLECTIVES_OBJECT_TEST_LIST)
  6870. ],
  6871. )
  6872. # Ensure errors are raised upon incorrect arguments.
  6873. with self.assertRaisesRegex(
  6874. RuntimeError,
  6875. "Expected argument scatter_object_output_list to be a list of size at least 1.",
  6876. ):
  6877. dist.scatter_object_list([], scatter_list, src=src_rank)
  6878. def _generate_sparse_tensors_for_bucket_assignment_test(self):
  6879. tensors = [
  6880. torch.empty([50], dtype=torch.float),
  6881. torch.empty([25], dtype=torch.double),
  6882. torch.empty([50], dtype=torch.float),
  6883. torch.empty([25], dtype=torch.double),
  6884. torch.empty([50], dtype=torch.float),
  6885. torch.empty([25], dtype=torch.double),
  6886. ]
  6887. tensors_sparse = [t.to_sparse() for t in tensors]
  6888. return tensors_sparse
  6889. def _test_compute_bucket_assignment_by_size(self, use_logger):
  6890. group_gloo = dist.new_group(
  6891. timeout=timedelta(seconds=60), backend=dist.Backend.GLOO
  6892. )
  6893. # Set NCCL_BLOCKING_WAIT and use a new NCCL group to improve test
  6894. # determinism.
  6895. os.environ["NCCL_BLOCKING_WAIT"] = "1"
  6896. group_to_use = dist.new_group(
  6897. backend=dist.get_backend(), timeout=timedelta(seconds=5)
  6898. )
  6899. torch.cuda.set_device(self.rank)
  6900. # Create a valid model. The constructor initializes the logger that we use later.
  6901. # We never actually use the rest of the model - we only need its logger.
  6902. net = EmbeddingNetDifferentParams(0)
  6903. net = torch.nn.parallel.DistributedDataParallel(
  6904. net.to(self.rank),
  6905. device_ids=[self.rank],
  6906. process_group=group_to_use,
  6907. )
  6908. # if we don't pass a logger then we can only check that an exception was thrown.
  6909. expected_err = "No support for sparse tensors."
  6910. with self.assertRaisesRegex(RuntimeError, expected_err):
  6911. tensors_sparse = self._generate_sparse_tensors_for_bucket_assignment_test()
  6912. if use_logger:
  6913. result = dist._compute_bucket_assignment_by_size(
  6914. tensors_sparse,
  6915. [400],
  6916. logger=net.logger)
  6917. else:
  6918. result = dist._compute_bucket_assignment_by_size(tensors_sparse, [400])
  6919. if use_logger:
  6920. verify_ddp_error_logged(net, expected_err)
  6921. # Perform gloo-based barrier to ensure one rank doesn't exit test
  6922. # early which causes failure with Barrier.sync.
  6923. dist.barrier(group_gloo)
  6924. @require_backend(DistTestCases.backend_feature["gpu"])
  6925. @require_backends_available(DistTestCases.backend_feature["gpu"])
  6926. @skip_if_lt_x_gpu(2)
  6927. def test_compute_bucket_assignment_by_size_sparse_error_without_logger(self):
  6928. self._test_compute_bucket_assignment_by_size(use_logger=False)
  6929. @require_backend(DistTestCases.backend_feature["gpu"])
  6930. @require_backends_available(DistTestCases.backend_feature["gpu"])
  6931. @skip_if_lt_x_gpu(2)
  6932. def test_compute_bucket_assignment_by_size_sparse_error_with_logger(self):
  6933. self._test_compute_bucket_assignment_by_size(use_logger=True)
  6934. def _determine_expected_error_verify_model_across_rank(
  6935. self,
  6936. group_to_use,
  6937. diff_num_params=False
  6938. ):
  6939. # When running with NCCL backend, we don't expect an error on rank 0,
  6940. # rather, it will be taken down by NCCL_ASYNC_ERROR_HANDLING. When
  6941. # running with Gloo or with debug mode wrapper, we expect the error
  6942. # to be caught inline.
  6943. # All ranks report same error when there is a # of parameter
  6944. # mismatch since we use allgather in the impl.
  6945. if diff_num_params:
  6946. expected_err = "DDP expects same model across all ranks"
  6947. ctx = self.assertRaisesRegex(RuntimeError, expected_err)
  6948. return ctx, expected_err
  6949. is_detail_dbg_mode = (
  6950. dist.get_debug_level() == dist.DebugLevel.DETAIL
  6951. )
  6952. if self.rank == 0:
  6953. if dist.get_backend(group_to_use) == dist.Backend.NCCL and not is_detail_dbg_mode:
  6954. expected_err = "Caught collective operation timeout"
  6955. ctx = self.assertRaisesRegex(RuntimeError, expected_err)
  6956. else:
  6957. expected_err = None
  6958. ctx = self.assertRaises(RuntimeError)
  6959. else:
  6960. expected_err = "appears not to match"
  6961. ctx = self.assertRaisesRegex(RuntimeError, expected_err)
  6962. return ctx, expected_err
  6963. def _test_verify_model_across_rank(self, use_logger):
  6964. group_gloo = dist.new_group(
  6965. timeout=timedelta(seconds=60), backend=dist.Backend.GLOO
  6966. )
  6967. # Set NCCL_BLOCKING_WAIT and use a new NCCL group to improve test
  6968. # determinism.
  6969. os.environ["NCCL_BLOCKING_WAIT"] = "1"
  6970. group_to_use = dist.new_group(
  6971. backend=dist.get_backend(), timeout=timedelta(seconds=5)
  6972. )
  6973. torch.cuda.set_device(self.rank)
  6974. ctx, expected_err = self._determine_expected_error_verify_model_across_rank(group_to_use)
  6975. # Create a valid model. The constructor initializes the logger that we use later.
  6976. net = EmbeddingNetDifferentParams(0)
  6977. net = torch.nn.parallel.DistributedDataParallel(
  6978. net.to(self.rank),
  6979. device_ids=[self.rank],
  6980. process_group=group_to_use,
  6981. )
  6982. # Modify the model so that the number of parameters are different for each rank.
  6983. # This will cause a RuntimeError to be thrown below in _verify_param_shape_across_processes,
  6984. # so we can check if the correct error is thrown and is logged.
  6985. # We can't do this in the constructor above otherwise the logger will
  6986. # not be properly initialized.
  6987. net.module.lin = nn.Linear(100 if self.rank == 0 else 10, 1)
  6988. # if we pass a logger we can verify that it was logged
  6989. with ctx:
  6990. if use_logger:
  6991. _verify_param_shape_across_processes(
  6992. net.process_group,
  6993. list(net.parameters()),
  6994. net.logger
  6995. )
  6996. else:
  6997. _verify_param_shape_across_processes(
  6998. net.process_group,
  6999. list(net.parameters())
  7000. )
  7001. # Should only be run by rank 0, and blocking_wait catches and
  7002. # reports exception.
  7003. dist.barrier(group_to_use)
  7004. # We don't check when self.rank != 0 because the logger doesn't log
  7005. # the error "Caught collective operation" as that is not thrown in the reducer.
  7006. if use_logger and self.rank != 0:
  7007. verify_ddp_error_logged(net, expected_err)
  7008. # Perform gloo-based barrier to ensure one rank doesn't exit test
  7009. # early which causes failure with Barrier.sync.
  7010. dist.barrier(group_gloo)
  7011. @require_backend(DistTestCases.backend_feature["gpu"])
  7012. @require_backends_available(DistTestCases.backend_feature["gpu"])
  7013. @sandcastle_skip_if(BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally")
  7014. @skip_if_lt_x_gpu(2)
  7015. def test_verify_model_across_rank_with_logger(self):
  7016. self._test_verify_model_across_rank(use_logger=True)
  7017. @require_backend(DistTestCases.backend_feature["gpu"])
  7018. @require_backends_available(DistTestCases.backend_feature["gpu"])
  7019. @sandcastle_skip_if(BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally")
  7020. @skip_if_lt_x_gpu(2)
  7021. def test_verify_model_across_rank_without_logger(self):
  7022. self._test_verify_model_across_rank(use_logger=False)
  7023. def _run_test_ddp_model_with_diff_params(self, ctx, net, ddp_group, group_gloo):
  7024. with ctx:
  7025. net = torch.nn.parallel.DistributedDataParallel(
  7026. net.to(self.rank),
  7027. device_ids=[self.rank],
  7028. process_group=ddp_group
  7029. )
  7030. # Should only be run by rank 0, and blocking_wait catches and
  7031. # reports exception.
  7032. dist.barrier(ddp_group)
  7033. # can't use verify_ddp_error_logged here because net was never properly constructed
  7034. # Perform gloo-based barrier to ensure one rank doesn't exit test
  7035. # early which causes failure with Barrier.sync.
  7036. dist.barrier(group_gloo)
  7037. @require_backend(DistTestCases.backend_feature["gpu"])
  7038. @require_backends_available(DistTestCases.backend_feature["gpu"])
  7039. @sandcastle_skip_if(BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally")
  7040. @skip_if_lt_x_gpu(2)
  7041. def test_ddp_model_diff_shape_across_ranks(self):
  7042. group_gloo = dist.new_group(
  7043. timeout=timedelta(seconds=60), backend=dist.Backend.GLOO
  7044. )
  7045. # Set NCCL_BLOCKING_WAIT and use a new NCCL group to improve test
  7046. # determinism.
  7047. os.environ["NCCL_BLOCKING_WAIT"] = "1"
  7048. group_to_use = dist.new_group(
  7049. backend=dist.get_backend(), timeout=timedelta(seconds=10)
  7050. )
  7051. torch.cuda.set_device(self.rank)
  7052. ctx, expected_err = self._determine_expected_error_verify_model_across_rank(group_to_use)
  7053. # Creates network with different sized embedding table on different
  7054. # ranks. This should throw an error during DDP init.
  7055. net = EmbeddingNetDifferentParams(self.rank)
  7056. self._run_test_ddp_model_with_diff_params(
  7057. ctx, net, group_to_use, group_gloo
  7058. )
  7059. @require_backend(DistTestCases.backend_feature["gpu"])
  7060. @require_backends_available(DistTestCases.backend_feature["gpu"])
  7061. @sandcastle_skip_if(BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally")
  7062. @skip_if_lt_x_gpu(2)
  7063. def test_ddp_model_diff_num_params_across_ranks(self):
  7064. group_gloo = dist.new_group(
  7065. timeout=timedelta(seconds=60), backend=dist.Backend.GLOO
  7066. )
  7067. # Set NCCL_BLOCKING_WAIT and use a new NCCL group to improve test
  7068. # determinism.
  7069. os.environ["NCCL_BLOCKING_WAIT"] = "1"
  7070. group_to_use = dist.new_group(
  7071. backend=dist.get_backend(), timeout=timedelta(seconds=10)
  7072. )
  7073. torch.cuda.set_device(self.rank)
  7074. ctx, expected_err = self._determine_expected_error_verify_model_across_rank(
  7075. group_to_use, diff_num_params=True
  7076. )
  7077. # Creates network with diff # of param across ranks, reducer should
  7078. # recognize this and throw appropriate error.
  7079. net = EmbeddingNetDifferentParams(self.rank, diff_num_params=(self.rank == 1))
  7080. self._run_test_ddp_model_with_diff_params(
  7081. ctx, net, group_to_use, group_gloo,
  7082. )
  7083. def _test_output_unused_in_loss(self, module_cls, gradient_as_bucket_view):
  7084. model = module_cls()
  7085. local_net = copy.deepcopy(model)
  7086. net = torch.nn.parallel.DistributedDataParallel(
  7087. copy.deepcopy(model).cuda(self.rank),
  7088. device_ids=[self.rank],
  7089. find_unused_parameters=True,
  7090. )
  7091. # Tests that certain parameters not getting gradient since the
  7092. # output is unused in loss computation is supported. Specifically,
  7093. # checks that the grads remain unchanged and are the same as local
  7094. # training.
  7095. inp = torch.randn(10, 10)
  7096. # Ensure that if a param is not used in loss computation, its
  7097. # gradient is untouched, i.e. if it is None before it is None after,
  7098. # not zero.
  7099. if module_cls == DictOutputModule:
  7100. a, b = local_net(inp)["predictions"]
  7101. a_dist, b_dist = net(inp)["predictions"]
  7102. else:
  7103. a, b = local_net(inp)
  7104. a_dist, b_dist = net(inp)
  7105. loss_dist = b_dist.sum()
  7106. loss_dist.backward()
  7107. # Ensure that gradient corresponding to parameter "a" was not
  7108. # touched, i.e. it is None and matches the local grad.
  7109. if module_cls == DictOutputModule:
  7110. self.assertTrue(net.module.module.a.weight.grad is None)
  7111. self.assertEqual(
  7112. net.module.module.a.weight.grad, local_net.module.a.weight.grad
  7113. )
  7114. else:
  7115. self.assertTrue(net.module.a.weight.grad is None)
  7116. self.assertEqual(net.module.a.weight.grad, local_net.a.weight.grad)
  7117. saved_a_local_grad = None
  7118. saved_a_dist_grad = None
  7119. net.zero_grad()
  7120. local_net.zero_grad()
  7121. for i in range(6):
  7122. if module_cls == DictOutputModule:
  7123. a, b = local_net(inp)["predictions"]
  7124. a_dist, b_dist = net(inp)["predictions"]
  7125. else:
  7126. a, b = local_net(inp)
  7127. a_dist, b_dist = net(inp)
  7128. if i < 2:
  7129. # Use both params in loss computation. Later, "a" will go
  7130. # unused and we check to ensure DDP supports this and
  7131. # gradients remain the same as local training.
  7132. t = a @ b
  7133. t_dist = a_dist @ b_dist
  7134. loss = t.sum()
  7135. loss_dist = t_dist.sum()
  7136. else:
  7137. # Model output "a" unused in loss.
  7138. loss = b.sum()
  7139. loss_dist = b_dist.sum()
  7140. loss.backward()
  7141. loss_dist.backward()
  7142. if i == 1:
  7143. # Save grads to compare with them in next iterations.
  7144. if module_cls == DictOutputModule:
  7145. saved_a_local_grad = local_net.module.a.weight.grad
  7146. saved_a_dist_grad = net.module.module.a.weight.grad
  7147. else:
  7148. saved_a_local_grad = local_net.a.weight.grad
  7149. saved_a_dist_grad = net.module.a.weight.grad
  7150. self.assertEqual(saved_a_local_grad, saved_a_dist_grad)
  7151. elif i >= 2:
  7152. # parameter "a" of both models should be the same and not change
  7153. if module_cls == DictOutputModule:
  7154. self.assertEqual(net.module.module.a.weight.grad, saved_a_dist_grad)
  7155. self.assertEqual(local_net.module.a.weight.grad, saved_a_local_grad)
  7156. else:
  7157. self.assertEqual(net.module.a.weight.grad, saved_a_dist_grad)
  7158. self.assertEqual(local_net.a.weight.grad, saved_a_local_grad)
  7159. # Verify grads are the same
  7160. for (local_param, dist_param) in zip(
  7161. local_net.parameters(), net.parameters()
  7162. ):
  7163. local_grad = local_param.grad
  7164. dist_grad = dist_param.grad
  7165. self.assertEqual(local_grad, dist_grad)
  7166. dist.barrier()
  7167. @sandcastle_skip_if(
  7168. BACKEND not in DistTestCases.backend_feature["ddp"],
  7169. f"The {BACKEND} backend does not support DistributedDataParallel"
  7170. )
  7171. @skip_if_lt_x_gpu(2)
  7172. def test_output_unused_in_loss_tuple_module(self):
  7173. module_cls = UnusedParamTwoLinLayerNet
  7174. for grad_as_bucket_view in [True, False]:
  7175. self._test_output_unused_in_loss(module_cls, grad_as_bucket_view)
  7176. @sandcastle_skip_if(
  7177. BACKEND not in DistTestCases.backend_feature["ddp"],
  7178. f"The {BACKEND} backend does not support DistributedDataParallel"
  7179. )
  7180. @skip_if_lt_x_gpu(2)
  7181. def test_output_unused_in_loss_dict_module(self):
  7182. module_cls = DictOutputModule
  7183. for grad_as_bucket_view in [True, False]:
  7184. self._test_output_unused_in_loss(module_cls, grad_as_bucket_view)
  7185. @sandcastle_skip_if(
  7186. BACKEND not in DistTestCases.backend_feature["ddp"],
  7187. f"The {BACKEND} backend does not support DistributedDataParallel"
  7188. )
  7189. @skip_if_lt_x_gpu(2)
  7190. def test_undefined_grad_parity_unused_parameters(self):
  7191. # TODO: enable this for general training use cases:
  7192. # https://github.com/pytorch/pytorch/issues/58511.
  7193. x = torch.ones(1, 2).to(self.rank)
  7194. net = Net().to(self.rank)
  7195. local_net = copy.deepcopy(net)
  7196. net = torch.nn.parallel.DistributedDataParallel(
  7197. net,
  7198. device_ids=[self.rank],
  7199. find_unused_parameters=True,
  7200. )
  7201. out = net(x).sum()
  7202. local_out = local_net(x).sum()
  7203. # Simulates undefined gradients.
  7204. torch._C._functions.UndefinedGrad()(out).backward()
  7205. torch._C._functions.UndefinedGrad()(local_out).backward()
  7206. for (dist_param_name, dist_param), (local_param_name, local_param) in zip(
  7207. net.named_parameters(), local_net.named_parameters()
  7208. ):
  7209. dist_grad = dist_param.grad
  7210. local_grad = local_param.grad
  7211. self.assertEqual(
  7212. dist_grad,
  7213. local_grad,
  7214. f"""DDP param {dist_param_name} with grad {dist_grad}
  7215. does not match local param {local_param_name} with grad
  7216. {local_grad}""",
  7217. )
  7218. def _test_different_graph_across_ranks(
  7219. self, find_unused_parameters=False, static_graph=False
  7220. ):
  7221. class ToyModel(nn.Module):
  7222. def __init__(self, rank):
  7223. super().__init__()
  7224. self.lin1 = nn.Linear(10, 10, bias=False)
  7225. self.lin2 = nn.Linear(10, 10, bias=False)
  7226. self.rank = rank
  7227. def forward(self, x):
  7228. if self.rank == 0:
  7229. return self.lin2(F.relu(self.lin1(x)))
  7230. else:
  7231. return F.relu(self.lin1(x))
  7232. torch.manual_seed(31415)
  7233. world_size = dist.get_world_size()
  7234. torch.cuda.set_device(self.rank)
  7235. model = ToyModel(self.rank).cuda(self.rank)
  7236. ddp_model = torch.nn.parallel.DistributedDataParallel(
  7237. model,
  7238. device_ids=[self.rank],
  7239. find_unused_parameters=find_unused_parameters,
  7240. gradient_as_bucket_view=True,
  7241. static_graph=static_graph,
  7242. )
  7243. random_input = torch.randn(20, 10, device=self.rank)
  7244. for i in range(10):
  7245. out = ddp_model(random_input)
  7246. loss = out.sum()
  7247. loss.backward()
  7248. return ddp_model
  7249. @require_backend(DistTestCases.backend_feature["gpu"])
  7250. @require_backends_available(DistTestCases.backend_feature["gpu"])
  7251. @skip_if_lt_x_gpu(2)
  7252. def test_different_graph_across_ranks(self):
  7253. base_model = self._test_different_graph_across_ranks(
  7254. find_unused_parameters=True
  7255. )
  7256. self.assertFalse(
  7257. base_model._get_ddp_logging_data().get("has_rebuilt_buckets", 0)
  7258. )
  7259. static_model = self._test_different_graph_across_ranks(static_graph=True)
  7260. self.assertTrue(
  7261. static_model._get_ddp_logging_data().get("has_rebuilt_buckets", 0)
  7262. )
  7263. for i, j in zip(base_model.parameters(), static_model.parameters()):
  7264. self.assertEqual(i, j)
  7265. @require_backend({"gloo"})
  7266. @require_backends_available({"gloo"})
  7267. @sandcastle_skip_if(
  7268. IS_MACOS or IS_WINDOWS,
  7269. "MacOS uses uv transport which does not have as robust error handling as tcp transport",
  7270. )
  7271. def test_monitored_barrier_gloo(self):
  7272. tensors = [torch.ones(10) * self.rank]
  7273. # Kick off some allreduce work on all ranks
  7274. for _ in range(10):
  7275. dist.all_reduce(torch.cat(tensors))
  7276. # Run monitored barrier and ensure it passees
  7277. timeout = timedelta(seconds=2)
  7278. dist.monitored_barrier(timeout=timeout)
  7279. # Check monitored_barrier success with wait_all_ranks=True
  7280. for _ in range(10):
  7281. dist.all_reduce(torch.cat(tensors))
  7282. dist.monitored_barrier(timeout=timeout, wait_all_ranks=True)
  7283. # All ranks besides 1 call into barrier, rank 0 should report failure
  7284. # while others report gloo error.
  7285. failed_rank = 1
  7286. src_rank = 0
  7287. if self.rank == src_rank:
  7288. with self.assertRaisesRegex(
  7289. RuntimeError, f"Rank {failed_rank} failed to pass monitoredBarrier"
  7290. ):
  7291. dist.monitored_barrier(timeout=timeout)
  7292. elif self.rank != failed_rank:
  7293. # Other ranks should not pass barrier since rank 0 failed.
  7294. err_regex = (
  7295. f"Rank {self.rank} successfully reached monitoredBarrier,"
  7296. f" but received errors while waiting for send/recv from rank"
  7297. f" {src_rank}"
  7298. )
  7299. with self.assertRaisesRegex(RuntimeError, err_regex):
  7300. dist.monitored_barrier(timeout=timeout)
  7301. # We need a barrier since otherwise failed_rank exits too early
  7302. # and cause a timeout.
  7303. self._barrier(timeout=30)
  7304. @require_backend({"gloo"})
  7305. @require_backends_available({"gloo"})
  7306. def test_monitored_barrier_gloo_subgroup(self):
  7307. # Tests that monitored_barrier works as expected on non-default
  7308. # process groups.
  7309. failed_rank = 1
  7310. timeout = 0.1
  7311. subgroup = dist.new_group(ranks=[0, 1])
  7312. if self.rank == failed_rank:
  7313. return
  7314. if self.rank == 0:
  7315. with self.assertRaisesRegex(
  7316. RuntimeError, f"Rank {failed_rank} failed to pass monitoredBarrier"
  7317. ):
  7318. dist.monitored_barrier(subgroup, timeout)
  7319. else:
  7320. # Other ranks call into monitored_barrier, but this should be a
  7321. # noop because they are not part of the subgroup. Verify that
  7322. # there are no errors here.
  7323. dist.monitored_barrier(subgroup, timeout)
  7324. def _test_monitored_barrier_allreduce_hang(self, wait_all_ranks):
  7325. # tests expected behavior when nonzero rank hangs.
  7326. nccl_pg = dist.new_group(
  7327. ranks=list(range(int(self.world_size))),
  7328. # provide sufficient timeout so communicators
  7329. # can be initialized in ctor.
  7330. timeout=timedelta(seconds=15),
  7331. backend=dist.Backend.NCCL,
  7332. )
  7333. gloo_pg = dist.new_group(
  7334. ranks=list(range(int(self.world_size))),
  7335. backend=dist.Backend.GLOO,
  7336. )
  7337. tensors = [torch.ones(10, device=self.rank) * self.rank]
  7338. # Let all ranks call allreduce first to set up communicators etc.
  7339. # Directly simulating error here will run into store issue described
  7340. # in https://github.com/pytorch/pytorch/issues/54524.
  7341. nccl_pg.allreduce(tensors).wait(timedelta(seconds=5))
  7342. # All ranks besides 0 call into allreduce. This is to simulate a
  7343. # desync across the world, where some ranks call into
  7344. # monitored_barrier() and others are stuck in collective comm. In
  7345. # practice, we don't need NCCL_BLOCKING_WAIT, but we use it in this
  7346. # test to ensure it exits cleanly.
  7347. if self.rank != 0:
  7348. # Can get different errors here depending on whether gloo-based
  7349. # wrapper PG is enabled or not, since with wrapper pg, it will
  7350. # fail in a collective synchronization check and not actually
  7351. # call into the nccl pg.
  7352. if dist.get_debug_level() == dist.DebugLevel.DETAIL:
  7353. err_regex = "Timed out waiting"
  7354. else:
  7355. err_regex = "Caught collective operation timeout"
  7356. with self.assertRaisesRegex(RuntimeError, err_regex):
  7357. nccl_pg.allreduce(tensors).wait(timedelta(seconds=0.1))
  7358. else:
  7359. # Rank 0 should report first (in order) timed out rank or all ranks
  7360. # depending on wait_all_ranks flag passed into monitored_barrier.
  7361. if wait_all_ranks:
  7362. rank_str = ", ".join(
  7363. [str(i) for i in range(1, int(self.world_size))]
  7364. )
  7365. err_regex = f"Ranks {rank_str} failed to pass monitoredBarrier"
  7366. else:
  7367. expected_first_fail_rank = 1
  7368. err_regex = f"Rank {expected_first_fail_rank} failed to pass monitoredBarrier"
  7369. monitored_barrier_timeout_seconds = timedelta(seconds=0.1)
  7370. with self.assertRaisesRegex(RuntimeError, err_regex):
  7371. gloo_pg.monitored_barrier(
  7372. monitored_barrier_timeout_seconds, wait_all_ranks=wait_all_ranks
  7373. )
  7374. self._barrier(timeout=30)
  7375. @with_nccl_blocking_wait
  7376. @require_backend(DistTestCases.backend_feature["gpu"])
  7377. @require_backends_available(DistTestCases.backend_feature["gpu"])
  7378. @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
  7379. def test_monitored_barrier_allreduce_hang(self):
  7380. # tests expected behavior when nonzero rank hangs and we want to
  7381. # report first timed out rank.
  7382. self._test_monitored_barrier_allreduce_hang(wait_all_ranks=False)
  7383. @with_nccl_blocking_wait
  7384. @require_backend(DistTestCases.backend_feature["gpu"])
  7385. @require_backends_available(DistTestCases.backend_feature["gpu"])
  7386. @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
  7387. def test_monitored_barrier_allreduce_hang_wait_all_ranks(self):
  7388. # tests expected behavior when nonzero rank hangs and we want to
  7389. # report all timed out ranks.
  7390. self._test_monitored_barrier_allreduce_hang(wait_all_ranks=True)
  7391. @require_backend({"gloo"})
  7392. @require_backends_available({"gloo"})
  7393. def test_monitored_barrier_gloo_rank_0_timeout(self):
  7394. # tests error when rank 0 exhausts its given timeout.
  7395. process_group = dist.new_group(
  7396. ranks=list(range(int(self.world_size)))
  7397. )
  7398. timeout = timedelta(seconds=0)
  7399. if self.rank == 0:
  7400. with self.assertRaisesRegex(
  7401. RuntimeError, f"Rank {self.rank} timed out in monitoredBarrier"
  7402. ):
  7403. process_group.monitored_barrier(timeout)
  7404. @require_backend({"gloo"})
  7405. @require_backends_available({"gloo"})
  7406. @skip_if_small_worldsize
  7407. @sandcastle_skip_if(
  7408. IS_MACOS or IS_WINDOWS,
  7409. "MacOS uses uv transport which does not have as robust error handling as tcp transport",
  7410. )
  7411. def test_monitored_barrier_failure_order(self):
  7412. # Ensure that the first (in sorted order) rank is reported when
  7413. # multiple ranks fail to pass the monitored_barrier.
  7414. # TODO(#54879): Provide ability to wait and report all failed ranks
  7415. expected_first_failed_rank = 2
  7416. timeout = timedelta(seconds=2)
  7417. src_rank = 0
  7418. if self.rank == src_rank:
  7419. with self.assertRaisesRegex(
  7420. RuntimeError, f"Rank {expected_first_failed_rank}"
  7421. ):
  7422. dist.monitored_barrier(timeout=timeout)
  7423. elif self.rank == 1:
  7424. err_regex = (
  7425. f"Rank {self.rank} successfully reached monitoredBarrier,"
  7426. f" but received errors while waiting for send/recv from rank"
  7427. f" {src_rank}"
  7428. )
  7429. with self.assertRaisesRegex(RuntimeError, err_regex):
  7430. dist.monitored_barrier(timeout=timeout)
  7431. @require_backend({"gloo"})
  7432. @require_backends_available({"gloo"})
  7433. @skip_if_small_worldsize
  7434. def test_monitored_barrier_wait_all_ranks(self):
  7435. # Tests simple case where > 1 rank does not call into monitored
  7436. # barrier and verifies all ranks are reported by rank 0.
  7437. if self.rank == 0:
  7438. timeout = timedelta(seconds=0.1)
  7439. rank_str = ", ".join([str(i) for i in range(1, int(self.world_size))])
  7440. err_regex = f"Ranks {rank_str} failed to pass monitoredBarrier"
  7441. with self.assertRaisesRegex(RuntimeError, err_regex):
  7442. dist.monitored_barrier(timeout=timeout, wait_all_ranks=True)
  7443. @require_backend(DistTestCases.backend_feature["gpu"])
  7444. @require_backends_available(DistTestCases.backend_feature["gpu"])
  7445. @with_dist_debug_levels(levels=["INFO"])
  7446. @skip_if_lt_x_gpu(2)
  7447. def test_ddp_build_debug_param_to_name_mapping(self):
  7448. model = TwoLinLayerNet()
  7449. net = torch.nn.parallel.DistributedDataParallel(
  7450. model.cuda(self.rank),
  7451. device_ids=[self.rank],
  7452. )
  7453. expected_mapping = {0: "a.weight", 1: "b.weight"}
  7454. net_params, _ = net._build_params_for_reducer()
  7455. param_to_name_mapping = net._build_debug_param_to_name_mapping(net_params)
  7456. self.assertDictEqual(expected_mapping, param_to_name_mapping)
  7457. # Test when DDP is used with ignored parameters.
  7458. model = TwoLinLayerNet()
  7459. # Parameters to ignore are in the format {module_name}.{param_name}
  7460. params_to_ignore = ["a.weight"]
  7461. torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
  7462. model, params_to_ignore
  7463. )
  7464. net = torch.nn.parallel.DistributedDataParallel(
  7465. model.cuda(self.rank),
  7466. device_ids=[self.rank],
  7467. )
  7468. expected_mapping = {0: "b.weight"}
  7469. net_params, _ = net._build_params_for_reducer()
  7470. param_to_name_mapping = net._build_debug_param_to_name_mapping(net_params)
  7471. self.assertDictEqual(expected_mapping, param_to_name_mapping)
  7472. # Test errors are raised when DDP and module parameters mismatch.
  7473. # This generally indicates a bug with DDP and is not expected to
  7474. # happen in user applications.
  7475. model = TwoLinLayerNet()
  7476. net = torch.nn.parallel.DistributedDataParallel(
  7477. model.cuda(self.rank),
  7478. device_ids=[self.rank],
  7479. )
  7480. net_params, _ = net._build_params_for_reducer()
  7481. if self.rank == 0:
  7482. print(type(net_params[0]))
  7483. net_params.extend(
  7484. [
  7485. torch.nn.Parameter(torch.ones(1)),
  7486. torch.nn.Parameter(torch.ones(1)),
  7487. ]
  7488. )
  7489. with self.assertRaisesRegex(ValueError, "Expected param to name mapping"):
  7490. net._build_debug_param_to_name_mapping(net_params)
  7491. net_params = net_params[:-3]
  7492. with self.assertRaisesRegex(ValueError, "Param with name"):
  7493. net._build_debug_param_to_name_mapping(net_params)
  7494. net_params.extend(
  7495. [
  7496. torch.nn.Parameter(torch.ones(1)),
  7497. torch.nn.Parameter(torch.ones(1)),
  7498. ]
  7499. )
  7500. @sandcastle_skip_if(
  7501. BACKEND not in DistTestCases.backend_feature["ddp"],
  7502. f"The {BACKEND} backend does not support DistributedDataParallel"
  7503. )
  7504. @with_dist_debug_levels(levels=["INFO"])
  7505. @skip_if_lt_x_gpu(2)
  7506. def test_ddp_build_debug_param_to_name_mapping_requires_grad(self):
  7507. class Net(nn.Module):
  7508. def __init__(self):
  7509. super().__init__()
  7510. self.lin = nn.Linear(10, 10)
  7511. # Is not tracked by DDP and should not show up in param to
  7512. # name mapping.
  7513. self.lin.bias.requires_grad_(False)
  7514. def forward(self, x):
  7515. return self.lin(x)
  7516. model = Net()
  7517. net = torch.nn.parallel.DistributedDataParallel(
  7518. model.cuda(self.rank), device_ids=[self.rank]
  7519. )
  7520. expected_mapping = {
  7521. 0: "lin.weight",
  7522. }
  7523. net_params, _ = net._build_params_for_reducer()
  7524. param_to_name_mapping = net._build_debug_param_to_name_mapping(net_params)
  7525. self.assertEqual(param_to_name_mapping, expected_mapping)
  7526. def _test_ddp_multiple_nested_unused_params_error(self, ignore_sparse):
  7527. debug_mode_off = dist.get_debug_level() == dist.DebugLevel.OFF
  7528. class SubModule(nn.Module):
  7529. def __init__(self):
  7530. super().__init__()
  7531. self.embedding_net = EmbeddingNetDifferentParams(0)
  7532. self.lin = TwoLinLayerNet()
  7533. self.bn = BatchNormNet()
  7534. self.lin_layer = nn.Linear(4, 10, bias=False)
  7535. def forward(self, x):
  7536. x = self.bn(x)
  7537. x = self.lin_layer(x)
  7538. x = self.lin.a(x) # self.lin.b param unused
  7539. # EmbeddingNetDifferentParams entirely unused: self.embedding_net.embedding and
  7540. # self.embedding_net.lin unused.
  7541. return x
  7542. class MyModel(nn.Module):
  7543. def __init__(self):
  7544. super().__init__()
  7545. self.sub_module = SubModule()
  7546. def forward(self, x):
  7547. return self.sub_module(x)
  7548. model = MyModel()
  7549. sparse_embedding_fqns = []
  7550. if ignore_sparse:
  7551. for module_name, module in model.named_modules():
  7552. if module == model.sub_module.embedding_net.embedding:
  7553. for parameter_name, param in module.named_parameters(
  7554. recurse=False
  7555. ):
  7556. fqn = f"{module_name}.{parameter_name}"
  7557. sparse_embedding_fqns.append(fqn)
  7558. torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
  7559. model, sparse_embedding_fqns
  7560. )
  7561. unused_modules = [
  7562. model.sub_module.embedding_net.lin,
  7563. model.sub_module.lin.b,
  7564. ]
  7565. else:
  7566. unused_modules = list(model.sub_module.embedding_net.modules()) + [
  7567. model.sub_module.lin.b,
  7568. ]
  7569. expected_unused_param_fqns = []
  7570. used_param_fqns = [] # Validate that these don't mistakenly show up.
  7571. fqn_to_param_index = {}
  7572. index = 0
  7573. for module_name, module in model.named_modules():
  7574. for parameter_name, param in module.named_parameters(recurse=False):
  7575. fqn = f"{module_name}.{parameter_name}"
  7576. fqn_to_param_index[fqn] = index
  7577. if fqn not in sparse_embedding_fqns:
  7578. index += 1
  7579. if module in unused_modules:
  7580. expected_unused_param_fqns.append(fqn)
  7581. else:
  7582. if (
  7583. not ignore_sparse
  7584. or module != model.sub_module.embedding_net.embedding
  7585. ):
  7586. used_param_fqns.append(fqn)
  7587. net = torch.nn.parallel.DistributedDataParallel(
  7588. model.cuda(self.rank),
  7589. device_ids=[self.rank],
  7590. )
  7591. batch, dim = 10, 2
  7592. inp = torch.ones(batch, dim)
  7593. for i in range(2):
  7594. if i == 0:
  7595. out = net(inp)
  7596. loss = out.sum()
  7597. loss.backward()
  7598. else:
  7599. try:
  7600. out = net(inp)
  7601. loss = out.sum()
  7602. loss.backward()
  7603. except RuntimeError as e:
  7604. e = str(e)
  7605. unused_param_substr = e[e.find("did not receive grad") :]
  7606. # Validate that each unused param fully qualified name
  7607. # shows up in error logs. We do this instead of
  7608. # constructing a joined string since order of parameters
  7609. # can be different in Reducer. In addition, validate
  7610. # param indices show up as well.
  7611. for unused_param_fqn in expected_unused_param_fqns:
  7612. self.assertTrue(
  7613. unused_param_fqn in unused_param_substr
  7614. or debug_mode_off
  7615. )
  7616. self.assertTrue(
  7617. str(fqn_to_param_index[unused_param_fqn])
  7618. in unused_param_substr,
  7619. f"Did not find index {fqn_to_param_index[unused_param_fqn]} for {unused_param_fqn}",
  7620. )
  7621. # Validate that used param fqns don't show up in error
  7622. # logs.
  7623. for used_param_fqn in used_param_fqns:
  7624. self.assertFalse(used_param_fqn in unused_param_substr)
  7625. # Validate that ignored param fqns don't show up as unused
  7626. # (since DDP does not track them)
  7627. for sparse_param_fqn in sparse_embedding_fqns:
  7628. self.assertFalse(sparse_param_fqn in unused_param_substr)
  7629. else:
  7630. self.assertTrue(False, "Expected error was not raised!")
  7631. @with_dist_debug_levels(levels=["OFF", "INFO", "DETAIL"])
  7632. @require_backend(DistTestCases.backend_feature["gpu"])
  7633. @require_backends_available(DistTestCases.backend_feature["gpu"])
  7634. @skip_if_lt_x_gpu(2)
  7635. def test_ddp_multiple_nested_unused_params_error(self):
  7636. self._test_ddp_multiple_nested_unused_params_error(ignore_sparse=False)
  7637. @with_dist_debug_levels(levels=["OFF", "INFO", "DETAIL"])
  7638. @require_backend(DistTestCases.backend_feature["gpu"])
  7639. @require_backends_available(DistTestCases.backend_feature["gpu"])
  7640. @skip_if_lt_x_gpu(2)
  7641. def test_ddp_multiple_nested_unused_params_err_ignore_params(self):
  7642. # Tests unused parameter reporting when DDP is configured to ignore
  7643. # certain parameters.
  7644. self._test_ddp_multiple_nested_unused_params_error(ignore_sparse=True)
  7645. @sandcastle_skip_if(
  7646. BACKEND not in DistTestCases.backend_feature["ddp"],
  7647. f"The {BACKEND} backend does not support DistributedDataParallel"
  7648. )
  7649. @skip_if_lt_x_gpu(2)
  7650. def test_ddp_inference(self):
  7651. # tests that DDP module can be run on a single node with no_grad
  7652. # or eval setting and there is no hang.
  7653. rank = self.rank
  7654. torch.cuda.set_device(rank)
  7655. model = Net().cuda()
  7656. local_model = copy.deepcopy(model)
  7657. model = torch.nn.parallel.DistributedDataParallel(
  7658. model,
  7659. device_ids=[rank],
  7660. )
  7661. syncbn_model = nn.SyncBatchNorm(
  7662. 2, momentum=0.99, track_running_stats=False
  7663. ).cuda()
  7664. local_syncbn_model = copy.deepcopy(syncbn_model)
  7665. syncbn_model = torch.nn.parallel.DistributedDataParallel(
  7666. syncbn_model, device_ids=[rank]
  7667. )
  7668. inp = torch.randn(10, 2, device=rank)
  7669. inp_syncbn = torch.randn(10, 2, 4, 4, device=rank)
  7670. tests = [
  7671. (model, local_model, inp),
  7672. (syncbn_model, local_syncbn_model, inp_syncbn),
  7673. ]
  7674. for test in tests:
  7675. test_model, test_local_model, test_inp = test
  7676. if self.rank == 0:
  7677. test_model.eval()
  7678. test_local_model.eval()
  7679. for _ in range(6):
  7680. self.assertEqual(
  7681. test_model(test_inp), test_local_model(test_inp)
  7682. )
  7683. # Barrier since only rank 0 runs inference. Test should be
  7684. # much faster than 30s, but this is to avoid flakiness.
  7685. self._barrier(timeout=30)
  7686. @sandcastle_skip_if(
  7687. BACKEND not in DistTestCases.backend_feature["ddp"],
  7688. f"The {BACKEND} backend does not support DistributedDataParallel"
  7689. )
  7690. @skip_if_lt_x_gpu(2)
  7691. def test_ddp_sync_bn_training_vs_eval(self):
  7692. rank = self.rank
  7693. torch.cuda.set_device(rank)
  7694. # Need to set track_running_stats=False, when track_running_stats=True,
  7695. # bn_training is False and sync could not occur in eval model.
  7696. model = nn.SyncBatchNorm(2, momentum=0.99, track_running_stats=False).cuda(
  7697. rank
  7698. )
  7699. model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])
  7700. # Test sync occurs in training mode.
  7701. with torch.autograd.profiler.profile() as prof:
  7702. for i in range(6):
  7703. inp = torch.randn(10, 2, 4, 4).cuda(rank)
  7704. out = model(inp)
  7705. loss = out.sum()
  7706. loss.backward()
  7707. # SyncBN allgathers stats across all ranks, so verify call to
  7708. # all_gather in profiler.
  7709. if BACKEND == "nccl":
  7710. all_gather_calls = get_profiling_event("_all_gather_base", prof)
  7711. else:
  7712. all_gather_calls = get_profiling_event("all_gather", prof)
  7713. self.assertNotEqual([], all_gather_calls)
  7714. # Only do inference on one rank. If SyncBN did collective stats sync,
  7715. # this would hang/error.
  7716. model_inference = model.module
  7717. if self.rank == 0:
  7718. model_inference.eval()
  7719. with torch.autograd.profiler.profile() as prof:
  7720. for i in range(6):
  7721. inp = torch.randn(10, 2, 4, 4).cuda(rank)
  7722. out = model_inference(inp)
  7723. loss = out.sum()
  7724. loss.backward()
  7725. # Ensure sync does not occur in eval() mode.
  7726. if BACKEND == "nccl":
  7727. all_gather_calls = get_profiling_event("_all_gather_base", prof)
  7728. else:
  7729. all_gather_calls = get_profiling_event("all_gather", prof)
  7730. self.assertEqual([], all_gather_calls)
  7731. @skip_if_lt_x_gpu(2)
  7732. @sandcastle_skip_if(
  7733. BACKEND not in DistTestCases.backend_feature["ddp"],
  7734. f"The {BACKEND} backend does not support DistributedDataParallel"
  7735. )
  7736. def test_ddp_python_error_logged(self):
  7737. # Most python exceptions in DDP are raised during init before
  7738. # reducer is constructed, so we don't have a logger in those cases.
  7739. # However, the below is one example where a python error is thrown
  7740. # after reducer is constructed.
  7741. model = TwoLinLayerNet().cuda(self.rank)
  7742. model = torch.nn.parallel.DistributedDataParallel(
  7743. model,
  7744. device_ids=[self.rank],
  7745. )
  7746. expected_err = "must be callable"
  7747. with self.assertRaisesRegex(TypeError, expected_err):
  7748. model.register_comm_hook({}, {})
  7749. verify_ddp_error_logged(model, expected_err)
  7750. @skip_if_lt_x_gpu(2)
  7751. @sandcastle_skip_if(
  7752. BACKEND not in DistTestCases.backend_feature["ddp"],
  7753. f"The {BACKEND} backend does not support DistributedDataParallel"
  7754. )
  7755. def test_ddp_static_graph_nested_types(self):
  7756. # Tests for static graph training when outputs are not just tensors
  7757. # but can be (nested) tuple, list, dict, etc.
  7758. rank = self.rank
  7759. torch.cuda.set_device(rank)
  7760. class NestedOutputModule(torch.nn.Module):
  7761. def __init__(self):
  7762. super().__init__()
  7763. self.lin = nn.Linear(100, 1, bias=False)
  7764. def forward(self, inp, output_type):
  7765. if output_type == "tuple":
  7766. return (
  7767. self.lin(inp),
  7768. (
  7769. self.lin(inp),
  7770. self.lin(inp),
  7771. ),
  7772. )
  7773. elif output_type == "list":
  7774. return [
  7775. self.lin(inp),
  7776. [
  7777. self.lin(inp),
  7778. self.lin(inp),
  7779. ],
  7780. ]
  7781. elif output_type == "dict":
  7782. return {
  7783. "a": self.lin(inp),
  7784. "b": {
  7785. "c": self.lin(inp),
  7786. },
  7787. }
  7788. def get_loss(model_output):
  7789. loss = 0.0
  7790. if isinstance(model_output, torch.Tensor):
  7791. return model_output.sum()
  7792. elif isinstance(model_output, dict):
  7793. for value in model_output.values():
  7794. loss += get_loss(value)
  7795. elif isinstance(model_output, (tuple, list)):
  7796. for x in model_output:
  7797. loss += get_loss(x)
  7798. else:
  7799. raise ValueError(f"Unknown model output type {type(model_output)}")
  7800. return loss
  7801. model = NestedOutputModule().cuda(rank)
  7802. model_static_graph = copy.deepcopy(model)
  7803. model = torch.nn.parallel.DistributedDataParallel(
  7804. model,
  7805. device_ids=[rank],
  7806. )
  7807. model_static_graph = torch.nn.parallel.DistributedDataParallel(
  7808. model,
  7809. device_ids=[rank],
  7810. static_graph=True,
  7811. )
  7812. inp = torch.randn(10, 100)
  7813. type_mapping = {
  7814. "list": list,
  7815. "tuple": tuple,
  7816. "dict": dict,
  7817. }
  7818. for output_type in type_mapping.keys():
  7819. for i in range(6):
  7820. out = model(inp, output_type=output_type)
  7821. loss = get_loss(out)
  7822. loss.backward()
  7823. self._model_step(model)
  7824. out_static = model_static_graph(inp, output_type=output_type)
  7825. self.assertTrue(isinstance(out_static, type_mapping[output_type]))
  7826. loss_static = get_loss(out_static)
  7827. loss_static.backward()
  7828. self._model_step(model_static_graph)
  7829. for (p, p_static) in zip(
  7830. model.parameters(), model_static_graph.parameters()
  7831. ):
  7832. self.assertEqual(p, p_static)
  7833. @skip_if_lt_x_gpu(2)
  7834. @sandcastle_skip_if(
  7835. BACKEND not in DistTestCases.backend_feature["ddp"],
  7836. f"The {BACKEND} backend does not support DistributedDataParallel"
  7837. )
  7838. def test_ddp_returns_tensor_with_no_grad(self):
  7839. # Tests case where module returns tensor that does not require grad.
  7840. torch.cuda.set_device(self.rank)
  7841. class MyModel(nn.Module):
  7842. def __init__(self):
  7843. super().__init__()
  7844. self.fc1 = nn.Linear(10, 10, bias=False)
  7845. self.fc2 = nn.Linear(10, 10, bias=False)
  7846. def forward(self, x):
  7847. x = self.fc2(F.relu(self.fc1(x)))
  7848. y = x.clone()
  7849. x = x.detach()
  7850. assert not x.requires_grad
  7851. return (x, y)
  7852. model = MyModel().to(self.rank)
  7853. inp = torch.randn(1, 10, device=self.rank)
  7854. for (find_unused, static_graph) in itertools.product([True, False], [True, False]):
  7855. ddp = DistributedDataParallel(
  7856. model,
  7857. device_ids=[self.rank],
  7858. output_device=self.rank,
  7859. find_unused_parameters=find_unused,
  7860. static_graph=static_graph,
  7861. )
  7862. for i in range(6):
  7863. out = ddp(inp)
  7864. self.assertFalse(out[0].requires_grad)
  7865. o = (out[0] + out[1]).sum()
  7866. o.backward()
  7867. @skip_if_lt_x_gpu(2)
  7868. @sandcastle_skip_if(
  7869. BACKEND not in DistTestCases.backend_feature["ddp"],
  7870. f"The {BACKEND} backend does not support DistributedDataParallel"
  7871. )
  7872. def test_detect_ddp_is_actually_static(self):
  7873. class ToyModel(nn.Module):
  7874. def __init__(self):
  7875. super().__init__()
  7876. self.net1 = nn.Linear(10, 10, bias=False)
  7877. self.net2 = nn.Linear(10, 10)
  7878. def forward(self, x, find_unused, dynamic):
  7879. if find_unused:
  7880. if dynamic:
  7881. return self.net2(self.net1(x))
  7882. else:
  7883. return self.net2(x)
  7884. else:
  7885. return self.net2(self.net1(x))
  7886. # Set of unused parameters don't change across iterations
  7887. torch.cuda.set_device(self.rank)
  7888. model = ToyModel().cuda()
  7889. for find_unused in [True, False]:
  7890. ddp = torch.nn.parallel.DistributedDataParallel(
  7891. model,
  7892. device_ids=[self.rank],
  7893. find_unused_parameters=find_unused,
  7894. )
  7895. inp = torch.randn(1, 10, device="cuda")
  7896. for _ in range(6):
  7897. out = ddp(inp, find_unused=find_unused, dynamic=False)
  7898. loss = out.sum()
  7899. loss.backward()
  7900. self.assertTrue(ddp.reducer._ddp_graph_static())
  7901. # Set of unused parameters dynamically change
  7902. ddp = torch.nn.parallel.DistributedDataParallel(
  7903. model,
  7904. device_ids=[self.rank],
  7905. find_unused_parameters=True,
  7906. )
  7907. inp = torch.randn(1, 10, device="cuda")
  7908. for i in range(6):
  7909. out = ddp(inp, find_unused=True, dynamic=i % 2 == 0)
  7910. loss = out.sum()
  7911. loss.backward()
  7912. self.assertFalse(ddp.reducer._ddp_graph_static())
  7913. def _test_ddp_new_tensor_in_fwd(self, static_graph):
  7914. # Test from https://github.com/pytorch/pytorch/issues/60733
  7915. class MyModel(nn.Module):
  7916. def __init__(self):
  7917. super().__init__()
  7918. self.fc1 = nn.Linear(10, 10, bias=False)
  7919. self.fc2 = nn.Linear(10, 10, bias=False)
  7920. self.device = self.fc1.weight.device
  7921. def __init_opt(self):
  7922. opt = torch.randn(1, 10, device=self.device)
  7923. return opt
  7924. def forward(self, x, opt_1, opt_2, opt_nested):
  7925. x = F.relu(self.fc1(x))
  7926. x = self.fc2(x)
  7927. if opt_1 is None:
  7928. opt_1 = self.__init_opt()
  7929. if opt_2 is None:
  7930. opt_2 = self.__init_opt()
  7931. if opt_nested is None or not torch.is_tensor(opt_nested):
  7932. opt_nested = self.__init_opt()
  7933. # Test multiple tensors as well as newly created tensors
  7934. # within a struct.
  7935. return x, opt_1, opt_2, {"tensor": opt_nested}
  7936. model = MyModel().to(self.rank)
  7937. for find_unused in [True, False]:
  7938. ddp = DistributedDataParallel(
  7939. model,
  7940. device_ids=[self.rank],
  7941. output_device=self.rank,
  7942. broadcast_buffers=False,
  7943. find_unused_parameters=find_unused,
  7944. static_graph=static_graph,
  7945. )
  7946. opt = [None for _ in range(3)]
  7947. for i in range(2):
  7948. ddp.zero_grad()
  7949. x = torch.randn(1, 10, device=self.rank)
  7950. out, opt[0], opt[1], opt[2] = ddp(
  7951. x, opt_1=opt[0], opt_2=opt[1], opt_nested=opt[2]
  7952. )
  7953. for i in range(len(opt)):
  7954. if torch.is_tensor(opt[i]):
  7955. self.assertEqual(opt[i].grad_fn, None)
  7956. else:
  7957. self.assertEqual(opt[i]["tensor"].grad_fn, None)
  7958. out.mean().backward()
  7959. @skip_if_lt_x_gpu(2)
  7960. @sandcastle_skip_if(
  7961. BACKEND not in DistTestCases.backend_feature["ddp"],
  7962. f"The {BACKEND} backend does not support DistributedDataParallel"
  7963. )
  7964. def test_ddp_new_tensor_in_fwd(self):
  7965. return self._test_ddp_new_tensor_in_fwd(static_graph=False)
  7966. @skip_if_lt_x_gpu(2)
  7967. @sandcastle_skip_if(
  7968. BACKEND not in DistTestCases.backend_feature["ddp"],
  7969. f"The {BACKEND} backend does not support DistributedDataParallel"
  7970. )
  7971. def test_ddp_new_tensor_in_fwd_static_graph(self):
  7972. return self._test_ddp_new_tensor_in_fwd(static_graph=True)
  7973. def _test_ddp_buffer_hook_allreduce(self, return_futures):
  7974. rank = self.rank
  7975. torch.cuda.set_device(rank)
  7976. torch.manual_seed(rank)
  7977. torch.cuda.manual_seed(rank)
  7978. def buffer_comm_hook(ddp, named_buffers):
  7979. buffers = [
  7980. buffer for (_, buffer) in named_buffers.items()
  7981. ]
  7982. futs = [
  7983. dist.all_reduce(buffer, group=ddp.process_group, async_op=True).get_future()
  7984. for buffer in buffers
  7985. ]
  7986. if return_futures:
  7987. return futs
  7988. else:
  7989. torch.futures.collect_all(futs).wait()
  7990. hook_pre_fwd = torch.nn.parallel.distributed._BufferCommHookLocation.PRE_FORWARD
  7991. hook_post_fwd = torch.nn.parallel.distributed._BufferCommHookLocation.POST_FORWARD
  7992. for hook_run_location in [
  7993. hook_pre_fwd,
  7994. hook_post_fwd,
  7995. ]:
  7996. model = NetWithBuffers().cuda(rank)
  7997. model_ddp = torch.nn.parallel.DistributedDataParallel(
  7998. model,
  7999. device_ids=[self.rank],
  8000. )
  8001. model_ddp._register_buffer_comm_hook(
  8002. model_ddp,
  8003. buffer_comm_hook,
  8004. hook_run_location
  8005. )
  8006. model_ddp_no_hook = torch.nn.parallel.DistributedDataParallel(
  8007. copy.deepcopy(model),
  8008. device_ids=[self.rank],
  8009. broadcast_buffers=False
  8010. )
  8011. inp = torch.randn(2, 10, device=rank)
  8012. for i in range(2):
  8013. loss_hook = model_ddp(inp).sum()
  8014. # Since buffer reduction is done pre-forward, simulate it for
  8015. # no hook case here.
  8016. # Simulate allreduce appropriately depending on hook location.
  8017. if hook_run_location == hook_pre_fwd:
  8018. model_no_hook_buffers = list(model_ddp_no_hook.module.buffers())
  8019. for tensor in model_no_hook_buffers:
  8020. dist.all_reduce(tensor)
  8021. loss_no_hook = model_ddp_no_hook(inp).sum()
  8022. if hook_run_location == hook_post_fwd:
  8023. model_no_hook_buffers = list(model_ddp_no_hook.module.buffers())
  8024. for tensor in model_no_hook_buffers:
  8025. dist.all_reduce(tensor)
  8026. torch.cuda.synchronize()
  8027. # if return_futures, they are only awaited on by DDP
  8028. # at the end of the backwards pass for maximum overlap.
  8029. if not return_futures:
  8030. self._verify_buffers_equal(model_ddp, model_ddp_no_hook)
  8031. loss_hook.backward()
  8032. loss_no_hook.backward()
  8033. # Note that when custom hooks return futures, this
  8034. # comparison is not expected to work when hook run location
  8035. # is pre-forward pass. This is because the hook does async
  8036. # communication and forward pass modifies the buffer without
  8037. # appropriate synchronization. Therefore, if returning
  8038. # futures from custom buffer hooks, it is advised to set
  8039. # hook run location to post forward.
  8040. if return_futures and hook_run_location == hook_post_fwd:
  8041. self._verify_buffers_equal(model_ddp, model_ddp_no_hook)
  8042. dist.barrier()
  8043. @skip_if_lt_x_gpu(2)
  8044. @sandcastle_skip_if(
  8045. BACKEND not in DistTestCases.backend_feature["ddp"],
  8046. f"The {BACKEND} backend does not support DistributedDataParallel"
  8047. )
  8048. def test_ddp_buffer_hook_allreduce_return_future(self):
  8049. self._test_ddp_buffer_hook_allreduce(
  8050. return_futures=True
  8051. )
  8052. @skip_if_lt_x_gpu(2)
  8053. @sandcastle_skip_if(
  8054. BACKEND not in DistTestCases.backend_feature["ddp"],
  8055. f"The {BACKEND} backend does not support DistributedDataParallel"
  8056. )
  8057. def test_ddp_buffer_hook_allreduce(self):
  8058. self._test_ddp_buffer_hook_allreduce(
  8059. return_futures=False
  8060. )
  8061. @skip_if_lt_x_gpu(2)
  8062. @sandcastle_skip_if(
  8063. BACKEND not in DistTestCases.backend_feature["ddp"],
  8064. f"The {BACKEND} backend does not support DistributedDataParallel"
  8065. )
  8066. def test_ddp_broadcast_buffer_via_hook(self):
  8067. # test that _distributed_broadcast_coalesced via registered hook is
  8068. # equivalent to DDP's default broadcast coalesced.
  8069. rank = self.rank
  8070. torch.cuda.set_device(rank)
  8071. torch.manual_seed(rank)
  8072. torch.cuda.manual_seed(rank)
  8073. def buffer_comm_hook(ddp, named_buffers):
  8074. # named_buffers is a Dict[str, Tensor] representing a mapping
  8075. # from buffer name to buffer.
  8076. buffers = [
  8077. buffer for (_, buffer) in named_buffers.items()
  8078. ]
  8079. ddp._default_broadcast_coalesced(buffers)
  8080. model = NetWithBuffers().cuda(rank)
  8081. model_ddp = torch.nn.parallel.DistributedDataParallel(
  8082. model,
  8083. device_ids=[self.rank],
  8084. )
  8085. model_ddp._register_buffer_comm_hook(
  8086. model_ddp,
  8087. buffer_comm_hook
  8088. )
  8089. model_ddp_no_hook = torch.nn.parallel.DistributedDataParallel(
  8090. copy.deepcopy(model),
  8091. device_ids=[self.rank],
  8092. )
  8093. inp = torch.randn(2, 10, device=rank)
  8094. for i in range(2):
  8095. loss_hook = model_ddp(inp).sum()
  8096. loss_no_hook = model_ddp_no_hook(inp).sum()
  8097. self._verify_buffers_equal(model_ddp, model_ddp_no_hook)
  8098. loss_hook.backward()
  8099. loss_no_hook.backward()
  8100. @skip_if_lt_x_gpu(2)
  8101. @sandcastle_skip_if(
  8102. BACKEND not in DistTestCases.backend_feature["ddp"],
  8103. f"The {BACKEND} backend does not support DistributedDataParallel"
  8104. )
  8105. def test_ddp_broadcast_buffer(self):
  8106. rank = self.rank
  8107. torch.cuda.set_device(rank)
  8108. torch.manual_seed(rank)
  8109. torch.cuda.manual_seed(rank)
  8110. class NetWithBuffers(nn.Module):
  8111. def __init__(self):
  8112. super().__init__()
  8113. self.a = nn.Linear(10, 10, bias=False)
  8114. self.b = nn.Linear(10, 1, bias=False)
  8115. self.register_buffer('buffer', torch.randn(1, 2))
  8116. def forward(self, x):
  8117. return self.b(self.a(x))
  8118. model = NetWithBuffers().cuda(rank)
  8119. model_ddp = torch.nn.parallel.DistributedDataParallel(
  8120. model,
  8121. device_ids=[self.rank],
  8122. )
  8123. inp = torch.randn(2, 10, device=rank)
  8124. for i in range(2):
  8125. if rank == 0:
  8126. model_ddp.module.buffer = model_ddp.module.buffer + 1
  8127. loss = model_ddp(inp).sum()
  8128. loss.backward()
  8129. # Ensure all buffers are synchronized.
  8130. bufs = [torch.empty_like(model_ddp.module.buffer) for _ in range(dist.get_world_size())]
  8131. dist.all_gather(bufs, model_ddp.module.buffer)
  8132. rank_0_buf = bufs[0]
  8133. for buf in bufs[1:]:
  8134. self.assertEqual(rank_0_buf, buf)
  8135. @skip_if_lt_x_gpu(2)
  8136. @sandcastle_skip_if(
  8137. BACKEND != "nccl" and BACKEND != "gloo",
  8138. "Only Nccl & Gloo backend support DistributedDataParallel",
  8139. )
  8140. def test_sync_bn_logged(self):
  8141. model = BN_NET
  8142. rank = self.rank
  8143. # single gpu training setup
  8144. model_gpu = model.cuda(rank)
  8145. no_sync_bn = torch.nn.parallel.DistributedDataParallel(
  8146. copy.deepcopy(model_gpu),
  8147. device_ids=[self.rank],
  8148. )
  8149. ddp_logging_data = no_sync_bn._get_ddp_logging_data()
  8150. sync_bn_logged = ddp_logging_data.get("has_sync_bn", True)
  8151. self.assertFalse(sync_bn_logged)
  8152. model_DDP = nn.SyncBatchNorm.convert_sync_batchnorm(model_gpu)
  8153. model_DDP = torch.nn.parallel.DistributedDataParallel(
  8154. model_DDP,
  8155. device_ids=[self.rank],
  8156. )
  8157. ddp_logging_data = model_DDP._get_ddp_logging_data()
  8158. sync_bn_logged = ddp_logging_data.get("has_sync_bn", False)
  8159. self.assertTrue(sync_bn_logged)
  8160. @skip_if_lt_x_gpu(2)
  8161. @sandcastle_skip_if(
  8162. BACKEND not in DistTestCases.backend_feature["ddp"],
  8163. f"The {BACKEND} backend does not support DistributedDataParallel"
  8164. )
  8165. def test_stateless_api_with_ddp(self):
  8166. class MockModule(torch.nn.Module):
  8167. def __init__(self):
  8168. super().__init__()
  8169. self.l1 = torch.nn.Linear(1, 1)
  8170. buffer = torch.ones(1)
  8171. self.register_buffer('buffer', buffer)
  8172. def forward(self, x):
  8173. return self.l1(x) + self.buffer
  8174. device = self.rank
  8175. module = MockModule().to(device)
  8176. # Disable DDP + ReplicatedTensor since stateless looks for 'module'
  8177. # whereas with ReplicatedTensor, we run '_replicated_tensor_module'
  8178. # in the forward pass.
  8179. from torch.nn.parallel._replicated_tensor_ddp_utils import _ddp_replicated_tensor
  8180. with _ddp_replicated_tensor(False):
  8181. module = torch.nn.parallel.DistributedDataParallel(
  8182. module,
  8183. device_ids=[device]
  8184. )
  8185. x = torch.rand((1, 1)).to(device)
  8186. weight = torch.tensor([[1.0]], device=device, requires_grad=True)
  8187. bias = torch.tensor([0.0], device=device, requires_grad=True)
  8188. buffer = torch.tensor([0.0], device=device)
  8189. parameters = {'module.l1.weight': weight,
  8190. 'module.l1.bias': bias,
  8191. 'module.buffer': buffer}
  8192. prev_weight = module.module.l1.weight.clone()
  8193. prev_buffer = module.module.buffer.clone()
  8194. res = torch.func.functional_call(module, parameters, x)
  8195. self.assertEqual(x, res)
  8196. # check that the weight remain unmodified
  8197. cur_weight = module.module.l1.weight
  8198. cur_buffer = module.module.buffer
  8199. self.assertEqual(cur_weight, prev_weight)
  8200. self.assertEqual(cur_buffer, prev_buffer)
  8201. # run a backward pass and check the gradients
  8202. res.backward()
  8203. self.assertIsNotNone(weight.grad)
  8204. self.assertIsNotNone(bias.grad)
  8205. # Gradient was not calculated for the module stated and buffers
  8206. self.assertIsNone(buffer.grad)
  8207. self.assertIsNone(module.module.l1.weight.grad)
  8208. self.assertIsNone(module.module.l1.bias.grad)
  8209. self.assertIsNone(module.module.buffer.grad)
  8210. @require_backend(DistTestCases.backend_feature["gpu"])
  8211. @require_backends_available(DistTestCases.backend_feature["gpu"])
  8212. @skip_if_lt_x_gpu(2)
  8213. def test_ddp_forward_backward_hook(self):
  8214. class DummyTestModel(nn.Module):
  8215. def __init__(self):
  8216. super().__init__()
  8217. torch.manual_seed(0)
  8218. self.fc = nn.Linear(2, 2)
  8219. def forward(self, x):
  8220. return self.fc(x)
  8221. def relu_hook(module, input):
  8222. return nn.functional.relu(input[0])
  8223. def gelu_hook(module, _input, output):
  8224. return nn.functional.gelu(output)
  8225. def celu_hook(module, _input, output):
  8226. return (nn.functional.celu(output[0]),)
  8227. local_model = DummyTestModel()
  8228. ddp_model = DummyTestModel()
  8229. local_model.fc.register_forward_pre_hook(relu_hook)
  8230. local_model.fc.register_forward_hook(gelu_hook)
  8231. ddp_model.fc.register_forward_pre_hook(relu_hook)
  8232. ddp_model.fc.register_forward_hook(gelu_hook)
  8233. local_model.fc.register_backward_hook(celu_hook)
  8234. ddp_model.fc.register_backward_hook(celu_hook)
  8235. ddp_model = DistributedDataParallel(
  8236. ddp_model.to(self.rank), device_ids=[self.rank]
  8237. )
  8238. input_data = torch.rand(5, 2)
  8239. output_local = local_model(input_data)
  8240. output_ddp = ddp_model(input_data.to(self.rank))
  8241. self.assertEqual(output_local, output_ddp)
  8242. output_local.sum().backward()
  8243. output_ddp.sum().backward()
  8244. ddp_grads = [p.grad for p in ddp_model.parameters()]
  8245. self.assertEqual(ddp_grads[0], local_model.fc.weight.grad)
  8246. self.assertEqual(ddp_grads[1], local_model.fc.bias.grad)
  8247. def _test_hook_pickling(self, hook, hook_state):
  8248. torch.manual_seed(0)
  8249. learning_rate = 0.01
  8250. chkpt_file = tempfile.gettempdir() + "/checkpoint.pt"
  8251. rank = self.rank
  8252. input = torch.randn(7, 1, device=rank)
  8253. target = torch.randn(7, 5, device=rank)
  8254. net = torch.nn.Linear(1, 5).to(rank)
  8255. ddp_model = DistributedDataParallel(
  8256. copy.deepcopy(net),
  8257. device_ids=[rank]
  8258. )
  8259. dummy_ddp_model = DistributedDataParallel(
  8260. copy.deepcopy(net),
  8261. device_ids=[rank]
  8262. )
  8263. optimizer = torch.optim.SGD(ddp_model.parameters(), lr=learning_rate)
  8264. ddp_model.register_comm_hook(hook_state, hook)
  8265. ddp_model.train()
  8266. for _ in range(10):
  8267. optimizer.zero_grad()
  8268. out = ddp_model(input)
  8269. loss = F.mse_loss(out, target)
  8270. loss.backward()
  8271. optimizer.step()
  8272. state = {
  8273. 'state_dict': ddp_model.state_dict(),
  8274. 'comm_hook': hook,
  8275. 'comm_hook_state': hook_state
  8276. }
  8277. if rank == 0:
  8278. with self.assertLogs() as captured:
  8279. torch.save(state, chkpt_file)
  8280. # Check that the logger has only one entry
  8281. self.assertEqual(len(captured.records), 1)
  8282. # Check that the logger has an expected entry
  8283. self.assertEqual(
  8284. captured.records[0].getMessage(),
  8285. "NOTE: Process group is not serializable and excluded from a saved state."
  8286. )
  8287. dist.barrier()
  8288. map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
  8289. with self.assertLogs() as captured:
  8290. checkpoint = torch.load(chkpt_file, map_location=map_location)
  8291. # Check that the logger has only one entry
  8292. self.assertEqual(len(captured.records), 1)
  8293. # Check that the logger has an expected entry
  8294. self.assertEqual(
  8295. captured.records[0].getMessage(),
  8296. "NOTE: Process group will be set to a default group (i.e. the world size).\
  8297. If a different group is desired, please set `self.process_group` after PowerSGD state is loaded."
  8298. )
  8299. dummy_ddp_model.load_state_dict(checkpoint['state_dict'])
  8300. dummy_hook = checkpoint['comm_hook']
  8301. dummy_hook_state = checkpoint['comm_hook_state']
  8302. dummy_optimizer = torch.optim.SGD(dummy_ddp_model.parameters(), lr=learning_rate)
  8303. # Check that loaded function is correct
  8304. self.assertEqual(dummy_hook.__qualname__, hook.__qualname__)
  8305. # Check that all slots' keys were restored correctly
  8306. self.assertEqual(hook_state.__slots__, dummy_hook_state.__slots__)
  8307. # Check that all slots' attributes are restored correctly
  8308. # Excluding ``process_group`` and ``rng``.
  8309. for entry in dummy_hook_state.__slots__:
  8310. if entry != "process_group" and entry != "rng":
  8311. self.assertEqual(getattr(dummy_hook_state, entry), getattr(hook_state, entry))
  8312. # Check that ``process_group`` was set to default
  8313. self.assertEqual(dummy_hook_state.process_group, _get_default_group())
  8314. # Check that a random state was restored properly:
  8315. # ``np.random.RandomState.get_state`` returns a tuple with entries:
  8316. # ``bit_generator`` - str,
  8317. # ``state.key`` - ndarray dtype[uint32],
  8318. # ``state.pos`` - int,
  8319. # ``has_gauss`` - int,
  8320. # ``gauss`` - float
  8321. # (refer to https://github.com/numpy/numpy/blob/266aad7478bc7fbcc55eea7f942a0d373b838396/numpy/random/mtrand.pyi)
  8322. # To make sure random state was restored properly, all entries should equal the original
  8323. for entry1, entry2 in zip(hook_state.rng.get_state(), dummy_hook_state.rng.get_state()):
  8324. np.testing.assert_array_equal(entry1, entry2)
  8325. dummy_ddp_model.register_comm_hook(dummy_hook_state, dummy_hook)
  8326. dummy_ddp_model.train()
  8327. for _ in range(10):
  8328. optimizer.zero_grad()
  8329. dummy_optimizer.zero_grad()
  8330. out_origin = ddp_model(input)
  8331. out_dummy = dummy_ddp_model(input)
  8332. loss_origin = F.mse_loss(out_origin, target)
  8333. loss_dummy = F.mse_loss(out_dummy, target)
  8334. loss_origin.backward()
  8335. loss_dummy.backward()
  8336. optimizer.step()
  8337. dummy_optimizer.step()
  8338. # Check that gradients after 10 epochs are the same
  8339. for orig_param, dummy_param in zip(ddp_model.parameters(), dummy_ddp_model.parameters()):
  8340. self.assertEqual(orig_param.grad, dummy_param.grad)
  8341. dist.barrier()
  8342. if rank == 0:
  8343. os.remove(chkpt_file)
  8344. @sandcastle_skip_if(
  8345. BACKEND not in DistTestCases.backend_feature["cuda"],
  8346. f"The {BACKEND} backend does not support DDP communication hook on CUDA devices"
  8347. )
  8348. @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
  8349. @sandcastle_skip_if(BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally")
  8350. def test_ddp_hook_pickling_powerSGD(self):
  8351. hook = powerSGD.powerSGD_hook
  8352. powersgd_state = powerSGD.PowerSGDState(
  8353. process_group=None,
  8354. matrix_approximation_rank=1,
  8355. start_powerSGD_iter=4,
  8356. )
  8357. self._test_hook_pickling(hook, powersgd_state)
  8358. instantiate_parametrized_tests(DistributedTest._DistTestBase)