lowering.py 119 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490349134923493349434953496349734983499350035013502350335043505350635073508350935103511351235133514351535163517351835193520352135223523352435253526352735283529353035313532353335343535353635373538353935403541354235433544354535463547354835493550355135523553355435553556355735583559356035613562356335643565356635673568356935703571357235733574357535763577357835793580358135823583358435853586358735883589359035913592359335943595359635973598359936003601360236033604360536063607360836093610361136123613361436153616361736183619362036213622362336243625362636273628362936303631363236333634363536363637363836393640364136423643364436453646364736483649365036513652365336543655365636573658365936603661366236633664366536663667366836693670367136723673367436753676367736783679368036813682368336843685368636873688368936903691369236933694369536963697369836993700370137023703370437053706370737083709371037113712371337143715371637173718371937203721372237233724372537263727372837293730373137323733373437353736373737383739374037413742374337443745374637473748374937503751375237533754375537563757375837593760376137623763376437653766376737683769377037713772377337743775377637773778377937803781378237833784378537863787378837893790379137923793379437953796379737983799380038013802380338043805380638073808380938103811381238133814381538163817381838193820382138223823382438253826382738283829383038313832383338343835383638373838383938403841384238433844384538463847384838493850385138523853385438553856385738583859386038613862386338643865386638673868
  1. import functools
  2. import itertools
  3. import logging
  4. from collections.abc import Iterable
  5. from typing import List, Optional, Tuple
  6. import sympy
  7. import torch
  8. import torch.fx
  9. import torch.utils._pytree as pytree
  10. from torch._prims_common import (
  11. canonicalize_dims,
  12. dtype_to_type,
  13. elementwise_dtypes,
  14. ELEMENTWISE_TYPE_PROMOTION_KIND,
  15. is_boolean_dtype,
  16. is_float_dtype,
  17. is_integer_dtype,
  18. Number,
  19. )
  20. from torch.fx.experimental.symbolic_shapes import magic_methods, method_to_operator
  21. from .._dynamo.utils import import_submodule
  22. from . import config, ir, overrides, test_operators # NOQA: F401
  23. from .cuda_properties import current_device
  24. from .decomposition import decompositions, get_decompositions
  25. from .ir import (
  26. ExpandView,
  27. IndexingConstant,
  28. PermuteView,
  29. Pointwise,
  30. Reduction,
  31. SqueezeView,
  32. TensorBox,
  33. validate_ir,
  34. View,
  35. )
  36. from .utils import ceildiv, developer_warning, sympy_product
  37. from .virtualized import ops, V
  38. log = logging.getLogger(__name__)
  39. lowerings = {}
  40. layout_constraints = {}
  41. fallbacks = set()
  42. aten = torch.ops.aten
  43. prims = torch.ops.prims
  44. needs_realized_inputs = set()
  45. def add_needs_realized_inputs(fn):
  46. if isinstance(fn, (list, tuple, set)):
  47. return [add_needs_realized_inputs(x) for x in fn]
  48. needs_realized_inputs.add(fn)
  49. if isinstance(fn, torch._ops.OpOverloadPacket):
  50. for overload in fn.overloads():
  51. needs_realized_inputs.add(getattr(fn, overload))
  52. def add_layout_constraint(fn, constraint):
  53. if isinstance(fn, torch._ops.OpOverloadPacket):
  54. for overload in fn.overloads():
  55. layout_constraints[getattr(fn, overload)] = constraint
  56. else:
  57. layout_constraints[fn] = constraint
  58. add_needs_realized_inputs(
  59. [
  60. aten.as_strided,
  61. aten.avg_pool2d,
  62. aten.avg_pool2d_backward,
  63. aten.bmm,
  64. aten.convolution,
  65. aten.convolution_backward,
  66. aten.max_pool2d_with_indices,
  67. aten.max_pool2d_with_indices_backward,
  68. aten.mm,
  69. aten.upsample_bilinear2d,
  70. aten.upsample_nearest2d,
  71. aten.upsample_bicubic2d,
  72. ]
  73. )
  74. # TODO(jansel): ezyang says we won't need this in the future, try removing it
  75. # based on https://github.com/pytorch/pytorch/blob/9e3eb329df8f701/c10/core/ScalarType.h#L28
  76. DTYPE_ID_LOOKUP = {
  77. 0: torch.uint8,
  78. 1: torch.int8,
  79. 2: torch.int16,
  80. 3: torch.int32,
  81. 4: torch.int64,
  82. 5: torch.float16,
  83. 6: torch.float32,
  84. 7: torch.float64,
  85. 8: torch.complex32,
  86. 9: torch.complex64,
  87. 10: torch.complex32,
  88. 11: torch.bool,
  89. 15: torch.bfloat16,
  90. # TODO(jansel): add quantized types?
  91. # _(c10::qint8, QInt8) /* 12 */
  92. # _(c10::quint8, QUInt8) /* 13 */
  93. # _(c10::qint32, QInt32) /* 14 */
  94. # _(c10::quint4x2, QUInt4x2) /* 16 */
  95. # _(c10::quint2x4, QUInt2x4) /* 17 */
  96. }
  97. def decode_dtype(dtype: int):
  98. if not isinstance(dtype, int):
  99. return dtype
  100. assert dtype in DTYPE_ID_LOOKUP, f"id {dtype} missing from DTYPE_ID_LOOKUP"
  101. dtype = DTYPE_ID_LOOKUP[dtype]
  102. return dtype
  103. def is_integer_type(x):
  104. if isinstance(x, TensorBox):
  105. return is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype())
  106. else:
  107. return isinstance(x, int)
  108. def is_boolean_type(x):
  109. if isinstance(x, TensorBox):
  110. return is_boolean_dtype(x.get_dtype())
  111. else:
  112. return isinstance(x, bool)
  113. def decode_device(device):
  114. if device is None:
  115. return torch.tensor(0.0).device # default device
  116. if isinstance(device, str):
  117. device = torch.device(device)
  118. if device.type == "cuda" and device.index is None:
  119. return torch.device("cuda", index=current_device())
  120. return device
  121. def get_promoted_dtype(*args, type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND):
  122. def construct_input(inp):
  123. if isinstance(inp, Number):
  124. return inp
  125. else:
  126. assert hasattr(inp, "get_dtype")
  127. dim = len(inp.get_size())
  128. # construct a tmp tensor to feed into torch.result_type
  129. return torch.zeros([1] * dim, dtype=inp.get_dtype())
  130. inps = [construct_input(arg) for arg in args]
  131. _, dtype = elementwise_dtypes(*inps, type_promotion_kind=type_promotion_kind)
  132. return dtype
  133. def _register_lowering(
  134. aten_fn, decomp_fn, broadcast, type_promotion_kind, convert_input_to_bool
  135. ):
  136. """
  137. Add a lowering to lowerings dict
  138. Arguments:
  139. aten_fn: torch.ops.aten.* fn we are lowering
  140. decomp_fn: alternate implementation on our IR
  141. broadcast: True to apply broadcasting to tensor inputs
  142. type_promotion_kind: kind of type promotion applied to tensor inputs, `None` means no type promotion
  143. convert_input_to_bool: some logical ops require inputs are converted to bool
  144. """
  145. @functools.wraps(decomp_fn)
  146. def wrapped(*args, **kwargs):
  147. args = list(args)
  148. unpacked = False
  149. # TODO maybe we need to use pytrees here
  150. if len(args) == 1 and isinstance(args[0], (list, tuple)):
  151. unpacked = True
  152. args = args[0]
  153. # Only look at args that are Tensors
  154. indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)]
  155. # explicitly assert for "out=" ops for better error messages
  156. assert not any(
  157. x == "out" for x in kwargs.keys()
  158. ), "out= ops aren't yet supported"
  159. # kwargs tensors not supported yet unless it's a fallback op
  160. assert not any(isinstance(x, TensorBox) for x in kwargs.values()) or all(
  161. fn in fallbacks for fn in aten_fn
  162. )
  163. if (type_promotion_kind or convert_input_to_bool) and indices:
  164. if convert_input_to_bool:
  165. dtype = torch.bool
  166. else:
  167. # FIXME that's a crude approximation for promoting args
  168. promoting_args = [
  169. a for a in args if isinstance(a, Number) or hasattr(a, "get_dtype")
  170. ]
  171. dtype = get_promoted_dtype(
  172. *promoting_args, type_promotion_kind=type_promotion_kind
  173. )
  174. # sometimes args are an immutable list so we can't mutate them
  175. new_args = []
  176. for i in range(len(args)):
  177. if i in indices:
  178. new_args.append(to_dtype(args[i], dtype))
  179. elif isinstance(args[i], ir.Constant):
  180. new_args.append(
  181. ir.Constant(args[i].value, dtype, args[indices[0]].get_device())
  182. )
  183. else:
  184. new_args.append(args[i])
  185. args = new_args
  186. if unpacked:
  187. args = [args]
  188. if broadcast and indices:
  189. for i, x in zip(indices, broadcast_tensors(*[args[i] for i in indices])):
  190. args[i] = x
  191. for i in range(len(args)):
  192. if isinstance(args[i], ir.Constant):
  193. args[i] = ExpandView.create(
  194. args[i], list(args[indices[0]].get_size())
  195. )
  196. out = decomp_fn(*args, **kwargs)
  197. validate_ir(out)
  198. return out
  199. if not isinstance(aten_fn, (list, tuple)):
  200. aten_fn = [aten_fn]
  201. else:
  202. aten_fn = list(aten_fn)
  203. for fn in list(aten_fn):
  204. if isinstance(fn, torch._ops.OpOverloadPacket):
  205. for overload in fn.overloads():
  206. other_fn = getattr(fn, overload)
  207. if other_fn not in lowerings:
  208. aten_fn.append(other_fn)
  209. lowerings.update({fn: wrapped for fn in aten_fn})
  210. return wrapped
  211. def register_lowering(
  212. aten_fn,
  213. broadcast=False,
  214. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  215. convert_input_to_bool=False,
  216. ):
  217. """
  218. Shim to support decorator syntax.
  219. """
  220. return functools.partial(
  221. _register_lowering,
  222. aten_fn,
  223. broadcast=broadcast,
  224. type_promotion_kind=type_promotion_kind,
  225. convert_input_to_bool=convert_input_to_bool,
  226. )
  227. def broadcast_symbolic_shapes(a, b):
  228. """
  229. Broadcasting logic based on symbolic shapes.
  230. We give the shapes 0 and 1 concrete values, while all other shapes
  231. are symbolic sympy formulas.
  232. """
  233. output = []
  234. for a, b in itertools.zip_longest(
  235. reversed(a), reversed(b), fillvalue=sympy.Integer(1)
  236. ):
  237. if b == 1:
  238. output.append(a)
  239. elif a == 1:
  240. output.append(b)
  241. else:
  242. V.graph.sizevars.guard_equals(a, b)
  243. if len(sympy.expand(b).free_symbols) < len(sympy.expand(a).free_symbols):
  244. output.append(b) # prefer shorter formula
  245. else:
  246. output.append(a)
  247. return tuple(reversed(output))
  248. def promote_constants(inputs, override_return_dtype=None):
  249. if not any(isinstance(x, (sympy.Expr, int, float)) for x in inputs):
  250. return inputs
  251. if all(isinstance(x, (int, float)) for x in inputs):
  252. dtype = override_return_dtype or get_promoted_dtype(
  253. *inputs, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  254. )
  255. return [ir.Constant(x, dtype, decode_device(None)) for x in inputs]
  256. ex = next(x for x in inputs if isinstance(x, (TensorBox, ExpandView)))
  257. out = []
  258. for x in inputs:
  259. if isinstance(x, (int, float)):
  260. out.append(
  261. ExpandView.create(
  262. ir.Constant(x, ex.get_dtype(), ex.get_device()), list(ex.get_size())
  263. )
  264. )
  265. elif isinstance(x, sympy.Expr):
  266. out.append(IndexingConstant(x, ex.get_dtype(), ex.get_device()))
  267. else:
  268. out.append(x)
  269. return out
  270. def make_pointwise(
  271. fn,
  272. override_return_dtype=None,
  273. override_device=None,
  274. override_fn_when_input_bool=None,
  275. override_fn_when_cuda_float64=None,
  276. allow_alpha=False,
  277. ):
  278. def inner(*inputs: List[TensorBox], alpha=None):
  279. inputs = promote_constants(inputs, override_return_dtype)
  280. if allow_alpha:
  281. if alpha is not None and alpha != 1:
  282. inputs = list(inputs)
  283. inputs[-1] = mul(inputs[-1], alpha)
  284. else:
  285. assert alpha is None
  286. loaders = [x.make_loader() for x in inputs]
  287. ranges = inputs[0].get_size()
  288. dtype = override_return_dtype or inputs[0].get_dtype()
  289. is_cuda = decode_device(inputs[0].get_device()).type == "cuda"
  290. for other in inputs[1:]:
  291. assert isinstance(other, ir.BaseConstant) or len(ranges) == len(
  292. other.get_size()
  293. ), f"ndim mismatch {fn} {ranges} {other.get_size()}"
  294. def inner_fn(index):
  295. assert len(index) == len(ranges), f"wrong ndim {index} {ranges}"
  296. if dtype == torch.bool and override_fn_when_input_bool is not None:
  297. return override_fn_when_input_bool(*[load(index) for load in loaders])
  298. elif override_fn_when_cuda_float64 and is_cuda and dtype == torch.float64:
  299. return override_fn_when_cuda_float64(*[load(index) for load in loaders])
  300. else:
  301. return fn(*[load(index) for load in loaders])
  302. if not override_device:
  303. device = None
  304. for i in inputs:
  305. if i.get_device().type == "cuda":
  306. device = i.get_device()
  307. break
  308. if not device:
  309. device = inputs[0].get_device()
  310. device = override_device or device
  311. return Pointwise.create(
  312. device=device,
  313. dtype=dtype,
  314. inner_fn=inner_fn,
  315. ranges=ranges,
  316. )
  317. return inner
  318. @register_lowering(prims.convert_element_type, type_promotion_kind=None)
  319. def to_dtype(x: TensorBox, dtype: torch.dtype):
  320. if x.get_dtype() == dtype:
  321. return x
  322. def _to_dtype(x):
  323. return ops.to_dtype(x, dtype)
  324. return make_pointwise(_to_dtype, override_return_dtype=dtype)(x)
  325. @register_lowering(prims.device_put, type_promotion_kind=None)
  326. def to_device(x: TensorBox, device: torch.device):
  327. device = decode_device(device)
  328. if x.get_device() == device:
  329. return x
  330. return TensorBox.create(ir.DeviceCopy.create(x, device))
  331. def ops_wrapper(name):
  332. assert isinstance(name, str)
  333. def fn(*args, **kwargs):
  334. return getattr(ops, name)(*args, **kwargs)
  335. return fn
  336. def register_pointwise(
  337. aten_fn,
  338. name=None,
  339. broadcast=True,
  340. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  341. convert_input_to_bool=False,
  342. override_return_dtype=None,
  343. override_fn_when_input_bool=None,
  344. allow_alpha=False,
  345. use_libdevice_for_f64=False,
  346. ):
  347. """A pointwise function that maps ops.{name} to inputs"""
  348. name = name or aten_fn.__name__
  349. fn = ops_wrapper(name)
  350. if use_libdevice_for_f64:
  351. fn_libdevice = ops_wrapper("libdevice_" + name)
  352. if override_fn_when_input_bool is not None:
  353. override_fn_when_input_bool = ops_wrapper(override_fn_when_input_bool)
  354. fn = make_pointwise(
  355. fn,
  356. override_return_dtype=override_return_dtype,
  357. override_fn_when_input_bool=override_fn_when_input_bool,
  358. override_fn_when_cuda_float64=fn_libdevice if use_libdevice_for_f64 else None,
  359. allow_alpha=allow_alpha,
  360. )
  361. fn = register_lowering(
  362. aten_fn,
  363. broadcast=broadcast,
  364. type_promotion_kind=type_promotion_kind,
  365. convert_input_to_bool=convert_input_to_bool,
  366. )(fn)
  367. if hasattr(prims, name):
  368. register_lowering(
  369. getattr(prims, name),
  370. type_promotion_kind=None,
  371. convert_input_to_bool=convert_input_to_bool,
  372. )(fn)
  373. return fn
  374. @register_lowering(aten.where, broadcast=False, type_promotion_kind=None)
  375. def where(cond, a, b):
  376. def fn(*args):
  377. return ops.where(*args)
  378. if isinstance(a, (float, int)):
  379. a = constant_like(a)(b)
  380. if isinstance(b, (float, int)):
  381. b = constant_like(b)(a)
  382. args = [cond, a, b]
  383. dtype = get_promoted_dtype(
  384. args[1], args[2], type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  385. )
  386. indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)]
  387. for i, x in zip(indices, broadcast_tensors(*[args[i] for i in indices])):
  388. args[i] = x
  389. for i in range(len(args)):
  390. if isinstance(args[i], ir.Constant):
  391. args[i] = ExpandView.create(args[i], list(args[indices[0]].get_size()))
  392. return make_pointwise(fn, override_return_dtype=dtype)(
  393. args[0], to_dtype(args[1], dtype), to_dtype(args[2], dtype)
  394. )
  395. @register_lowering(aten.broadcast_tensors, broadcast=False, type_promotion_kind=None)
  396. def broadcast_tensors(*inputs):
  397. if len(inputs) == 1 and isinstance(inputs[0], (list, tuple)):
  398. return broadcast_tensors(*inputs[0])
  399. target = functools.reduce(
  400. broadcast_symbolic_shapes, [x.get_size() for x in inputs], ()
  401. )
  402. outputs = []
  403. for x in inputs:
  404. sizes = x.get_size()
  405. if len(sizes) != len(target) or any(
  406. ((a == 1 and b != 1) or (a != 1 and b == 1)) for a, b in zip(sizes, target)
  407. ):
  408. x = expand(x, target)
  409. outputs.append(x)
  410. return outputs
  411. @register_lowering([aten.alias, aten.detach, aten.detach_, aten.lift, prims.view_of])
  412. def nop(x):
  413. return x # AOT autograd handles this for us
  414. if hasattr(aten, "lift_fresh"):
  415. register_lowering(aten.lift_fresh)(nop)
  416. @register_lowering(aten.squeeze, type_promotion_kind=None)
  417. def squeeze(x, dim=None):
  418. assert isinstance(x, TensorBox)
  419. if dim is None:
  420. return TensorBox(SqueezeView.create(x.data))
  421. dim = canonicalize_dims(len(x.get_size()), dim)
  422. dims = set((dim,) if not isinstance(dim, tuple) else dim)
  423. new_shape = [
  424. s
  425. for d, s in enumerate(x.get_size())
  426. if not (d in dims and V.graph.sizevars.maybe_guard_equals(s, 1))
  427. ]
  428. # squeeze does nothing if the size isn't 1
  429. return view(x, new_shape) if new_shape != x.get_size() else x
  430. @register_lowering([aten.squeeze_])
  431. def squeeze_(x, dim=None):
  432. val = squeeze(x, dim)
  433. assert isinstance(x, TensorBox)
  434. assert isinstance(val, TensorBox)
  435. x.data = val.data
  436. return x
  437. @register_lowering(aten.isinf)
  438. def isinf(x):
  439. if is_integer_type(x):
  440. return full_like(x, False, dtype=torch.bool)
  441. fn = ops_wrapper("isinf")
  442. return make_pointwise(fn, override_return_dtype=torch.bool)(x)
  443. @register_lowering(aten.isnan)
  444. def isnan(x):
  445. if is_integer_type(x):
  446. return full_like(x, False, dtype=torch.bool)
  447. fn = ops_wrapper("isnan")
  448. return make_pointwise(fn, override_return_dtype=torch.bool)(x)
  449. @register_lowering(aten.ceil)
  450. def ceil(x):
  451. if is_integer_type(x):
  452. return x
  453. fn = ops_wrapper("ceil")
  454. return make_pointwise(fn)(x)
  455. @register_lowering(aten.floor)
  456. def floor(x):
  457. if is_integer_type(x):
  458. return x
  459. fn = ops_wrapper("floor")
  460. return make_pointwise(fn)(x)
  461. @register_lowering(aten.round)
  462. def round(x):
  463. if is_integer_type(x):
  464. return x
  465. fn = ops_wrapper("round")
  466. return make_pointwise(fn)(x)
  467. @register_lowering(aten.trunc)
  468. def trunc(x):
  469. if is_integer_type(x):
  470. return x
  471. fn = ops_wrapper("trunc")
  472. return make_pointwise(fn)(x)
  473. @register_lowering(aten.expand, type_promotion_kind=None)
  474. def expand(x, sizes):
  475. (x,) = promote_constants([x])
  476. if isinstance(x, ir.BaseConstant):
  477. return ExpandView.create(x, tuple(sizes))
  478. assert isinstance(x, TensorBox)
  479. assert isinstance(sizes, (list, tuple))
  480. if tuple(x.get_size()) == tuple(sizes):
  481. return x
  482. x_size_product = V.graph.sizevars.size_hint(sympy_product(x.get_size()))
  483. if x_size_product > 0:
  484. # maybe realize input before broadcasting it
  485. x.mark_reuse(V.graph.sizevars.size_hint(sympy_product(sizes)) // x_size_product)
  486. return TensorBox(ExpandView.create(x.data, tuple(sizes)))
  487. @register_lowering(prims.broadcast_in_dim, type_promotion_kind=None)
  488. def broadcast_in_dim(a, shape, broadcast_dimensions):
  489. s = list(shape)
  490. for broadcast_dimension in broadcast_dimensions:
  491. s[broadcast_dimension] = -1
  492. v = a
  493. for idx, x in enumerate(s):
  494. if x != -1:
  495. v = unsqueeze(v, idx)
  496. return expand(v, shape)
  497. @register_lowering(aten.expand_as, type_promotion_kind=None)
  498. def expand_as(x, y):
  499. return expand(x, y.get_size())
  500. @register_lowering(aten.repeat)
  501. def repeat(x, repeats):
  502. old_size = list(x.get_size())
  503. if len(repeats) > len(old_size):
  504. old_size = [sympy.Integer(1)] * (len(repeats) - len(old_size)) + old_size
  505. x = view(x, list(old_size))
  506. assert len(repeats) == len(x.get_size())
  507. new_size = list(x.get_size())
  508. for i in range(len(repeats)):
  509. assert repeats[i] != 0
  510. if repeats[i] != 1:
  511. new_size[i] = new_size[i] * repeats[i]
  512. if all((a == 1 or b == 1) for a, b in zip(repeats, old_size)):
  513. return expand(x, new_size)
  514. def inner_fn(index):
  515. assert len(index) == len(repeats)
  516. index = list(index)
  517. for i in range(len(repeats)):
  518. if repeats[i] != 1:
  519. if old_size[i] == 1:
  520. index[i] = sympy.Integer(0)
  521. else:
  522. index[i] = ir.ModularIndexing(index[i], 1, old_size[i])
  523. return x_loader(index)
  524. old_size_product = V.graph.sizevars.size_hint(sympy_product(old_size))
  525. if old_size_product > 0:
  526. # maybe realize the input
  527. x.mark_reuse(
  528. V.graph.sizevars.size_hint(sympy_product(new_size)) // old_size_product
  529. )
  530. x_loader = x.make_loader()
  531. return Pointwise.create(
  532. device=x.get_device(),
  533. dtype=x.get_dtype(),
  534. inner_fn=inner_fn,
  535. ranges=list(new_size),
  536. )
  537. @register_lowering(aten._unsafe_view, type_promotion_kind=None)
  538. @register_lowering(aten.view, type_promotion_kind=None)
  539. @register_lowering(aten.reshape, type_promotion_kind=None)
  540. def view(x, sizes):
  541. assert isinstance(x, TensorBox)
  542. assert isinstance(sizes, (list, tuple))
  543. return TensorBox(View.create(x.data, sizes))
  544. @register_lowering(aten.permute, type_promotion_kind=None)
  545. def permute(x, dims):
  546. assert isinstance(x, TensorBox)
  547. assert isinstance(dims, (list, tuple))
  548. return TensorBox(PermuteView.create(x.data, tuple(dims)))
  549. @register_lowering(aten.slice, type_promotion_kind=None)
  550. def slice_(x, dim=0, start=0, end=2**63, step=1):
  551. assert isinstance(x, TensorBox)
  552. dim = _validate_dim(x, dim, 0)
  553. return TensorBox(ir.SliceView.create(x.data, dim, start, end, step))
  554. @register_lowering(aten.roll, type_promotion_kind=None)
  555. def roll(a, shifts, dims=tuple()):
  556. """
  557. This is based on torch._refs.roll(), but uses ir.ModularIndexing().
  558. We can't use the ref here because it is based on multiple calls to
  559. torch.cat() that this will result in terrible code.
  560. """
  561. # ATen specifies int[1] type for shifts and dims which expands integers to tuples of length 1
  562. if not isinstance(shifts, Iterable):
  563. shifts = (shifts,)
  564. if not isinstance(dims, Iterable):
  565. dims = (dims,)
  566. dims = [_validate_dim(a, d) for d in dims]
  567. if sympy_product(a.get_size()) == 0:
  568. return clone(a)
  569. len_shifts = len(shifts)
  570. len_dims = len(dims)
  571. if len_shifts != 1 or len_dims != 1:
  572. if len_shifts == 0:
  573. raise RuntimeError("`shifts` required")
  574. # Takes care of the case when dims is not specified (default)
  575. # By default, the tensor is flattened before shifting, after which the original shape is restored
  576. if len_dims == 0 and len_shifts == 1:
  577. flat = view(a, [sympy_product(a.get_size())])
  578. rolled = roll(flat, shifts, 0)
  579. return view(rolled, list(a.get_size()))
  580. if len_shifts != len_dims:
  581. raise RuntimeError(
  582. f"shifts and dimensions must align. shifts: {len_shifts}, dims: {len_dims}"
  583. )
  584. tail_shifts = shifts[1:]
  585. tail_dims = dims[1:]
  586. first_dim_rolled = roll(a, shifts[0], dims[0])
  587. return roll(first_dim_rolled, tail_shifts, tail_dims)
  588. (dim,) = dims
  589. size = V.graph.sizevars.guard_static_shape(a.get_size()[dim])
  590. start = (size - shifts[0]) % size
  591. a_loader = a.make_loader()
  592. def fn(index):
  593. index = list(index)
  594. index[dim] = ir.ModularIndexing(
  595. index[dim] + start, sympy.Integer(1), sympy.expand(size)
  596. )
  597. return a_loader(index)
  598. return Pointwise.create(
  599. device=a.get_device(),
  600. dtype=a.get_dtype(),
  601. inner_fn=fn,
  602. ranges=a.get_size(),
  603. )
  604. @register_lowering(aten.as_strided, type_promotion_kind=None)
  605. def as_strided(x, size, stride, storage_offset=None):
  606. if isinstance(x, TensorBox) and isinstance(x.data, ir.BaseView):
  607. # as_strided ignores views
  608. x = x.data.unwrap_view()
  609. x.realize()
  610. if not ir.is_storage_and_layout(x):
  611. raise NotImplementedError(f"unrealized as_strided({x}, ...)")
  612. storage, old_layout = ir.as_storage_and_layout(x)
  613. new_layout = ir.FixedLayout(
  614. old_layout.device,
  615. old_layout.dtype,
  616. [sympy.expand(s) for s in size],
  617. [sympy.expand(s) for s in stride],
  618. sympy.expand(storage_offset or 0),
  619. )
  620. return TensorBox(ir.ReinterpretView(storage, new_layout))
  621. @register_lowering(aten.as_strided_)
  622. def as_strided_(x, size, stride, storage_offset=None):
  623. assert isinstance(x, TensorBox)
  624. x.data = as_strided(x, size, stride, storage_offset).data
  625. return x
  626. @register_lowering(aten.cat)
  627. def cat(inputs, dim=0):
  628. if len(inputs) == 1:
  629. return clone(inputs[0])
  630. dim = _validate_dim(inputs[0], dim, 0)
  631. dtype = get_promoted_dtype(
  632. *inputs, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  633. )
  634. inputs = [to_dtype(inp, dtype) for inp in inputs]
  635. return TensorBox(ir.ConcatKernel.create(inputs, dim))
  636. @register_lowering(aten.select, type_promotion_kind=None)
  637. def select(x, dim, idx):
  638. idx = View.handle_negative_index(idx, x.get_size()[dim])
  639. return squeeze(slice_(x, dim, idx, idx + 1), dim)
  640. @register_lowering(aten.split, type_promotion_kind=None)
  641. def split(x, sizes, dim=0):
  642. dim = _validate_dim(x, dim, 0)
  643. x_size = V.graph.sizevars.guard_static_shape(x.get_size()[dim])
  644. if isinstance(sizes, sympy.Expr):
  645. sizes = V.graph.sizevars.guard_static_shape(sizes)
  646. if isinstance(sizes, (int, sympy.Integer)):
  647. sizes = [sizes] * ((x_size + sizes - 1) // sizes)
  648. result = []
  649. start = 0
  650. for size in sizes:
  651. end = start + size
  652. result.append(slice_(x, dim, start, end))
  653. start = end
  654. return result
  655. @register_lowering(aten.split_with_sizes, type_promotion_kind=None)
  656. def split_with_sizes(x, sizes, dim=0):
  657. return split(x, sizes, dim)
  658. @register_lowering(aten.unbind, type_promotion_kind=None)
  659. def unbind(x, dim=0):
  660. dim = _validate_dim(x, dim, 0)
  661. x_size = V.graph.sizevars.guard_static_shape(x.get_size()[dim])
  662. result = []
  663. for i in range(x_size):
  664. result.append(select(x, dim, i))
  665. return result
  666. @register_lowering(aten.unsqueeze, type_promotion_kind=None)
  667. def unsqueeze(x, dim):
  668. dim = _validate_dim(x, dim, 1)
  669. new_shape = list(x.get_size())
  670. new_shape.insert(dim, sympy.Integer(1))
  671. return view(x, new_shape)
  672. @register_lowering(aten.unsqueeze_, type_promotion_kind=None)
  673. def unsqueeze_(x, dim):
  674. val = unsqueeze(x, dim)
  675. assert isinstance(x, TensorBox)
  676. assert isinstance(val, TensorBox)
  677. x.data = val.data
  678. return x
  679. def _validate_dim(x, dim, offset=0):
  680. assert isinstance(dim, int)
  681. ndim = len(x.get_size())
  682. if dim < 0:
  683. dim += ndim + offset
  684. assert 0 <= dim < ndim + offset
  685. return dim
  686. @register_lowering(aten.glu)
  687. def glu(x, dim=-1):
  688. dim = _validate_dim(x, dim, 0)
  689. new_len = V.graph.sizevars.guard_static_shape(x.get_size()[dim]) // 2
  690. a = slice_(x, dim, 0, new_len)
  691. b = slice_(x, dim, new_len, new_len * 2)
  692. return mul(a, sigmoid(b))
  693. def register_onednn_fusion_ops():
  694. if torch._C.has_mkldnn:
  695. @register_lowering(torch.ops.mkldnn._convolution_pointwise)
  696. def convolution_unary(
  697. x: TensorBox,
  698. weight: TensorBox,
  699. bias: TensorBox,
  700. padding,
  701. stride,
  702. dilation,
  703. groups,
  704. attr,
  705. scalars,
  706. algorithm,
  707. ):
  708. return TensorBox.create(
  709. ir.ConvolutionUnary.create(
  710. x,
  711. weight,
  712. bias,
  713. padding,
  714. stride,
  715. dilation,
  716. groups,
  717. attr,
  718. scalars,
  719. algorithm,
  720. )
  721. )
  722. @register_lowering(torch.ops.mkldnn._convolution_pointwise.binary)
  723. def convolution_binary(
  724. x: TensorBox,
  725. other: TensorBox,
  726. weight: TensorBox,
  727. bias: TensorBox,
  728. padding,
  729. stride,
  730. dilation,
  731. groups,
  732. binary_attr,
  733. binary_alpha,
  734. unary_attr,
  735. unary_scalars,
  736. unary_algorithm,
  737. ):
  738. return TensorBox.create(
  739. ir.ConvolutionBinary.create(
  740. x,
  741. other,
  742. weight,
  743. bias,
  744. padding,
  745. stride,
  746. dilation,
  747. groups,
  748. binary_attr,
  749. binary_alpha,
  750. unary_attr,
  751. unary_scalars,
  752. unary_algorithm,
  753. )
  754. )
  755. @register_lowering(torch.ops.mkldnn._convolution_pointwise_.binary)
  756. def convolution_binary_inplace(
  757. x: TensorBox,
  758. other: TensorBox,
  759. weight: TensorBox,
  760. bias: TensorBox,
  761. padding,
  762. stride,
  763. dilation,
  764. groups,
  765. binary_attr,
  766. binary_alpha,
  767. unary_attr,
  768. unary_scalars,
  769. unary_algorithm,
  770. ):
  771. return TensorBox.create(
  772. ir.ConvolutionBinaryInplace.create(
  773. x,
  774. other,
  775. weight,
  776. bias,
  777. padding,
  778. stride,
  779. dilation,
  780. groups,
  781. binary_attr,
  782. binary_alpha,
  783. unary_attr,
  784. unary_scalars,
  785. unary_algorithm,
  786. )
  787. )
  788. @register_lowering(torch.ops.mkldnn._linear_pointwise)
  789. def linear_unary(
  790. x: TensorBox, w: TensorBox, b: TensorBox, attr, scalars, algorithm
  791. ):
  792. return TensorBox.create(
  793. ir.LinearUnary.create(x, w, b, attr, scalars, algorithm)
  794. )
  795. @register_lowering(torch.ops.mkldnn._linear_pointwise.binary)
  796. def linear_binary(x: TensorBox, y: TensorBox, w: TensorBox, b: TensorBox, attr):
  797. return TensorBox.create(ir.LinearBinary.create(x, y, w, b, attr))
  798. @register_lowering(torch.ops.mkldnn._convolution_transpose_pointwise)
  799. def convolution_transpose_unary(
  800. x: TensorBox,
  801. weight: TensorBox,
  802. bias: TensorBox,
  803. padding,
  804. output_padding,
  805. stride,
  806. dilation,
  807. groups,
  808. attr,
  809. scalars,
  810. algorithm,
  811. ):
  812. return TensorBox.create(
  813. ir.ConvolutionTransposeUnary.create(
  814. x,
  815. weight,
  816. bias,
  817. padding,
  818. output_padding,
  819. stride,
  820. dilation,
  821. groups,
  822. attr,
  823. scalars,
  824. algorithm,
  825. )
  826. )
  827. if torch._C.has_mkl:
  828. @register_lowering(torch.ops.mkl._mkl_linear)
  829. def mkl_packed_linear(
  830. x: TensorBox,
  831. packed_w: TensorBox,
  832. orig_w: TensorBox,
  833. b: TensorBox,
  834. batch_size,
  835. ):
  836. result = TensorBox.create(
  837. ir.MKLPackedLinear.create(x, packed_w, orig_w, batch_size)
  838. )
  839. if b is not None:
  840. result = add(result, b)
  841. return result
  842. else:
  843. pass
  844. register_onednn_fusion_ops()
  845. def fallback_handler(kernel):
  846. fallbacks.add(kernel)
  847. def handler(*args, **kwargs):
  848. return pytree.tree_map(
  849. TensorBox.create, ir.FallbackKernel.create(kernel, *args, **kwargs)
  850. )
  851. return handler
  852. def make_fallback(kernel, layout_constraint=None, warn=True):
  853. assert (
  854. kernel not in decompositions
  855. ), f"both a fallback and a decomp for same kernel: {kernel}"
  856. if get_decompositions([kernel]) and warn:
  857. developer_warning(
  858. f"make_fallback({kernel}): a decomposition exists, we should switch to it"
  859. )
  860. add_needs_realized_inputs(kernel)
  861. if layout_constraint is not None:
  862. add_layout_constraint(kernel, layout_constraint)
  863. return register_lowering(kernel, type_promotion_kind=None)(fallback_handler(kernel))
  864. @register_lowering(aten.native_dropout, type_promotion_kind=None)
  865. def native_dropout(x, p, train):
  866. assert (
  867. config.fallback_random
  868. ), "this should be handled in decomps unless config.fallback_random"
  869. if train:
  870. return pytree.tree_map(
  871. TensorBox.create, ir.FallbackKernel.create(aten.native_dropout, x, p, train)
  872. )
  873. return x, ones_like(x, dtype=torch.bool)
  874. @register_lowering(aten.bernoulli_, type_promotion_kind=None)
  875. def bernoulli_(x, *args):
  876. assert (
  877. config.fallback_random
  878. ), "this should be handled in decomps unless config.fallback_random"
  879. x.realize()
  880. V.graph.realize_users_of(x.get_name())
  881. ir.InplaceBernoulliFallback(x, *args)
  882. return x
  883. @register_lowering(aten.bernoulli.p, type_promotion_kind=None)
  884. def bernoulli_p(x, *args):
  885. assert (
  886. config.fallback_random
  887. ), "this should be handled in decomps unless config.fallback_random"
  888. return bernoulli_(clone(x), *args)
  889. # This shouldn't be called in general
  890. @register_lowering(aten._foobar)
  891. def _foobar(_):
  892. raise AssertionError()
  893. @functools.lru_cache(1)
  894. def _warn_triton_random(salt):
  895. developer_warning("using triton random, expect difference from eager")
  896. def warn_triton_random():
  897. # only warn once per graph
  898. _warn_triton_random(V.graph.creation_time)
  899. def make_rand(fn_name):
  900. def rand_or_randn(
  901. *size,
  902. dtype=None,
  903. layout=0,
  904. device=None,
  905. pin_memory=False,
  906. memory_format=None,
  907. ):
  908. warn_triton_random()
  909. assert not pin_memory
  910. assert layout in (0, torch.strided)
  911. assert memory_format in (None, torch.contiguous_format)
  912. device = decode_device(device)
  913. dtype = dtype or torch.get_default_dtype()
  914. if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)):
  915. size = tuple(size[0])
  916. size = [sympy.expand(s) for s in size]
  917. offset = V.graph.increment_randomness_offset(sympy_product(size))
  918. random_pos = ir.FixedLayout(
  919. device,
  920. dtype,
  921. size,
  922. ir.FlexibleLayout.contiguous_strides(size),
  923. offset=offset,
  924. ).make_indexer()
  925. seed_buffer = V.graph.random_seed_buffer(device).make_loader()
  926. def inner_fn(index):
  927. seed = seed_buffer([])
  928. # change seed so that we don't collide with philox_rand_like()
  929. # TODO(jansel): migrate everything to philox_rand_like()
  930. seed = ops.bitwise_xor(seed, ops.constant(0xFFFF, torch.int32))
  931. return getattr(ops, fn_name)(
  932. seed,
  933. ops.index_expr(random_pos(index), torch.int32),
  934. dtype,
  935. )
  936. return Pointwise.create(
  937. device=device,
  938. dtype=dtype,
  939. inner_fn=inner_fn,
  940. ranges=list(size),
  941. )
  942. return rand_or_randn
  943. fallback_rand = fallback_handler(aten.rand)
  944. fallback_randn = fallback_handler(aten.randn)
  945. fast_rand = make_rand("rand")
  946. fast_randn = make_rand("randn")
  947. @register_lowering([aten.rand, torch.rand])
  948. def rand(*args, **kwargs):
  949. if config.fallback_random or kwargs.get("generator", None) is not None:
  950. return fallback_rand(*args, **kwargs)
  951. else:
  952. kwargs.pop("generator", None)
  953. return fast_rand(*args, **kwargs)
  954. @register_lowering([aten.randn, torch.randn])
  955. def randn(*args, **kwargs):
  956. if config.fallback_random or kwargs.get("generator", None) is not None:
  957. return fallback_randn(*args, **kwargs)
  958. else:
  959. kwargs.pop("generator", None)
  960. return fast_randn(*args, **kwargs)
  961. @register_lowering(overrides.philox_seed_like._overloadpacket)
  962. def philox_seed_like(x):
  963. warn_triton_random()
  964. return V.graph.random_seed_buffer(x.get_device())
  965. @register_lowering(overrides.philox_rand_like._overloadpacket, type_promotion_kind=None)
  966. def philox_rand_like(x, seed, offset):
  967. device = x.get_device()
  968. dtype = x.get_dtype()
  969. size = x.get_size()
  970. random_pos = ir.FixedLayout(
  971. device,
  972. dtype,
  973. size,
  974. ir.FlexibleLayout.contiguous_strides(size),
  975. offset=sympy.expand(offset),
  976. ).make_indexer()
  977. seed_loader = seed.make_loader()
  978. def inner_fn(index):
  979. return ops.rand(
  980. seed_loader([]),
  981. ops.index_expr(random_pos(index), torch.int32),
  982. dtype,
  983. )
  984. return Pointwise.create(
  985. device=device,
  986. dtype=dtype,
  987. inner_fn=inner_fn,
  988. ranges=list(size),
  989. )
  990. def require_dense(_, *args, **kwargs):
  991. args, kwargs = pytree.tree_map_only(
  992. ir.IRNode, lambda t: ir.ExternKernel.require_stride1(t), (args, kwargs)
  993. )
  994. return args, kwargs
  995. def require_contiguous(_, *args, **kwargs):
  996. args, kwargs = pytree.tree_map_only(
  997. ir.IRNode, lambda t: ir.ExternKernel.require_contiguous(t), (args, kwargs)
  998. )
  999. return args, kwargs
  1000. def constrain_to_fx_strides(fx_node, *args, **kwargs):
  1001. def apply_constraint(arg, fx_arg):
  1002. if isinstance(arg, ir.IRNode):
  1003. stride_order = ir.get_stride_order(fx_arg.meta["val"].stride())
  1004. return ir.ExternKernel.require_stride_order(arg, stride_order)
  1005. return arg
  1006. args = [apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args)]
  1007. kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()}
  1008. return args, kwargs
  1009. # TODO(jansel): we should implement decomps or lowerings for these
  1010. # https://github.com/pytorch/torchdynamo/issues/327
  1011. FALLBACK_ALLOW_LIST = {
  1012. "torchvision::roi_align",
  1013. }
  1014. make_fallback(aten._adaptive_avg_pool2d_backward, require_dense)
  1015. make_fallback(aten.convolution_backward, constrain_to_fx_strides)
  1016. make_fallback(aten._cudnn_rnn, require_dense)
  1017. make_fallback(aten._cudnn_rnn_backward, require_contiguous)
  1018. make_fallback(aten.cumsum, require_dense, warn=False)
  1019. make_fallback(aten._embedding_bag, require_contiguous)
  1020. make_fallback(aten._embedding_bag_forward_only, require_contiguous)
  1021. make_fallback(aten._fused_moving_avg_obs_fq_helper)
  1022. make_fallback(aten._fused_moving_avg_obs_fq_helper_functional)
  1023. make_fallback(aten.grid_sampler_2d_backward, require_dense)
  1024. make_fallback(aten.randperm)
  1025. make_fallback(aten.sort)
  1026. make_fallback(aten.sort.stable)
  1027. make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors)
  1028. make_fallback(aten._thnn_fused_lstm_cell, require_dense)
  1029. make_fallback(aten.topk)
  1030. make_fallback(aten.upsample_bicubic2d_backward, require_contiguous)
  1031. make_fallback(aten.upsample_bilinear2d_backward, require_dense)
  1032. # The following were added as a result of https://github.com/pytorch/pytorch/pull/94039 to pass tests
  1033. # It's not necessarily a priority to implement these
  1034. make_fallback(aten.upsample_linear1d)
  1035. make_fallback(aten.upsample_trilinear3d)
  1036. make_fallback(aten.upsample_linear1d_backward)
  1037. make_fallback(aten.upsample_trilinear3d_backward)
  1038. make_fallback(aten._adaptive_avg_pool3d)
  1039. make_fallback(aten.adaptive_max_pool2d)
  1040. make_fallback(aten.adaptive_max_pool3d)
  1041. make_fallback(aten.addbmm)
  1042. make_fallback(aten.addmv)
  1043. make_fallback(aten.aminmax)
  1044. make_fallback(aten.avg_pool3d)
  1045. make_fallback(aten.block_diag)
  1046. make_fallback(aten._cdist_forward)
  1047. make_fallback(aten.count_nonzero)
  1048. make_fallback(aten.cummax)
  1049. make_fallback(aten.cummin)
  1050. make_fallback(aten.cumprod)
  1051. make_fallback(aten.deg2rad)
  1052. make_fallback(aten.diagonal_copy, warn=False)
  1053. make_fallback(aten.diagonal_scatter, warn=False)
  1054. make_fallback(aten.digamma, warn=False)
  1055. make_fallback(aten.dist)
  1056. make_fallback(aten._efficientzerotensor)
  1057. make_fallback(aten._embedding_bag_per_sample_weights_backward)
  1058. make_fallback(aten.erfc, warn=False)
  1059. make_fallback(aten.erfinv, warn=False)
  1060. make_fallback(aten.fmax, warn=False)
  1061. make_fallback(aten.fmin, warn=False)
  1062. make_fallback(aten.dist)
  1063. make_fallback(aten._efficientzerotensor)
  1064. make_fallback(aten._embedding_bag_per_sample_weights_backward)
  1065. make_fallback(aten.fractional_max_pool2d)
  1066. make_fallback(aten.fractional_max_pool3d)
  1067. make_fallback(aten.frexp)
  1068. make_fallback(aten.geqrf)
  1069. make_fallback(aten.histc)
  1070. make_fallback(aten.i0)
  1071. make_fallback(aten.igamma, warn=False)
  1072. make_fallback(aten.igammac, warn=False)
  1073. make_fallback(aten.isin)
  1074. make_fallback(aten.isneginf, warn=False)
  1075. make_fallback(aten.isposinf, warn=False)
  1076. make_fallback(aten.kthvalue)
  1077. make_fallback(aten.linalg_cholesky_ex)
  1078. make_fallback(aten.linalg_cross)
  1079. make_fallback(aten._linalg_det)
  1080. make_fallback(aten.linalg_householder_product)
  1081. make_fallback(aten.linalg_inv_ex)
  1082. make_fallback(aten.linalg_ldl_factor_ex)
  1083. make_fallback(aten.linalg_ldl_solve)
  1084. make_fallback(aten.linalg_lu)
  1085. make_fallback(aten.linalg_lu_factor_ex)
  1086. make_fallback(aten.linalg_lu_solve)
  1087. make_fallback(aten.linalg_matrix_exp)
  1088. make_fallback(aten.linalg_qr)
  1089. make_fallback(aten._linalg_slogdet)
  1090. make_fallback(aten._linalg_solve_ex)
  1091. make_fallback(aten.linalg_solve_triangular)
  1092. make_fallback(aten._linalg_svd)
  1093. make_fallback(aten.logaddexp2)
  1094. make_fallback(aten.logcumsumexp)
  1095. make_fallback(aten.log_sigmoid_forward, warn=False)
  1096. make_fallback(aten.logspace, warn=False)
  1097. make_fallback(aten.lu_unpack)
  1098. make_fallback(aten.max_pool3d_with_indices)
  1099. make_fallback(aten.max_unpool2d)
  1100. make_fallback(aten.max_unpool3d)
  1101. make_fallback(aten.median)
  1102. make_fallback(aten.mode)
  1103. make_fallback(aten.multilabel_margin_loss_forward)
  1104. make_fallback(aten.multi_margin_loss)
  1105. make_fallback(aten.nanmedian)
  1106. make_fallback(aten.nansum)
  1107. make_fallback(aten.narrow_copy, warn=False)
  1108. make_fallback(aten.ormqr)
  1109. make_fallback(aten._pdist_forward)
  1110. make_fallback(aten.pixel_shuffle)
  1111. make_fallback(aten.pixel_unshuffle)
  1112. make_fallback(aten.polygamma)
  1113. make_fallback(aten.prod, warn=False)
  1114. make_fallback(aten.put)
  1115. make_fallback(aten.rad2deg)
  1116. make_fallback(aten.reflection_pad1d)
  1117. make_fallback(aten.renorm)
  1118. make_fallback(aten.replication_pad1d)
  1119. make_fallback(aten.resize)
  1120. make_fallback(aten.resize_)
  1121. make_fallback(aten.resize_as)
  1122. make_fallback(aten.resize_as_)
  1123. make_fallback(aten.searchsorted)
  1124. make_fallback(aten.smooth_l1_loss)
  1125. make_fallback(aten.special_airy_ai)
  1126. make_fallback(aten.special_bessel_j0, warn=False)
  1127. make_fallback(aten.special_bessel_j1, warn=False)
  1128. make_fallback(aten.special_bessel_y0, warn=False)
  1129. make_fallback(aten.special_bessel_y1)
  1130. make_fallback(aten.special_chebyshev_polynomial_t)
  1131. make_fallback(aten.special_chebyshev_polynomial_u)
  1132. make_fallback(aten.special_erfcx, warn=False)
  1133. make_fallback(aten.special_hermite_polynomial_h)
  1134. make_fallback(aten.special_hermite_polynomial_he)
  1135. make_fallback(aten.special_i0e, warn=False)
  1136. make_fallback(aten.special_i1, warn=False)
  1137. make_fallback(aten.special_i1e, warn=False)
  1138. make_fallback(aten.special_laguerre_polynomial_l)
  1139. make_fallback(aten.special_modified_bessel_i0)
  1140. make_fallback(aten.special_modified_bessel_i1)
  1141. make_fallback(aten.special_modified_bessel_k0)
  1142. make_fallback(aten.special_modified_bessel_k1)
  1143. make_fallback(aten.special_ndtri, warn=False)
  1144. make_fallback(aten.special_scaled_modified_bessel_k0)
  1145. make_fallback(aten.special_scaled_modified_bessel_k1)
  1146. make_fallback(aten.special_spherical_bessel_j0, warn=False)
  1147. make_fallback(aten.special_zeta, warn=False)
  1148. make_fallback(aten.take)
  1149. make_fallback(aten.threshold, warn=False)
  1150. make_fallback(aten.trace, warn=False)
  1151. make_fallback(aten._trilinear)
  1152. make_fallback(aten.unfold_copy, warn=False)
  1153. make_fallback(aten.uniform, warn=False)
  1154. make_fallback(aten.unsafe_split, warn=False)
  1155. make_fallback(aten.vdot)
  1156. make_fallback(aten.view_as_complex)
  1157. make_fallback(aten.view_copy)
  1158. make_fallback(aten._adaptive_avg_pool3d_backward)
  1159. make_fallback(aten.adaptive_max_pool2d_backward)
  1160. make_fallback(aten.adaptive_max_pool3d_backward)
  1161. make_fallback(aten.avg_pool3d_backward)
  1162. make_fallback(aten.bitwise_or_, warn=False)
  1163. make_fallback(aten._cdist_backward)
  1164. make_fallback(aten.diagonal_backward, warn=False)
  1165. make_fallback(aten._embedding_bag_dense_backward)
  1166. make_fallback(aten.fractional_max_pool2d_backward)
  1167. make_fallback(aten.fractional_max_pool3d_backward)
  1168. make_fallback(aten._linalg_check_errors)
  1169. make_fallback(aten.max_pool3d_with_indices_backward)
  1170. make_fallback(aten.multilabel_margin_loss_backward)
  1171. make_fallback(aten.multi_margin_loss_backward)
  1172. make_fallback(aten._pdist_backward)
  1173. make_fallback(aten.reflection_pad1d_backward)
  1174. make_fallback(aten.replication_pad1d_backward)
  1175. make_fallback(aten.smooth_l1_loss_backward)
  1176. make_fallback(aten.soft_margin_loss_backward, warn=False)
  1177. make_fallback(aten.softshrink_backward, warn=False)
  1178. make_fallback(aten.squeeze_copy)
  1179. make_fallback(aten.linalg_pinv.atol_rtol_tensor)
  1180. make_fallback(aten.segment_reduce.default)
  1181. make_fallback(aten._segment_reduce_backward.default)
  1182. make_fallback(aten.angle)
  1183. make_fallback(aten.cholesky_inverse)
  1184. make_fallback(aten.cholesky_solve)
  1185. make_fallback(aten._fft_r2c)
  1186. make_fallback(aten.histogram.bin_ct)
  1187. make_fallback(aten._histogramdd_bin_edges.default)
  1188. make_fallback(aten._histogramdd_from_bin_cts.default)
  1189. make_fallback(aten.index_reduce)
  1190. make_fallback(aten.masked_scatter)
  1191. make_fallback(aten.to_sparse)
  1192. make_fallback(aten.triangular_solve)
  1193. make_fallback(aten.expand_copy)
  1194. make_fallback(aten.gcd.default, warn=False)
  1195. make_fallback(aten._linalg_eigh)
  1196. make_fallback(aten.zeros.names)
  1197. # TODO(fdrocha): this should be removed once the register_pointwise(aten.bitwise_right_shift) below is uncommented
  1198. make_fallback(aten.bitwise_right_shift, warn=False)
  1199. add_layout_constraint(aten.convolution, constrain_to_fx_strides)
  1200. @register_lowering(aten.convolution)
  1201. def convolution(
  1202. x: TensorBox,
  1203. weight: TensorBox,
  1204. bias: TensorBox,
  1205. stride: List[int],
  1206. padding: List[int],
  1207. dilation: List[int],
  1208. transposed: bool,
  1209. output_padding: List[int],
  1210. groups: int,
  1211. ):
  1212. is_cpu = all(
  1213. input.get_device().type == "cpu"
  1214. for input in (x, weight, bias)
  1215. if input is not None
  1216. )
  1217. result = TensorBox.create(
  1218. ir.Convolution.create(
  1219. x,
  1220. weight,
  1221. bias if is_cpu else None, # For cpu path, bias can always be fused
  1222. stride,
  1223. padding,
  1224. dilation,
  1225. transposed,
  1226. output_padding,
  1227. groups,
  1228. )
  1229. )
  1230. if not is_cpu and bias is not None:
  1231. kernel_dims = len(weight.get_size()) - 2
  1232. out_chan = result.get_size()[-1 - kernel_dims]
  1233. bias = view(bias, [out_chan] + kernel_dims * [1])
  1234. result = add(result, bias)
  1235. return result
  1236. @register_lowering(aten._convolution)
  1237. def _convolution(
  1238. x,
  1239. weight,
  1240. bias,
  1241. stride,
  1242. padding,
  1243. dilation,
  1244. transposed,
  1245. output_padding,
  1246. groups,
  1247. benchmark,
  1248. deterministic,
  1249. cudnn_enabled,
  1250. allow_tf32,
  1251. ):
  1252. return convolution(
  1253. x, weight, bias, stride, padding, dilation, transposed, output_padding, groups
  1254. )
  1255. @register_lowering(aten.clone)
  1256. def clone(x, *, memory_format=0):
  1257. # TODO(jansel): memory format
  1258. return Pointwise.create(
  1259. device=x.get_device(),
  1260. dtype=x.get_dtype(),
  1261. inner_fn=x.make_loader(),
  1262. ranges=list(x.get_size()),
  1263. )
  1264. if hasattr(aten, "lift_fresh_copy"):
  1265. register_lowering(aten.lift_fresh_copy)(clone)
  1266. @register_lowering(prims.iota)
  1267. def iota(
  1268. length,
  1269. *,
  1270. start,
  1271. step,
  1272. dtype,
  1273. device,
  1274. requires_grad,
  1275. ):
  1276. def fn(index):
  1277. return ops.index_expr(step * index[0] + start, dtype=dtype)
  1278. return Pointwise.create(
  1279. device=decode_device(device),
  1280. dtype=dtype,
  1281. inner_fn=fn,
  1282. ranges=[length],
  1283. )
  1284. @register_lowering(aten.select_scatter, type_promotion_kind=None)
  1285. def select_scatter(x, src, dim: int, index: int):
  1286. assert x.get_dtype() == src.get_dtype()
  1287. x_loader = x.make_loader()
  1288. dim = _validate_dim(x, dim, 0)
  1289. if index < 0:
  1290. index = index + x.get_size()[dim]
  1291. V.graph.sizevars.guard_leq(0, index)
  1292. V.graph.sizevars.guard_lt(index, x.get_size()[dim])
  1293. src = expand(unsqueeze(src, dim), x.get_size())
  1294. src_loader = src.make_loader()
  1295. def inner_fn(idx):
  1296. return ops.where(
  1297. ops.eq(
  1298. ops.index_expr(idx[dim], torch.int32),
  1299. ops.index_expr(index, torch.int32),
  1300. ),
  1301. src_loader(idx),
  1302. x_loader(idx),
  1303. )
  1304. return Pointwise.create(
  1305. device=x.get_device(),
  1306. dtype=x.get_dtype(),
  1307. inner_fn=inner_fn,
  1308. ranges=list(x.get_size()),
  1309. )
  1310. @register_lowering(aten.slice_scatter, type_promotion_kind=None)
  1311. def slice_scatter(x, src, dim=0, start=None, end=None, step=1):
  1312. assert x.get_dtype() == src.get_dtype()
  1313. x_loader = x.make_loader()
  1314. dim = _validate_dim(x, dim, 0)
  1315. dim_size = x.get_size()[dim]
  1316. if start is not None and start < 0:
  1317. start = start + dim_size
  1318. if end is not None and end < 0:
  1319. end = end + dim_size
  1320. if start is None:
  1321. start = 0
  1322. if end is None or V.graph.sizevars.maybe_guard_leq(x.get_size()[dim], end):
  1323. end = dim_size
  1324. src_size = list(x.get_size())
  1325. src_size[dim] = ir.FloorDiv(sympy.expand(end - start), sympy.expand(step))
  1326. src = expand(src, src_size)
  1327. src_loader = src.make_loader()
  1328. def inner_fn(idx):
  1329. if start == 0 and end == dim_size and step == 1:
  1330. # selecting every element is the same as just src.clone()
  1331. return src_loader(idx)
  1332. idx_dim = ops.index_expr(idx[dim], torch.int32)
  1333. src_idx = list(idx)
  1334. src_idx[dim] = ir.FloorDiv(idx[dim] - start, step)
  1335. mask = []
  1336. if start != 0:
  1337. mask.append(
  1338. ops.ge(
  1339. idx_dim,
  1340. ops.index_expr(sympy.expand(start), torch.int32),
  1341. )
  1342. )
  1343. if end != dim_size:
  1344. mask.append(
  1345. ops.lt(
  1346. idx_dim,
  1347. ops.index_expr(sympy.expand(end), torch.int32),
  1348. )
  1349. )
  1350. if step != 1:
  1351. mask.append(
  1352. ops.eq(
  1353. ops.index_expr(
  1354. ir.ModularIndexing(idx[dim] - start, 1, step), torch.int32
  1355. ),
  1356. ops.constant(0, torch.int32),
  1357. )
  1358. )
  1359. assert mask
  1360. mask = functools.reduce(ops.and_, mask)
  1361. src_val = ops.masked(
  1362. mask,
  1363. lambda: src_loader(src_idx),
  1364. 0 if is_integer_type(x) else 0.0,
  1365. )
  1366. return ops.where(
  1367. mask,
  1368. src_val,
  1369. x_loader(idx),
  1370. )
  1371. return Pointwise.create(
  1372. device=x.get_device(),
  1373. dtype=x.get_dtype(),
  1374. inner_fn=inner_fn,
  1375. ranges=list(x.get_size()),
  1376. )
  1377. def _unwrap(x):
  1378. if isinstance(x, (list, tuple)) and len(x) > 0:
  1379. return _unwrap(x[0])
  1380. return x
  1381. @register_lowering([torch.tensor, aten.scalar_tensor])
  1382. def tensor(data, *, dtype=None, device=None, layout=None, pin_memory=False):
  1383. assert layout in (None, torch.strided)
  1384. assert pin_memory is False
  1385. if isinstance(_unwrap(data), int):
  1386. dtype = dtype or torch.int64
  1387. else:
  1388. dtype = dtype or torch.get_default_dtype()
  1389. if isinstance(data, (float, int)):
  1390. ranges = []
  1391. def inner_fn(index):
  1392. return ops.constant(data, dtype)
  1393. elif len(data) == 0 or isinstance(data[0], (float, int)) and len(data) <= 8:
  1394. # inline small tensors
  1395. ranges = [sympy.Integer(len(data))]
  1396. def inner_fn(index):
  1397. def binary_search(start, end):
  1398. assert start < end
  1399. if end - start == 1:
  1400. return ops.constant(data[start], dtype)
  1401. mid = (end - start) // 2 + start
  1402. return ops.where(
  1403. ops.lt(
  1404. ops.index_expr(index[0], torch.int64),
  1405. ops.constant(mid, torch.int64),
  1406. ),
  1407. binary_search(start, mid),
  1408. binary_search(mid, end),
  1409. )
  1410. if len(data) == 0:
  1411. return ops.constant(0, dtype)
  1412. return binary_search(0, len(data))
  1413. else:
  1414. return V.graph.add_tensor_constant(
  1415. torch.tensor(data, dtype=dtype, device=device)
  1416. )
  1417. return Pointwise.create(
  1418. device=decode_device(device),
  1419. dtype=dtype,
  1420. inner_fn=inner_fn,
  1421. ranges=ranges,
  1422. )
  1423. @register_lowering(torch.as_tensor)
  1424. def as_tensor(data, dtype=None, device=None):
  1425. if isinstance(data, TensorBox):
  1426. if dtype is not None:
  1427. data = to_dtype(data, dtype)
  1428. if device is not None:
  1429. data = to_device(data, device)
  1430. return data
  1431. return tensor(data, dtype=dtype, device=device)
  1432. @register_lowering(torch.LongTensor)
  1433. def long_tensor(data):
  1434. return tensor(data, dtype=torch.int64)
  1435. @register_lowering(aten._local_scalar_dense)
  1436. def _local_scalar_dense(data):
  1437. return ir.DynamicScalar()
  1438. def _full(fill_value, device, dtype, size):
  1439. value = fill_value
  1440. if not isinstance(fill_value, (int, float)) and hasattr(value, "value"):
  1441. value = value.value
  1442. if isinstance(value, (int, float, sympy.Expr)):
  1443. def inner_fn(index):
  1444. return ops.constant(value, dtype)
  1445. else:
  1446. assert len(value.get_size()) == 0
  1447. value_loader = value.make_loader()
  1448. def inner_fn(index):
  1449. return value_loader([])
  1450. return Pointwise.create(
  1451. device=device,
  1452. dtype=dtype,
  1453. inner_fn=inner_fn,
  1454. ranges=list(size),
  1455. )
  1456. @register_lowering(aten.full_like, type_promotion_kind=None)
  1457. def full_like(x, fill_value, **kwargs):
  1458. return create_tensor_like(tensor_constructor(fill_value))(x, **kwargs)
  1459. def tensor_constructor(fill_value):
  1460. # torch.zeros, torch.ones, etc
  1461. def inner(
  1462. *size,
  1463. names=None,
  1464. dtype=None,
  1465. device=None,
  1466. layout=0,
  1467. pin_memory=False,
  1468. memory_format=None,
  1469. ):
  1470. assert names is None
  1471. assert not pin_memory
  1472. assert layout in (0, torch.strided)
  1473. assert memory_format in (None, torch.contiguous_format)
  1474. device = decode_device(device)
  1475. dtype = dtype or torch.get_default_dtype()
  1476. if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)):
  1477. size = tuple(size[0])
  1478. size = [sympy.expand(s) for s in size]
  1479. return _full(fill_value, device, dtype, size)
  1480. return inner
  1481. @register_lowering([torch.empty, aten.empty])
  1482. def empty(
  1483. *size,
  1484. names=None,
  1485. dtype=None,
  1486. layout=None,
  1487. device=None,
  1488. pin_memory=None,
  1489. memory_format=None,
  1490. ):
  1491. assert names is None
  1492. assert memory_format in (None, torch.contiguous_format)
  1493. device = decode_device(device)
  1494. if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)):
  1495. size = list(size[0])
  1496. return empty_strided(
  1497. size, None, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
  1498. )
  1499. def create_tensor_like(creation_fn):
  1500. """
  1501. Shim to convert X_like(...) into X(...). For example zeros_like() into zeros().
  1502. """
  1503. def _constant_like(
  1504. x, *, dtype=None, device=None, layout=0, pin_memory=False, memory_format=None
  1505. ):
  1506. assert not pin_memory
  1507. assert layout in (0, torch.strided)
  1508. if dtype is None:
  1509. dtype = x.get_dtype()
  1510. else:
  1511. dtype = decode_dtype(dtype)
  1512. device = device or x.get_device()
  1513. size = list(x.get_size())
  1514. return creation_fn(
  1515. size, dtype=dtype, device=device, layout=layout, pin_memory=pin_memory
  1516. )
  1517. return _constant_like
  1518. def constant_like(fill_value):
  1519. return create_tensor_like(tensor_constructor(fill_value))
  1520. empty_like = register_lowering(aten.empty_like)(create_tensor_like(empty))
  1521. ones_like = create_tensor_like(tensor_constructor(1))
  1522. if not config.fallback_random:
  1523. rand_like = register_lowering(aten.rand_like)(create_tensor_like(rand))
  1524. randn_like = register_lowering(aten.randn_like)(create_tensor_like(randn))
  1525. def new_constant(fill_value):
  1526. def _new_constant(
  1527. x, size, *, dtype=None, layout=None, device=None, pin_memory=None
  1528. ):
  1529. assert isinstance(size, (list, type))
  1530. assert not pin_memory
  1531. assert not layout or layout == torch.strided
  1532. dtype = decode_dtype(dtype) or x.get_dtype()
  1533. device = device or x.get_device()
  1534. size = [sympy.Integer(s) for s in size]
  1535. return _full(fill_value, device, dtype, size)
  1536. return _new_constant
  1537. @register_lowering(aten.new_empty)
  1538. def new_empty(x, size, *, dtype=None, layout=None, device=None, pin_memory=None):
  1539. if dtype is None:
  1540. dtype = x.get_dtype()
  1541. if device is None:
  1542. device = x.get_device()
  1543. return empty_strided(
  1544. size, None, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
  1545. )
  1546. @register_lowering(aten.empty_strided)
  1547. def empty_strided(
  1548. size, stride, *, dtype=None, layout=None, device=None, pin_memory=None
  1549. ):
  1550. assert isinstance(size, (list, type))
  1551. assert isinstance(stride, (list, type, type(None)))
  1552. assert not pin_memory
  1553. assert not layout or layout == torch.strided
  1554. dtype = decode_dtype(dtype) or torch.get_default_dtype()
  1555. device = device or torch.tensor(0.0).device
  1556. pointwise = _full(fill_value=0, device=device, dtype=dtype, size=size)
  1557. pointwise.realize()
  1558. buffer = pointwise.data.data
  1559. # explicitly set ranges to zeros in order to make a NopKernelSchedulerNode
  1560. buffer.data.ranges = [0] * len(size)
  1561. assert isinstance(buffer, ir.ComputedBuffer)
  1562. size = [sympy.expand(s) for s in size]
  1563. stride = (
  1564. [sympy.expand(s) for s in stride]
  1565. if stride
  1566. else ir.FlexibleLayout.contiguous_strides(size)
  1567. )
  1568. buffer.layout = ir.FixedLayout(
  1569. device=device,
  1570. dtype=dtype,
  1571. size=size,
  1572. stride=stride,
  1573. )
  1574. return pointwise
  1575. @register_lowering(aten.new_empty_strided)
  1576. def new_empty_strided(
  1577. x, size, stride, *, dtype=None, layout=None, device=None, pin_memory=None
  1578. ):
  1579. if dtype is None:
  1580. dtype = x.get_dtype()
  1581. if device is None:
  1582. device = x.get_device()
  1583. return empty_strided(
  1584. size, stride, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
  1585. )
  1586. @register_lowering(prims.copy_strided.default)
  1587. def copy_strided(x, stride):
  1588. stride = [V.graph.sizevars.size_hint(s) for s in stride]
  1589. stride_order = sorted(range(len(stride)), key=stride.__getitem__)
  1590. return ir.ExternKernel.require_stride_order(x, stride_order)
  1591. @register_lowering([torch.full, aten.full])
  1592. def full(size, fill_value, **kwargs):
  1593. return tensor_constructor(fill_value)(size, **kwargs)
  1594. @register_lowering(aten.gather, type_promotion_kind=None)
  1595. def gather(x, dim, index):
  1596. assert isinstance(x, TensorBox)
  1597. assert index.get_dtype() == torch.int64
  1598. offset = len(x.get_size()) == 0
  1599. dim = _validate_dim(x, dim, offset)
  1600. x_loader = x.make_loader()
  1601. index_loader = index.make_loader()
  1602. def fn(idx):
  1603. idx = list(idx)
  1604. if len(idx) != 0:
  1605. idx[dim] = ops.indirect_indexing(index_loader(idx))
  1606. return x_loader(idx)
  1607. return Pointwise.create(
  1608. device=x.get_device(),
  1609. dtype=x.get_dtype(),
  1610. inner_fn=fn,
  1611. ranges=index.get_size(),
  1612. )
  1613. @register_lowering(aten.embedding, type_promotion_kind=None)
  1614. def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False):
  1615. assert not sparse
  1616. assert isinstance(weight, TensorBox)
  1617. assert isinstance(indices, TensorBox)
  1618. assert "int" in str(indices.get_dtype())
  1619. weight_loader = weight.make_loader()
  1620. indices_loader = indices.make_loader()
  1621. indices_ndim = len(indices.get_size())
  1622. new_size = [*indices.get_size(), *weight.get_size()[1:]]
  1623. def fn(idx):
  1624. assert len(idx) == len(new_size), f"{idx} != {new_size}"
  1625. var_index = indices_loader(idx[:indices_ndim])
  1626. weight_idx = [ops.indirect_indexing(var_index)] + [*idx[indices_ndim:]]
  1627. return weight_loader(weight_idx)
  1628. return Pointwise.create(
  1629. device=weight.get_device(),
  1630. dtype=weight.get_dtype(),
  1631. inner_fn=fn,
  1632. ranges=new_size,
  1633. )
  1634. def check_and_broadcast_indices(indices, device):
  1635. assert all(
  1636. i.get_dtype() in (torch.int64, torch.int32, torch.bool, torch.uint8)
  1637. for i in indices
  1638. if i is not None
  1639. ), f"indices must be int64, byte or bool. Got {[i.get_dtype() for i in indices if i is not None]}"
  1640. if any(
  1641. i.get_dtype() in (torch.bool, torch.uint8) for i in indices if i is not None
  1642. ):
  1643. raise NotImplementedError("Fallback for bool indices")
  1644. valid_idxs = [i for i, x in enumerate(indices) if isinstance(x, TensorBox)]
  1645. assert len(valid_idxs) > 0, "requires at least 1 non-None index"
  1646. new_indices = [None] * len(indices)
  1647. for i, x in zip(valid_idxs, broadcast_tensors(*[indices[i] for i in valid_idxs])):
  1648. # Eager allows indices to be CPU tensor when running on CUDA
  1649. # FIXME: Calling to_device(x, device) should work but
  1650. # test_advancedindex_mixed_cpu_devices still fails
  1651. if x.get_device() != device:
  1652. raise NotImplementedError("Fallback when indices is on a different device")
  1653. new_indices[i] = x
  1654. output_dim = len(x.get_size())
  1655. start_offset = 0
  1656. # only support None at start or end for now
  1657. tmp = list(new_indices)
  1658. while tmp and tmp[-1] is None:
  1659. tmp.pop()
  1660. while tmp and tmp[0] is None:
  1661. tmp.pop(0)
  1662. start_offset += 1
  1663. if any((i is None) for i in tmp):
  1664. raise NotImplementedError("Fallback when None is in the middle of indices")
  1665. end_offset = output_dim + start_offset
  1666. return new_indices, start_offset, end_offset
  1667. @register_lowering(aten.index, type_promotion_kind=None)
  1668. def index(x, indices):
  1669. assert isinstance(indices, (list, tuple))
  1670. x_loader = x.make_loader()
  1671. try:
  1672. indices, start_offset, end_offset = check_and_broadcast_indices(
  1673. indices, x.get_device()
  1674. )
  1675. except NotImplementedError:
  1676. x.realize()
  1677. return fallback_handler(aten.index)(x, indices)
  1678. indices_sizes = [i.get_size() for i in indices if i is not None]
  1679. indices_loaders = [i.make_loader() for i in indices if i is not None]
  1680. # no guards on output size, all the guards are set in broadcast_tensors
  1681. output_size = list(indices_sizes[0])
  1682. x_size = x.get_size()
  1683. output_size = [
  1684. *x_size[:start_offset],
  1685. *output_size,
  1686. *x_size[start_offset + len(indices_loaders) :],
  1687. ]
  1688. def fn(idx):
  1689. assert len(idx) == len(output_size)
  1690. new_index = [
  1691. ops.indirect_indexing(loader(idx[start_offset:end_offset]))
  1692. for loader in indices_loaders
  1693. ]
  1694. new_index = [*idx[:start_offset], *new_index, *idx[end_offset:]]
  1695. return x_loader(new_index)
  1696. return Pointwise.create(
  1697. device=x.get_device(),
  1698. dtype=x.get_dtype(),
  1699. inner_fn=fn,
  1700. ranges=output_size,
  1701. )
  1702. # All the indexing decompositions are written in terms of index, index_put, and index_put_
  1703. # We cannot have this lowering as a decomposition as it introduces
  1704. # mutation in the graph, which is bad for Aot Autograd. Aot Autograd runs dead
  1705. # code elimination and common subexpression elimination optimizations, which
  1706. # assume graphs to be side-effect free. More details at
  1707. # https://github.com/pytorch/torchdynamo/issues/1235
  1708. # and
  1709. # https://github.com/pytorch/torchdynamo/issues/1863
  1710. @register_lowering([aten.index_put])
  1711. def index_put(x, indices, values, accumulate=False):
  1712. return index_put_(clone(x), indices, values, accumulate)
  1713. def index_put_as_masked_fill(self, indices, value, accumulate):
  1714. if value.get_device() != self.get_device():
  1715. value = to_device(value, self.get_device())
  1716. if accumulate:
  1717. value = add(self, value)
  1718. return mutate_to(self, where(indices[0], value, self))
  1719. def index_put_fallback(self, indices, values, accumulate):
  1720. ir.IndexPutFallback(self, indices, values, accumulate)
  1721. return self
  1722. @register_lowering(aten.index_put_, type_promotion_kind=None)
  1723. def index_put_(self, indices, values, accumulate=False):
  1724. # Dispatch to masked fill for single boolean index with single value
  1725. if (
  1726. values.get_numel() == 1
  1727. and len(indices) == 1
  1728. and indices[0].get_dtype() in {torch.bool, torch.uint8}
  1729. ):
  1730. return index_put_as_masked_fill(self, indices, values, accumulate)
  1731. # Fallback if there is a boolean index
  1732. for index in indices:
  1733. if index is not None and index.get_dtype() in {torch.bool, torch.uint8}:
  1734. return index_put_fallback(self, indices, values, accumulate)
  1735. x_size = self.get_size()
  1736. x_ndim = len(x_size)
  1737. # fallback to aten.index_put_, as tl.atomic_add does NOT support int64 or bool
  1738. if self.get_dtype() in {torch.int64, torch.bool}:
  1739. # self is an scalar Tensor
  1740. if x_ndim == 0:
  1741. self = view(self, [1])
  1742. self = index_put_fallback(self, indices, values, accumulate)
  1743. if x_ndim == 0:
  1744. self = view(self, [])
  1745. return self
  1746. values = to_dtype(values, self.get_dtype())
  1747. try:
  1748. indices, start_offset, end_offset = check_and_broadcast_indices(
  1749. indices, self.get_device()
  1750. )
  1751. except NotImplementedError:
  1752. return index_put_fallback(self, indices, values, accumulate)
  1753. indices_sizes = [i.get_size() for i in indices if i is not None]
  1754. indices_loaders = [i.make_loader() for i in indices if i is not None]
  1755. assert isinstance(self, TensorBox)
  1756. self.realize()
  1757. V.graph.realize_users_of(self.get_name())
  1758. # self is an scalar Tensor
  1759. if x_ndim == 0:
  1760. self = view(self, [1])
  1761. output_size = list(indices_sizes[0])
  1762. expected_vals_size = [
  1763. *x_size[:start_offset],
  1764. *output_size,
  1765. *x_size[start_offset + len(indices_sizes) :],
  1766. ]
  1767. values = expand(values, expected_vals_size)
  1768. # all guards are set above during broadcast_tensors and expand
  1769. def output_indexer(index):
  1770. assert len(index) == len(expected_vals_size)
  1771. new_index = [
  1772. ops.indirect_indexing(loader(index[start_offset:end_offset]))
  1773. for loader in indices_loaders
  1774. ]
  1775. new_index = [*index[:start_offset], *new_index, *index[end_offset:]]
  1776. return new_index
  1777. scatter = ir.Scatter(
  1778. device=self.get_device(),
  1779. dtype=self.get_dtype(),
  1780. inner_fn=values.make_loader(),
  1781. ranges=expected_vals_size, # iter_ranges,
  1782. output_indexer=output_indexer,
  1783. scatter_mode="atomic_add" if accumulate else None,
  1784. )
  1785. buffer = ir.ComputedBuffer(
  1786. None,
  1787. ir.MutationLayout(self),
  1788. scatter,
  1789. )
  1790. buffer.name = V.graph.register_buffer(buffer)
  1791. if x_ndim == 0:
  1792. self = view(self, [])
  1793. return self
  1794. @register_lowering(aten.as_strided_scatter, type_promotion_kind=None)
  1795. def as_strided_scatter(self, src, size, stride, storage_offset=None):
  1796. output = clone(self)
  1797. output_view = as_strided(output, size, stride, storage_offset)
  1798. copy_(output_view, src)
  1799. return output
  1800. @register_lowering(aten.scatter, type_promotion_kind=None)
  1801. def scatter(x, dim: int, index, src, **kwargs):
  1802. return scatter_(clone(x), dim, index, src, **kwargs)
  1803. def scatter_fallback(
  1804. fn, self, dim: int, index, src, *, reduce: str = None, include_self: bool = True
  1805. ):
  1806. if reduce not in {None, "sum"} or (
  1807. reduce == "sum" and self.get_dtype() in {torch.bool, torch.int64}
  1808. ):
  1809. self.realize()
  1810. return fallback_handler(fn)(
  1811. self, dim, index, src, reduce=reduce, include_self=include_self
  1812. )
  1813. return None
  1814. @register_lowering(aten.scatter_, type_promotion_kind=None)
  1815. def scatter_(self, dim: int, index, src, *, reduce: str = None):
  1816. if reduce == "add":
  1817. reduce = "sum"
  1818. elif reduce == "multiply":
  1819. reduce = "prod"
  1820. else:
  1821. assert reduce is None
  1822. fallback_result = scatter_fallback(
  1823. aten.scatter_, self, dim, index, src, reduce=reduce
  1824. )
  1825. if fallback_result:
  1826. return fallback_result
  1827. return scatter_reduce_(self, dim, index, src, reduce)
  1828. @register_lowering(aten.scatter_add, type_promotion_kind=None)
  1829. def scatter_add(x, dim: int, index, src):
  1830. return scatter_add_(clone(x), dim, index, src)
  1831. @register_lowering(aten.scatter_add_, type_promotion_kind=None)
  1832. def scatter_add_(x, dim: int, index, src):
  1833. return scatter_reduce_(clone(x), dim, index, src, "sum")
  1834. @register_lowering(aten.scatter_reduce, type_promotion_kind=None)
  1835. def scatter_reduce(x, dim: int, index, src, reduction_type, **kwargs):
  1836. return scatter_reduce_(clone(x), dim, index, src, reduction_type, **kwargs)
  1837. fallback_scatter_reduce_ = fallback_handler(aten.scatter_reduce_)
  1838. @register_lowering(aten.scatter_reduce_, type_promotion_kind=None)
  1839. def scatter_reduce_(self, dim: int, index, src, reduce, *, include_self: bool = True):
  1840. assert reduce in {None, "sum", "prod", "mean", "amax", "amin"}
  1841. fallback_result = scatter_fallback(
  1842. aten.scatter_reduce_,
  1843. self,
  1844. dim,
  1845. index,
  1846. src,
  1847. reduce=reduce,
  1848. include_self=include_self,
  1849. )
  1850. if fallback_result:
  1851. return fallback_result
  1852. assert isinstance(self, TensorBox)
  1853. assert "int" in str(index.get_dtype())
  1854. ndim = len(self.get_size())
  1855. if ndim == 0:
  1856. self = view(self, [1])
  1857. if isinstance(src, TensorBox) and len(src.get_size()) == 0:
  1858. src = view(src, [1])
  1859. if isinstance(index, TensorBox) and len(index.get_size()) == 0:
  1860. index = view(index, [1])
  1861. assert -len(self.get_size()) <= dim < len(self.get_size())
  1862. self.realize()
  1863. V.graph.realize_users_of(self.get_name())
  1864. index_loader = index.make_loader()
  1865. src_loader = src.make_loader() if isinstance(src, TensorBox) else None
  1866. def output_indexer(idx):
  1867. indirect_idx = list(idx)
  1868. indirect_idx[dim] = ops.indirect_indexing(index_loader(idx))
  1869. return indirect_idx
  1870. def fn(idx):
  1871. if src_loader:
  1872. return src_loader(idx)
  1873. else:
  1874. # src is a scalar
  1875. return ops.constant(src, self.get_dtype())
  1876. def backend_reduce_str(reduce):
  1877. if reduce == "sum":
  1878. return "atomic_add"
  1879. else:
  1880. # TODO: Need to support more reduction type
  1881. assert reduce is None
  1882. return None
  1883. if not include_self:
  1884. # zero out the corresponding elements first
  1885. zero_out = ir.Scatter(
  1886. device=self.get_device(),
  1887. dtype=self.get_dtype(),
  1888. inner_fn=lambda index: ops.constant(0, self.get_dtype()),
  1889. ranges=index.get_size(),
  1890. output_indexer=output_indexer,
  1891. scatter_mode=None,
  1892. )
  1893. buffer = ir.ComputedBuffer(
  1894. None,
  1895. ir.MutationLayout(self),
  1896. zero_out,
  1897. )
  1898. buffer.name = V.graph.register_buffer(buffer)
  1899. # self[index[i][j][k]][j][k] += src[i][j][k] # if dim == 0
  1900. # self[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1
  1901. # self[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2
  1902. scatter = ir.Scatter(
  1903. device=self.get_device(),
  1904. dtype=self.get_dtype(),
  1905. inner_fn=fn,
  1906. ranges=index.get_size(),
  1907. output_indexer=output_indexer,
  1908. scatter_mode=backend_reduce_str(reduce),
  1909. )
  1910. buffer = ir.ComputedBuffer(
  1911. None,
  1912. ir.MutationLayout(self),
  1913. scatter,
  1914. )
  1915. buffer.name = V.graph.register_buffer(buffer)
  1916. if ndim == 0:
  1917. self = view(self, [])
  1918. return self
  1919. def upsample_nearestnd(x, output_size, scales_x: Tuple[float] = None, n: int = 2):
  1920. x.realize_hint() # elements are reused
  1921. x_loader = x.make_loader()
  1922. i_sizes = x.get_size()[-n:]
  1923. batch = x.get_size()[:-n]
  1924. i_sizes = [V.graph.sizevars.guard_static_shape(i) for i in i_sizes]
  1925. assert len(scales_x) == n
  1926. o_sizes = output_size
  1927. scales = [i / o for i, o in zip(i_sizes, o_sizes)]
  1928. for i, scale in enumerate(scales):
  1929. if scale:
  1930. scales[i] = scale
  1931. def scale(x, scale):
  1932. x = ops.index_expr(x, torch.float32)
  1933. x = ops.mul(x, ops.constant(scale, torch.float32))
  1934. x = ops.to_dtype(x, torch.int32)
  1935. return ops.indirect_indexing(x)
  1936. def fn(idx):
  1937. x = idx[-n:]
  1938. b = idx[:-n]
  1939. return x_loader([*b, *[scale(i, s) for i, s in zip(x, scales)]])
  1940. return Pointwise.create(
  1941. device=x.get_device(),
  1942. dtype=x.get_dtype(),
  1943. inner_fn=fn,
  1944. ranges=[*batch, *o_sizes],
  1945. )
  1946. @register_lowering(aten.upsample_nearest1d.default)
  1947. def upsample_nearest1d(x, output_size, scales: Optional[float] = None):
  1948. return upsample_nearestnd(x, output_size, (scales,), n=1)
  1949. @register_lowering(aten.upsample_nearest2d.default)
  1950. def upsample_nearest2d(
  1951. x, output_size, scales_h: Optional[float] = None, scales_w: Optional[float] = None
  1952. ):
  1953. return upsample_nearestnd(x, output_size, (scales_h, scales_w), n=2)
  1954. @register_lowering(aten.upsample_nearest3d.default)
  1955. def upsample_nearest3d(
  1956. x,
  1957. output_size,
  1958. scales_d: Optional[float] = None,
  1959. scales_h: Optional[float] = None,
  1960. scales_w: Optional[float] = None,
  1961. ):
  1962. return upsample_nearestnd(x, output_size, (scales_d, scales_h, scales_w), n=3)
  1963. @register_lowering(aten.upsample_bicubic2d.default)
  1964. def upsample_bicubic2d_default(
  1965. x,
  1966. output_size,
  1967. align_corners: bool,
  1968. scales_h: Optional[float] = None,
  1969. scales_w: Optional[float] = None,
  1970. ):
  1971. x.realize_hint()
  1972. x_loader = x.make_loader()
  1973. N, C, iH, iW = x.get_size()
  1974. oH, oW = output_size
  1975. iH = V.graph.sizevars.guard_static_shape(iH)
  1976. iW = V.graph.sizevars.guard_static_shape(iW)
  1977. def get_int_dtype(maxval):
  1978. if maxval > torch.iinfo(torch.int32).max:
  1979. return torch.int64
  1980. return torch.int32
  1981. def compute_scale(in_size, out_size, align_corners, scale=None):
  1982. if align_corners:
  1983. return (in_size - 1) / (out_size - 1) if out_size > 1 else 0
  1984. else:
  1985. return 1 / scale if scale is not None and scale > 0 else in_size / out_size
  1986. def compute_source_index(scale, dst_index, align_corners):
  1987. dst_index_ie = ops.index_expr(dst_index, torch.float32)
  1988. if align_corners:
  1989. return ops.mul(scale, dst_index_ie)
  1990. else:
  1991. return ops.sub(
  1992. ops.mul(scale, ops.add(dst_index_ie, 0.5)), 0.5
  1993. ) # scale * (dst_index + 0.5) - 0.5
  1994. def cubic_convolution1(x, A):
  1995. # ((A + 2) * x - (A+3)) * x * x + 1
  1996. return ops.add(ops.mul(ops.mul(ops.sub(ops.mul(A + 2, x), A + 3), x), x), 1.0)
  1997. def cubic_convolution2(x, A):
  1998. # ((A * x - 5 * A) * x + 8 * A) * x - 4*A
  1999. return ops.sub(
  2000. ops.mul(ops.add(ops.mul(ops.sub(ops.mul(A, x), 5 * A), x), 8 * A), x), 4 * A
  2001. )
  2002. def get_cubic_upsample_coefficients(t):
  2003. A = -0.75
  2004. c0 = cubic_convolution2(ops.add(t, 1.0), A)
  2005. c1 = cubic_convolution1(t, A)
  2006. x2 = ops.sub(1.0, t)
  2007. c2 = cubic_convolution1(x2, A)
  2008. c3 = cubic_convolution2(ops.add(x2, 1.0), A)
  2009. return (
  2010. c0,
  2011. c1,
  2012. c2,
  2013. c3,
  2014. )
  2015. def cubic_interp1d(xs, t):
  2016. cs = get_cubic_upsample_coefficients(t)
  2017. # dot product between xs and cs
  2018. return ops.add(
  2019. ops.mul(xs[0], cs[0]),
  2020. ops.add(
  2021. ops.mul(xs[1], cs[1]),
  2022. ops.add(ops.mul(xs[2], cs[2]), ops.mul(xs[3], cs[3])),
  2023. ),
  2024. )
  2025. height_scale = compute_scale(iH, oH, align_corners, scales_h)
  2026. width_scale = compute_scale(iW, oW, align_corners, scales_h)
  2027. def clamp(v, min, max):
  2028. return ops.maximum(min, ops.minimum(max, v))
  2029. def fn(idx):
  2030. n, c, oy, ox = idx
  2031. real_x = compute_source_index(width_scale, ox, align_corners)
  2032. in_x = ops.floor(real_x)
  2033. t_x = ops.sub(real_x, in_x)
  2034. real_y = compute_source_index(height_scale, oy, align_corners)
  2035. in_y = ops.floor(real_y)
  2036. t_y = ops.sub(real_y, in_y)
  2037. def load_bounded(fy, fx):
  2038. iy = ops.indirect_indexing(clamp(fy, 0, iH - 1))
  2039. ix = ops.indirect_indexing(clamp(fx, 0, iW - 1))
  2040. return x_loader([n, c, iy, ix])
  2041. iy = ops.to_dtype(in_y, get_int_dtype(iH + 1))
  2042. ix = ops.to_dtype(in_x, get_int_dtype(iW + 1))
  2043. iys_ofs = tuple((ops.add(iy, ofs) for ofs in (-1, 0, 1, 2)))
  2044. ixs_ofs = tuple((ops.add(ix, ofs) for ofs in (-1, 0, 1, 2)))
  2045. def get_x_interp(y):
  2046. coeffs_x = tuple((load_bounded(y, x) for x in ixs_ofs))
  2047. return cubic_interp1d(coeffs_x, t_x)
  2048. coeffs_y = tuple(get_x_interp(y) for y in iys_ofs)
  2049. return cubic_interp1d(coeffs_y, t_y)
  2050. return Pointwise.create(
  2051. device=x.get_device(),
  2052. dtype=x.get_dtype(),
  2053. inner_fn=fn,
  2054. ranges=[N, C, sympy.Integer(oH), sympy.Integer(oW)],
  2055. )
  2056. @register_lowering(aten.reflection_pad2d)
  2057. def reflection_pad2d(x, padding):
  2058. assert len(padding) == 4
  2059. left, right, top, bot = padding
  2060. x_loader = x.make_loader()
  2061. *batch, h, w = x.get_size()
  2062. h = V.graph.sizevars.guard_static_shape(h)
  2063. w = V.graph.sizevars.guard_static_shape(w)
  2064. def reflect(x, size, offset):
  2065. size = ops.constant(size - 1, torch.int32)
  2066. x = ops.index_expr(x, torch.int32)
  2067. x = ops.sub(x, ops.constant(offset, torch.int32))
  2068. x = ops.sub(size, ops.abs(ops.sub(size, ops.abs(x))))
  2069. return ops.indirect_indexing(x)
  2070. def fn(idx):
  2071. *b, x, y = idx
  2072. x = reflect(x, h, top)
  2073. y = reflect(y, w, left)
  2074. return x_loader([*b, x, y])
  2075. return Pointwise.create(
  2076. device=x.get_device(),
  2077. dtype=x.get_dtype(),
  2078. inner_fn=fn,
  2079. ranges=[*batch, sympy.Integer(h + top + bot), sympy.Integer(w + left + right)],
  2080. )
  2081. @register_lowering(aten.reflection_pad2d_backward)
  2082. def reflection_pad2d_backward(grad_output, x, padding):
  2083. assert len(padding) == 4
  2084. left, right, top, bot = padding
  2085. *_, h, w = x.get_size()
  2086. h = V.graph.sizevars.guard_static_shape(h) - 1
  2087. w = V.graph.sizevars.guard_static_shape(w) - 1
  2088. grad_loader = grad_output.make_loader()
  2089. def fn(idx):
  2090. *b, x, y = idx
  2091. def load_from_output(x, y):
  2092. x = ops.indirect_indexing(ops.index_expr(x, torch.int32))
  2093. y = ops.indirect_indexing(ops.index_expr(y, torch.int32))
  2094. return grad_loader([*b, x, y])
  2095. def index_range_condition(index_range):
  2096. i, lb, ub = index_range
  2097. i = ops.index_expr(i, torch.int32)
  2098. return ops.and_(ops.ge(i, lb), ops.le(i, ub))
  2099. def accumulate(out_x, out_y, index_range1, index_range2=None):
  2100. nonlocal grad
  2101. # If the upper bound is less than the lower bound, we can get rid of one accumulation.
  2102. # This happens when the padding size is zero.
  2103. if index_range1[2] < index_range1[1]:
  2104. return
  2105. cond = index_range_condition(index_range1)
  2106. if index_range2 is not None:
  2107. if index_range2[2] < index_range2[1]:
  2108. return
  2109. cond = ops.and_(cond, index_range_condition(index_range2))
  2110. g = ops.masked(cond, lambda: load_from_output(out_x, out_y), 0.0)
  2111. grad = ops.add(grad, g)
  2112. # Areas after reflection:
  2113. #
  2114. # top-left | top | top-right
  2115. # -----------------------------------------
  2116. # left | center | right
  2117. # -----------------------------------------
  2118. # bottom-left | bottom | bottom-right
  2119. #
  2120. # The center area is the orignial matrix. Other areas are reflections.
  2121. center_x, center_y = x + top, y + left
  2122. top_reflect_x, left_reflect_y = top - x, left - y
  2123. bot_reflect_x, right_reflect_y = 2 * h + top - x, 2 * w + left - y
  2124. # Accumulate gradients from different areas
  2125. grad = load_from_output(center_x, center_y)
  2126. accumulate(center_x, left_reflect_y, (y, 1, left))
  2127. accumulate(center_x, right_reflect_y, (y, w - right, w - 1))
  2128. accumulate(top_reflect_x, center_y, (x, 1, top))
  2129. accumulate(bot_reflect_x, center_y, (x, h - bot, h - 1))
  2130. accumulate(top_reflect_x, left_reflect_y, (x, 1, top), (y, 1, left))
  2131. accumulate(top_reflect_x, right_reflect_y, (x, 1, top), (y, w - right, w - 1))
  2132. accumulate(bot_reflect_x, left_reflect_y, (x, h - bot, h - 1), (y, 1, left))
  2133. accumulate(
  2134. bot_reflect_x, right_reflect_y, (x, h - bot, h - 1), (y, w - right, w - 1)
  2135. )
  2136. return grad
  2137. return Pointwise.create(
  2138. device=grad_output.get_device(),
  2139. dtype=grad_output.get_dtype(),
  2140. inner_fn=fn,
  2141. ranges=list(x.get_size()),
  2142. )
  2143. @register_lowering(prims.rev.default)
  2144. def rev(x, dims):
  2145. # note - dims pre-canoncalized
  2146. x_loader = x.make_loader()
  2147. sizes = x.get_size()
  2148. def loader(idx):
  2149. idx = list(idx)
  2150. assert len(idx) == len(sizes)
  2151. for dim in dims:
  2152. idx[dim] = (sizes[dim] - 1) - idx[dim]
  2153. return x_loader(idx)
  2154. return Pointwise.create(
  2155. device=x.get_device(),
  2156. dtype=x.get_dtype(),
  2157. inner_fn=loader,
  2158. ranges=sizes,
  2159. )
  2160. @register_lowering(aten.constant_pad_nd, type_promotion_kind=None)
  2161. def constant_pad_nd(x, padding, fill_value=0):
  2162. assert (len(padding) % 2) == 0
  2163. if all(p == 0 for p in padding):
  2164. return x
  2165. sizes = x.get_size()
  2166. bounds = list(reversed(list(zip(padding[::2], padding[1::2]))))
  2167. n = len(sizes) - len(bounds)
  2168. output_size = list(sizes[:n])
  2169. mask_sizes = []
  2170. for (low, high), size in zip(bounds, sizes[n:]):
  2171. size = V.graph.sizevars.guard_static_shape(size)
  2172. mask_sizes.append(size)
  2173. output_size.append(sympy.expand(size + low + high))
  2174. assert len(output_size) == len(sizes)
  2175. fill_value = dtype_to_type(x.get_dtype())(fill_value)
  2176. def mask(index):
  2177. mask = []
  2178. for idx, (low, high), length in zip(index[n:], bounds, mask_sizes):
  2179. if low != 0:
  2180. mask.append(range_mask_low(idx))
  2181. if high != 0:
  2182. mask.append(range_mask_high(idx, length))
  2183. mask = functools.reduce(ops.and_, mask)
  2184. return ops.masked(mask, lambda: x_loader(index), fill_value)
  2185. def offset_fn(index):
  2186. new_index = list(index[:n])
  2187. for idx, (low, high) in zip(index[n:], bounds):
  2188. new_index.append(idx - low)
  2189. assert len(new_index) == len(index)
  2190. return mask(new_index)
  2191. x_loader = x.make_loader()
  2192. return Pointwise.create(
  2193. device=x.get_device(),
  2194. dtype=x.get_dtype(),
  2195. inner_fn=offset_fn,
  2196. ranges=output_size,
  2197. )
  2198. def range_mask_low(i: sympy.Expr):
  2199. return ops.ge(
  2200. ops.index_expr(i, torch.int64),
  2201. ops.index_expr(sympy.Integer(0), torch.int64),
  2202. )
  2203. def range_mask_high(i: sympy.Expr, length: sympy.Expr):
  2204. return ops.lt(
  2205. ops.index_expr(i, torch.int64),
  2206. ops.index_expr(length, torch.int64),
  2207. )
  2208. def range_mask(i: sympy.Expr, length: sympy.Expr):
  2209. return ops.and_(
  2210. range_mask_low(i),
  2211. range_mask_high(i, length),
  2212. )
  2213. def constant_boundary_condition_2d(x, fill_value, padding):
  2214. *_, h, w = x.get_size()
  2215. x_loader = x.make_loader()
  2216. def load(index):
  2217. *prefix, ih, iw = index
  2218. mask = ops.and_(
  2219. range_mask(ih, h),
  2220. range_mask(iw, w),
  2221. )
  2222. return ops.masked(mask, lambda: x_loader([*prefix, ih, iw]), fill_value)
  2223. return load
  2224. def pooling_size(x, i, kernel_size, stride, padding, ceil_mode):
  2225. x_out = ir.FloorDiv(
  2226. x + 2 * padding[i] - (kernel_size[i] - 1) + (stride[i] - 1), stride[i]
  2227. )
  2228. if ceil_mode:
  2229. x_alt = ir.FloorDiv(
  2230. x + 2 * padding[i] - (kernel_size[i] - 1) + 2 * (stride[i] - 1), stride[i]
  2231. )
  2232. if V.graph.sizevars.size_hint(x_out - x_alt) == 0:
  2233. # ceil mode is actually a no-op, lets guard on that
  2234. V.graph.sizevars.guard_equals(x_out, x_alt)
  2235. ceil_mode = False
  2236. else:
  2237. x_out = x_alt
  2238. return x_out, ceil_mode
  2239. fallback_max_pool2d_with_indices = fallback_handler(aten.max_pool2d_with_indices)
  2240. @register_lowering(aten.max_pool2d_with_indices, type_promotion_kind=None)
  2241. def max_pool2d_with_indices(
  2242. x, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False
  2243. ):
  2244. if padding == 0:
  2245. padding = [0, 0]
  2246. if not stride:
  2247. stride = kernel_size
  2248. assert dilation == 1 or all(d == 1 for d in dilation)
  2249. assert isinstance(x, TensorBox)
  2250. assert len(kernel_size) == 2
  2251. assert len(stride) == 2
  2252. assert len(padding) == 2
  2253. assert len(x.get_size()) in (3, 4)
  2254. x.realize_hint()
  2255. *batch, h, w = x.get_size()
  2256. h_out, ceil_mode1 = pooling_size(h, 0, kernel_size, stride, padding, ceil_mode)
  2257. w_out, ceil_mode2 = pooling_size(w, 1, kernel_size, stride, padding, ceil_mode)
  2258. if padding[0] or padding[1] or ceil_mode1 or ceil_mode2:
  2259. x_loader = constant_boundary_condition_2d(x, float("-inf"), padding)
  2260. else:
  2261. x_loader = x.make_loader()
  2262. new_size = list(batch) + [h_out, w_out]
  2263. window_size = kernel_size[0] * kernel_size[1]
  2264. if window_size > 25:
  2265. # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
  2266. return fallback_max_pool2d_with_indices(
  2267. x, kernel_size, stride, padding, dilation, ceil_mode
  2268. )
  2269. def fn(idx, return_index):
  2270. *prefix, bh, bw = idx
  2271. maxval = None
  2272. maxindex = None
  2273. for ih, iw in itertools.product(range(kernel_size[0]), range(kernel_size[1])):
  2274. ih = bh * stride[0] + ih - padding[0]
  2275. iw = bw * stride[1] + iw - padding[1]
  2276. val = x_loader([*prefix, ih, iw])
  2277. if return_index:
  2278. index = ops.index_expr(ih * w + iw, torch.int64)
  2279. if maxindex is None:
  2280. maxindex = index
  2281. else:
  2282. maxindex = ops.where(ops.gt(val, maxval), index, maxindex)
  2283. if maxval is None:
  2284. maxval = val
  2285. else:
  2286. maxval = ops.maximum(val, maxval)
  2287. if return_index:
  2288. return maxindex
  2289. else:
  2290. return maxval
  2291. r1 = Pointwise.create(
  2292. device=x.get_device(),
  2293. dtype=x.get_dtype(),
  2294. inner_fn=functools.partial(fn, return_index=False),
  2295. ranges=new_size,
  2296. )
  2297. r2 = Pointwise.create(
  2298. device=x.get_device(),
  2299. dtype=torch.int64,
  2300. inner_fn=functools.partial(fn, return_index=True),
  2301. ranges=new_size,
  2302. )
  2303. # TODO(jansel): should we force these to be realized?
  2304. return r1, r2
  2305. fallback_max_pool2d_with_indices_backward = fallback_handler(
  2306. aten.max_pool2d_with_indices_backward
  2307. )
  2308. @register_lowering(aten.max_pool2d_with_indices_backward, type_promotion_kind=None)
  2309. def max_pool2d_with_indices_backward(
  2310. grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices
  2311. ):
  2312. if padding == 0:
  2313. padding = [0, 0]
  2314. if not stride:
  2315. stride = kernel_size
  2316. assert dilation == 1 or all(d == 1 for d in dilation)
  2317. assert isinstance(x, TensorBox)
  2318. assert len(kernel_size) == 2
  2319. assert len(stride) == 2
  2320. assert len(padding) == 2
  2321. assert len(x.get_size()) in (3, 4)
  2322. # we will read this many times, so make sure it is computed
  2323. grad_output.realize_hint()
  2324. try:
  2325. gO_stride = grad_output.get_stride()
  2326. except AttributeError:
  2327. # some classes don't have `get_stride`
  2328. # TODO will need a better way of determining if inputs are channels-last
  2329. gO_stride = None
  2330. if isinstance(x, TensorBox) and isinstance(x.data.data, Pointwise):
  2331. data = x.data.data
  2332. x_buffer = ir.ComputedBuffer(
  2333. name=None,
  2334. layout=ir.FlexibleLayout(
  2335. device=data.get_device(),
  2336. dtype=data.get_dtype(),
  2337. size=data.get_size(),
  2338. ),
  2339. data=data,
  2340. )
  2341. x_buffer.decide_layout()
  2342. x_stride = x_buffer.get_stride()
  2343. else:
  2344. try:
  2345. x_stride = x.get_stride()
  2346. except AttributeError:
  2347. x_stride = None
  2348. if (
  2349. (x_stride is not None and x_stride[1] == 1)
  2350. or gO_stride is not None
  2351. and gO_stride[1] == 1
  2352. ):
  2353. # don't codegen channels-last, it's very slow
  2354. return fallback_max_pool2d_with_indices_backward(
  2355. grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices
  2356. )
  2357. indices.realize_hint()
  2358. *batch, height, width = x.get_size()
  2359. *_, pooled_height, pooled_width = grad_output.get_size()
  2360. indices_loader = indices.make_loader()
  2361. grad_loader = grad_output.make_loader()
  2362. new_size = list(x.get_size())
  2363. h_window_size = max(
  2364. [
  2365. max(h // stride[0] - max(0, (h - kernel_size[0]) // stride[0]), 1)
  2366. for h in range(kernel_size[0] * 2)
  2367. ]
  2368. )
  2369. w_window_size = max(
  2370. [
  2371. max(w // stride[1] - max(0, (w - kernel_size[1]) // stride[1]), 1)
  2372. for w in range(kernel_size[1] * 2)
  2373. ]
  2374. )
  2375. window_size = h_window_size * w_window_size
  2376. if window_size > 25:
  2377. # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
  2378. return fallback_max_pool2d_with_indices_backward(
  2379. grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices
  2380. )
  2381. def fn(idx):
  2382. *prefix, h, w = idx
  2383. index_test = ops.index_expr(h * width + w, torch.int32)
  2384. h = h + padding[0]
  2385. w = w + padding[1]
  2386. phstart = ops.index_expr(
  2387. ir.FloorDiv(h - kernel_size[0] + stride[0], stride[0]), torch.int32
  2388. )
  2389. pwstart = ops.index_expr(
  2390. ir.FloorDiv(w - kernel_size[1] + stride[1], stride[1]), torch.int32
  2391. )
  2392. phend = ops.index_expr(ir.FloorDiv(h, stride[0]) + 1, torch.int32)
  2393. pwend = ops.index_expr(ir.FloorDiv(w, stride[1]) + 1, torch.int32)
  2394. phstart = ops.maximum(phstart, ops.constant(0, torch.int32))
  2395. pwstart = ops.maximum(pwstart, ops.constant(0, torch.int32))
  2396. phend = ops.minimum(phend, ops.index_expr(pooled_height, torch.int32))
  2397. pwend = ops.minimum(pwend, ops.index_expr(pooled_width, torch.int32))
  2398. gradient = None
  2399. for ph_ in range(h_window_size):
  2400. for pw_ in range(w_window_size):
  2401. ph = ops.add(phstart, ops.constant(ph_, torch.int32))
  2402. pw = ops.add(pwstart, ops.constant(pw_, torch.int32))
  2403. grad_index = [
  2404. *prefix,
  2405. ops.indirect_indexing(
  2406. ops.minimum(ph, ops.sub(phend, ops.constant(1, torch.int32)))
  2407. ),
  2408. ops.indirect_indexing(
  2409. ops.minimum(pw, ops.sub(pwend, ops.constant(1, torch.int32)))
  2410. ),
  2411. ]
  2412. index_actual = indices_loader(grad_index)
  2413. grad_part = grad_loader(grad_index)
  2414. check = ops.eq(index_actual, index_test)
  2415. if gradient is None:
  2416. # don't need mask for 0, 0
  2417. gradient = ops.where(
  2418. check, grad_part, ops.constant(0.0, torch.float32)
  2419. )
  2420. else:
  2421. mask = ops.and_(
  2422. ops.and_(
  2423. ops.lt(ph, phend),
  2424. ops.lt(pw, pwend),
  2425. ),
  2426. check,
  2427. )
  2428. gradient = ops.where(mask, ops.add(gradient, grad_part), gradient)
  2429. assert gradient is not None
  2430. return gradient
  2431. return Pointwise.create(
  2432. device=grad_output.get_device(),
  2433. dtype=grad_output.get_dtype(),
  2434. inner_fn=fn,
  2435. ranges=new_size,
  2436. )
  2437. def pad_adaptive_loader(x):
  2438. *_, h, w = x.get_size()
  2439. x_loader = x.make_loader()
  2440. def load(prefix, increments, start_indices, end_indices):
  2441. ih, iw = increments
  2442. h_start_index, w_start_index = start_indices
  2443. h_end_index, w_end_index = end_indices
  2444. mask = ops.and_(
  2445. ops.lt(
  2446. ops.index_expr(h_start_index + ih, torch.int64),
  2447. ops.index_expr(h_end_index, torch.int64),
  2448. ),
  2449. ops.lt(
  2450. ops.index_expr(w_start_index + iw, torch.int64),
  2451. ops.index_expr(w_end_index, torch.int64),
  2452. ),
  2453. )
  2454. return ops.masked(
  2455. mask,
  2456. lambda: x_loader([*prefix, h_start_index + ih, w_start_index + iw]),
  2457. 0.0,
  2458. )
  2459. return load
  2460. def _adaptive_pooling_idx_sum(kernel_maxes, start_index_fns, end_index_fns):
  2461. h_start_index_fn, w_start_index_fn = start_index_fns
  2462. h_end_index_fn, w_end_index_fn = end_index_fns
  2463. def fn_sum(idx, loader):
  2464. *prefix, bh, bw = idx
  2465. h_start_index = h_start_index_fn(bh)
  2466. h_end_index = h_end_index_fn(bh)
  2467. w_start_index = w_start_index_fn(bw)
  2468. w_end_index = w_end_index_fn(bw)
  2469. total = None
  2470. for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])):
  2471. val = loader(
  2472. prefix,
  2473. [ih, iw],
  2474. [h_start_index, w_start_index],
  2475. [h_end_index, w_end_index],
  2476. )
  2477. if total is None:
  2478. total = val
  2479. else:
  2480. total = ops.add(val, total)
  2481. return total
  2482. return fn_sum
  2483. fallback_adaptive_avg_pool2d = fallback_handler(aten._adaptive_avg_pool2d)
  2484. @register_lowering(aten._adaptive_avg_pool2d)
  2485. def _adaptive_avg_pool2d(x, output_size):
  2486. assert isinstance(x, TensorBox)
  2487. assert len(output_size) == 2
  2488. x.realize_hint()
  2489. *batch, h_in, w_in = x.get_size()
  2490. h_in = V.graph.sizevars.guard_static_shape(h_in)
  2491. w_in = V.graph.sizevars.guard_static_shape(w_in)
  2492. h_out, w_out = output_size
  2493. # no-op if the same input and output
  2494. if h_in == h_out and w_in == w_out:
  2495. return clone(x)
  2496. if h_in % h_out == 0 and w_in % w_out == 0:
  2497. kernel_size = [h_in // h_out, w_in // w_out]
  2498. return avg_pool2d(x, kernel_size)
  2499. h_kernel_max = ceildiv((h_in + h_out - 1), h_out)
  2500. w_kernel_max = ceildiv((w_in + w_out - 1), w_out)
  2501. new_size = list(batch) + [h_out, w_out]
  2502. dtype = x.get_dtype()
  2503. def start_index(index, out_dim, inp_dim):
  2504. return ir.FloorDiv((index * inp_dim), out_dim)
  2505. def end_index(index, out_dim, inp_dim):
  2506. return ir.FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim)
  2507. h_start_index = functools.partial(start_index, out_dim=h_out, inp_dim=h_in)
  2508. h_end_index = functools.partial(end_index, out_dim=h_out, inp_dim=h_in)
  2509. w_start_index = functools.partial(start_index, out_dim=w_out, inp_dim=w_in)
  2510. w_end_index = functools.partial(end_index, out_dim=w_out, inp_dim=w_in)
  2511. window_size = h_kernel_max * w_kernel_max
  2512. if window_size > 25:
  2513. # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
  2514. return fallback_adaptive_avg_pool2d(x, output_size)
  2515. fn_sum = _adaptive_pooling_idx_sum(
  2516. [h_kernel_max, w_kernel_max],
  2517. [h_start_index, w_start_index],
  2518. [h_end_index, w_end_index],
  2519. )
  2520. ones_loader = pad_adaptive_loader(ones_like(x))
  2521. def fn(idx):
  2522. return ops.div(fn_sum(idx, pad_adaptive_loader(x)), fn_sum(idx, ones_loader))
  2523. rv = Pointwise.create(
  2524. device=x.get_device(),
  2525. dtype=dtype,
  2526. inner_fn=fn,
  2527. ranges=new_size,
  2528. )
  2529. # TODO: should we force these to be realized?
  2530. return rv
  2531. @register_lowering(aten.upsample_nearest2d_backward.default)
  2532. def upsample_nearest2d_backward(
  2533. x, output_size=None, input_size=None, scales_h=None, scales_w=None
  2534. ):
  2535. x.realize_hint()
  2536. *batch, inp_h, inp_w = x.get_size()
  2537. inp_h = V.graph.sizevars.guard_static_shape(inp_h)
  2538. inp_w = V.graph.sizevars.guard_static_shape(inp_w)
  2539. *batch, out_h, out_w = input_size
  2540. if inp_h % out_h == 0 and inp_w % out_w == 0:
  2541. return avg_pool2d(x, [inp_h // out_h, inp_w // out_w], divisor_override=1)
  2542. h_kernel_max = ceildiv(inp_h, out_h)
  2543. w_kernel_max = ceildiv(inp_w, out_w)
  2544. def start_index(index, out_dim, inp_dim):
  2545. return ir.CeilDiv(index * inp_dim, out_dim)
  2546. def end_index(index, out_dim, inp_dim):
  2547. return start_index((index + 1), out_dim, inp_dim)
  2548. h_start_index = functools.partial(start_index, out_dim=out_h, inp_dim=inp_h)
  2549. h_end_index = functools.partial(end_index, out_dim=out_h, inp_dim=inp_h)
  2550. w_start_index = functools.partial(start_index, out_dim=out_w, inp_dim=inp_w)
  2551. w_end_index = functools.partial(end_index, out_dim=out_w, inp_dim=inp_w)
  2552. fn_sum = _adaptive_pooling_idx_sum(
  2553. [h_kernel_max, w_kernel_max],
  2554. [h_start_index, w_start_index],
  2555. [h_end_index, w_end_index],
  2556. )
  2557. def fn(idx):
  2558. return fn_sum(idx, pad_adaptive_loader(x))
  2559. rv = Pointwise.create(
  2560. device=x.get_device(),
  2561. dtype=x.get_dtype(),
  2562. inner_fn=fn,
  2563. ranges=list(input_size),
  2564. )
  2565. return rv
  2566. fallback_avg_pool2d = fallback_handler(aten.avg_pool2d)
  2567. @register_lowering(aten.avg_pool2d, type_promotion_kind=None)
  2568. def avg_pool2d(
  2569. x,
  2570. kernel_size,
  2571. stride=(),
  2572. padding=0,
  2573. ceil_mode=False,
  2574. count_include_pad=True,
  2575. divisor_override=None,
  2576. ):
  2577. if not stride:
  2578. stride = kernel_size
  2579. if not padding:
  2580. padding = [0, 0]
  2581. assert isinstance(x, TensorBox)
  2582. assert len(kernel_size) == 2
  2583. assert len(stride) == 2
  2584. assert len(padding) == 2
  2585. assert len(x.get_size()) in (3, 4)
  2586. x.realize_hint()
  2587. *batch, h, w = x.get_size()
  2588. h_out, ceil_mode1 = pooling_size(h, 0, kernel_size, stride, padding, ceil_mode)
  2589. w_out, ceil_mode2 = pooling_size(w, 1, kernel_size, stride, padding, ceil_mode)
  2590. if padding[0] or padding[1] or ceil_mode1 or ceil_mode2:
  2591. x_loader = constant_boundary_condition_2d(x, 0.0, padding)
  2592. had_padding = True
  2593. else:
  2594. x_loader = x.make_loader()
  2595. had_padding = False
  2596. new_size = list(batch) + [h_out, w_out]
  2597. dtype = x.get_dtype()
  2598. window_size = kernel_size[0] * kernel_size[1]
  2599. if window_size > 25:
  2600. # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
  2601. return fallback_avg_pool2d(
  2602. x,
  2603. kernel_size,
  2604. stride,
  2605. padding,
  2606. ceil_mode,
  2607. count_include_pad,
  2608. divisor_override,
  2609. )
  2610. def fn_sum(idx, loader):
  2611. *prefix, bh, bw = idx
  2612. total = None
  2613. for ih, iw in itertools.product(range(kernel_size[0]), range(kernel_size[1])):
  2614. ih = bh * stride[0] + ih - padding[0]
  2615. iw = bw * stride[1] + iw - padding[1]
  2616. val = loader([*prefix, ih, iw])
  2617. if total is None:
  2618. total = val
  2619. else:
  2620. total = ops.add(val, total)
  2621. return total
  2622. if count_include_pad or not had_padding or divisor_override:
  2623. if divisor_override:
  2624. scale = 1 / divisor_override
  2625. else:
  2626. scale = 1.0 / (kernel_size[0] * kernel_size[1])
  2627. def fn(idx):
  2628. return ops.mul(fn_sum(idx, x_loader), ops.constant(scale, dtype))
  2629. else:
  2630. ones_loader = constant_boundary_condition_2d(ones_like(x), 0.0, padding)
  2631. def fn(idx):
  2632. # TODO(jansel): optimize to do `int(x<h)` rather than `x<h?1:0`
  2633. return ops.div(fn_sum(idx, x_loader), fn_sum(idx, ones_loader))
  2634. rv = Pointwise.create(
  2635. device=x.get_device(),
  2636. dtype=dtype,
  2637. inner_fn=fn,
  2638. ranges=new_size,
  2639. )
  2640. # TODO(jansel): should we force these to be realized?
  2641. return rv
  2642. fallback_avg_pool2d_backward = fallback_handler(aten.avg_pool2d_backward)
  2643. @register_lowering(aten.avg_pool2d_backward, type_promotion_kind=None)
  2644. def avg_pool2d_backward(
  2645. grad_output,
  2646. x,
  2647. kernel_size,
  2648. stride,
  2649. padding,
  2650. ceil_mode,
  2651. count_include_pad,
  2652. divisor_override=None,
  2653. ):
  2654. assert not divisor_override
  2655. if not stride:
  2656. stride = kernel_size
  2657. if not padding:
  2658. padding = [0, 0]
  2659. assert isinstance(grad_output, TensorBox)
  2660. assert isinstance(x, TensorBox)
  2661. assert len(kernel_size) == 2
  2662. assert len(stride) == 2
  2663. assert len(padding) == 2
  2664. assert len(x.get_size()) in (3, 4)
  2665. grad_output.realize_hint() # we will read this many times, so make sure it is computed
  2666. *batch, height, width = x.get_size()
  2667. h_out, ceil_mode1 = pooling_size(height, 0, kernel_size, stride, padding, ceil_mode)
  2668. w_out, ceil_mode2 = pooling_size(width, 1, kernel_size, stride, padding, ceil_mode)
  2669. grad_loader = grad_output.make_loader()
  2670. had_padding = padding[0] or padding[1] or ceil_mode1 or ceil_mode2
  2671. *_, pooled_height, pooled_width = grad_output.get_size()
  2672. new_size = list(x.get_size())
  2673. dtype = x.get_dtype()
  2674. h_window_size = max(
  2675. [
  2676. max(h // stride[0] - max(0, (h - kernel_size[0]) // stride[0]), 1)
  2677. for h in range(kernel_size[0] * 2)
  2678. ]
  2679. )
  2680. w_window_size = max(
  2681. [
  2682. max(w // stride[1] - max(0, (w - kernel_size[1]) // stride[1]), 1)
  2683. for w in range(kernel_size[1] * 2)
  2684. ]
  2685. )
  2686. window_size = h_window_size * w_window_size
  2687. if window_size > 25:
  2688. # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
  2689. return fallback_avg_pool2d_backward(
  2690. grad_output,
  2691. x,
  2692. kernel_size,
  2693. stride,
  2694. padding,
  2695. ceil_mode,
  2696. count_include_pad,
  2697. divisor_override,
  2698. )
  2699. def compute_pool_size_without_padding(ph, pw):
  2700. """
  2701. This computes the scaling factor that we will divide an element
  2702. by when `count_include_pad=False`
  2703. """
  2704. stride_h = ops.constant(stride[0], torch.int32)
  2705. stride_w = ops.constant(stride[1], torch.int32)
  2706. pad_h = ops.constant(padding[0], torch.int32)
  2707. pad_w = ops.constant(padding[1], torch.int32)
  2708. kernel_h = ops.constant(kernel_size[0], torch.int32)
  2709. kernel_w = ops.constant(kernel_size[1], torch.int32)
  2710. hstart = ops.sub(ops.mul(ph, stride_h), pad_h)
  2711. wstart = ops.sub(ops.mul(pw, stride_w), pad_w)
  2712. hend = ops.minimum(
  2713. ops.add(hstart, kernel_h),
  2714. ops.add(ops.index_expr(height, torch.int32), pad_h),
  2715. )
  2716. wend = ops.minimum(
  2717. ops.add(wstart, kernel_w),
  2718. ops.add(ops.index_expr(width, torch.int32), pad_w),
  2719. )
  2720. hstart = ops.maximum(hstart, ops.constant(0, torch.int32))
  2721. wstart = ops.maximum(wstart, ops.constant(0, torch.int32))
  2722. hend = ops.minimum(hend, ops.index_expr(height, torch.int32))
  2723. wend = ops.minimum(wend, ops.index_expr(width, torch.int32))
  2724. divide_factor = ops.mul(ops.sub(hend, hstart), ops.sub(wend, wstart))
  2725. return divide_factor
  2726. def fn(idx):
  2727. *prefix, h, w = idx
  2728. h = h + padding[0]
  2729. w = w + padding[1]
  2730. phstart = ops.index_expr(
  2731. ir.FloorDiv(h - kernel_size[0] + stride[0], stride[0]), torch.int32
  2732. )
  2733. pwstart = ops.index_expr(
  2734. ir.FloorDiv(w - kernel_size[1] + stride[1], stride[1]), torch.int32
  2735. )
  2736. phend = ops.index_expr(ir.FloorDiv(h, stride[0]) + 1, torch.int32)
  2737. pwend = ops.index_expr(ir.FloorDiv(w, stride[1]) + 1, torch.int32)
  2738. phstart = ops.maximum(phstart, ops.constant(0, torch.int32))
  2739. pwstart = ops.maximum(pwstart, ops.constant(0, torch.int32))
  2740. phend = ops.minimum(phend, ops.index_expr(pooled_height, torch.int32))
  2741. pwend = ops.minimum(pwend, ops.index_expr(pooled_width, torch.int32))
  2742. gradient = None
  2743. for ph_ in range(h_window_size):
  2744. for pw_ in range(w_window_size):
  2745. ph = ops.add(phstart, ops.constant(ph_, torch.int32))
  2746. pw = ops.add(pwstart, ops.constant(pw_, torch.int32))
  2747. if count_include_pad or not had_padding:
  2748. scale = kernel_size[0] * kernel_size[1]
  2749. else:
  2750. scale = compute_pool_size_without_padding(ph, pw)
  2751. part = ops.truediv(
  2752. grad_loader(
  2753. [
  2754. *prefix,
  2755. ops.indirect_indexing(
  2756. ops.minimum(
  2757. ph, ops.sub(phend, ops.constant(1, torch.int32))
  2758. )
  2759. ),
  2760. ops.indirect_indexing(
  2761. ops.minimum(
  2762. pw, ops.sub(pwend, ops.constant(1, torch.int32))
  2763. )
  2764. ),
  2765. ]
  2766. ),
  2767. scale,
  2768. )
  2769. mask = ops.and_(
  2770. ops.lt(ph, phend),
  2771. ops.lt(pw, pwend),
  2772. )
  2773. if gradient is None:
  2774. gradient = ops.where(mask, part, ops.constant(0.0, torch.float32))
  2775. else:
  2776. gradient = ops.where(mask, ops.add(gradient, part), gradient)
  2777. assert gradient is not None
  2778. return gradient
  2779. rv = Pointwise.create(
  2780. device=grad_output.get_device(),
  2781. dtype=dtype,
  2782. inner_fn=fn,
  2783. ranges=new_size,
  2784. )
  2785. return rv
  2786. def _validate_reduction_axis(x, axis):
  2787. size = x.get_size()
  2788. if isinstance(axis, int):
  2789. axis = [axis]
  2790. elif not axis:
  2791. axis = range(len(size))
  2792. axis = list(axis)
  2793. for i in range(len(axis)):
  2794. if axis[i] < 0:
  2795. axis[i] += len(size) if len(size) else 1
  2796. assert 0 <= axis[i] < len(size) or (len(size) == 0 and axis[i] == 0)
  2797. assert len(set(axis)) == len(axis), "reduction axis not unique"
  2798. return axis
  2799. def make_reduction(reduction_type: str, override_return_dtype=None):
  2800. def inner(x, axis=None, keepdims=False, *, dtype=None):
  2801. if reduction_type == "min" and axis is not None:
  2802. return (
  2803. reduce_amin(x, axis, keepdims, dtype=dtype),
  2804. reduce_argmin(x, axis, keepdims),
  2805. )
  2806. if reduction_type == "max" and axis is not None:
  2807. return (
  2808. reduce_amax(x, axis, keepdims, dtype=dtype),
  2809. reduce_argmax(x, axis, keepdims),
  2810. )
  2811. if dtype is not None:
  2812. x = to_dtype(x, dtype)
  2813. if reduction_type == "any":
  2814. x = to_dtype(x, torch.bool)
  2815. size = x.get_size()
  2816. axis = set(_validate_reduction_axis(x, axis))
  2817. kept_sizes = []
  2818. kept_idx = []
  2819. reduced_sizes = []
  2820. reduced_idx = []
  2821. for i in range(len(size)):
  2822. if i in axis:
  2823. reduced_idx.append(i)
  2824. reduced_sizes.append(size[i])
  2825. else:
  2826. kept_idx.append(i)
  2827. kept_sizes.append(size[i])
  2828. def loader(index, reduction_index):
  2829. assert len(reduction_index) == len(reduced_idx)
  2830. if keepdims:
  2831. assert len(index) == len(size)
  2832. assert all(index[i] == 0 for i in reduced_idx)
  2833. index = [index[i] for i in kept_idx]
  2834. assert len(index) == len(kept_idx)
  2835. new_index = [None] * (len(index) + len(reduction_index))
  2836. for idx, var in itertools.chain(
  2837. zip(kept_idx, index), zip(reduced_idx, reduction_index)
  2838. ):
  2839. new_index[idx] = var
  2840. return inner_loader(new_index)
  2841. if keepdims:
  2842. new_size = list(size)
  2843. for i in reduced_idx:
  2844. new_size[i] = sympy.Integer(1)
  2845. else:
  2846. new_size = kept_sizes
  2847. inner_loader = x.make_loader()
  2848. result = Reduction.create(
  2849. device=x.get_device(),
  2850. dst_dtype=override_return_dtype or x.get_dtype(),
  2851. src_dtype=x.get_dtype(),
  2852. inner_fn=loader,
  2853. ranges=new_size,
  2854. reduction_ranges=reduced_sizes,
  2855. reduction_type={"amax": "max", "amin": "min"}.get(
  2856. reduction_type, reduction_type
  2857. ),
  2858. )
  2859. if isinstance(
  2860. result.data.data, Reduction
  2861. ): # Only realize if reduction isn't unrolled
  2862. result.realize()
  2863. return result
  2864. return inner
  2865. @register_lowering(aten.mean)
  2866. def mean(x, axis=None, keepdim=False, *, dtype=None):
  2867. if dtype is not None:
  2868. x = to_dtype(x, dtype)
  2869. size = x.get_size()
  2870. axis = _validate_reduction_axis(x, axis)
  2871. # compute in higher-precision until end of mean lowering
  2872. output_dtype = x.get_dtype()
  2873. if output_dtype in (torch.float16, torch.bfloat16):
  2874. x = to_dtype(x, torch.float)
  2875. sum_result = sum_(x, axis, keepdim)
  2876. denom = sympy_product(size[i] for i in axis)
  2877. denom = ir.IndexingConstant(denom, x.get_dtype(), x.get_device())
  2878. denom = ExpandView.create(denom, list(sum_result.get_size()))
  2879. return to_dtype(div(sum_result, denom), output_dtype)
  2880. def var_mean_(x, axis, correction, keepdim, return_mean):
  2881. if correction is None:
  2882. correction = 1
  2883. size = x.get_size()
  2884. axis = _validate_reduction_axis(x, axis)
  2885. x_mean = mean(x, axis, keepdim=True)
  2886. if return_mean:
  2887. x_mean.realize()
  2888. diffs = square(sub(x, x_mean))
  2889. sum_result = sum_(diffs, axis, keepdim)
  2890. denom = sympy_product(size[i] for i in axis)
  2891. if correction:
  2892. denom = denom - correction
  2893. denom = ir.IndexingConstant(denom, x.get_dtype(), x.get_device())
  2894. denom = ExpandView.create(denom, list(sum_result.get_size()))
  2895. x_var = div(sum_result, denom)
  2896. if not return_mean:
  2897. return x_var
  2898. x_mean = x_mean if keepdim else squeeze(x_mean, axis)
  2899. return x_var, x_mean
  2900. @register_lowering([aten.var, prims.var])
  2901. def var_(x, axis=None, *, correction=None, keepdim=False):
  2902. return var_mean_(
  2903. x, axis=axis, correction=correction, keepdim=keepdim, return_mean=False
  2904. )
  2905. @register_lowering(aten.var_mean)
  2906. def var_mean(x, axis=None, *, correction=None, keepdim=False):
  2907. return var_mean_(
  2908. x, axis=axis, correction=correction, keepdim=keepdim, return_mean=True
  2909. )
  2910. def pow_recursive(x, y, dtype):
  2911. if y < 0:
  2912. return pow_recursive(ops.reciprocal(x), -y, dtype)
  2913. if y == 0:
  2914. return ops.constant(1, dtype)
  2915. if y == 1:
  2916. return x
  2917. result = pow_recursive(x, y // 2, dtype)
  2918. result = ops.mul(result, result)
  2919. if (y % 2) == 1:
  2920. result = ops.mul(result, x)
  2921. return result
  2922. @make_pointwise
  2923. def pow_native(a, b):
  2924. return ops.pow(a, b)
  2925. def _is_ir_node_and_cuda(x):
  2926. if isinstance(x, ir.IRNode) and decode_device(x.get_device()).type == "cuda":
  2927. return True
  2928. return False
  2929. @register_lowering(aten.pow, broadcast=True)
  2930. def pow(a, b):
  2931. if _is_ir_node_and_cuda(a) and _is_ir_node_and_cuda(b):
  2932. assert a.get_dtype() in (
  2933. torch.float16,
  2934. torch.float32,
  2935. torch.float64,
  2936. ), "Pow input must be floating point."
  2937. if isinstance(b, float) and b == int(b):
  2938. return pow(a, int(b))
  2939. elif isinstance(b, float) and b == 0.5:
  2940. return sqrt(a)
  2941. elif isinstance(b, int) and b == 1:
  2942. return a
  2943. elif isinstance(b, int) and -32 < b < 32:
  2944. # Optimize away small fixed powers
  2945. loader = a.make_loader()
  2946. def fn(idx):
  2947. return pow_recursive(loader(idx), b, a.get_dtype())
  2948. return Pointwise.create(
  2949. device=a.get_device(),
  2950. dtype=a.get_dtype(),
  2951. inner_fn=fn,
  2952. ranges=a.get_size(),
  2953. )
  2954. if isinstance(a, Number):
  2955. if a == 1:
  2956. return full_like(b, 1)
  2957. if a == 2 and is_float_dtype(b.get_dtype()):
  2958. return exp2(b)
  2959. return pow_native(a, b)
  2960. def mutate_to(changed, val):
  2961. if isinstance(changed, TensorBox):
  2962. changed_data = changed.data
  2963. else:
  2964. changed_data = changed
  2965. if isinstance(val, TensorBox):
  2966. val = val.data
  2967. if not isinstance(val, ir.StorageBox):
  2968. # introduce a copy to handle views
  2969. val = Pointwise.create(
  2970. device=changed.get_device(),
  2971. dtype=changed.get_dtype(),
  2972. inner_fn=val.make_loader(),
  2973. ranges=changed.get_size(),
  2974. ).data
  2975. assert isinstance(val, ir.StorageBox)
  2976. if isinstance(changed_data, ir.StorageBox) and not (
  2977. changed_data.is_input_buffer() or isinstance(changed_data.data, ir.NopKernel)
  2978. ):
  2979. # Fast path, just swing the data pointer
  2980. val.realize()
  2981. changed_data.data = val.data
  2982. return changed
  2983. ir.MutationLayout.realize_into(val, changed_data)
  2984. return changed
  2985. @register_lowering(aten.fill_)
  2986. def fill_(x, fill_value):
  2987. return mutate_to(x, full_like(x, fill_value))
  2988. @register_lowering(aten.copy_, type_promotion_kind=None)
  2989. def copy_(dst, src, non_blocking=False):
  2990. src = to_device(src, dst.get_device())
  2991. src = to_dtype(src, dst.get_dtype())
  2992. src = expand(src, dst.get_size())
  2993. return mutate_to(dst, src)
  2994. @make_pointwise
  2995. def floordiv(a, b):
  2996. return ops.floordiv(a, b)
  2997. @make_pointwise
  2998. def truncdiv(a, b):
  2999. return ops.truncdiv(a, b)
  3000. @register_lowering(aten.div, broadcast=True)
  3001. def div_mode(a, b, rounding_mode=None):
  3002. both_integer = is_integer_type(a) and is_integer_type(b)
  3003. both_boolean = is_boolean_type(a) and is_boolean_type(b)
  3004. # floordiv and truncdiv need special handling for integer tensors on Triton,
  3005. # see the discussion at https://github.com/openai/triton/issues/605
  3006. if rounding_mode == "floor":
  3007. assert not both_boolean, "floordiv operands can not be boolean at the same time"
  3008. return floordiv(a, b) if both_integer else floor(div(a, b))
  3009. if rounding_mode == "trunc":
  3010. assert not both_boolean, "truncdiv operands can not be boolean at the same time"
  3011. return truncdiv(a, b) if both_integer else trunc(div(a, b))
  3012. return div(a, b)
  3013. @register_lowering([aten.mul], broadcast=True)
  3014. def mul(a, b):
  3015. both_bool = is_boolean_type(a) and is_boolean_type(b)
  3016. if both_bool:
  3017. return logical_and(a, b)
  3018. else:
  3019. fn = ops_wrapper(aten.mul.__name__)
  3020. return make_pointwise(fn)(a, b)
  3021. # NOTE: prims.div maps to a / b in C, so performs truncation division on
  3022. # integer inputs and true division for floating and complex inputs.
  3023. @register_lowering([prims.div], broadcast=True)
  3024. def div_prim(a, b):
  3025. is_integral = is_boolean_type(a) or is_integer_type(a)
  3026. if is_integral:
  3027. return truncdiv(a, b)
  3028. def fn(*args):
  3029. return ops.div(*args)
  3030. return make_pointwise(fn)(a, b)
  3031. div = register_lowering(
  3032. [aten.true_divide, aten.div.Tensor],
  3033. broadcast=True,
  3034. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  3035. )(div_prim)
  3036. @register_lowering([aten.fmod, prims.fmod], broadcast=True)
  3037. def fmod(a, b):
  3038. is_integral = is_boolean_type(a) or is_integer_type(a)
  3039. if is_integral:
  3040. def fn(a, b):
  3041. return ops.mod(a, b)
  3042. else:
  3043. def fn(a, b):
  3044. return ops.fmod(a, b)
  3045. return make_pointwise(fn)(a, b)
  3046. @register_lowering(aten.rsqrt)
  3047. def rsqrt(x):
  3048. dtype = x.get_dtype()
  3049. if is_integer_dtype(dtype) or is_boolean_dtype(dtype):
  3050. x = to_dtype(x, torch.get_default_dtype())
  3051. def _rsqrt(x):
  3052. return ops.rsqrt(x)
  3053. return make_pointwise(_rsqrt)(x)
  3054. @register_lowering([aten.sum, prims.sum])
  3055. def sum_(x, axis=None, keepdims=False, *, dtype=None):
  3056. if (
  3057. is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype())
  3058. ) and dtype is None:
  3059. dtype = torch.int64
  3060. fn = make_reduction("sum", override_return_dtype=dtype)
  3061. return fn(x, axis, keepdims, dtype=dtype)
  3062. register_lowering(aten.max)(make_reduction("max"))
  3063. register_lowering(aten.min)(make_reduction("min"))
  3064. reduce_amax = register_lowering(aten.amax)(make_reduction("amax"))
  3065. reduce_amin = register_lowering(aten.amin)(make_reduction("amin"))
  3066. register_lowering(aten.any)(make_reduction("any", override_return_dtype=torch.bool))
  3067. reduce_argmax = register_lowering(aten.argmax)(
  3068. make_reduction("argmax", override_return_dtype=torch.int64)
  3069. )
  3070. reduce_argmin = register_lowering(aten.argmin)(
  3071. make_reduction("argmin", override_return_dtype=torch.int64)
  3072. )
  3073. add = register_pointwise(
  3074. aten.add, allow_alpha=True, override_fn_when_input_bool="logical_or"
  3075. )
  3076. def register_pointwise_numeric(op):
  3077. return register_pointwise(
  3078. op, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
  3079. )
  3080. def register_pointwise_numeric_ldf64(op):
  3081. return register_pointwise(
  3082. op,
  3083. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  3084. use_libdevice_for_f64=True,
  3085. )
  3086. exp = register_pointwise_numeric_ldf64(aten.exp)
  3087. exp2 = register_pointwise_numeric(aten.exp2)
  3088. expm1 = register_pointwise_numeric(aten.expm1)
  3089. relu = register_pointwise(aten.relu)
  3090. sigmoid = register_pointwise_numeric_ldf64(aten.sigmoid)
  3091. sqrt = register_pointwise_numeric_ldf64(aten.sqrt)
  3092. square = register_pointwise(aten.square)
  3093. sub = register_pointwise(aten.sub, allow_alpha=True)
  3094. register_pointwise_numeric_ldf64(aten.cos)
  3095. register_pointwise_numeric_ldf64(aten.sin)
  3096. register_pointwise(aten.abs)
  3097. register_pointwise(aten.bitwise_and)
  3098. register_pointwise(aten.bitwise_not, override_fn_when_input_bool="logical_not")
  3099. register_pointwise(aten.bitwise_or)
  3100. register_pointwise(aten.bitwise_xor)
  3101. register_pointwise(aten.bitwise_left_shift)
  3102. # TODO(fdrocha): once https://github.com/openai/triton/pull/1153 is merged and we advance the triton pin past it
  3103. # this should be uncommented
  3104. # register_pointwise(aten.bitwise_right_shift)
  3105. register_pointwise_numeric(aten.lgamma)
  3106. erf = register_pointwise_numeric(aten.erf)
  3107. register_lowering(
  3108. aten.special_erf, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
  3109. )(erf)
  3110. register_pointwise_numeric(aten.log1p)
  3111. register_pointwise_numeric(aten.tan)
  3112. register_pointwise_numeric(aten.tanh)
  3113. register_pointwise_numeric_ldf64(aten.log)
  3114. register_pointwise(aten.logical_not, convert_input_to_bool=True)
  3115. maximum = register_pointwise(aten.maximum)
  3116. minimum = register_pointwise(aten.minimum)
  3117. register_lowering(aten.clamp_min)(maximum)
  3118. register_lowering(aten.clamp_max)(minimum)
  3119. register_pointwise(aten.neg)
  3120. register_pointwise_numeric(aten.reciprocal)
  3121. register_pointwise(aten.remainder)
  3122. register_pointwise(aten.sign, override_fn_when_input_bool="identity")
  3123. register_pointwise(aten.ceil)
  3124. register_pointwise(aten.signbit, override_return_dtype=torch.bool)
  3125. register_pointwise(aten.le, type_promotion_kind=None, override_return_dtype=torch.bool)
  3126. register_pointwise(aten.lt, type_promotion_kind=None, override_return_dtype=torch.bool)
  3127. register_pointwise(aten.ge, type_promotion_kind=None, override_return_dtype=torch.bool)
  3128. register_pointwise(aten.gt, type_promotion_kind=None, override_return_dtype=torch.bool)
  3129. register_pointwise(aten.eq, type_promotion_kind=None, override_return_dtype=torch.bool)
  3130. register_pointwise(aten.ne, type_promotion_kind=None, override_return_dtype=torch.bool)
  3131. logical_and = register_pointwise(
  3132. aten.logical_and,
  3133. type_promotion_kind=None,
  3134. convert_input_to_bool=True,
  3135. override_return_dtype=torch.bool,
  3136. )
  3137. register_lowering(aten.__and__, type_promotion_kind=None)(logical_and)
  3138. register_lowering(aten.__or__, type_promotion_kind=None)(
  3139. register_pointwise(
  3140. aten.logical_or,
  3141. type_promotion_kind=None,
  3142. convert_input_to_bool=True,
  3143. override_return_dtype=torch.bool,
  3144. )
  3145. )
  3146. logical_xor = register_pointwise(
  3147. aten.logical_xor,
  3148. name="bitwise_xor",
  3149. type_promotion_kind=None,
  3150. convert_input_to_bool=True,
  3151. override_return_dtype=torch.bool,
  3152. )
  3153. register_lowering(aten.__xor__, type_promotion_kind=None)(logical_xor)
  3154. register_pointwise_numeric(aten.cosh)
  3155. register_pointwise_numeric(aten.sinh)
  3156. register_pointwise_numeric(aten.acos)
  3157. register_pointwise_numeric(aten.acosh)
  3158. register_pointwise_numeric(aten.asin)
  3159. register_pointwise_numeric(aten.asinh)
  3160. register_pointwise_numeric(aten.atan2)
  3161. register_pointwise_numeric(aten.atan)
  3162. register_pointwise_numeric(aten.atanh)
  3163. register_pointwise_numeric(aten.copysign)
  3164. register_pointwise_numeric(aten.erfc)
  3165. register_pointwise_numeric(aten.hypot)
  3166. register_pointwise_numeric(aten.log10)
  3167. register_pointwise_numeric(aten.nextafter)
  3168. def register_inplace(aten_op, outplace_op):
  3169. @register_lowering(aten_op, type_promotion_kind=None)
  3170. def fn(*args, **kwargs):
  3171. result = outplace_op(*args, **kwargs)
  3172. result = to_dtype(result, args[0].get_dtype())
  3173. return mutate_to(args[0], result)
  3174. return fn
  3175. register_inplace(aten.add_, add)
  3176. register_inplace(aten.mul_, mul)
  3177. register_inplace(aten.div_.Tensor, div)
  3178. register_inplace(aten.div_.Tensor_mode, div_mode)
  3179. register_inplace(aten.sub_, sub)
  3180. register_inplace(aten.relu_, relu)
  3181. register_inplace(aten.sigmoid_, sigmoid)
  3182. @register_lowering(aten.sym_size)
  3183. def sym_size(a, dim):
  3184. return a.get_size()[dim]
  3185. @register_lowering(aten.sym_stride)
  3186. def sym_stride(a, dim):
  3187. return a.get_stride()[dim]
  3188. @register_lowering(aten.sym_numel)
  3189. def sym_numel(a):
  3190. return a.get_numel()
  3191. for method, func in magic_methods.items():
  3192. register_lowering(method_to_operator(method))(func)
  3193. @register_lowering(aten._foobar)
  3194. def foobar(self, *args, **kwargs):
  3195. raise NotImplementedError("Helpful for debugging")
  3196. @register_lowering(torch.ops._inductor_test.realize)
  3197. def _realize(x):
  3198. x.realize()
  3199. return clone(x)
  3200. # populate lowerings defined in kernel/*
  3201. from . import kernel
  3202. import_submodule(kernel)