ir.py 132 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932393339343935393639373938393939403941394239433944394539463947394839493950395139523953395439553956395739583959396039613962396339643965396639673968396939703971397239733974397539763977397839793980398139823983398439853986398739883989399039913992399339943995399639973998399940004001400240034004400540064007400840094010401140124013401440154016401740184019402040214022402340244025402640274028402940304031403240334034403540364037403840394040404140424043404440454046404740484049405040514052405340544055405640574058405940604061406240634064406540664067406840694070407140724073407440754076407740784079408040814082408340844085408640874088408940904091409240934094409540964097409840994100410141024103410441054106410741084109411041114112411341144115411641174118411941204121412241234124412541264127412841294130413141324133413441354136413741384139414041414142414341444145414641474148414941504151415241534154415541564157415841594160416141624163416441654166416741684169417041714172417341744175417641774178417941804181418241834184418541864187418841894190
  1. import contextlib
  2. import dataclasses
  3. import functools
  4. import itertools
  5. import logging
  6. import re
  7. import textwrap
  8. from collections import OrderedDict
  9. from contextlib import nullcontext
  10. from enum import Enum
  11. from functools import partial
  12. from inspect import signature
  13. from typing import Any, Callable, ClassVar, Dict, List, Optional, Set, Tuple, Union
  14. from unittest.mock import patch
  15. import sympy
  16. from sympy import Expr, Integer
  17. import torch.fx
  18. import torch.utils._pytree as pytree
  19. from torch._prims_common import (
  20. is_boolean_dtype,
  21. is_float_dtype,
  22. make_channels_last_strides_for,
  23. make_contiguous_strides_for,
  24. )
  25. from torch.fx.experimental.symbolic_shapes import FloorDiv
  26. from . import config, dependencies
  27. from .codegen.common import index_prevent_reordering
  28. from .cuda_properties import get_device_properties
  29. from .dependencies import extract_read_writes, var_builder
  30. from .utils import (
  31. argsort,
  32. cache_on_self,
  33. convert_shape_to_inductor,
  34. convert_shape_to_symint,
  35. developer_warning,
  36. sympy_dot,
  37. sympy_product,
  38. sympy_subs,
  39. sympy_symbol,
  40. )
  41. from .virtualized import ops, V
  42. log = logging.getLogger(__name__)
  43. indent = functools.partial(textwrap.indent, prefix=" ")
  44. aten = torch.ops.aten
  45. """ [Note: Inductor IR]
  46. Inductor's IR is produced by executing 'lowering' code (see lowering.py). Each
  47. lowering is registered to a particular aten operator, and expects inputs that
  48. correspond to the aten schema. However, in place of torch Tensor inputs, lowerings
  49. expect Inductor TensorBox inputs.
  50. TensorBox IR represents torch tensors. Tensors are sometimes single objects owning
  51. storage, and sometimes views of another Tensor's storage. Mutating tensor operations
  52. (such as add_()) affect the underlying storage and any associated views. Other operations
  53. (such as .t_()) update metadata about the current view but don't modify the underlying storage.
  54. To model this in Inductor, the IR distinguishes between TensorBox, View, StorageBox and Buffer.
  55. TensorBox is the top level IR construct that any lowering should produce and maps to a torch.Tensor
  56. output from an operation. But just as torch.Tensors take different forms, TensorBox IR can
  57. reference View IR or directly reference StorageBox IRs.
  58. Some Inductor lowerings produce new sets of 'Box'es, while others (such as .t() or other view ops)
  59. may take an existing TensorBox and point it to a new underlying View IR.
  60. Tensors that directly own storage are represented as a chain of:
  61. TensorBox -> StorageBox -> Buffer
  62. where Buffer is a simple (1D) allocation, and StorageBox introduces the concept of a Layout.
  63. If you mutate the data of such a tensor, we swing the StorageBox pointer to point to a new buffer
  64. (leaving the old buffer unmodified and functionalizing the operation).
  65. Tensors backed by views add one more indirection to the IR.
  66. TensorBox -> View -> StorageBox -> Buffer
  67. In these cases, the underlying StorageBox/Buffer will be shared with the pre-view TensorBox.
  68. For metadata mutation (e.g. as_strided_) we swing the TensorBox pointer.
  69. """
  70. def validate_ir(node_or_nodes):
  71. def _check_tensorbox(node):
  72. # Could expand this to check deeper properties
  73. # (e.g. TensorBox points to View or StorageBox)
  74. assert isinstance(
  75. node,
  76. (
  77. TensorBox,
  78. RandSeedBuffer,
  79. torch.fx.experimental.symbolic_shapes.Symbol,
  80. Expr,
  81. ),
  82. ), f"Found {type(node)}, which is not a supported top level IR node. See [Note: Inductor IR]"
  83. # Be picky about the accepted data structure (don't use pytree here)
  84. if isinstance(node_or_nodes, (List, Tuple)):
  85. for node in node_or_nodes:
  86. _check_tensorbox(node)
  87. else:
  88. _check_tensorbox(node_or_nodes)
  89. def inverse_reorder(order):
  90. inv_order = dict(zip(order, range(len(order))))
  91. def reindex(index):
  92. assert len(index) == len(inv_order)
  93. return [index[inv_order[i]] for i in range(len(index))]
  94. return reindex
  95. def same_reorder(order):
  96. def reindex(index):
  97. assert len(index) == len(order)
  98. return [index[order[i]] for i in range(len(index))]
  99. return reindex
  100. def fuse_reindexing(reindex1, reindex2):
  101. def reindex(index):
  102. return reindex1(reindex2(index))
  103. return reindex
  104. def stride_order2fill_order(order):
  105. """
  106. Convert stride order to fill order
  107. For channel last format,
  108. stride order = [3, 0, 2, 1] and fill order = [1, 3, 2, 0]
  109. """
  110. lookup = {pos: idx for idx, pos in enumerate(order)}
  111. fill_order = [lookup[i] for i in range(len(order))]
  112. return fill_order
  113. def get_stride_order(seq):
  114. """
  115. Convert strides to stride order
  116. """
  117. sorted_idx = argsort(seq)
  118. out = [None for _ in range(len(seq))]
  119. for i, elem in enumerate(sorted_idx):
  120. out[elem] = i
  121. return out
  122. def reads_from_conv(buf, var_ranges):
  123. """
  124. return:
  125. if reads_from_conv: boolean
  126. the new memory_addr: Sympy Expression
  127. """
  128. if buf is None:
  129. return False, None
  130. if isinstance(buf, Convolution):
  131. indexer = buf.layout.as_fixed().make_indexer()
  132. index_vars = sorted(var_ranges, key=lambda var: var.name)
  133. index = indexer(index_vars)
  134. return True, index
  135. # for case like
  136. # buf0 = conv(x, w)
  137. # return torch.cat([buf0, buf1]), torch.cat([buf0, buf2])
  138. # Because of ConcatKernel, it will create two bufs buf3 and 4
  139. # buf3 has the AliasedLayout which reads from buf0(Convolution)
  140. # but buf4 is a copy of buf3 which reads from buf3
  141. # we want to know that buf4 also follows buf0 conv's layout
  142. if isinstance(buf.layout, AliasedLayout):
  143. reads = buf.get_read_writes().reads
  144. reads_bufs = [
  145. V.graph.name_to_buffer[r.name]
  146. if r.name in V.graph.name_to_buffer.keys()
  147. else None
  148. for r in reads
  149. ]
  150. for reads_buf in reads_bufs:
  151. read_from_conv, addr = reads_from_conv(reads_buf, var_ranges)
  152. if read_from_conv:
  153. return True, addr
  154. return False, None
  155. def ir_node_to_tensor(x, guard_shape=True):
  156. if not guard_shape:
  157. shape_fn = V.graph.sizevars.size_hint
  158. else:
  159. def nop(x):
  160. return x
  161. shape_fn = nop
  162. size = [shape_fn(s) for s in x.get_size()]
  163. if is_storage_and_layout(x):
  164. stride = [shape_fn(s) for s in x.get_layout().stride]
  165. else:
  166. stride = make_contiguous_strides_for(size)
  167. dtype = x.get_dtype()
  168. device = x.get_device()
  169. size = convert_shape_to_symint(size)
  170. stride = convert_shape_to_symint(stride)
  171. t = torch.empty_strided(
  172. size=size, stride=stride, dtype=dtype, device=device
  173. ).zero_()
  174. return t
  175. def layout_priority_idx(reads_bufs, memory_addrs, var_ranges):
  176. """
  177. if reads from conv that needs to use specific layout
  178. return:
  179. priority_idx regarding memory_addrs idx
  180. memory_addrs - update memory_addrs with the true addr if needed
  181. """
  182. priority_idx = []
  183. for i, reads_buf in enumerate(reads_bufs):
  184. read_from_conv, mem_addr = reads_from_conv(reads_buf, var_ranges)
  185. if read_from_conv:
  186. priority_idx.append(i)
  187. memory_addrs[i] = mem_addr
  188. return priority_idx, memory_addrs
  189. class ModularIndexing(sympy.Function):
  190. """
  191. ModularIndexing(a, b, c) => (a // b) % c
  192. """
  193. nargs = (3,)
  194. is_integer = True
  195. @classmethod
  196. def eval(cls, base, divisor, modulus):
  197. if base == 0 or modulus == 1:
  198. return sympy.Integer(0)
  199. if (
  200. isinstance(base, sympy.Integer)
  201. and isinstance(divisor, sympy.Integer)
  202. and isinstance(modulus, sympy.Integer)
  203. ):
  204. return (base // divisor) % modulus
  205. if divisor != 1:
  206. gcd = sympy.gcd(base, divisor)
  207. if gcd != 1:
  208. return ModularIndexing(base / gcd, divisor / gcd, modulus)
  209. if isinstance(base, sympy.Add):
  210. new_terms = []
  211. all_positive = True
  212. for term in base.args:
  213. if sympy.gcd(term, modulus * divisor) != modulus * divisor:
  214. if (isinstance(term, sympy.Integer) and term < 0) or (
  215. isinstance(term, sympy.Mul)
  216. and isinstance(term.args[0], sympy.Integer)
  217. and term.args[0] < 0
  218. ):
  219. # workaround for https://github.com/openai/triton/issues/619,
  220. # if there are negative terms, // produces wrong result
  221. # TODO if https://github.com/openai/triton/issues/619 is fixed
  222. # this optimization would become valid
  223. all_positive = False
  224. break
  225. else:
  226. new_terms.append(term)
  227. if len(new_terms) != len(base.args) and all_positive:
  228. return ModularIndexing(sum(new_terms), divisor, modulus)
  229. if isinstance(base, FloorDiv):
  230. return ModularIndexing(base.args[0], base.args[1] * divisor, modulus)
  231. class CleanDiv(FloorDiv):
  232. """
  233. Div where we can assume no rounding.
  234. This is to enable future optimizations.
  235. """
  236. pass
  237. class CeilDiv(sympy.Function):
  238. """
  239. Div used in indexing that rounds up.
  240. """
  241. is_integer = True
  242. def __new__(cls, base, divisor):
  243. if sympy.gcd(base, divisor) == divisor:
  244. return CleanDiv(base, divisor)
  245. else:
  246. return FloorDiv(base + (divisor - 1), divisor)
  247. def get_device_type(x):
  248. if getattr(x, "get_device", None):
  249. return get_device_type(x.get_device())
  250. if isinstance(x, torch.device):
  251. return x.type
  252. return None
  253. def is_triton(x):
  254. return get_device_type(x) == "cuda"
  255. def is_cpu(x):
  256. return get_device_type(x) == "cpu"
  257. @dataclasses.dataclass
  258. class IRNode:
  259. _current_origins: ClassVar[Set[Any]] = set()
  260. @staticmethod
  261. @contextlib.contextmanager
  262. def current_origins(origins: Set[torch.fx.Node]):
  263. old = IRNode._current_origins
  264. IRNode._current_origins = old | origins
  265. yield
  266. IRNode._current_origins = old
  267. def __post_init__(self):
  268. self.origins = set(self._current_origins)
  269. def common_repr(self):
  270. return (
  271. [f"origins={self.origins}"] if hasattr(self, "origins") else ["no origins?"]
  272. )
  273. def str_helper(self, lines):
  274. lines = lines + self.common_repr()
  275. lines = indent(",\n".join(map(str, lines)))
  276. return f"{type(self).__name__}(\n{lines}\n)"
  277. def is_user_of(self, name):
  278. return any(name == dep.name for dep in self.get_reads())
  279. def get_numel(self):
  280. return sympy_product(self.get_size())
  281. @dataclasses.dataclass
  282. class Loops(IRNode):
  283. device: torch.device
  284. dtype: torch.dtype
  285. inner_fn: Callable
  286. ranges: List[Expr]
  287. def __str__(self, names=("ranges",)):
  288. return self.str_helper(
  289. [
  290. f"'{self.device.type}'",
  291. str(self.dtype),
  292. self.inner_fn_str(),
  293. ]
  294. + [f"{name}={getattr(self, name)}" for name in names]
  295. )
  296. __repr__ = __str__
  297. def get_dtype(self):
  298. return self.dtype
  299. def get_device(self):
  300. return self.device
  301. def get_size(self):
  302. return self.ranges
  303. def is_extern(self):
  304. return False
  305. @classmethod
  306. def create(cls, *args, **kwargs):
  307. return TensorBox.create(cls(*args, **kwargs))
  308. @staticmethod
  309. def _index(ranges, prefix="i"):
  310. return [
  311. sympy.Integer(0) if s == 1 else sympy_symbol(f"{prefix}{n}")
  312. for n, s in enumerate(ranges)
  313. ]
  314. @cache_on_self
  315. def inner_fn_str(self):
  316. formatter = V.KernelFormatterHandler(V.MockHandler())
  317. with V.set_ops_handler(formatter), patch.object(
  318. FlexibleLayout, "allow_indexing", True
  319. ):
  320. result = self.inner_fn(self._index(self.ranges))
  321. return formatter.getvalue(result)
  322. def is_zero_elements(self):
  323. return any(r == 0 for r in self.ranges)
  324. @cache_on_self
  325. def get_reads(self):
  326. with patch.object(FlexibleLayout, "allow_indexing", True):
  327. if self.get_reduction_type():
  328. return extract_read_writes(
  329. self.make_loader(),
  330. self.get_size(),
  331. self.get_reduction_size(),
  332. ).reads
  333. else:
  334. return extract_read_writes(
  335. self.make_loader(),
  336. self.get_size(),
  337. ).reads
  338. class Pointwise(Loops):
  339. def make_loader(self):
  340. return self.inner_fn
  341. def get_reduction_size(self):
  342. return []
  343. def get_reduction_type(self):
  344. return None
  345. def store_output(self, output_name, indexer, vars):
  346. return ops.store(output_name, indexer(vars), self.inner_fn(vars))
  347. def constant_to_device(self, device):
  348. """Move this to a given device. Requires that all reads are to constants."""
  349. loader = self.make_loader()
  350. loader = patch.object(ConstantBuffer, "override_device", device)(loader)
  351. return Pointwise(device, self.dtype, loader, self.ranges)
  352. @dataclasses.dataclass
  353. class Scatter(Pointwise):
  354. output_indexer: Callable[[List[Expr]], Expr]
  355. scatter_mode: Optional[str] = None
  356. def constant_to_device(self, device):
  357. """Move this to a given device. Requires that all reads are to constants."""
  358. loader = self.make_loader()
  359. loader = patch.object(ConstantBuffer, "override_device", device)(loader)
  360. return Scatter(
  361. device,
  362. self.dtype,
  363. loader,
  364. self.ranges,
  365. self.output_indexer,
  366. self.scatter_mode,
  367. )
  368. def store_output(self, output_name, indexer, vars):
  369. return ops.store(
  370. output_name,
  371. indexer(self.output_indexer(vars)),
  372. self.inner_fn(vars),
  373. mode=self.scatter_mode,
  374. )
  375. class ReductionHint(Enum):
  376. INNER = 0
  377. OUTER = 1
  378. OUTER_TINY = 2
  379. DEFAULT = 3
  380. class TileHint(Enum):
  381. SQUARE = 0
  382. DEFAULT = 1
  383. @dataclasses.dataclass
  384. class Reduction(Loops):
  385. reduction_ranges: List[Expr]
  386. reduction_type: str
  387. # self.dtype represents the dst dtype
  388. src_dtype: torch.dtype
  389. reduction_hint: ReductionHint
  390. def __str__(self):
  391. return Loops.__str__(
  392. self, names=("ranges", "reduction_ranges", "reduction_type")
  393. )
  394. __repr__ = __str__
  395. def get_reduction_size(self):
  396. return self.reduction_ranges
  397. def get_reduction_type(self):
  398. return self.reduction_type
  399. def store_reduction(self, output_name, indexer, vars, reduction_vars):
  400. return ops.reduction(
  401. output_name,
  402. self.dtype,
  403. self.src_dtype,
  404. self.reduction_type,
  405. indexer(vars),
  406. self.inner_fn(vars, reduction_vars),
  407. )
  408. def index_length(self):
  409. return len(self.ranges) + len(self.reduction_ranges)
  410. @cache_on_self
  411. def inner_fn_str(self):
  412. formatter = V.KernelFormatterHandler(V.MockHandler())
  413. with V.set_ops_handler(formatter), patch.object(
  414. FlexibleLayout, "allow_indexing", True
  415. ):
  416. result = self.inner_fn(
  417. self._index(self.ranges),
  418. self._index(self.reduction_ranges, "r"),
  419. )
  420. return formatter.getvalue(result)
  421. def constant_to_device(self, device):
  422. """Move this to a given device. Requires that all reads are to constants."""
  423. loader = self.make_loader()
  424. loader = patch.object(ConstantBuffer, "override_device", device)(loader)
  425. return Reduction(
  426. device,
  427. self.dtype,
  428. loader,
  429. self.ranges,
  430. self.reduction_ranges,
  431. self.reduction_type,
  432. self.src_dtype,
  433. ReductionHint.DEFAULT,
  434. )
  435. @staticmethod
  436. def num_splits(
  437. device,
  438. dst_dtype,
  439. src_dtype,
  440. inner_fn,
  441. ranges,
  442. reduction_ranges,
  443. reduction_type,
  444. reduction_numel,
  445. ):
  446. num_sm = get_device_properties(device).multi_processor_count
  447. min_elements_per_thread = 32
  448. max_elements_per_thread = 512
  449. threads_per_sm = 2048
  450. min_elements_per_device = min_elements_per_thread * num_sm * threads_per_sm
  451. max_elements_per_device = max_elements_per_thread * num_sm * threads_per_sm
  452. def inner_reduction_splits(reduction_numel_hint, numel_hint):
  453. # do heuristics that's close to eager mode for split inner reduction
  454. # we leak reduction autotune configs here, and will need to refactor to avoid this later
  455. num_warps = 8
  456. num_threads = 32 * num_warps
  457. if numel_hint >= 2 * num_sm: # don't split if there are enough outputs
  458. return 1
  459. if reduction_numel_hint <= 8192:
  460. return 1
  461. if reduction_numel_hint * numel_hint <= min_elements_per_device:
  462. split_size = min_elements_per_thread
  463. elif reduction_numel_hint * numel_hint < max_elements_per_device:
  464. target_blocks = num_sm * threads_per_sm // (2 * num_threads)
  465. blocks_per_output = (target_blocks + numel_hint - 1) // numel_hint
  466. tmp_split_size = (
  467. reduction_numel_hint + num_threads * blocks_per_output - 1
  468. ) // (num_threads * blocks_per_output)
  469. divisors = sympy.divisors(reduction_numel_hint)
  470. closest = min(divisors, key=lambda x: abs(x - tmp_split_size))
  471. if abs(closest - tmp_split_size) < 30:
  472. # prefer even splits, but never smalle than min_elements_per_thread
  473. split_size = max(closest, min_elements_per_thread)
  474. else:
  475. split_size = tmp_split_size
  476. else:
  477. divisors = sympy.divisors(reduction_numel_hint)
  478. closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread))
  479. if abs(closest - max_elements_per_thread) < 50:
  480. # prefer even splits
  481. split_size = closest
  482. else:
  483. split_size = max_elements_per_thread
  484. return (reduction_numel_hint + split_size * num_threads - 1) // (
  485. split_size * num_threads
  486. )
  487. def outer_reduction_splits(reduction_numel_hint, numel_hint):
  488. # TODO the best heuristic currently has XBLOCK (corresponding to numel_hint) 128
  489. # extend to even smaller number of outputs
  490. num_warps = 8
  491. num_threads = num_warps * 32
  492. rvals_per_thread = 4 # comes from heuristics, refactor to not leak here
  493. xvals_per_block = 128
  494. xblocks = (numel_hint + xvals_per_block - 1) // xvals_per_block
  495. if reduction_numel_hint * numel_hint < min_elements_per_device:
  496. split_size = min_elements_per_thread
  497. elif reduction_numel_hint * numel_hint < max_elements_per_device:
  498. target_blocks = num_sm * threads_per_sm // (num_threads)
  499. target_blocks = (target_blocks + xblocks - 1) // xblocks
  500. tmp_split_size = (
  501. reduction_numel_hint + rvals_per_thread * target_blocks - 1
  502. ) // (rvals_per_thread * target_blocks)
  503. divisors = sympy.divisors(reduction_numel_hint)
  504. closest = min(divisors, key=lambda x: abs(x - tmp_split_size))
  505. if abs(tmp_split_size - closest) < 20:
  506. split_size = max(closest, min_elements_per_thread)
  507. else:
  508. split_size = tmp_split_size
  509. else:
  510. divisors = sympy.divisors(reduction_numel_hint)
  511. closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread))
  512. if abs(closest - max_elements_per_thread) < 50:
  513. # prefer even splits
  514. split_size = closest
  515. else:
  516. split_size = max_elements_per_thread
  517. return (reduction_numel_hint + rvals_per_thread * split_size - 1) // (
  518. rvals_per_thread * split_size
  519. )
  520. reduction_numel_hint = V.graph.sizevars.size_hint(reduction_numel)
  521. numel_hint = V.graph.sizevars.size_hint(sympy_product(ranges))
  522. # easy cases
  523. if numel_hint == 1:
  524. return ReductionHint.INNER, inner_reduction_splits(
  525. reduction_numel_hint, numel_hint
  526. )
  527. if (
  528. reduction_numel_hint <= min_elements_per_thread
  529. or numel_hint >= num_sm * 2 * 32
  530. ):
  531. return ReductionHint.DEFAULT, 1
  532. r = Reduction(
  533. device,
  534. dst_dtype,
  535. inner_fn,
  536. ranges,
  537. reduction_ranges,
  538. reduction_type,
  539. src_dtype,
  540. ReductionHint.DEFAULT,
  541. )
  542. def get_read_indices(r):
  543. cb = ComputedBuffer(
  544. name=None,
  545. layout=FlexibleLayout(
  546. device=r.get_device(),
  547. dtype=r.get_dtype(),
  548. size=r.get_size(),
  549. ),
  550. data=r,
  551. )
  552. read_writes = cb.get_read_writes()
  553. # try finding the full size producer
  554. # TODO this will fail for something like ((1, N) * (N, 1)).sum()
  555. # this would also possibly be wrong for producers with the different contiguity but we hope those cases are rare
  556. range_vars = [
  557. r
  558. for r in read_writes.range_vars
  559. if isinstance(r, sympy.Expr) and not isinstance(r, sympy.Number)
  560. ]
  561. indices = []
  562. changed = False
  563. for md in sorted(read_writes.reads, key=lambda x: x.name):
  564. if all([r in md.index.free_symbols for r in range_vars]):
  565. indices.append(md.index)
  566. if md.name in V.graph.name_to_buffer:
  567. buf = V.graph.name_to_buffer[md.name]
  568. original_stride = buf.layout.stride
  569. buf.decide_layout()
  570. if buf.layout.stride != original_stride:
  571. changed = True
  572. return indices, changed
  573. indices, changed = get_read_indices(r)
  574. if changed:
  575. indices, _ = get_read_indices(r)
  576. if len(indices) == 0:
  577. # TODO determine splits when all inputs are broadcast
  578. return ReductionHint.DEFAULT, 1
  579. _, (_, reduction_vars), _ = dependencies.index_vars_squeeze(
  580. r.get_size(), r.get_reduction_size()
  581. )
  582. num_outer = 0
  583. num_inner = 0
  584. for i in indices:
  585. strides = V.graph.sizevars.stride_hints(i, reduction_vars)
  586. outer = all([s > 1 for s in strides])
  587. if outer:
  588. num_outer += 1
  589. else:
  590. num_inner += 1
  591. if num_inner > num_outer:
  592. return ReductionHint.INNER, inner_reduction_splits(
  593. reduction_numel_hint, numel_hint
  594. )
  595. else:
  596. return ReductionHint.OUTER, outer_reduction_splits(
  597. reduction_numel_hint, numel_hint
  598. )
  599. @staticmethod
  600. def _unroll_reduction_fn(inner_fn, reduction_ranges, reduction_type):
  601. """Convert inner_fn from a reduction to an pointwise"""
  602. reduction_ranges = [
  603. V.graph.sizevars.guard_static_shape(x) for x in reduction_ranges
  604. ]
  605. if reduction_type == "sum":
  606. def combine_fn(a, b):
  607. return ops.add(a, b)
  608. elif reduction_type == "min":
  609. def combine_fn(a, b):
  610. return ops.minimum(a, b)
  611. elif reduction_type == "max":
  612. def combine_fn(a, b):
  613. return ops.maximum(a, b)
  614. elif reduction_type == "any":
  615. def combine_fn(a, b):
  616. return ops.logical_or(a, b)
  617. elif reduction_type == "argmin":
  618. def combine_fn(a, b):
  619. return ops.minimum(a[0], b[0]), ops.where(
  620. ops.lt(b[0], a[0]), b[1], a[1]
  621. )
  622. elif reduction_type == "argmax":
  623. def combine_fn(a, b):
  624. return ops.maximum(a[0], b[0]), ops.where(
  625. ops.gt(b[0], a[0]), b[1], a[1]
  626. )
  627. else:
  628. raise NotImplementedError(f"unknown reduction_type={reduction_type}")
  629. def fn(index):
  630. return functools.reduce(
  631. combine_fn,
  632. (
  633. value_fn(index, rindex)
  634. for rindex in itertools.product(
  635. *[range(x) for x in reduction_ranges]
  636. )
  637. ),
  638. )
  639. if reduction_type in ("argmin", "argmax"):
  640. flatten_index = FixedLayout(
  641. None,
  642. None,
  643. reduction_ranges,
  644. FlexibleLayout.contiguous_strides(reduction_ranges),
  645. ).make_indexer()
  646. def value_fn(index, rindex):
  647. rindex = [sympy.expand(i) for i in rindex]
  648. return (
  649. inner_fn(index, rindex),
  650. ops.index_expr(flatten_index(rindex), torch.int64),
  651. )
  652. return lambda index: fn(index)[1]
  653. else:
  654. value_fn = inner_fn
  655. return fn
  656. @classmethod
  657. def create(
  658. cls,
  659. device: torch.device,
  660. dst_dtype: torch.dtype,
  661. src_dtype: torch.dtype,
  662. inner_fn: Callable,
  663. ranges: List[Expr],
  664. reduction_ranges: List[Expr],
  665. reduction_type: str,
  666. reduction_hint: ReductionHint = ReductionHint.DEFAULT,
  667. ):
  668. reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges))
  669. if reduction_numel == 0:
  670. # N.B. This is a hack to generate the literal of the given type
  671. # Ideally, we should be fixing `def constant` in triton.py
  672. # but it breaks due to hardcoded dtypes in other places
  673. def py_cnst(val):
  674. return (
  675. bool(val)
  676. if dst_dtype == torch.bool
  677. else float(val)
  678. if dst_dtype.is_floating_point
  679. else int(val)
  680. )
  681. rtypes_to_inits = {
  682. "sum": py_cnst(0),
  683. "prod": py_cnst(1),
  684. "any": py_cnst(0),
  685. # "all" is desugared to `!any(!val)`
  686. }
  687. assert (
  688. reduction_type in rtypes_to_inits.keys()
  689. ), f"{reduction_type} not supported for zero-dimension tensors!"
  690. def const_fn(index):
  691. return ops.constant(rtypes_to_inits[reduction_type], dst_dtype)
  692. return Pointwise.create(
  693. device=device,
  694. dtype=src_dtype,
  695. inner_fn=const_fn,
  696. ranges=list(ranges),
  697. )
  698. if reduction_numel == 1:
  699. # this reduction is actually a pointwise op
  700. if reduction_type in ("argmin", "argmax"):
  701. def fn(index):
  702. return ops.constant(0, dst_dtype)
  703. else:
  704. def fn(index):
  705. reduction_index = [sympy.Integer(0) for _ in reduction_ranges]
  706. return inner_fn(index, reduction_index)
  707. return Pointwise.create(device, dst_dtype, fn, ranges)
  708. if (
  709. isinstance(reduction_numel, sympy.Integer)
  710. and V.graph.sizevars.size_hint(reduction_numel)
  711. < config.unroll_reductions_threshold
  712. and sympy_product(ranges) != 1
  713. ):
  714. return Pointwise.create(
  715. device,
  716. dst_dtype,
  717. cls._unroll_reduction_fn(inner_fn, reduction_ranges, reduction_type),
  718. ranges,
  719. )
  720. if is_triton(device) and reduction_type not in {"argmax", "argmin"}:
  721. # triton doesn't support reduce to single element well, so break it up
  722. hint, split = cls.num_splits(
  723. device,
  724. dst_dtype,
  725. src_dtype,
  726. inner_fn,
  727. ranges,
  728. reduction_ranges,
  729. reduction_type,
  730. reduction_numel,
  731. )
  732. # intermediate reduction in split can contain complex indexing,
  733. # and num_splits will fail to correctly set the hint
  734. # reuse the passed hint if available
  735. if reduction_hint == ReductionHint.DEFAULT:
  736. reduction_hint = hint
  737. if split > 1:
  738. # triton doesn't support reduce to single element well, so break it up
  739. return cls.create_multilayer(
  740. device,
  741. dst_dtype,
  742. src_dtype,
  743. inner_fn,
  744. ranges,
  745. reduction_ranges,
  746. reduction_type,
  747. split,
  748. reduction_hint,
  749. )
  750. return TensorBox.create(
  751. Reduction(
  752. device,
  753. dst_dtype,
  754. inner_fn,
  755. ranges,
  756. reduction_ranges,
  757. reduction_type,
  758. src_dtype,
  759. reduction_hint,
  760. )
  761. )
  762. @staticmethod
  763. def default_value(reduction_type, dtype):
  764. if reduction_type in {"max", "argmax"}:
  765. if is_float_dtype(dtype):
  766. return float("-inf")
  767. elif is_boolean_dtype(dtype):
  768. return 0
  769. else:
  770. return torch.iinfo(dtype).min
  771. if reduction_type in {"min", "argmin"}:
  772. if is_float_dtype(dtype):
  773. return float("inf")
  774. elif is_boolean_dtype(dtype):
  775. return 1
  776. else:
  777. return torch.iinfo(dtype).max
  778. return {
  779. "sum": 0,
  780. "any": 0,
  781. }[reduction_type]
  782. @classmethod
  783. def create_multilayer(
  784. cls,
  785. device: torch.device,
  786. dst_dtype: torch.dtype,
  787. src_dtype: torch.dtype,
  788. inner_fn: Callable,
  789. ranges: List[Expr],
  790. reduction_ranges: List[Expr],
  791. reduction_type: str,
  792. split: int,
  793. reduction_hint: ReductionHint,
  794. ):
  795. """
  796. Break a large reduction up into multiple smaller reductions
  797. recursively
  798. """
  799. reduction_numel = sympy_product(reduction_ranges)
  800. # TODO(jansel): convert this to dynamic shapes
  801. # TODO(jansel): realize the reduction so we can do dynamic indexing
  802. reduction_ranges = [
  803. sympy.Integer(V.graph.sizevars.guard_static_shape(s))
  804. for s in reduction_ranges
  805. ]
  806. reduction_numel = sympy.Integer(
  807. V.graph.sizevars.guard_static_shape(reduction_numel)
  808. )
  809. if V.graph.sizevars.size_hint(reduction_numel) % split == 0:
  810. need_mask = False
  811. else:
  812. need_mask = True
  813. split = sympy.Integer(split)
  814. block_size = FloorDiv(reduction_numel + (split - 1), split)
  815. reindex = View.dynamic_reshape_indexer(reduction_ranges, [reduction_numel])
  816. def wrapper_fn(index, reduction_index):
  817. (reduction_index,) = reduction_index
  818. *new_index, reduction_block = index
  819. indices = block_size * reduction_block + reduction_index
  820. def body():
  821. return inner_fn(new_index, reindex([indices]))
  822. if need_mask:
  823. mask = ops.lt(
  824. ops.index_expr(indices, torch.int32),
  825. ops.index_expr(reduction_numel, torch.int32),
  826. )
  827. return ops.masked(
  828. mask, body, cls.default_value(reduction_type, dst_dtype)
  829. )
  830. else:
  831. return body()
  832. # triton will automatically compute reductions in fp32 if reducing over fp16/bf16
  833. # within the kernel. keep the intermediate in fp32 so as to keep the whole reduction
  834. # in fp32 and not reduce precision by breaking up the kernel into multiple layers
  835. intermediate_dtype = (
  836. dst_dtype
  837. if dst_dtype not in (torch.float16, torch.bfloat16)
  838. else torch.float
  839. )
  840. intermediate = Reduction.create(
  841. device,
  842. intermediate_dtype,
  843. src_dtype,
  844. wrapper_fn,
  845. [*ranges, split],
  846. [block_size],
  847. reduction_type,
  848. reduction_hint,
  849. )
  850. intermediate.realize()
  851. intermediate_loader = intermediate.make_loader()
  852. def intermediate_fn(index, reduction_index):
  853. return intermediate_loader([*index, *reduction_index])
  854. numel_hint = V.graph.sizevars.size_hint(sympy_product(ranges))
  855. if split <= 512 and numel_hint <= 512 and reduction_hint == ReductionHint.OUTER:
  856. reduction_hint = ReductionHint.OUTER_TINY
  857. if (
  858. split <= 1024
  859. and numel_hint <= 256
  860. and reduction_hint == ReductionHint.OUTER
  861. ):
  862. reduction_hint = ReductionHint.OUTER_TINY
  863. return TensorBox.create(
  864. Reduction(
  865. device,
  866. dst_dtype,
  867. intermediate_fn,
  868. ranges,
  869. [split],
  870. reduction_type,
  871. src_dtype,
  872. reduction_hint,
  873. )
  874. )
  875. def is_storage_and_layout(x):
  876. try:
  877. as_storage_and_layout(x, freeze=False)
  878. return True
  879. except NotImplementedError:
  880. return False
  881. def is_contiguous_storage_and_layout(x):
  882. try:
  883. buffer, layout = as_storage_and_layout(x, freeze=False)
  884. return layout.is_contiguous()
  885. except NotImplementedError:
  886. return False
  887. def as_storage_and_layout(x, freeze=True, want_contiguous=False, stride_order=None):
  888. """Try to simplify x into a StorageBox and a Layout"""
  889. if isinstance(x, TensorBox):
  890. return as_storage_and_layout(
  891. x.data,
  892. freeze=freeze,
  893. want_contiguous=want_contiguous,
  894. stride_order=stride_order,
  895. )
  896. if isinstance(x, StorageBox) and isinstance(x.data, Buffer):
  897. if freeze:
  898. if want_contiguous:
  899. x.data.freeze_layout()
  900. elif stride_order is not None:
  901. x.data.freeze_layout_with_stride_order(stride_order)
  902. else:
  903. x.data.decide_layout()
  904. return x, x.data.layout
  905. if isinstance(x, ReinterpretView):
  906. # making the base of x contiguous or stride_ordered will not necessarily make
  907. # the ReinterpretedView either, so dont pass along those arguments
  908. buffer, _ = as_storage_and_layout(
  909. x.data,
  910. freeze=freeze,
  911. )
  912. return buffer, x.layout
  913. raise NotImplementedError
  914. as_contiguous_storage_and_layout = functools.partial(
  915. as_storage_and_layout, want_contiguous=True
  916. )
  917. def is_stride_order_storage_and_layout(x, stride_order):
  918. try:
  919. buffer, layout = as_storage_and_layout(x, freeze=False)
  920. return layout.is_stride_ordered(stride_order)
  921. except NotImplementedError:
  922. return False
  923. @dataclasses.dataclass
  924. class BaseView(IRNode):
  925. data: IRNode
  926. def get_dtype(self):
  927. return self.data.get_dtype()
  928. def get_device(self):
  929. return self.data.get_device()
  930. def get_name(self):
  931. return self.data.get_name()
  932. def mark_reuse(self, users):
  933. return self.data.mark_reuse(users)
  934. def has_exceeded_max_reads(self):
  935. return self.data.has_exceeded_max_reads()
  936. def realize(self):
  937. return self.data.realize()
  938. def realize_hint(self):
  939. return self.data.realize_hint()
  940. def get_storage_numel(self):
  941. return self.data.get_storage_numel()
  942. def is_extern(self):
  943. return self.data.is_extern()
  944. @cache_on_self
  945. def get_reads(self):
  946. with patch.object(FlexibleLayout, "allow_indexing", True):
  947. return extract_read_writes(
  948. self.make_loader(),
  949. self.get_size(),
  950. ).reads
  951. def unwrap_view(self):
  952. x = self
  953. while isinstance(x, BaseView):
  954. x = x.data
  955. return x
  956. def constant_to_device(self, device):
  957. """Move this to a given device. Requires that all reads are to constants."""
  958. loader = self.make_loader()
  959. loader = patch.object(ConstantBuffer, "override_device", device)(loader)
  960. return Pointwise(device, self.get_dtype(), loader, self.get_size())
  961. @dataclasses.dataclass
  962. class ExpandView(BaseView):
  963. size: List[Expr]
  964. @staticmethod
  965. def _normalize_size(x, new_size):
  966. """Replace `-1` with correct sizes"""
  967. new_size = list(map(sympy.expand, new_size))
  968. old_size = x.get_size()
  969. old_size = [None] * (len(new_size) - len(old_size)) + list(old_size)
  970. assert len(new_size) == len(old_size)
  971. for i in range(len(new_size)):
  972. if new_size[i] == -1:
  973. assert old_size[i] is not None
  974. new_size[i] = old_size[i]
  975. return new_size
  976. @classmethod
  977. def create(cls, x, new_size):
  978. new_size = cls._normalize_size(x, new_size)
  979. if is_storage_and_layout(x):
  980. storage, old_layout = as_storage_and_layout(x)
  981. skip = len(new_size) - len(old_layout.size)
  982. assert skip >= 0
  983. new_stride = [sympy.Integer(0)] * skip
  984. for stride, size in zip(old_layout.stride, old_layout.size):
  985. new_stride.append(stride if size != 1 else sympy.Integer(0))
  986. new_layout = FixedLayout(
  987. old_layout.device,
  988. old_layout.dtype,
  989. list(new_size),
  990. new_stride,
  991. old_layout.offset,
  992. )
  993. return ReinterpretView(storage, new_layout)
  994. return ExpandView(x, new_size)
  995. def get_size(self):
  996. return self.size
  997. def make_loader(self):
  998. target = self.get_size()
  999. actual = self.data.get_size()
  1000. skip = len(target) - len(actual)
  1001. inner = self.data.make_loader()
  1002. def load(index):
  1003. index = list(index[skip:])
  1004. assert len(index) == len(actual)
  1005. for i in range(len(actual)):
  1006. if actual[i] == 1:
  1007. # zero out broadcast dimension
  1008. index[i] = sympy.Integer(0)
  1009. return inner(index)
  1010. return load
  1011. @dataclasses.dataclass
  1012. class PermuteView(BaseView):
  1013. dims: List[Expr]
  1014. @classmethod
  1015. def create(cls, x, dims):
  1016. dims = cls._map_neg_dims(dims)
  1017. assert set(dims) == set(range(len(dims)))
  1018. if is_storage_and_layout(x):
  1019. storage, old_layout = as_storage_and_layout(x)
  1020. new_layout = FixedLayout(
  1021. old_layout.device,
  1022. old_layout.dtype,
  1023. [old_layout.size[i] for i in dims],
  1024. [old_layout.stride[i] for i in dims],
  1025. old_layout.offset,
  1026. )
  1027. return ReinterpretView(storage, new_layout)
  1028. return PermuteView(x, dims)
  1029. @classmethod
  1030. def _map_neg_dims(cls, dims):
  1031. return [dim if dim >= 0 else len(dims) + dim for dim in dims]
  1032. def get_size(self):
  1033. assert set(self._map_neg_dims(self.dims)) == set(range(len(self.dims)))
  1034. size = self.data.get_size()
  1035. return [size[i] for i in self.dims]
  1036. def make_loader(self):
  1037. inner = self.data.make_loader()
  1038. inv = {j: i for i, j in enumerate(self.dims)}
  1039. inv = [inv[i] for i in range(len(self.dims))]
  1040. assert set(inv) == set(range(len(self.dims)))
  1041. def load(index):
  1042. index = [index[i] for i in inv]
  1043. return inner(index)
  1044. return load
  1045. class SqueezeView(BaseView):
  1046. @classmethod
  1047. def create(cls, x, *, dim=None):
  1048. if is_storage_and_layout(x):
  1049. storage, old_layout = as_storage_and_layout(x)
  1050. new_size = []
  1051. new_stride = []
  1052. if dim is not None:
  1053. assert isinstance(dim, int), "expected integer dim argument"
  1054. assert 0 <= dim and dim < len(old_layout.size)
  1055. for i, (size, stride) in enumerate(zip(old_layout.size, old_layout.stride)):
  1056. if dim is None:
  1057. if size != 1:
  1058. new_size.append(size)
  1059. new_stride.append(stride)
  1060. else:
  1061. if i != dim:
  1062. new_size.append(size)
  1063. new_stride.append(stride)
  1064. else:
  1065. assert size == 1, "expected squeezed size to be 1"
  1066. new_layout = FixedLayout(
  1067. old_layout.device,
  1068. old_layout.dtype,
  1069. new_size,
  1070. new_stride,
  1071. old_layout.offset,
  1072. )
  1073. return ReinterpretView(storage, new_layout)
  1074. if dim is None:
  1075. # redirect to a generic view
  1076. return View.create(x, [s for s in x.get_size() if s != 1])
  1077. else:
  1078. assert x.get_size()[dim] == 1
  1079. return View.create(x, [s for i, s in enumerate(x.get_size()) if i != dim])
  1080. @staticmethod
  1081. def squeezer(size: Tuple[sympy.Expr, ...]):
  1082. new_size = [s for s in size if s != 1]
  1083. not_one = [i for i, s in enumerate(size) if s != 1]
  1084. length = len(size)
  1085. def reindex(index: List[sympy.Expr]) -> List[sympy.Expr]:
  1086. assert len(index) == len(not_one), f"{index} {not_one}"
  1087. new_index = [sympy.Integer(0)] * length
  1088. for idx, s in zip(not_one, index):
  1089. new_index[idx] = s
  1090. return tuple(new_index)
  1091. return new_size, reindex
  1092. def __init__(self, data):
  1093. raise AssertionError("use SqueezeView.create()")
  1094. @dataclasses.dataclass
  1095. class View(BaseView):
  1096. size: List[Expr]
  1097. reindex: Callable
  1098. def make_indexer(self):
  1099. base_indexer = self.data.make_indexer()
  1100. def indexer(idx):
  1101. return base_indexer(self.reindex(idx))
  1102. return indexer
  1103. @staticmethod
  1104. def handle_negative_index(idx, size):
  1105. idx = sympy.expand(idx)
  1106. size = sympy.expand(size)
  1107. sizevars = V.graph.sizevars
  1108. if sizevars.size_hint(idx) < 0:
  1109. sizevars.guard_lt(idx, 0)
  1110. idx = idx + size
  1111. return idx
  1112. def reindex_str(self):
  1113. index_old = [sympy_symbol(f"i{n}") for n in range(len(self.size))]
  1114. index_new = list(self.reindex(index_old))
  1115. return f"lambda {', '.join(map(str, index_old))}: {index_new}"
  1116. def __str__(self):
  1117. return self.str_helper(
  1118. [self.data, f"size={self.size}", f"reindex={self.reindex_str()}"]
  1119. )
  1120. __repr__ = __str__
  1121. @classmethod
  1122. def create(cls, x, new_size):
  1123. assert isinstance(new_size, (tuple, list))
  1124. old_size, new_size = cls.resolve_negative_size(x.get_size(), new_size)
  1125. if V.graph.sizevars.maybe_guard_list_equals(old_size, new_size):
  1126. return x
  1127. # TODO: a new class for FixedTransferLayout that output layout is constrained by input layout
  1128. if is_contiguous_storage_and_layout(x) and not isinstance(
  1129. x.data, ExternKernelAlloc
  1130. ):
  1131. storage, old_layout = as_contiguous_storage_and_layout(x)
  1132. new_layout = FixedLayout(
  1133. old_layout.device,
  1134. old_layout.dtype,
  1135. new_size,
  1136. FlexibleLayout.contiguous_strides(new_size),
  1137. old_layout.offset,
  1138. )
  1139. return ReinterpretView(storage, new_layout)
  1140. reindex = cls.dynamic_reshape_indexer(old_size, new_size)
  1141. return cls(x, tuple(new_size), reindex)
  1142. @staticmethod
  1143. def resolve_negative_size(old_size, new_size):
  1144. new_size = [V.graph.sizevars.simplify(x) for x in new_size]
  1145. old_size = [V.graph.sizevars.simplify(x) for x in old_size]
  1146. new_size = list(new_size)
  1147. for i in range(len(new_size)):
  1148. if new_size[i] == -1:
  1149. new_size[i] = sympy.Integer(1)
  1150. new_size[i] = CleanDiv(sympy_product(old_size), sympy_product(new_size))
  1151. break
  1152. V.graph.sizevars.guard_equals(sympy_product(old_size), sympy_product(new_size))
  1153. return old_size, new_size
  1154. @classmethod
  1155. def dynamic_reshape_indexer(cls, old_size, new_size):
  1156. try:
  1157. reindex = cls._dynamic_reshape_indexer(old_size, new_size)
  1158. except (AssertionError, IndexError):
  1159. # optimistic algorithm failed, lets do a fallback
  1160. flat = [sympy_product(old_size)]
  1161. reindex1 = cls._dynamic_reshape_indexer(old_size, flat)
  1162. reindex2 = cls._dynamic_reshape_indexer(flat, new_size)
  1163. reindex = fuse_reindexing(reindex1, reindex2)
  1164. return reindex
  1165. @staticmethod
  1166. def _dynamic_reshape_indexer(old_size, new_size):
  1167. """
  1168. Perform a reshape entirely by modifying indexing math
  1169. """
  1170. size_hint = V.graph.sizevars.size_hint
  1171. vars = [sympy_symbol(f"view{i}") for i in range(len(new_size))]
  1172. stack_new = list(zip(vars, new_size))
  1173. stack_old = list(old_size)
  1174. view_expr = []
  1175. while stack_new and stack_old:
  1176. size_old = stack_old.pop()
  1177. var, size_new = stack_new.pop()
  1178. if size_old == 1:
  1179. view_expr.append(sympy.Integer(0))
  1180. stack_new.append((var, size_new)) # re-add
  1181. elif size_new == 1:
  1182. stack_old.append(size_old) # re-add
  1183. elif size_hint(size_new) == size_hint(size_old):
  1184. view_expr.append(var)
  1185. V.graph.sizevars.guard_equals(size_new, size_old)
  1186. elif size_hint(size_new) < size_hint(size_old):
  1187. while size_hint(size_new) < size_hint(size_old):
  1188. var2, size_new2 = stack_new.pop()
  1189. var = var2 * size_new + var
  1190. size_new = size_new * size_new2
  1191. view_expr.append(var)
  1192. V.graph.sizevars.guard_equals(size_new, size_old)
  1193. elif size_hint(size_new) > size_hint(size_old):
  1194. divisor = sympy.Integer(1)
  1195. modulus = size_old
  1196. view_expr.append(ModularIndexing(var, divisor, modulus))
  1197. divisor = divisor * modulus
  1198. while size_hint(size_new) > size_hint(size_old):
  1199. modulus = stack_old.pop()
  1200. view_expr.append(ModularIndexing(var, divisor, modulus))
  1201. divisor = divisor * modulus
  1202. size_old = size_old * modulus
  1203. V.graph.sizevars.guard_equals(size_new, size_old)
  1204. else:
  1205. raise AssertionError()
  1206. while stack_old:
  1207. size_old = stack_old.pop()
  1208. assert size_old == 1
  1209. view_expr.append(sympy.Integer(0))
  1210. while stack_new:
  1211. var, size_new = stack_new.pop()
  1212. assert size_new == 1
  1213. view_expr = list(reversed(view_expr))
  1214. assert len(view_expr) == len(old_size)
  1215. def reindex(index):
  1216. assert len(index) == len(vars), (len(index), len(vars))
  1217. replacements = dict(zip(vars, index))
  1218. return tuple(sympy_subs(x, replacements) for x in view_expr)
  1219. return reindex
  1220. def get_size(self):
  1221. return self.size
  1222. def make_loader(self):
  1223. def load(index):
  1224. return inner(self.reindex(index))
  1225. inner = self.data.make_loader()
  1226. return load
  1227. @dataclasses.dataclass
  1228. class ReinterpretView(BaseView):
  1229. """Pretend our storage has a different layout"""
  1230. layout: "Layout"
  1231. def __post_init__(self):
  1232. if isinstance(self.data, BaseView):
  1233. self.data = self.data.unwrap_view()
  1234. def __str__(self):
  1235. return self.str_helper(
  1236. [
  1237. self.data,
  1238. self.layout,
  1239. ]
  1240. )
  1241. __repr__ = __str__
  1242. def get_name(self):
  1243. return self.data.get_name()
  1244. def get_device(self):
  1245. return self.layout.device
  1246. def get_dtype(self):
  1247. return self.layout.dtype
  1248. def get_size(self):
  1249. return list(self.layout.size)
  1250. def get_stride(self):
  1251. return list(self.layout.stride)
  1252. def make_loader(self):
  1253. def loader(index):
  1254. indexer = self.layout.make_indexer()
  1255. return ops.load(self.get_name(), indexer(index))
  1256. return loader
  1257. def make_indexer(self):
  1258. return self.layout.make_indexer()
  1259. def get_layout(self):
  1260. return self.layout
  1261. def freeze_layout(self):
  1262. pass
  1263. def codegen_reference(self):
  1264. size = V.graph.sizevars.codegen_shape_tuple(self.layout.size)
  1265. stride = V.graph.sizevars.codegen_shape_tuple(self.layout.stride)
  1266. offset = V.graph.sizevars.codegen_sizevar(self.layout.offset)
  1267. as_strided = V.graph.sizevars.as_strided
  1268. if offset != "0":
  1269. return f"{as_strided}({self.get_name()}, {size}, {stride}, {offset})"
  1270. return f"{as_strided}({self.get_name()}, {size}, {stride})"
  1271. class SliceView(View):
  1272. @classmethod
  1273. def create(cls, x, dim, start, end, step=1):
  1274. step = sympy.expand(step)
  1275. assert step > 0
  1276. try:
  1277. if start == 0 and end >= 2**63 and step == 1:
  1278. return x
  1279. except TypeError:
  1280. pass
  1281. sizevars = V.graph.sizevars
  1282. new_size = list(x.get_size())
  1283. start = cls.handle_negative_index(start, new_size[dim])
  1284. end = cls.handle_negative_index(end, new_size[dim])
  1285. end = sizevars.guard_min(end, new_size[dim])
  1286. start = sizevars.guard_min(sizevars.guard_min(start, new_size[dim]), end)
  1287. if start == 0 and sizevars.size_hint(end - new_size[dim]) == 0 and step == 1:
  1288. sizevars.guard_equals(end, new_size[dim])
  1289. return x
  1290. new_size[dim] = FloorDiv(end - start + (step - 1), step)
  1291. if is_storage_and_layout(x):
  1292. # Fast path
  1293. storage, old_layout = as_storage_and_layout(x)
  1294. new_stride = list(old_layout.stride)
  1295. new_stride[dim] = new_stride[dim] * step
  1296. new_layout = FixedLayout(
  1297. old_layout.device,
  1298. old_layout.dtype,
  1299. new_size,
  1300. new_stride,
  1301. old_layout.offset + old_layout.stride[dim] * start,
  1302. )
  1303. return ReinterpretView(storage, new_layout)
  1304. def reindex(index):
  1305. assert len(index) == len(new_size), f"wrong ndim {index} {new_size}"
  1306. index = list(index)
  1307. index[dim] = index[dim] * step + start
  1308. return index
  1309. # redirect to a generic view
  1310. return SliceView(x, size=new_size, reindex=reindex)
  1311. class BaseConstant(IRNode):
  1312. def get_size(self):
  1313. return ()
  1314. def get_dtype(self):
  1315. return self.dtype
  1316. def get_device(self):
  1317. return self.device
  1318. def mark_reuse(self, users):
  1319. pass
  1320. def has_exceeded_max_reads(self):
  1321. return False
  1322. def get_reads(self):
  1323. return ()
  1324. def is_extern(self):
  1325. return False
  1326. @dataclasses.dataclass
  1327. class Constant(BaseConstant):
  1328. value: Any
  1329. dtype: torch.dtype
  1330. device: torch.device
  1331. def make_loader(self):
  1332. def loader(index):
  1333. return ops.constant(self.value, self.dtype)
  1334. return loader
  1335. def realize(self):
  1336. pass
  1337. @dataclasses.dataclass
  1338. class IndexingConstant(BaseConstant):
  1339. index: Any
  1340. dtype: torch.dtype
  1341. device: torch.device
  1342. def make_loader(self):
  1343. def loader(index):
  1344. return ops.index_expr(self.index, self.dtype)
  1345. return loader
  1346. @dataclasses.dataclass
  1347. class Layout(IRNode):
  1348. def __init__(
  1349. self,
  1350. device: torch.device,
  1351. dtype: torch.dtype,
  1352. size: List[Expr],
  1353. stride: List[Expr],
  1354. offset: Expr = Integer(0),
  1355. ):
  1356. self.device = device
  1357. self.dtype = dtype
  1358. assert all(isinstance(s, (Expr, int)) for s in size)
  1359. self.size = size
  1360. self._stride = stride
  1361. self.offset = offset
  1362. @property
  1363. def stride(self):
  1364. return self._stride
  1365. def __str__(self):
  1366. offset = ""
  1367. if self.offset != 0:
  1368. offset = f", offset={self.offset}"
  1369. return (
  1370. f"{type(self).__name__}('{self.device.type}', {self.dtype}, "
  1371. f"size={self.size}, stride={self.stride}{offset})"
  1372. )
  1373. __repr__ = __str__
  1374. def is_contiguous(self):
  1375. for left, right, size in zip(
  1376. self.stride, FlexibleLayout.contiguous_strides(self.size), self.size
  1377. ):
  1378. if size != 1 and left != right:
  1379. return False
  1380. return True
  1381. def is_channels_last_contiguous(self):
  1382. ndim = len(self.size)
  1383. if ndim not in [4, 5]:
  1384. return False
  1385. for left, right, size in zip(
  1386. self.stride, make_channels_last_strides_for(self.size), self.size
  1387. ):
  1388. if size != 1 and left != right:
  1389. return False
  1390. return True
  1391. def is_transposed(self):
  1392. for left, right, size in zip(
  1393. self.stride,
  1394. reversed(FlexibleLayout.contiguous_strides(self.size)),
  1395. self.size,
  1396. ):
  1397. if size != 1 and left != right:
  1398. return False
  1399. return True
  1400. def is_stride_ordered(self, order):
  1401. assert len(self.stride) == len(order)
  1402. # reorder the stride given order
  1403. stride_ordered = [None] * len(order)
  1404. for i in range(len(order)):
  1405. stride_ordered[order[i]] = V.graph.sizevars.size_hint(self.stride[i])
  1406. # check if it is in ascending order
  1407. for i in range(len(order) - 1):
  1408. if stride_ordered[i] > stride_ordered[i + 1]:
  1409. return False
  1410. return True
  1411. def is_channels_last_stride_ordered(self):
  1412. # create channels_last order(NCHW, NCDHW, the C is the first order).
  1413. order = [0] + list(reversed(range(1, len(self.stride) - 1)))
  1414. order = [len(order)] + order
  1415. return self.is_stride_ordered(order)
  1416. def as_fixed(self):
  1417. return FixedLayout(
  1418. self.device,
  1419. self.dtype,
  1420. self.size,
  1421. self.stride,
  1422. self.offset,
  1423. )
  1424. def make_indexer(self):
  1425. assert (
  1426. FlexibleLayout.allow_indexing
  1427. ), f"convert {type(self).__name__} to FixedLayout first"
  1428. return self.as_fixed().make_indexer()
  1429. def __eq__(self, other) -> bool:
  1430. return (
  1431. self.device == other.device
  1432. and self.dtype == other.dtype
  1433. and self.size == other.size
  1434. and self.stride == other.stride
  1435. and self.offset == other.offset
  1436. )
  1437. class FixedLayout(Layout):
  1438. """A Tensor layout we cannot change"""
  1439. def __init__(
  1440. self,
  1441. device: torch.device,
  1442. dtype: torch.dtype,
  1443. size: List[Expr],
  1444. stride: List[Expr] = None,
  1445. offset: Expr = Integer(0),
  1446. ):
  1447. if stride is None:
  1448. stride = FlexibleLayout.contiguous_strides(size)
  1449. super().__init__(
  1450. device,
  1451. dtype,
  1452. size,
  1453. stride,
  1454. offset,
  1455. )
  1456. def make_indexer(self):
  1457. """A closure containing math to read a given element"""
  1458. def indexer(index):
  1459. assert len(index) == len(self.stride) == len(self.size)
  1460. result = self.offset
  1461. for idx, stride, sz in zip(index, self.stride, self.size):
  1462. if sz != 1:
  1463. result = result + idx * stride
  1464. return result
  1465. return indexer
  1466. class FlexibleLayout(Layout):
  1467. """A Tensor layout we are allowed to change"""
  1468. allow_indexing = False
  1469. @staticmethod
  1470. def contiguous_strides(sizes):
  1471. if len(sizes) == 0:
  1472. return []
  1473. reversed_strides = [sympy.Integer(1)]
  1474. for size in reversed(sizes[1:]):
  1475. reversed_strides.append(size * reversed_strides[-1])
  1476. return list(reversed(reversed_strides))
  1477. @staticmethod
  1478. def fill_ordered(sizes, order):
  1479. """
  1480. Create a stride based on the order the dimensions should be filled in.
  1481. In this format, channels last would be:
  1482. [1, 3, 2, 0]
  1483. """
  1484. assert set(range(len(sizes))) == set(order)
  1485. next_stride = sympy.Integer(1)
  1486. strides = [None] * len(order)
  1487. for i in order:
  1488. strides[i] = next_stride
  1489. next_stride = next_stride * sizes[i]
  1490. return strides
  1491. @staticmethod
  1492. def stride_ordered(sizes, order):
  1493. """
  1494. Create a stride based on the sorted order of a permuted range.
  1495. In this format, channels last would be:
  1496. [3, 0, 2, 1]
  1497. """
  1498. assert set(range(len(sizes))) == set(order)
  1499. fill_order = stride_order2fill_order(order)
  1500. return FlexibleLayout.fill_ordered(sizes, fill_order)
  1501. @staticmethod
  1502. def same_ordered(sizes, stride):
  1503. """
  1504. Create a stride that has the same stride order as given stride
  1505. For example, if given stride is [1000, 1, 100, 10],
  1506. the fill order should be [1, 3, 2, 0]
  1507. """
  1508. assert len(sizes) == len(stride)
  1509. stride = [V.graph.sizevars.size_hint(x) for x in stride]
  1510. fill_order = sorted(range(len(stride)), key=stride.__getitem__)
  1511. return FlexibleLayout.fill_ordered(sizes, fill_order)
  1512. def as_stride_order(self, order):
  1513. return FixedLayout(
  1514. self.device,
  1515. self.dtype,
  1516. self.size,
  1517. self.stride_ordered(self.size, order),
  1518. self.offset,
  1519. )
  1520. def as_fill_order(self, order):
  1521. return FixedLayout(
  1522. self.device,
  1523. self.dtype,
  1524. self.size,
  1525. self.fill_ordered(self.size, order),
  1526. self.offset,
  1527. )
  1528. def as_same_order(self, stride):
  1529. return FixedLayout(
  1530. self.device,
  1531. self.dtype,
  1532. self.size,
  1533. self.same_ordered(self.size, stride),
  1534. self.offset,
  1535. )
  1536. def __init__(self, device, dtype, size, stride_order=None):
  1537. if stride_order:
  1538. strides = FlexibleLayout.fill_ordered(size, stride_order)
  1539. else:
  1540. strides = FlexibleLayout.contiguous_strides(size)
  1541. super().__init__(device, dtype, size, strides)
  1542. class AliasedLayout(Layout):
  1543. """Shares the same storage as another tensor"""
  1544. def __init__(self, view: "ReinterpretView"):
  1545. layout = view.get_layout()
  1546. super().__init__(
  1547. layout.device,
  1548. layout.dtype,
  1549. layout.size,
  1550. layout.stride,
  1551. )
  1552. self.view = view
  1553. def make_indexer(self):
  1554. return self.as_fixed().make_indexer()
  1555. def maybe_guard_aligned(self):
  1556. offset = self.view.get_layout().offset
  1557. if offset == 0:
  1558. return True
  1559. from .compile_fx import ALIGNMENT
  1560. return V.graph.sizevars.maybe_guard_multiple_of(offset, ALIGNMENT)
  1561. class MutationLayout(Layout):
  1562. def __init__(self, target: IRNode):
  1563. super().__init__(
  1564. target.get_device(),
  1565. target.get_dtype(),
  1566. target.get_size(),
  1567. None, # type: ignore[arg-type]
  1568. )
  1569. self.target = target
  1570. @Layout.stride.getter
  1571. def stride(self):
  1572. return self.real_layout().stride
  1573. def real_layout(self):
  1574. if isinstance(self.target, MutationLayout):
  1575. return self.target.real_layout()
  1576. return self.target.data.layout
  1577. @classmethod
  1578. def realize_into(cls, src, dst):
  1579. dst.realize()
  1580. V.graph.realize_users_of(dst.get_name())
  1581. if isinstance(src, TensorBox):
  1582. src = src.data
  1583. if not isinstance(src, StorageBox) or src.is_user_of(dst.get_name()):
  1584. need_copy = True
  1585. else:
  1586. src.realize()
  1587. need_copy = not isinstance(src.data.layout, FlexibleLayout)
  1588. if need_copy:
  1589. src = Pointwise.create(
  1590. device=src.get_device(),
  1591. dtype=src.get_dtype(),
  1592. inner_fn=src.make_loader(),
  1593. ranges=[
  1594. V.graph.sizevars.guard_equals(a, b)
  1595. for a, b in zip(src.get_size(), dst.get_size())
  1596. ],
  1597. ).data
  1598. src.realize()
  1599. assert isinstance(src.data.layout, FlexibleLayout)
  1600. src.data.layout = MutationLayout(dst)
  1601. return src.data
  1602. def as_fixed(self):
  1603. return self
  1604. def make_indexer(self):
  1605. return self.target.make_indexer()
  1606. @dataclasses.dataclass
  1607. class Buffer(IRNode):
  1608. name: str
  1609. layout: Layout
  1610. def make_indexer(self):
  1611. return self.layout.make_indexer()
  1612. def get_name(self):
  1613. assert self.name
  1614. return self.name
  1615. def get_device(self):
  1616. return self.layout.device
  1617. def get_dtype(self):
  1618. return getattr(self.layout, "dtype", None)
  1619. def get_size(self):
  1620. return list(self.layout.size)
  1621. def get_stride(self):
  1622. return list(self.layout.stride)
  1623. def get_layout(self):
  1624. return self.layout
  1625. def get_storage_numel(self):
  1626. return self.get_numel()
  1627. def is_extern(self):
  1628. return False
  1629. def freeze_layout(self):
  1630. if not isinstance(self.layout, MultiOutputLayout):
  1631. self.layout = self.layout.as_fixed()
  1632. def freeze_layout_with_stride_order(self, order):
  1633. assert isinstance(self.layout, FlexibleLayout)
  1634. self.layout = self.layout.as_stride_order(order)
  1635. def freeze_layout_with_fill_order(self, order):
  1636. assert isinstance(self.layout, FlexibleLayout)
  1637. self.layout = self.layout.as_fill_order(order)
  1638. def freeze_layout_with_same_order(self, stride):
  1639. assert isinstance(self.layout, FlexibleLayout)
  1640. self.layout = self.layout.as_same_order(stride)
  1641. def make_loader(self):
  1642. def loader(index):
  1643. indexer = self.layout.make_indexer()
  1644. return ops.load(self.name, indexer(index))
  1645. return loader
  1646. def is_no_op(self):
  1647. return False
  1648. def codegen_reference(self):
  1649. return self.get_name()
  1650. def decide_layout(self):
  1651. pass
  1652. def get_alias_names(self):
  1653. if isinstance(self.layout, AliasedLayout):
  1654. return [self.layout.view.get_name()]
  1655. return ()
  1656. def get_mutation_names(self):
  1657. if isinstance(self.layout, MutationLayout):
  1658. return [self.layout.target.get_name()]
  1659. return ()
  1660. @cache_on_self
  1661. def get_read_writes(self):
  1662. with patch.object(FlexibleLayout, "allow_indexing", True):
  1663. return extract_read_writes(
  1664. self.make_loader(),
  1665. self.get_size(),
  1666. )
  1667. def get_reads(self):
  1668. return self.get_read_writes().reads
  1669. def realize(self):
  1670. pass
  1671. class InputBuffer(Buffer):
  1672. pass
  1673. class ConstantBuffer(InputBuffer):
  1674. override_device = None
  1675. def make_loader(self):
  1676. def loader(index):
  1677. indexer = self.layout.make_indexer()
  1678. return ops.load(
  1679. V.graph.constant_name(self.name, self.override_device), indexer(index)
  1680. )
  1681. return loader
  1682. def constant_to_device(self, device):
  1683. return ConstantBuffer(V.graph.constant_name(self.name, device), self.layout)
  1684. class RandSeedBuffer(ConstantBuffer):
  1685. def codegen_reference(self):
  1686. # Clone makes sure if we pass this from forwards to backwards
  1687. # the value does not get clobbered by the time backwards is run.
  1688. return self.get_name() + ".clone()"
  1689. class NoneAsConstantBuffer(IRNode):
  1690. def codegen_reference(self):
  1691. return "None"
  1692. def cpp_wrapper_codegen_reference(self):
  1693. return "at::Tensor()"
  1694. class ShapeAsConstantBuffer(IRNode):
  1695. def __init__(self, shape):
  1696. super().__init__()
  1697. self.shape = shape
  1698. def codegen_reference(self):
  1699. return str(V.graph.sizevars.simplify(self.shape))
  1700. @dataclasses.dataclass
  1701. class ComputedBuffer(Buffer):
  1702. data: Loops
  1703. @cache_on_self
  1704. def get_read_writes(self):
  1705. with patch.object(FlexibleLayout, "allow_indexing", True):
  1706. if self.data.get_reduction_type():
  1707. return extract_read_writes(
  1708. self.get_store_function(),
  1709. self.data.get_size(),
  1710. self.data.get_reduction_size(),
  1711. )
  1712. else:
  1713. return extract_read_writes(
  1714. self.get_store_function(),
  1715. self.data.get_size(),
  1716. )
  1717. def get_store_function(self):
  1718. indexer = self.layout.as_fixed().make_indexer()
  1719. if self.data.get_reduction_type():
  1720. return partial(self.data.store_reduction, self.name, indexer)
  1721. else:
  1722. return partial(self.data.store_output, self.name, indexer)
  1723. def get_fill_order(self):
  1724. """
  1725. If our layout is still flexible, try to determine the stride order based on stride orders of reads.
  1726. TODO(jansel): A better algorithm here would look at downstream consumers of this
  1727. value and try to do global graph-level layout optimization.
  1728. This is also something just begging to be autotuned.
  1729. """
  1730. if isinstance(self.layout, FlexibleLayout):
  1731. _, (index_vars, reduction_vars), _ = dependencies.index_vars_squeeze(
  1732. self.data.get_size(), self.data.get_reduction_size()
  1733. )
  1734. reads = self.get_read_writes().reads
  1735. reads_bufs = [
  1736. V.graph.name_to_buffer[r.name]
  1737. if r.name in V.graph.name_to_buffer.keys()
  1738. else None
  1739. for r in reads
  1740. ]
  1741. priority_idx = []
  1742. for i, reads_buf in enumerate(reads_bufs):
  1743. if (
  1744. isinstance(reads_buf, Convolution)
  1745. and reads_buf.kernel != "aten.convolution"
  1746. ):
  1747. # prioritize Conv layout order
  1748. priority_idx.append(i)
  1749. # only consider reads to buffer of same size
  1750. reads = [
  1751. sympy_subs(
  1752. r.index, {v: sympy.Integer(0) for v in reduction_vars if v != 0}
  1753. )
  1754. for r in reads
  1755. ]
  1756. if reads:
  1757. stride_lengths = [
  1758. V.graph.sizevars.stride_hints(expr, index_vars) for expr in reads
  1759. ]
  1760. from .scheduler import pick_loop_order
  1761. return pick_loop_order(stride_lengths, self.get_size(), priority_idx)
  1762. return None
  1763. def decide_layout(self):
  1764. if isinstance(self.layout, FlexibleLayout):
  1765. order = self.get_fill_order()
  1766. if order:
  1767. self.freeze_layout_with_fill_order(order)
  1768. else:
  1769. self.freeze_layout()
  1770. def simplify_and_reorder(self):
  1771. """
  1772. This is a main place where we do loop transformations in a
  1773. backend-agnostic way.
  1774. Here we:
  1775. 1) Remove any 1 dimensions
  1776. 2) Fuse contiguous dimensions together
  1777. 3) Reorder dimensions based on stride orders
  1778. """
  1779. _, args, var_ranges = dependencies.index_vars_squeeze(
  1780. self.data.get_size(), self.data.get_reduction_size(), prefix="q"
  1781. )
  1782. with patch.object(ConstantBuffer, "override_device", self.get_device()):
  1783. body = LoopBody(
  1784. self.get_store_function(),
  1785. (args if self.get_reduction_type() else args[:1]),
  1786. var_ranges,
  1787. )
  1788. index_formulas = [*body.indexing_exprs.values()]
  1789. reads_bufs = [
  1790. V.graph.name_to_buffer[reads_name]
  1791. if reads_name in V.graph.name_to_buffer.keys()
  1792. else None
  1793. for reads_name in body.reads_name2expr.keys()
  1794. ]
  1795. priority_idx = []
  1796. memory_addrs = [
  1797. *body.reads_name2expr.values(),
  1798. *body.writes_name2expr.values(),
  1799. ]
  1800. index_vars = []
  1801. reduce_vars = []
  1802. index_size = []
  1803. reduce_size = []
  1804. for v, s in var_ranges.items():
  1805. if v in args[0]:
  1806. assert not reduce_vars
  1807. index_vars.append(v)
  1808. index_size.append(s)
  1809. else:
  1810. assert v in args[1]
  1811. reduce_vars.append(v)
  1812. reduce_size.append(s)
  1813. # the reordering_reindex in reads' simplify_reorder_and_tile
  1814. reordering_reindex = [same_reorder(range(len(index_vars)))] * len(memory_addrs)
  1815. for i, reads_buf in enumerate(reads_bufs):
  1816. if isinstance(reads_buf, ComputedBuffer) and hasattr(
  1817. reads_buf, "iter_reordering_reindex"
  1818. ):
  1819. reordering_reindex[i] = reads_buf.iter_reordering_reindex
  1820. def simplify_and_reorder(x_vars, sizes, reordering_reindex=None):
  1821. sizes, reindex0, reindex1 = self._apply_loop_reordering(
  1822. x_vars, sizes, memory_addrs, reordering_reindex, priority_idx
  1823. )
  1824. # for NHWC: reindex0([0,1,2,3]) = [0,2,3,1], reindex1([0,1,2,3]) = [0,3,2,1]
  1825. x_vars = reindex0(x_vars)
  1826. sizes, reindex2, prune = V.graph.sizevars._simplify_loops(
  1827. x_vars,
  1828. sizes,
  1829. index_prevent_reordering(index_formulas, x_vars, sizes),
  1830. )
  1831. x_vars = prune(x_vars)
  1832. # sizes, reindex1, prune = _simplify_loops(x_vars, sizes, index_formulas)
  1833. # x_vars = prune(x_vars)
  1834. # sizes, reindex2 = self._apply_loop_reordering(x_vars, sizes, memory_addrs)
  1835. reindex = fuse_reindexing(reindex1, reindex2)
  1836. return sizes, reindex, reindex1
  1837. iter_ranges, iter_reindex, iter_reordering_reindex = simplify_and_reorder(
  1838. index_vars, index_size, reordering_reindex
  1839. )
  1840. reduce_ranges, reduce_reindex, _ = simplify_and_reorder(
  1841. reduce_vars, reduce_size
  1842. )
  1843. # remember the reordering order
  1844. self.iter_reordering_reindex = iter_reordering_reindex
  1845. # retrace the loop body with simplification and reordering applied
  1846. (iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze(
  1847. iter_ranges, reduce_ranges, prefix="z"
  1848. )
  1849. body = LoopBody(
  1850. body, [iter_reindex(iter_vars), reduce_reindex(reduce_vars)], var_ranges
  1851. )
  1852. return (iter_ranges, reduce_ranges), body
  1853. @staticmethod
  1854. def _apply_loop_reordering(
  1855. index_vars, sizes, memory_addrs, reordering_reindex=None, priority_idx=None
  1856. ):
  1857. """
  1858. Shuffle the order of loops around to hopefully improve performance.
  1859. """
  1860. from .scheduler import pick_loop_order
  1861. if priority_idx is None:
  1862. priority_idx = []
  1863. try:
  1864. strides = [
  1865. V.graph.sizevars.stride_hints(expr, index_vars) for expr in memory_addrs
  1866. ]
  1867. assert len(strides) == len(memory_addrs) and len(strides[0]) == len(
  1868. index_vars
  1869. )
  1870. # consider both layout(strides) and reordering(reordering_reindex)
  1871. if reordering_reindex is not None:
  1872. for i in range(len(memory_addrs)):
  1873. try:
  1874. strides[i] = reordering_reindex[i](strides[i])
  1875. # if len(order) != len(strides), do not reorder
  1876. except AssertionError:
  1877. pass
  1878. order = list(reversed(pick_loop_order(strides, sizes, priority_idx)))
  1879. except Exception:
  1880. if config.debug:
  1881. log.warning(
  1882. f"Did not simplify complex index:\n{dict(zip(index_vars, sizes))}\n{memory_addrs}"
  1883. )
  1884. order = list(range(len(sizes)))
  1885. sizes = [sizes[i] for i in order]
  1886. return sizes, same_reorder(order), inverse_reorder(order)
  1887. def get_reduction_size(self):
  1888. return self.data.get_reduction_size()
  1889. def get_reduction_type(self):
  1890. return self.data.get_reduction_type()
  1891. def is_no_op(self):
  1892. return self.data.is_zero_elements()
  1893. def should_allocate(self):
  1894. return True
  1895. def constant_to_device(self, device):
  1896. """Move this to a given device. Requires that all reads are to constants."""
  1897. return self.data.constant_to_device(device)
  1898. class TemplateBuffer(Buffer):
  1899. """
  1900. Represents a Triton (in the futurue other type) of template operator
  1901. that we can fuse an epilogue onto.
  1902. """
  1903. def __init__(self, layout, inputs, make_kernel_render):
  1904. super().__init__(name=None, layout=layout)
  1905. self.inputs = InputsKernel.unwrap_storage(inputs)
  1906. self.make_kernel_render = make_kernel_render
  1907. self.name = V.graph.register_buffer(self)
  1908. def get_read_writes(self):
  1909. return self.normalized_read_writes()
  1910. @cache_on_self
  1911. def normalized_read_writes(self):
  1912. name = self.get_name()
  1913. indexer = self.layout.make_indexer()
  1914. def dummy(index, rindex):
  1915. assert len(rindex) == 0
  1916. return ops.store(name, indexer(index), "fake")
  1917. deps = dependencies.extract_read_writes(
  1918. dummy, self.get_size(), (), normalize=True
  1919. )
  1920. deps.reads = {dependencies.StarDep(x.get_name()) for x in self.inputs}
  1921. return deps
  1922. def get_reduction_size(self):
  1923. return 1
  1924. def get_reduction_type(self):
  1925. return None
  1926. def is_no_op(self):
  1927. return False
  1928. def should_allocate(self):
  1929. return True
  1930. def simplify_and_reorder(self):
  1931. return (
  1932. (
  1933. self.get_size(),
  1934. (),
  1935. ),
  1936. None,
  1937. )
  1938. @dataclasses.dataclass
  1939. class InputsKernel(Buffer):
  1940. inputs: List[Buffer]
  1941. def get_read_writes(self):
  1942. return dependencies.ReadWrites(
  1943. {dependencies.StarDep(x.get_name()) for x in self.inputs},
  1944. {dependencies.StarDep(self.get_name())},
  1945. set(),
  1946. [],
  1947. None,
  1948. )
  1949. @staticmethod
  1950. def unwrap_storage(inputs):
  1951. inputs_new = []
  1952. for x in inputs:
  1953. if isinstance(x, TensorBox):
  1954. x = x.data
  1955. if isinstance(x, StorageBox):
  1956. x = x.data
  1957. if isinstance(x, BaseView) and not isinstance(x, ReinterpretView):
  1958. x = ExternKernel.realize_input(x)
  1959. assert isinstance(x, (Buffer, ReinterpretView)), x
  1960. inputs_new.append(x)
  1961. return inputs_new
  1962. def is_extern(self):
  1963. return True
  1964. class NopKernel(InputsKernel):
  1965. def is_no_op(self):
  1966. return True
  1967. class ConcatKernel(NopKernel):
  1968. """
  1969. There isn't actually a real kernel for concat, we just change the
  1970. storage for the upstream data.
  1971. """
  1972. @classmethod
  1973. def create(cls, inputs, dim):
  1974. device = inputs[0].get_device()
  1975. dtype = inputs[0].get_dtype()
  1976. new_size = list(inputs[0].get_size())
  1977. offsets_start = [0]
  1978. offsets_end = [new_size[dim]]
  1979. assert 0 <= dim < len(new_size)
  1980. for i in range(1, len(inputs)):
  1981. input_size = inputs[i].get_size()
  1982. offsets_start.append(new_size[dim])
  1983. assert len(input_size) == len(new_size)
  1984. assert inputs[i].get_dtype() == dtype
  1985. assert inputs[i].get_device() == device
  1986. for j in range(len(new_size)):
  1987. if j == dim:
  1988. new_size[j] = new_size[j] + input_size[j]
  1989. else:
  1990. new_size[j] = V.graph.sizevars.guard_equals(
  1991. new_size[j], input_size[j]
  1992. )
  1993. offsets_end.append(new_size[dim])
  1994. output_stride = FlexibleLayout.contiguous_strides(new_size)
  1995. # If any of the inputs is in CL format, use CL format for the output
  1996. for i in range(len(inputs)):
  1997. x = inputs[i]
  1998. if is_storage_and_layout(x):
  1999. layout = x.get_layout()
  2000. if (
  2001. isinstance(layout, FixedLayout)
  2002. and layout.is_channels_last_contiguous()
  2003. ):
  2004. # use CL stride for the output
  2005. output_stride = make_channels_last_strides_for(new_size)
  2006. break
  2007. kernel = ConcatKernel(
  2008. name=None,
  2009. layout=FixedLayout(
  2010. device=device,
  2011. dtype=dtype,
  2012. size=new_size,
  2013. stride=output_stride,
  2014. ),
  2015. inputs=[],
  2016. )
  2017. kernel = StorageBox(kernel)
  2018. for i in range(len(inputs)):
  2019. kernel.data.inputs.append(
  2020. cls.realize_into(
  2021. inputs[i],
  2022. SliceView.create(kernel, dim, offsets_start[i], offsets_end[i]),
  2023. )
  2024. )
  2025. kernel.data.name = V.graph.register_buffer(kernel.data)
  2026. kernel.data.inputs = cls.unwrap_storage(kernel.data.inputs)
  2027. return kernel
  2028. @classmethod
  2029. def realize_into(cls, src, dst):
  2030. # Attempt to turn this into a ReinterpretView rather than assert.
  2031. # This has concessions around layout, as as_storage_and_layout
  2032. # can cause us to go from flexible to fixed layout.
  2033. if not isinstance(dst, ReinterpretView):
  2034. if is_storage_and_layout(dst):
  2035. storage, layout = as_storage_and_layout(dst)
  2036. dst = ReinterpretView(storage, layout)
  2037. assert isinstance(dst, ReinterpretView), dst
  2038. if isinstance(src, TensorBox):
  2039. # unwrap a TensorBox
  2040. return cls.realize_into(src.data, dst)
  2041. if isinstance(src, StorageBox):
  2042. src.realize()
  2043. # ExternKernelAlloc has specific requirements for output layout, should create a copy
  2044. if isinstance(src.data.layout, FlexibleLayout) and not isinstance(
  2045. src.data, ExternKernelAlloc
  2046. ):
  2047. src.data.layout = AliasedLayout(dst)
  2048. return src.data
  2049. # introduce a copy
  2050. pw = Pointwise.create(
  2051. device=src.get_device(),
  2052. dtype=src.get_dtype(),
  2053. inner_fn=src.make_loader(),
  2054. ranges=[
  2055. V.graph.sizevars.guard_equals(a, b)
  2056. for a, b in zip(src.get_size(), dst.get_size())
  2057. ],
  2058. )
  2059. return cls.realize_into(pw, dst)
  2060. def should_allocate(self):
  2061. return True
  2062. @dataclasses.dataclass
  2063. class ExternKernel(InputsKernel):
  2064. constant_args: Tuple[Any, ...] = ()
  2065. kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
  2066. output_view: Optional[ReinterpretView] = None
  2067. def decide_layout(self):
  2068. if isinstance(self.layout, FlexibleLayout):
  2069. self.apply_constraint()
  2070. self.freeze_layout()
  2071. def codegen(self, wrapper):
  2072. raise NotImplementedError
  2073. @staticmethod
  2074. def copy_input(x):
  2075. pw = Pointwise.create(
  2076. device=x.get_device(),
  2077. dtype=x.get_dtype(),
  2078. inner_fn=x.make_loader(),
  2079. ranges=x.get_size(),
  2080. )
  2081. pw.realize()
  2082. return pw
  2083. @classmethod
  2084. def process_kernel(cls, kernel, *args, **kwargs):
  2085. binded_args = signature(kernel).bind(*args, **kwargs).arguments
  2086. args_flat, args_spec = pytree.tree_flatten(binded_args)
  2087. is_arg_tensor = []
  2088. tensor_args = []
  2089. non_tensor_args = []
  2090. for arg in args_flat:
  2091. is_arg_tensor.append(isinstance(arg, IRNode))
  2092. if is_arg_tensor[-1]:
  2093. tensor_args.append(arg)
  2094. else:
  2095. if isinstance(arg, sympy.Expr):
  2096. arg = V.graph.sizevars.shape_env.create_symintnode(arg, hint=None)
  2097. non_tensor_args.append(arg)
  2098. def unflatten_args(new_tensor_args, new_non_tensor_args):
  2099. result = []
  2100. it_tensors = iter(new_tensor_args)
  2101. it_non_tensors = iter(new_non_tensor_args)
  2102. for is_tensor in is_arg_tensor:
  2103. if is_tensor:
  2104. result.append(next(it_tensors))
  2105. else:
  2106. result.append(next(it_non_tensors))
  2107. result = pytree.tree_unflatten(result, args_spec)
  2108. return result.get("args", []), result.get("kwargs", {})
  2109. tensor_args = [cls.realize_input(x) for x in tensor_args]
  2110. # freeze layout otherwise our output stride calculation might
  2111. # become incorrect
  2112. for x in tensor_args:
  2113. if is_storage_and_layout(x):
  2114. as_storage_and_layout(x, freeze=True)
  2115. # We don't have generic shape formulas, so just burn in the
  2116. # shapes and run an example input.
  2117. # TODO(jansel): replace this with dynamic shape formulas
  2118. example_args = []
  2119. for x in tensor_args:
  2120. example_args.append(ir_node_to_tensor(x, guard_shape=True))
  2121. new_args, new_kwargs = unflatten_args(example_args, non_tensor_args)
  2122. example_output = kernel(*new_args, **new_kwargs)
  2123. return example_output, tensor_args, non_tensor_args, unflatten_args
  2124. @classmethod
  2125. def convert_to_reinterpret_view(cls, x):
  2126. """
  2127. In order to pass this to an extern kernel we need a
  2128. ReinterpretView not a View. This allows us to avoid some
  2129. uneeded copies.
  2130. """
  2131. assert isinstance(x, BaseView)
  2132. if isinstance(x, ReinterpretView):
  2133. return x
  2134. x.unwrap_view().freeze_layout()
  2135. rw = extract_read_writes(x.make_loader(), x.get_size(), normalize=False)
  2136. assert len(rw.reads) == 1
  2137. index = V.graph.sizevars.simplify_with_ranges(
  2138. list(rw.reads)[0].index, rw.var_ranges
  2139. )
  2140. strides = V.graph.sizevars.stride_vars(index, rw.range_vars)
  2141. offset = V.graph.sizevars.offset_var(index, rw.range_vars)
  2142. expected = sympy_dot(rw.range_vars, strides) + offset
  2143. if index != expected:
  2144. log.debug(
  2145. "convert_to_reinterpret_view failed: stride=%s offset=%s index=%s",
  2146. strides,
  2147. offset,
  2148. index,
  2149. )
  2150. raise NotImplementedError()
  2151. return ReinterpretView(
  2152. data=x.data,
  2153. layout=FixedLayout(
  2154. device=x.get_device(),
  2155. dtype=x.get_dtype(),
  2156. size=x.get_size(),
  2157. stride=strides,
  2158. offset=offset,
  2159. ),
  2160. )
  2161. @classmethod
  2162. def realize_input(cls, x):
  2163. if x is None:
  2164. return NoneAsConstantBuffer()
  2165. if isinstance(x, (sympy.Expr, int)):
  2166. return ShapeAsConstantBuffer(x)
  2167. if isinstance(x, Constant):
  2168. return V.graph.add_tensor_constant(
  2169. torch.tensor(x.value, dtype=x.get_dtype(), device=x.get_device())
  2170. )
  2171. if isinstance(x, ConstantBuffer):
  2172. return x
  2173. if isinstance(x, TensorBox):
  2174. return cls.realize_input(x.data)
  2175. if isinstance(x, ReinterpretView):
  2176. return x
  2177. if isinstance(x, BaseView):
  2178. x.realize()
  2179. if is_storage_and_layout(x.unwrap_view()) and not isinstance(
  2180. x.unwrap_view().data, ExternKernelAlloc
  2181. ):
  2182. try:
  2183. return cls.convert_to_reinterpret_view(x)
  2184. except NotImplementedError:
  2185. pass
  2186. if isinstance(x, StorageBox):
  2187. # TODO(jansel): impose layout preference on realized buffer
  2188. x.realize()
  2189. return x
  2190. return cls.copy_input(x)
  2191. @classmethod
  2192. def require_stride1(cls, x):
  2193. if is_storage_and_layout(x):
  2194. if len(x.get_stride()) == 0:
  2195. return x
  2196. for stride in x.get_stride():
  2197. if stride == 1:
  2198. return x
  2199. return cls.copy_input(x)
  2200. @classmethod
  2201. def require_stride_order(cls, x, order):
  2202. if x.get_numel() == 0: # Layout doesn't matter
  2203. return x
  2204. # require x to have the layout as strided_ordered as order
  2205. if is_storage_and_layout(x):
  2206. if isinstance(x.get_layout(), FlexibleLayout):
  2207. # fix flexiblelayout to be FixedLayout with stride_order
  2208. as_storage_and_layout(
  2209. x, freeze=True, want_contiguous=False, stride_order=order
  2210. )
  2211. return x
  2212. elif isinstance(
  2213. x.get_layout(), FixedLayout
  2214. ) and x.get_layout().is_stride_ordered(order):
  2215. return x
  2216. elif isinstance(x.get_layout(), MutationLayout):
  2217. if isinstance(x.get_layout().real_layout(), FlexibleLayout):
  2218. raise AssertionError(
  2219. "the MutationLayout's real layout shouldn't be FlexibleLayout"
  2220. )
  2221. elif isinstance(
  2222. x.get_layout().real_layout(), FixedLayout
  2223. ) and x.get_layout().real_layout().is_stride_ordered(order):
  2224. return x
  2225. # TODO - Storage to InputBuffer
  2226. if isinstance(x, InputBuffer) and x.get_layout().is_stride_ordered(order):
  2227. return x
  2228. x = cls.copy_input(x)
  2229. as_storage_and_layout(x, freeze=True, want_contiguous=False, stride_order=order)
  2230. assert is_stride_order_storage_and_layout(x, order)
  2231. return x
  2232. @classmethod
  2233. def require_contiguous(cls, x):
  2234. return cls.require_stride_order(x, list(reversed(range(len(x.get_size())))))
  2235. def apply_constraint(self):
  2236. pass
  2237. def codegen_args(self):
  2238. args = [x.codegen_reference() for x in self.inputs]
  2239. args.extend(map(repr, self.constant_args))
  2240. return args
  2241. def codegen_kwargs(self):
  2242. kwargs = []
  2243. if self.kwargs:
  2244. kwargs = [f"{k}={repr(v)}" for k, v in self.kwargs.items()]
  2245. return kwargs
  2246. def cpp_wrapper_codegen_kwargs(self):
  2247. kwargs = []
  2248. if self.kwargs:
  2249. for arg_name in self.ordered_kwargs_for_cpp_kernel:
  2250. assert arg_name in self.kwargs, (
  2251. "arg %s not found in self.kwargs" % arg_name
  2252. )
  2253. v = self.kwargs.get(arg_name)
  2254. kwargs.append(repr(v))
  2255. return kwargs
  2256. def codegen_size_asserts(self, wrapper):
  2257. if config.size_asserts:
  2258. size = V.graph.sizevars.codegen_shape_tuple(self.get_size())
  2259. stride = V.graph.sizevars.codegen_shape_tuple(self.get_stride())
  2260. wrapper.writeline(
  2261. f"assert_size_stride({self.get_name()}, {size}, {stride})"
  2262. )
  2263. def get_group_stride(self):
  2264. """
  2265. get output sizes and strides, for template_codegen
  2266. """
  2267. _size = self.get_size()
  2268. _stride = self.get_stride()
  2269. # iter_ranges = _size of output tensor, reduce_range = [] because no reduction
  2270. return [_size, []], _stride
  2271. def canonicalize(self):
  2272. """
  2273. Manually get cononicalization of the output index
  2274. """
  2275. # manually generate index formula for conv
  2276. sizevars = V.graph.sizevars
  2277. sizes = self.get_size()
  2278. strides = self.get_stride()
  2279. strides = [sizevars.size_hint(x) for x in strides]
  2280. index_vars = [sympy_symbol(f"d{i}") for i in range(len(sizes))]
  2281. # reorder index vars according to stride
  2282. index_order = sorted(range(len(strides)), key=strides.__getitem__, reverse=True)
  2283. lookup = {pos: idx for idx, pos in enumerate(index_order)}
  2284. order = [lookup[i] for i in range(len(lookup))]
  2285. index_vars = [index_vars[i] for i in order]
  2286. indexer = self.make_indexer()
  2287. index = indexer(index_vars)
  2288. new_sizes, reindex, prune = V.graph.sizevars._simplify_loops(
  2289. index_vars, sizes, [index]
  2290. )
  2291. # assign new variables each dimension to deal with numbering mismatches
  2292. # d0, d1, d2 could become d0, d2 -- which won't match d0, d1
  2293. _, add_var = var_builder("c")
  2294. replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes])))
  2295. index = sympy_subs(sympy.expand(index), replacement)
  2296. return index, tuple(new_sizes)
  2297. def __str__(self):
  2298. lines = [
  2299. f"{field.name}={getattr(self, field.name)}"
  2300. for field in dataclasses.fields(self)
  2301. ]
  2302. return self.str_helper(lines)
  2303. @dataclasses.dataclass
  2304. class ExternKernelOut(ExternKernel):
  2305. output_view: Optional[ReinterpretView] = None
  2306. def codegen(self, wrapper):
  2307. args = self.codegen_args()
  2308. from torch._inductor.codegen.wrapper import CppWrapperCodeGen
  2309. if isinstance(wrapper, CppWrapperCodeGen):
  2310. kwargs = self.cpp_wrapper_codegen_kwargs()
  2311. else:
  2312. kwargs = self.codegen_kwargs()
  2313. if kwargs:
  2314. args.extend(kwargs)
  2315. wrapper.generate_extern_kernel_out(
  2316. self.output_view,
  2317. self.codegen_reference(),
  2318. args,
  2319. self.kernel,
  2320. self.cpp_kernel,
  2321. )
  2322. def __init__(
  2323. self,
  2324. layout,
  2325. inputs,
  2326. constant_args=(),
  2327. kwargs=None,
  2328. output_view=None,
  2329. kernel=None,
  2330. cpp_kernel=None,
  2331. ):
  2332. super().__init__(
  2333. None, layout, self.unwrap_storage(inputs), constant_args, kwargs or {}
  2334. )
  2335. self.output_view = output_view
  2336. self.name = V.graph.register_buffer(self)
  2337. if kernel is not None:
  2338. self.kernel = kernel
  2339. self.cpp_kernel = cpp_kernel
  2340. def should_allocate(self):
  2341. return True
  2342. class ExternKernelAlloc(ExternKernel):
  2343. def codegen(self, wrapper):
  2344. wrapper.writeline(
  2345. f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})"
  2346. )
  2347. if isinstance(self.layout, Layout):
  2348. self.codegen_size_asserts(wrapper)
  2349. def __init__(self, layout, inputs, constant_args=()):
  2350. super().__init__(None, layout, self.unwrap_storage(inputs), constant_args)
  2351. self.name = V.graph.register_buffer(self)
  2352. def should_allocate(self):
  2353. return False
  2354. def apply_constraint(self):
  2355. raise NotImplementedError
  2356. class InplaceBernoulliFallback(ExternKernel):
  2357. """
  2358. This needs to be a custom class to handle mutation properly
  2359. """
  2360. kernel = "aten.bernoulli_"
  2361. def codegen(self, wrapper):
  2362. (x,) = [t.codegen_reference() for t in self.inputs]
  2363. wrapper.writeline(
  2364. f"{self.kernel}({x}, {', '.join(map(repr, self.constant_args))})"
  2365. )
  2366. def should_allocate(self):
  2367. return False
  2368. def get_mutation_names(self):
  2369. assert isinstance(self.layout, MutationLayout)
  2370. return (self.layout.target.get_name(),)
  2371. def __init__(self, x, *constant_args):
  2372. super().__init__(
  2373. None,
  2374. MutationLayout(x),
  2375. self.unwrap_storage([x]),
  2376. constant_args,
  2377. )
  2378. self.name = V.graph.register_buffer(self)
  2379. class IndexPutFallback(ExternKernel):
  2380. """
  2381. This needs to be a custom class to handle mutation and indices properly
  2382. """
  2383. kernel = "aten.index_put_"
  2384. def codegen(self, wrapper):
  2385. (x, values, *valid_indices) = [t.codegen_reference() for t in self.inputs]
  2386. indices = []
  2387. iter_valid_indices = iter(valid_indices)
  2388. for i, _ in enumerate(self.indices):
  2389. if self.indices[i] is not None:
  2390. indices.append(next(iter_valid_indices))
  2391. else:
  2392. indices.append("None")
  2393. wrapper.writeline(
  2394. f"{self.kernel}({x}, [{','.join(indices)}], {values}, {repr(self.constant_args[0])})"
  2395. )
  2396. def should_allocate(self):
  2397. return False
  2398. def __init__(self, x, indices, values, accumulate):
  2399. self.indices = indices
  2400. valid_indices = [i for i in indices if i is not None]
  2401. tensors = [self.realize_input(x) for x in [x, values, *valid_indices]]
  2402. super().__init__(
  2403. None,
  2404. MutationLayout(x),
  2405. self.unwrap_storage(tensors),
  2406. [accumulate],
  2407. )
  2408. self.name = V.graph.register_buffer(self)
  2409. class DeviceCopy(ExternKernelOut):
  2410. @classmethod
  2411. def create(cls, x, device):
  2412. if not x.is_extern() and all(
  2413. (r.name in V.graph.constants and hasattr(r, "index")) for r in x.get_reads()
  2414. ):
  2415. return x.constant_to_device(device)
  2416. V.graph.device_types.add(device.type)
  2417. V.graph.device_types.add(x.get_device().type)
  2418. developer_warning("DeviceCopy in input program")
  2419. return DeviceCopy(
  2420. FlexibleLayout(
  2421. device=device,
  2422. dtype=x.get_dtype(),
  2423. size=x.get_size(),
  2424. ),
  2425. [cls.realize_input(x)],
  2426. )
  2427. def codegen(self, wrapper):
  2428. args = self.codegen_args()
  2429. assert len(args) == 1
  2430. if self.output_view:
  2431. wrapper.writeline(
  2432. f"{self.output_view.codegen_reference()}.copy_({args[0]})"
  2433. )
  2434. else:
  2435. wrapper.writeline(f"{self.codegen_reference()}.copy_({args[0]})")
  2436. class DynamicScalar(IRNode):
  2437. """
  2438. The result of a call to aten._local_scalar_dense.
  2439. This is not yet implemented. The one model (so far) that calls this
  2440. (fastNLP_Bert) does not actually use the result. So we expect this
  2441. node to get dead code eliminated.
  2442. """
  2443. def get_reads(self):
  2444. return ()
  2445. @dataclasses.dataclass
  2446. class FallbackKernel(ExternKernelAlloc):
  2447. def __init__(
  2448. self,
  2449. layout,
  2450. kernel,
  2451. tensor_args,
  2452. nontensor_args,
  2453. unflatten_args,
  2454. kwargs=None,
  2455. ):
  2456. super().__init__(
  2457. layout,
  2458. tuple(tensor_args),
  2459. tuple(nontensor_args),
  2460. )
  2461. if getattr(torch.ops.aten, kernel.__name__, None) is kernel:
  2462. self.kernel = f"aten.{kernel.__name__}"
  2463. else:
  2464. self.kernel = (
  2465. f"{kernel.__module__.replace('._ops.', '.ops.')}.{kernel.__name__}"
  2466. )
  2467. self.unflatten_args = unflatten_args
  2468. self.kwargs = {} if kwargs is None else kwargs
  2469. V.graph.warn_fallback(self.kernel)
  2470. def codegen_args(self):
  2471. @dataclasses.dataclass
  2472. class Shim:
  2473. ref: Any
  2474. def __repr__(self):
  2475. return self.ref
  2476. def gen_kwarg(k, v):
  2477. return f"{k}={repr(v)}"
  2478. tensor_args = [Shim(x.codegen_reference()) for x in self.inputs]
  2479. constant_args = [Shim(repr(x)) for x in self.constant_args]
  2480. args, kwargs = self.unflatten_args(tensor_args, constant_args)
  2481. return list(map(repr, args)) + [gen_kwarg(k, v) for k, v in kwargs.items()]
  2482. @classmethod
  2483. def create(cls, kernel, *args, **kwargs):
  2484. fake_incorrect_kernels = (
  2485. aten._fft_r2c.default,
  2486. aten._fft_r2c.out,
  2487. aten._fft_c2r.default,
  2488. aten._fft_c2c.default,
  2489. aten._fft_c2c.out,
  2490. aten._linalg_svd.default,
  2491. aten._linalg_svd.U,
  2492. aten._fused_moving_avg_obs_fq_helper_functional,
  2493. )
  2494. context = (
  2495. V.graph.fake_mode if kernel not in fake_incorrect_kernels else nullcontext()
  2496. )
  2497. with context:
  2498. (
  2499. example_output,
  2500. tensor_args,
  2501. non_tensor_args,
  2502. unflatten_args,
  2503. ) = cls.process_kernel(kernel, *args, **kwargs)
  2504. assert tensor_args or isinstance(
  2505. example_output, torch.Tensor
  2506. ), "Not sure where to find device info"
  2507. packed = FallbackKernel(
  2508. MultiOutputLayout(
  2509. tensor_args[0].get_device() if tensor_args else example_output.device
  2510. ),
  2511. kernel,
  2512. tensor_args,
  2513. non_tensor_args,
  2514. unflatten_args,
  2515. kwargs,
  2516. )
  2517. def generate_output(output, index=""):
  2518. if isinstance(output, (list, tuple)):
  2519. return type(output)(
  2520. generate_output(output[i], f"{index}[{i}]")
  2521. for i in range(len(output))
  2522. )
  2523. elif isinstance(output, torch.Tensor):
  2524. return MultiOutput(
  2525. FixedLayout(
  2526. output.device,
  2527. output.dtype,
  2528. convert_shape_to_inductor(output.size()),
  2529. convert_shape_to_inductor(output.stride()),
  2530. ),
  2531. packed,
  2532. index,
  2533. )
  2534. elif isinstance(output, int):
  2535. return output
  2536. else:
  2537. assert output is None, "FallbackKernel output type is not supported"
  2538. return None
  2539. return generate_output(example_output)
  2540. def apply_constraint(self):
  2541. return super().apply_constraint()
  2542. @dataclasses.dataclass
  2543. class MultiOutputLayout(IRNode):
  2544. device: torch.device
  2545. class MultiOutput(ExternKernel):
  2546. def codegen(self, wrapper):
  2547. wrapper.writeline(
  2548. f"{self.get_name()} = {self.inputs[0].get_name()}{self.index}"
  2549. )
  2550. self.codegen_size_asserts(wrapper)
  2551. def __init__(self, layout, input, index: str):
  2552. super().__init__(None, layout, [input], ())
  2553. self.name = V.graph.register_buffer(self)
  2554. self.index = index
  2555. def should_allocate(self):
  2556. return False
  2557. class Convolution(ExternKernelAlloc):
  2558. kernel = "aten.convolution"
  2559. def __init__(
  2560. self,
  2561. layout,
  2562. inputs,
  2563. constant_args=(),
  2564. preferred_stride_order=None,
  2565. kernel="aten.convolution",
  2566. ):
  2567. super().__init__(layout, inputs, constant_args)
  2568. self.kernel = kernel
  2569. self.preferred_stride_order = preferred_stride_order
  2570. def codegen(self, wrapper):
  2571. if self.kernel.startswith("triton_ops."):
  2572. wrapper.header.writeline("from torch._inductor import triton_ops")
  2573. wrapper.writeline(
  2574. f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})"
  2575. )
  2576. if isinstance(self.layout, Layout):
  2577. self.codegen_size_asserts(wrapper)
  2578. @classmethod
  2579. def create(
  2580. cls,
  2581. x: "TensorBox",
  2582. weight: "TensorBox",
  2583. bias: "TensorBox",
  2584. stride_: List[int],
  2585. padding_: List[int],
  2586. dilation_: List[int],
  2587. transposed: bool,
  2588. output_padding_: List[int],
  2589. groups: int,
  2590. ):
  2591. with V.graph.fake_mode:
  2592. x_fake = ir_node_to_tensor(x, guard_shape=True)
  2593. weight_fake = ir_node_to_tensor(weight, guard_shape=True)
  2594. bias_fake = (
  2595. ir_node_to_tensor(bias, guard_shape=True) if bias is not None else bias
  2596. )
  2597. output = torch.ops.aten.convolution(
  2598. x_fake,
  2599. weight_fake,
  2600. bias_fake,
  2601. stride_,
  2602. padding_,
  2603. dilation_,
  2604. transposed,
  2605. output_padding_,
  2606. groups,
  2607. )
  2608. req_stride_order = get_stride_order(output.stride())
  2609. weight = cls.require_stride_order(weight, req_stride_order)
  2610. x = cls.require_stride_order(x, req_stride_order)
  2611. stride = tuple(stride_)
  2612. padding = tuple(padding_)
  2613. dilation = tuple(dilation_)
  2614. assert isinstance(transposed, bool)
  2615. output_padding = tuple(output_padding_)
  2616. assert isinstance(groups, int)
  2617. output_size = output.shape
  2618. weight_shape = [
  2619. sympy.Integer(V.graph.sizevars.guard_static_shape(s))
  2620. for s in weight.get_size()
  2621. ]
  2622. _, _, *kernel_size = weight_shape
  2623. # choose runtime kernel
  2624. config_conv = "aten"
  2625. if (
  2626. config_conv == "aten"
  2627. or len(kernel_size) != 2 # triton conv only supports conv2d
  2628. or not is_triton(x.get_device())
  2629. or transposed
  2630. or groups != 1
  2631. # or x.get_dtype() == torch.float16
  2632. # or x.get_dtype() == torch.bfloat16
  2633. ):
  2634. kernel = "aten.convolution"
  2635. elif config_conv == "triton":
  2636. kernel = "triton_ops.conv"
  2637. else:
  2638. assert config_conv == "autotune"
  2639. from .codegen.autotuner import tuned_conv
  2640. kernel = tuned_conv(
  2641. x.get_size(),
  2642. weight.get_size(),
  2643. x.get_stride(),
  2644. weight.get_stride(),
  2645. stride,
  2646. padding,
  2647. dilation,
  2648. transposed,
  2649. output_padding,
  2650. groups,
  2651. x.get_device(),
  2652. x.get_dtype(),
  2653. )
  2654. # for conv2d or conv3d, prefer channels last format
  2655. transform_x_layout = False
  2656. if kernel == "triton_ops.conv":
  2657. output_layout_str = "torch.channels_last"
  2658. else:
  2659. output_layout_str = (
  2660. "torch.contiguous_format"
  2661. if output.is_contiguous()
  2662. else "torch.channels_last"
  2663. )
  2664. if config.tune_layout and len(x.get_size()) == 4:
  2665. from .codegen.autotuner import tuned_conv_layout
  2666. faster_output_layout_str = tuned_conv_layout(
  2667. kernel,
  2668. x.get_size(),
  2669. weight.get_size(),
  2670. stride,
  2671. padding,
  2672. dilation,
  2673. transposed,
  2674. output_padding,
  2675. groups,
  2676. x.get_device(),
  2677. x.get_dtype(),
  2678. )
  2679. if faster_output_layout_str != output_layout_str:
  2680. output_layout_str = faster_output_layout_str
  2681. transform_x_layout = True
  2682. if output_layout_str == "torch.channels_last":
  2683. stride_order = [0] + list(reversed(range(1, len(kernel_size) + 1)))
  2684. if len(stride_order) < len(output_size):
  2685. # add batch dim if it exists
  2686. stride_order = [len(stride_order)] + stride_order
  2687. strides = make_channels_last_strides_for(output_size)
  2688. else:
  2689. stride_order = list(reversed(range(len(output_size))))
  2690. strides = make_contiguous_strides_for(output_size)
  2691. if transform_x_layout:
  2692. x = cls.require_stride_order(x, stride_order)
  2693. output_layout = FixedLayout(
  2694. x.get_device(),
  2695. x.get_dtype(),
  2696. convert_shape_to_inductor(output_size),
  2697. convert_shape_to_inductor(strides),
  2698. )
  2699. if bias is not None:
  2700. return Convolution(
  2701. output_layout,
  2702. (x, weight, bias),
  2703. (stride, padding, dilation, transposed, output_padding, groups),
  2704. stride_order,
  2705. kernel,
  2706. )
  2707. else:
  2708. return Convolution(
  2709. output_layout,
  2710. (x, weight),
  2711. (bias, stride, padding, dilation, transposed, output_padding, groups),
  2712. stride_order,
  2713. kernel,
  2714. )
  2715. def map_args(self):
  2716. # x, w, bias
  2717. in_args = [x.codegen_reference() for x in self.inputs]
  2718. # stride, padding, dilation, transposed, output_padding, groups
  2719. const_args = self.constant_args
  2720. if len(in_args) < 3:
  2721. # otherwise, bias=None is the first constant_args
  2722. const_args = const_args[1:]
  2723. inout_dict = OrderedDict(
  2724. [
  2725. ("x", f"{in_args[0]}"),
  2726. ("w", f"{in_args[1]}"),
  2727. ("y", f"{self.get_name()}"),
  2728. ]
  2729. )
  2730. args_dict = OrderedDict(
  2731. [
  2732. ("stride_xn", f"{self.inputs[0].get_stride()[0]}"),
  2733. ("stride_xc", f"{self.inputs[0].get_stride()[1]}"),
  2734. ("stride_xh", f"{self.inputs[0].get_stride()[2]}"),
  2735. ("stride_xw", f"{self.inputs[0].get_stride()[3]}"),
  2736. ("stride_wn", f"{self.inputs[1].get_stride()[0]}"),
  2737. ("stride_wc", f"{self.inputs[1].get_stride()[1]}"),
  2738. ("stride_wh", f"{self.inputs[1].get_stride()[2]}"),
  2739. ("stride_ww", f"{self.inputs[1].get_stride()[3]}"),
  2740. ("stride_yn", f"{self.get_stride()[0]}"),
  2741. ("stride_yc", f"{self.get_stride()[1]}"),
  2742. ("stride_yh", f"{self.get_stride()[2]}"),
  2743. ("stride_yw", f"{self.get_stride()[3]}"),
  2744. (
  2745. "stride_biasn",
  2746. f"{self.inputs[0].get_stride()[0]}"
  2747. if len(in_args) >= 3
  2748. else "None",
  2749. ),
  2750. # ("delta_x_ptr", "None"),
  2751. ("BATCH", f"{self.inputs[0].get_size()[0]}"),
  2752. ("IN_C", f"{self.inputs[0].get_size()[1]}"),
  2753. ("IN_H", f"{self.inputs[0].get_size()[2]}"),
  2754. ("IN_W", f"{self.inputs[0].get_size()[3]}"),
  2755. ("KERNEL_N", f"{self.inputs[1].get_size()[0]}"),
  2756. ("KERNEL_H", f"{self.inputs[1].get_size()[2]}"),
  2757. ("KERNEL_W", f"{self.inputs[1].get_size()[3]}"),
  2758. ("OUT_H", f"{self.get_size()[2]}"),
  2759. ("OUT_W", f"{self.get_size()[3]}"),
  2760. ("stride_h", f"{const_args[0][0]}"),
  2761. ("stride_w", f"{const_args[0][1]}"),
  2762. ("padding_h", f"{const_args[1][0]}"),
  2763. ("padding_w", f"{const_args[1][1]}"),
  2764. ("dilation_h", f"{const_args[2][0]}"),
  2765. ("dilation_w", f"{const_args[2][1]}"),
  2766. # ("transposed", f"{const_args[3]}"),
  2767. ("output_padding_h", f"{const_args[4][0]}"),
  2768. ("output_padding_w", f"{const_args[4][1]}"),
  2769. ("groups", f"{const_args[5]}"),
  2770. ]
  2771. )
  2772. # accumulator type
  2773. ACC_TYPE = (
  2774. "tl.float32"
  2775. if self.inputs[0].get_dtype()
  2776. in [torch.float16, torch.bfloat16, torch.float32]
  2777. else "tl.int32"
  2778. )
  2779. CONV1X1_NHWC = (
  2780. "True"
  2781. if self.inputs[0].get_stride()[1] == 1
  2782. and self.inputs[1].get_size()[2] == 1
  2783. and self.inputs[1].get_size()[3] == 1
  2784. else "False"
  2785. )
  2786. # dict for tl.constexpr
  2787. const_dict = OrderedDict(
  2788. [
  2789. ("ACC_TYPE", ACC_TYPE),
  2790. ("CONV1X1_NHWC", CONV1X1_NHWC),
  2791. ]
  2792. )
  2793. # dict for non-kernel args (e.g. delta_x_ptr)
  2794. other_dict = OrderedDict(
  2795. [
  2796. ("device", f'"{self.inputs[0].get_device()}"'),
  2797. ]
  2798. )
  2799. return inout_dict, args_dict, const_dict, other_dict
  2800. def get_template_tiling(self):
  2801. n, c, h, w = self.get_size()
  2802. return (
  2803. n * h * w,
  2804. c,
  2805. sympy.Integer(1),
  2806. )
  2807. def _prepare_convolution_fusion_create(
  2808. cls,
  2809. x: "TensorBox",
  2810. weight: "TensorBox",
  2811. bias: "TensorBox",
  2812. padding_: List[int],
  2813. stride_: List[int],
  2814. dilation_: List[int],
  2815. groups: int,
  2816. transposed: bool = False,
  2817. output_padding_: List[int] = None,
  2818. ):
  2819. """
  2820. This function is a helper function to prepare inputs, layout and constant args
  2821. for convolution post-op fusion's create function, including deciding the output
  2822. layout (channels first or channels last), realizing inputs and make them etc. The
  2823. function only supports the CPU device since conv post-op fusion kernel is only
  2824. supported on CPU right now.
  2825. """
  2826. # Port from aten/src/ATen/native/ConvUtils.h: _conv_input_size
  2827. def _conv_input_size(
  2828. output_size, weight_size, padding, output_padding, stride, dilation, groups
  2829. ):
  2830. assert len(output_size) == len(weight_size), "Expect input dim == weight dim"
  2831. dim = len(output_size)
  2832. assert dim > 2, "Expect input dim > 2"
  2833. BATCH_DIM = 0
  2834. WEIGHT_INPUT_CHANNELS_DIM = 1
  2835. input_size = []
  2836. input_size.append(output_size[BATCH_DIM])
  2837. input_size.append(weight_size[WEIGHT_INPUT_CHANNELS_DIM] * groups)
  2838. for d in range(2, dim):
  2839. kernel = (weight_size[d] - 1) * dilation[d - 2] + 1
  2840. input_size_d = (
  2841. (output_size[d] - 1) * stride[d - 2]
  2842. - (padding[d - 2] * 2)
  2843. + kernel
  2844. + output_padding[d - 2]
  2845. )
  2846. input_size.append(input_size_d)
  2847. return list(map(int, input_size))
  2848. # The size of prepacked_weight is the prepacked weight size of deconv:
  2849. # Groups > 1: [g*o, i/g, ...]
  2850. # Groups == 1: [o, i, ...]
  2851. # Returns original weight size in [i, o, ...]
  2852. def _original_deconv_weight_size(
  2853. prepacked_weight,
  2854. groups,
  2855. ):
  2856. prepacked_weight_size = prepacked_weight.size()
  2857. dim = len(prepacked_weight_size)
  2858. assert dim > 2, "Expect weight dim > 2"
  2859. if groups > 1:
  2860. weight_size = []
  2861. weight_size.append(prepacked_weight_size[1] * groups)
  2862. weight_size.append(prepacked_weight_size[0] / groups)
  2863. for d in range(2, dim):
  2864. weight_size.append(prepacked_weight_size[d])
  2865. else:
  2866. weight_size = prepacked_weight.transpose(0, 1).size()
  2867. return weight_size
  2868. stride = tuple(stride_)
  2869. padding = tuple(padding_)
  2870. dilation = tuple(dilation_)
  2871. assert isinstance(groups, int)
  2872. output_padding = tuple(output_padding_) if output_padding_ else (0, 0)
  2873. with V.graph.fake_mode:
  2874. x_fake = ir_node_to_tensor(x, guard_shape=True)
  2875. weight_fake = ir_node_to_tensor(weight, guard_shape=True)
  2876. if transposed:
  2877. # When transposed, the size of the prepacked oneDNN weight is different
  2878. # from the PyTorch weight. We're not able to run aten conv with such
  2879. # size. We infer the output size from the input params here:
  2880. weight_size = _original_deconv_weight_size(weight_fake, groups)
  2881. input_size = x_fake.size()
  2882. output_size = _conv_input_size(
  2883. input_size,
  2884. weight_size,
  2885. padding,
  2886. output_padding,
  2887. stride,
  2888. dilation,
  2889. groups,
  2890. )
  2891. else:
  2892. bias_fake = (
  2893. ir_node_to_tensor(bias, guard_shape=True) if bias is not None else bias
  2894. )
  2895. output = torch.ops.aten.convolution(
  2896. x_fake,
  2897. weight_fake,
  2898. bias_fake,
  2899. stride,
  2900. padding,
  2901. dilation,
  2902. transposed,
  2903. output_padding,
  2904. groups,
  2905. )
  2906. output_size = output.size()
  2907. req_stride_order = [0] + list(reversed(range(1, len(stride) + 1)))
  2908. req_stride_order = [len(req_stride_order)] + req_stride_order
  2909. output_stride = make_channels_last_strides_for(output_size)
  2910. x = cls.require_stride_order(x, req_stride_order)
  2911. assert x.get_device().type == "cpu" and weight.get_device().type == "cpu"
  2912. inputs = [x, weight]
  2913. kernel_layout = FixedLayout(
  2914. x.get_device(),
  2915. x.get_dtype(),
  2916. convert_shape_to_inductor(output_size),
  2917. convert_shape_to_inductor(output_stride),
  2918. )
  2919. constant_args = [padding, stride, dilation, groups]
  2920. if transposed:
  2921. constant_args.insert(1, output_padding)
  2922. if bias is not None:
  2923. inputs.append(bias)
  2924. else:
  2925. constant_args.insert(0, bias)
  2926. return inputs, constant_args, kernel_layout, req_stride_order
  2927. class ConvolutionUnary(ExternKernelAlloc):
  2928. kernel = "torch.ops.mkldnn._convolution_pointwise"
  2929. def __init__(
  2930. self,
  2931. layout,
  2932. inputs,
  2933. constant_args=(),
  2934. kernel="torch.ops.mkldnn._convolution_pointwise",
  2935. ):
  2936. super().__init__(layout, inputs, constant_args)
  2937. self.kernel = kernel
  2938. def codegen(self, wrapper):
  2939. wrapper.writeline(
  2940. f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})"
  2941. )
  2942. if isinstance(self.layout, Layout):
  2943. self.codegen_size_asserts(wrapper)
  2944. @classmethod
  2945. def create(
  2946. cls,
  2947. x: "TensorBox",
  2948. weight: "TensorBox",
  2949. bias: "TensorBox",
  2950. padding_: List[int],
  2951. stride_: List[int],
  2952. dilation_: List[int],
  2953. groups: int,
  2954. attr,
  2955. scalars,
  2956. algorithm,
  2957. ):
  2958. kernel = "torch.ops.mkldnn._convolution_pointwise"
  2959. (inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create(
  2960. cls, x, weight, bias, padding_, stride_, dilation_, groups
  2961. )
  2962. constant_args = constant_args + [attr, scalars, algorithm]
  2963. return ConvolutionUnary(
  2964. layout=kernel_layout,
  2965. inputs=inputs,
  2966. constant_args=constant_args,
  2967. kernel=kernel,
  2968. )
  2969. class ConvolutionBinary(ExternKernelAlloc):
  2970. kernel = "torch.ops.mkldnn._convolution_pointwise.binary"
  2971. def __init__(
  2972. self,
  2973. layout,
  2974. inputs,
  2975. constant_args=(),
  2976. kernel="torch.ops.mkldnn._convolution_pointwise.binary",
  2977. ):
  2978. super().__init__(layout, inputs, constant_args)
  2979. self.kernel = kernel
  2980. def codegen(self, wrapper):
  2981. wrapper.writeline(
  2982. f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})"
  2983. )
  2984. if isinstance(self.layout, Layout):
  2985. self.codegen_size_asserts(wrapper)
  2986. @classmethod
  2987. def create(
  2988. cls,
  2989. x: "TensorBox",
  2990. other: "TensorBox",
  2991. weight: "TensorBox",
  2992. bias: "TensorBox",
  2993. padding_: List[int],
  2994. stride_: List[int],
  2995. dilation_: List[int],
  2996. groups: int,
  2997. binary_attr: str,
  2998. binary_alpha: Optional[float],
  2999. unary_attr: Optional[str],
  3000. unary_scalars: Optional[List],
  3001. unary_algorithm: Optional[str],
  3002. ):
  3003. kernel = "torch.ops.mkldnn._convolution_pointwise.binary"
  3004. (
  3005. inputs,
  3006. constant_args,
  3007. kernel_layout,
  3008. req_stride_order,
  3009. ) = _prepare_convolution_fusion_create(
  3010. cls, x, weight, bias, padding_, stride_, dilation_, groups
  3011. )
  3012. other = cls.require_stride_order(other, req_stride_order)
  3013. inputs.insert(1, other)
  3014. constant_args = constant_args + [
  3015. binary_attr,
  3016. binary_alpha,
  3017. unary_attr,
  3018. unary_scalars,
  3019. unary_algorithm,
  3020. ]
  3021. return ConvolutionBinary(
  3022. layout=kernel_layout,
  3023. inputs=inputs,
  3024. constant_args=constant_args,
  3025. kernel=kernel,
  3026. )
  3027. class ConvolutionBinaryInplace(ExternKernelAlloc):
  3028. kernel = "torch.ops.mkldnn._convolution_pointwise_.binary"
  3029. def __init__(
  3030. self,
  3031. kernel_layout,
  3032. inputs,
  3033. constant_args=(),
  3034. kernel="torch.ops.mkldnn._convolution_pointwise_.binary",
  3035. ):
  3036. super().__init__(kernel_layout, inputs, constant_args)
  3037. self.kernel = kernel
  3038. def codegen(self, wrapper):
  3039. wrapper.writeline(
  3040. f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})"
  3041. )
  3042. def get_mutation_names(self):
  3043. assert isinstance(self.layout, MutationLayout)
  3044. return (self.layout.target.get_name(),)
  3045. @classmethod
  3046. def create(
  3047. cls,
  3048. x: "TensorBox",
  3049. other: "TensorBox",
  3050. weight: "TensorBox",
  3051. bias: "TensorBox",
  3052. padding_: List[int],
  3053. stride_: List[int],
  3054. dilation_: List[int],
  3055. groups: int,
  3056. binary_attr: str,
  3057. binary_alpha: Optional[float],
  3058. unary_attr: Optional[str],
  3059. unary_scalars: Optional[List],
  3060. unary_algorithm: Optional[str],
  3061. ):
  3062. kernel = "torch.ops.mkldnn._convolution_pointwise_.binary"
  3063. (inputs, constant_args, _, _) = _prepare_convolution_fusion_create(
  3064. cls, x, weight, bias, padding_, stride_, dilation_, groups
  3065. )
  3066. other = cls.realize_input(other)
  3067. V.graph.realize_users_of(other.get_name())
  3068. inputs.insert(1, other)
  3069. constant_args = constant_args + [
  3070. binary_attr,
  3071. binary_alpha,
  3072. unary_attr,
  3073. unary_scalars,
  3074. unary_algorithm,
  3075. ]
  3076. return ConvolutionBinaryInplace(
  3077. kernel_layout=MutationLayout(inputs[1]),
  3078. inputs=inputs,
  3079. constant_args=constant_args,
  3080. kernel=kernel,
  3081. )
  3082. class MKLPackedLinear(ExternKernelAlloc):
  3083. kernel = "torch.ops.mkl._mkl_linear"
  3084. def __init__(
  3085. self,
  3086. layout,
  3087. inputs,
  3088. constant_args=(),
  3089. kernel="torch.ops.mkl._mkl_linear",
  3090. ):
  3091. super().__init__(layout, inputs, constant_args)
  3092. self.kernel = kernel
  3093. def codegen(self, wrapper):
  3094. wrapper.writeline(
  3095. f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})"
  3096. )
  3097. @classmethod
  3098. def create(cls, x, packed_w, orig_w, batch_size):
  3099. kernel = "torch.ops.mkl._mkl_linear"
  3100. x = cls.require_stride1(cls.realize_input(x))
  3101. orig_w = cls.require_stride1(cls.realize_input(orig_w))
  3102. *m, _ = x.get_size()
  3103. oc, _ = orig_w.get_size()
  3104. output_size = list(m) + [oc]
  3105. output_stride = make_contiguous_strides_for(output_size)
  3106. inputs = [x, packed_w, orig_w]
  3107. bias = None
  3108. constant_args = [bias, batch_size]
  3109. return MKLPackedLinear(
  3110. layout=FixedLayout(
  3111. x.get_device(), x.get_dtype(), output_size, output_stride
  3112. ),
  3113. inputs=inputs,
  3114. constant_args=constant_args,
  3115. kernel=kernel,
  3116. )
  3117. class LinearUnary(ExternKernelAlloc):
  3118. kernel = "torch.ops.mkldnn._linear_pointwise"
  3119. def __init__(
  3120. self,
  3121. layout,
  3122. inputs,
  3123. constant_args=(),
  3124. kernel="torch.ops.mkldnn._linear_pointwise",
  3125. ):
  3126. super().__init__(layout, inputs, constant_args)
  3127. self.kernel = kernel
  3128. def codegen(self, wrapper):
  3129. wrapper.writeline(
  3130. f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})"
  3131. )
  3132. @classmethod
  3133. def create(cls, x, w, b, attr, scalars, algorithm):
  3134. kernel = "torch.ops.mkldnn._linear_pointwise"
  3135. x = cls.require_stride1(cls.realize_input(x))
  3136. w = cls.require_stride1(cls.realize_input(w))
  3137. *m, ic = x.get_size()
  3138. oc, ic = w.get_size()
  3139. inputs = [x, w]
  3140. constant_args = [attr, scalars, algorithm]
  3141. if b is not None:
  3142. b = cls.require_stride1(cls.realize_input(b))
  3143. inputs.append(b)
  3144. else:
  3145. constant_args.insert(0, b)
  3146. return LinearUnary(
  3147. layout=FlexibleLayout(
  3148. device=x.get_device(),
  3149. dtype=x.get_dtype(),
  3150. size=list(m) + [oc],
  3151. ),
  3152. inputs=inputs,
  3153. constant_args=constant_args,
  3154. kernel=kernel,
  3155. )
  3156. def apply_constraint(self):
  3157. pass
  3158. class LinearBinary(ExternKernelAlloc):
  3159. kernel = "torch.ops.mkldnn._linear_pointwise.binary"
  3160. def __init__(
  3161. self,
  3162. layout,
  3163. inputs,
  3164. constant_args=(),
  3165. kernel="torch.ops.mkldnn._linear_pointwise.binary",
  3166. ):
  3167. super().__init__(layout, inputs, constant_args)
  3168. self.kernel = kernel
  3169. def codegen(self, wrapper):
  3170. wrapper.writeline(
  3171. f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})"
  3172. )
  3173. @classmethod
  3174. def create(cls, x, y, w, b, attr):
  3175. kernel = "torch.ops.mkldnn._linear_pointwise.binary"
  3176. x = cls.require_stride1(cls.realize_input(x))
  3177. y = cls.require_stride1(cls.realize_input(y))
  3178. w = cls.require_stride1(cls.realize_input(w))
  3179. *m, ic = x.get_size()
  3180. oc, ic = w.get_size()
  3181. inputs = [x, y, w]
  3182. constant_args = [attr]
  3183. if b is not None:
  3184. b = cls.require_stride1(cls.realize_input(b))
  3185. inputs.append(b)
  3186. else:
  3187. constant_args.insert(0, b)
  3188. return LinearBinary(
  3189. layout=FlexibleLayout(
  3190. device=x.get_device(),
  3191. dtype=x.get_dtype(),
  3192. size=list(m) + [oc],
  3193. ),
  3194. inputs=inputs,
  3195. constant_args=constant_args,
  3196. kernel=kernel,
  3197. )
  3198. def apply_constraint(self):
  3199. pass
  3200. class ConvolutionTransposeUnary(ExternKernelAlloc):
  3201. kernel = "torch.ops.mkldnn._convolution_transpose_pointwise"
  3202. def __init__(
  3203. self,
  3204. layout,
  3205. inputs,
  3206. constant_args=(),
  3207. kernel="torch.ops.mkldnn._convolution_transpose_pointwise",
  3208. ):
  3209. super().__init__(layout, inputs, constant_args)
  3210. self.kernel = kernel
  3211. def codegen(self, wrapper):
  3212. wrapper.writeline(
  3213. f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})"
  3214. )
  3215. @classmethod
  3216. def create(
  3217. cls,
  3218. x: "TensorBox",
  3219. weight: "TensorBox",
  3220. bias: "TensorBox",
  3221. padding_: List[int],
  3222. output_padding_: List[int],
  3223. stride_: List[int],
  3224. dilation_: List[int],
  3225. groups_: int,
  3226. attr,
  3227. scalars,
  3228. algorithm,
  3229. ):
  3230. kernel = "torch.ops.mkldnn._convolution_transpose_pointwise"
  3231. transposed = True
  3232. (inputs, constant_args, kernel_layout, _,) = _prepare_convolution_fusion_create(
  3233. cls,
  3234. x,
  3235. weight,
  3236. bias,
  3237. padding_,
  3238. stride_,
  3239. dilation_,
  3240. groups_,
  3241. transposed,
  3242. output_padding_,
  3243. )
  3244. constant_args = constant_args + [attr, scalars, algorithm]
  3245. return ConvolutionTransposeUnary(
  3246. layout=kernel_layout,
  3247. inputs=inputs,
  3248. constant_args=constant_args,
  3249. kernel=kernel,
  3250. )
  3251. @dataclasses.dataclass
  3252. class MutableBox(IRNode):
  3253. """
  3254. TensorBox / StorageBox allow in-place mutation of Tensors
  3255. """
  3256. data: IRNode
  3257. def __getattr__(self, name):
  3258. fn = getattr(self.data, name)
  3259. if callable(fn):
  3260. return fn
  3261. raise AttributeError(f"{type(self.data).__name__}.{name} not callable")
  3262. def __str__(self):
  3263. if isinstance(self.data, MutableBox):
  3264. line0 = f"{type(self).__name__}({type(self.data).__name__}("
  3265. endl = "))"
  3266. inner = self.data.data
  3267. else:
  3268. line0 = f"{type(self).__name__}("
  3269. inner = self.data
  3270. endl = ")"
  3271. lines = [
  3272. line0,
  3273. indent(str(inner)),
  3274. endl,
  3275. ]
  3276. return "\n".join(lines)
  3277. __repr__ = __str__
  3278. class TensorBox(MutableBox):
  3279. @staticmethod
  3280. def create(data):
  3281. return TensorBox(StorageBox(data))
  3282. class StorageBox(MutableBox):
  3283. def is_input_buffer(self):
  3284. if isinstance(self.data, (InputBuffer, ReinterpretView)):
  3285. return self.data.get_name() in V.graph.graph_inputs
  3286. return False
  3287. def realize(self):
  3288. if isinstance(
  3289. self.data,
  3290. (
  3291. ComputedBuffer,
  3292. InputsKernel,
  3293. InputBuffer,
  3294. ReinterpretView,
  3295. TemplateBuffer,
  3296. ),
  3297. ):
  3298. return self.data.get_name()
  3299. assert isinstance(self.data, (Pointwise, Reduction)), type(self.data)
  3300. self.data = ComputedBuffer(
  3301. name=None,
  3302. layout=FlexibleLayout(
  3303. device=self.data.get_device(),
  3304. dtype=self.data.get_dtype(),
  3305. size=self.data.get_size(),
  3306. ),
  3307. data=self.data,
  3308. )
  3309. self.data.name = V.graph.register_buffer(self.data)
  3310. self.data.origins = self.origins
  3311. return self.data.name
  3312. def realize_hint(self):
  3313. """
  3314. Called on buffers we expect to be forced to realize later.
  3315. """
  3316. if isinstance(self.data, (Pointwise, Reduction)) and self.num_reads() > 1:
  3317. self.realize()
  3318. def has_exceeded_max_reads(self):
  3319. return isinstance(self.data, Pointwise) and (
  3320. self.num_reads() > config.realize_acc_reads_threshold
  3321. or len(self.inner_fn_str()) > config.realize_bytes_threshold
  3322. )
  3323. def mark_reuse(self, users):
  3324. """
  3325. A heuristic to decide if we should realize a tensor
  3326. that is used multiple times.
  3327. """
  3328. def should_realize_on_cpu(loops: Union[Pointwise, Reduction]):
  3329. """
  3330. The heuristic for realizing reused result of heavy ops on cpu
  3331. """
  3332. heavy_ops = ["exp"] # a list of heavy ops
  3333. fn_str = loops.inner_fn_str()
  3334. return any([(op + "(") in fn_str for op in heavy_ops])
  3335. if (
  3336. users > 1
  3337. and isinstance(self.data, (Pointwise, Reduction))
  3338. and (
  3339. self.num_reads() > config.realize_reads_threshold
  3340. or len(self.inner_fn_str()) > config.realize_bytes_threshold
  3341. or (is_cpu(self.data) and should_realize_on_cpu(self.data))
  3342. )
  3343. ):
  3344. self.realize()
  3345. @cache_on_self
  3346. def num_reads(self):
  3347. data = self.data
  3348. if isinstance(data, (InputsKernel, InputBuffer, ReinterpretView)):
  3349. return 1
  3350. if isinstance(data, ComputedBuffer):
  3351. read_writes = data.get_read_writes()
  3352. else:
  3353. assert isinstance(data, (Pointwise, Reduction)), type(data)
  3354. read_writes = ComputedBuffer(
  3355. name=None,
  3356. layout=FlexibleLayout(
  3357. device=data.get_device(),
  3358. dtype=data.get_dtype(),
  3359. size=data.get_size(),
  3360. ),
  3361. data=data,
  3362. ).get_read_writes()
  3363. return len(read_writes.reads)
  3364. class InterpreterShim(torch.fx.Interpreter):
  3365. def __init__(self, graph, submodules):
  3366. """
  3367. We don't call super() here to avoid constructing a
  3368. GraphModule which is very expensive (it does codegen).
  3369. """
  3370. self.module = self
  3371. self.graph = graph
  3372. self.submodules = submodules
  3373. self.garbage_collect_values = False
  3374. self.env = {}
  3375. self.fetch_attr = submodules.__getitem__
  3376. self.name = "InterpreterShim"
  3377. self.current_node = None
  3378. def run_node(self, n: torch.fx.Node) -> Any:
  3379. self.current_node = n
  3380. return super().run_node(n)
  3381. def run(self, *args, **kwargs):
  3382. with V.set_interpreter_handler(self):
  3383. return super().run(*args, **kwargs)
  3384. class LoopBody:
  3385. """
  3386. Captures the body of a Loops subclass into an FX graph. Persists any
  3387. indexing simplifications and makes it easier to analyze loop bodies.
  3388. """
  3389. def __init__(self, fn, args, var_ranges):
  3390. super().__init__()
  3391. self.var_ranges = var_ranges
  3392. self.indexing_exprs = {}
  3393. self.indexing_exprs_name = {}
  3394. self.reads = []
  3395. self.writes = []
  3396. self.reads_name2expr = {}
  3397. self.writes_name2expr = {}
  3398. self.other = []
  3399. self.submodules = {"get_index": self.get_index}
  3400. self.subblocks = {}
  3401. self.indirect_vars = []
  3402. self.root_block = LoopBodyBlock(self, fn, args)
  3403. self.indexing = None
  3404. def debug_str(self):
  3405. lines = [f"var_ranges = {dict(self.var_ranges)}"]
  3406. lines.extend([f"{name} = {val}" for name, val in self.indexing_exprs.items()])
  3407. lines.extend(
  3408. [
  3409. block.debug_str(name)
  3410. for name, block in itertools.chain(
  3411. [("body", self.root_block)], self.subblocks.items()
  3412. )
  3413. ]
  3414. )
  3415. return "\n".join(lines)
  3416. def add_index_expr(self, expr: sympy.Expr, category, buf_name):
  3417. getattr(self, category).append(expr)
  3418. if buf_name is not None:
  3419. getattr(self, f"{category}_name2expr")[buf_name] = expr
  3420. if expr not in self.indexing_exprs_name:
  3421. name = f"index{len(self.indexing_exprs)}"
  3422. self.indexing_exprs_name[expr] = name
  3423. self.indexing_exprs[name] = expr
  3424. return self.indexing_exprs_name[expr]
  3425. def add_submodule(self, block, prefix):
  3426. """Not actually for nn.Modules, but subblocks in generated code are mapped to FX call_module opcodes"""
  3427. if prefix[-1].isnumeric() and prefix not in self.submodules:
  3428. name = prefix
  3429. else:
  3430. name = f"{prefix}{len(self.submodules)}"
  3431. self.submodules[name] = block
  3432. return name
  3433. def add_indirect(self):
  3434. name = f"indirect{len(self.indirect_vars)}"
  3435. var = sympy_symbol(name)
  3436. self.indirect_vars.append(var)
  3437. return var
  3438. def replace_indirect(self, old, new):
  3439. """Swap in a variable used in indirect indexing"""
  3440. if str(old) == str(new):
  3441. return
  3442. self.indexing = {k: sympy_subs(v, {old: new}) for k, v in self.indexing.items()}
  3443. def get_index(self, name):
  3444. return self.indexing[name]
  3445. def __call__(self, *indices):
  3446. index = list(itertools.chain(*indices))
  3447. assert len(index) == len(self.var_ranges), (index, self.var_ranges)
  3448. assert all(v not in self.var_ranges for v in index)
  3449. replacements = dict(zip(self.var_ranges.keys(), index))
  3450. self.indexing = {
  3451. name: sympy_subs(expr, replacements)
  3452. for name, expr in self.indexing_exprs.items()
  3453. }
  3454. result = self.root_block()
  3455. self.indexing = None
  3456. return result
  3457. class LoopBodyBlock:
  3458. """
  3459. Captures the body of a Loops subclass into an FX graph.
  3460. In normal cases there will be a 1:1 mapping between LoopBody and
  3461. LoopBodyBlock, hower in the case of ops.masked() the masked out
  3462. operations will manifest as an extra LoopBodyBlock.
  3463. """
  3464. def __init__(self, body: LoopBody, fn: Callable, args: List[Any]):
  3465. self.body = body
  3466. def add_index(expr, category, buf_name=None):
  3467. return tracer.create_proxy(
  3468. "call_module",
  3469. "get_index",
  3470. (self.body.add_index_expr(expr, category, buf_name),),
  3471. {},
  3472. )
  3473. class CaptureIndexing(V.WrapperHandler):
  3474. self.name = "CaptureIndexing"
  3475. def load(self, name: str, index: sympy.Expr):
  3476. index = add_index(index, "reads", name)
  3477. return self._inner.load(name, index)
  3478. def store(self, name, index, value, mode=None):
  3479. index = add_index(index, "writes", name)
  3480. return self._inner.store(name, index, value, mode)
  3481. def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
  3482. index = add_index(index, "writes", name)
  3483. return self._inner.reduction(
  3484. name, dtype, src_dtype, reduction_type, index, value
  3485. )
  3486. def index_expr(self, index, dtype):
  3487. if isinstance(index, (int, sympy.Integer)):
  3488. return ops.constant(int(index), dtype)
  3489. index = add_index(index, "other")
  3490. return self._inner.index_expr(index, dtype)
  3491. @staticmethod
  3492. def masked(mask_proxy, masked_body: Callable, other_proxy):
  3493. """
  3494. Recursively capture the masked out body in another LoopBodyBlock
  3495. """
  3496. def shim(mask, other):
  3497. return V.ops.masked(mask, subblock, other)
  3498. name = self.body.add_submodule(shim, "masked_subblock")
  3499. subblock = LoopBodyBlock(self.body, masked_body, [])
  3500. self.body.subblocks[name] = subblock
  3501. return tracer.create_proxy(
  3502. "call_module", name, (mask_proxy, other_proxy), {}
  3503. )
  3504. @staticmethod
  3505. def indirect_indexing(index_proxy):
  3506. """
  3507. Flow data from tensors into indexing formulas.
  3508. Introduce a call_module to update the indexing.
  3509. """
  3510. def set_indirect(new_var):
  3511. self.body.replace_indirect(var, V.ops.indirect_indexing(new_var))
  3512. var = self.body.add_indirect()
  3513. tracer.create_proxy(
  3514. "call_module",
  3515. self.body.add_submodule(set_indirect, f"set_{var}"),
  3516. (index_proxy,),
  3517. {},
  3518. )
  3519. return var
  3520. tracer = torch.fx.Tracer()
  3521. tracer.graph = torch.fx.Graph(tracer_cls=tracer.__class__)
  3522. proxy_ops = tracer.create_proxy("placeholder", "ops", (), {})
  3523. from .sizevars import SimplifyIndexing
  3524. with V.set_ops_handler(
  3525. SimplifyIndexing(CaptureIndexing(proxy_ops), self.body.var_ranges)
  3526. ):
  3527. tracer.create_proxy("output", "output", (fn(*args),), {})
  3528. self.graph = tracer.graph
  3529. def __call__(self):
  3530. graph = self.graph
  3531. submodules = self.body.submodules
  3532. return InterpreterShim(graph, submodules).run(V.get_ops_handler())
  3533. def debug_str(self, name="block"):
  3534. code = torch.fx.GraphModule(self.body.submodules, self.graph).code
  3535. return re.sub(
  3536. # strip `; del var0` suffixes to make output prettier
  3537. r";[^\n]*",
  3538. "",
  3539. code.strip().replace("def forward(", f"def {name}("),
  3540. )