12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932393339343935393639373938393939403941394239433944394539463947394839493950395139523953395439553956395739583959396039613962396339643965396639673968396939703971397239733974397539763977397839793980398139823983398439853986398739883989399039913992399339943995399639973998399940004001400240034004400540064007400840094010401140124013401440154016401740184019402040214022402340244025402640274028402940304031403240334034403540364037403840394040404140424043404440454046404740484049405040514052405340544055405640574058405940604061406240634064406540664067406840694070407140724073407440754076407740784079408040814082408340844085408640874088408940904091409240934094409540964097409840994100410141024103410441054106410741084109411041114112411341144115411641174118411941204121412241234124412541264127412841294130413141324133413441354136413741384139414041414142414341444145414641474148414941504151415241534154415541564157415841594160416141624163416441654166416741684169417041714172417341744175417641774178417941804181418241834184418541864187418841894190 |
- import contextlib
- import dataclasses
- import functools
- import itertools
- import logging
- import re
- import textwrap
- from collections import OrderedDict
- from contextlib import nullcontext
- from enum import Enum
- from functools import partial
- from inspect import signature
- from typing import Any, Callable, ClassVar, Dict, List, Optional, Set, Tuple, Union
- from unittest.mock import patch
- import sympy
- from sympy import Expr, Integer
- import torch.fx
- import torch.utils._pytree as pytree
- from torch._prims_common import (
- is_boolean_dtype,
- is_float_dtype,
- make_channels_last_strides_for,
- make_contiguous_strides_for,
- )
- from torch.fx.experimental.symbolic_shapes import FloorDiv
- from . import config, dependencies
- from .codegen.common import index_prevent_reordering
- from .cuda_properties import get_device_properties
- from .dependencies import extract_read_writes, var_builder
- from .utils import (
- argsort,
- cache_on_self,
- convert_shape_to_inductor,
- convert_shape_to_symint,
- developer_warning,
- sympy_dot,
- sympy_product,
- sympy_subs,
- sympy_symbol,
- )
- from .virtualized import ops, V
- log = logging.getLogger(__name__)
- indent = functools.partial(textwrap.indent, prefix=" ")
- aten = torch.ops.aten
- """ [Note: Inductor IR]
- Inductor's IR is produced by executing 'lowering' code (see lowering.py). Each
- lowering is registered to a particular aten operator, and expects inputs that
- correspond to the aten schema. However, in place of torch Tensor inputs, lowerings
- expect Inductor TensorBox inputs.
- TensorBox IR represents torch tensors. Tensors are sometimes single objects owning
- storage, and sometimes views of another Tensor's storage. Mutating tensor operations
- (such as add_()) affect the underlying storage and any associated views. Other operations
- (such as .t_()) update metadata about the current view but don't modify the underlying storage.
- To model this in Inductor, the IR distinguishes between TensorBox, View, StorageBox and Buffer.
- TensorBox is the top level IR construct that any lowering should produce and maps to a torch.Tensor
- output from an operation. But just as torch.Tensors take different forms, TensorBox IR can
- reference View IR or directly reference StorageBox IRs.
- Some Inductor lowerings produce new sets of 'Box'es, while others (such as .t() or other view ops)
- may take an existing TensorBox and point it to a new underlying View IR.
- Tensors that directly own storage are represented as a chain of:
- TensorBox -> StorageBox -> Buffer
- where Buffer is a simple (1D) allocation, and StorageBox introduces the concept of a Layout.
- If you mutate the data of such a tensor, we swing the StorageBox pointer to point to a new buffer
- (leaving the old buffer unmodified and functionalizing the operation).
- Tensors backed by views add one more indirection to the IR.
- TensorBox -> View -> StorageBox -> Buffer
- In these cases, the underlying StorageBox/Buffer will be shared with the pre-view TensorBox.
- For metadata mutation (e.g. as_strided_) we swing the TensorBox pointer.
- """
- def validate_ir(node_or_nodes):
- def _check_tensorbox(node):
- # Could expand this to check deeper properties
- # (e.g. TensorBox points to View or StorageBox)
- assert isinstance(
- node,
- (
- TensorBox,
- RandSeedBuffer,
- torch.fx.experimental.symbolic_shapes.Symbol,
- Expr,
- ),
- ), f"Found {type(node)}, which is not a supported top level IR node. See [Note: Inductor IR]"
- # Be picky about the accepted data structure (don't use pytree here)
- if isinstance(node_or_nodes, (List, Tuple)):
- for node in node_or_nodes:
- _check_tensorbox(node)
- else:
- _check_tensorbox(node_or_nodes)
- def inverse_reorder(order):
- inv_order = dict(zip(order, range(len(order))))
- def reindex(index):
- assert len(index) == len(inv_order)
- return [index[inv_order[i]] for i in range(len(index))]
- return reindex
- def same_reorder(order):
- def reindex(index):
- assert len(index) == len(order)
- return [index[order[i]] for i in range(len(index))]
- return reindex
- def fuse_reindexing(reindex1, reindex2):
- def reindex(index):
- return reindex1(reindex2(index))
- return reindex
- def stride_order2fill_order(order):
- """
- Convert stride order to fill order
- For channel last format,
- stride order = [3, 0, 2, 1] and fill order = [1, 3, 2, 0]
- """
- lookup = {pos: idx for idx, pos in enumerate(order)}
- fill_order = [lookup[i] for i in range(len(order))]
- return fill_order
- def get_stride_order(seq):
- """
- Convert strides to stride order
- """
- sorted_idx = argsort(seq)
- out = [None for _ in range(len(seq))]
- for i, elem in enumerate(sorted_idx):
- out[elem] = i
- return out
- def reads_from_conv(buf, var_ranges):
- """
- return:
- if reads_from_conv: boolean
- the new memory_addr: Sympy Expression
- """
- if buf is None:
- return False, None
- if isinstance(buf, Convolution):
- indexer = buf.layout.as_fixed().make_indexer()
- index_vars = sorted(var_ranges, key=lambda var: var.name)
- index = indexer(index_vars)
- return True, index
- # for case like
- # buf0 = conv(x, w)
- # return torch.cat([buf0, buf1]), torch.cat([buf0, buf2])
- # Because of ConcatKernel, it will create two bufs buf3 and 4
- # buf3 has the AliasedLayout which reads from buf0(Convolution)
- # but buf4 is a copy of buf3 which reads from buf3
- # we want to know that buf4 also follows buf0 conv's layout
- if isinstance(buf.layout, AliasedLayout):
- reads = buf.get_read_writes().reads
- reads_bufs = [
- V.graph.name_to_buffer[r.name]
- if r.name in V.graph.name_to_buffer.keys()
- else None
- for r in reads
- ]
- for reads_buf in reads_bufs:
- read_from_conv, addr = reads_from_conv(reads_buf, var_ranges)
- if read_from_conv:
- return True, addr
- return False, None
- def ir_node_to_tensor(x, guard_shape=True):
- if not guard_shape:
- shape_fn = V.graph.sizevars.size_hint
- else:
- def nop(x):
- return x
- shape_fn = nop
- size = [shape_fn(s) for s in x.get_size()]
- if is_storage_and_layout(x):
- stride = [shape_fn(s) for s in x.get_layout().stride]
- else:
- stride = make_contiguous_strides_for(size)
- dtype = x.get_dtype()
- device = x.get_device()
- size = convert_shape_to_symint(size)
- stride = convert_shape_to_symint(stride)
- t = torch.empty_strided(
- size=size, stride=stride, dtype=dtype, device=device
- ).zero_()
- return t
- def layout_priority_idx(reads_bufs, memory_addrs, var_ranges):
- """
- if reads from conv that needs to use specific layout
- return:
- priority_idx regarding memory_addrs idx
- memory_addrs - update memory_addrs with the true addr if needed
- """
- priority_idx = []
- for i, reads_buf in enumerate(reads_bufs):
- read_from_conv, mem_addr = reads_from_conv(reads_buf, var_ranges)
- if read_from_conv:
- priority_idx.append(i)
- memory_addrs[i] = mem_addr
- return priority_idx, memory_addrs
- class ModularIndexing(sympy.Function):
- """
- ModularIndexing(a, b, c) => (a // b) % c
- """
- nargs = (3,)
- is_integer = True
- @classmethod
- def eval(cls, base, divisor, modulus):
- if base == 0 or modulus == 1:
- return sympy.Integer(0)
- if (
- isinstance(base, sympy.Integer)
- and isinstance(divisor, sympy.Integer)
- and isinstance(modulus, sympy.Integer)
- ):
- return (base // divisor) % modulus
- if divisor != 1:
- gcd = sympy.gcd(base, divisor)
- if gcd != 1:
- return ModularIndexing(base / gcd, divisor / gcd, modulus)
- if isinstance(base, sympy.Add):
- new_terms = []
- all_positive = True
- for term in base.args:
- if sympy.gcd(term, modulus * divisor) != modulus * divisor:
- if (isinstance(term, sympy.Integer) and term < 0) or (
- isinstance(term, sympy.Mul)
- and isinstance(term.args[0], sympy.Integer)
- and term.args[0] < 0
- ):
- # workaround for https://github.com/openai/triton/issues/619,
- # if there are negative terms, // produces wrong result
- # TODO if https://github.com/openai/triton/issues/619 is fixed
- # this optimization would become valid
- all_positive = False
- break
- else:
- new_terms.append(term)
- if len(new_terms) != len(base.args) and all_positive:
- return ModularIndexing(sum(new_terms), divisor, modulus)
- if isinstance(base, FloorDiv):
- return ModularIndexing(base.args[0], base.args[1] * divisor, modulus)
- class CleanDiv(FloorDiv):
- """
- Div where we can assume no rounding.
- This is to enable future optimizations.
- """
- pass
- class CeilDiv(sympy.Function):
- """
- Div used in indexing that rounds up.
- """
- is_integer = True
- def __new__(cls, base, divisor):
- if sympy.gcd(base, divisor) == divisor:
- return CleanDiv(base, divisor)
- else:
- return FloorDiv(base + (divisor - 1), divisor)
- def get_device_type(x):
- if getattr(x, "get_device", None):
- return get_device_type(x.get_device())
- if isinstance(x, torch.device):
- return x.type
- return None
- def is_triton(x):
- return get_device_type(x) == "cuda"
- def is_cpu(x):
- return get_device_type(x) == "cpu"
- @dataclasses.dataclass
- class IRNode:
- _current_origins: ClassVar[Set[Any]] = set()
- @staticmethod
- @contextlib.contextmanager
- def current_origins(origins: Set[torch.fx.Node]):
- old = IRNode._current_origins
- IRNode._current_origins = old | origins
- yield
- IRNode._current_origins = old
- def __post_init__(self):
- self.origins = set(self._current_origins)
- def common_repr(self):
- return (
- [f"origins={self.origins}"] if hasattr(self, "origins") else ["no origins?"]
- )
- def str_helper(self, lines):
- lines = lines + self.common_repr()
- lines = indent(",\n".join(map(str, lines)))
- return f"{type(self).__name__}(\n{lines}\n)"
- def is_user_of(self, name):
- return any(name == dep.name for dep in self.get_reads())
- def get_numel(self):
- return sympy_product(self.get_size())
- @dataclasses.dataclass
- class Loops(IRNode):
- device: torch.device
- dtype: torch.dtype
- inner_fn: Callable
- ranges: List[Expr]
- def __str__(self, names=("ranges",)):
- return self.str_helper(
- [
- f"'{self.device.type}'",
- str(self.dtype),
- self.inner_fn_str(),
- ]
- + [f"{name}={getattr(self, name)}" for name in names]
- )
- __repr__ = __str__
- def get_dtype(self):
- return self.dtype
- def get_device(self):
- return self.device
- def get_size(self):
- return self.ranges
- def is_extern(self):
- return False
- @classmethod
- def create(cls, *args, **kwargs):
- return TensorBox.create(cls(*args, **kwargs))
- @staticmethod
- def _index(ranges, prefix="i"):
- return [
- sympy.Integer(0) if s == 1 else sympy_symbol(f"{prefix}{n}")
- for n, s in enumerate(ranges)
- ]
- @cache_on_self
- def inner_fn_str(self):
- formatter = V.KernelFormatterHandler(V.MockHandler())
- with V.set_ops_handler(formatter), patch.object(
- FlexibleLayout, "allow_indexing", True
- ):
- result = self.inner_fn(self._index(self.ranges))
- return formatter.getvalue(result)
- def is_zero_elements(self):
- return any(r == 0 for r in self.ranges)
- @cache_on_self
- def get_reads(self):
- with patch.object(FlexibleLayout, "allow_indexing", True):
- if self.get_reduction_type():
- return extract_read_writes(
- self.make_loader(),
- self.get_size(),
- self.get_reduction_size(),
- ).reads
- else:
- return extract_read_writes(
- self.make_loader(),
- self.get_size(),
- ).reads
- class Pointwise(Loops):
- def make_loader(self):
- return self.inner_fn
- def get_reduction_size(self):
- return []
- def get_reduction_type(self):
- return None
- def store_output(self, output_name, indexer, vars):
- return ops.store(output_name, indexer(vars), self.inner_fn(vars))
- def constant_to_device(self, device):
- """Move this to a given device. Requires that all reads are to constants."""
- loader = self.make_loader()
- loader = patch.object(ConstantBuffer, "override_device", device)(loader)
- return Pointwise(device, self.dtype, loader, self.ranges)
- @dataclasses.dataclass
- class Scatter(Pointwise):
- output_indexer: Callable[[List[Expr]], Expr]
- scatter_mode: Optional[str] = None
- def constant_to_device(self, device):
- """Move this to a given device. Requires that all reads are to constants."""
- loader = self.make_loader()
- loader = patch.object(ConstantBuffer, "override_device", device)(loader)
- return Scatter(
- device,
- self.dtype,
- loader,
- self.ranges,
- self.output_indexer,
- self.scatter_mode,
- )
- def store_output(self, output_name, indexer, vars):
- return ops.store(
- output_name,
- indexer(self.output_indexer(vars)),
- self.inner_fn(vars),
- mode=self.scatter_mode,
- )
- class ReductionHint(Enum):
- INNER = 0
- OUTER = 1
- OUTER_TINY = 2
- DEFAULT = 3
- class TileHint(Enum):
- SQUARE = 0
- DEFAULT = 1
- @dataclasses.dataclass
- class Reduction(Loops):
- reduction_ranges: List[Expr]
- reduction_type: str
- # self.dtype represents the dst dtype
- src_dtype: torch.dtype
- reduction_hint: ReductionHint
- def __str__(self):
- return Loops.__str__(
- self, names=("ranges", "reduction_ranges", "reduction_type")
- )
- __repr__ = __str__
- def get_reduction_size(self):
- return self.reduction_ranges
- def get_reduction_type(self):
- return self.reduction_type
- def store_reduction(self, output_name, indexer, vars, reduction_vars):
- return ops.reduction(
- output_name,
- self.dtype,
- self.src_dtype,
- self.reduction_type,
- indexer(vars),
- self.inner_fn(vars, reduction_vars),
- )
- def index_length(self):
- return len(self.ranges) + len(self.reduction_ranges)
- @cache_on_self
- def inner_fn_str(self):
- formatter = V.KernelFormatterHandler(V.MockHandler())
- with V.set_ops_handler(formatter), patch.object(
- FlexibleLayout, "allow_indexing", True
- ):
- result = self.inner_fn(
- self._index(self.ranges),
- self._index(self.reduction_ranges, "r"),
- )
- return formatter.getvalue(result)
- def constant_to_device(self, device):
- """Move this to a given device. Requires that all reads are to constants."""
- loader = self.make_loader()
- loader = patch.object(ConstantBuffer, "override_device", device)(loader)
- return Reduction(
- device,
- self.dtype,
- loader,
- self.ranges,
- self.reduction_ranges,
- self.reduction_type,
- self.src_dtype,
- ReductionHint.DEFAULT,
- )
- @staticmethod
- def num_splits(
- device,
- dst_dtype,
- src_dtype,
- inner_fn,
- ranges,
- reduction_ranges,
- reduction_type,
- reduction_numel,
- ):
- num_sm = get_device_properties(device).multi_processor_count
- min_elements_per_thread = 32
- max_elements_per_thread = 512
- threads_per_sm = 2048
- min_elements_per_device = min_elements_per_thread * num_sm * threads_per_sm
- max_elements_per_device = max_elements_per_thread * num_sm * threads_per_sm
- def inner_reduction_splits(reduction_numel_hint, numel_hint):
- # do heuristics that's close to eager mode for split inner reduction
- # we leak reduction autotune configs here, and will need to refactor to avoid this later
- num_warps = 8
- num_threads = 32 * num_warps
- if numel_hint >= 2 * num_sm: # don't split if there are enough outputs
- return 1
- if reduction_numel_hint <= 8192:
- return 1
- if reduction_numel_hint * numel_hint <= min_elements_per_device:
- split_size = min_elements_per_thread
- elif reduction_numel_hint * numel_hint < max_elements_per_device:
- target_blocks = num_sm * threads_per_sm // (2 * num_threads)
- blocks_per_output = (target_blocks + numel_hint - 1) // numel_hint
- tmp_split_size = (
- reduction_numel_hint + num_threads * blocks_per_output - 1
- ) // (num_threads * blocks_per_output)
- divisors = sympy.divisors(reduction_numel_hint)
- closest = min(divisors, key=lambda x: abs(x - tmp_split_size))
- if abs(closest - tmp_split_size) < 30:
- # prefer even splits, but never smalle than min_elements_per_thread
- split_size = max(closest, min_elements_per_thread)
- else:
- split_size = tmp_split_size
- else:
- divisors = sympy.divisors(reduction_numel_hint)
- closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread))
- if abs(closest - max_elements_per_thread) < 50:
- # prefer even splits
- split_size = closest
- else:
- split_size = max_elements_per_thread
- return (reduction_numel_hint + split_size * num_threads - 1) // (
- split_size * num_threads
- )
- def outer_reduction_splits(reduction_numel_hint, numel_hint):
- # TODO the best heuristic currently has XBLOCK (corresponding to numel_hint) 128
- # extend to even smaller number of outputs
- num_warps = 8
- num_threads = num_warps * 32
- rvals_per_thread = 4 # comes from heuristics, refactor to not leak here
- xvals_per_block = 128
- xblocks = (numel_hint + xvals_per_block - 1) // xvals_per_block
- if reduction_numel_hint * numel_hint < min_elements_per_device:
- split_size = min_elements_per_thread
- elif reduction_numel_hint * numel_hint < max_elements_per_device:
- target_blocks = num_sm * threads_per_sm // (num_threads)
- target_blocks = (target_blocks + xblocks - 1) // xblocks
- tmp_split_size = (
- reduction_numel_hint + rvals_per_thread * target_blocks - 1
- ) // (rvals_per_thread * target_blocks)
- divisors = sympy.divisors(reduction_numel_hint)
- closest = min(divisors, key=lambda x: abs(x - tmp_split_size))
- if abs(tmp_split_size - closest) < 20:
- split_size = max(closest, min_elements_per_thread)
- else:
- split_size = tmp_split_size
- else:
- divisors = sympy.divisors(reduction_numel_hint)
- closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread))
- if abs(closest - max_elements_per_thread) < 50:
- # prefer even splits
- split_size = closest
- else:
- split_size = max_elements_per_thread
- return (reduction_numel_hint + rvals_per_thread * split_size - 1) // (
- rvals_per_thread * split_size
- )
- reduction_numel_hint = V.graph.sizevars.size_hint(reduction_numel)
- numel_hint = V.graph.sizevars.size_hint(sympy_product(ranges))
- # easy cases
- if numel_hint == 1:
- return ReductionHint.INNER, inner_reduction_splits(
- reduction_numel_hint, numel_hint
- )
- if (
- reduction_numel_hint <= min_elements_per_thread
- or numel_hint >= num_sm * 2 * 32
- ):
- return ReductionHint.DEFAULT, 1
- r = Reduction(
- device,
- dst_dtype,
- inner_fn,
- ranges,
- reduction_ranges,
- reduction_type,
- src_dtype,
- ReductionHint.DEFAULT,
- )
- def get_read_indices(r):
- cb = ComputedBuffer(
- name=None,
- layout=FlexibleLayout(
- device=r.get_device(),
- dtype=r.get_dtype(),
- size=r.get_size(),
- ),
- data=r,
- )
- read_writes = cb.get_read_writes()
- # try finding the full size producer
- # TODO this will fail for something like ((1, N) * (N, 1)).sum()
- # this would also possibly be wrong for producers with the different contiguity but we hope those cases are rare
- range_vars = [
- r
- for r in read_writes.range_vars
- if isinstance(r, sympy.Expr) and not isinstance(r, sympy.Number)
- ]
- indices = []
- changed = False
- for md in sorted(read_writes.reads, key=lambda x: x.name):
- if all([r in md.index.free_symbols for r in range_vars]):
- indices.append(md.index)
- if md.name in V.graph.name_to_buffer:
- buf = V.graph.name_to_buffer[md.name]
- original_stride = buf.layout.stride
- buf.decide_layout()
- if buf.layout.stride != original_stride:
- changed = True
- return indices, changed
- indices, changed = get_read_indices(r)
- if changed:
- indices, _ = get_read_indices(r)
- if len(indices) == 0:
- # TODO determine splits when all inputs are broadcast
- return ReductionHint.DEFAULT, 1
- _, (_, reduction_vars), _ = dependencies.index_vars_squeeze(
- r.get_size(), r.get_reduction_size()
- )
- num_outer = 0
- num_inner = 0
- for i in indices:
- strides = V.graph.sizevars.stride_hints(i, reduction_vars)
- outer = all([s > 1 for s in strides])
- if outer:
- num_outer += 1
- else:
- num_inner += 1
- if num_inner > num_outer:
- return ReductionHint.INNER, inner_reduction_splits(
- reduction_numel_hint, numel_hint
- )
- else:
- return ReductionHint.OUTER, outer_reduction_splits(
- reduction_numel_hint, numel_hint
- )
- @staticmethod
- def _unroll_reduction_fn(inner_fn, reduction_ranges, reduction_type):
- """Convert inner_fn from a reduction to an pointwise"""
- reduction_ranges = [
- V.graph.sizevars.guard_static_shape(x) for x in reduction_ranges
- ]
- if reduction_type == "sum":
- def combine_fn(a, b):
- return ops.add(a, b)
- elif reduction_type == "min":
- def combine_fn(a, b):
- return ops.minimum(a, b)
- elif reduction_type == "max":
- def combine_fn(a, b):
- return ops.maximum(a, b)
- elif reduction_type == "any":
- def combine_fn(a, b):
- return ops.logical_or(a, b)
- elif reduction_type == "argmin":
- def combine_fn(a, b):
- return ops.minimum(a[0], b[0]), ops.where(
- ops.lt(b[0], a[0]), b[1], a[1]
- )
- elif reduction_type == "argmax":
- def combine_fn(a, b):
- return ops.maximum(a[0], b[0]), ops.where(
- ops.gt(b[0], a[0]), b[1], a[1]
- )
- else:
- raise NotImplementedError(f"unknown reduction_type={reduction_type}")
- def fn(index):
- return functools.reduce(
- combine_fn,
- (
- value_fn(index, rindex)
- for rindex in itertools.product(
- *[range(x) for x in reduction_ranges]
- )
- ),
- )
- if reduction_type in ("argmin", "argmax"):
- flatten_index = FixedLayout(
- None,
- None,
- reduction_ranges,
- FlexibleLayout.contiguous_strides(reduction_ranges),
- ).make_indexer()
- def value_fn(index, rindex):
- rindex = [sympy.expand(i) for i in rindex]
- return (
- inner_fn(index, rindex),
- ops.index_expr(flatten_index(rindex), torch.int64),
- )
- return lambda index: fn(index)[1]
- else:
- value_fn = inner_fn
- return fn
- @classmethod
- def create(
- cls,
- device: torch.device,
- dst_dtype: torch.dtype,
- src_dtype: torch.dtype,
- inner_fn: Callable,
- ranges: List[Expr],
- reduction_ranges: List[Expr],
- reduction_type: str,
- reduction_hint: ReductionHint = ReductionHint.DEFAULT,
- ):
- reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges))
- if reduction_numel == 0:
- # N.B. This is a hack to generate the literal of the given type
- # Ideally, we should be fixing `def constant` in triton.py
- # but it breaks due to hardcoded dtypes in other places
- def py_cnst(val):
- return (
- bool(val)
- if dst_dtype == torch.bool
- else float(val)
- if dst_dtype.is_floating_point
- else int(val)
- )
- rtypes_to_inits = {
- "sum": py_cnst(0),
- "prod": py_cnst(1),
- "any": py_cnst(0),
- # "all" is desugared to `!any(!val)`
- }
- assert (
- reduction_type in rtypes_to_inits.keys()
- ), f"{reduction_type} not supported for zero-dimension tensors!"
- def const_fn(index):
- return ops.constant(rtypes_to_inits[reduction_type], dst_dtype)
- return Pointwise.create(
- device=device,
- dtype=src_dtype,
- inner_fn=const_fn,
- ranges=list(ranges),
- )
- if reduction_numel == 1:
- # this reduction is actually a pointwise op
- if reduction_type in ("argmin", "argmax"):
- def fn(index):
- return ops.constant(0, dst_dtype)
- else:
- def fn(index):
- reduction_index = [sympy.Integer(0) for _ in reduction_ranges]
- return inner_fn(index, reduction_index)
- return Pointwise.create(device, dst_dtype, fn, ranges)
- if (
- isinstance(reduction_numel, sympy.Integer)
- and V.graph.sizevars.size_hint(reduction_numel)
- < config.unroll_reductions_threshold
- and sympy_product(ranges) != 1
- ):
- return Pointwise.create(
- device,
- dst_dtype,
- cls._unroll_reduction_fn(inner_fn, reduction_ranges, reduction_type),
- ranges,
- )
- if is_triton(device) and reduction_type not in {"argmax", "argmin"}:
- # triton doesn't support reduce to single element well, so break it up
- hint, split = cls.num_splits(
- device,
- dst_dtype,
- src_dtype,
- inner_fn,
- ranges,
- reduction_ranges,
- reduction_type,
- reduction_numel,
- )
- # intermediate reduction in split can contain complex indexing,
- # and num_splits will fail to correctly set the hint
- # reuse the passed hint if available
- if reduction_hint == ReductionHint.DEFAULT:
- reduction_hint = hint
- if split > 1:
- # triton doesn't support reduce to single element well, so break it up
- return cls.create_multilayer(
- device,
- dst_dtype,
- src_dtype,
- inner_fn,
- ranges,
- reduction_ranges,
- reduction_type,
- split,
- reduction_hint,
- )
- return TensorBox.create(
- Reduction(
- device,
- dst_dtype,
- inner_fn,
- ranges,
- reduction_ranges,
- reduction_type,
- src_dtype,
- reduction_hint,
- )
- )
- @staticmethod
- def default_value(reduction_type, dtype):
- if reduction_type in {"max", "argmax"}:
- if is_float_dtype(dtype):
- return float("-inf")
- elif is_boolean_dtype(dtype):
- return 0
- else:
- return torch.iinfo(dtype).min
- if reduction_type in {"min", "argmin"}:
- if is_float_dtype(dtype):
- return float("inf")
- elif is_boolean_dtype(dtype):
- return 1
- else:
- return torch.iinfo(dtype).max
- return {
- "sum": 0,
- "any": 0,
- }[reduction_type]
- @classmethod
- def create_multilayer(
- cls,
- device: torch.device,
- dst_dtype: torch.dtype,
- src_dtype: torch.dtype,
- inner_fn: Callable,
- ranges: List[Expr],
- reduction_ranges: List[Expr],
- reduction_type: str,
- split: int,
- reduction_hint: ReductionHint,
- ):
- """
- Break a large reduction up into multiple smaller reductions
- recursively
- """
- reduction_numel = sympy_product(reduction_ranges)
- # TODO(jansel): convert this to dynamic shapes
- # TODO(jansel): realize the reduction so we can do dynamic indexing
- reduction_ranges = [
- sympy.Integer(V.graph.sizevars.guard_static_shape(s))
- for s in reduction_ranges
- ]
- reduction_numel = sympy.Integer(
- V.graph.sizevars.guard_static_shape(reduction_numel)
- )
- if V.graph.sizevars.size_hint(reduction_numel) % split == 0:
- need_mask = False
- else:
- need_mask = True
- split = sympy.Integer(split)
- block_size = FloorDiv(reduction_numel + (split - 1), split)
- reindex = View.dynamic_reshape_indexer(reduction_ranges, [reduction_numel])
- def wrapper_fn(index, reduction_index):
- (reduction_index,) = reduction_index
- *new_index, reduction_block = index
- indices = block_size * reduction_block + reduction_index
- def body():
- return inner_fn(new_index, reindex([indices]))
- if need_mask:
- mask = ops.lt(
- ops.index_expr(indices, torch.int32),
- ops.index_expr(reduction_numel, torch.int32),
- )
- return ops.masked(
- mask, body, cls.default_value(reduction_type, dst_dtype)
- )
- else:
- return body()
- # triton will automatically compute reductions in fp32 if reducing over fp16/bf16
- # within the kernel. keep the intermediate in fp32 so as to keep the whole reduction
- # in fp32 and not reduce precision by breaking up the kernel into multiple layers
- intermediate_dtype = (
- dst_dtype
- if dst_dtype not in (torch.float16, torch.bfloat16)
- else torch.float
- )
- intermediate = Reduction.create(
- device,
- intermediate_dtype,
- src_dtype,
- wrapper_fn,
- [*ranges, split],
- [block_size],
- reduction_type,
- reduction_hint,
- )
- intermediate.realize()
- intermediate_loader = intermediate.make_loader()
- def intermediate_fn(index, reduction_index):
- return intermediate_loader([*index, *reduction_index])
- numel_hint = V.graph.sizevars.size_hint(sympy_product(ranges))
- if split <= 512 and numel_hint <= 512 and reduction_hint == ReductionHint.OUTER:
- reduction_hint = ReductionHint.OUTER_TINY
- if (
- split <= 1024
- and numel_hint <= 256
- and reduction_hint == ReductionHint.OUTER
- ):
- reduction_hint = ReductionHint.OUTER_TINY
- return TensorBox.create(
- Reduction(
- device,
- dst_dtype,
- intermediate_fn,
- ranges,
- [split],
- reduction_type,
- src_dtype,
- reduction_hint,
- )
- )
- def is_storage_and_layout(x):
- try:
- as_storage_and_layout(x, freeze=False)
- return True
- except NotImplementedError:
- return False
- def is_contiguous_storage_and_layout(x):
- try:
- buffer, layout = as_storage_and_layout(x, freeze=False)
- return layout.is_contiguous()
- except NotImplementedError:
- return False
- def as_storage_and_layout(x, freeze=True, want_contiguous=False, stride_order=None):
- """Try to simplify x into a StorageBox and a Layout"""
- if isinstance(x, TensorBox):
- return as_storage_and_layout(
- x.data,
- freeze=freeze,
- want_contiguous=want_contiguous,
- stride_order=stride_order,
- )
- if isinstance(x, StorageBox) and isinstance(x.data, Buffer):
- if freeze:
- if want_contiguous:
- x.data.freeze_layout()
- elif stride_order is not None:
- x.data.freeze_layout_with_stride_order(stride_order)
- else:
- x.data.decide_layout()
- return x, x.data.layout
- if isinstance(x, ReinterpretView):
- # making the base of x contiguous or stride_ordered will not necessarily make
- # the ReinterpretedView either, so dont pass along those arguments
- buffer, _ = as_storage_and_layout(
- x.data,
- freeze=freeze,
- )
- return buffer, x.layout
- raise NotImplementedError
- as_contiguous_storage_and_layout = functools.partial(
- as_storage_and_layout, want_contiguous=True
- )
- def is_stride_order_storage_and_layout(x, stride_order):
- try:
- buffer, layout = as_storage_and_layout(x, freeze=False)
- return layout.is_stride_ordered(stride_order)
- except NotImplementedError:
- return False
- @dataclasses.dataclass
- class BaseView(IRNode):
- data: IRNode
- def get_dtype(self):
- return self.data.get_dtype()
- def get_device(self):
- return self.data.get_device()
- def get_name(self):
- return self.data.get_name()
- def mark_reuse(self, users):
- return self.data.mark_reuse(users)
- def has_exceeded_max_reads(self):
- return self.data.has_exceeded_max_reads()
- def realize(self):
- return self.data.realize()
- def realize_hint(self):
- return self.data.realize_hint()
- def get_storage_numel(self):
- return self.data.get_storage_numel()
- def is_extern(self):
- return self.data.is_extern()
- @cache_on_self
- def get_reads(self):
- with patch.object(FlexibleLayout, "allow_indexing", True):
- return extract_read_writes(
- self.make_loader(),
- self.get_size(),
- ).reads
- def unwrap_view(self):
- x = self
- while isinstance(x, BaseView):
- x = x.data
- return x
- def constant_to_device(self, device):
- """Move this to a given device. Requires that all reads are to constants."""
- loader = self.make_loader()
- loader = patch.object(ConstantBuffer, "override_device", device)(loader)
- return Pointwise(device, self.get_dtype(), loader, self.get_size())
- @dataclasses.dataclass
- class ExpandView(BaseView):
- size: List[Expr]
- @staticmethod
- def _normalize_size(x, new_size):
- """Replace `-1` with correct sizes"""
- new_size = list(map(sympy.expand, new_size))
- old_size = x.get_size()
- old_size = [None] * (len(new_size) - len(old_size)) + list(old_size)
- assert len(new_size) == len(old_size)
- for i in range(len(new_size)):
- if new_size[i] == -1:
- assert old_size[i] is not None
- new_size[i] = old_size[i]
- return new_size
- @classmethod
- def create(cls, x, new_size):
- new_size = cls._normalize_size(x, new_size)
- if is_storage_and_layout(x):
- storage, old_layout = as_storage_and_layout(x)
- skip = len(new_size) - len(old_layout.size)
- assert skip >= 0
- new_stride = [sympy.Integer(0)] * skip
- for stride, size in zip(old_layout.stride, old_layout.size):
- new_stride.append(stride if size != 1 else sympy.Integer(0))
- new_layout = FixedLayout(
- old_layout.device,
- old_layout.dtype,
- list(new_size),
- new_stride,
- old_layout.offset,
- )
- return ReinterpretView(storage, new_layout)
- return ExpandView(x, new_size)
- def get_size(self):
- return self.size
- def make_loader(self):
- target = self.get_size()
- actual = self.data.get_size()
- skip = len(target) - len(actual)
- inner = self.data.make_loader()
- def load(index):
- index = list(index[skip:])
- assert len(index) == len(actual)
- for i in range(len(actual)):
- if actual[i] == 1:
- # zero out broadcast dimension
- index[i] = sympy.Integer(0)
- return inner(index)
- return load
- @dataclasses.dataclass
- class PermuteView(BaseView):
- dims: List[Expr]
- @classmethod
- def create(cls, x, dims):
- dims = cls._map_neg_dims(dims)
- assert set(dims) == set(range(len(dims)))
- if is_storage_and_layout(x):
- storage, old_layout = as_storage_and_layout(x)
- new_layout = FixedLayout(
- old_layout.device,
- old_layout.dtype,
- [old_layout.size[i] for i in dims],
- [old_layout.stride[i] for i in dims],
- old_layout.offset,
- )
- return ReinterpretView(storage, new_layout)
- return PermuteView(x, dims)
- @classmethod
- def _map_neg_dims(cls, dims):
- return [dim if dim >= 0 else len(dims) + dim for dim in dims]
- def get_size(self):
- assert set(self._map_neg_dims(self.dims)) == set(range(len(self.dims)))
- size = self.data.get_size()
- return [size[i] for i in self.dims]
- def make_loader(self):
- inner = self.data.make_loader()
- inv = {j: i for i, j in enumerate(self.dims)}
- inv = [inv[i] for i in range(len(self.dims))]
- assert set(inv) == set(range(len(self.dims)))
- def load(index):
- index = [index[i] for i in inv]
- return inner(index)
- return load
- class SqueezeView(BaseView):
- @classmethod
- def create(cls, x, *, dim=None):
- if is_storage_and_layout(x):
- storage, old_layout = as_storage_and_layout(x)
- new_size = []
- new_stride = []
- if dim is not None:
- assert isinstance(dim, int), "expected integer dim argument"
- assert 0 <= dim and dim < len(old_layout.size)
- for i, (size, stride) in enumerate(zip(old_layout.size, old_layout.stride)):
- if dim is None:
- if size != 1:
- new_size.append(size)
- new_stride.append(stride)
- else:
- if i != dim:
- new_size.append(size)
- new_stride.append(stride)
- else:
- assert size == 1, "expected squeezed size to be 1"
- new_layout = FixedLayout(
- old_layout.device,
- old_layout.dtype,
- new_size,
- new_stride,
- old_layout.offset,
- )
- return ReinterpretView(storage, new_layout)
- if dim is None:
- # redirect to a generic view
- return View.create(x, [s for s in x.get_size() if s != 1])
- else:
- assert x.get_size()[dim] == 1
- return View.create(x, [s for i, s in enumerate(x.get_size()) if i != dim])
- @staticmethod
- def squeezer(size: Tuple[sympy.Expr, ...]):
- new_size = [s for s in size if s != 1]
- not_one = [i for i, s in enumerate(size) if s != 1]
- length = len(size)
- def reindex(index: List[sympy.Expr]) -> List[sympy.Expr]:
- assert len(index) == len(not_one), f"{index} {not_one}"
- new_index = [sympy.Integer(0)] * length
- for idx, s in zip(not_one, index):
- new_index[idx] = s
- return tuple(new_index)
- return new_size, reindex
- def __init__(self, data):
- raise AssertionError("use SqueezeView.create()")
- @dataclasses.dataclass
- class View(BaseView):
- size: List[Expr]
- reindex: Callable
- def make_indexer(self):
- base_indexer = self.data.make_indexer()
- def indexer(idx):
- return base_indexer(self.reindex(idx))
- return indexer
- @staticmethod
- def handle_negative_index(idx, size):
- idx = sympy.expand(idx)
- size = sympy.expand(size)
- sizevars = V.graph.sizevars
- if sizevars.size_hint(idx) < 0:
- sizevars.guard_lt(idx, 0)
- idx = idx + size
- return idx
- def reindex_str(self):
- index_old = [sympy_symbol(f"i{n}") for n in range(len(self.size))]
- index_new = list(self.reindex(index_old))
- return f"lambda {', '.join(map(str, index_old))}: {index_new}"
- def __str__(self):
- return self.str_helper(
- [self.data, f"size={self.size}", f"reindex={self.reindex_str()}"]
- )
- __repr__ = __str__
- @classmethod
- def create(cls, x, new_size):
- assert isinstance(new_size, (tuple, list))
- old_size, new_size = cls.resolve_negative_size(x.get_size(), new_size)
- if V.graph.sizevars.maybe_guard_list_equals(old_size, new_size):
- return x
- # TODO: a new class for FixedTransferLayout that output layout is constrained by input layout
- if is_contiguous_storage_and_layout(x) and not isinstance(
- x.data, ExternKernelAlloc
- ):
- storage, old_layout = as_contiguous_storage_and_layout(x)
- new_layout = FixedLayout(
- old_layout.device,
- old_layout.dtype,
- new_size,
- FlexibleLayout.contiguous_strides(new_size),
- old_layout.offset,
- )
- return ReinterpretView(storage, new_layout)
- reindex = cls.dynamic_reshape_indexer(old_size, new_size)
- return cls(x, tuple(new_size), reindex)
- @staticmethod
- def resolve_negative_size(old_size, new_size):
- new_size = [V.graph.sizevars.simplify(x) for x in new_size]
- old_size = [V.graph.sizevars.simplify(x) for x in old_size]
- new_size = list(new_size)
- for i in range(len(new_size)):
- if new_size[i] == -1:
- new_size[i] = sympy.Integer(1)
- new_size[i] = CleanDiv(sympy_product(old_size), sympy_product(new_size))
- break
- V.graph.sizevars.guard_equals(sympy_product(old_size), sympy_product(new_size))
- return old_size, new_size
- @classmethod
- def dynamic_reshape_indexer(cls, old_size, new_size):
- try:
- reindex = cls._dynamic_reshape_indexer(old_size, new_size)
- except (AssertionError, IndexError):
- # optimistic algorithm failed, lets do a fallback
- flat = [sympy_product(old_size)]
- reindex1 = cls._dynamic_reshape_indexer(old_size, flat)
- reindex2 = cls._dynamic_reshape_indexer(flat, new_size)
- reindex = fuse_reindexing(reindex1, reindex2)
- return reindex
- @staticmethod
- def _dynamic_reshape_indexer(old_size, new_size):
- """
- Perform a reshape entirely by modifying indexing math
- """
- size_hint = V.graph.sizevars.size_hint
- vars = [sympy_symbol(f"view{i}") for i in range(len(new_size))]
- stack_new = list(zip(vars, new_size))
- stack_old = list(old_size)
- view_expr = []
- while stack_new and stack_old:
- size_old = stack_old.pop()
- var, size_new = stack_new.pop()
- if size_old == 1:
- view_expr.append(sympy.Integer(0))
- stack_new.append((var, size_new)) # re-add
- elif size_new == 1:
- stack_old.append(size_old) # re-add
- elif size_hint(size_new) == size_hint(size_old):
- view_expr.append(var)
- V.graph.sizevars.guard_equals(size_new, size_old)
- elif size_hint(size_new) < size_hint(size_old):
- while size_hint(size_new) < size_hint(size_old):
- var2, size_new2 = stack_new.pop()
- var = var2 * size_new + var
- size_new = size_new * size_new2
- view_expr.append(var)
- V.graph.sizevars.guard_equals(size_new, size_old)
- elif size_hint(size_new) > size_hint(size_old):
- divisor = sympy.Integer(1)
- modulus = size_old
- view_expr.append(ModularIndexing(var, divisor, modulus))
- divisor = divisor * modulus
- while size_hint(size_new) > size_hint(size_old):
- modulus = stack_old.pop()
- view_expr.append(ModularIndexing(var, divisor, modulus))
- divisor = divisor * modulus
- size_old = size_old * modulus
- V.graph.sizevars.guard_equals(size_new, size_old)
- else:
- raise AssertionError()
- while stack_old:
- size_old = stack_old.pop()
- assert size_old == 1
- view_expr.append(sympy.Integer(0))
- while stack_new:
- var, size_new = stack_new.pop()
- assert size_new == 1
- view_expr = list(reversed(view_expr))
- assert len(view_expr) == len(old_size)
- def reindex(index):
- assert len(index) == len(vars), (len(index), len(vars))
- replacements = dict(zip(vars, index))
- return tuple(sympy_subs(x, replacements) for x in view_expr)
- return reindex
- def get_size(self):
- return self.size
- def make_loader(self):
- def load(index):
- return inner(self.reindex(index))
- inner = self.data.make_loader()
- return load
- @dataclasses.dataclass
- class ReinterpretView(BaseView):
- """Pretend our storage has a different layout"""
- layout: "Layout"
- def __post_init__(self):
- if isinstance(self.data, BaseView):
- self.data = self.data.unwrap_view()
- def __str__(self):
- return self.str_helper(
- [
- self.data,
- self.layout,
- ]
- )
- __repr__ = __str__
- def get_name(self):
- return self.data.get_name()
- def get_device(self):
- return self.layout.device
- def get_dtype(self):
- return self.layout.dtype
- def get_size(self):
- return list(self.layout.size)
- def get_stride(self):
- return list(self.layout.stride)
- def make_loader(self):
- def loader(index):
- indexer = self.layout.make_indexer()
- return ops.load(self.get_name(), indexer(index))
- return loader
- def make_indexer(self):
- return self.layout.make_indexer()
- def get_layout(self):
- return self.layout
- def freeze_layout(self):
- pass
- def codegen_reference(self):
- size = V.graph.sizevars.codegen_shape_tuple(self.layout.size)
- stride = V.graph.sizevars.codegen_shape_tuple(self.layout.stride)
- offset = V.graph.sizevars.codegen_sizevar(self.layout.offset)
- as_strided = V.graph.sizevars.as_strided
- if offset != "0":
- return f"{as_strided}({self.get_name()}, {size}, {stride}, {offset})"
- return f"{as_strided}({self.get_name()}, {size}, {stride})"
- class SliceView(View):
- @classmethod
- def create(cls, x, dim, start, end, step=1):
- step = sympy.expand(step)
- assert step > 0
- try:
- if start == 0 and end >= 2**63 and step == 1:
- return x
- except TypeError:
- pass
- sizevars = V.graph.sizevars
- new_size = list(x.get_size())
- start = cls.handle_negative_index(start, new_size[dim])
- end = cls.handle_negative_index(end, new_size[dim])
- end = sizevars.guard_min(end, new_size[dim])
- start = sizevars.guard_min(sizevars.guard_min(start, new_size[dim]), end)
- if start == 0 and sizevars.size_hint(end - new_size[dim]) == 0 and step == 1:
- sizevars.guard_equals(end, new_size[dim])
- return x
- new_size[dim] = FloorDiv(end - start + (step - 1), step)
- if is_storage_and_layout(x):
- # Fast path
- storage, old_layout = as_storage_and_layout(x)
- new_stride = list(old_layout.stride)
- new_stride[dim] = new_stride[dim] * step
- new_layout = FixedLayout(
- old_layout.device,
- old_layout.dtype,
- new_size,
- new_stride,
- old_layout.offset + old_layout.stride[dim] * start,
- )
- return ReinterpretView(storage, new_layout)
- def reindex(index):
- assert len(index) == len(new_size), f"wrong ndim {index} {new_size}"
- index = list(index)
- index[dim] = index[dim] * step + start
- return index
- # redirect to a generic view
- return SliceView(x, size=new_size, reindex=reindex)
- class BaseConstant(IRNode):
- def get_size(self):
- return ()
- def get_dtype(self):
- return self.dtype
- def get_device(self):
- return self.device
- def mark_reuse(self, users):
- pass
- def has_exceeded_max_reads(self):
- return False
- def get_reads(self):
- return ()
- def is_extern(self):
- return False
- @dataclasses.dataclass
- class Constant(BaseConstant):
- value: Any
- dtype: torch.dtype
- device: torch.device
- def make_loader(self):
- def loader(index):
- return ops.constant(self.value, self.dtype)
- return loader
- def realize(self):
- pass
- @dataclasses.dataclass
- class IndexingConstant(BaseConstant):
- index: Any
- dtype: torch.dtype
- device: torch.device
- def make_loader(self):
- def loader(index):
- return ops.index_expr(self.index, self.dtype)
- return loader
- @dataclasses.dataclass
- class Layout(IRNode):
- def __init__(
- self,
- device: torch.device,
- dtype: torch.dtype,
- size: List[Expr],
- stride: List[Expr],
- offset: Expr = Integer(0),
- ):
- self.device = device
- self.dtype = dtype
- assert all(isinstance(s, (Expr, int)) for s in size)
- self.size = size
- self._stride = stride
- self.offset = offset
- @property
- def stride(self):
- return self._stride
- def __str__(self):
- offset = ""
- if self.offset != 0:
- offset = f", offset={self.offset}"
- return (
- f"{type(self).__name__}('{self.device.type}', {self.dtype}, "
- f"size={self.size}, stride={self.stride}{offset})"
- )
- __repr__ = __str__
- def is_contiguous(self):
- for left, right, size in zip(
- self.stride, FlexibleLayout.contiguous_strides(self.size), self.size
- ):
- if size != 1 and left != right:
- return False
- return True
- def is_channels_last_contiguous(self):
- ndim = len(self.size)
- if ndim not in [4, 5]:
- return False
- for left, right, size in zip(
- self.stride, make_channels_last_strides_for(self.size), self.size
- ):
- if size != 1 and left != right:
- return False
- return True
- def is_transposed(self):
- for left, right, size in zip(
- self.stride,
- reversed(FlexibleLayout.contiguous_strides(self.size)),
- self.size,
- ):
- if size != 1 and left != right:
- return False
- return True
- def is_stride_ordered(self, order):
- assert len(self.stride) == len(order)
- # reorder the stride given order
- stride_ordered = [None] * len(order)
- for i in range(len(order)):
- stride_ordered[order[i]] = V.graph.sizevars.size_hint(self.stride[i])
- # check if it is in ascending order
- for i in range(len(order) - 1):
- if stride_ordered[i] > stride_ordered[i + 1]:
- return False
- return True
- def is_channels_last_stride_ordered(self):
- # create channels_last order(NCHW, NCDHW, the C is the first order).
- order = [0] + list(reversed(range(1, len(self.stride) - 1)))
- order = [len(order)] + order
- return self.is_stride_ordered(order)
- def as_fixed(self):
- return FixedLayout(
- self.device,
- self.dtype,
- self.size,
- self.stride,
- self.offset,
- )
- def make_indexer(self):
- assert (
- FlexibleLayout.allow_indexing
- ), f"convert {type(self).__name__} to FixedLayout first"
- return self.as_fixed().make_indexer()
- def __eq__(self, other) -> bool:
- return (
- self.device == other.device
- and self.dtype == other.dtype
- and self.size == other.size
- and self.stride == other.stride
- and self.offset == other.offset
- )
- class FixedLayout(Layout):
- """A Tensor layout we cannot change"""
- def __init__(
- self,
- device: torch.device,
- dtype: torch.dtype,
- size: List[Expr],
- stride: List[Expr] = None,
- offset: Expr = Integer(0),
- ):
- if stride is None:
- stride = FlexibleLayout.contiguous_strides(size)
- super().__init__(
- device,
- dtype,
- size,
- stride,
- offset,
- )
- def make_indexer(self):
- """A closure containing math to read a given element"""
- def indexer(index):
- assert len(index) == len(self.stride) == len(self.size)
- result = self.offset
- for idx, stride, sz in zip(index, self.stride, self.size):
- if sz != 1:
- result = result + idx * stride
- return result
- return indexer
- class FlexibleLayout(Layout):
- """A Tensor layout we are allowed to change"""
- allow_indexing = False
- @staticmethod
- def contiguous_strides(sizes):
- if len(sizes) == 0:
- return []
- reversed_strides = [sympy.Integer(1)]
- for size in reversed(sizes[1:]):
- reversed_strides.append(size * reversed_strides[-1])
- return list(reversed(reversed_strides))
- @staticmethod
- def fill_ordered(sizes, order):
- """
- Create a stride based on the order the dimensions should be filled in.
- In this format, channels last would be:
- [1, 3, 2, 0]
- """
- assert set(range(len(sizes))) == set(order)
- next_stride = sympy.Integer(1)
- strides = [None] * len(order)
- for i in order:
- strides[i] = next_stride
- next_stride = next_stride * sizes[i]
- return strides
- @staticmethod
- def stride_ordered(sizes, order):
- """
- Create a stride based on the sorted order of a permuted range.
- In this format, channels last would be:
- [3, 0, 2, 1]
- """
- assert set(range(len(sizes))) == set(order)
- fill_order = stride_order2fill_order(order)
- return FlexibleLayout.fill_ordered(sizes, fill_order)
- @staticmethod
- def same_ordered(sizes, stride):
- """
- Create a stride that has the same stride order as given stride
- For example, if given stride is [1000, 1, 100, 10],
- the fill order should be [1, 3, 2, 0]
- """
- assert len(sizes) == len(stride)
- stride = [V.graph.sizevars.size_hint(x) for x in stride]
- fill_order = sorted(range(len(stride)), key=stride.__getitem__)
- return FlexibleLayout.fill_ordered(sizes, fill_order)
- def as_stride_order(self, order):
- return FixedLayout(
- self.device,
- self.dtype,
- self.size,
- self.stride_ordered(self.size, order),
- self.offset,
- )
- def as_fill_order(self, order):
- return FixedLayout(
- self.device,
- self.dtype,
- self.size,
- self.fill_ordered(self.size, order),
- self.offset,
- )
- def as_same_order(self, stride):
- return FixedLayout(
- self.device,
- self.dtype,
- self.size,
- self.same_ordered(self.size, stride),
- self.offset,
- )
- def __init__(self, device, dtype, size, stride_order=None):
- if stride_order:
- strides = FlexibleLayout.fill_ordered(size, stride_order)
- else:
- strides = FlexibleLayout.contiguous_strides(size)
- super().__init__(device, dtype, size, strides)
- class AliasedLayout(Layout):
- """Shares the same storage as another tensor"""
- def __init__(self, view: "ReinterpretView"):
- layout = view.get_layout()
- super().__init__(
- layout.device,
- layout.dtype,
- layout.size,
- layout.stride,
- )
- self.view = view
- def make_indexer(self):
- return self.as_fixed().make_indexer()
- def maybe_guard_aligned(self):
- offset = self.view.get_layout().offset
- if offset == 0:
- return True
- from .compile_fx import ALIGNMENT
- return V.graph.sizevars.maybe_guard_multiple_of(offset, ALIGNMENT)
- class MutationLayout(Layout):
- def __init__(self, target: IRNode):
- super().__init__(
- target.get_device(),
- target.get_dtype(),
- target.get_size(),
- None, # type: ignore[arg-type]
- )
- self.target = target
- @Layout.stride.getter
- def stride(self):
- return self.real_layout().stride
- def real_layout(self):
- if isinstance(self.target, MutationLayout):
- return self.target.real_layout()
- return self.target.data.layout
- @classmethod
- def realize_into(cls, src, dst):
- dst.realize()
- V.graph.realize_users_of(dst.get_name())
- if isinstance(src, TensorBox):
- src = src.data
- if not isinstance(src, StorageBox) or src.is_user_of(dst.get_name()):
- need_copy = True
- else:
- src.realize()
- need_copy = not isinstance(src.data.layout, FlexibleLayout)
- if need_copy:
- src = Pointwise.create(
- device=src.get_device(),
- dtype=src.get_dtype(),
- inner_fn=src.make_loader(),
- ranges=[
- V.graph.sizevars.guard_equals(a, b)
- for a, b in zip(src.get_size(), dst.get_size())
- ],
- ).data
- src.realize()
- assert isinstance(src.data.layout, FlexibleLayout)
- src.data.layout = MutationLayout(dst)
- return src.data
- def as_fixed(self):
- return self
- def make_indexer(self):
- return self.target.make_indexer()
- @dataclasses.dataclass
- class Buffer(IRNode):
- name: str
- layout: Layout
- def make_indexer(self):
- return self.layout.make_indexer()
- def get_name(self):
- assert self.name
- return self.name
- def get_device(self):
- return self.layout.device
- def get_dtype(self):
- return getattr(self.layout, "dtype", None)
- def get_size(self):
- return list(self.layout.size)
- def get_stride(self):
- return list(self.layout.stride)
- def get_layout(self):
- return self.layout
- def get_storage_numel(self):
- return self.get_numel()
- def is_extern(self):
- return False
- def freeze_layout(self):
- if not isinstance(self.layout, MultiOutputLayout):
- self.layout = self.layout.as_fixed()
- def freeze_layout_with_stride_order(self, order):
- assert isinstance(self.layout, FlexibleLayout)
- self.layout = self.layout.as_stride_order(order)
- def freeze_layout_with_fill_order(self, order):
- assert isinstance(self.layout, FlexibleLayout)
- self.layout = self.layout.as_fill_order(order)
- def freeze_layout_with_same_order(self, stride):
- assert isinstance(self.layout, FlexibleLayout)
- self.layout = self.layout.as_same_order(stride)
- def make_loader(self):
- def loader(index):
- indexer = self.layout.make_indexer()
- return ops.load(self.name, indexer(index))
- return loader
- def is_no_op(self):
- return False
- def codegen_reference(self):
- return self.get_name()
- def decide_layout(self):
- pass
- def get_alias_names(self):
- if isinstance(self.layout, AliasedLayout):
- return [self.layout.view.get_name()]
- return ()
- def get_mutation_names(self):
- if isinstance(self.layout, MutationLayout):
- return [self.layout.target.get_name()]
- return ()
- @cache_on_self
- def get_read_writes(self):
- with patch.object(FlexibleLayout, "allow_indexing", True):
- return extract_read_writes(
- self.make_loader(),
- self.get_size(),
- )
- def get_reads(self):
- return self.get_read_writes().reads
- def realize(self):
- pass
- class InputBuffer(Buffer):
- pass
- class ConstantBuffer(InputBuffer):
- override_device = None
- def make_loader(self):
- def loader(index):
- indexer = self.layout.make_indexer()
- return ops.load(
- V.graph.constant_name(self.name, self.override_device), indexer(index)
- )
- return loader
- def constant_to_device(self, device):
- return ConstantBuffer(V.graph.constant_name(self.name, device), self.layout)
- class RandSeedBuffer(ConstantBuffer):
- def codegen_reference(self):
- # Clone makes sure if we pass this from forwards to backwards
- # the value does not get clobbered by the time backwards is run.
- return self.get_name() + ".clone()"
- class NoneAsConstantBuffer(IRNode):
- def codegen_reference(self):
- return "None"
- def cpp_wrapper_codegen_reference(self):
- return "at::Tensor()"
- class ShapeAsConstantBuffer(IRNode):
- def __init__(self, shape):
- super().__init__()
- self.shape = shape
- def codegen_reference(self):
- return str(V.graph.sizevars.simplify(self.shape))
- @dataclasses.dataclass
- class ComputedBuffer(Buffer):
- data: Loops
- @cache_on_self
- def get_read_writes(self):
- with patch.object(FlexibleLayout, "allow_indexing", True):
- if self.data.get_reduction_type():
- return extract_read_writes(
- self.get_store_function(),
- self.data.get_size(),
- self.data.get_reduction_size(),
- )
- else:
- return extract_read_writes(
- self.get_store_function(),
- self.data.get_size(),
- )
- def get_store_function(self):
- indexer = self.layout.as_fixed().make_indexer()
- if self.data.get_reduction_type():
- return partial(self.data.store_reduction, self.name, indexer)
- else:
- return partial(self.data.store_output, self.name, indexer)
- def get_fill_order(self):
- """
- If our layout is still flexible, try to determine the stride order based on stride orders of reads.
- TODO(jansel): A better algorithm here would look at downstream consumers of this
- value and try to do global graph-level layout optimization.
- This is also something just begging to be autotuned.
- """
- if isinstance(self.layout, FlexibleLayout):
- _, (index_vars, reduction_vars), _ = dependencies.index_vars_squeeze(
- self.data.get_size(), self.data.get_reduction_size()
- )
- reads = self.get_read_writes().reads
- reads_bufs = [
- V.graph.name_to_buffer[r.name]
- if r.name in V.graph.name_to_buffer.keys()
- else None
- for r in reads
- ]
- priority_idx = []
- for i, reads_buf in enumerate(reads_bufs):
- if (
- isinstance(reads_buf, Convolution)
- and reads_buf.kernel != "aten.convolution"
- ):
- # prioritize Conv layout order
- priority_idx.append(i)
- # only consider reads to buffer of same size
- reads = [
- sympy_subs(
- r.index, {v: sympy.Integer(0) for v in reduction_vars if v != 0}
- )
- for r in reads
- ]
- if reads:
- stride_lengths = [
- V.graph.sizevars.stride_hints(expr, index_vars) for expr in reads
- ]
- from .scheduler import pick_loop_order
- return pick_loop_order(stride_lengths, self.get_size(), priority_idx)
- return None
- def decide_layout(self):
- if isinstance(self.layout, FlexibleLayout):
- order = self.get_fill_order()
- if order:
- self.freeze_layout_with_fill_order(order)
- else:
- self.freeze_layout()
- def simplify_and_reorder(self):
- """
- This is a main place where we do loop transformations in a
- backend-agnostic way.
- Here we:
- 1) Remove any 1 dimensions
- 2) Fuse contiguous dimensions together
- 3) Reorder dimensions based on stride orders
- """
- _, args, var_ranges = dependencies.index_vars_squeeze(
- self.data.get_size(), self.data.get_reduction_size(), prefix="q"
- )
- with patch.object(ConstantBuffer, "override_device", self.get_device()):
- body = LoopBody(
- self.get_store_function(),
- (args if self.get_reduction_type() else args[:1]),
- var_ranges,
- )
- index_formulas = [*body.indexing_exprs.values()]
- reads_bufs = [
- V.graph.name_to_buffer[reads_name]
- if reads_name in V.graph.name_to_buffer.keys()
- else None
- for reads_name in body.reads_name2expr.keys()
- ]
- priority_idx = []
- memory_addrs = [
- *body.reads_name2expr.values(),
- *body.writes_name2expr.values(),
- ]
- index_vars = []
- reduce_vars = []
- index_size = []
- reduce_size = []
- for v, s in var_ranges.items():
- if v in args[0]:
- assert not reduce_vars
- index_vars.append(v)
- index_size.append(s)
- else:
- assert v in args[1]
- reduce_vars.append(v)
- reduce_size.append(s)
- # the reordering_reindex in reads' simplify_reorder_and_tile
- reordering_reindex = [same_reorder(range(len(index_vars)))] * len(memory_addrs)
- for i, reads_buf in enumerate(reads_bufs):
- if isinstance(reads_buf, ComputedBuffer) and hasattr(
- reads_buf, "iter_reordering_reindex"
- ):
- reordering_reindex[i] = reads_buf.iter_reordering_reindex
- def simplify_and_reorder(x_vars, sizes, reordering_reindex=None):
- sizes, reindex0, reindex1 = self._apply_loop_reordering(
- x_vars, sizes, memory_addrs, reordering_reindex, priority_idx
- )
- # for NHWC: reindex0([0,1,2,3]) = [0,2,3,1], reindex1([0,1,2,3]) = [0,3,2,1]
- x_vars = reindex0(x_vars)
- sizes, reindex2, prune = V.graph.sizevars._simplify_loops(
- x_vars,
- sizes,
- index_prevent_reordering(index_formulas, x_vars, sizes),
- )
- x_vars = prune(x_vars)
- # sizes, reindex1, prune = _simplify_loops(x_vars, sizes, index_formulas)
- # x_vars = prune(x_vars)
- # sizes, reindex2 = self._apply_loop_reordering(x_vars, sizes, memory_addrs)
- reindex = fuse_reindexing(reindex1, reindex2)
- return sizes, reindex, reindex1
- iter_ranges, iter_reindex, iter_reordering_reindex = simplify_and_reorder(
- index_vars, index_size, reordering_reindex
- )
- reduce_ranges, reduce_reindex, _ = simplify_and_reorder(
- reduce_vars, reduce_size
- )
- # remember the reordering order
- self.iter_reordering_reindex = iter_reordering_reindex
- # retrace the loop body with simplification and reordering applied
- (iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze(
- iter_ranges, reduce_ranges, prefix="z"
- )
- body = LoopBody(
- body, [iter_reindex(iter_vars), reduce_reindex(reduce_vars)], var_ranges
- )
- return (iter_ranges, reduce_ranges), body
- @staticmethod
- def _apply_loop_reordering(
- index_vars, sizes, memory_addrs, reordering_reindex=None, priority_idx=None
- ):
- """
- Shuffle the order of loops around to hopefully improve performance.
- """
- from .scheduler import pick_loop_order
- if priority_idx is None:
- priority_idx = []
- try:
- strides = [
- V.graph.sizevars.stride_hints(expr, index_vars) for expr in memory_addrs
- ]
- assert len(strides) == len(memory_addrs) and len(strides[0]) == len(
- index_vars
- )
- # consider both layout(strides) and reordering(reordering_reindex)
- if reordering_reindex is not None:
- for i in range(len(memory_addrs)):
- try:
- strides[i] = reordering_reindex[i](strides[i])
- # if len(order) != len(strides), do not reorder
- except AssertionError:
- pass
- order = list(reversed(pick_loop_order(strides, sizes, priority_idx)))
- except Exception:
- if config.debug:
- log.warning(
- f"Did not simplify complex index:\n{dict(zip(index_vars, sizes))}\n{memory_addrs}"
- )
- order = list(range(len(sizes)))
- sizes = [sizes[i] for i in order]
- return sizes, same_reorder(order), inverse_reorder(order)
- def get_reduction_size(self):
- return self.data.get_reduction_size()
- def get_reduction_type(self):
- return self.data.get_reduction_type()
- def is_no_op(self):
- return self.data.is_zero_elements()
- def should_allocate(self):
- return True
- def constant_to_device(self, device):
- """Move this to a given device. Requires that all reads are to constants."""
- return self.data.constant_to_device(device)
- class TemplateBuffer(Buffer):
- """
- Represents a Triton (in the futurue other type) of template operator
- that we can fuse an epilogue onto.
- """
- def __init__(self, layout, inputs, make_kernel_render):
- super().__init__(name=None, layout=layout)
- self.inputs = InputsKernel.unwrap_storage(inputs)
- self.make_kernel_render = make_kernel_render
- self.name = V.graph.register_buffer(self)
- def get_read_writes(self):
- return self.normalized_read_writes()
- @cache_on_self
- def normalized_read_writes(self):
- name = self.get_name()
- indexer = self.layout.make_indexer()
- def dummy(index, rindex):
- assert len(rindex) == 0
- return ops.store(name, indexer(index), "fake")
- deps = dependencies.extract_read_writes(
- dummy, self.get_size(), (), normalize=True
- )
- deps.reads = {dependencies.StarDep(x.get_name()) for x in self.inputs}
- return deps
- def get_reduction_size(self):
- return 1
- def get_reduction_type(self):
- return None
- def is_no_op(self):
- return False
- def should_allocate(self):
- return True
- def simplify_and_reorder(self):
- return (
- (
- self.get_size(),
- (),
- ),
- None,
- )
- @dataclasses.dataclass
- class InputsKernel(Buffer):
- inputs: List[Buffer]
- def get_read_writes(self):
- return dependencies.ReadWrites(
- {dependencies.StarDep(x.get_name()) for x in self.inputs},
- {dependencies.StarDep(self.get_name())},
- set(),
- [],
- None,
- )
- @staticmethod
- def unwrap_storage(inputs):
- inputs_new = []
- for x in inputs:
- if isinstance(x, TensorBox):
- x = x.data
- if isinstance(x, StorageBox):
- x = x.data
- if isinstance(x, BaseView) and not isinstance(x, ReinterpretView):
- x = ExternKernel.realize_input(x)
- assert isinstance(x, (Buffer, ReinterpretView)), x
- inputs_new.append(x)
- return inputs_new
- def is_extern(self):
- return True
- class NopKernel(InputsKernel):
- def is_no_op(self):
- return True
- class ConcatKernel(NopKernel):
- """
- There isn't actually a real kernel for concat, we just change the
- storage for the upstream data.
- """
- @classmethod
- def create(cls, inputs, dim):
- device = inputs[0].get_device()
- dtype = inputs[0].get_dtype()
- new_size = list(inputs[0].get_size())
- offsets_start = [0]
- offsets_end = [new_size[dim]]
- assert 0 <= dim < len(new_size)
- for i in range(1, len(inputs)):
- input_size = inputs[i].get_size()
- offsets_start.append(new_size[dim])
- assert len(input_size) == len(new_size)
- assert inputs[i].get_dtype() == dtype
- assert inputs[i].get_device() == device
- for j in range(len(new_size)):
- if j == dim:
- new_size[j] = new_size[j] + input_size[j]
- else:
- new_size[j] = V.graph.sizevars.guard_equals(
- new_size[j], input_size[j]
- )
- offsets_end.append(new_size[dim])
- output_stride = FlexibleLayout.contiguous_strides(new_size)
- # If any of the inputs is in CL format, use CL format for the output
- for i in range(len(inputs)):
- x = inputs[i]
- if is_storage_and_layout(x):
- layout = x.get_layout()
- if (
- isinstance(layout, FixedLayout)
- and layout.is_channels_last_contiguous()
- ):
- # use CL stride for the output
- output_stride = make_channels_last_strides_for(new_size)
- break
- kernel = ConcatKernel(
- name=None,
- layout=FixedLayout(
- device=device,
- dtype=dtype,
- size=new_size,
- stride=output_stride,
- ),
- inputs=[],
- )
- kernel = StorageBox(kernel)
- for i in range(len(inputs)):
- kernel.data.inputs.append(
- cls.realize_into(
- inputs[i],
- SliceView.create(kernel, dim, offsets_start[i], offsets_end[i]),
- )
- )
- kernel.data.name = V.graph.register_buffer(kernel.data)
- kernel.data.inputs = cls.unwrap_storage(kernel.data.inputs)
- return kernel
- @classmethod
- def realize_into(cls, src, dst):
- # Attempt to turn this into a ReinterpretView rather than assert.
- # This has concessions around layout, as as_storage_and_layout
- # can cause us to go from flexible to fixed layout.
- if not isinstance(dst, ReinterpretView):
- if is_storage_and_layout(dst):
- storage, layout = as_storage_and_layout(dst)
- dst = ReinterpretView(storage, layout)
- assert isinstance(dst, ReinterpretView), dst
- if isinstance(src, TensorBox):
- # unwrap a TensorBox
- return cls.realize_into(src.data, dst)
- if isinstance(src, StorageBox):
- src.realize()
- # ExternKernelAlloc has specific requirements for output layout, should create a copy
- if isinstance(src.data.layout, FlexibleLayout) and not isinstance(
- src.data, ExternKernelAlloc
- ):
- src.data.layout = AliasedLayout(dst)
- return src.data
- # introduce a copy
- pw = Pointwise.create(
- device=src.get_device(),
- dtype=src.get_dtype(),
- inner_fn=src.make_loader(),
- ranges=[
- V.graph.sizevars.guard_equals(a, b)
- for a, b in zip(src.get_size(), dst.get_size())
- ],
- )
- return cls.realize_into(pw, dst)
- def should_allocate(self):
- return True
- @dataclasses.dataclass
- class ExternKernel(InputsKernel):
- constant_args: Tuple[Any, ...] = ()
- kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
- output_view: Optional[ReinterpretView] = None
- def decide_layout(self):
- if isinstance(self.layout, FlexibleLayout):
- self.apply_constraint()
- self.freeze_layout()
- def codegen(self, wrapper):
- raise NotImplementedError
- @staticmethod
- def copy_input(x):
- pw = Pointwise.create(
- device=x.get_device(),
- dtype=x.get_dtype(),
- inner_fn=x.make_loader(),
- ranges=x.get_size(),
- )
- pw.realize()
- return pw
- @classmethod
- def process_kernel(cls, kernel, *args, **kwargs):
- binded_args = signature(kernel).bind(*args, **kwargs).arguments
- args_flat, args_spec = pytree.tree_flatten(binded_args)
- is_arg_tensor = []
- tensor_args = []
- non_tensor_args = []
- for arg in args_flat:
- is_arg_tensor.append(isinstance(arg, IRNode))
- if is_arg_tensor[-1]:
- tensor_args.append(arg)
- else:
- if isinstance(arg, sympy.Expr):
- arg = V.graph.sizevars.shape_env.create_symintnode(arg, hint=None)
- non_tensor_args.append(arg)
- def unflatten_args(new_tensor_args, new_non_tensor_args):
- result = []
- it_tensors = iter(new_tensor_args)
- it_non_tensors = iter(new_non_tensor_args)
- for is_tensor in is_arg_tensor:
- if is_tensor:
- result.append(next(it_tensors))
- else:
- result.append(next(it_non_tensors))
- result = pytree.tree_unflatten(result, args_spec)
- return result.get("args", []), result.get("kwargs", {})
- tensor_args = [cls.realize_input(x) for x in tensor_args]
- # freeze layout otherwise our output stride calculation might
- # become incorrect
- for x in tensor_args:
- if is_storage_and_layout(x):
- as_storage_and_layout(x, freeze=True)
- # We don't have generic shape formulas, so just burn in the
- # shapes and run an example input.
- # TODO(jansel): replace this with dynamic shape formulas
- example_args = []
- for x in tensor_args:
- example_args.append(ir_node_to_tensor(x, guard_shape=True))
- new_args, new_kwargs = unflatten_args(example_args, non_tensor_args)
- example_output = kernel(*new_args, **new_kwargs)
- return example_output, tensor_args, non_tensor_args, unflatten_args
- @classmethod
- def convert_to_reinterpret_view(cls, x):
- """
- In order to pass this to an extern kernel we need a
- ReinterpretView not a View. This allows us to avoid some
- uneeded copies.
- """
- assert isinstance(x, BaseView)
- if isinstance(x, ReinterpretView):
- return x
- x.unwrap_view().freeze_layout()
- rw = extract_read_writes(x.make_loader(), x.get_size(), normalize=False)
- assert len(rw.reads) == 1
- index = V.graph.sizevars.simplify_with_ranges(
- list(rw.reads)[0].index, rw.var_ranges
- )
- strides = V.graph.sizevars.stride_vars(index, rw.range_vars)
- offset = V.graph.sizevars.offset_var(index, rw.range_vars)
- expected = sympy_dot(rw.range_vars, strides) + offset
- if index != expected:
- log.debug(
- "convert_to_reinterpret_view failed: stride=%s offset=%s index=%s",
- strides,
- offset,
- index,
- )
- raise NotImplementedError()
- return ReinterpretView(
- data=x.data,
- layout=FixedLayout(
- device=x.get_device(),
- dtype=x.get_dtype(),
- size=x.get_size(),
- stride=strides,
- offset=offset,
- ),
- )
- @classmethod
- def realize_input(cls, x):
- if x is None:
- return NoneAsConstantBuffer()
- if isinstance(x, (sympy.Expr, int)):
- return ShapeAsConstantBuffer(x)
- if isinstance(x, Constant):
- return V.graph.add_tensor_constant(
- torch.tensor(x.value, dtype=x.get_dtype(), device=x.get_device())
- )
- if isinstance(x, ConstantBuffer):
- return x
- if isinstance(x, TensorBox):
- return cls.realize_input(x.data)
- if isinstance(x, ReinterpretView):
- return x
- if isinstance(x, BaseView):
- x.realize()
- if is_storage_and_layout(x.unwrap_view()) and not isinstance(
- x.unwrap_view().data, ExternKernelAlloc
- ):
- try:
- return cls.convert_to_reinterpret_view(x)
- except NotImplementedError:
- pass
- if isinstance(x, StorageBox):
- # TODO(jansel): impose layout preference on realized buffer
- x.realize()
- return x
- return cls.copy_input(x)
- @classmethod
- def require_stride1(cls, x):
- if is_storage_and_layout(x):
- if len(x.get_stride()) == 0:
- return x
- for stride in x.get_stride():
- if stride == 1:
- return x
- return cls.copy_input(x)
- @classmethod
- def require_stride_order(cls, x, order):
- if x.get_numel() == 0: # Layout doesn't matter
- return x
- # require x to have the layout as strided_ordered as order
- if is_storage_and_layout(x):
- if isinstance(x.get_layout(), FlexibleLayout):
- # fix flexiblelayout to be FixedLayout with stride_order
- as_storage_and_layout(
- x, freeze=True, want_contiguous=False, stride_order=order
- )
- return x
- elif isinstance(
- x.get_layout(), FixedLayout
- ) and x.get_layout().is_stride_ordered(order):
- return x
- elif isinstance(x.get_layout(), MutationLayout):
- if isinstance(x.get_layout().real_layout(), FlexibleLayout):
- raise AssertionError(
- "the MutationLayout's real layout shouldn't be FlexibleLayout"
- )
- elif isinstance(
- x.get_layout().real_layout(), FixedLayout
- ) and x.get_layout().real_layout().is_stride_ordered(order):
- return x
- # TODO - Storage to InputBuffer
- if isinstance(x, InputBuffer) and x.get_layout().is_stride_ordered(order):
- return x
- x = cls.copy_input(x)
- as_storage_and_layout(x, freeze=True, want_contiguous=False, stride_order=order)
- assert is_stride_order_storage_and_layout(x, order)
- return x
- @classmethod
- def require_contiguous(cls, x):
- return cls.require_stride_order(x, list(reversed(range(len(x.get_size())))))
- def apply_constraint(self):
- pass
- def codegen_args(self):
- args = [x.codegen_reference() for x in self.inputs]
- args.extend(map(repr, self.constant_args))
- return args
- def codegen_kwargs(self):
- kwargs = []
- if self.kwargs:
- kwargs = [f"{k}={repr(v)}" for k, v in self.kwargs.items()]
- return kwargs
- def cpp_wrapper_codegen_kwargs(self):
- kwargs = []
- if self.kwargs:
- for arg_name in self.ordered_kwargs_for_cpp_kernel:
- assert arg_name in self.kwargs, (
- "arg %s not found in self.kwargs" % arg_name
- )
- v = self.kwargs.get(arg_name)
- kwargs.append(repr(v))
- return kwargs
- def codegen_size_asserts(self, wrapper):
- if config.size_asserts:
- size = V.graph.sizevars.codegen_shape_tuple(self.get_size())
- stride = V.graph.sizevars.codegen_shape_tuple(self.get_stride())
- wrapper.writeline(
- f"assert_size_stride({self.get_name()}, {size}, {stride})"
- )
- def get_group_stride(self):
- """
- get output sizes and strides, for template_codegen
- """
- _size = self.get_size()
- _stride = self.get_stride()
- # iter_ranges = _size of output tensor, reduce_range = [] because no reduction
- return [_size, []], _stride
- def canonicalize(self):
- """
- Manually get cononicalization of the output index
- """
- # manually generate index formula for conv
- sizevars = V.graph.sizevars
- sizes = self.get_size()
- strides = self.get_stride()
- strides = [sizevars.size_hint(x) for x in strides]
- index_vars = [sympy_symbol(f"d{i}") for i in range(len(sizes))]
- # reorder index vars according to stride
- index_order = sorted(range(len(strides)), key=strides.__getitem__, reverse=True)
- lookup = {pos: idx for idx, pos in enumerate(index_order)}
- order = [lookup[i] for i in range(len(lookup))]
- index_vars = [index_vars[i] for i in order]
- indexer = self.make_indexer()
- index = indexer(index_vars)
- new_sizes, reindex, prune = V.graph.sizevars._simplify_loops(
- index_vars, sizes, [index]
- )
- # assign new variables each dimension to deal with numbering mismatches
- # d0, d1, d2 could become d0, d2 -- which won't match d0, d1
- _, add_var = var_builder("c")
- replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes])))
- index = sympy_subs(sympy.expand(index), replacement)
- return index, tuple(new_sizes)
- def __str__(self):
- lines = [
- f"{field.name}={getattr(self, field.name)}"
- for field in dataclasses.fields(self)
- ]
- return self.str_helper(lines)
- @dataclasses.dataclass
- class ExternKernelOut(ExternKernel):
- output_view: Optional[ReinterpretView] = None
- def codegen(self, wrapper):
- args = self.codegen_args()
- from torch._inductor.codegen.wrapper import CppWrapperCodeGen
- if isinstance(wrapper, CppWrapperCodeGen):
- kwargs = self.cpp_wrapper_codegen_kwargs()
- else:
- kwargs = self.codegen_kwargs()
- if kwargs:
- args.extend(kwargs)
- wrapper.generate_extern_kernel_out(
- self.output_view,
- self.codegen_reference(),
- args,
- self.kernel,
- self.cpp_kernel,
- )
- def __init__(
- self,
- layout,
- inputs,
- constant_args=(),
- kwargs=None,
- output_view=None,
- kernel=None,
- cpp_kernel=None,
- ):
- super().__init__(
- None, layout, self.unwrap_storage(inputs), constant_args, kwargs or {}
- )
- self.output_view = output_view
- self.name = V.graph.register_buffer(self)
- if kernel is not None:
- self.kernel = kernel
- self.cpp_kernel = cpp_kernel
- def should_allocate(self):
- return True
- class ExternKernelAlloc(ExternKernel):
- def codegen(self, wrapper):
- wrapper.writeline(
- f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})"
- )
- if isinstance(self.layout, Layout):
- self.codegen_size_asserts(wrapper)
- def __init__(self, layout, inputs, constant_args=()):
- super().__init__(None, layout, self.unwrap_storage(inputs), constant_args)
- self.name = V.graph.register_buffer(self)
- def should_allocate(self):
- return False
- def apply_constraint(self):
- raise NotImplementedError
- class InplaceBernoulliFallback(ExternKernel):
- """
- This needs to be a custom class to handle mutation properly
- """
- kernel = "aten.bernoulli_"
- def codegen(self, wrapper):
- (x,) = [t.codegen_reference() for t in self.inputs]
- wrapper.writeline(
- f"{self.kernel}({x}, {', '.join(map(repr, self.constant_args))})"
- )
- def should_allocate(self):
- return False
- def get_mutation_names(self):
- assert isinstance(self.layout, MutationLayout)
- return (self.layout.target.get_name(),)
- def __init__(self, x, *constant_args):
- super().__init__(
- None,
- MutationLayout(x),
- self.unwrap_storage([x]),
- constant_args,
- )
- self.name = V.graph.register_buffer(self)
- class IndexPutFallback(ExternKernel):
- """
- This needs to be a custom class to handle mutation and indices properly
- """
- kernel = "aten.index_put_"
- def codegen(self, wrapper):
- (x, values, *valid_indices) = [t.codegen_reference() for t in self.inputs]
- indices = []
- iter_valid_indices = iter(valid_indices)
- for i, _ in enumerate(self.indices):
- if self.indices[i] is not None:
- indices.append(next(iter_valid_indices))
- else:
- indices.append("None")
- wrapper.writeline(
- f"{self.kernel}({x}, [{','.join(indices)}], {values}, {repr(self.constant_args[0])})"
- )
- def should_allocate(self):
- return False
- def __init__(self, x, indices, values, accumulate):
- self.indices = indices
- valid_indices = [i for i in indices if i is not None]
- tensors = [self.realize_input(x) for x in [x, values, *valid_indices]]
- super().__init__(
- None,
- MutationLayout(x),
- self.unwrap_storage(tensors),
- [accumulate],
- )
- self.name = V.graph.register_buffer(self)
- class DeviceCopy(ExternKernelOut):
- @classmethod
- def create(cls, x, device):
- if not x.is_extern() and all(
- (r.name in V.graph.constants and hasattr(r, "index")) for r in x.get_reads()
- ):
- return x.constant_to_device(device)
- V.graph.device_types.add(device.type)
- V.graph.device_types.add(x.get_device().type)
- developer_warning("DeviceCopy in input program")
- return DeviceCopy(
- FlexibleLayout(
- device=device,
- dtype=x.get_dtype(),
- size=x.get_size(),
- ),
- [cls.realize_input(x)],
- )
- def codegen(self, wrapper):
- args = self.codegen_args()
- assert len(args) == 1
- if self.output_view:
- wrapper.writeline(
- f"{self.output_view.codegen_reference()}.copy_({args[0]})"
- )
- else:
- wrapper.writeline(f"{self.codegen_reference()}.copy_({args[0]})")
- class DynamicScalar(IRNode):
- """
- The result of a call to aten._local_scalar_dense.
- This is not yet implemented. The one model (so far) that calls this
- (fastNLP_Bert) does not actually use the result. So we expect this
- node to get dead code eliminated.
- """
- def get_reads(self):
- return ()
- @dataclasses.dataclass
- class FallbackKernel(ExternKernelAlloc):
- def __init__(
- self,
- layout,
- kernel,
- tensor_args,
- nontensor_args,
- unflatten_args,
- kwargs=None,
- ):
- super().__init__(
- layout,
- tuple(tensor_args),
- tuple(nontensor_args),
- )
- if getattr(torch.ops.aten, kernel.__name__, None) is kernel:
- self.kernel = f"aten.{kernel.__name__}"
- else:
- self.kernel = (
- f"{kernel.__module__.replace('._ops.', '.ops.')}.{kernel.__name__}"
- )
- self.unflatten_args = unflatten_args
- self.kwargs = {} if kwargs is None else kwargs
- V.graph.warn_fallback(self.kernel)
- def codegen_args(self):
- @dataclasses.dataclass
- class Shim:
- ref: Any
- def __repr__(self):
- return self.ref
- def gen_kwarg(k, v):
- return f"{k}={repr(v)}"
- tensor_args = [Shim(x.codegen_reference()) for x in self.inputs]
- constant_args = [Shim(repr(x)) for x in self.constant_args]
- args, kwargs = self.unflatten_args(tensor_args, constant_args)
- return list(map(repr, args)) + [gen_kwarg(k, v) for k, v in kwargs.items()]
- @classmethod
- def create(cls, kernel, *args, **kwargs):
- fake_incorrect_kernels = (
- aten._fft_r2c.default,
- aten._fft_r2c.out,
- aten._fft_c2r.default,
- aten._fft_c2c.default,
- aten._fft_c2c.out,
- aten._linalg_svd.default,
- aten._linalg_svd.U,
- aten._fused_moving_avg_obs_fq_helper_functional,
- )
- context = (
- V.graph.fake_mode if kernel not in fake_incorrect_kernels else nullcontext()
- )
- with context:
- (
- example_output,
- tensor_args,
- non_tensor_args,
- unflatten_args,
- ) = cls.process_kernel(kernel, *args, **kwargs)
- assert tensor_args or isinstance(
- example_output, torch.Tensor
- ), "Not sure where to find device info"
- packed = FallbackKernel(
- MultiOutputLayout(
- tensor_args[0].get_device() if tensor_args else example_output.device
- ),
- kernel,
- tensor_args,
- non_tensor_args,
- unflatten_args,
- kwargs,
- )
- def generate_output(output, index=""):
- if isinstance(output, (list, tuple)):
- return type(output)(
- generate_output(output[i], f"{index}[{i}]")
- for i in range(len(output))
- )
- elif isinstance(output, torch.Tensor):
- return MultiOutput(
- FixedLayout(
- output.device,
- output.dtype,
- convert_shape_to_inductor(output.size()),
- convert_shape_to_inductor(output.stride()),
- ),
- packed,
- index,
- )
- elif isinstance(output, int):
- return output
- else:
- assert output is None, "FallbackKernel output type is not supported"
- return None
- return generate_output(example_output)
- def apply_constraint(self):
- return super().apply_constraint()
- @dataclasses.dataclass
- class MultiOutputLayout(IRNode):
- device: torch.device
- class MultiOutput(ExternKernel):
- def codegen(self, wrapper):
- wrapper.writeline(
- f"{self.get_name()} = {self.inputs[0].get_name()}{self.index}"
- )
- self.codegen_size_asserts(wrapper)
- def __init__(self, layout, input, index: str):
- super().__init__(None, layout, [input], ())
- self.name = V.graph.register_buffer(self)
- self.index = index
- def should_allocate(self):
- return False
- class Convolution(ExternKernelAlloc):
- kernel = "aten.convolution"
- def __init__(
- self,
- layout,
- inputs,
- constant_args=(),
- preferred_stride_order=None,
- kernel="aten.convolution",
- ):
- super().__init__(layout, inputs, constant_args)
- self.kernel = kernel
- self.preferred_stride_order = preferred_stride_order
- def codegen(self, wrapper):
- if self.kernel.startswith("triton_ops."):
- wrapper.header.writeline("from torch._inductor import triton_ops")
- wrapper.writeline(
- f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})"
- )
- if isinstance(self.layout, Layout):
- self.codegen_size_asserts(wrapper)
- @classmethod
- def create(
- cls,
- x: "TensorBox",
- weight: "TensorBox",
- bias: "TensorBox",
- stride_: List[int],
- padding_: List[int],
- dilation_: List[int],
- transposed: bool,
- output_padding_: List[int],
- groups: int,
- ):
- with V.graph.fake_mode:
- x_fake = ir_node_to_tensor(x, guard_shape=True)
- weight_fake = ir_node_to_tensor(weight, guard_shape=True)
- bias_fake = (
- ir_node_to_tensor(bias, guard_shape=True) if bias is not None else bias
- )
- output = torch.ops.aten.convolution(
- x_fake,
- weight_fake,
- bias_fake,
- stride_,
- padding_,
- dilation_,
- transposed,
- output_padding_,
- groups,
- )
- req_stride_order = get_stride_order(output.stride())
- weight = cls.require_stride_order(weight, req_stride_order)
- x = cls.require_stride_order(x, req_stride_order)
- stride = tuple(stride_)
- padding = tuple(padding_)
- dilation = tuple(dilation_)
- assert isinstance(transposed, bool)
- output_padding = tuple(output_padding_)
- assert isinstance(groups, int)
- output_size = output.shape
- weight_shape = [
- sympy.Integer(V.graph.sizevars.guard_static_shape(s))
- for s in weight.get_size()
- ]
- _, _, *kernel_size = weight_shape
- # choose runtime kernel
- config_conv = "aten"
- if (
- config_conv == "aten"
- or len(kernel_size) != 2 # triton conv only supports conv2d
- or not is_triton(x.get_device())
- or transposed
- or groups != 1
- # or x.get_dtype() == torch.float16
- # or x.get_dtype() == torch.bfloat16
- ):
- kernel = "aten.convolution"
- elif config_conv == "triton":
- kernel = "triton_ops.conv"
- else:
- assert config_conv == "autotune"
- from .codegen.autotuner import tuned_conv
- kernel = tuned_conv(
- x.get_size(),
- weight.get_size(),
- x.get_stride(),
- weight.get_stride(),
- stride,
- padding,
- dilation,
- transposed,
- output_padding,
- groups,
- x.get_device(),
- x.get_dtype(),
- )
- # for conv2d or conv3d, prefer channels last format
- transform_x_layout = False
- if kernel == "triton_ops.conv":
- output_layout_str = "torch.channels_last"
- else:
- output_layout_str = (
- "torch.contiguous_format"
- if output.is_contiguous()
- else "torch.channels_last"
- )
- if config.tune_layout and len(x.get_size()) == 4:
- from .codegen.autotuner import tuned_conv_layout
- faster_output_layout_str = tuned_conv_layout(
- kernel,
- x.get_size(),
- weight.get_size(),
- stride,
- padding,
- dilation,
- transposed,
- output_padding,
- groups,
- x.get_device(),
- x.get_dtype(),
- )
- if faster_output_layout_str != output_layout_str:
- output_layout_str = faster_output_layout_str
- transform_x_layout = True
- if output_layout_str == "torch.channels_last":
- stride_order = [0] + list(reversed(range(1, len(kernel_size) + 1)))
- if len(stride_order) < len(output_size):
- # add batch dim if it exists
- stride_order = [len(stride_order)] + stride_order
- strides = make_channels_last_strides_for(output_size)
- else:
- stride_order = list(reversed(range(len(output_size))))
- strides = make_contiguous_strides_for(output_size)
- if transform_x_layout:
- x = cls.require_stride_order(x, stride_order)
- output_layout = FixedLayout(
- x.get_device(),
- x.get_dtype(),
- convert_shape_to_inductor(output_size),
- convert_shape_to_inductor(strides),
- )
- if bias is not None:
- return Convolution(
- output_layout,
- (x, weight, bias),
- (stride, padding, dilation, transposed, output_padding, groups),
- stride_order,
- kernel,
- )
- else:
- return Convolution(
- output_layout,
- (x, weight),
- (bias, stride, padding, dilation, transposed, output_padding, groups),
- stride_order,
- kernel,
- )
- def map_args(self):
- # x, w, bias
- in_args = [x.codegen_reference() for x in self.inputs]
- # stride, padding, dilation, transposed, output_padding, groups
- const_args = self.constant_args
- if len(in_args) < 3:
- # otherwise, bias=None is the first constant_args
- const_args = const_args[1:]
- inout_dict = OrderedDict(
- [
- ("x", f"{in_args[0]}"),
- ("w", f"{in_args[1]}"),
- ("y", f"{self.get_name()}"),
- ]
- )
- args_dict = OrderedDict(
- [
- ("stride_xn", f"{self.inputs[0].get_stride()[0]}"),
- ("stride_xc", f"{self.inputs[0].get_stride()[1]}"),
- ("stride_xh", f"{self.inputs[0].get_stride()[2]}"),
- ("stride_xw", f"{self.inputs[0].get_stride()[3]}"),
- ("stride_wn", f"{self.inputs[1].get_stride()[0]}"),
- ("stride_wc", f"{self.inputs[1].get_stride()[1]}"),
- ("stride_wh", f"{self.inputs[1].get_stride()[2]}"),
- ("stride_ww", f"{self.inputs[1].get_stride()[3]}"),
- ("stride_yn", f"{self.get_stride()[0]}"),
- ("stride_yc", f"{self.get_stride()[1]}"),
- ("stride_yh", f"{self.get_stride()[2]}"),
- ("stride_yw", f"{self.get_stride()[3]}"),
- (
- "stride_biasn",
- f"{self.inputs[0].get_stride()[0]}"
- if len(in_args) >= 3
- else "None",
- ),
- # ("delta_x_ptr", "None"),
- ("BATCH", f"{self.inputs[0].get_size()[0]}"),
- ("IN_C", f"{self.inputs[0].get_size()[1]}"),
- ("IN_H", f"{self.inputs[0].get_size()[2]}"),
- ("IN_W", f"{self.inputs[0].get_size()[3]}"),
- ("KERNEL_N", f"{self.inputs[1].get_size()[0]}"),
- ("KERNEL_H", f"{self.inputs[1].get_size()[2]}"),
- ("KERNEL_W", f"{self.inputs[1].get_size()[3]}"),
- ("OUT_H", f"{self.get_size()[2]}"),
- ("OUT_W", f"{self.get_size()[3]}"),
- ("stride_h", f"{const_args[0][0]}"),
- ("stride_w", f"{const_args[0][1]}"),
- ("padding_h", f"{const_args[1][0]}"),
- ("padding_w", f"{const_args[1][1]}"),
- ("dilation_h", f"{const_args[2][0]}"),
- ("dilation_w", f"{const_args[2][1]}"),
- # ("transposed", f"{const_args[3]}"),
- ("output_padding_h", f"{const_args[4][0]}"),
- ("output_padding_w", f"{const_args[4][1]}"),
- ("groups", f"{const_args[5]}"),
- ]
- )
- # accumulator type
- ACC_TYPE = (
- "tl.float32"
- if self.inputs[0].get_dtype()
- in [torch.float16, torch.bfloat16, torch.float32]
- else "tl.int32"
- )
- CONV1X1_NHWC = (
- "True"
- if self.inputs[0].get_stride()[1] == 1
- and self.inputs[1].get_size()[2] == 1
- and self.inputs[1].get_size()[3] == 1
- else "False"
- )
- # dict for tl.constexpr
- const_dict = OrderedDict(
- [
- ("ACC_TYPE", ACC_TYPE),
- ("CONV1X1_NHWC", CONV1X1_NHWC),
- ]
- )
- # dict for non-kernel args (e.g. delta_x_ptr)
- other_dict = OrderedDict(
- [
- ("device", f'"{self.inputs[0].get_device()}"'),
- ]
- )
- return inout_dict, args_dict, const_dict, other_dict
- def get_template_tiling(self):
- n, c, h, w = self.get_size()
- return (
- n * h * w,
- c,
- sympy.Integer(1),
- )
- def _prepare_convolution_fusion_create(
- cls,
- x: "TensorBox",
- weight: "TensorBox",
- bias: "TensorBox",
- padding_: List[int],
- stride_: List[int],
- dilation_: List[int],
- groups: int,
- transposed: bool = False,
- output_padding_: List[int] = None,
- ):
- """
- This function is a helper function to prepare inputs, layout and constant args
- for convolution post-op fusion's create function, including deciding the output
- layout (channels first or channels last), realizing inputs and make them etc. The
- function only supports the CPU device since conv post-op fusion kernel is only
- supported on CPU right now.
- """
- # Port from aten/src/ATen/native/ConvUtils.h: _conv_input_size
- def _conv_input_size(
- output_size, weight_size, padding, output_padding, stride, dilation, groups
- ):
- assert len(output_size) == len(weight_size), "Expect input dim == weight dim"
- dim = len(output_size)
- assert dim > 2, "Expect input dim > 2"
- BATCH_DIM = 0
- WEIGHT_INPUT_CHANNELS_DIM = 1
- input_size = []
- input_size.append(output_size[BATCH_DIM])
- input_size.append(weight_size[WEIGHT_INPUT_CHANNELS_DIM] * groups)
- for d in range(2, dim):
- kernel = (weight_size[d] - 1) * dilation[d - 2] + 1
- input_size_d = (
- (output_size[d] - 1) * stride[d - 2]
- - (padding[d - 2] * 2)
- + kernel
- + output_padding[d - 2]
- )
- input_size.append(input_size_d)
- return list(map(int, input_size))
- # The size of prepacked_weight is the prepacked weight size of deconv:
- # Groups > 1: [g*o, i/g, ...]
- # Groups == 1: [o, i, ...]
- # Returns original weight size in [i, o, ...]
- def _original_deconv_weight_size(
- prepacked_weight,
- groups,
- ):
- prepacked_weight_size = prepacked_weight.size()
- dim = len(prepacked_weight_size)
- assert dim > 2, "Expect weight dim > 2"
- if groups > 1:
- weight_size = []
- weight_size.append(prepacked_weight_size[1] * groups)
- weight_size.append(prepacked_weight_size[0] / groups)
- for d in range(2, dim):
- weight_size.append(prepacked_weight_size[d])
- else:
- weight_size = prepacked_weight.transpose(0, 1).size()
- return weight_size
- stride = tuple(stride_)
- padding = tuple(padding_)
- dilation = tuple(dilation_)
- assert isinstance(groups, int)
- output_padding = tuple(output_padding_) if output_padding_ else (0, 0)
- with V.graph.fake_mode:
- x_fake = ir_node_to_tensor(x, guard_shape=True)
- weight_fake = ir_node_to_tensor(weight, guard_shape=True)
- if transposed:
- # When transposed, the size of the prepacked oneDNN weight is different
- # from the PyTorch weight. We're not able to run aten conv with such
- # size. We infer the output size from the input params here:
- weight_size = _original_deconv_weight_size(weight_fake, groups)
- input_size = x_fake.size()
- output_size = _conv_input_size(
- input_size,
- weight_size,
- padding,
- output_padding,
- stride,
- dilation,
- groups,
- )
- else:
- bias_fake = (
- ir_node_to_tensor(bias, guard_shape=True) if bias is not None else bias
- )
- output = torch.ops.aten.convolution(
- x_fake,
- weight_fake,
- bias_fake,
- stride,
- padding,
- dilation,
- transposed,
- output_padding,
- groups,
- )
- output_size = output.size()
- req_stride_order = [0] + list(reversed(range(1, len(stride) + 1)))
- req_stride_order = [len(req_stride_order)] + req_stride_order
- output_stride = make_channels_last_strides_for(output_size)
- x = cls.require_stride_order(x, req_stride_order)
- assert x.get_device().type == "cpu" and weight.get_device().type == "cpu"
- inputs = [x, weight]
- kernel_layout = FixedLayout(
- x.get_device(),
- x.get_dtype(),
- convert_shape_to_inductor(output_size),
- convert_shape_to_inductor(output_stride),
- )
- constant_args = [padding, stride, dilation, groups]
- if transposed:
- constant_args.insert(1, output_padding)
- if bias is not None:
- inputs.append(bias)
- else:
- constant_args.insert(0, bias)
- return inputs, constant_args, kernel_layout, req_stride_order
- class ConvolutionUnary(ExternKernelAlloc):
- kernel = "torch.ops.mkldnn._convolution_pointwise"
- def __init__(
- self,
- layout,
- inputs,
- constant_args=(),
- kernel="torch.ops.mkldnn._convolution_pointwise",
- ):
- super().__init__(layout, inputs, constant_args)
- self.kernel = kernel
- def codegen(self, wrapper):
- wrapper.writeline(
- f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})"
- )
- if isinstance(self.layout, Layout):
- self.codegen_size_asserts(wrapper)
- @classmethod
- def create(
- cls,
- x: "TensorBox",
- weight: "TensorBox",
- bias: "TensorBox",
- padding_: List[int],
- stride_: List[int],
- dilation_: List[int],
- groups: int,
- attr,
- scalars,
- algorithm,
- ):
- kernel = "torch.ops.mkldnn._convolution_pointwise"
- (inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create(
- cls, x, weight, bias, padding_, stride_, dilation_, groups
- )
- constant_args = constant_args + [attr, scalars, algorithm]
- return ConvolutionUnary(
- layout=kernel_layout,
- inputs=inputs,
- constant_args=constant_args,
- kernel=kernel,
- )
- class ConvolutionBinary(ExternKernelAlloc):
- kernel = "torch.ops.mkldnn._convolution_pointwise.binary"
- def __init__(
- self,
- layout,
- inputs,
- constant_args=(),
- kernel="torch.ops.mkldnn._convolution_pointwise.binary",
- ):
- super().__init__(layout, inputs, constant_args)
- self.kernel = kernel
- def codegen(self, wrapper):
- wrapper.writeline(
- f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})"
- )
- if isinstance(self.layout, Layout):
- self.codegen_size_asserts(wrapper)
- @classmethod
- def create(
- cls,
- x: "TensorBox",
- other: "TensorBox",
- weight: "TensorBox",
- bias: "TensorBox",
- padding_: List[int],
- stride_: List[int],
- dilation_: List[int],
- groups: int,
- binary_attr: str,
- binary_alpha: Optional[float],
- unary_attr: Optional[str],
- unary_scalars: Optional[List],
- unary_algorithm: Optional[str],
- ):
- kernel = "torch.ops.mkldnn._convolution_pointwise.binary"
- (
- inputs,
- constant_args,
- kernel_layout,
- req_stride_order,
- ) = _prepare_convolution_fusion_create(
- cls, x, weight, bias, padding_, stride_, dilation_, groups
- )
- other = cls.require_stride_order(other, req_stride_order)
- inputs.insert(1, other)
- constant_args = constant_args + [
- binary_attr,
- binary_alpha,
- unary_attr,
- unary_scalars,
- unary_algorithm,
- ]
- return ConvolutionBinary(
- layout=kernel_layout,
- inputs=inputs,
- constant_args=constant_args,
- kernel=kernel,
- )
- class ConvolutionBinaryInplace(ExternKernelAlloc):
- kernel = "torch.ops.mkldnn._convolution_pointwise_.binary"
- def __init__(
- self,
- kernel_layout,
- inputs,
- constant_args=(),
- kernel="torch.ops.mkldnn._convolution_pointwise_.binary",
- ):
- super().__init__(kernel_layout, inputs, constant_args)
- self.kernel = kernel
- def codegen(self, wrapper):
- wrapper.writeline(
- f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})"
- )
- def get_mutation_names(self):
- assert isinstance(self.layout, MutationLayout)
- return (self.layout.target.get_name(),)
- @classmethod
- def create(
- cls,
- x: "TensorBox",
- other: "TensorBox",
- weight: "TensorBox",
- bias: "TensorBox",
- padding_: List[int],
- stride_: List[int],
- dilation_: List[int],
- groups: int,
- binary_attr: str,
- binary_alpha: Optional[float],
- unary_attr: Optional[str],
- unary_scalars: Optional[List],
- unary_algorithm: Optional[str],
- ):
- kernel = "torch.ops.mkldnn._convolution_pointwise_.binary"
- (inputs, constant_args, _, _) = _prepare_convolution_fusion_create(
- cls, x, weight, bias, padding_, stride_, dilation_, groups
- )
- other = cls.realize_input(other)
- V.graph.realize_users_of(other.get_name())
- inputs.insert(1, other)
- constant_args = constant_args + [
- binary_attr,
- binary_alpha,
- unary_attr,
- unary_scalars,
- unary_algorithm,
- ]
- return ConvolutionBinaryInplace(
- kernel_layout=MutationLayout(inputs[1]),
- inputs=inputs,
- constant_args=constant_args,
- kernel=kernel,
- )
- class MKLPackedLinear(ExternKernelAlloc):
- kernel = "torch.ops.mkl._mkl_linear"
- def __init__(
- self,
- layout,
- inputs,
- constant_args=(),
- kernel="torch.ops.mkl._mkl_linear",
- ):
- super().__init__(layout, inputs, constant_args)
- self.kernel = kernel
- def codegen(self, wrapper):
- wrapper.writeline(
- f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})"
- )
- @classmethod
- def create(cls, x, packed_w, orig_w, batch_size):
- kernel = "torch.ops.mkl._mkl_linear"
- x = cls.require_stride1(cls.realize_input(x))
- orig_w = cls.require_stride1(cls.realize_input(orig_w))
- *m, _ = x.get_size()
- oc, _ = orig_w.get_size()
- output_size = list(m) + [oc]
- output_stride = make_contiguous_strides_for(output_size)
- inputs = [x, packed_w, orig_w]
- bias = None
- constant_args = [bias, batch_size]
- return MKLPackedLinear(
- layout=FixedLayout(
- x.get_device(), x.get_dtype(), output_size, output_stride
- ),
- inputs=inputs,
- constant_args=constant_args,
- kernel=kernel,
- )
- class LinearUnary(ExternKernelAlloc):
- kernel = "torch.ops.mkldnn._linear_pointwise"
- def __init__(
- self,
- layout,
- inputs,
- constant_args=(),
- kernel="torch.ops.mkldnn._linear_pointwise",
- ):
- super().__init__(layout, inputs, constant_args)
- self.kernel = kernel
- def codegen(self, wrapper):
- wrapper.writeline(
- f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})"
- )
- @classmethod
- def create(cls, x, w, b, attr, scalars, algorithm):
- kernel = "torch.ops.mkldnn._linear_pointwise"
- x = cls.require_stride1(cls.realize_input(x))
- w = cls.require_stride1(cls.realize_input(w))
- *m, ic = x.get_size()
- oc, ic = w.get_size()
- inputs = [x, w]
- constant_args = [attr, scalars, algorithm]
- if b is not None:
- b = cls.require_stride1(cls.realize_input(b))
- inputs.append(b)
- else:
- constant_args.insert(0, b)
- return LinearUnary(
- layout=FlexibleLayout(
- device=x.get_device(),
- dtype=x.get_dtype(),
- size=list(m) + [oc],
- ),
- inputs=inputs,
- constant_args=constant_args,
- kernel=kernel,
- )
- def apply_constraint(self):
- pass
- class LinearBinary(ExternKernelAlloc):
- kernel = "torch.ops.mkldnn._linear_pointwise.binary"
- def __init__(
- self,
- layout,
- inputs,
- constant_args=(),
- kernel="torch.ops.mkldnn._linear_pointwise.binary",
- ):
- super().__init__(layout, inputs, constant_args)
- self.kernel = kernel
- def codegen(self, wrapper):
- wrapper.writeline(
- f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})"
- )
- @classmethod
- def create(cls, x, y, w, b, attr):
- kernel = "torch.ops.mkldnn._linear_pointwise.binary"
- x = cls.require_stride1(cls.realize_input(x))
- y = cls.require_stride1(cls.realize_input(y))
- w = cls.require_stride1(cls.realize_input(w))
- *m, ic = x.get_size()
- oc, ic = w.get_size()
- inputs = [x, y, w]
- constant_args = [attr]
- if b is not None:
- b = cls.require_stride1(cls.realize_input(b))
- inputs.append(b)
- else:
- constant_args.insert(0, b)
- return LinearBinary(
- layout=FlexibleLayout(
- device=x.get_device(),
- dtype=x.get_dtype(),
- size=list(m) + [oc],
- ),
- inputs=inputs,
- constant_args=constant_args,
- kernel=kernel,
- )
- def apply_constraint(self):
- pass
- class ConvolutionTransposeUnary(ExternKernelAlloc):
- kernel = "torch.ops.mkldnn._convolution_transpose_pointwise"
- def __init__(
- self,
- layout,
- inputs,
- constant_args=(),
- kernel="torch.ops.mkldnn._convolution_transpose_pointwise",
- ):
- super().__init__(layout, inputs, constant_args)
- self.kernel = kernel
- def codegen(self, wrapper):
- wrapper.writeline(
- f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})"
- )
- @classmethod
- def create(
- cls,
- x: "TensorBox",
- weight: "TensorBox",
- bias: "TensorBox",
- padding_: List[int],
- output_padding_: List[int],
- stride_: List[int],
- dilation_: List[int],
- groups_: int,
- attr,
- scalars,
- algorithm,
- ):
- kernel = "torch.ops.mkldnn._convolution_transpose_pointwise"
- transposed = True
- (inputs, constant_args, kernel_layout, _,) = _prepare_convolution_fusion_create(
- cls,
- x,
- weight,
- bias,
- padding_,
- stride_,
- dilation_,
- groups_,
- transposed,
- output_padding_,
- )
- constant_args = constant_args + [attr, scalars, algorithm]
- return ConvolutionTransposeUnary(
- layout=kernel_layout,
- inputs=inputs,
- constant_args=constant_args,
- kernel=kernel,
- )
- @dataclasses.dataclass
- class MutableBox(IRNode):
- """
- TensorBox / StorageBox allow in-place mutation of Tensors
- """
- data: IRNode
- def __getattr__(self, name):
- fn = getattr(self.data, name)
- if callable(fn):
- return fn
- raise AttributeError(f"{type(self.data).__name__}.{name} not callable")
- def __str__(self):
- if isinstance(self.data, MutableBox):
- line0 = f"{type(self).__name__}({type(self.data).__name__}("
- endl = "))"
- inner = self.data.data
- else:
- line0 = f"{type(self).__name__}("
- inner = self.data
- endl = ")"
- lines = [
- line0,
- indent(str(inner)),
- endl,
- ]
- return "\n".join(lines)
- __repr__ = __str__
- class TensorBox(MutableBox):
- @staticmethod
- def create(data):
- return TensorBox(StorageBox(data))
- class StorageBox(MutableBox):
- def is_input_buffer(self):
- if isinstance(self.data, (InputBuffer, ReinterpretView)):
- return self.data.get_name() in V.graph.graph_inputs
- return False
- def realize(self):
- if isinstance(
- self.data,
- (
- ComputedBuffer,
- InputsKernel,
- InputBuffer,
- ReinterpretView,
- TemplateBuffer,
- ),
- ):
- return self.data.get_name()
- assert isinstance(self.data, (Pointwise, Reduction)), type(self.data)
- self.data = ComputedBuffer(
- name=None,
- layout=FlexibleLayout(
- device=self.data.get_device(),
- dtype=self.data.get_dtype(),
- size=self.data.get_size(),
- ),
- data=self.data,
- )
- self.data.name = V.graph.register_buffer(self.data)
- self.data.origins = self.origins
- return self.data.name
- def realize_hint(self):
- """
- Called on buffers we expect to be forced to realize later.
- """
- if isinstance(self.data, (Pointwise, Reduction)) and self.num_reads() > 1:
- self.realize()
- def has_exceeded_max_reads(self):
- return isinstance(self.data, Pointwise) and (
- self.num_reads() > config.realize_acc_reads_threshold
- or len(self.inner_fn_str()) > config.realize_bytes_threshold
- )
- def mark_reuse(self, users):
- """
- A heuristic to decide if we should realize a tensor
- that is used multiple times.
- """
- def should_realize_on_cpu(loops: Union[Pointwise, Reduction]):
- """
- The heuristic for realizing reused result of heavy ops on cpu
- """
- heavy_ops = ["exp"] # a list of heavy ops
- fn_str = loops.inner_fn_str()
- return any([(op + "(") in fn_str for op in heavy_ops])
- if (
- users > 1
- and isinstance(self.data, (Pointwise, Reduction))
- and (
- self.num_reads() > config.realize_reads_threshold
- or len(self.inner_fn_str()) > config.realize_bytes_threshold
- or (is_cpu(self.data) and should_realize_on_cpu(self.data))
- )
- ):
- self.realize()
- @cache_on_self
- def num_reads(self):
- data = self.data
- if isinstance(data, (InputsKernel, InputBuffer, ReinterpretView)):
- return 1
- if isinstance(data, ComputedBuffer):
- read_writes = data.get_read_writes()
- else:
- assert isinstance(data, (Pointwise, Reduction)), type(data)
- read_writes = ComputedBuffer(
- name=None,
- layout=FlexibleLayout(
- device=data.get_device(),
- dtype=data.get_dtype(),
- size=data.get_size(),
- ),
- data=data,
- ).get_read_writes()
- return len(read_writes.reads)
- class InterpreterShim(torch.fx.Interpreter):
- def __init__(self, graph, submodules):
- """
- We don't call super() here to avoid constructing a
- GraphModule which is very expensive (it does codegen).
- """
- self.module = self
- self.graph = graph
- self.submodules = submodules
- self.garbage_collect_values = False
- self.env = {}
- self.fetch_attr = submodules.__getitem__
- self.name = "InterpreterShim"
- self.current_node = None
- def run_node(self, n: torch.fx.Node) -> Any:
- self.current_node = n
- return super().run_node(n)
- def run(self, *args, **kwargs):
- with V.set_interpreter_handler(self):
- return super().run(*args, **kwargs)
- class LoopBody:
- """
- Captures the body of a Loops subclass into an FX graph. Persists any
- indexing simplifications and makes it easier to analyze loop bodies.
- """
- def __init__(self, fn, args, var_ranges):
- super().__init__()
- self.var_ranges = var_ranges
- self.indexing_exprs = {}
- self.indexing_exprs_name = {}
- self.reads = []
- self.writes = []
- self.reads_name2expr = {}
- self.writes_name2expr = {}
- self.other = []
- self.submodules = {"get_index": self.get_index}
- self.subblocks = {}
- self.indirect_vars = []
- self.root_block = LoopBodyBlock(self, fn, args)
- self.indexing = None
- def debug_str(self):
- lines = [f"var_ranges = {dict(self.var_ranges)}"]
- lines.extend([f"{name} = {val}" for name, val in self.indexing_exprs.items()])
- lines.extend(
- [
- block.debug_str(name)
- for name, block in itertools.chain(
- [("body", self.root_block)], self.subblocks.items()
- )
- ]
- )
- return "\n".join(lines)
- def add_index_expr(self, expr: sympy.Expr, category, buf_name):
- getattr(self, category).append(expr)
- if buf_name is not None:
- getattr(self, f"{category}_name2expr")[buf_name] = expr
- if expr not in self.indexing_exprs_name:
- name = f"index{len(self.indexing_exprs)}"
- self.indexing_exprs_name[expr] = name
- self.indexing_exprs[name] = expr
- return self.indexing_exprs_name[expr]
- def add_submodule(self, block, prefix):
- """Not actually for nn.Modules, but subblocks in generated code are mapped to FX call_module opcodes"""
- if prefix[-1].isnumeric() and prefix not in self.submodules:
- name = prefix
- else:
- name = f"{prefix}{len(self.submodules)}"
- self.submodules[name] = block
- return name
- def add_indirect(self):
- name = f"indirect{len(self.indirect_vars)}"
- var = sympy_symbol(name)
- self.indirect_vars.append(var)
- return var
- def replace_indirect(self, old, new):
- """Swap in a variable used in indirect indexing"""
- if str(old) == str(new):
- return
- self.indexing = {k: sympy_subs(v, {old: new}) for k, v in self.indexing.items()}
- def get_index(self, name):
- return self.indexing[name]
- def __call__(self, *indices):
- index = list(itertools.chain(*indices))
- assert len(index) == len(self.var_ranges), (index, self.var_ranges)
- assert all(v not in self.var_ranges for v in index)
- replacements = dict(zip(self.var_ranges.keys(), index))
- self.indexing = {
- name: sympy_subs(expr, replacements)
- for name, expr in self.indexing_exprs.items()
- }
- result = self.root_block()
- self.indexing = None
- return result
- class LoopBodyBlock:
- """
- Captures the body of a Loops subclass into an FX graph.
- In normal cases there will be a 1:1 mapping between LoopBody and
- LoopBodyBlock, hower in the case of ops.masked() the masked out
- operations will manifest as an extra LoopBodyBlock.
- """
- def __init__(self, body: LoopBody, fn: Callable, args: List[Any]):
- self.body = body
- def add_index(expr, category, buf_name=None):
- return tracer.create_proxy(
- "call_module",
- "get_index",
- (self.body.add_index_expr(expr, category, buf_name),),
- {},
- )
- class CaptureIndexing(V.WrapperHandler):
- self.name = "CaptureIndexing"
- def load(self, name: str, index: sympy.Expr):
- index = add_index(index, "reads", name)
- return self._inner.load(name, index)
- def store(self, name, index, value, mode=None):
- index = add_index(index, "writes", name)
- return self._inner.store(name, index, value, mode)
- def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
- index = add_index(index, "writes", name)
- return self._inner.reduction(
- name, dtype, src_dtype, reduction_type, index, value
- )
- def index_expr(self, index, dtype):
- if isinstance(index, (int, sympy.Integer)):
- return ops.constant(int(index), dtype)
- index = add_index(index, "other")
- return self._inner.index_expr(index, dtype)
- @staticmethod
- def masked(mask_proxy, masked_body: Callable, other_proxy):
- """
- Recursively capture the masked out body in another LoopBodyBlock
- """
- def shim(mask, other):
- return V.ops.masked(mask, subblock, other)
- name = self.body.add_submodule(shim, "masked_subblock")
- subblock = LoopBodyBlock(self.body, masked_body, [])
- self.body.subblocks[name] = subblock
- return tracer.create_proxy(
- "call_module", name, (mask_proxy, other_proxy), {}
- )
- @staticmethod
- def indirect_indexing(index_proxy):
- """
- Flow data from tensors into indexing formulas.
- Introduce a call_module to update the indexing.
- """
- def set_indirect(new_var):
- self.body.replace_indirect(var, V.ops.indirect_indexing(new_var))
- var = self.body.add_indirect()
- tracer.create_proxy(
- "call_module",
- self.body.add_submodule(set_indirect, f"set_{var}"),
- (index_proxy,),
- {},
- )
- return var
- tracer = torch.fx.Tracer()
- tracer.graph = torch.fx.Graph(tracer_cls=tracer.__class__)
- proxy_ops = tracer.create_proxy("placeholder", "ops", (), {})
- from .sizevars import SimplifyIndexing
- with V.set_ops_handler(
- SimplifyIndexing(CaptureIndexing(proxy_ops), self.body.var_ranges)
- ):
- tracer.create_proxy("output", "output", (fn(*args),), {})
- self.graph = tracer.graph
- def __call__(self):
- graph = self.graph
- submodules = self.body.submodules
- return InterpreterShim(graph, submodules).run(V.get_ops_handler())
- def debug_str(self, name="block"):
- code = torch.fx.GraphModule(self.body.submodules, self.graph).code
- return re.sub(
- # strip `; del var0` suffixes to make output prettier
- r";[^\n]*",
- "",
- code.strip().replace("def forward(", f"def {name}("),
- )
|