1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490349134923493349434953496349734983499350035013502350335043505350635073508350935103511351235133514351535163517351835193520352135223523352435253526352735283529353035313532353335343535353635373538353935403541354235433544354535463547354835493550355135523553355435553556355735583559356035613562356335643565356635673568356935703571357235733574357535763577357835793580358135823583358435853586358735883589359035913592359335943595359635973598359936003601360236033604360536063607360836093610361136123613361436153616361736183619362036213622362336243625362636273628362936303631363236333634363536363637363836393640364136423643364436453646364736483649365036513652365336543655365636573658365936603661366236633664366536663667366836693670367136723673367436753676367736783679368036813682368336843685368636873688368936903691369236933694369536963697369836993700370137023703370437053706370737083709371037113712371337143715371637173718371937203721372237233724372537263727372837293730373137323733373437353736373737383739374037413742374337443745374637473748374937503751375237533754375537563757375837593760376137623763376437653766376737683769377037713772377337743775377637773778377937803781378237833784378537863787378837893790379137923793379437953796379737983799380038013802380338043805380638073808380938103811381238133814381538163817381838193820382138223823382438253826382738283829383038313832383338343835383638373838383938403841384238433844384538463847384838493850385138523853385438553856385738583859386038613862386338643865386638673868 |
- import functools
- import itertools
- import logging
- from collections.abc import Iterable
- from typing import List, Optional, Tuple
- import sympy
- import torch
- import torch.fx
- import torch.utils._pytree as pytree
- from torch._prims_common import (
- canonicalize_dims,
- dtype_to_type,
- elementwise_dtypes,
- ELEMENTWISE_TYPE_PROMOTION_KIND,
- is_boolean_dtype,
- is_float_dtype,
- is_integer_dtype,
- Number,
- )
- from torch.fx.experimental.symbolic_shapes import magic_methods, method_to_operator
- from .._dynamo.utils import import_submodule
- from . import config, ir, overrides, test_operators # NOQA: F401
- from .cuda_properties import current_device
- from .decomposition import decompositions, get_decompositions
- from .ir import (
- ExpandView,
- IndexingConstant,
- PermuteView,
- Pointwise,
- Reduction,
- SqueezeView,
- TensorBox,
- validate_ir,
- View,
- )
- from .utils import ceildiv, developer_warning, sympy_product
- from .virtualized import ops, V
- log = logging.getLogger(__name__)
- lowerings = {}
- layout_constraints = {}
- fallbacks = set()
- aten = torch.ops.aten
- prims = torch.ops.prims
- needs_realized_inputs = set()
- def add_needs_realized_inputs(fn):
- if isinstance(fn, (list, tuple, set)):
- return [add_needs_realized_inputs(x) for x in fn]
- needs_realized_inputs.add(fn)
- if isinstance(fn, torch._ops.OpOverloadPacket):
- for overload in fn.overloads():
- needs_realized_inputs.add(getattr(fn, overload))
- def add_layout_constraint(fn, constraint):
- if isinstance(fn, torch._ops.OpOverloadPacket):
- for overload in fn.overloads():
- layout_constraints[getattr(fn, overload)] = constraint
- else:
- layout_constraints[fn] = constraint
- add_needs_realized_inputs(
- [
- aten.as_strided,
- aten.avg_pool2d,
- aten.avg_pool2d_backward,
- aten.bmm,
- aten.convolution,
- aten.convolution_backward,
- aten.max_pool2d_with_indices,
- aten.max_pool2d_with_indices_backward,
- aten.mm,
- aten.upsample_bilinear2d,
- aten.upsample_nearest2d,
- aten.upsample_bicubic2d,
- ]
- )
- # TODO(jansel): ezyang says we won't need this in the future, try removing it
- # based on https://github.com/pytorch/pytorch/blob/9e3eb329df8f701/c10/core/ScalarType.h#L28
- DTYPE_ID_LOOKUP = {
- 0: torch.uint8,
- 1: torch.int8,
- 2: torch.int16,
- 3: torch.int32,
- 4: torch.int64,
- 5: torch.float16,
- 6: torch.float32,
- 7: torch.float64,
- 8: torch.complex32,
- 9: torch.complex64,
- 10: torch.complex32,
- 11: torch.bool,
- 15: torch.bfloat16,
- # TODO(jansel): add quantized types?
- # _(c10::qint8, QInt8) /* 12 */
- # _(c10::quint8, QUInt8) /* 13 */
- # _(c10::qint32, QInt32) /* 14 */
- # _(c10::quint4x2, QUInt4x2) /* 16 */
- # _(c10::quint2x4, QUInt2x4) /* 17 */
- }
- def decode_dtype(dtype: int):
- if not isinstance(dtype, int):
- return dtype
- assert dtype in DTYPE_ID_LOOKUP, f"id {dtype} missing from DTYPE_ID_LOOKUP"
- dtype = DTYPE_ID_LOOKUP[dtype]
- return dtype
- def is_integer_type(x):
- if isinstance(x, TensorBox):
- return is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype())
- else:
- return isinstance(x, int)
- def is_boolean_type(x):
- if isinstance(x, TensorBox):
- return is_boolean_dtype(x.get_dtype())
- else:
- return isinstance(x, bool)
- def decode_device(device):
- if device is None:
- return torch.tensor(0.0).device # default device
- if isinstance(device, str):
- device = torch.device(device)
- if device.type == "cuda" and device.index is None:
- return torch.device("cuda", index=current_device())
- return device
- def get_promoted_dtype(*args, type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND):
- def construct_input(inp):
- if isinstance(inp, Number):
- return inp
- else:
- assert hasattr(inp, "get_dtype")
- dim = len(inp.get_size())
- # construct a tmp tensor to feed into torch.result_type
- return torch.zeros([1] * dim, dtype=inp.get_dtype())
- inps = [construct_input(arg) for arg in args]
- _, dtype = elementwise_dtypes(*inps, type_promotion_kind=type_promotion_kind)
- return dtype
- def _register_lowering(
- aten_fn, decomp_fn, broadcast, type_promotion_kind, convert_input_to_bool
- ):
- """
- Add a lowering to lowerings dict
- Arguments:
- aten_fn: torch.ops.aten.* fn we are lowering
- decomp_fn: alternate implementation on our IR
- broadcast: True to apply broadcasting to tensor inputs
- type_promotion_kind: kind of type promotion applied to tensor inputs, `None` means no type promotion
- convert_input_to_bool: some logical ops require inputs are converted to bool
- """
- @functools.wraps(decomp_fn)
- def wrapped(*args, **kwargs):
- args = list(args)
- unpacked = False
- # TODO maybe we need to use pytrees here
- if len(args) == 1 and isinstance(args[0], (list, tuple)):
- unpacked = True
- args = args[0]
- # Only look at args that are Tensors
- indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)]
- # explicitly assert for "out=" ops for better error messages
- assert not any(
- x == "out" for x in kwargs.keys()
- ), "out= ops aren't yet supported"
- # kwargs tensors not supported yet unless it's a fallback op
- assert not any(isinstance(x, TensorBox) for x in kwargs.values()) or all(
- fn in fallbacks for fn in aten_fn
- )
- if (type_promotion_kind or convert_input_to_bool) and indices:
- if convert_input_to_bool:
- dtype = torch.bool
- else:
- # FIXME that's a crude approximation for promoting args
- promoting_args = [
- a for a in args if isinstance(a, Number) or hasattr(a, "get_dtype")
- ]
- dtype = get_promoted_dtype(
- *promoting_args, type_promotion_kind=type_promotion_kind
- )
- # sometimes args are an immutable list so we can't mutate them
- new_args = []
- for i in range(len(args)):
- if i in indices:
- new_args.append(to_dtype(args[i], dtype))
- elif isinstance(args[i], ir.Constant):
- new_args.append(
- ir.Constant(args[i].value, dtype, args[indices[0]].get_device())
- )
- else:
- new_args.append(args[i])
- args = new_args
- if unpacked:
- args = [args]
- if broadcast and indices:
- for i, x in zip(indices, broadcast_tensors(*[args[i] for i in indices])):
- args[i] = x
- for i in range(len(args)):
- if isinstance(args[i], ir.Constant):
- args[i] = ExpandView.create(
- args[i], list(args[indices[0]].get_size())
- )
- out = decomp_fn(*args, **kwargs)
- validate_ir(out)
- return out
- if not isinstance(aten_fn, (list, tuple)):
- aten_fn = [aten_fn]
- else:
- aten_fn = list(aten_fn)
- for fn in list(aten_fn):
- if isinstance(fn, torch._ops.OpOverloadPacket):
- for overload in fn.overloads():
- other_fn = getattr(fn, overload)
- if other_fn not in lowerings:
- aten_fn.append(other_fn)
- lowerings.update({fn: wrapped for fn in aten_fn})
- return wrapped
- def register_lowering(
- aten_fn,
- broadcast=False,
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- convert_input_to_bool=False,
- ):
- """
- Shim to support decorator syntax.
- """
- return functools.partial(
- _register_lowering,
- aten_fn,
- broadcast=broadcast,
- type_promotion_kind=type_promotion_kind,
- convert_input_to_bool=convert_input_to_bool,
- )
- def broadcast_symbolic_shapes(a, b):
- """
- Broadcasting logic based on symbolic shapes.
- We give the shapes 0 and 1 concrete values, while all other shapes
- are symbolic sympy formulas.
- """
- output = []
- for a, b in itertools.zip_longest(
- reversed(a), reversed(b), fillvalue=sympy.Integer(1)
- ):
- if b == 1:
- output.append(a)
- elif a == 1:
- output.append(b)
- else:
- V.graph.sizevars.guard_equals(a, b)
- if len(sympy.expand(b).free_symbols) < len(sympy.expand(a).free_symbols):
- output.append(b) # prefer shorter formula
- else:
- output.append(a)
- return tuple(reversed(output))
- def promote_constants(inputs, override_return_dtype=None):
- if not any(isinstance(x, (sympy.Expr, int, float)) for x in inputs):
- return inputs
- if all(isinstance(x, (int, float)) for x in inputs):
- dtype = override_return_dtype or get_promoted_dtype(
- *inputs, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
- )
- return [ir.Constant(x, dtype, decode_device(None)) for x in inputs]
- ex = next(x for x in inputs if isinstance(x, (TensorBox, ExpandView)))
- out = []
- for x in inputs:
- if isinstance(x, (int, float)):
- out.append(
- ExpandView.create(
- ir.Constant(x, ex.get_dtype(), ex.get_device()), list(ex.get_size())
- )
- )
- elif isinstance(x, sympy.Expr):
- out.append(IndexingConstant(x, ex.get_dtype(), ex.get_device()))
- else:
- out.append(x)
- return out
- def make_pointwise(
- fn,
- override_return_dtype=None,
- override_device=None,
- override_fn_when_input_bool=None,
- override_fn_when_cuda_float64=None,
- allow_alpha=False,
- ):
- def inner(*inputs: List[TensorBox], alpha=None):
- inputs = promote_constants(inputs, override_return_dtype)
- if allow_alpha:
- if alpha is not None and alpha != 1:
- inputs = list(inputs)
- inputs[-1] = mul(inputs[-1], alpha)
- else:
- assert alpha is None
- loaders = [x.make_loader() for x in inputs]
- ranges = inputs[0].get_size()
- dtype = override_return_dtype or inputs[0].get_dtype()
- is_cuda = decode_device(inputs[0].get_device()).type == "cuda"
- for other in inputs[1:]:
- assert isinstance(other, ir.BaseConstant) or len(ranges) == len(
- other.get_size()
- ), f"ndim mismatch {fn} {ranges} {other.get_size()}"
- def inner_fn(index):
- assert len(index) == len(ranges), f"wrong ndim {index} {ranges}"
- if dtype == torch.bool and override_fn_when_input_bool is not None:
- return override_fn_when_input_bool(*[load(index) for load in loaders])
- elif override_fn_when_cuda_float64 and is_cuda and dtype == torch.float64:
- return override_fn_when_cuda_float64(*[load(index) for load in loaders])
- else:
- return fn(*[load(index) for load in loaders])
- if not override_device:
- device = None
- for i in inputs:
- if i.get_device().type == "cuda":
- device = i.get_device()
- break
- if not device:
- device = inputs[0].get_device()
- device = override_device or device
- return Pointwise.create(
- device=device,
- dtype=dtype,
- inner_fn=inner_fn,
- ranges=ranges,
- )
- return inner
- @register_lowering(prims.convert_element_type, type_promotion_kind=None)
- def to_dtype(x: TensorBox, dtype: torch.dtype):
- if x.get_dtype() == dtype:
- return x
- def _to_dtype(x):
- return ops.to_dtype(x, dtype)
- return make_pointwise(_to_dtype, override_return_dtype=dtype)(x)
- @register_lowering(prims.device_put, type_promotion_kind=None)
- def to_device(x: TensorBox, device: torch.device):
- device = decode_device(device)
- if x.get_device() == device:
- return x
- return TensorBox.create(ir.DeviceCopy.create(x, device))
- def ops_wrapper(name):
- assert isinstance(name, str)
- def fn(*args, **kwargs):
- return getattr(ops, name)(*args, **kwargs)
- return fn
- def register_pointwise(
- aten_fn,
- name=None,
- broadcast=True,
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- convert_input_to_bool=False,
- override_return_dtype=None,
- override_fn_when_input_bool=None,
- allow_alpha=False,
- use_libdevice_for_f64=False,
- ):
- """A pointwise function that maps ops.{name} to inputs"""
- name = name or aten_fn.__name__
- fn = ops_wrapper(name)
- if use_libdevice_for_f64:
- fn_libdevice = ops_wrapper("libdevice_" + name)
- if override_fn_when_input_bool is not None:
- override_fn_when_input_bool = ops_wrapper(override_fn_when_input_bool)
- fn = make_pointwise(
- fn,
- override_return_dtype=override_return_dtype,
- override_fn_when_input_bool=override_fn_when_input_bool,
- override_fn_when_cuda_float64=fn_libdevice if use_libdevice_for_f64 else None,
- allow_alpha=allow_alpha,
- )
- fn = register_lowering(
- aten_fn,
- broadcast=broadcast,
- type_promotion_kind=type_promotion_kind,
- convert_input_to_bool=convert_input_to_bool,
- )(fn)
- if hasattr(prims, name):
- register_lowering(
- getattr(prims, name),
- type_promotion_kind=None,
- convert_input_to_bool=convert_input_to_bool,
- )(fn)
- return fn
- @register_lowering(aten.where, broadcast=False, type_promotion_kind=None)
- def where(cond, a, b):
- def fn(*args):
- return ops.where(*args)
- if isinstance(a, (float, int)):
- a = constant_like(a)(b)
- if isinstance(b, (float, int)):
- b = constant_like(b)(a)
- args = [cond, a, b]
- dtype = get_promoted_dtype(
- args[1], args[2], type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
- )
- indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)]
- for i, x in zip(indices, broadcast_tensors(*[args[i] for i in indices])):
- args[i] = x
- for i in range(len(args)):
- if isinstance(args[i], ir.Constant):
- args[i] = ExpandView.create(args[i], list(args[indices[0]].get_size()))
- return make_pointwise(fn, override_return_dtype=dtype)(
- args[0], to_dtype(args[1], dtype), to_dtype(args[2], dtype)
- )
- @register_lowering(aten.broadcast_tensors, broadcast=False, type_promotion_kind=None)
- def broadcast_tensors(*inputs):
- if len(inputs) == 1 and isinstance(inputs[0], (list, tuple)):
- return broadcast_tensors(*inputs[0])
- target = functools.reduce(
- broadcast_symbolic_shapes, [x.get_size() for x in inputs], ()
- )
- outputs = []
- for x in inputs:
- sizes = x.get_size()
- if len(sizes) != len(target) or any(
- ((a == 1 and b != 1) or (a != 1 and b == 1)) for a, b in zip(sizes, target)
- ):
- x = expand(x, target)
- outputs.append(x)
- return outputs
- @register_lowering([aten.alias, aten.detach, aten.detach_, aten.lift, prims.view_of])
- def nop(x):
- return x # AOT autograd handles this for us
- if hasattr(aten, "lift_fresh"):
- register_lowering(aten.lift_fresh)(nop)
- @register_lowering(aten.squeeze, type_promotion_kind=None)
- def squeeze(x, dim=None):
- assert isinstance(x, TensorBox)
- if dim is None:
- return TensorBox(SqueezeView.create(x.data))
- dim = canonicalize_dims(len(x.get_size()), dim)
- dims = set((dim,) if not isinstance(dim, tuple) else dim)
- new_shape = [
- s
- for d, s in enumerate(x.get_size())
- if not (d in dims and V.graph.sizevars.maybe_guard_equals(s, 1))
- ]
- # squeeze does nothing if the size isn't 1
- return view(x, new_shape) if new_shape != x.get_size() else x
- @register_lowering([aten.squeeze_])
- def squeeze_(x, dim=None):
- val = squeeze(x, dim)
- assert isinstance(x, TensorBox)
- assert isinstance(val, TensorBox)
- x.data = val.data
- return x
- @register_lowering(aten.isinf)
- def isinf(x):
- if is_integer_type(x):
- return full_like(x, False, dtype=torch.bool)
- fn = ops_wrapper("isinf")
- return make_pointwise(fn, override_return_dtype=torch.bool)(x)
- @register_lowering(aten.isnan)
- def isnan(x):
- if is_integer_type(x):
- return full_like(x, False, dtype=torch.bool)
- fn = ops_wrapper("isnan")
- return make_pointwise(fn, override_return_dtype=torch.bool)(x)
- @register_lowering(aten.ceil)
- def ceil(x):
- if is_integer_type(x):
- return x
- fn = ops_wrapper("ceil")
- return make_pointwise(fn)(x)
- @register_lowering(aten.floor)
- def floor(x):
- if is_integer_type(x):
- return x
- fn = ops_wrapper("floor")
- return make_pointwise(fn)(x)
- @register_lowering(aten.round)
- def round(x):
- if is_integer_type(x):
- return x
- fn = ops_wrapper("round")
- return make_pointwise(fn)(x)
- @register_lowering(aten.trunc)
- def trunc(x):
- if is_integer_type(x):
- return x
- fn = ops_wrapper("trunc")
- return make_pointwise(fn)(x)
- @register_lowering(aten.expand, type_promotion_kind=None)
- def expand(x, sizes):
- (x,) = promote_constants([x])
- if isinstance(x, ir.BaseConstant):
- return ExpandView.create(x, tuple(sizes))
- assert isinstance(x, TensorBox)
- assert isinstance(sizes, (list, tuple))
- if tuple(x.get_size()) == tuple(sizes):
- return x
- x_size_product = V.graph.sizevars.size_hint(sympy_product(x.get_size()))
- if x_size_product > 0:
- # maybe realize input before broadcasting it
- x.mark_reuse(V.graph.sizevars.size_hint(sympy_product(sizes)) // x_size_product)
- return TensorBox(ExpandView.create(x.data, tuple(sizes)))
- @register_lowering(prims.broadcast_in_dim, type_promotion_kind=None)
- def broadcast_in_dim(a, shape, broadcast_dimensions):
- s = list(shape)
- for broadcast_dimension in broadcast_dimensions:
- s[broadcast_dimension] = -1
- v = a
- for idx, x in enumerate(s):
- if x != -1:
- v = unsqueeze(v, idx)
- return expand(v, shape)
- @register_lowering(aten.expand_as, type_promotion_kind=None)
- def expand_as(x, y):
- return expand(x, y.get_size())
- @register_lowering(aten.repeat)
- def repeat(x, repeats):
- old_size = list(x.get_size())
- if len(repeats) > len(old_size):
- old_size = [sympy.Integer(1)] * (len(repeats) - len(old_size)) + old_size
- x = view(x, list(old_size))
- assert len(repeats) == len(x.get_size())
- new_size = list(x.get_size())
- for i in range(len(repeats)):
- assert repeats[i] != 0
- if repeats[i] != 1:
- new_size[i] = new_size[i] * repeats[i]
- if all((a == 1 or b == 1) for a, b in zip(repeats, old_size)):
- return expand(x, new_size)
- def inner_fn(index):
- assert len(index) == len(repeats)
- index = list(index)
- for i in range(len(repeats)):
- if repeats[i] != 1:
- if old_size[i] == 1:
- index[i] = sympy.Integer(0)
- else:
- index[i] = ir.ModularIndexing(index[i], 1, old_size[i])
- return x_loader(index)
- old_size_product = V.graph.sizevars.size_hint(sympy_product(old_size))
- if old_size_product > 0:
- # maybe realize the input
- x.mark_reuse(
- V.graph.sizevars.size_hint(sympy_product(new_size)) // old_size_product
- )
- x_loader = x.make_loader()
- return Pointwise.create(
- device=x.get_device(),
- dtype=x.get_dtype(),
- inner_fn=inner_fn,
- ranges=list(new_size),
- )
- @register_lowering(aten._unsafe_view, type_promotion_kind=None)
- @register_lowering(aten.view, type_promotion_kind=None)
- @register_lowering(aten.reshape, type_promotion_kind=None)
- def view(x, sizes):
- assert isinstance(x, TensorBox)
- assert isinstance(sizes, (list, tuple))
- return TensorBox(View.create(x.data, sizes))
- @register_lowering(aten.permute, type_promotion_kind=None)
- def permute(x, dims):
- assert isinstance(x, TensorBox)
- assert isinstance(dims, (list, tuple))
- return TensorBox(PermuteView.create(x.data, tuple(dims)))
- @register_lowering(aten.slice, type_promotion_kind=None)
- def slice_(x, dim=0, start=0, end=2**63, step=1):
- assert isinstance(x, TensorBox)
- dim = _validate_dim(x, dim, 0)
- return TensorBox(ir.SliceView.create(x.data, dim, start, end, step))
- @register_lowering(aten.roll, type_promotion_kind=None)
- def roll(a, shifts, dims=tuple()):
- """
- This is based on torch._refs.roll(), but uses ir.ModularIndexing().
- We can't use the ref here because it is based on multiple calls to
- torch.cat() that this will result in terrible code.
- """
- # ATen specifies int[1] type for shifts and dims which expands integers to tuples of length 1
- if not isinstance(shifts, Iterable):
- shifts = (shifts,)
- if not isinstance(dims, Iterable):
- dims = (dims,)
- dims = [_validate_dim(a, d) for d in dims]
- if sympy_product(a.get_size()) == 0:
- return clone(a)
- len_shifts = len(shifts)
- len_dims = len(dims)
- if len_shifts != 1 or len_dims != 1:
- if len_shifts == 0:
- raise RuntimeError("`shifts` required")
- # Takes care of the case when dims is not specified (default)
- # By default, the tensor is flattened before shifting, after which the original shape is restored
- if len_dims == 0 and len_shifts == 1:
- flat = view(a, [sympy_product(a.get_size())])
- rolled = roll(flat, shifts, 0)
- return view(rolled, list(a.get_size()))
- if len_shifts != len_dims:
- raise RuntimeError(
- f"shifts and dimensions must align. shifts: {len_shifts}, dims: {len_dims}"
- )
- tail_shifts = shifts[1:]
- tail_dims = dims[1:]
- first_dim_rolled = roll(a, shifts[0], dims[0])
- return roll(first_dim_rolled, tail_shifts, tail_dims)
- (dim,) = dims
- size = V.graph.sizevars.guard_static_shape(a.get_size()[dim])
- start = (size - shifts[0]) % size
- a_loader = a.make_loader()
- def fn(index):
- index = list(index)
- index[dim] = ir.ModularIndexing(
- index[dim] + start, sympy.Integer(1), sympy.expand(size)
- )
- return a_loader(index)
- return Pointwise.create(
- device=a.get_device(),
- dtype=a.get_dtype(),
- inner_fn=fn,
- ranges=a.get_size(),
- )
- @register_lowering(aten.as_strided, type_promotion_kind=None)
- def as_strided(x, size, stride, storage_offset=None):
- if isinstance(x, TensorBox) and isinstance(x.data, ir.BaseView):
- # as_strided ignores views
- x = x.data.unwrap_view()
- x.realize()
- if not ir.is_storage_and_layout(x):
- raise NotImplementedError(f"unrealized as_strided({x}, ...)")
- storage, old_layout = ir.as_storage_and_layout(x)
- new_layout = ir.FixedLayout(
- old_layout.device,
- old_layout.dtype,
- [sympy.expand(s) for s in size],
- [sympy.expand(s) for s in stride],
- sympy.expand(storage_offset or 0),
- )
- return TensorBox(ir.ReinterpretView(storage, new_layout))
- @register_lowering(aten.as_strided_)
- def as_strided_(x, size, stride, storage_offset=None):
- assert isinstance(x, TensorBox)
- x.data = as_strided(x, size, stride, storage_offset).data
- return x
- @register_lowering(aten.cat)
- def cat(inputs, dim=0):
- if len(inputs) == 1:
- return clone(inputs[0])
- dim = _validate_dim(inputs[0], dim, 0)
- dtype = get_promoted_dtype(
- *inputs, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
- )
- inputs = [to_dtype(inp, dtype) for inp in inputs]
- return TensorBox(ir.ConcatKernel.create(inputs, dim))
- @register_lowering(aten.select, type_promotion_kind=None)
- def select(x, dim, idx):
- idx = View.handle_negative_index(idx, x.get_size()[dim])
- return squeeze(slice_(x, dim, idx, idx + 1), dim)
- @register_lowering(aten.split, type_promotion_kind=None)
- def split(x, sizes, dim=0):
- dim = _validate_dim(x, dim, 0)
- x_size = V.graph.sizevars.guard_static_shape(x.get_size()[dim])
- if isinstance(sizes, sympy.Expr):
- sizes = V.graph.sizevars.guard_static_shape(sizes)
- if isinstance(sizes, (int, sympy.Integer)):
- sizes = [sizes] * ((x_size + sizes - 1) // sizes)
- result = []
- start = 0
- for size in sizes:
- end = start + size
- result.append(slice_(x, dim, start, end))
- start = end
- return result
- @register_lowering(aten.split_with_sizes, type_promotion_kind=None)
- def split_with_sizes(x, sizes, dim=0):
- return split(x, sizes, dim)
- @register_lowering(aten.unbind, type_promotion_kind=None)
- def unbind(x, dim=0):
- dim = _validate_dim(x, dim, 0)
- x_size = V.graph.sizevars.guard_static_shape(x.get_size()[dim])
- result = []
- for i in range(x_size):
- result.append(select(x, dim, i))
- return result
- @register_lowering(aten.unsqueeze, type_promotion_kind=None)
- def unsqueeze(x, dim):
- dim = _validate_dim(x, dim, 1)
- new_shape = list(x.get_size())
- new_shape.insert(dim, sympy.Integer(1))
- return view(x, new_shape)
- @register_lowering(aten.unsqueeze_, type_promotion_kind=None)
- def unsqueeze_(x, dim):
- val = unsqueeze(x, dim)
- assert isinstance(x, TensorBox)
- assert isinstance(val, TensorBox)
- x.data = val.data
- return x
- def _validate_dim(x, dim, offset=0):
- assert isinstance(dim, int)
- ndim = len(x.get_size())
- if dim < 0:
- dim += ndim + offset
- assert 0 <= dim < ndim + offset
- return dim
- @register_lowering(aten.glu)
- def glu(x, dim=-1):
- dim = _validate_dim(x, dim, 0)
- new_len = V.graph.sizevars.guard_static_shape(x.get_size()[dim]) // 2
- a = slice_(x, dim, 0, new_len)
- b = slice_(x, dim, new_len, new_len * 2)
- return mul(a, sigmoid(b))
- def register_onednn_fusion_ops():
- if torch._C.has_mkldnn:
- @register_lowering(torch.ops.mkldnn._convolution_pointwise)
- def convolution_unary(
- x: TensorBox,
- weight: TensorBox,
- bias: TensorBox,
- padding,
- stride,
- dilation,
- groups,
- attr,
- scalars,
- algorithm,
- ):
- return TensorBox.create(
- ir.ConvolutionUnary.create(
- x,
- weight,
- bias,
- padding,
- stride,
- dilation,
- groups,
- attr,
- scalars,
- algorithm,
- )
- )
- @register_lowering(torch.ops.mkldnn._convolution_pointwise.binary)
- def convolution_binary(
- x: TensorBox,
- other: TensorBox,
- weight: TensorBox,
- bias: TensorBox,
- padding,
- stride,
- dilation,
- groups,
- binary_attr,
- binary_alpha,
- unary_attr,
- unary_scalars,
- unary_algorithm,
- ):
- return TensorBox.create(
- ir.ConvolutionBinary.create(
- x,
- other,
- weight,
- bias,
- padding,
- stride,
- dilation,
- groups,
- binary_attr,
- binary_alpha,
- unary_attr,
- unary_scalars,
- unary_algorithm,
- )
- )
- @register_lowering(torch.ops.mkldnn._convolution_pointwise_.binary)
- def convolution_binary_inplace(
- x: TensorBox,
- other: TensorBox,
- weight: TensorBox,
- bias: TensorBox,
- padding,
- stride,
- dilation,
- groups,
- binary_attr,
- binary_alpha,
- unary_attr,
- unary_scalars,
- unary_algorithm,
- ):
- return TensorBox.create(
- ir.ConvolutionBinaryInplace.create(
- x,
- other,
- weight,
- bias,
- padding,
- stride,
- dilation,
- groups,
- binary_attr,
- binary_alpha,
- unary_attr,
- unary_scalars,
- unary_algorithm,
- )
- )
- @register_lowering(torch.ops.mkldnn._linear_pointwise)
- def linear_unary(
- x: TensorBox, w: TensorBox, b: TensorBox, attr, scalars, algorithm
- ):
- return TensorBox.create(
- ir.LinearUnary.create(x, w, b, attr, scalars, algorithm)
- )
- @register_lowering(torch.ops.mkldnn._linear_pointwise.binary)
- def linear_binary(x: TensorBox, y: TensorBox, w: TensorBox, b: TensorBox, attr):
- return TensorBox.create(ir.LinearBinary.create(x, y, w, b, attr))
- @register_lowering(torch.ops.mkldnn._convolution_transpose_pointwise)
- def convolution_transpose_unary(
- x: TensorBox,
- weight: TensorBox,
- bias: TensorBox,
- padding,
- output_padding,
- stride,
- dilation,
- groups,
- attr,
- scalars,
- algorithm,
- ):
- return TensorBox.create(
- ir.ConvolutionTransposeUnary.create(
- x,
- weight,
- bias,
- padding,
- output_padding,
- stride,
- dilation,
- groups,
- attr,
- scalars,
- algorithm,
- )
- )
- if torch._C.has_mkl:
- @register_lowering(torch.ops.mkl._mkl_linear)
- def mkl_packed_linear(
- x: TensorBox,
- packed_w: TensorBox,
- orig_w: TensorBox,
- b: TensorBox,
- batch_size,
- ):
- result = TensorBox.create(
- ir.MKLPackedLinear.create(x, packed_w, orig_w, batch_size)
- )
- if b is not None:
- result = add(result, b)
- return result
- else:
- pass
- register_onednn_fusion_ops()
- def fallback_handler(kernel):
- fallbacks.add(kernel)
- def handler(*args, **kwargs):
- return pytree.tree_map(
- TensorBox.create, ir.FallbackKernel.create(kernel, *args, **kwargs)
- )
- return handler
- def make_fallback(kernel, layout_constraint=None, warn=True):
- assert (
- kernel not in decompositions
- ), f"both a fallback and a decomp for same kernel: {kernel}"
- if get_decompositions([kernel]) and warn:
- developer_warning(
- f"make_fallback({kernel}): a decomposition exists, we should switch to it"
- )
- add_needs_realized_inputs(kernel)
- if layout_constraint is not None:
- add_layout_constraint(kernel, layout_constraint)
- return register_lowering(kernel, type_promotion_kind=None)(fallback_handler(kernel))
- @register_lowering(aten.native_dropout, type_promotion_kind=None)
- def native_dropout(x, p, train):
- assert (
- config.fallback_random
- ), "this should be handled in decomps unless config.fallback_random"
- if train:
- return pytree.tree_map(
- TensorBox.create, ir.FallbackKernel.create(aten.native_dropout, x, p, train)
- )
- return x, ones_like(x, dtype=torch.bool)
- @register_lowering(aten.bernoulli_, type_promotion_kind=None)
- def bernoulli_(x, *args):
- assert (
- config.fallback_random
- ), "this should be handled in decomps unless config.fallback_random"
- x.realize()
- V.graph.realize_users_of(x.get_name())
- ir.InplaceBernoulliFallback(x, *args)
- return x
- @register_lowering(aten.bernoulli.p, type_promotion_kind=None)
- def bernoulli_p(x, *args):
- assert (
- config.fallback_random
- ), "this should be handled in decomps unless config.fallback_random"
- return bernoulli_(clone(x), *args)
- # This shouldn't be called in general
- @register_lowering(aten._foobar)
- def _foobar(_):
- raise AssertionError()
- @functools.lru_cache(1)
- def _warn_triton_random(salt):
- developer_warning("using triton random, expect difference from eager")
- def warn_triton_random():
- # only warn once per graph
- _warn_triton_random(V.graph.creation_time)
- def make_rand(fn_name):
- def rand_or_randn(
- *size,
- dtype=None,
- layout=0,
- device=None,
- pin_memory=False,
- memory_format=None,
- ):
- warn_triton_random()
- assert not pin_memory
- assert layout in (0, torch.strided)
- assert memory_format in (None, torch.contiguous_format)
- device = decode_device(device)
- dtype = dtype or torch.get_default_dtype()
- if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)):
- size = tuple(size[0])
- size = [sympy.expand(s) for s in size]
- offset = V.graph.increment_randomness_offset(sympy_product(size))
- random_pos = ir.FixedLayout(
- device,
- dtype,
- size,
- ir.FlexibleLayout.contiguous_strides(size),
- offset=offset,
- ).make_indexer()
- seed_buffer = V.graph.random_seed_buffer(device).make_loader()
- def inner_fn(index):
- seed = seed_buffer([])
- # change seed so that we don't collide with philox_rand_like()
- # TODO(jansel): migrate everything to philox_rand_like()
- seed = ops.bitwise_xor(seed, ops.constant(0xFFFF, torch.int32))
- return getattr(ops, fn_name)(
- seed,
- ops.index_expr(random_pos(index), torch.int32),
- dtype,
- )
- return Pointwise.create(
- device=device,
- dtype=dtype,
- inner_fn=inner_fn,
- ranges=list(size),
- )
- return rand_or_randn
- fallback_rand = fallback_handler(aten.rand)
- fallback_randn = fallback_handler(aten.randn)
- fast_rand = make_rand("rand")
- fast_randn = make_rand("randn")
- @register_lowering([aten.rand, torch.rand])
- def rand(*args, **kwargs):
- if config.fallback_random or kwargs.get("generator", None) is not None:
- return fallback_rand(*args, **kwargs)
- else:
- kwargs.pop("generator", None)
- return fast_rand(*args, **kwargs)
- @register_lowering([aten.randn, torch.randn])
- def randn(*args, **kwargs):
- if config.fallback_random or kwargs.get("generator", None) is not None:
- return fallback_randn(*args, **kwargs)
- else:
- kwargs.pop("generator", None)
- return fast_randn(*args, **kwargs)
- @register_lowering(overrides.philox_seed_like._overloadpacket)
- def philox_seed_like(x):
- warn_triton_random()
- return V.graph.random_seed_buffer(x.get_device())
- @register_lowering(overrides.philox_rand_like._overloadpacket, type_promotion_kind=None)
- def philox_rand_like(x, seed, offset):
- device = x.get_device()
- dtype = x.get_dtype()
- size = x.get_size()
- random_pos = ir.FixedLayout(
- device,
- dtype,
- size,
- ir.FlexibleLayout.contiguous_strides(size),
- offset=sympy.expand(offset),
- ).make_indexer()
- seed_loader = seed.make_loader()
- def inner_fn(index):
- return ops.rand(
- seed_loader([]),
- ops.index_expr(random_pos(index), torch.int32),
- dtype,
- )
- return Pointwise.create(
- device=device,
- dtype=dtype,
- inner_fn=inner_fn,
- ranges=list(size),
- )
- def require_dense(_, *args, **kwargs):
- args, kwargs = pytree.tree_map_only(
- ir.IRNode, lambda t: ir.ExternKernel.require_stride1(t), (args, kwargs)
- )
- return args, kwargs
- def require_contiguous(_, *args, **kwargs):
- args, kwargs = pytree.tree_map_only(
- ir.IRNode, lambda t: ir.ExternKernel.require_contiguous(t), (args, kwargs)
- )
- return args, kwargs
- def constrain_to_fx_strides(fx_node, *args, **kwargs):
- def apply_constraint(arg, fx_arg):
- if isinstance(arg, ir.IRNode):
- stride_order = ir.get_stride_order(fx_arg.meta["val"].stride())
- return ir.ExternKernel.require_stride_order(arg, stride_order)
- return arg
- args = [apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args)]
- kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()}
- return args, kwargs
- # TODO(jansel): we should implement decomps or lowerings for these
- # https://github.com/pytorch/torchdynamo/issues/327
- FALLBACK_ALLOW_LIST = {
- "torchvision::roi_align",
- }
- make_fallback(aten._adaptive_avg_pool2d_backward, require_dense)
- make_fallback(aten.convolution_backward, constrain_to_fx_strides)
- make_fallback(aten._cudnn_rnn, require_dense)
- make_fallback(aten._cudnn_rnn_backward, require_contiguous)
- make_fallback(aten.cumsum, require_dense, warn=False)
- make_fallback(aten._embedding_bag, require_contiguous)
- make_fallback(aten._embedding_bag_forward_only, require_contiguous)
- make_fallback(aten._fused_moving_avg_obs_fq_helper)
- make_fallback(aten._fused_moving_avg_obs_fq_helper_functional)
- make_fallback(aten.grid_sampler_2d_backward, require_dense)
- make_fallback(aten.randperm)
- make_fallback(aten.sort)
- make_fallback(aten.sort.stable)
- make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors)
- make_fallback(aten._thnn_fused_lstm_cell, require_dense)
- make_fallback(aten.topk)
- make_fallback(aten.upsample_bicubic2d_backward, require_contiguous)
- make_fallback(aten.upsample_bilinear2d_backward, require_dense)
- # The following were added as a result of https://github.com/pytorch/pytorch/pull/94039 to pass tests
- # It's not necessarily a priority to implement these
- make_fallback(aten.upsample_linear1d)
- make_fallback(aten.upsample_trilinear3d)
- make_fallback(aten.upsample_linear1d_backward)
- make_fallback(aten.upsample_trilinear3d_backward)
- make_fallback(aten._adaptive_avg_pool3d)
- make_fallback(aten.adaptive_max_pool2d)
- make_fallback(aten.adaptive_max_pool3d)
- make_fallback(aten.addbmm)
- make_fallback(aten.addmv)
- make_fallback(aten.aminmax)
- make_fallback(aten.avg_pool3d)
- make_fallback(aten.block_diag)
- make_fallback(aten._cdist_forward)
- make_fallback(aten.count_nonzero)
- make_fallback(aten.cummax)
- make_fallback(aten.cummin)
- make_fallback(aten.cumprod)
- make_fallback(aten.deg2rad)
- make_fallback(aten.diagonal_copy, warn=False)
- make_fallback(aten.diagonal_scatter, warn=False)
- make_fallback(aten.digamma, warn=False)
- make_fallback(aten.dist)
- make_fallback(aten._efficientzerotensor)
- make_fallback(aten._embedding_bag_per_sample_weights_backward)
- make_fallback(aten.erfc, warn=False)
- make_fallback(aten.erfinv, warn=False)
- make_fallback(aten.fmax, warn=False)
- make_fallback(aten.fmin, warn=False)
- make_fallback(aten.dist)
- make_fallback(aten._efficientzerotensor)
- make_fallback(aten._embedding_bag_per_sample_weights_backward)
- make_fallback(aten.fractional_max_pool2d)
- make_fallback(aten.fractional_max_pool3d)
- make_fallback(aten.frexp)
- make_fallback(aten.geqrf)
- make_fallback(aten.histc)
- make_fallback(aten.i0)
- make_fallback(aten.igamma, warn=False)
- make_fallback(aten.igammac, warn=False)
- make_fallback(aten.isin)
- make_fallback(aten.isneginf, warn=False)
- make_fallback(aten.isposinf, warn=False)
- make_fallback(aten.kthvalue)
- make_fallback(aten.linalg_cholesky_ex)
- make_fallback(aten.linalg_cross)
- make_fallback(aten._linalg_det)
- make_fallback(aten.linalg_householder_product)
- make_fallback(aten.linalg_inv_ex)
- make_fallback(aten.linalg_ldl_factor_ex)
- make_fallback(aten.linalg_ldl_solve)
- make_fallback(aten.linalg_lu)
- make_fallback(aten.linalg_lu_factor_ex)
- make_fallback(aten.linalg_lu_solve)
- make_fallback(aten.linalg_matrix_exp)
- make_fallback(aten.linalg_qr)
- make_fallback(aten._linalg_slogdet)
- make_fallback(aten._linalg_solve_ex)
- make_fallback(aten.linalg_solve_triangular)
- make_fallback(aten._linalg_svd)
- make_fallback(aten.logaddexp2)
- make_fallback(aten.logcumsumexp)
- make_fallback(aten.log_sigmoid_forward, warn=False)
- make_fallback(aten.logspace, warn=False)
- make_fallback(aten.lu_unpack)
- make_fallback(aten.max_pool3d_with_indices)
- make_fallback(aten.max_unpool2d)
- make_fallback(aten.max_unpool3d)
- make_fallback(aten.median)
- make_fallback(aten.mode)
- make_fallback(aten.multilabel_margin_loss_forward)
- make_fallback(aten.multi_margin_loss)
- make_fallback(aten.nanmedian)
- make_fallback(aten.nansum)
- make_fallback(aten.narrow_copy, warn=False)
- make_fallback(aten.ormqr)
- make_fallback(aten._pdist_forward)
- make_fallback(aten.pixel_shuffle)
- make_fallback(aten.pixel_unshuffle)
- make_fallback(aten.polygamma)
- make_fallback(aten.prod, warn=False)
- make_fallback(aten.put)
- make_fallback(aten.rad2deg)
- make_fallback(aten.reflection_pad1d)
- make_fallback(aten.renorm)
- make_fallback(aten.replication_pad1d)
- make_fallback(aten.resize)
- make_fallback(aten.resize_)
- make_fallback(aten.resize_as)
- make_fallback(aten.resize_as_)
- make_fallback(aten.searchsorted)
- make_fallback(aten.smooth_l1_loss)
- make_fallback(aten.special_airy_ai)
- make_fallback(aten.special_bessel_j0, warn=False)
- make_fallback(aten.special_bessel_j1, warn=False)
- make_fallback(aten.special_bessel_y0, warn=False)
- make_fallback(aten.special_bessel_y1)
- make_fallback(aten.special_chebyshev_polynomial_t)
- make_fallback(aten.special_chebyshev_polynomial_u)
- make_fallback(aten.special_erfcx, warn=False)
- make_fallback(aten.special_hermite_polynomial_h)
- make_fallback(aten.special_hermite_polynomial_he)
- make_fallback(aten.special_i0e, warn=False)
- make_fallback(aten.special_i1, warn=False)
- make_fallback(aten.special_i1e, warn=False)
- make_fallback(aten.special_laguerre_polynomial_l)
- make_fallback(aten.special_modified_bessel_i0)
- make_fallback(aten.special_modified_bessel_i1)
- make_fallback(aten.special_modified_bessel_k0)
- make_fallback(aten.special_modified_bessel_k1)
- make_fallback(aten.special_ndtri, warn=False)
- make_fallback(aten.special_scaled_modified_bessel_k0)
- make_fallback(aten.special_scaled_modified_bessel_k1)
- make_fallback(aten.special_spherical_bessel_j0, warn=False)
- make_fallback(aten.special_zeta, warn=False)
- make_fallback(aten.take)
- make_fallback(aten.threshold, warn=False)
- make_fallback(aten.trace, warn=False)
- make_fallback(aten._trilinear)
- make_fallback(aten.unfold_copy, warn=False)
- make_fallback(aten.uniform, warn=False)
- make_fallback(aten.unsafe_split, warn=False)
- make_fallback(aten.vdot)
- make_fallback(aten.view_as_complex)
- make_fallback(aten.view_copy)
- make_fallback(aten._adaptive_avg_pool3d_backward)
- make_fallback(aten.adaptive_max_pool2d_backward)
- make_fallback(aten.adaptive_max_pool3d_backward)
- make_fallback(aten.avg_pool3d_backward)
- make_fallback(aten.bitwise_or_, warn=False)
- make_fallback(aten._cdist_backward)
- make_fallback(aten.diagonal_backward, warn=False)
- make_fallback(aten._embedding_bag_dense_backward)
- make_fallback(aten.fractional_max_pool2d_backward)
- make_fallback(aten.fractional_max_pool3d_backward)
- make_fallback(aten._linalg_check_errors)
- make_fallback(aten.max_pool3d_with_indices_backward)
- make_fallback(aten.multilabel_margin_loss_backward)
- make_fallback(aten.multi_margin_loss_backward)
- make_fallback(aten._pdist_backward)
- make_fallback(aten.reflection_pad1d_backward)
- make_fallback(aten.replication_pad1d_backward)
- make_fallback(aten.smooth_l1_loss_backward)
- make_fallback(aten.soft_margin_loss_backward, warn=False)
- make_fallback(aten.softshrink_backward, warn=False)
- make_fallback(aten.squeeze_copy)
- make_fallback(aten.linalg_pinv.atol_rtol_tensor)
- make_fallback(aten.segment_reduce.default)
- make_fallback(aten._segment_reduce_backward.default)
- make_fallback(aten.angle)
- make_fallback(aten.cholesky_inverse)
- make_fallback(aten.cholesky_solve)
- make_fallback(aten._fft_r2c)
- make_fallback(aten.histogram.bin_ct)
- make_fallback(aten._histogramdd_bin_edges.default)
- make_fallback(aten._histogramdd_from_bin_cts.default)
- make_fallback(aten.index_reduce)
- make_fallback(aten.masked_scatter)
- make_fallback(aten.to_sparse)
- make_fallback(aten.triangular_solve)
- make_fallback(aten.expand_copy)
- make_fallback(aten.gcd.default, warn=False)
- make_fallback(aten._linalg_eigh)
- make_fallback(aten.zeros.names)
- # TODO(fdrocha): this should be removed once the register_pointwise(aten.bitwise_right_shift) below is uncommented
- make_fallback(aten.bitwise_right_shift, warn=False)
- add_layout_constraint(aten.convolution, constrain_to_fx_strides)
- @register_lowering(aten.convolution)
- def convolution(
- x: TensorBox,
- weight: TensorBox,
- bias: TensorBox,
- stride: List[int],
- padding: List[int],
- dilation: List[int],
- transposed: bool,
- output_padding: List[int],
- groups: int,
- ):
- is_cpu = all(
- input.get_device().type == "cpu"
- for input in (x, weight, bias)
- if input is not None
- )
- result = TensorBox.create(
- ir.Convolution.create(
- x,
- weight,
- bias if is_cpu else None, # For cpu path, bias can always be fused
- stride,
- padding,
- dilation,
- transposed,
- output_padding,
- groups,
- )
- )
- if not is_cpu and bias is not None:
- kernel_dims = len(weight.get_size()) - 2
- out_chan = result.get_size()[-1 - kernel_dims]
- bias = view(bias, [out_chan] + kernel_dims * [1])
- result = add(result, bias)
- return result
- @register_lowering(aten._convolution)
- def _convolution(
- x,
- weight,
- bias,
- stride,
- padding,
- dilation,
- transposed,
- output_padding,
- groups,
- benchmark,
- deterministic,
- cudnn_enabled,
- allow_tf32,
- ):
- return convolution(
- x, weight, bias, stride, padding, dilation, transposed, output_padding, groups
- )
- @register_lowering(aten.clone)
- def clone(x, *, memory_format=0):
- # TODO(jansel): memory format
- return Pointwise.create(
- device=x.get_device(),
- dtype=x.get_dtype(),
- inner_fn=x.make_loader(),
- ranges=list(x.get_size()),
- )
- if hasattr(aten, "lift_fresh_copy"):
- register_lowering(aten.lift_fresh_copy)(clone)
- @register_lowering(prims.iota)
- def iota(
- length,
- *,
- start,
- step,
- dtype,
- device,
- requires_grad,
- ):
- def fn(index):
- return ops.index_expr(step * index[0] + start, dtype=dtype)
- return Pointwise.create(
- device=decode_device(device),
- dtype=dtype,
- inner_fn=fn,
- ranges=[length],
- )
- @register_lowering(aten.select_scatter, type_promotion_kind=None)
- def select_scatter(x, src, dim: int, index: int):
- assert x.get_dtype() == src.get_dtype()
- x_loader = x.make_loader()
- dim = _validate_dim(x, dim, 0)
- if index < 0:
- index = index + x.get_size()[dim]
- V.graph.sizevars.guard_leq(0, index)
- V.graph.sizevars.guard_lt(index, x.get_size()[dim])
- src = expand(unsqueeze(src, dim), x.get_size())
- src_loader = src.make_loader()
- def inner_fn(idx):
- return ops.where(
- ops.eq(
- ops.index_expr(idx[dim], torch.int32),
- ops.index_expr(index, torch.int32),
- ),
- src_loader(idx),
- x_loader(idx),
- )
- return Pointwise.create(
- device=x.get_device(),
- dtype=x.get_dtype(),
- inner_fn=inner_fn,
- ranges=list(x.get_size()),
- )
- @register_lowering(aten.slice_scatter, type_promotion_kind=None)
- def slice_scatter(x, src, dim=0, start=None, end=None, step=1):
- assert x.get_dtype() == src.get_dtype()
- x_loader = x.make_loader()
- dim = _validate_dim(x, dim, 0)
- dim_size = x.get_size()[dim]
- if start is not None and start < 0:
- start = start + dim_size
- if end is not None and end < 0:
- end = end + dim_size
- if start is None:
- start = 0
- if end is None or V.graph.sizevars.maybe_guard_leq(x.get_size()[dim], end):
- end = dim_size
- src_size = list(x.get_size())
- src_size[dim] = ir.FloorDiv(sympy.expand(end - start), sympy.expand(step))
- src = expand(src, src_size)
- src_loader = src.make_loader()
- def inner_fn(idx):
- if start == 0 and end == dim_size and step == 1:
- # selecting every element is the same as just src.clone()
- return src_loader(idx)
- idx_dim = ops.index_expr(idx[dim], torch.int32)
- src_idx = list(idx)
- src_idx[dim] = ir.FloorDiv(idx[dim] - start, step)
- mask = []
- if start != 0:
- mask.append(
- ops.ge(
- idx_dim,
- ops.index_expr(sympy.expand(start), torch.int32),
- )
- )
- if end != dim_size:
- mask.append(
- ops.lt(
- idx_dim,
- ops.index_expr(sympy.expand(end), torch.int32),
- )
- )
- if step != 1:
- mask.append(
- ops.eq(
- ops.index_expr(
- ir.ModularIndexing(idx[dim] - start, 1, step), torch.int32
- ),
- ops.constant(0, torch.int32),
- )
- )
- assert mask
- mask = functools.reduce(ops.and_, mask)
- src_val = ops.masked(
- mask,
- lambda: src_loader(src_idx),
- 0 if is_integer_type(x) else 0.0,
- )
- return ops.where(
- mask,
- src_val,
- x_loader(idx),
- )
- return Pointwise.create(
- device=x.get_device(),
- dtype=x.get_dtype(),
- inner_fn=inner_fn,
- ranges=list(x.get_size()),
- )
- def _unwrap(x):
- if isinstance(x, (list, tuple)) and len(x) > 0:
- return _unwrap(x[0])
- return x
- @register_lowering([torch.tensor, aten.scalar_tensor])
- def tensor(data, *, dtype=None, device=None, layout=None, pin_memory=False):
- assert layout in (None, torch.strided)
- assert pin_memory is False
- if isinstance(_unwrap(data), int):
- dtype = dtype or torch.int64
- else:
- dtype = dtype or torch.get_default_dtype()
- if isinstance(data, (float, int)):
- ranges = []
- def inner_fn(index):
- return ops.constant(data, dtype)
- elif len(data) == 0 or isinstance(data[0], (float, int)) and len(data) <= 8:
- # inline small tensors
- ranges = [sympy.Integer(len(data))]
- def inner_fn(index):
- def binary_search(start, end):
- assert start < end
- if end - start == 1:
- return ops.constant(data[start], dtype)
- mid = (end - start) // 2 + start
- return ops.where(
- ops.lt(
- ops.index_expr(index[0], torch.int64),
- ops.constant(mid, torch.int64),
- ),
- binary_search(start, mid),
- binary_search(mid, end),
- )
- if len(data) == 0:
- return ops.constant(0, dtype)
- return binary_search(0, len(data))
- else:
- return V.graph.add_tensor_constant(
- torch.tensor(data, dtype=dtype, device=device)
- )
- return Pointwise.create(
- device=decode_device(device),
- dtype=dtype,
- inner_fn=inner_fn,
- ranges=ranges,
- )
- @register_lowering(torch.as_tensor)
- def as_tensor(data, dtype=None, device=None):
- if isinstance(data, TensorBox):
- if dtype is not None:
- data = to_dtype(data, dtype)
- if device is not None:
- data = to_device(data, device)
- return data
- return tensor(data, dtype=dtype, device=device)
- @register_lowering(torch.LongTensor)
- def long_tensor(data):
- return tensor(data, dtype=torch.int64)
- @register_lowering(aten._local_scalar_dense)
- def _local_scalar_dense(data):
- return ir.DynamicScalar()
- def _full(fill_value, device, dtype, size):
- value = fill_value
- if not isinstance(fill_value, (int, float)) and hasattr(value, "value"):
- value = value.value
- if isinstance(value, (int, float, sympy.Expr)):
- def inner_fn(index):
- return ops.constant(value, dtype)
- else:
- assert len(value.get_size()) == 0
- value_loader = value.make_loader()
- def inner_fn(index):
- return value_loader([])
- return Pointwise.create(
- device=device,
- dtype=dtype,
- inner_fn=inner_fn,
- ranges=list(size),
- )
- @register_lowering(aten.full_like, type_promotion_kind=None)
- def full_like(x, fill_value, **kwargs):
- return create_tensor_like(tensor_constructor(fill_value))(x, **kwargs)
- def tensor_constructor(fill_value):
- # torch.zeros, torch.ones, etc
- def inner(
- *size,
- names=None,
- dtype=None,
- device=None,
- layout=0,
- pin_memory=False,
- memory_format=None,
- ):
- assert names is None
- assert not pin_memory
- assert layout in (0, torch.strided)
- assert memory_format in (None, torch.contiguous_format)
- device = decode_device(device)
- dtype = dtype or torch.get_default_dtype()
- if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)):
- size = tuple(size[0])
- size = [sympy.expand(s) for s in size]
- return _full(fill_value, device, dtype, size)
- return inner
- @register_lowering([torch.empty, aten.empty])
- def empty(
- *size,
- names=None,
- dtype=None,
- layout=None,
- device=None,
- pin_memory=None,
- memory_format=None,
- ):
- assert names is None
- assert memory_format in (None, torch.contiguous_format)
- device = decode_device(device)
- if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)):
- size = list(size[0])
- return empty_strided(
- size, None, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
- )
- def create_tensor_like(creation_fn):
- """
- Shim to convert X_like(...) into X(...). For example zeros_like() into zeros().
- """
- def _constant_like(
- x, *, dtype=None, device=None, layout=0, pin_memory=False, memory_format=None
- ):
- assert not pin_memory
- assert layout in (0, torch.strided)
- if dtype is None:
- dtype = x.get_dtype()
- else:
- dtype = decode_dtype(dtype)
- device = device or x.get_device()
- size = list(x.get_size())
- return creation_fn(
- size, dtype=dtype, device=device, layout=layout, pin_memory=pin_memory
- )
- return _constant_like
- def constant_like(fill_value):
- return create_tensor_like(tensor_constructor(fill_value))
- empty_like = register_lowering(aten.empty_like)(create_tensor_like(empty))
- ones_like = create_tensor_like(tensor_constructor(1))
- if not config.fallback_random:
- rand_like = register_lowering(aten.rand_like)(create_tensor_like(rand))
- randn_like = register_lowering(aten.randn_like)(create_tensor_like(randn))
- def new_constant(fill_value):
- def _new_constant(
- x, size, *, dtype=None, layout=None, device=None, pin_memory=None
- ):
- assert isinstance(size, (list, type))
- assert not pin_memory
- assert not layout or layout == torch.strided
- dtype = decode_dtype(dtype) or x.get_dtype()
- device = device or x.get_device()
- size = [sympy.Integer(s) for s in size]
- return _full(fill_value, device, dtype, size)
- return _new_constant
- @register_lowering(aten.new_empty)
- def new_empty(x, size, *, dtype=None, layout=None, device=None, pin_memory=None):
- if dtype is None:
- dtype = x.get_dtype()
- if device is None:
- device = x.get_device()
- return empty_strided(
- size, None, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
- )
- @register_lowering(aten.empty_strided)
- def empty_strided(
- size, stride, *, dtype=None, layout=None, device=None, pin_memory=None
- ):
- assert isinstance(size, (list, type))
- assert isinstance(stride, (list, type, type(None)))
- assert not pin_memory
- assert not layout or layout == torch.strided
- dtype = decode_dtype(dtype) or torch.get_default_dtype()
- device = device or torch.tensor(0.0).device
- pointwise = _full(fill_value=0, device=device, dtype=dtype, size=size)
- pointwise.realize()
- buffer = pointwise.data.data
- # explicitly set ranges to zeros in order to make a NopKernelSchedulerNode
- buffer.data.ranges = [0] * len(size)
- assert isinstance(buffer, ir.ComputedBuffer)
- size = [sympy.expand(s) for s in size]
- stride = (
- [sympy.expand(s) for s in stride]
- if stride
- else ir.FlexibleLayout.contiguous_strides(size)
- )
- buffer.layout = ir.FixedLayout(
- device=device,
- dtype=dtype,
- size=size,
- stride=stride,
- )
- return pointwise
- @register_lowering(aten.new_empty_strided)
- def new_empty_strided(
- x, size, stride, *, dtype=None, layout=None, device=None, pin_memory=None
- ):
- if dtype is None:
- dtype = x.get_dtype()
- if device is None:
- device = x.get_device()
- return empty_strided(
- size, stride, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
- )
- @register_lowering(prims.copy_strided.default)
- def copy_strided(x, stride):
- stride = [V.graph.sizevars.size_hint(s) for s in stride]
- stride_order = sorted(range(len(stride)), key=stride.__getitem__)
- return ir.ExternKernel.require_stride_order(x, stride_order)
- @register_lowering([torch.full, aten.full])
- def full(size, fill_value, **kwargs):
- return tensor_constructor(fill_value)(size, **kwargs)
- @register_lowering(aten.gather, type_promotion_kind=None)
- def gather(x, dim, index):
- assert isinstance(x, TensorBox)
- assert index.get_dtype() == torch.int64
- offset = len(x.get_size()) == 0
- dim = _validate_dim(x, dim, offset)
- x_loader = x.make_loader()
- index_loader = index.make_loader()
- def fn(idx):
- idx = list(idx)
- if len(idx) != 0:
- idx[dim] = ops.indirect_indexing(index_loader(idx))
- return x_loader(idx)
- return Pointwise.create(
- device=x.get_device(),
- dtype=x.get_dtype(),
- inner_fn=fn,
- ranges=index.get_size(),
- )
- @register_lowering(aten.embedding, type_promotion_kind=None)
- def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False):
- assert not sparse
- assert isinstance(weight, TensorBox)
- assert isinstance(indices, TensorBox)
- assert "int" in str(indices.get_dtype())
- weight_loader = weight.make_loader()
- indices_loader = indices.make_loader()
- indices_ndim = len(indices.get_size())
- new_size = [*indices.get_size(), *weight.get_size()[1:]]
- def fn(idx):
- assert len(idx) == len(new_size), f"{idx} != {new_size}"
- var_index = indices_loader(idx[:indices_ndim])
- weight_idx = [ops.indirect_indexing(var_index)] + [*idx[indices_ndim:]]
- return weight_loader(weight_idx)
- return Pointwise.create(
- device=weight.get_device(),
- dtype=weight.get_dtype(),
- inner_fn=fn,
- ranges=new_size,
- )
- def check_and_broadcast_indices(indices, device):
- assert all(
- i.get_dtype() in (torch.int64, torch.int32, torch.bool, torch.uint8)
- for i in indices
- if i is not None
- ), f"indices must be int64, byte or bool. Got {[i.get_dtype() for i in indices if i is not None]}"
- if any(
- i.get_dtype() in (torch.bool, torch.uint8) for i in indices if i is not None
- ):
- raise NotImplementedError("Fallback for bool indices")
- valid_idxs = [i for i, x in enumerate(indices) if isinstance(x, TensorBox)]
- assert len(valid_idxs) > 0, "requires at least 1 non-None index"
- new_indices = [None] * len(indices)
- for i, x in zip(valid_idxs, broadcast_tensors(*[indices[i] for i in valid_idxs])):
- # Eager allows indices to be CPU tensor when running on CUDA
- # FIXME: Calling to_device(x, device) should work but
- # test_advancedindex_mixed_cpu_devices still fails
- if x.get_device() != device:
- raise NotImplementedError("Fallback when indices is on a different device")
- new_indices[i] = x
- output_dim = len(x.get_size())
- start_offset = 0
- # only support None at start or end for now
- tmp = list(new_indices)
- while tmp and tmp[-1] is None:
- tmp.pop()
- while tmp and tmp[0] is None:
- tmp.pop(0)
- start_offset += 1
- if any((i is None) for i in tmp):
- raise NotImplementedError("Fallback when None is in the middle of indices")
- end_offset = output_dim + start_offset
- return new_indices, start_offset, end_offset
- @register_lowering(aten.index, type_promotion_kind=None)
- def index(x, indices):
- assert isinstance(indices, (list, tuple))
- x_loader = x.make_loader()
- try:
- indices, start_offset, end_offset = check_and_broadcast_indices(
- indices, x.get_device()
- )
- except NotImplementedError:
- x.realize()
- return fallback_handler(aten.index)(x, indices)
- indices_sizes = [i.get_size() for i in indices if i is not None]
- indices_loaders = [i.make_loader() for i in indices if i is not None]
- # no guards on output size, all the guards are set in broadcast_tensors
- output_size = list(indices_sizes[0])
- x_size = x.get_size()
- output_size = [
- *x_size[:start_offset],
- *output_size,
- *x_size[start_offset + len(indices_loaders) :],
- ]
- def fn(idx):
- assert len(idx) == len(output_size)
- new_index = [
- ops.indirect_indexing(loader(idx[start_offset:end_offset]))
- for loader in indices_loaders
- ]
- new_index = [*idx[:start_offset], *new_index, *idx[end_offset:]]
- return x_loader(new_index)
- return Pointwise.create(
- device=x.get_device(),
- dtype=x.get_dtype(),
- inner_fn=fn,
- ranges=output_size,
- )
- # All the indexing decompositions are written in terms of index, index_put, and index_put_
- # We cannot have this lowering as a decomposition as it introduces
- # mutation in the graph, which is bad for Aot Autograd. Aot Autograd runs dead
- # code elimination and common subexpression elimination optimizations, which
- # assume graphs to be side-effect free. More details at
- # https://github.com/pytorch/torchdynamo/issues/1235
- # and
- # https://github.com/pytorch/torchdynamo/issues/1863
- @register_lowering([aten.index_put])
- def index_put(x, indices, values, accumulate=False):
- return index_put_(clone(x), indices, values, accumulate)
- def index_put_as_masked_fill(self, indices, value, accumulate):
- if value.get_device() != self.get_device():
- value = to_device(value, self.get_device())
- if accumulate:
- value = add(self, value)
- return mutate_to(self, where(indices[0], value, self))
- def index_put_fallback(self, indices, values, accumulate):
- ir.IndexPutFallback(self, indices, values, accumulate)
- return self
- @register_lowering(aten.index_put_, type_promotion_kind=None)
- def index_put_(self, indices, values, accumulate=False):
- # Dispatch to masked fill for single boolean index with single value
- if (
- values.get_numel() == 1
- and len(indices) == 1
- and indices[0].get_dtype() in {torch.bool, torch.uint8}
- ):
- return index_put_as_masked_fill(self, indices, values, accumulate)
- # Fallback if there is a boolean index
- for index in indices:
- if index is not None and index.get_dtype() in {torch.bool, torch.uint8}:
- return index_put_fallback(self, indices, values, accumulate)
- x_size = self.get_size()
- x_ndim = len(x_size)
- # fallback to aten.index_put_, as tl.atomic_add does NOT support int64 or bool
- if self.get_dtype() in {torch.int64, torch.bool}:
- # self is an scalar Tensor
- if x_ndim == 0:
- self = view(self, [1])
- self = index_put_fallback(self, indices, values, accumulate)
- if x_ndim == 0:
- self = view(self, [])
- return self
- values = to_dtype(values, self.get_dtype())
- try:
- indices, start_offset, end_offset = check_and_broadcast_indices(
- indices, self.get_device()
- )
- except NotImplementedError:
- return index_put_fallback(self, indices, values, accumulate)
- indices_sizes = [i.get_size() for i in indices if i is not None]
- indices_loaders = [i.make_loader() for i in indices if i is not None]
- assert isinstance(self, TensorBox)
- self.realize()
- V.graph.realize_users_of(self.get_name())
- # self is an scalar Tensor
- if x_ndim == 0:
- self = view(self, [1])
- output_size = list(indices_sizes[0])
- expected_vals_size = [
- *x_size[:start_offset],
- *output_size,
- *x_size[start_offset + len(indices_sizes) :],
- ]
- values = expand(values, expected_vals_size)
- # all guards are set above during broadcast_tensors and expand
- def output_indexer(index):
- assert len(index) == len(expected_vals_size)
- new_index = [
- ops.indirect_indexing(loader(index[start_offset:end_offset]))
- for loader in indices_loaders
- ]
- new_index = [*index[:start_offset], *new_index, *index[end_offset:]]
- return new_index
- scatter = ir.Scatter(
- device=self.get_device(),
- dtype=self.get_dtype(),
- inner_fn=values.make_loader(),
- ranges=expected_vals_size, # iter_ranges,
- output_indexer=output_indexer,
- scatter_mode="atomic_add" if accumulate else None,
- )
- buffer = ir.ComputedBuffer(
- None,
- ir.MutationLayout(self),
- scatter,
- )
- buffer.name = V.graph.register_buffer(buffer)
- if x_ndim == 0:
- self = view(self, [])
- return self
- @register_lowering(aten.as_strided_scatter, type_promotion_kind=None)
- def as_strided_scatter(self, src, size, stride, storage_offset=None):
- output = clone(self)
- output_view = as_strided(output, size, stride, storage_offset)
- copy_(output_view, src)
- return output
- @register_lowering(aten.scatter, type_promotion_kind=None)
- def scatter(x, dim: int, index, src, **kwargs):
- return scatter_(clone(x), dim, index, src, **kwargs)
- def scatter_fallback(
- fn, self, dim: int, index, src, *, reduce: str = None, include_self: bool = True
- ):
- if reduce not in {None, "sum"} or (
- reduce == "sum" and self.get_dtype() in {torch.bool, torch.int64}
- ):
- self.realize()
- return fallback_handler(fn)(
- self, dim, index, src, reduce=reduce, include_self=include_self
- )
- return None
- @register_lowering(aten.scatter_, type_promotion_kind=None)
- def scatter_(self, dim: int, index, src, *, reduce: str = None):
- if reduce == "add":
- reduce = "sum"
- elif reduce == "multiply":
- reduce = "prod"
- else:
- assert reduce is None
- fallback_result = scatter_fallback(
- aten.scatter_, self, dim, index, src, reduce=reduce
- )
- if fallback_result:
- return fallback_result
- return scatter_reduce_(self, dim, index, src, reduce)
- @register_lowering(aten.scatter_add, type_promotion_kind=None)
- def scatter_add(x, dim: int, index, src):
- return scatter_add_(clone(x), dim, index, src)
- @register_lowering(aten.scatter_add_, type_promotion_kind=None)
- def scatter_add_(x, dim: int, index, src):
- return scatter_reduce_(clone(x), dim, index, src, "sum")
- @register_lowering(aten.scatter_reduce, type_promotion_kind=None)
- def scatter_reduce(x, dim: int, index, src, reduction_type, **kwargs):
- return scatter_reduce_(clone(x), dim, index, src, reduction_type, **kwargs)
- fallback_scatter_reduce_ = fallback_handler(aten.scatter_reduce_)
- @register_lowering(aten.scatter_reduce_, type_promotion_kind=None)
- def scatter_reduce_(self, dim: int, index, src, reduce, *, include_self: bool = True):
- assert reduce in {None, "sum", "prod", "mean", "amax", "amin"}
- fallback_result = scatter_fallback(
- aten.scatter_reduce_,
- self,
- dim,
- index,
- src,
- reduce=reduce,
- include_self=include_self,
- )
- if fallback_result:
- return fallback_result
- assert isinstance(self, TensorBox)
- assert "int" in str(index.get_dtype())
- ndim = len(self.get_size())
- if ndim == 0:
- self = view(self, [1])
- if isinstance(src, TensorBox) and len(src.get_size()) == 0:
- src = view(src, [1])
- if isinstance(index, TensorBox) and len(index.get_size()) == 0:
- index = view(index, [1])
- assert -len(self.get_size()) <= dim < len(self.get_size())
- self.realize()
- V.graph.realize_users_of(self.get_name())
- index_loader = index.make_loader()
- src_loader = src.make_loader() if isinstance(src, TensorBox) else None
- def output_indexer(idx):
- indirect_idx = list(idx)
- indirect_idx[dim] = ops.indirect_indexing(index_loader(idx))
- return indirect_idx
- def fn(idx):
- if src_loader:
- return src_loader(idx)
- else:
- # src is a scalar
- return ops.constant(src, self.get_dtype())
- def backend_reduce_str(reduce):
- if reduce == "sum":
- return "atomic_add"
- else:
- # TODO: Need to support more reduction type
- assert reduce is None
- return None
- if not include_self:
- # zero out the corresponding elements first
- zero_out = ir.Scatter(
- device=self.get_device(),
- dtype=self.get_dtype(),
- inner_fn=lambda index: ops.constant(0, self.get_dtype()),
- ranges=index.get_size(),
- output_indexer=output_indexer,
- scatter_mode=None,
- )
- buffer = ir.ComputedBuffer(
- None,
- ir.MutationLayout(self),
- zero_out,
- )
- buffer.name = V.graph.register_buffer(buffer)
- # self[index[i][j][k]][j][k] += src[i][j][k] # if dim == 0
- # self[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1
- # self[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2
- scatter = ir.Scatter(
- device=self.get_device(),
- dtype=self.get_dtype(),
- inner_fn=fn,
- ranges=index.get_size(),
- output_indexer=output_indexer,
- scatter_mode=backend_reduce_str(reduce),
- )
- buffer = ir.ComputedBuffer(
- None,
- ir.MutationLayout(self),
- scatter,
- )
- buffer.name = V.graph.register_buffer(buffer)
- if ndim == 0:
- self = view(self, [])
- return self
- def upsample_nearestnd(x, output_size, scales_x: Tuple[float] = None, n: int = 2):
- x.realize_hint() # elements are reused
- x_loader = x.make_loader()
- i_sizes = x.get_size()[-n:]
- batch = x.get_size()[:-n]
- i_sizes = [V.graph.sizevars.guard_static_shape(i) for i in i_sizes]
- assert len(scales_x) == n
- o_sizes = output_size
- scales = [i / o for i, o in zip(i_sizes, o_sizes)]
- for i, scale in enumerate(scales):
- if scale:
- scales[i] = scale
- def scale(x, scale):
- x = ops.index_expr(x, torch.float32)
- x = ops.mul(x, ops.constant(scale, torch.float32))
- x = ops.to_dtype(x, torch.int32)
- return ops.indirect_indexing(x)
- def fn(idx):
- x = idx[-n:]
- b = idx[:-n]
- return x_loader([*b, *[scale(i, s) for i, s in zip(x, scales)]])
- return Pointwise.create(
- device=x.get_device(),
- dtype=x.get_dtype(),
- inner_fn=fn,
- ranges=[*batch, *o_sizes],
- )
- @register_lowering(aten.upsample_nearest1d.default)
- def upsample_nearest1d(x, output_size, scales: Optional[float] = None):
- return upsample_nearestnd(x, output_size, (scales,), n=1)
- @register_lowering(aten.upsample_nearest2d.default)
- def upsample_nearest2d(
- x, output_size, scales_h: Optional[float] = None, scales_w: Optional[float] = None
- ):
- return upsample_nearestnd(x, output_size, (scales_h, scales_w), n=2)
- @register_lowering(aten.upsample_nearest3d.default)
- def upsample_nearest3d(
- x,
- output_size,
- scales_d: Optional[float] = None,
- scales_h: Optional[float] = None,
- scales_w: Optional[float] = None,
- ):
- return upsample_nearestnd(x, output_size, (scales_d, scales_h, scales_w), n=3)
- @register_lowering(aten.upsample_bicubic2d.default)
- def upsample_bicubic2d_default(
- x,
- output_size,
- align_corners: bool,
- scales_h: Optional[float] = None,
- scales_w: Optional[float] = None,
- ):
- x.realize_hint()
- x_loader = x.make_loader()
- N, C, iH, iW = x.get_size()
- oH, oW = output_size
- iH = V.graph.sizevars.guard_static_shape(iH)
- iW = V.graph.sizevars.guard_static_shape(iW)
- def get_int_dtype(maxval):
- if maxval > torch.iinfo(torch.int32).max:
- return torch.int64
- return torch.int32
- def compute_scale(in_size, out_size, align_corners, scale=None):
- if align_corners:
- return (in_size - 1) / (out_size - 1) if out_size > 1 else 0
- else:
- return 1 / scale if scale is not None and scale > 0 else in_size / out_size
- def compute_source_index(scale, dst_index, align_corners):
- dst_index_ie = ops.index_expr(dst_index, torch.float32)
- if align_corners:
- return ops.mul(scale, dst_index_ie)
- else:
- return ops.sub(
- ops.mul(scale, ops.add(dst_index_ie, 0.5)), 0.5
- ) # scale * (dst_index + 0.5) - 0.5
- def cubic_convolution1(x, A):
- # ((A + 2) * x - (A+3)) * x * x + 1
- return ops.add(ops.mul(ops.mul(ops.sub(ops.mul(A + 2, x), A + 3), x), x), 1.0)
- def cubic_convolution2(x, A):
- # ((A * x - 5 * A) * x + 8 * A) * x - 4*A
- return ops.sub(
- ops.mul(ops.add(ops.mul(ops.sub(ops.mul(A, x), 5 * A), x), 8 * A), x), 4 * A
- )
- def get_cubic_upsample_coefficients(t):
- A = -0.75
- c0 = cubic_convolution2(ops.add(t, 1.0), A)
- c1 = cubic_convolution1(t, A)
- x2 = ops.sub(1.0, t)
- c2 = cubic_convolution1(x2, A)
- c3 = cubic_convolution2(ops.add(x2, 1.0), A)
- return (
- c0,
- c1,
- c2,
- c3,
- )
- def cubic_interp1d(xs, t):
- cs = get_cubic_upsample_coefficients(t)
- # dot product between xs and cs
- return ops.add(
- ops.mul(xs[0], cs[0]),
- ops.add(
- ops.mul(xs[1], cs[1]),
- ops.add(ops.mul(xs[2], cs[2]), ops.mul(xs[3], cs[3])),
- ),
- )
- height_scale = compute_scale(iH, oH, align_corners, scales_h)
- width_scale = compute_scale(iW, oW, align_corners, scales_h)
- def clamp(v, min, max):
- return ops.maximum(min, ops.minimum(max, v))
- def fn(idx):
- n, c, oy, ox = idx
- real_x = compute_source_index(width_scale, ox, align_corners)
- in_x = ops.floor(real_x)
- t_x = ops.sub(real_x, in_x)
- real_y = compute_source_index(height_scale, oy, align_corners)
- in_y = ops.floor(real_y)
- t_y = ops.sub(real_y, in_y)
- def load_bounded(fy, fx):
- iy = ops.indirect_indexing(clamp(fy, 0, iH - 1))
- ix = ops.indirect_indexing(clamp(fx, 0, iW - 1))
- return x_loader([n, c, iy, ix])
- iy = ops.to_dtype(in_y, get_int_dtype(iH + 1))
- ix = ops.to_dtype(in_x, get_int_dtype(iW + 1))
- iys_ofs = tuple((ops.add(iy, ofs) for ofs in (-1, 0, 1, 2)))
- ixs_ofs = tuple((ops.add(ix, ofs) for ofs in (-1, 0, 1, 2)))
- def get_x_interp(y):
- coeffs_x = tuple((load_bounded(y, x) for x in ixs_ofs))
- return cubic_interp1d(coeffs_x, t_x)
- coeffs_y = tuple(get_x_interp(y) for y in iys_ofs)
- return cubic_interp1d(coeffs_y, t_y)
- return Pointwise.create(
- device=x.get_device(),
- dtype=x.get_dtype(),
- inner_fn=fn,
- ranges=[N, C, sympy.Integer(oH), sympy.Integer(oW)],
- )
- @register_lowering(aten.reflection_pad2d)
- def reflection_pad2d(x, padding):
- assert len(padding) == 4
- left, right, top, bot = padding
- x_loader = x.make_loader()
- *batch, h, w = x.get_size()
- h = V.graph.sizevars.guard_static_shape(h)
- w = V.graph.sizevars.guard_static_shape(w)
- def reflect(x, size, offset):
- size = ops.constant(size - 1, torch.int32)
- x = ops.index_expr(x, torch.int32)
- x = ops.sub(x, ops.constant(offset, torch.int32))
- x = ops.sub(size, ops.abs(ops.sub(size, ops.abs(x))))
- return ops.indirect_indexing(x)
- def fn(idx):
- *b, x, y = idx
- x = reflect(x, h, top)
- y = reflect(y, w, left)
- return x_loader([*b, x, y])
- return Pointwise.create(
- device=x.get_device(),
- dtype=x.get_dtype(),
- inner_fn=fn,
- ranges=[*batch, sympy.Integer(h + top + bot), sympy.Integer(w + left + right)],
- )
- @register_lowering(aten.reflection_pad2d_backward)
- def reflection_pad2d_backward(grad_output, x, padding):
- assert len(padding) == 4
- left, right, top, bot = padding
- *_, h, w = x.get_size()
- h = V.graph.sizevars.guard_static_shape(h) - 1
- w = V.graph.sizevars.guard_static_shape(w) - 1
- grad_loader = grad_output.make_loader()
- def fn(idx):
- *b, x, y = idx
- def load_from_output(x, y):
- x = ops.indirect_indexing(ops.index_expr(x, torch.int32))
- y = ops.indirect_indexing(ops.index_expr(y, torch.int32))
- return grad_loader([*b, x, y])
- def index_range_condition(index_range):
- i, lb, ub = index_range
- i = ops.index_expr(i, torch.int32)
- return ops.and_(ops.ge(i, lb), ops.le(i, ub))
- def accumulate(out_x, out_y, index_range1, index_range2=None):
- nonlocal grad
- # If the upper bound is less than the lower bound, we can get rid of one accumulation.
- # This happens when the padding size is zero.
- if index_range1[2] < index_range1[1]:
- return
- cond = index_range_condition(index_range1)
- if index_range2 is not None:
- if index_range2[2] < index_range2[1]:
- return
- cond = ops.and_(cond, index_range_condition(index_range2))
- g = ops.masked(cond, lambda: load_from_output(out_x, out_y), 0.0)
- grad = ops.add(grad, g)
- # Areas after reflection:
- #
- # top-left | top | top-right
- # -----------------------------------------
- # left | center | right
- # -----------------------------------------
- # bottom-left | bottom | bottom-right
- #
- # The center area is the orignial matrix. Other areas are reflections.
- center_x, center_y = x + top, y + left
- top_reflect_x, left_reflect_y = top - x, left - y
- bot_reflect_x, right_reflect_y = 2 * h + top - x, 2 * w + left - y
- # Accumulate gradients from different areas
- grad = load_from_output(center_x, center_y)
- accumulate(center_x, left_reflect_y, (y, 1, left))
- accumulate(center_x, right_reflect_y, (y, w - right, w - 1))
- accumulate(top_reflect_x, center_y, (x, 1, top))
- accumulate(bot_reflect_x, center_y, (x, h - bot, h - 1))
- accumulate(top_reflect_x, left_reflect_y, (x, 1, top), (y, 1, left))
- accumulate(top_reflect_x, right_reflect_y, (x, 1, top), (y, w - right, w - 1))
- accumulate(bot_reflect_x, left_reflect_y, (x, h - bot, h - 1), (y, 1, left))
- accumulate(
- bot_reflect_x, right_reflect_y, (x, h - bot, h - 1), (y, w - right, w - 1)
- )
- return grad
- return Pointwise.create(
- device=grad_output.get_device(),
- dtype=grad_output.get_dtype(),
- inner_fn=fn,
- ranges=list(x.get_size()),
- )
- @register_lowering(prims.rev.default)
- def rev(x, dims):
- # note - dims pre-canoncalized
- x_loader = x.make_loader()
- sizes = x.get_size()
- def loader(idx):
- idx = list(idx)
- assert len(idx) == len(sizes)
- for dim in dims:
- idx[dim] = (sizes[dim] - 1) - idx[dim]
- return x_loader(idx)
- return Pointwise.create(
- device=x.get_device(),
- dtype=x.get_dtype(),
- inner_fn=loader,
- ranges=sizes,
- )
- @register_lowering(aten.constant_pad_nd, type_promotion_kind=None)
- def constant_pad_nd(x, padding, fill_value=0):
- assert (len(padding) % 2) == 0
- if all(p == 0 for p in padding):
- return x
- sizes = x.get_size()
- bounds = list(reversed(list(zip(padding[::2], padding[1::2]))))
- n = len(sizes) - len(bounds)
- output_size = list(sizes[:n])
- mask_sizes = []
- for (low, high), size in zip(bounds, sizes[n:]):
- size = V.graph.sizevars.guard_static_shape(size)
- mask_sizes.append(size)
- output_size.append(sympy.expand(size + low + high))
- assert len(output_size) == len(sizes)
- fill_value = dtype_to_type(x.get_dtype())(fill_value)
- def mask(index):
- mask = []
- for idx, (low, high), length in zip(index[n:], bounds, mask_sizes):
- if low != 0:
- mask.append(range_mask_low(idx))
- if high != 0:
- mask.append(range_mask_high(idx, length))
- mask = functools.reduce(ops.and_, mask)
- return ops.masked(mask, lambda: x_loader(index), fill_value)
- def offset_fn(index):
- new_index = list(index[:n])
- for idx, (low, high) in zip(index[n:], bounds):
- new_index.append(idx - low)
- assert len(new_index) == len(index)
- return mask(new_index)
- x_loader = x.make_loader()
- return Pointwise.create(
- device=x.get_device(),
- dtype=x.get_dtype(),
- inner_fn=offset_fn,
- ranges=output_size,
- )
- def range_mask_low(i: sympy.Expr):
- return ops.ge(
- ops.index_expr(i, torch.int64),
- ops.index_expr(sympy.Integer(0), torch.int64),
- )
- def range_mask_high(i: sympy.Expr, length: sympy.Expr):
- return ops.lt(
- ops.index_expr(i, torch.int64),
- ops.index_expr(length, torch.int64),
- )
- def range_mask(i: sympy.Expr, length: sympy.Expr):
- return ops.and_(
- range_mask_low(i),
- range_mask_high(i, length),
- )
- def constant_boundary_condition_2d(x, fill_value, padding):
- *_, h, w = x.get_size()
- x_loader = x.make_loader()
- def load(index):
- *prefix, ih, iw = index
- mask = ops.and_(
- range_mask(ih, h),
- range_mask(iw, w),
- )
- return ops.masked(mask, lambda: x_loader([*prefix, ih, iw]), fill_value)
- return load
- def pooling_size(x, i, kernel_size, stride, padding, ceil_mode):
- x_out = ir.FloorDiv(
- x + 2 * padding[i] - (kernel_size[i] - 1) + (stride[i] - 1), stride[i]
- )
- if ceil_mode:
- x_alt = ir.FloorDiv(
- x + 2 * padding[i] - (kernel_size[i] - 1) + 2 * (stride[i] - 1), stride[i]
- )
- if V.graph.sizevars.size_hint(x_out - x_alt) == 0:
- # ceil mode is actually a no-op, lets guard on that
- V.graph.sizevars.guard_equals(x_out, x_alt)
- ceil_mode = False
- else:
- x_out = x_alt
- return x_out, ceil_mode
- fallback_max_pool2d_with_indices = fallback_handler(aten.max_pool2d_with_indices)
- @register_lowering(aten.max_pool2d_with_indices, type_promotion_kind=None)
- def max_pool2d_with_indices(
- x, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False
- ):
- if padding == 0:
- padding = [0, 0]
- if not stride:
- stride = kernel_size
- assert dilation == 1 or all(d == 1 for d in dilation)
- assert isinstance(x, TensorBox)
- assert len(kernel_size) == 2
- assert len(stride) == 2
- assert len(padding) == 2
- assert len(x.get_size()) in (3, 4)
- x.realize_hint()
- *batch, h, w = x.get_size()
- h_out, ceil_mode1 = pooling_size(h, 0, kernel_size, stride, padding, ceil_mode)
- w_out, ceil_mode2 = pooling_size(w, 1, kernel_size, stride, padding, ceil_mode)
- if padding[0] or padding[1] or ceil_mode1 or ceil_mode2:
- x_loader = constant_boundary_condition_2d(x, float("-inf"), padding)
- else:
- x_loader = x.make_loader()
- new_size = list(batch) + [h_out, w_out]
- window_size = kernel_size[0] * kernel_size[1]
- if window_size > 25:
- # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
- return fallback_max_pool2d_with_indices(
- x, kernel_size, stride, padding, dilation, ceil_mode
- )
- def fn(idx, return_index):
- *prefix, bh, bw = idx
- maxval = None
- maxindex = None
- for ih, iw in itertools.product(range(kernel_size[0]), range(kernel_size[1])):
- ih = bh * stride[0] + ih - padding[0]
- iw = bw * stride[1] + iw - padding[1]
- val = x_loader([*prefix, ih, iw])
- if return_index:
- index = ops.index_expr(ih * w + iw, torch.int64)
- if maxindex is None:
- maxindex = index
- else:
- maxindex = ops.where(ops.gt(val, maxval), index, maxindex)
- if maxval is None:
- maxval = val
- else:
- maxval = ops.maximum(val, maxval)
- if return_index:
- return maxindex
- else:
- return maxval
- r1 = Pointwise.create(
- device=x.get_device(),
- dtype=x.get_dtype(),
- inner_fn=functools.partial(fn, return_index=False),
- ranges=new_size,
- )
- r2 = Pointwise.create(
- device=x.get_device(),
- dtype=torch.int64,
- inner_fn=functools.partial(fn, return_index=True),
- ranges=new_size,
- )
- # TODO(jansel): should we force these to be realized?
- return r1, r2
- fallback_max_pool2d_with_indices_backward = fallback_handler(
- aten.max_pool2d_with_indices_backward
- )
- @register_lowering(aten.max_pool2d_with_indices_backward, type_promotion_kind=None)
- def max_pool2d_with_indices_backward(
- grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices
- ):
- if padding == 0:
- padding = [0, 0]
- if not stride:
- stride = kernel_size
- assert dilation == 1 or all(d == 1 for d in dilation)
- assert isinstance(x, TensorBox)
- assert len(kernel_size) == 2
- assert len(stride) == 2
- assert len(padding) == 2
- assert len(x.get_size()) in (3, 4)
- # we will read this many times, so make sure it is computed
- grad_output.realize_hint()
- try:
- gO_stride = grad_output.get_stride()
- except AttributeError:
- # some classes don't have `get_stride`
- # TODO will need a better way of determining if inputs are channels-last
- gO_stride = None
- if isinstance(x, TensorBox) and isinstance(x.data.data, Pointwise):
- data = x.data.data
- x_buffer = ir.ComputedBuffer(
- name=None,
- layout=ir.FlexibleLayout(
- device=data.get_device(),
- dtype=data.get_dtype(),
- size=data.get_size(),
- ),
- data=data,
- )
- x_buffer.decide_layout()
- x_stride = x_buffer.get_stride()
- else:
- try:
- x_stride = x.get_stride()
- except AttributeError:
- x_stride = None
- if (
- (x_stride is not None and x_stride[1] == 1)
- or gO_stride is not None
- and gO_stride[1] == 1
- ):
- # don't codegen channels-last, it's very slow
- return fallback_max_pool2d_with_indices_backward(
- grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices
- )
- indices.realize_hint()
- *batch, height, width = x.get_size()
- *_, pooled_height, pooled_width = grad_output.get_size()
- indices_loader = indices.make_loader()
- grad_loader = grad_output.make_loader()
- new_size = list(x.get_size())
- h_window_size = max(
- [
- max(h // stride[0] - max(0, (h - kernel_size[0]) // stride[0]), 1)
- for h in range(kernel_size[0] * 2)
- ]
- )
- w_window_size = max(
- [
- max(w // stride[1] - max(0, (w - kernel_size[1]) // stride[1]), 1)
- for w in range(kernel_size[1] * 2)
- ]
- )
- window_size = h_window_size * w_window_size
- if window_size > 25:
- # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
- return fallback_max_pool2d_with_indices_backward(
- grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices
- )
- def fn(idx):
- *prefix, h, w = idx
- index_test = ops.index_expr(h * width + w, torch.int32)
- h = h + padding[0]
- w = w + padding[1]
- phstart = ops.index_expr(
- ir.FloorDiv(h - kernel_size[0] + stride[0], stride[0]), torch.int32
- )
- pwstart = ops.index_expr(
- ir.FloorDiv(w - kernel_size[1] + stride[1], stride[1]), torch.int32
- )
- phend = ops.index_expr(ir.FloorDiv(h, stride[0]) + 1, torch.int32)
- pwend = ops.index_expr(ir.FloorDiv(w, stride[1]) + 1, torch.int32)
- phstart = ops.maximum(phstart, ops.constant(0, torch.int32))
- pwstart = ops.maximum(pwstart, ops.constant(0, torch.int32))
- phend = ops.minimum(phend, ops.index_expr(pooled_height, torch.int32))
- pwend = ops.minimum(pwend, ops.index_expr(pooled_width, torch.int32))
- gradient = None
- for ph_ in range(h_window_size):
- for pw_ in range(w_window_size):
- ph = ops.add(phstart, ops.constant(ph_, torch.int32))
- pw = ops.add(pwstart, ops.constant(pw_, torch.int32))
- grad_index = [
- *prefix,
- ops.indirect_indexing(
- ops.minimum(ph, ops.sub(phend, ops.constant(1, torch.int32)))
- ),
- ops.indirect_indexing(
- ops.minimum(pw, ops.sub(pwend, ops.constant(1, torch.int32)))
- ),
- ]
- index_actual = indices_loader(grad_index)
- grad_part = grad_loader(grad_index)
- check = ops.eq(index_actual, index_test)
- if gradient is None:
- # don't need mask for 0, 0
- gradient = ops.where(
- check, grad_part, ops.constant(0.0, torch.float32)
- )
- else:
- mask = ops.and_(
- ops.and_(
- ops.lt(ph, phend),
- ops.lt(pw, pwend),
- ),
- check,
- )
- gradient = ops.where(mask, ops.add(gradient, grad_part), gradient)
- assert gradient is not None
- return gradient
- return Pointwise.create(
- device=grad_output.get_device(),
- dtype=grad_output.get_dtype(),
- inner_fn=fn,
- ranges=new_size,
- )
- def pad_adaptive_loader(x):
- *_, h, w = x.get_size()
- x_loader = x.make_loader()
- def load(prefix, increments, start_indices, end_indices):
- ih, iw = increments
- h_start_index, w_start_index = start_indices
- h_end_index, w_end_index = end_indices
- mask = ops.and_(
- ops.lt(
- ops.index_expr(h_start_index + ih, torch.int64),
- ops.index_expr(h_end_index, torch.int64),
- ),
- ops.lt(
- ops.index_expr(w_start_index + iw, torch.int64),
- ops.index_expr(w_end_index, torch.int64),
- ),
- )
- return ops.masked(
- mask,
- lambda: x_loader([*prefix, h_start_index + ih, w_start_index + iw]),
- 0.0,
- )
- return load
- def _adaptive_pooling_idx_sum(kernel_maxes, start_index_fns, end_index_fns):
- h_start_index_fn, w_start_index_fn = start_index_fns
- h_end_index_fn, w_end_index_fn = end_index_fns
- def fn_sum(idx, loader):
- *prefix, bh, bw = idx
- h_start_index = h_start_index_fn(bh)
- h_end_index = h_end_index_fn(bh)
- w_start_index = w_start_index_fn(bw)
- w_end_index = w_end_index_fn(bw)
- total = None
- for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])):
- val = loader(
- prefix,
- [ih, iw],
- [h_start_index, w_start_index],
- [h_end_index, w_end_index],
- )
- if total is None:
- total = val
- else:
- total = ops.add(val, total)
- return total
- return fn_sum
- fallback_adaptive_avg_pool2d = fallback_handler(aten._adaptive_avg_pool2d)
- @register_lowering(aten._adaptive_avg_pool2d)
- def _adaptive_avg_pool2d(x, output_size):
- assert isinstance(x, TensorBox)
- assert len(output_size) == 2
- x.realize_hint()
- *batch, h_in, w_in = x.get_size()
- h_in = V.graph.sizevars.guard_static_shape(h_in)
- w_in = V.graph.sizevars.guard_static_shape(w_in)
- h_out, w_out = output_size
- # no-op if the same input and output
- if h_in == h_out and w_in == w_out:
- return clone(x)
- if h_in % h_out == 0 and w_in % w_out == 0:
- kernel_size = [h_in // h_out, w_in // w_out]
- return avg_pool2d(x, kernel_size)
- h_kernel_max = ceildiv((h_in + h_out - 1), h_out)
- w_kernel_max = ceildiv((w_in + w_out - 1), w_out)
- new_size = list(batch) + [h_out, w_out]
- dtype = x.get_dtype()
- def start_index(index, out_dim, inp_dim):
- return ir.FloorDiv((index * inp_dim), out_dim)
- def end_index(index, out_dim, inp_dim):
- return ir.FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim)
- h_start_index = functools.partial(start_index, out_dim=h_out, inp_dim=h_in)
- h_end_index = functools.partial(end_index, out_dim=h_out, inp_dim=h_in)
- w_start_index = functools.partial(start_index, out_dim=w_out, inp_dim=w_in)
- w_end_index = functools.partial(end_index, out_dim=w_out, inp_dim=w_in)
- window_size = h_kernel_max * w_kernel_max
- if window_size > 25:
- # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
- return fallback_adaptive_avg_pool2d(x, output_size)
- fn_sum = _adaptive_pooling_idx_sum(
- [h_kernel_max, w_kernel_max],
- [h_start_index, w_start_index],
- [h_end_index, w_end_index],
- )
- ones_loader = pad_adaptive_loader(ones_like(x))
- def fn(idx):
- return ops.div(fn_sum(idx, pad_adaptive_loader(x)), fn_sum(idx, ones_loader))
- rv = Pointwise.create(
- device=x.get_device(),
- dtype=dtype,
- inner_fn=fn,
- ranges=new_size,
- )
- # TODO: should we force these to be realized?
- return rv
- @register_lowering(aten.upsample_nearest2d_backward.default)
- def upsample_nearest2d_backward(
- x, output_size=None, input_size=None, scales_h=None, scales_w=None
- ):
- x.realize_hint()
- *batch, inp_h, inp_w = x.get_size()
- inp_h = V.graph.sizevars.guard_static_shape(inp_h)
- inp_w = V.graph.sizevars.guard_static_shape(inp_w)
- *batch, out_h, out_w = input_size
- if inp_h % out_h == 0 and inp_w % out_w == 0:
- return avg_pool2d(x, [inp_h // out_h, inp_w // out_w], divisor_override=1)
- h_kernel_max = ceildiv(inp_h, out_h)
- w_kernel_max = ceildiv(inp_w, out_w)
- def start_index(index, out_dim, inp_dim):
- return ir.CeilDiv(index * inp_dim, out_dim)
- def end_index(index, out_dim, inp_dim):
- return start_index((index + 1), out_dim, inp_dim)
- h_start_index = functools.partial(start_index, out_dim=out_h, inp_dim=inp_h)
- h_end_index = functools.partial(end_index, out_dim=out_h, inp_dim=inp_h)
- w_start_index = functools.partial(start_index, out_dim=out_w, inp_dim=inp_w)
- w_end_index = functools.partial(end_index, out_dim=out_w, inp_dim=inp_w)
- fn_sum = _adaptive_pooling_idx_sum(
- [h_kernel_max, w_kernel_max],
- [h_start_index, w_start_index],
- [h_end_index, w_end_index],
- )
- def fn(idx):
- return fn_sum(idx, pad_adaptive_loader(x))
- rv = Pointwise.create(
- device=x.get_device(),
- dtype=x.get_dtype(),
- inner_fn=fn,
- ranges=list(input_size),
- )
- return rv
- fallback_avg_pool2d = fallback_handler(aten.avg_pool2d)
- @register_lowering(aten.avg_pool2d, type_promotion_kind=None)
- def avg_pool2d(
- x,
- kernel_size,
- stride=(),
- padding=0,
- ceil_mode=False,
- count_include_pad=True,
- divisor_override=None,
- ):
- if not stride:
- stride = kernel_size
- if not padding:
- padding = [0, 0]
- assert isinstance(x, TensorBox)
- assert len(kernel_size) == 2
- assert len(stride) == 2
- assert len(padding) == 2
- assert len(x.get_size()) in (3, 4)
- x.realize_hint()
- *batch, h, w = x.get_size()
- h_out, ceil_mode1 = pooling_size(h, 0, kernel_size, stride, padding, ceil_mode)
- w_out, ceil_mode2 = pooling_size(w, 1, kernel_size, stride, padding, ceil_mode)
- if padding[0] or padding[1] or ceil_mode1 or ceil_mode2:
- x_loader = constant_boundary_condition_2d(x, 0.0, padding)
- had_padding = True
- else:
- x_loader = x.make_loader()
- had_padding = False
- new_size = list(batch) + [h_out, w_out]
- dtype = x.get_dtype()
- window_size = kernel_size[0] * kernel_size[1]
- if window_size > 25:
- # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
- return fallback_avg_pool2d(
- x,
- kernel_size,
- stride,
- padding,
- ceil_mode,
- count_include_pad,
- divisor_override,
- )
- def fn_sum(idx, loader):
- *prefix, bh, bw = idx
- total = None
- for ih, iw in itertools.product(range(kernel_size[0]), range(kernel_size[1])):
- ih = bh * stride[0] + ih - padding[0]
- iw = bw * stride[1] + iw - padding[1]
- val = loader([*prefix, ih, iw])
- if total is None:
- total = val
- else:
- total = ops.add(val, total)
- return total
- if count_include_pad or not had_padding or divisor_override:
- if divisor_override:
- scale = 1 / divisor_override
- else:
- scale = 1.0 / (kernel_size[0] * kernel_size[1])
- def fn(idx):
- return ops.mul(fn_sum(idx, x_loader), ops.constant(scale, dtype))
- else:
- ones_loader = constant_boundary_condition_2d(ones_like(x), 0.0, padding)
- def fn(idx):
- # TODO(jansel): optimize to do `int(x<h)` rather than `x<h?1:0`
- return ops.div(fn_sum(idx, x_loader), fn_sum(idx, ones_loader))
- rv = Pointwise.create(
- device=x.get_device(),
- dtype=dtype,
- inner_fn=fn,
- ranges=new_size,
- )
- # TODO(jansel): should we force these to be realized?
- return rv
- fallback_avg_pool2d_backward = fallback_handler(aten.avg_pool2d_backward)
- @register_lowering(aten.avg_pool2d_backward, type_promotion_kind=None)
- def avg_pool2d_backward(
- grad_output,
- x,
- kernel_size,
- stride,
- padding,
- ceil_mode,
- count_include_pad,
- divisor_override=None,
- ):
- assert not divisor_override
- if not stride:
- stride = kernel_size
- if not padding:
- padding = [0, 0]
- assert isinstance(grad_output, TensorBox)
- assert isinstance(x, TensorBox)
- assert len(kernel_size) == 2
- assert len(stride) == 2
- assert len(padding) == 2
- assert len(x.get_size()) in (3, 4)
- grad_output.realize_hint() # we will read this many times, so make sure it is computed
- *batch, height, width = x.get_size()
- h_out, ceil_mode1 = pooling_size(height, 0, kernel_size, stride, padding, ceil_mode)
- w_out, ceil_mode2 = pooling_size(width, 1, kernel_size, stride, padding, ceil_mode)
- grad_loader = grad_output.make_loader()
- had_padding = padding[0] or padding[1] or ceil_mode1 or ceil_mode2
- *_, pooled_height, pooled_width = grad_output.get_size()
- new_size = list(x.get_size())
- dtype = x.get_dtype()
- h_window_size = max(
- [
- max(h // stride[0] - max(0, (h - kernel_size[0]) // stride[0]), 1)
- for h in range(kernel_size[0] * 2)
- ]
- )
- w_window_size = max(
- [
- max(w // stride[1] - max(0, (w - kernel_size[1]) // stride[1]), 1)
- for w in range(kernel_size[1] * 2)
- ]
- )
- window_size = h_window_size * w_window_size
- if window_size > 25:
- # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
- return fallback_avg_pool2d_backward(
- grad_output,
- x,
- kernel_size,
- stride,
- padding,
- ceil_mode,
- count_include_pad,
- divisor_override,
- )
- def compute_pool_size_without_padding(ph, pw):
- """
- This computes the scaling factor that we will divide an element
- by when `count_include_pad=False`
- """
- stride_h = ops.constant(stride[0], torch.int32)
- stride_w = ops.constant(stride[1], torch.int32)
- pad_h = ops.constant(padding[0], torch.int32)
- pad_w = ops.constant(padding[1], torch.int32)
- kernel_h = ops.constant(kernel_size[0], torch.int32)
- kernel_w = ops.constant(kernel_size[1], torch.int32)
- hstart = ops.sub(ops.mul(ph, stride_h), pad_h)
- wstart = ops.sub(ops.mul(pw, stride_w), pad_w)
- hend = ops.minimum(
- ops.add(hstart, kernel_h),
- ops.add(ops.index_expr(height, torch.int32), pad_h),
- )
- wend = ops.minimum(
- ops.add(wstart, kernel_w),
- ops.add(ops.index_expr(width, torch.int32), pad_w),
- )
- hstart = ops.maximum(hstart, ops.constant(0, torch.int32))
- wstart = ops.maximum(wstart, ops.constant(0, torch.int32))
- hend = ops.minimum(hend, ops.index_expr(height, torch.int32))
- wend = ops.minimum(wend, ops.index_expr(width, torch.int32))
- divide_factor = ops.mul(ops.sub(hend, hstart), ops.sub(wend, wstart))
- return divide_factor
- def fn(idx):
- *prefix, h, w = idx
- h = h + padding[0]
- w = w + padding[1]
- phstart = ops.index_expr(
- ir.FloorDiv(h - kernel_size[0] + stride[0], stride[0]), torch.int32
- )
- pwstart = ops.index_expr(
- ir.FloorDiv(w - kernel_size[1] + stride[1], stride[1]), torch.int32
- )
- phend = ops.index_expr(ir.FloorDiv(h, stride[0]) + 1, torch.int32)
- pwend = ops.index_expr(ir.FloorDiv(w, stride[1]) + 1, torch.int32)
- phstart = ops.maximum(phstart, ops.constant(0, torch.int32))
- pwstart = ops.maximum(pwstart, ops.constant(0, torch.int32))
- phend = ops.minimum(phend, ops.index_expr(pooled_height, torch.int32))
- pwend = ops.minimum(pwend, ops.index_expr(pooled_width, torch.int32))
- gradient = None
- for ph_ in range(h_window_size):
- for pw_ in range(w_window_size):
- ph = ops.add(phstart, ops.constant(ph_, torch.int32))
- pw = ops.add(pwstart, ops.constant(pw_, torch.int32))
- if count_include_pad or not had_padding:
- scale = kernel_size[0] * kernel_size[1]
- else:
- scale = compute_pool_size_without_padding(ph, pw)
- part = ops.truediv(
- grad_loader(
- [
- *prefix,
- ops.indirect_indexing(
- ops.minimum(
- ph, ops.sub(phend, ops.constant(1, torch.int32))
- )
- ),
- ops.indirect_indexing(
- ops.minimum(
- pw, ops.sub(pwend, ops.constant(1, torch.int32))
- )
- ),
- ]
- ),
- scale,
- )
- mask = ops.and_(
- ops.lt(ph, phend),
- ops.lt(pw, pwend),
- )
- if gradient is None:
- gradient = ops.where(mask, part, ops.constant(0.0, torch.float32))
- else:
- gradient = ops.where(mask, ops.add(gradient, part), gradient)
- assert gradient is not None
- return gradient
- rv = Pointwise.create(
- device=grad_output.get_device(),
- dtype=dtype,
- inner_fn=fn,
- ranges=new_size,
- )
- return rv
- def _validate_reduction_axis(x, axis):
- size = x.get_size()
- if isinstance(axis, int):
- axis = [axis]
- elif not axis:
- axis = range(len(size))
- axis = list(axis)
- for i in range(len(axis)):
- if axis[i] < 0:
- axis[i] += len(size) if len(size) else 1
- assert 0 <= axis[i] < len(size) or (len(size) == 0 and axis[i] == 0)
- assert len(set(axis)) == len(axis), "reduction axis not unique"
- return axis
- def make_reduction(reduction_type: str, override_return_dtype=None):
- def inner(x, axis=None, keepdims=False, *, dtype=None):
- if reduction_type == "min" and axis is not None:
- return (
- reduce_amin(x, axis, keepdims, dtype=dtype),
- reduce_argmin(x, axis, keepdims),
- )
- if reduction_type == "max" and axis is not None:
- return (
- reduce_amax(x, axis, keepdims, dtype=dtype),
- reduce_argmax(x, axis, keepdims),
- )
- if dtype is not None:
- x = to_dtype(x, dtype)
- if reduction_type == "any":
- x = to_dtype(x, torch.bool)
- size = x.get_size()
- axis = set(_validate_reduction_axis(x, axis))
- kept_sizes = []
- kept_idx = []
- reduced_sizes = []
- reduced_idx = []
- for i in range(len(size)):
- if i in axis:
- reduced_idx.append(i)
- reduced_sizes.append(size[i])
- else:
- kept_idx.append(i)
- kept_sizes.append(size[i])
- def loader(index, reduction_index):
- assert len(reduction_index) == len(reduced_idx)
- if keepdims:
- assert len(index) == len(size)
- assert all(index[i] == 0 for i in reduced_idx)
- index = [index[i] for i in kept_idx]
- assert len(index) == len(kept_idx)
- new_index = [None] * (len(index) + len(reduction_index))
- for idx, var in itertools.chain(
- zip(kept_idx, index), zip(reduced_idx, reduction_index)
- ):
- new_index[idx] = var
- return inner_loader(new_index)
- if keepdims:
- new_size = list(size)
- for i in reduced_idx:
- new_size[i] = sympy.Integer(1)
- else:
- new_size = kept_sizes
- inner_loader = x.make_loader()
- result = Reduction.create(
- device=x.get_device(),
- dst_dtype=override_return_dtype or x.get_dtype(),
- src_dtype=x.get_dtype(),
- inner_fn=loader,
- ranges=new_size,
- reduction_ranges=reduced_sizes,
- reduction_type={"amax": "max", "amin": "min"}.get(
- reduction_type, reduction_type
- ),
- )
- if isinstance(
- result.data.data, Reduction
- ): # Only realize if reduction isn't unrolled
- result.realize()
- return result
- return inner
- @register_lowering(aten.mean)
- def mean(x, axis=None, keepdim=False, *, dtype=None):
- if dtype is not None:
- x = to_dtype(x, dtype)
- size = x.get_size()
- axis = _validate_reduction_axis(x, axis)
- # compute in higher-precision until end of mean lowering
- output_dtype = x.get_dtype()
- if output_dtype in (torch.float16, torch.bfloat16):
- x = to_dtype(x, torch.float)
- sum_result = sum_(x, axis, keepdim)
- denom = sympy_product(size[i] for i in axis)
- denom = ir.IndexingConstant(denom, x.get_dtype(), x.get_device())
- denom = ExpandView.create(denom, list(sum_result.get_size()))
- return to_dtype(div(sum_result, denom), output_dtype)
- def var_mean_(x, axis, correction, keepdim, return_mean):
- if correction is None:
- correction = 1
- size = x.get_size()
- axis = _validate_reduction_axis(x, axis)
- x_mean = mean(x, axis, keepdim=True)
- if return_mean:
- x_mean.realize()
- diffs = square(sub(x, x_mean))
- sum_result = sum_(diffs, axis, keepdim)
- denom = sympy_product(size[i] for i in axis)
- if correction:
- denom = denom - correction
- denom = ir.IndexingConstant(denom, x.get_dtype(), x.get_device())
- denom = ExpandView.create(denom, list(sum_result.get_size()))
- x_var = div(sum_result, denom)
- if not return_mean:
- return x_var
- x_mean = x_mean if keepdim else squeeze(x_mean, axis)
- return x_var, x_mean
- @register_lowering([aten.var, prims.var])
- def var_(x, axis=None, *, correction=None, keepdim=False):
- return var_mean_(
- x, axis=axis, correction=correction, keepdim=keepdim, return_mean=False
- )
- @register_lowering(aten.var_mean)
- def var_mean(x, axis=None, *, correction=None, keepdim=False):
- return var_mean_(
- x, axis=axis, correction=correction, keepdim=keepdim, return_mean=True
- )
- def pow_recursive(x, y, dtype):
- if y < 0:
- return pow_recursive(ops.reciprocal(x), -y, dtype)
- if y == 0:
- return ops.constant(1, dtype)
- if y == 1:
- return x
- result = pow_recursive(x, y // 2, dtype)
- result = ops.mul(result, result)
- if (y % 2) == 1:
- result = ops.mul(result, x)
- return result
- @make_pointwise
- def pow_native(a, b):
- return ops.pow(a, b)
- def _is_ir_node_and_cuda(x):
- if isinstance(x, ir.IRNode) and decode_device(x.get_device()).type == "cuda":
- return True
- return False
- @register_lowering(aten.pow, broadcast=True)
- def pow(a, b):
- if _is_ir_node_and_cuda(a) and _is_ir_node_and_cuda(b):
- assert a.get_dtype() in (
- torch.float16,
- torch.float32,
- torch.float64,
- ), "Pow input must be floating point."
- if isinstance(b, float) and b == int(b):
- return pow(a, int(b))
- elif isinstance(b, float) and b == 0.5:
- return sqrt(a)
- elif isinstance(b, int) and b == 1:
- return a
- elif isinstance(b, int) and -32 < b < 32:
- # Optimize away small fixed powers
- loader = a.make_loader()
- def fn(idx):
- return pow_recursive(loader(idx), b, a.get_dtype())
- return Pointwise.create(
- device=a.get_device(),
- dtype=a.get_dtype(),
- inner_fn=fn,
- ranges=a.get_size(),
- )
- if isinstance(a, Number):
- if a == 1:
- return full_like(b, 1)
- if a == 2 and is_float_dtype(b.get_dtype()):
- return exp2(b)
- return pow_native(a, b)
- def mutate_to(changed, val):
- if isinstance(changed, TensorBox):
- changed_data = changed.data
- else:
- changed_data = changed
- if isinstance(val, TensorBox):
- val = val.data
- if not isinstance(val, ir.StorageBox):
- # introduce a copy to handle views
- val = Pointwise.create(
- device=changed.get_device(),
- dtype=changed.get_dtype(),
- inner_fn=val.make_loader(),
- ranges=changed.get_size(),
- ).data
- assert isinstance(val, ir.StorageBox)
- if isinstance(changed_data, ir.StorageBox) and not (
- changed_data.is_input_buffer() or isinstance(changed_data.data, ir.NopKernel)
- ):
- # Fast path, just swing the data pointer
- val.realize()
- changed_data.data = val.data
- return changed
- ir.MutationLayout.realize_into(val, changed_data)
- return changed
- @register_lowering(aten.fill_)
- def fill_(x, fill_value):
- return mutate_to(x, full_like(x, fill_value))
- @register_lowering(aten.copy_, type_promotion_kind=None)
- def copy_(dst, src, non_blocking=False):
- src = to_device(src, dst.get_device())
- src = to_dtype(src, dst.get_dtype())
- src = expand(src, dst.get_size())
- return mutate_to(dst, src)
- @make_pointwise
- def floordiv(a, b):
- return ops.floordiv(a, b)
- @make_pointwise
- def truncdiv(a, b):
- return ops.truncdiv(a, b)
- @register_lowering(aten.div, broadcast=True)
- def div_mode(a, b, rounding_mode=None):
- both_integer = is_integer_type(a) and is_integer_type(b)
- both_boolean = is_boolean_type(a) and is_boolean_type(b)
- # floordiv and truncdiv need special handling for integer tensors on Triton,
- # see the discussion at https://github.com/openai/triton/issues/605
- if rounding_mode == "floor":
- assert not both_boolean, "floordiv operands can not be boolean at the same time"
- return floordiv(a, b) if both_integer else floor(div(a, b))
- if rounding_mode == "trunc":
- assert not both_boolean, "truncdiv operands can not be boolean at the same time"
- return truncdiv(a, b) if both_integer else trunc(div(a, b))
- return div(a, b)
- @register_lowering([aten.mul], broadcast=True)
- def mul(a, b):
- both_bool = is_boolean_type(a) and is_boolean_type(b)
- if both_bool:
- return logical_and(a, b)
- else:
- fn = ops_wrapper(aten.mul.__name__)
- return make_pointwise(fn)(a, b)
- # NOTE: prims.div maps to a / b in C, so performs truncation division on
- # integer inputs and true division for floating and complex inputs.
- @register_lowering([prims.div], broadcast=True)
- def div_prim(a, b):
- is_integral = is_boolean_type(a) or is_integer_type(a)
- if is_integral:
- return truncdiv(a, b)
- def fn(*args):
- return ops.div(*args)
- return make_pointwise(fn)(a, b)
- div = register_lowering(
- [aten.true_divide, aten.div.Tensor],
- broadcast=True,
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- )(div_prim)
- @register_lowering([aten.fmod, prims.fmod], broadcast=True)
- def fmod(a, b):
- is_integral = is_boolean_type(a) or is_integer_type(a)
- if is_integral:
- def fn(a, b):
- return ops.mod(a, b)
- else:
- def fn(a, b):
- return ops.fmod(a, b)
- return make_pointwise(fn)(a, b)
- @register_lowering(aten.rsqrt)
- def rsqrt(x):
- dtype = x.get_dtype()
- if is_integer_dtype(dtype) or is_boolean_dtype(dtype):
- x = to_dtype(x, torch.get_default_dtype())
- def _rsqrt(x):
- return ops.rsqrt(x)
- return make_pointwise(_rsqrt)(x)
- @register_lowering([aten.sum, prims.sum])
- def sum_(x, axis=None, keepdims=False, *, dtype=None):
- if (
- is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype())
- ) and dtype is None:
- dtype = torch.int64
- fn = make_reduction("sum", override_return_dtype=dtype)
- return fn(x, axis, keepdims, dtype=dtype)
- register_lowering(aten.max)(make_reduction("max"))
- register_lowering(aten.min)(make_reduction("min"))
- reduce_amax = register_lowering(aten.amax)(make_reduction("amax"))
- reduce_amin = register_lowering(aten.amin)(make_reduction("amin"))
- register_lowering(aten.any)(make_reduction("any", override_return_dtype=torch.bool))
- reduce_argmax = register_lowering(aten.argmax)(
- make_reduction("argmax", override_return_dtype=torch.int64)
- )
- reduce_argmin = register_lowering(aten.argmin)(
- make_reduction("argmin", override_return_dtype=torch.int64)
- )
- add = register_pointwise(
- aten.add, allow_alpha=True, override_fn_when_input_bool="logical_or"
- )
- def register_pointwise_numeric(op):
- return register_pointwise(
- op, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
- )
- def register_pointwise_numeric_ldf64(op):
- return register_pointwise(
- op,
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- use_libdevice_for_f64=True,
- )
- exp = register_pointwise_numeric_ldf64(aten.exp)
- exp2 = register_pointwise_numeric(aten.exp2)
- expm1 = register_pointwise_numeric(aten.expm1)
- relu = register_pointwise(aten.relu)
- sigmoid = register_pointwise_numeric_ldf64(aten.sigmoid)
- sqrt = register_pointwise_numeric_ldf64(aten.sqrt)
- square = register_pointwise(aten.square)
- sub = register_pointwise(aten.sub, allow_alpha=True)
- register_pointwise_numeric_ldf64(aten.cos)
- register_pointwise_numeric_ldf64(aten.sin)
- register_pointwise(aten.abs)
- register_pointwise(aten.bitwise_and)
- register_pointwise(aten.bitwise_not, override_fn_when_input_bool="logical_not")
- register_pointwise(aten.bitwise_or)
- register_pointwise(aten.bitwise_xor)
- register_pointwise(aten.bitwise_left_shift)
- # TODO(fdrocha): once https://github.com/openai/triton/pull/1153 is merged and we advance the triton pin past it
- # this should be uncommented
- # register_pointwise(aten.bitwise_right_shift)
- register_pointwise_numeric(aten.lgamma)
- erf = register_pointwise_numeric(aten.erf)
- register_lowering(
- aten.special_erf, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
- )(erf)
- register_pointwise_numeric(aten.log1p)
- register_pointwise_numeric(aten.tan)
- register_pointwise_numeric(aten.tanh)
- register_pointwise_numeric_ldf64(aten.log)
- register_pointwise(aten.logical_not, convert_input_to_bool=True)
- maximum = register_pointwise(aten.maximum)
- minimum = register_pointwise(aten.minimum)
- register_lowering(aten.clamp_min)(maximum)
- register_lowering(aten.clamp_max)(minimum)
- register_pointwise(aten.neg)
- register_pointwise_numeric(aten.reciprocal)
- register_pointwise(aten.remainder)
- register_pointwise(aten.sign, override_fn_when_input_bool="identity")
- register_pointwise(aten.ceil)
- register_pointwise(aten.signbit, override_return_dtype=torch.bool)
- register_pointwise(aten.le, type_promotion_kind=None, override_return_dtype=torch.bool)
- register_pointwise(aten.lt, type_promotion_kind=None, override_return_dtype=torch.bool)
- register_pointwise(aten.ge, type_promotion_kind=None, override_return_dtype=torch.bool)
- register_pointwise(aten.gt, type_promotion_kind=None, override_return_dtype=torch.bool)
- register_pointwise(aten.eq, type_promotion_kind=None, override_return_dtype=torch.bool)
- register_pointwise(aten.ne, type_promotion_kind=None, override_return_dtype=torch.bool)
- logical_and = register_pointwise(
- aten.logical_and,
- type_promotion_kind=None,
- convert_input_to_bool=True,
- override_return_dtype=torch.bool,
- )
- register_lowering(aten.__and__, type_promotion_kind=None)(logical_and)
- register_lowering(aten.__or__, type_promotion_kind=None)(
- register_pointwise(
- aten.logical_or,
- type_promotion_kind=None,
- convert_input_to_bool=True,
- override_return_dtype=torch.bool,
- )
- )
- logical_xor = register_pointwise(
- aten.logical_xor,
- name="bitwise_xor",
- type_promotion_kind=None,
- convert_input_to_bool=True,
- override_return_dtype=torch.bool,
- )
- register_lowering(aten.__xor__, type_promotion_kind=None)(logical_xor)
- register_pointwise_numeric(aten.cosh)
- register_pointwise_numeric(aten.sinh)
- register_pointwise_numeric(aten.acos)
- register_pointwise_numeric(aten.acosh)
- register_pointwise_numeric(aten.asin)
- register_pointwise_numeric(aten.asinh)
- register_pointwise_numeric(aten.atan2)
- register_pointwise_numeric(aten.atan)
- register_pointwise_numeric(aten.atanh)
- register_pointwise_numeric(aten.copysign)
- register_pointwise_numeric(aten.erfc)
- register_pointwise_numeric(aten.hypot)
- register_pointwise_numeric(aten.log10)
- register_pointwise_numeric(aten.nextafter)
- def register_inplace(aten_op, outplace_op):
- @register_lowering(aten_op, type_promotion_kind=None)
- def fn(*args, **kwargs):
- result = outplace_op(*args, **kwargs)
- result = to_dtype(result, args[0].get_dtype())
- return mutate_to(args[0], result)
- return fn
- register_inplace(aten.add_, add)
- register_inplace(aten.mul_, mul)
- register_inplace(aten.div_.Tensor, div)
- register_inplace(aten.div_.Tensor_mode, div_mode)
- register_inplace(aten.sub_, sub)
- register_inplace(aten.relu_, relu)
- register_inplace(aten.sigmoid_, sigmoid)
- @register_lowering(aten.sym_size)
- def sym_size(a, dim):
- return a.get_size()[dim]
- @register_lowering(aten.sym_stride)
- def sym_stride(a, dim):
- return a.get_stride()[dim]
- @register_lowering(aten.sym_numel)
- def sym_numel(a):
- return a.get_numel()
- for method, func in magic_methods.items():
- register_lowering(method_to_operator(method))(func)
- @register_lowering(aten._foobar)
- def foobar(self, *args, **kwargs):
- raise NotImplementedError("Helpful for debugging")
- @register_lowering(torch.ops._inductor_test.realize)
- def _realize(x):
- x.realize()
- return clone(x)
- # populate lowerings defined in kernel/*
- from . import kernel
- import_submodule(kernel)
|