decompositions.py 109 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317
  1. import functools
  2. import operator
  3. import sys
  4. from enum import Enum
  5. from functools import partial, reduce
  6. from itertools import product
  7. from typing import Callable, cast, Iterable, List, Optional, Tuple, Union
  8. import torch
  9. import torch._prims as prims
  10. import torch._prims_common as utils
  11. import torch.nn.functional as F
  12. from torch import sym_float, sym_int, Tensor
  13. from torch._decomp import register_decomposition
  14. from torch._prims_common import IntLike, NumberType, TensorLike, TensorSequenceType
  15. from torch._prims_common.wrappers import (
  16. _maybe_convert_to_dtype,
  17. _maybe_resize_out,
  18. _safe_copy_out,
  19. out_wrapper,
  20. )
  21. from torch.fx.experimental.symbolic_shapes import guard_int
  22. from torch.utils._pytree import tree_flatten, tree_map
  23. DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
  24. # None of these functions are publicly accessible; get at them
  25. # from torch._decomps
  26. __all__: List[str] = []
  27. aten = torch._ops.ops.aten
  28. class Reduction(Enum):
  29. NONE = 0
  30. MEAN = 1
  31. SUM = 2
  32. # This wraps a decomposition and performs various type promotion logic within it, depending on the strategy provided
  33. # We're currently re-using ELEMENTWISE_TYPE_PROMOTION_KIND, although some of the usages are on non-elementwise ops
  34. # Will need to validate the non-elementwise uses
  35. def type_casts(
  36. f: Callable,
  37. type_promotion: utils.ELEMENTWISE_TYPE_PROMOTION_KIND,
  38. compute_dtype_only: bool = False,
  39. ):
  40. @functools.wraps(f)
  41. def inner(*args, **kwargs):
  42. flat_args = [
  43. x for x in tree_flatten((args, kwargs))[0] if isinstance(x, Tensor)
  44. ]
  45. computation_dtype, result_dtype = utils.elementwise_dtypes(
  46. *flat_args, type_promotion_kind=type_promotion
  47. )
  48. # TODO: pretty sure this is not quite right
  49. def increase_prec(x):
  50. if isinstance(x, Tensor):
  51. return x.to(computation_dtype)
  52. else:
  53. return x
  54. def decrease_prec(x):
  55. if isinstance(x, Tensor):
  56. return x.to(result_dtype)
  57. else:
  58. return x
  59. r = f(*tree_map(increase_prec, args), **tree_map(increase_prec, kwargs))
  60. if compute_dtype_only:
  61. return r
  62. else:
  63. return tree_map(decrease_prec, r)
  64. return inner
  65. compute_only_pw_cast_for_opmath = partial(
  66. type_casts,
  67. type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  68. compute_dtype_only=True,
  69. )
  70. pw_cast_for_opmath = partial(
  71. type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  72. )
  73. pw_cast_for_int_to_real = partial(
  74. type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
  75. )
  76. # This expands x until x.dim() == dim. Might be useful as an operator
  77. def _unsqueeze_to_dim(x: Tensor, dim: int):
  78. for _ in range(dim - x.dim()):
  79. x = x.unsqueeze(-1)
  80. return x
  81. @register_decomposition(aten.tanh_backward)
  82. @pw_cast_for_opmath
  83. def tanh_backward(out_grad: Tensor, y: Tensor):
  84. return out_grad * (1 - y * y).conj_physical()
  85. @register_decomposition(aten.sigmoid_backward)
  86. @pw_cast_for_opmath
  87. def sigmoid_backward(out_grad: Tensor, y: Tensor):
  88. return out_grad * (y * (1 - y)).conj_physical()
  89. @register_decomposition(aten.softplus_backward)
  90. @pw_cast_for_opmath
  91. def softplus_backward(out_grad: Tensor, x: Tensor, beta: float, threshold: float):
  92. z = (x * beta).exp()
  93. return torch.where((x * beta) > threshold, out_grad, out_grad * z / (z + 1.0))
  94. @register_decomposition(aten.elu_backward)
  95. @pw_cast_for_opmath
  96. def elu_backward(
  97. grad_output: Tensor,
  98. alpha: float,
  99. scale: float,
  100. input_scale: float,
  101. is_result: bool,
  102. self_or_result: Tensor,
  103. ):
  104. negcoef = alpha * scale
  105. poscoef = scale
  106. negiptcoef = input_scale
  107. if is_result:
  108. return torch.where(
  109. self_or_result <= 0,
  110. grad_output * negiptcoef * (self_or_result + negcoef),
  111. self_or_result * poscoef,
  112. )
  113. else:
  114. return torch.where(
  115. self_or_result <= 0,
  116. grad_output * negiptcoef * negcoef * torch.exp(self_or_result * negiptcoef),
  117. grad_output * poscoef,
  118. )
  119. @register_decomposition([aten.fill.Scalar])
  120. def fill_scalar(self, value):
  121. return torch.full_like(self, value)
  122. @register_decomposition([aten.fill.Tensor])
  123. def fill_tensor(self, value: Tensor):
  124. utils.check(
  125. value.dim() == 0,
  126. lambda: f"fill only supports 0-dimension value tensor but got tensor with {value.dim()} dimensions",
  127. )
  128. return torch.full_like(self, value.item())
  129. @register_decomposition(aten.hardsigmoid)
  130. @pw_cast_for_opmath
  131. def hardsigmoid(self: Tensor) -> Tensor:
  132. return torch.clamp(torch.clamp(self + 3, min=0), max=6) / 6
  133. @register_decomposition(aten.hardsigmoid_backward)
  134. @pw_cast_for_opmath
  135. def hardsigmoid_backward(grad_output: Tensor, self: Tensor):
  136. return torch.where(
  137. (self > -3.0) & (self < 3.0),
  138. grad_output * (1.0 / 6.0),
  139. 0.0,
  140. )
  141. @register_decomposition(aten.hardtanh_backward)
  142. def hardtanh_backward(
  143. grad_output: Tensor, self: Tensor, min_val: float, max_val: float
  144. ):
  145. return torch.where((self <= min_val) | (self >= max_val), 0.0, grad_output)
  146. @register_decomposition(aten.hardshrink_backward)
  147. def hardshrink_backward(grad_out: Tensor, self: Tensor, lambd: float):
  148. return torch.where((self >= -lambd) & (self <= lambd), 0.0, grad_out)
  149. @register_decomposition(aten.hardswish)
  150. @pw_cast_for_opmath
  151. def hardswish(self: Tensor) -> Tensor:
  152. return self * torch.clamp(torch.clamp(self + 3, min=0), max=6) / 6
  153. @register_decomposition(aten.hardswish_backward)
  154. @pw_cast_for_opmath
  155. def hardswish_backward(grad_output: Tensor, self: Tensor) -> Tensor:
  156. return torch.where(
  157. self < -3,
  158. 0.0,
  159. torch.where(self <= 3, grad_output * ((self / 3) + 0.5), grad_output),
  160. )
  161. @register_decomposition(aten.threshold_backward)
  162. def threshold_backward(grad_output: Tensor, self: Tensor, threshold: float):
  163. return torch.where(self <= threshold, 0.0, grad_output)
  164. @register_decomposition(aten.leaky_relu_backward)
  165. @pw_cast_for_opmath
  166. def leaky_relu_backward(
  167. grad_output: Tensor, self: Tensor, negative_slope: float, self_is_result: bool
  168. ):
  169. return torch.where(self > 0, grad_output, grad_output * negative_slope)
  170. @register_decomposition(aten.gelu_backward)
  171. @pw_cast_for_opmath
  172. def gelu_backward(grad: Tensor, self: Tensor, approximate: str = "none"):
  173. M_SQRT2 = 1.41421356237309504880
  174. M_SQRT1_2 = 0.70710678118654752440
  175. M_2_SQRTPI = 1.12837916709551257390
  176. if approximate == "tanh":
  177. kBeta = M_SQRT2 * M_2_SQRTPI * 0.5
  178. kKappa = 0.044715
  179. x_sq = self * self
  180. x_cube = x_sq * self
  181. inner = kBeta * (self + kKappa * x_cube)
  182. tanh_inner = torch.tanh(inner)
  183. left = 0.5 * self
  184. right = 1 + tanh_inner
  185. left_derivative = 0.5 * right
  186. tanh_derivative = 1 - tanh_inner * tanh_inner
  187. inner_derivative = kBeta * (1 + 3 * kKappa * x_sq)
  188. right_derivative = left * tanh_derivative * inner_derivative
  189. return grad * (left_derivative + right_derivative)
  190. else:
  191. kAlpha = M_SQRT1_2
  192. kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5
  193. cdf = 0.5 * (1 + torch.erf(self * kAlpha))
  194. pdf = kBeta * torch.exp(self * self * -0.5)
  195. return grad * (cdf + self * pdf)
  196. @register_decomposition(aten.mish_backward)
  197. @pw_cast_for_opmath
  198. def mish_backward(grad_output: Tensor, input: Tensor):
  199. input_tanh_softplus = torch.tanh(F.softplus(input))
  200. input_sigmoid = torch.sigmoid(input)
  201. out = input * input_sigmoid * (1 - input_tanh_softplus * input_tanh_softplus)
  202. return grad_output * (input_tanh_softplus + out)
  203. @register_decomposition(aten.silu)
  204. @pw_cast_for_opmath
  205. def silu(self: Tensor) -> Tensor:
  206. return self * torch.sigmoid(self)
  207. @register_decomposition(aten.silu_backward)
  208. @pw_cast_for_opmath
  209. def silu_backward(grad_output: Tensor, self: Tensor) -> Tensor:
  210. sigmoid = 1 / (1 + torch.exp(-self))
  211. return grad_output * sigmoid * (1 + self * (1 - sigmoid))
  212. @register_decomposition(aten.softshrink_backward)
  213. def softshrink_backward(grad_output: Tensor, self: Tensor, lambd: float) -> Tensor:
  214. return torch.where((self >= -lambd) & (self <= lambd), 0.0, grad_output)
  215. @register_decomposition(aten._prelu_kernel)
  216. def _prelu_kernel(self: Tensor, weight: Tensor) -> Tensor:
  217. return torch.where(self > 0, self, weight * self)
  218. @register_decomposition(aten._prelu_kernel_backward)
  219. def _prelu_kernel_backward(
  220. grad_output: Tensor,
  221. self: Tensor,
  222. weight: Tensor,
  223. ) -> Tuple[Tensor, Tensor]:
  224. input_grad = torch.where(self > 0, grad_output, weight * grad_output)
  225. weight_grad = torch.where(self > 0, 0.0, self * grad_output)
  226. return (input_grad, weight_grad)
  227. @register_decomposition(aten.rrelu_with_noise_backward)
  228. @pw_cast_for_opmath
  229. def rrelu_with_noise_backward(
  230. grad_output: Tensor,
  231. self: Tensor,
  232. noise: Tensor,
  233. lower: float,
  234. upper: float,
  235. training: bool,
  236. self_is_result: bool,
  237. ) -> Tensor:
  238. if training and upper - lower > 1e-6:
  239. return grad_output.mul(noise)
  240. else:
  241. negative_slope = (lower + upper) / 2
  242. return aten.leaky_relu_backward(
  243. grad_output, self, negative_slope, self_is_result
  244. )
  245. @register_decomposition(aten.log_sigmoid_backward)
  246. @pw_cast_for_opmath
  247. def log_sigmoid_backward(grad_output: Tensor, self: Tensor, buffer: Tensor) -> Tensor:
  248. in_negative = self < 0
  249. max_deriv = torch.where(in_negative, 1, 0)
  250. sign = torch.where(in_negative, 1, -1)
  251. z = torch.exp(-torch.abs(self))
  252. return grad_output * (max_deriv - sign * (z / (1 + z)))
  253. # CPU has a special formula that uses buffer, but disabled for convenience sake
  254. # return (max_deriv - sign * (buffer / (1 + buffer))) * grad_output
  255. def apply_loss_reduction(loss: Tensor, reduction: int):
  256. if reduction == Reduction.MEAN.value:
  257. return torch.mean(loss)
  258. elif reduction == Reduction.SUM.value:
  259. return torch.sum(loss)
  260. else:
  261. return loss
  262. def to_real_dtype(dtype: torch.dtype):
  263. if dtype == torch.complex32:
  264. return torch.float16
  265. elif dtype == torch.complex64:
  266. return torch.float32
  267. elif dtype == torch.complex128:
  268. return torch.float64
  269. # TODO: None of these loss castings are quite correct, see
  270. # https://github.com/pytorch/pytorch/issues/76870. Also, the ATen kernels
  271. # perform the pointwise portion in opmath, but don't maintain it between the
  272. # pointwise portion and the reduction
  273. @register_decomposition(aten.mse_loss)
  274. @pw_cast_for_opmath
  275. def mse_loss(
  276. self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value
  277. ) -> Tensor:
  278. loss = (self - target) ** 2
  279. return apply_loss_reduction(loss, reduction)
  280. @register_decomposition(aten.mse_loss_backward)
  281. @pw_cast_for_opmath
  282. def mse_loss_backward(
  283. grad_output: Tensor, input: Tensor, target: Tensor, reduction: int
  284. ):
  285. norm = 2.0 / input.numel() if reduction == Reduction.MEAN.value else 2.0
  286. return norm * (input - target) * grad_output
  287. @register_decomposition(aten.huber_loss_backward.default)
  288. @pw_cast_for_opmath
  289. def huber_loss_backward(
  290. grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, delta: float
  291. ):
  292. norm = 1.0 / self.numel() if reduction == Reduction.MEAN.value else 1.0
  293. x = self - target
  294. return torch.where(
  295. x < -delta,
  296. -norm * grad_output * delta,
  297. torch.where(x > delta, norm * grad_output * delta, norm * x * grad_output),
  298. )
  299. # We cannot use @out_wrapper() here, because the output tensor is not named 'out', it's 'grad_input'
  300. @register_decomposition(aten.huber_loss_backward.out)
  301. @pw_cast_for_opmath
  302. def huber_loss_backward_out(
  303. grad_output: Tensor,
  304. self: Tensor,
  305. target: Tensor,
  306. reduction: int,
  307. delta: float,
  308. grad_input: Tensor,
  309. ):
  310. result = huber_loss_backward(grad_output, self, target, reduction, delta)
  311. _maybe_resize_out(grad_input, result.shape)
  312. return _safe_copy_out(copy_from=result, copy_to=grad_input, exact_dtype=True)
  313. def _nll_loss_backward(
  314. grad_output: Tensor,
  315. self: Tensor,
  316. target: Tensor,
  317. weight: Optional[Tensor],
  318. reduction: int,
  319. ignore_index: int,
  320. total_weight: Tensor,
  321. ) -> Tensor:
  322. channel_dim = 0 if self.dim() < 2 else 1
  323. if reduction == Reduction.MEAN.value:
  324. grad_output = grad_output / total_weight
  325. target = target.unsqueeze(channel_dim)
  326. safe_target = torch.where(target != ignore_index, target, 0)
  327. grad_input = torch.zeros_like(self)
  328. grad_input = torch.scatter(grad_input, channel_dim, safe_target, -1.0)
  329. if grad_input.dim() > grad_output.dim() > 0:
  330. grad_output = grad_output.unsqueeze(channel_dim)
  331. if weight is not None:
  332. new_shape = [1 for _ in range(self.dim())]
  333. new_shape[channel_dim] = weight.shape[0]
  334. weight = weight.reshape(new_shape)
  335. grad_output = grad_output * weight
  336. grad_output = torch.where(target != ignore_index, grad_output, 0)
  337. return grad_input * grad_output
  338. @register_decomposition(aten.glu_backward)
  339. @pw_cast_for_opmath
  340. def glu_backward(grad_output: Tensor, self: Tensor, dim: int) -> Tensor:
  341. assert self.dim() > 0, "glu does not support 0-dimensional tensors"
  342. wrap_dim = utils.canonicalize_dim(self.dim(), dim)
  343. nIn = self.size(wrap_dim)
  344. assert (
  345. nIn % 2 == 0
  346. ), f"Halving dimension must be even, but dimension {wrap_dim} is size {nIn}"
  347. inputSize = nIn // 2
  348. firstHalf = self.narrow(wrap_dim, 0, inputSize)
  349. secondHalf = self.narrow(wrap_dim, inputSize, inputSize)
  350. gradInputFirstHalf = torch.sigmoid(secondHalf)
  351. gradInputSecondHalf = (
  352. (1.0 - gradInputFirstHalf) * gradInputFirstHalf * firstHalf * grad_output
  353. )
  354. gradInputFirstHalf = gradInputFirstHalf * grad_output
  355. return torch.cat([gradInputFirstHalf, gradInputSecondHalf], dim=wrap_dim)
  356. @register_decomposition(aten.nll_loss_backward)
  357. def nll_loss_backward(
  358. grad_output: Tensor,
  359. self: Tensor,
  360. target: Tensor,
  361. weight: Optional[Tensor],
  362. reduction: int,
  363. ignore_index: int,
  364. total_weight: Tensor,
  365. ) -> Tensor:
  366. assert 0 <= self.dim() <= 2, "input tensor should be 1D or 2D"
  367. assert (
  368. target.dim() <= 1
  369. ), "0D or 1D target tensor expected, multi-target not supported"
  370. no_batch_dim = self.dim() == 1 and target.dim() == 0
  371. assert no_batch_dim or (
  372. self.shape[0] == target.shape[0]
  373. ), f"size mismatch (got input: {self.shape}, target: {target.shape})"
  374. assert total_weight.numel() == 1, (
  375. "expected total_weight to be a single element tensor, got: ",
  376. f"{total_weight.shape} ({total_weight.numel()} elements)",
  377. )
  378. assert (
  379. weight is None or weight.numel() == self.shape[-1]
  380. ), "weight tensor should be defined either for all or no classes"
  381. if reduction == Reduction.NONE.value and self.dim() == 2:
  382. assert grad_output.dim() == 1 and grad_output.shape[0] == self.shape[0], (
  383. f"Expected a tensor of dimension 1 and tensor.size[0] == {self.shape[0]} but "
  384. f"got: dimension {grad_output.dim()} and tensor.size[0] == {grad_output.shape[0]}"
  385. )
  386. else:
  387. assert (
  388. grad_output.dim() <= 1 and grad_output.numel() == 1
  389. ), f"Expected a single element grad_output tensor, but got: {grad_output.shape}"
  390. return _nll_loss_backward(
  391. grad_output, self, target, weight, reduction, ignore_index, total_weight
  392. )
  393. @register_decomposition(aten.nll_loss2d_backward)
  394. def nll_loss2d_backward(
  395. grad_output: Tensor,
  396. self: Tensor,
  397. target: Tensor,
  398. weight: Optional[Tensor],
  399. reduction: int,
  400. ignore_index: int,
  401. total_weight: Tensor,
  402. ) -> Tensor:
  403. assert (
  404. self.dim() == 4
  405. ), f"only batches of spatial inputs supported (4D tensors), but got input of dimension: {self.dim()}"
  406. assert (
  407. target.dim() == 3
  408. ), f"only batches of spatial targets supported (3D tensors) but got targets of dimension: {target.dim()}"
  409. assert (
  410. self.shape[0] == target.shape[0]
  411. and self.shape[2] == target.shape[1]
  412. and self.shape[3] == target.shape[2]
  413. ), f"size mismatch (got input: {self.shape}, target: {target.shape}"
  414. assert total_weight.numel() == 1, (
  415. "expected total_weight to be a single element tensor, "
  416. f"got: {total_weight.shape} ( {total_weight.numel()}, elements)"
  417. )
  418. return _nll_loss_backward(
  419. grad_output, self, target, weight, reduction, ignore_index, total_weight
  420. )
  421. @register_decomposition(aten.binary_cross_entropy)
  422. @pw_cast_for_opmath
  423. def binary_cross_entropy(
  424. self: Tensor,
  425. target: Tensor,
  426. weight: Optional[Tensor] = None,
  427. reduction: int = Reduction.MEAN.value,
  428. ) -> Tensor:
  429. # We cannot currently model this without introducing data-dependent control flow
  430. # TORCH_CHECK(
  431. # (input_val >= 0) && (input_val <= 1),
  432. # "all elements of input should be between 0 and 1"
  433. # )
  434. loss = (target - 1) * torch.maximum(
  435. torch.log1p(-self), self.new_full((), -100)
  436. ) - target * torch.maximum(torch.log(self), self.new_full((), -100))
  437. if weight is not None:
  438. loss = loss * weight
  439. return apply_loss_reduction(loss, reduction)
  440. @register_decomposition(aten.binary_cross_entropy_backward)
  441. @pw_cast_for_opmath
  442. def binary_cross_entropy_backward(
  443. grad_output: Tensor,
  444. self: Tensor,
  445. target: Tensor,
  446. weight: Optional[Tensor] = None,
  447. reduction: int = Reduction.MEAN.value,
  448. ) -> Tensor:
  449. EPSILON = 1e-12
  450. result = grad_output * (self - target) / torch.clamp(self * (1 - self), min=EPSILON)
  451. if weight is not None:
  452. result = result * weight
  453. if reduction == Reduction.MEAN.value:
  454. result = result / self.numel()
  455. return result
  456. @register_decomposition(aten.soft_margin_loss)
  457. @out_wrapper()
  458. @pw_cast_for_opmath
  459. def soft_margin_loss(
  460. input: Tensor,
  461. target: Tensor,
  462. reduction: int = Reduction.MEAN.value,
  463. ) -> Tensor:
  464. loss = torch.log1p(torch.exp(-input * target))
  465. return apply_loss_reduction(loss, reduction)
  466. @register_decomposition(aten.soft_margin_loss_backward)
  467. @pw_cast_for_opmath
  468. def soft_margin_loss_backward(
  469. grad_output: Tensor,
  470. self: Tensor,
  471. target: Tensor,
  472. reduction: int = Reduction.MEAN.value,
  473. ) -> Tensor:
  474. grad_input = target * grad_output * (torch.sigmoid(target * self) - 1)
  475. if reduction == Reduction.MEAN.value:
  476. grad_input = grad_input / self.numel()
  477. return grad_input
  478. @register_decomposition(aten._euclidean_dist)
  479. def _euclidean_dist(x1: Tensor, x2: Tensor) -> Tensor:
  480. x1_norm = x1.pow(2).sum(-1, True)
  481. x1_pad = torch.ones_like(x1_norm, memory_format=torch.contiguous_format)
  482. x2_norm = x2.pow(2).sum(-1, True)
  483. x2_pad = torch.ones_like(x2_norm, memory_format=torch.contiguous_format)
  484. x1_ = torch.cat([x1.mul(-2), x1_norm, x1_pad], -1)
  485. x2_ = torch.cat([x2, x2_pad, x2_norm], -1)
  486. result = x1_.matmul(x2_.mT)
  487. return result.clamp_min(0).sqrt()
  488. @register_decomposition(aten.slice_backward)
  489. def slice_backward(
  490. grad_output: Tensor,
  491. input_sizes: List[int],
  492. dim: int,
  493. start: int,
  494. end: int,
  495. step: int,
  496. ):
  497. grad_input = grad_output.new_zeros(input_sizes)
  498. return torch.slice_scatter(grad_input, grad_output, dim, start, end, step)
  499. @register_decomposition(aten.slice.Tensor)
  500. def slice_forward(
  501. # Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1
  502. self: Tensor,
  503. dim: int = 0,
  504. start: Optional[int] = None,
  505. end: Optional[int] = None,
  506. step: int = 1,
  507. ):
  508. ndim = self.dim()
  509. if ndim == 0:
  510. raise RuntimeError("slice() cannot be applied to a 0-dim tensor.")
  511. dim = utils.canonicalize_dim(self.dim(), dim)
  512. sizes = list(self.size())
  513. strides = list(self.stride())
  514. if step <= 0:
  515. raise RuntimeError("slice step must be positive")
  516. start_val = start if start is not None else 0
  517. end_val = end if end is not None else sys.maxsize # 2^63 – 1
  518. if start_val < 0:
  519. start_val += sizes[dim]
  520. if end_val < 0:
  521. end_val += sizes[dim]
  522. if start_val < 0:
  523. start_val = 0
  524. elif start_val >= sizes[dim]:
  525. start_val = sizes[dim]
  526. if end_val < start_val:
  527. end_val = start_val
  528. elif end_val >= sizes[dim]:
  529. end_val = sizes[dim]
  530. storage_offset = self.storage_offset() + start_val * strides[dim]
  531. len = end_val - start_val
  532. sizes[dim] = (len + step - 1) // step
  533. strides[dim] *= step
  534. if self.is_quantized:
  535. raise NotImplementedError(
  536. "Slice decomposition for quantized tensors aren't implemented"
  537. )
  538. else:
  539. return self.as_strided(sizes, strides, storage_offset)
  540. @register_decomposition(aten.select_backward)
  541. def select_backward(grad_output: Tensor, input_sizes: List[int], dim: int, index: int):
  542. grad_input = grad_output.new_zeros(input_sizes)
  543. return torch.select_scatter(grad_input, grad_output, dim, index)
  544. @register_decomposition(aten.diagonal_backward)
  545. def diagonal_backward(
  546. grad_output: Tensor, input_sizes: List[int], offset: int, dim1: int, dim2: int
  547. ):
  548. grad_input = grad_output.new_zeros(input_sizes)
  549. return torch.diagonal_scatter(grad_input, grad_output, offset, dim1, dim2)
  550. def _cast_grad_to_input_dtype(
  551. grad_output: Tensor, grad_input: Tensor, input_dtype: torch.dtype
  552. ):
  553. if grad_output.dtype != input_dtype:
  554. grad_input = grad_input.to(input_dtype)
  555. return grad_input
  556. @register_decomposition(aten._softmax_backward_data)
  557. @compute_only_pw_cast_for_opmath
  558. def _softmax_backward_data(
  559. grad_output: Tensor, output: Tensor, dim: int, input_dtype: torch.dtype
  560. ):
  561. new_grad_output = grad_output * output
  562. grad_input = new_grad_output - output * torch.sum(
  563. new_grad_output, dim=dim, keepdim=True
  564. )
  565. # CPU kernel doesn't respect input_dtype, but following check doesn't work for meta tensor
  566. # if grad_output.device == torch.device("cpu"):
  567. # return grad_input.contiguous()
  568. return _cast_grad_to_input_dtype(grad_output, grad_input, input_dtype).contiguous()
  569. @register_decomposition(aten._log_softmax_backward_data)
  570. @compute_only_pw_cast_for_opmath
  571. def _log_softmax_backward_data(
  572. grad_output: Tensor, output: Tensor, dim: int, input_dtype: torch.dtype
  573. ):
  574. grad_input = grad_output - torch.exp(output) * torch.sum(
  575. grad_output, dim=dim, keepdim=True
  576. )
  577. return _cast_grad_to_input_dtype(grad_output, grad_input, input_dtype)
  578. def _im2col_col2im_indices_along_dim(
  579. input_d, kernel_d, dilation_d, padding_d, stride_d, device
  580. ):
  581. """Utility function to implement im2col and col2im"""
  582. blocks_d = input_d + padding_d * 2 - dilation_d * (kernel_d - 1)
  583. arange_kw = partial(torch.arange, dtype=torch.int64, device=device)
  584. # Stride kernel over input and find starting indices along dim d
  585. blocks_d_indices = arange_kw(0, blocks_d, stride_d).unsqueeze(0)
  586. # Apply dilation on kernel and find its indices along dim d
  587. kernel_grid = arange_kw(0, kernel_d * dilation_d, dilation_d).unsqueeze(-1)
  588. # Broadcast and add kernel staring positions (indices) with
  589. # kernel_grid along dim d, to get block indices along dim d
  590. return blocks_d_indices + kernel_grid
  591. @register_decomposition(aten.im2col)
  592. @out_wrapper()
  593. @pw_cast_for_opmath
  594. def im2col(
  595. input: Tensor,
  596. kernel_size: List[int],
  597. dilation: List[int],
  598. padding: List[int],
  599. stride: List[int],
  600. ) -> Tensor:
  601. utils.check(len(kernel_size) == 2, lambda: "im2col(): only 2D kernel supported")
  602. utils.check(len(dilation) == 2, lambda: "im2col(): only 2D dilation supported")
  603. utils.check(len(padding) == 2, lambda: "im2col(): only 2D padding supported")
  604. utils.check(len(stride) == 2, lambda: "im2col(): only 2D stride supported")
  605. def check_positive(param, param_name, strict=True):
  606. cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param)
  607. utils.check(
  608. cond, lambda: "{param_name} should be greater {'than' zero, but got {param}"
  609. )
  610. check_positive(kernel_size, "kernel_size")
  611. check_positive(dilation, "dilation")
  612. check_positive(dilation, "padding", strict=False)
  613. check_positive(stride, "stride")
  614. shape = input.shape
  615. ndim = len(shape)
  616. utils.check(
  617. ndim in (3, 4) and all(d != 0 for d in shape[-3:]),
  618. lambda: "Expected 3D or 4D (batch mode) tensor for input with possible 0 batch size "
  619. f"and non-zero dimensions, but got: {tuple(shape)}",
  620. )
  621. output_size = tuple(
  622. 1 + (out + 2 * pad - dil * (ker - 1) - 1) // st
  623. for out, pad, dil, ker, st in zip(
  624. shape[-2:], padding, dilation, kernel_size, stride
  625. )
  626. )
  627. utils.check(
  628. all(c > 0 for c in output_size),
  629. lambda: f"Given an input with spacial size {tuple(shape[-2:])}, "
  630. f"kernel_size={kernel_size}, dilation={dilation}, "
  631. f"padding={padding}, stride={stride}, "
  632. "the calculated shape of the array of sliding blocks "
  633. f"is {output_size}, but its components must be at least one.",
  634. )
  635. batched_input = ndim == 4
  636. if not batched_input:
  637. input = input.unsqueeze(0)
  638. batch_dim, channel_dim, input_h, input_w = input.shape
  639. stride_h, stride_w = stride
  640. padding_h, padding_w = padding
  641. dilation_h, dilation_w = dilation
  642. kernel_h, kernel_w = kernel_size
  643. blocks_row_indices = _im2col_col2im_indices_along_dim(
  644. input_h, kernel_h, dilation_h, padding_h, stride_h, input.device
  645. )
  646. blocks_col_indices = _im2col_col2im_indices_along_dim(
  647. input_w, kernel_w, dilation_w, padding_w, stride_w, input.device
  648. )
  649. # Note that F.pad takes (padding_left, padding_right, padding_top, padding_bottom)
  650. # ugh
  651. padded_input = F.pad(input, (padding_w, padding_w, padding_h, padding_h))
  652. blocks_row_indices = blocks_row_indices.unsqueeze(-1).unsqueeze(-1)
  653. output = padded_input[:, :, blocks_row_indices, blocks_col_indices]
  654. output = output.permute(0, 1, 2, 4, 3, 5)
  655. num_blocks_row = blocks_row_indices.size(1)
  656. num_blocks_col = blocks_col_indices.size(1)
  657. output = output.reshape(
  658. batch_dim, channel_dim * kernel_h * kernel_w, num_blocks_row * num_blocks_col
  659. )
  660. if not batched_input:
  661. output = output.squeeze(0)
  662. return output
  663. @register_decomposition(aten.col2im)
  664. @out_wrapper()
  665. @pw_cast_for_opmath
  666. def col2im(
  667. input: Tensor,
  668. output_size: List[int],
  669. kernel_size: List[int],
  670. dilation: List[int],
  671. padding: List[int],
  672. stride: List[int],
  673. ) -> Tensor:
  674. utils.check(len(output_size) == 2, lambda: "only 2D output_size supported")
  675. utils.check(len(kernel_size) == 2, lambda: "only 2D kernel supported")
  676. utils.check(len(dilation) == 2, lambda: "only 2D dilation supported")
  677. utils.check(len(padding) == 2, lambda: "only 2D padding supported")
  678. utils.check(len(stride) == 2, lambda: "only 2D stride supported")
  679. def check_positive(param, param_name, strict=True):
  680. cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param)
  681. utils.check(
  682. cond, lambda: "{param_name} should be greater than zero, but got {param}"
  683. )
  684. check_positive(kernel_size, "kernel_size")
  685. check_positive(dilation, "dilation")
  686. check_positive(padding, "padding", strict=False)
  687. check_positive(stride, "stride")
  688. check_positive(output_size, "output_size")
  689. shape = input.shape
  690. ndim = len(shape)
  691. utils.check(
  692. ndim in (2, 3) and all(d != 0 for d in shape[-2:]),
  693. lambda: "Expected 2D or 3D (batch mode) tensor for input with possible 0 batch size "
  694. f"and non-zero dimensions, but got: {tuple(shape)}",
  695. )
  696. prod_kernel_size = kernel_size[0] * kernel_size[1]
  697. utils.check(
  698. shape[-2] % prod_kernel_size == 0,
  699. lambda: "Expected size of input's first non-batch dimension to be divisible by the "
  700. f"product of kernel_size, but got input.shape[-2] = {shape[-2]} and "
  701. f"kernel_size={kernel_size}",
  702. )
  703. col = [
  704. 1 + (out + 2 * pad - dil * (ker - 1) - 1) // st
  705. for out, pad, dil, ker, st in zip(
  706. output_size, padding, dilation, kernel_size, stride
  707. )
  708. ]
  709. L = col[0] * col[1]
  710. utils.check(
  711. shape[-1] == L,
  712. lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, "
  713. f"dilation={dilation}, padding={padding}, stride={stride}, "
  714. f"expected input.size(-1) to be {L} but got {shape[-1]}.",
  715. )
  716. utils.check(
  717. L > 0,
  718. lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, "
  719. f"dilation={dilation}, padding={padding}, stride={stride}, "
  720. f"expected input.size(-1) to be {L} but got {shape[-1]}.",
  721. )
  722. batched_input = ndim == 3
  723. if not batched_input:
  724. input = input.unsqueeze(0)
  725. shape = input.shape
  726. out_h, out_w = output_size
  727. stride_h, stride_w = stride
  728. padding_h, padding_w = padding
  729. dilation_h, dilation_w = dilation
  730. kernel_h, kernel_w = kernel_size
  731. # col2im is defined as the backwards of im2col, so we differentiate its decomposition by hand
  732. input = input.reshape([shape[0], shape[1] // prod_kernel_size] + kernel_size + col)
  733. input = input.permute(0, 1, 2, 4, 3, 5)
  734. indices_row = _im2col_col2im_indices_along_dim(
  735. out_h, kernel_h, dilation_h, padding_h, stride_h, input.device
  736. )
  737. indices_row = _unsqueeze_to_dim(indices_row, 4)
  738. indices_col = _im2col_col2im_indices_along_dim(
  739. out_w, kernel_w, dilation_w, padding_w, stride_w, input.device
  740. )
  741. output_padded_size = [o + 2 * p for o, p in zip(output_size, padding)]
  742. output = input.new_zeros(
  743. [shape[0], shape[1] // prod(kernel_size)] + output_padded_size
  744. )
  745. idx = (None, None, indices_row, indices_col)
  746. output = aten.index_put(output, idx, input, accumulate=True)
  747. output = F.pad(output, (-padding_w, -padding_w, -padding_h, -padding_h))
  748. if not batched_input:
  749. output = output.squeeze(0)
  750. return output
  751. @register_decomposition(aten.native_dropout_backward)
  752. def native_dropout_backward(grad_output: Tensor, mask: Tensor, scale: float):
  753. # According to the CUDA kernel implementation we should have this test;
  754. # but it seems to fail tests!
  755. # utils.check(mask.dtype == torch.bool, lambda: f"Mask should be Bool Scalar Type {mask.dtype}")
  756. # Mimicking CUDA kernel's behavior for output stride: output follow input's memory format
  757. # This different from TensorIterator's behavior
  758. r = (grad_output * (mask.type_as(grad_output) * scale)).clone(
  759. memory_format=utils.suggest_memory_format(grad_output)
  760. )
  761. return r
  762. @register_decomposition(aten.unfold_backward)
  763. def unfold_backward(
  764. grad: Tensor, input_size: List[int], dimension: int, size: int, step: int
  765. ) -> Tensor:
  766. if len(input_size) == 0:
  767. return torch.squeeze_copy(grad, 0)
  768. dim = utils.canonicalize_dim(len(input_size), dimension)
  769. idx = torch.arange(input_size[dim], device=grad.device, dtype=torch.int32)
  770. idx = idx.unfold(0, size, step).flatten()
  771. grad = grad.movedim(-1, dim + 1).flatten(dim, dim + 1)
  772. # nb. At the moment this generates two kernels in triton
  773. # It could potentially be fused into one call to scatter_reduce,
  774. # in the case step <= size provided scatter_reduce generates 1 kernel
  775. grad_input = grad.new_zeros(input_size)
  776. return torch.index_add(grad_input, dim, idx, grad)
  777. @register_decomposition(aten.logit_backward.default)
  778. @pw_cast_for_opmath
  779. def logit_backward(
  780. grad_output: Tensor, self: Tensor, eps: Optional[float] = None
  781. ) -> Tensor:
  782. if eps is not None:
  783. lo = eps
  784. hi = 1.0 - lo
  785. return torch.where(
  786. torch.logical_and(self >= lo, self <= hi),
  787. grad_output / (self * (1.0 - self)),
  788. 0.0,
  789. )
  790. else:
  791. return torch.where(
  792. torch.logical_and(self >= 0.0, self <= 1.0),
  793. grad_output / (self * (1.0 - self)),
  794. self.new_full((), float("nan")),
  795. )
  796. @register_decomposition(aten.native_dropout)
  797. def native_dropout(input: Tensor, p: float, train: Optional[bool]):
  798. if train:
  799. bool_mask = torch.rand_like(input) > p
  800. res = bool_mask * input * float(1.0 / (1.0 - p))
  801. return (res, bool_mask)
  802. else:
  803. return (input, torch.ones_like(input, dtype=torch.bool))
  804. @register_decomposition(aten._softmax)
  805. @out_wrapper()
  806. def _softmax(x: Tensor, dim: int, half_to_float: bool):
  807. # eager softmax returns a contiguous tensor. Ensure that decomp also returns
  808. # a contiguous tensor.
  809. x = x.contiguous()
  810. if half_to_float:
  811. assert x.dtype == torch.half
  812. computation_dtype, result_dtype = utils.elementwise_dtypes(
  813. x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  814. )
  815. x = x.to(computation_dtype)
  816. if x.numel() == 0:
  817. unnormalized = torch.exp(x)
  818. else:
  819. x_max = torch.amax(x, dim, keepdim=True)
  820. unnormalized = torch.exp(x - x_max)
  821. result = unnormalized / torch.sum(unnormalized, dim, keepdim=True)
  822. if not half_to_float:
  823. result = result.to(result_dtype)
  824. return result
  825. @register_decomposition(aten._log_softmax)
  826. @out_wrapper()
  827. def _log_softmax(x: Tensor, dim: int, half_to_float: bool):
  828. # eager log_softmax returns a contiguous tensor. Ensure that decomp also
  829. # returns a contiguous tensor.
  830. x = x.contiguous()
  831. if half_to_float:
  832. assert x.dtype == torch.half
  833. computation_dtype, result_dtype = utils.elementwise_dtypes(
  834. x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  835. )
  836. x = x.to(computation_dtype)
  837. if x.numel() == 0:
  838. shifted = x
  839. else:
  840. x_max = torch.amax(x, dim, keepdim=True)
  841. shifted = x - x_max
  842. shifted_logsumexp = torch.log(torch.sum(torch.exp(shifted), dim, keepdim=True))
  843. result = shifted - shifted_logsumexp
  844. if not half_to_float:
  845. result = result.to(result_dtype)
  846. return result
  847. @register_decomposition(aten.rsub.Tensor)
  848. def rsub_Tensor(self: Tensor, other: Tensor, alpha: float = 1) -> Tensor:
  849. return torch.sub(other, self, alpha=alpha)
  850. @register_decomposition(aten.rsub.Scalar)
  851. def rsub_Scalar(self: Tensor, other: float, alpha: float = 1) -> Tensor:
  852. return torch.sub(other, self, alpha=alpha)
  853. @register_decomposition(aten.embedding)
  854. def embedding(
  855. weight: Tensor,
  856. indices: Tensor,
  857. padding_idx: int = -1,
  858. scale_grad_by_freq: bool = False,
  859. sparse: bool = False,
  860. ) -> Tensor:
  861. assert weight.dim() == 2, "'weight' must be 2-D"
  862. # Nb. scale_grad_by_freq is not used in the forward
  863. if indices.ndim <= 1:
  864. # We need this one as weight[indices] calls item() in these cases
  865. out = weight.index_select(0, indices)
  866. if indices.ndim == 0:
  867. out = out.squeeze(0)
  868. return out
  869. else:
  870. return weight[indices]
  871. @register_decomposition(aten.embedding_dense_backward)
  872. def embedding_dense_backward(
  873. grad_output: Tensor,
  874. indices: Tensor,
  875. num_weights: int,
  876. padding_idx: int,
  877. scale_grad_by_freq: bool,
  878. ):
  879. computation_dtype, result_dtype = utils.elementwise_dtypes(
  880. grad_output, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  881. )
  882. grad_output = grad_output.to(computation_dtype)
  883. indices = _maybe_convert_to_dtype(indices, torch.long) # type: ignore[assignment]
  884. if scale_grad_by_freq:
  885. counts = indices.new_zeros((num_weights,))
  886. ones = torch.ones_like(indices)
  887. counts = counts.index_put([indices], ones, accumulate=True)
  888. grad_weights_scale = counts[indices]
  889. grad_output = grad_output / grad_weights_scale.unsqueeze(1)
  890. mask = _unsqueeze_to_dim(indices == padding_idx, grad_output.ndim)
  891. grad = grad_output.masked_fill(mask, 0)
  892. grad_weight = grad_output.new_zeros(
  893. (num_weights,) + grad_output.shape[indices.ndim :]
  894. )
  895. return grad_weight.index_put([indices], grad, accumulate=True).to(result_dtype)
  896. def prod(x: List[int]):
  897. r = 1
  898. for i in x:
  899. r *= i
  900. return r
  901. @register_decomposition([aten.split_with_sizes, aten.unsafe_split_with_sizes])
  902. def split_with_sizes(
  903. self: Tensor, split_sizes: List[int], dim: int = 0
  904. ) -> List[Tensor]:
  905. num_splits = len(split_sizes)
  906. splits = []
  907. start_idx = 0
  908. for i in range(num_splits):
  909. length = split_sizes[i]
  910. splits.append(self.narrow(dim, start_idx, length))
  911. start_idx += length
  912. return splits
  913. @register_decomposition([aten.split.Tensor, aten.unsafe_split.Tensor])
  914. def split(self: Tensor, split_size: int, dim: int = 0) -> List[Tensor]:
  915. input_sizes = self.shape
  916. dim_size = input_sizes[dim]
  917. if split_size == 0:
  918. assert dim_size == 0
  919. return [self]
  920. chunks = (dim_size + split_size - 1) // split_size
  921. chunks = guard_int(chunks)
  922. split_sizes = [split_size for i in range(chunks)]
  923. split_sizes[-1] = split_size - (split_size * chunks - dim_size)
  924. return torch.split(self, split_sizes, dim)
  925. # TODO: this doesn't appear to have enough precision in bfloat16
  926. @register_decomposition(aten.addmm)
  927. @out_wrapper()
  928. @pw_cast_for_opmath
  929. def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int = 1):
  930. if not self.is_floating_point() and not self.is_complex():
  931. beta = int(beta)
  932. alpha = int(alpha)
  933. out = alpha * torch.mm(mat1, mat2)
  934. if beta == 0:
  935. return out
  936. # The output of aten.addmm is contiguous, we need to match this behavior in the decomposition.
  937. # The original implementation 'beta * self + out' would return a strided tensor if `self` is strided.
  938. # We thus use `out`, the output of torch.mm, which is always contiguous, as the first argument for addition.
  939. # This is relying on TensorIterator's behavior that it takes higher precedence on the stride of first input.
  940. # Alternative, we can write `(beta * self + out).contiguous()`, but it introduces another copy in some cases.
  941. # This implementation is not ideal, and we should revisit this when we have a better solution.
  942. return out + beta * self
  943. @register_decomposition(aten.native_group_norm_backward)
  944. @pw_cast_for_opmath
  945. def native_group_norm_backward(
  946. grad_output: Tensor,
  947. input: Tensor,
  948. mean: Tensor,
  949. rstd: Tensor,
  950. gamma: Optional[Tensor],
  951. N: int,
  952. C: int,
  953. HxW: int,
  954. group: int,
  955. output_mask: List[bool],
  956. ) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
  957. utils.check_same_device(
  958. grad_output, input, mean, rstd, allow_cpu_scalar_tensors=False
  959. )
  960. utils.check_same_shape(input, grad_output, allow_cpu_scalar_tensors=False)
  961. utils.check_same_shape(mean, rstd, allow_cpu_scalar_tensors=False)
  962. utils.check(
  963. input.numel() == N * C * HxW,
  964. lambda: f"Expect input to have { N * C * HxW} elements",
  965. )
  966. utils.check(
  967. mean.shape == (N, group),
  968. lambda: f"Expect mean to have shape ({N}, {group}, but got {mean.shape}",
  969. )
  970. utils.check(
  971. gamma is None or gamma.numel() == C,
  972. lambda: f"Expect gamma to have {C} elements but got {gamma.numel() if gamma is not None else -1}",
  973. )
  974. cpg, _rem = divmod(C, group)
  975. utils.check(
  976. _rem == 0,
  977. lambda: f"Expect number of channels {C} to be evenly-divisible by number of groups {group}",
  978. )
  979. # Compute Internal gradients
  980. ds = torch.mul(grad_output, input).view(N, C, HxW).sum(dim=[2])
  981. db = grad_output.view(N, C, HxW).sum(dim=[2])
  982. d_input: Optional[Tensor] = None
  983. d_gamma: Optional[Tensor] = None
  984. d_bias: Optional[Tensor] = None
  985. if output_mask[0]:
  986. s = 1.0 / (HxW * cpg)
  987. if gamma is not None:
  988. ds_val = torch.mul(ds, gamma.unsqueeze(0)).reshape(N, group, cpg).sum(2)
  989. db_val = torch.mul(db, gamma.unsqueeze(0)).reshape(N, group, cpg).sum(2)
  990. c1 = torch.mul(
  991. rstd.unsqueeze(-1),
  992. gamma.reshape(1, group, cpg),
  993. )
  994. else:
  995. ds_val = ds.reshape(N, group, cpg).sum(2)
  996. db_val = db.reshape(N, group, cpg).sum(2)
  997. c1 = torch.mul(
  998. rstd.unsqueeze(-1),
  999. torch.ones((1, group, cpg), device=rstd.device),
  1000. )
  1001. c2 = (db_val * mean - ds_val) * rstd * rstd * rstd * s
  1002. c3 = -c2 * mean - db_val * rstd * s
  1003. c1 = c1.unsqueeze(-1)
  1004. c2 = _unsqueeze_to_dim(c2, 4)
  1005. c3 = _unsqueeze_to_dim(c3, 4)
  1006. d_input = (
  1007. torch.mul(grad_output.reshape(N, group, cpg, HxW), c1)
  1008. + torch.mul(input.reshape(N, group, cpg, HxW), c2)
  1009. + c3
  1010. )
  1011. d_input = d_input.reshape(input.shape).to(input.dtype)
  1012. if output_mask[1]:
  1013. d_gamma = (
  1014. (
  1015. (ds.view(N, group, cpg) - db.view(N, group, cpg) * mean.unsqueeze(-1))
  1016. * rstd.unsqueeze(-1)
  1017. )
  1018. .sum(dim=[0])
  1019. .reshape(C)
  1020. )
  1021. if output_mask[2]:
  1022. d_bias = db.sum(dim=[0])
  1023. return (d_input, d_gamma, d_bias)
  1024. def _maybe_cast(x: Optional[Tensor], dtype) -> Optional[Tensor]:
  1025. if x is not None:
  1026. return x.to(dtype)
  1027. return x
  1028. # TODO: Take a closer look at the type promotion semantics
  1029. @register_decomposition(aten.native_layer_norm_backward)
  1030. def native_layer_norm_backward(
  1031. grad_out: Tensor,
  1032. input: Tensor,
  1033. normalized_shape: List[int],
  1034. mean: Tensor,
  1035. rstd: Tensor,
  1036. weight: Optional[Tensor],
  1037. bias: Optional[Tensor],
  1038. output_mask: List[bool],
  1039. ) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
  1040. input_shape = input.shape
  1041. input_ndim = input.dim()
  1042. computation_dtype = utils.get_computation_dtype(input.dtype)
  1043. grad_out_cast, input_cast, weight_cast, bias_cast = [
  1044. x.to(computation_dtype).contiguous() if x is not None else x
  1045. for x in (grad_out, input, weight, bias)
  1046. ]
  1047. assert grad_out_cast is not None
  1048. axis = input_ndim - len(normalized_shape)
  1049. inner_dims = input_shape[axis:]
  1050. outer_dims = input_shape[:axis]
  1051. inner_dim_indices: List[int] = []
  1052. outer_dim_indices: List[int] = []
  1053. for i in range(input_ndim):
  1054. if i >= axis:
  1055. inner_dim_indices.append(i)
  1056. else:
  1057. outer_dim_indices.append(i)
  1058. N = prod(inner_dims) # type: ignore[arg-type]
  1059. M = prod(outer_dims) # type: ignore[arg-type]
  1060. if M <= 0 or N <= 0:
  1061. return (
  1062. input.new_zeros(input_shape) if output_mask[0] else None,
  1063. input.new_zeros(input_shape[axis:]) if output_mask[1] else None,
  1064. input.new_zeros(input_shape[axis:]) if output_mask[2] else None,
  1065. )
  1066. x_hat = (input_cast - mean) * rstd
  1067. if weight_cast is not None:
  1068. grad_x_hat = grad_out_cast * weight_cast
  1069. else:
  1070. grad_x_hat = grad_out_cast
  1071. a = grad_x_hat * N
  1072. b = torch.sum(grad_x_hat, inner_dim_indices, True)
  1073. c1 = torch.mul(grad_x_hat, x_hat)
  1074. c2 = torch.sum(c1, inner_dim_indices, True)
  1075. c3 = torch.mul(x_hat, c2)
  1076. inner = a - b - c3
  1077. d_input: Optional[Tensor] = None
  1078. d_weight: Optional[Tensor] = None
  1079. d_bias: Optional[Tensor] = None
  1080. if output_mask[0]:
  1081. d_input = (rstd / N) * inner
  1082. if output_mask[1] and weight_cast is not None:
  1083. if len(outer_dim_indices) > 0:
  1084. d_weight = torch.sum(grad_out_cast * x_hat, outer_dim_indices, False)
  1085. else:
  1086. d_weight = grad_out_cast * x_hat
  1087. if output_mask[2] and bias_cast is not None:
  1088. if len(outer_dim_indices) > 0:
  1089. d_bias = torch.sum(grad_out_cast, outer_dim_indices, False)
  1090. else:
  1091. d_bias = grad_out_cast.clone()
  1092. return (
  1093. _maybe_cast(d_input, input.dtype),
  1094. _maybe_cast(d_weight, input.dtype),
  1095. _maybe_cast(d_bias, input.dtype),
  1096. )
  1097. def native_batch_norm_helper(
  1098. input: Tensor,
  1099. weight: Optional[Tensor],
  1100. bias: Optional[Tensor],
  1101. running_mean: Optional[Tensor],
  1102. running_var: Optional[Tensor],
  1103. training: bool,
  1104. momentum: float,
  1105. eps: float,
  1106. functional: bool,
  1107. ) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
  1108. reduction_dims = [0] + list(range(2, input.dim()))
  1109. computation_dtype = utils.get_computation_dtype(input.dtype)
  1110. new_running_mean = running_mean
  1111. new_running_var = running_var
  1112. if training:
  1113. computation_dtype = utils.get_computation_dtype(input.dtype)
  1114. input_acc = input.to(dtype=computation_dtype)
  1115. biased_var, mean = torch.var_mean(
  1116. input_acc, dim=reduction_dims, correction=0, keepdim=True
  1117. )
  1118. rstd = torch.rsqrt(biased_var + eps)
  1119. output = (input - mean) * rstd
  1120. save_mean = torch.squeeze(mean, reduction_dims)
  1121. save_rstd = torch.squeeze(rstd, reduction_dims)
  1122. if running_mean is not None:
  1123. new_running_mean = momentum * save_mean + (1 - momentum) * running_mean
  1124. if not functional:
  1125. running_mean.copy_(new_running_mean)
  1126. if running_var is not None:
  1127. n = input.numel() / input.shape[1]
  1128. # This doesn't strictly match eager's numerics, which accumulates var sum and then directly applies the correction
  1129. # But... that would require re-implementing var here, for negligible numerics gain on a tensor whose
  1130. # numerics probably don't matter.
  1131. squeezed_var = torch.squeeze(biased_var, reduction_dims)
  1132. unbiased_var = squeezed_var * (n / (n - 1))
  1133. new_running_var = momentum * unbiased_var + (1 - momentum) * running_var
  1134. if not functional:
  1135. running_var.copy_(new_running_var)
  1136. else:
  1137. assert running_mean is not None and running_var is not None
  1138. running_mean = running_mean.to(dtype=computation_dtype, copy=True)
  1139. new_running_mean = running_mean
  1140. running_var = running_var.to(dtype=computation_dtype, copy=True)
  1141. new_running_var = running_var
  1142. mean = running_mean
  1143. invstd = 1 / (torch.sqrt(running_var + eps))
  1144. # Very annoying inconsistency where CPU and CUDA give different shapes
  1145. if input.device.type != "cpu":
  1146. save_mean = running_mean
  1147. save_rstd = invstd
  1148. else:
  1149. save_mean = input.new_zeros((0,))
  1150. save_rstd = input.new_zeros((0,))
  1151. mean = _unsqueeze_to_dim(mean, input.dim() - 1)
  1152. invstd = _unsqueeze_to_dim(invstd, input.dim() - 1)
  1153. output = (input - mean) * invstd
  1154. if weight is None:
  1155. weight = input.new_ones(())
  1156. if bias is None:
  1157. bias = input.new_zeros(())
  1158. weight = _unsqueeze_to_dim(weight, input.dim() - 1)
  1159. bias = _unsqueeze_to_dim(bias, input.dim() - 1)
  1160. output = output * weight + bias
  1161. if input.device.type == "cpu":
  1162. save_mean = save_mean.to(dtype=input.dtype)
  1163. save_rstd = save_rstd.to(dtype=input.dtype)
  1164. return (
  1165. output.to(dtype=input.dtype),
  1166. save_mean,
  1167. save_rstd,
  1168. new_running_mean,
  1169. new_running_var,
  1170. )
  1171. @register_decomposition(aten.native_batch_norm)
  1172. def native_batch_norm(
  1173. input: Tensor,
  1174. weight: Optional[Tensor],
  1175. bias: Optional[Tensor],
  1176. running_mean: Optional[Tensor],
  1177. running_var: Optional[Tensor],
  1178. training: bool,
  1179. momentum: float,
  1180. eps: float,
  1181. ) -> Tuple[Tensor, Tensor, Tensor]:
  1182. output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
  1183. input, weight, bias, running_mean, running_var, training, momentum, eps, False
  1184. )
  1185. return output, save_mean, save_rstd
  1186. # TODO: this decomposition is NOT here to stay. We would much prefer replacing native_batch_norm
  1187. # with our new correctly schema'd _native_batch_norm_legit and its variants, but
  1188. # we cannot do that immediately in the C++ because it would be forwards incompatible
  1189. # with some mobile use cases.
  1190. #
  1191. # Since this change is most impactful for aot autograd/functionalization, we simply
  1192. # register this decomposition on the Autograd key for the python dispatcher (which is
  1193. # currently only used by aot autograd/functionalization and no one else, really).
  1194. # In two weeks or so, we should remove this decomposition and phase out the current native_batch_norm
  1195. # to be _native_batch_norm_legit and have the right schema (stating that there are input mutations).
  1196. @aten.native_batch_norm.default.py_impl(DispatchKey.Autograd)
  1197. def native_batch_norm_decomposition(
  1198. input: Tensor,
  1199. weight: Optional[Tensor],
  1200. bias: Optional[Tensor],
  1201. running_mean: Optional[Tensor],
  1202. running_var: Optional[Tensor],
  1203. training: bool,
  1204. momentum: float,
  1205. eps: float,
  1206. ) -> Tuple[Tensor, Tensor, Tensor]:
  1207. if running_mean is None and running_var is None:
  1208. return aten._native_batch_norm_legit(
  1209. input, weight, bias, training, momentum, eps
  1210. )
  1211. if running_mean is None:
  1212. raise RuntimeError(
  1213. "running_mean is None, but running_var is provided. "
  1214. "They should both be None or both be provided."
  1215. )
  1216. if running_var is None:
  1217. raise RuntimeError(
  1218. "running_var is None, but running_mean is provided. "
  1219. "They should both be None or both be provided."
  1220. )
  1221. return aten._native_batch_norm_legit(
  1222. input, weight, bias, running_mean, running_var, training, momentum, eps
  1223. )
  1224. @aten.unsafe_chunk.default.py_impl(DispatchKey.CompositeImplicitAutograd)
  1225. def unsafe_chunk_py_impl(tensor, chunks, dim=0) -> List[Tensor]:
  1226. dim_size = tensor.size(dim)
  1227. split_size = (dim_size + chunks - 1) // chunks
  1228. if split_size == 0 and dim_size == 0:
  1229. split_sizes = [split_size for _ in chunks]
  1230. split_sizes[chunks - 1] = split_size - (split_size * chunks - dim_size)
  1231. return torch.ops.aten.unsafe_split_with_sizes.default(tensor, split_sizes, dim)
  1232. return torch.ops.aten.unsafe_split.Tensor(tensor, split_size, dim)
  1233. @register_decomposition(aten._native_batch_norm_legit.default)
  1234. def _native_batch_norm_legit(
  1235. input: Tensor,
  1236. weight: Optional[Tensor],
  1237. bias: Optional[Tensor],
  1238. running_mean: Tensor,
  1239. running_var: Tensor,
  1240. training: bool,
  1241. momentum: float,
  1242. eps: float,
  1243. ) -> Tuple[Tensor, Tensor, Tensor]:
  1244. output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
  1245. input, weight, bias, running_mean, running_var, training, momentum, eps, False
  1246. )
  1247. return output, save_mean, save_rstd
  1248. @register_decomposition(aten._native_batch_norm_legit.no_stats)
  1249. def _native_batch_norm_legit_no_stats(
  1250. input: Tensor,
  1251. weight: Optional[Tensor],
  1252. bias: Optional[Tensor],
  1253. training: bool,
  1254. momentum: float,
  1255. eps: float,
  1256. ) -> Tuple[Tensor, Tensor, Tensor]:
  1257. output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
  1258. input, weight, bias, None, None, training, momentum, eps, False
  1259. )
  1260. return output, save_mean, save_rstd
  1261. @register_decomposition(aten._native_batch_norm_legit_functional.default)
  1262. def _native_batch_norm_legit_functional(
  1263. input: Tensor,
  1264. weight: Optional[Tensor],
  1265. bias: Optional[Tensor],
  1266. running_mean: Tensor,
  1267. running_var: Tensor,
  1268. training: bool,
  1269. momentum: float,
  1270. eps: float,
  1271. ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
  1272. (
  1273. output,
  1274. save_mean,
  1275. save_rstd,
  1276. new_running_mean,
  1277. new_running_var,
  1278. ) = native_batch_norm_helper(
  1279. input, weight, bias, running_mean, running_var, training, momentum, eps, True
  1280. )
  1281. assert new_running_mean is not None, "new_running_mean should not be None"
  1282. assert new_running_var is not None, "new_running_var should not be None"
  1283. return output, save_mean, save_rstd, new_running_mean, new_running_var
  1284. @register_decomposition(aten._fused_dropout)
  1285. @pw_cast_for_opmath
  1286. def _fused_dropout_decomposition(input, p, generator=None):
  1287. assert generator is None
  1288. mask = (torch.rand_like(input) < p).to(dtype=torch.uint8)
  1289. res = mask.type_as(input) * input * (1.0 / p)
  1290. return (res, mask)
  1291. @register_decomposition(aten._to_copy)
  1292. def _to_copy(
  1293. x: Tensor,
  1294. *,
  1295. dtype: Optional[torch.dtype] = None,
  1296. layout=None,
  1297. device: Optional[torch.device] = None,
  1298. pin_memory: bool = False,
  1299. non_blocking: bool = False,
  1300. memory_format: Optional[torch.memory_format] = None,
  1301. ):
  1302. assert not layout or layout == torch.strided, "TODO"
  1303. assert not pin_memory, "TODO"
  1304. if device is None and dtype is None and memory_format is None:
  1305. return x.clone()
  1306. dtype_converted = False
  1307. if device is not None and device != x.device:
  1308. # avoid conversions on cpu
  1309. if dtype is not None and device.type == "cpu":
  1310. x = torch._prims.convert_element_type(x, dtype)
  1311. dtype_converted = True
  1312. x = torch._prims.device_put(x, device)
  1313. if dtype is not None and not dtype_converted:
  1314. x = torch._prims.convert_element_type(x, dtype)
  1315. if memory_format is not None: # no ref/prim for memory format
  1316. return torch.clone(x, memory_format=memory_format)
  1317. return x
  1318. # Questionable decompositions
  1319. # This is only valid if we're running the graph without autograd, such as if the backward pass has been traced.
  1320. # Note that this decomposition causes issues with in-place ops
  1321. @register_decomposition([aten.detach, aten.lift, aten.lift_fresh])
  1322. def nop_decomposition(x):
  1323. return aten.alias(x)
  1324. # Also register to the Autograd dispatch key, so this decomp can run above autograd.
  1325. # native_batch_norm needs to decompose into other ops before autograd.
  1326. @aten.cudnn_batch_norm.default.py_impl(DispatchKey.Autograd)
  1327. @register_decomposition(aten.cudnn_batch_norm)
  1328. def cudnn_batch_norm(
  1329. input: Tensor,
  1330. weight: Tensor,
  1331. bias: Optional[Tensor],
  1332. running_mean: Optional[Tensor],
  1333. running_var: Optional[Tensor],
  1334. training: bool,
  1335. exponential_average_factor: float,
  1336. epsilon: float,
  1337. ):
  1338. a, b, c = aten.native_batch_norm(
  1339. input,
  1340. weight,
  1341. bias,
  1342. running_mean,
  1343. running_var,
  1344. training,
  1345. exponential_average_factor,
  1346. epsilon,
  1347. )
  1348. # Cudnn return running mean and variance when training is True
  1349. if training:
  1350. return (a, b, c, input.new_zeros((0,), dtype=torch.uint8))
  1351. return (
  1352. a,
  1353. weight.new_zeros((0,)),
  1354. weight.new_zeros((0,)),
  1355. input.new_zeros((0,), dtype=torch.uint8),
  1356. )
  1357. def _broadcast_batch_norm_backward(x, broadcast_mask):
  1358. for axis, mask in enumerate(broadcast_mask):
  1359. if mask == 1 and not (axis < x.ndim and x.shape[axis] == broadcast_mask[axis]):
  1360. x = x.unsqueeze(axis)
  1361. return x
  1362. @register_decomposition(aten.native_batch_norm_backward)
  1363. def native_batch_norm_backward(
  1364. grad_out: Tensor,
  1365. input: Tensor,
  1366. weight: Optional[Tensor],
  1367. running_mean: Optional[Tensor],
  1368. running_var: Optional[Tensor],
  1369. save_mean: Optional[Tensor],
  1370. save_invstd: Optional[Tensor],
  1371. train: bool,
  1372. eps: float,
  1373. output_mask: List[bool],
  1374. ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
  1375. input_dtype = input.dtype
  1376. if weight is not None:
  1377. weight_dtype = weight.dtype
  1378. else:
  1379. weight_dtype = input_dtype
  1380. computation_dtype = utils.get_computation_dtype(input.dtype)
  1381. (
  1382. grad_out_cast,
  1383. input_cast,
  1384. weight_cast,
  1385. running_mean_cast,
  1386. running_var_cast,
  1387. save_mean_cast,
  1388. save_invstd_cast,
  1389. ) = [
  1390. x.to(computation_dtype) if x is not None else x
  1391. for x in (
  1392. grad_out,
  1393. input,
  1394. weight,
  1395. running_mean,
  1396. running_var,
  1397. save_mean,
  1398. save_invstd,
  1399. )
  1400. ]
  1401. input_shape = input.shape
  1402. input_rank = input.dim()
  1403. assert input_rank >= 2, "rank of the input must be at least 2"
  1404. axis = 1
  1405. num_features = prod(list(input_shape)) / input_shape[axis]
  1406. mean = save_mean_cast
  1407. invstd = save_invstd_cast
  1408. if train:
  1409. assert save_mean_cast is not None and save_invstd_cast is not None
  1410. else:
  1411. assert running_mean_cast is not None and running_var_cast is not None
  1412. mean = running_mean_cast
  1413. invstd = torch.rsqrt(running_var_cast + eps)
  1414. broadcast_mask: List[int] = [1] * input_rank
  1415. broadcast_mask[axis] = input_shape[axis]
  1416. reduction_axes: List[int] = []
  1417. for i in range(input_rank):
  1418. if i != axis:
  1419. reduction_axes.append(i)
  1420. mean = _broadcast_batch_norm_backward(mean, broadcast_mask) # type: ignore[arg-type]
  1421. norm = 1.0 / num_features
  1422. grad_output_sum = torch.sum(grad_out_cast, reduction_axes) # type: ignore[arg-type]
  1423. dot_p = torch.sum(grad_out_cast * (input_cast - mean), reduction_axes) # type: ignore[operator]
  1424. grad_mean = _broadcast_batch_norm_backward(grad_output_sum * norm, broadcast_mask)
  1425. proj_scale = _broadcast_batch_norm_backward(torch.mul(dot_p * norm, invstd * invstd), broadcast_mask) # type: ignore[operator]
  1426. if weight_cast is None:
  1427. grad_scale = _broadcast_batch_norm_backward(invstd, broadcast_mask) * 1.0 # type: ignore[arg-type]
  1428. else:
  1429. grad_scale = _broadcast_batch_norm_backward(
  1430. invstd * weight_cast, broadcast_mask
  1431. )
  1432. if train:
  1433. proj = (input_cast - mean) * proj_scale # type: ignore[operator]
  1434. grad_input = ((grad_out_cast - proj) - grad_mean) * grad_scale
  1435. else:
  1436. grad_input = grad_out_cast * grad_scale
  1437. if output_mask[1]:
  1438. grad_weight = dot_p * invstd
  1439. else:
  1440. grad_weight = None # "None" doesn't work with vjp, should use zeros for vjp
  1441. if output_mask[2]:
  1442. grad_bias = grad_output_sum
  1443. else:
  1444. grad_bias = None # "None" doesn't work with vjp, should use zeros for vjp
  1445. return (
  1446. grad_input.to(input_dtype),
  1447. _maybe_cast(grad_weight, weight_dtype),
  1448. _maybe_cast(grad_bias, weight_dtype),
  1449. )
  1450. @register_decomposition(aten.cudnn_batch_norm_backward)
  1451. def cudnn_batch_norm_backward(
  1452. input: Tensor,
  1453. grad_output: Tensor,
  1454. weight: Tensor,
  1455. running_mean: Optional[Tensor],
  1456. running_var: Optional[Tensor],
  1457. save_mean: Optional[Tensor],
  1458. save_var: Optional[Tensor],
  1459. epsilon: float,
  1460. reserveSpace: Tensor,
  1461. ):
  1462. return aten.native_batch_norm_backward(
  1463. grad_output,
  1464. input,
  1465. weight,
  1466. running_mean,
  1467. running_var,
  1468. save_mean,
  1469. save_var,
  1470. True,
  1471. epsilon,
  1472. [True, True, True],
  1473. )
  1474. @register_decomposition(aten._adaptive_avg_pool2d)
  1475. @pw_cast_for_opmath
  1476. def adaptive_avg_pool2d(input: Tensor, output_size: Tuple[int, int]):
  1477. # Preconditions
  1478. device = input.device
  1479. shape = input.shape
  1480. ndim = len(shape)
  1481. utils.check(
  1482. ndim in (3, 4),
  1483. lambda: f"adaptive_avg_pool2d(): Expected 3D or 4D tensor, but got {ndim}",
  1484. )
  1485. for d in input.shape[-2:]:
  1486. utils.check(
  1487. d != 0,
  1488. lambda: "adaptive_avg_pool2d(): Expected input to have non-zero size for "
  1489. f"non-batch dimensions, but input has shape {tuple(shape)}.",
  1490. )
  1491. # Optimisation (we should also do this in the kernel implementation)
  1492. if shape[-2] % output_size[-2] == 0 and shape[-1] % output_size[-1] == 0:
  1493. stride = tuple(i // o for i, o in zip(shape[-2:], output_size))
  1494. kernel = tuple(
  1495. i - (o - 1) * s for i, o, s in zip(shape[-2:], output_size, stride)
  1496. )
  1497. return torch.nn.functional.avg_pool2d(input, kernel, stride)
  1498. def start_index(a, b, c):
  1499. return torch.div(a * c, b, rounding_mode="trunc")
  1500. def end_index(a, b, c):
  1501. return torch.div((a + 1) * c + b - 1, b, rounding_mode="trunc")
  1502. def compute_idx(in_size, out_size):
  1503. orange = torch.arange(out_size, device=device, dtype=torch.int64)
  1504. i0 = start_index(orange, out_size, in_size)
  1505. # Let length = end_index - start_index, i.e. the length of the pooling kernels
  1506. # length.max() can be computed analytically as follows:
  1507. maxlength = in_size // out_size + 1
  1508. in_size_mod = in_size % out_size
  1509. # adaptive = True iff there are kernels with different lengths
  1510. adaptive = not (in_size_mod == 0 or out_size % in_size_mod == 0)
  1511. if adaptive:
  1512. maxlength += 1
  1513. elif in_size_mod == 0:
  1514. maxlength -= 1
  1515. range_max = torch.arange(maxlength, device=device, dtype=torch.int64)
  1516. idx = i0.unsqueeze(-1) + range_max
  1517. if adaptive:
  1518. # Need to clamp to avoid accesing out-of-bounds memory
  1519. # TODO make minimum accept scalars
  1520. maxval = torch.scalar_tensor(
  1521. in_size - 1, dtype=idx.dtype, device=idx.device
  1522. )
  1523. idx = torch.minimum(idx, maxval)
  1524. # Compute the lenghts
  1525. i1 = end_index(orange, out_size, in_size)
  1526. length = i1 - i0
  1527. else:
  1528. length = maxlength
  1529. return idx, length, range_max, adaptive
  1530. # length is not None if it's constant, otherwise we'll need to compute it
  1531. idxh, length_h, range_max_h, adaptive_h = compute_idx(shape[-2], output_size[-2])
  1532. idxw, length_w, range_max_w, adaptive_w = compute_idx(shape[-1], output_size[-1])
  1533. vals = input[..., _unsqueeze_to_dim(idxh, 4), idxw]
  1534. # Shortcut for the simpler case
  1535. if not adaptive_h and not adaptive_w:
  1536. return torch.mean(vals, dim=(-3, -1))
  1537. def maybe_mask(vals, length, range_max, adaptive, dim):
  1538. if isinstance(length, IntLike):
  1539. return vals, length
  1540. else:
  1541. # zero-out the things we didn't really want to select
  1542. assert dim < 0
  1543. # hack
  1544. mask = range_max >= length.unsqueeze(-1)
  1545. if dim == -2:
  1546. mask = _unsqueeze_to_dim(mask, 4)
  1547. vals = torch.masked_fill(vals, mask, 0.0)
  1548. # Compute the length of each window
  1549. length = _unsqueeze_to_dim(length, -dim)
  1550. return vals, length
  1551. vals, length_h = maybe_mask(
  1552. vals, length_h, range_max_h, adaptive=adaptive_h, dim=-2
  1553. )
  1554. vals, length_w = maybe_mask(
  1555. vals, length_w, range_max_w, adaptive=adaptive_w, dim=-1
  1556. )
  1557. # We unroll the sum as we assume that the kernels are going to be small
  1558. ret = None
  1559. for i, j in product(range(vals.shape[-3]), range(vals.shape[-1])):
  1560. if ret is None:
  1561. ret = vals[..., i, :, j]
  1562. else:
  1563. ret = ret + vals[..., i, :, j]
  1564. return ret / (length_h * length_w)
  1565. @register_decomposition(aten.index_add_)
  1566. def index_add_(
  1567. x: TensorLike,
  1568. dim: int,
  1569. index: TensorLike,
  1570. tensor: TensorLike,
  1571. *,
  1572. alpha: NumberType = 1,
  1573. ):
  1574. return _index_add(x, dim, index, tensor, inplace=True, alpha=alpha)
  1575. @register_decomposition(aten.index_add)
  1576. @out_wrapper()
  1577. def index_add(
  1578. x: TensorLike,
  1579. dim: int,
  1580. index: TensorLike,
  1581. tensor: TensorLike,
  1582. *,
  1583. alpha: NumberType = 1,
  1584. ):
  1585. return _index_add(x, dim, index, tensor, inplace=False, alpha=alpha)
  1586. def _index_add(
  1587. x: TensorLike,
  1588. dim: int,
  1589. index: TensorLike,
  1590. tensor: TensorLike,
  1591. *,
  1592. inplace: bool,
  1593. alpha: NumberType = 1,
  1594. ):
  1595. dim = utils.canonicalize_dims(x.ndim, dim)
  1596. utils.check(
  1597. index.ndim <= 1,
  1598. lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
  1599. )
  1600. if alpha != 1:
  1601. python_type = utils.dtype_to_type(x.dtype)
  1602. utils.check(
  1603. python_type == bool
  1604. or utils.is_weakly_lesser_type(type(alpha), python_type),
  1605. lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!",
  1606. )
  1607. tensor = tensor * alpha
  1608. # Treat scalars as elements of \R^1
  1609. zero_dim = x.ndim == 0
  1610. x1 = x.unsqueeze(0) if zero_dim else x
  1611. idx = (None,) * dim + (index,)
  1612. index_put = aten.index_put_ if inplace else aten.index_put
  1613. out = index_put(x1, idx, tensor, accumulate=True)
  1614. if inplace:
  1615. return x
  1616. else:
  1617. return out.squeeze(0) if zero_dim else out.contiguous()
  1618. @register_decomposition(aten.index_copy_)
  1619. def index_copy_(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike):
  1620. return _index_copy(x, dim, index, tensor, inplace=True)
  1621. @register_decomposition(aten.index_copy)
  1622. @out_wrapper()
  1623. def index_copy(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike):
  1624. return _index_copy(x, dim, index, tensor, inplace=False)
  1625. def _index_copy(
  1626. x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike, *, inplace: bool
  1627. ):
  1628. dim = utils.canonicalize_dims(x.ndim, dim)
  1629. utils.check(
  1630. index.ndim <= 1,
  1631. lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
  1632. )
  1633. # Treat scalars as elements of \R^1
  1634. zero_dim = x.ndim == 0
  1635. x1 = x.unsqueeze(0) if zero_dim else x
  1636. idx = (None,) * dim + (index,)
  1637. index_put = aten.index_put_ if inplace else aten.index_put
  1638. out = index_put(x1, idx, tensor)
  1639. if inplace:
  1640. return x
  1641. else:
  1642. return out.squeeze(0) if zero_dim else out.contiguous()
  1643. # nb: Should use acc_t, not op_math
  1644. @register_decomposition(aten.log_sigmoid_forward)
  1645. @out_wrapper("output", "buffer")
  1646. @pw_cast_for_opmath
  1647. def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]:
  1648. min = torch.minimum(self.new_zeros(()), self)
  1649. z = torch.exp(-torch.abs(self))
  1650. if self.is_cuda:
  1651. buffer = self.new_zeros((0,))
  1652. else:
  1653. buffer = z
  1654. return min - torch.log1p(z), buffer
  1655. @register_decomposition(aten.uniform)
  1656. def uniform(
  1657. x: Tensor,
  1658. low: Union[bool, int, float] = 0.0,
  1659. high: Union[bool, int, float] = 1.0,
  1660. ):
  1661. return prims._uniform_helper(
  1662. x.shape,
  1663. low=sym_float(low),
  1664. high=sym_float(high),
  1665. dtype=x.dtype,
  1666. device=x.device,
  1667. )
  1668. @register_decomposition(aten.uniform_)
  1669. def uniform_(self, low=0, high=1, generator=None):
  1670. assert generator is None
  1671. return self.copy_((high - low) * torch.rand_like(self) + low)
  1672. # aten/src/ATen/native/UpSample.cpp compute_output_size
  1673. def upsample_compute_output_size(input_size, output_size, scale_factors):
  1674. spatial_dimensions = len(input_size) - 2
  1675. if output_size is not None:
  1676. utils.check(
  1677. scale_factors is None,
  1678. lambda: "Must specify exactly one of output_size and scale_factors",
  1679. )
  1680. utils.check(len(output_size) == spatial_dimensions, lambda: "")
  1681. return output_size
  1682. if scale_factors is not None:
  1683. # NB: this isn't necessary lol
  1684. utils.check(
  1685. output_size is None,
  1686. lambda: "Must specify exactly one of output_size and scale_factors",
  1687. )
  1688. utils.check(len(scale_factors) == spatial_dimensions, lambda: "")
  1689. output_size = []
  1690. for i, s in enumerate(scale_factors):
  1691. if int(s) == s:
  1692. output_size.append(input_size[i + 2] * int(s))
  1693. else:
  1694. output_size.append(sym_int(input_size[i + 2] * s))
  1695. return output_size
  1696. utils.check(
  1697. False, lambda: "Must specify exactly one of output_size and scale_factors"
  1698. )
  1699. def get_scale_value(scales, idx):
  1700. if scales is None:
  1701. return None
  1702. return scales[idx]
  1703. @register_decomposition(aten.upsample_nearest1d.vec)
  1704. @aten.upsample_nearest1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
  1705. @aten.upsample_nearest1d.vec.py_impl(DispatchKey.Autograd)
  1706. def upsample_nearest1d_vec(input, output_size, scale_factors):
  1707. osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
  1708. scale = get_scale_value(scale_factors, 0)
  1709. return upsample_nearest1d(input, osize, scale)
  1710. @register_decomposition(aten.upsample_nearest2d.vec)
  1711. @aten.upsample_nearest2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
  1712. @aten.upsample_nearest2d.vec.py_impl(DispatchKey.Autograd)
  1713. def upsample_nearest2d_vec(input, output_size, scale_factors):
  1714. osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
  1715. scale_h = get_scale_value(scale_factors, 0)
  1716. scale_w = get_scale_value(scale_factors, 1)
  1717. return upsample_nearest2d(input, osize, scale_h, scale_w)
  1718. @register_decomposition(aten.upsample_nearest3d.vec)
  1719. @aten.upsample_nearest3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
  1720. @aten.upsample_nearest3d.vec.py_impl(DispatchKey.Autograd)
  1721. def upsample_nearest3d_vec(input, output_size, scale_factors):
  1722. osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
  1723. scale_d = get_scale_value(scale_factors, 0)
  1724. scale_h = get_scale_value(scale_factors, 1)
  1725. scale_w = get_scale_value(scale_factors, 2)
  1726. return upsample_nearest3d(input, osize, scale_d, scale_h, scale_w)
  1727. def _compute_upsample_nearest_indices(input, output_size, scales):
  1728. # For each dim in output_size, compute the set of input indices used
  1729. # to produce the upsampled output.
  1730. indices = []
  1731. num_spatial_dims = len(output_size)
  1732. for d in range(num_spatial_dims):
  1733. # Math matches aten/src/ATen/native/cpu/UpSampleKernel.cpp
  1734. # Indices are computed as following:
  1735. # scale = isize / osize
  1736. # input_index = floor(output_index * scale)
  1737. # Same as OpenCV INTER_NEAREST
  1738. osize = output_size[d]
  1739. output_indices = torch.arange(osize, dtype=input.dtype, device=input.device)
  1740. isize = input.shape[-num_spatial_dims + d]
  1741. scale = isize / (isize * scales[d]) if scales[d] is not None else isize / osize
  1742. input_indices = (output_indices * scale).to(torch.int64)
  1743. for _ in range(num_spatial_dims - 1 - d):
  1744. input_indices = input_indices.unsqueeze(-1)
  1745. indices.append(input_indices)
  1746. return tuple(indices)
  1747. @register_decomposition(aten.upsample_nearest1d.default)
  1748. @aten.upsample_nearest1d.default.py_impl(DispatchKey.Autograd)
  1749. @pw_cast_for_opmath
  1750. def upsample_nearest1d(
  1751. input: Tensor,
  1752. output_size: List[int],
  1753. scales: Optional[float] = None,
  1754. ) -> Tensor:
  1755. (l_indices,) = _compute_upsample_nearest_indices(input, output_size, (scales,))
  1756. result = input[:, :, l_indices]
  1757. return result
  1758. @register_decomposition(aten.upsample_nearest2d.default)
  1759. @aten.upsample_nearest2d.default.py_impl(DispatchKey.Autograd)
  1760. @pw_cast_for_opmath
  1761. def upsample_nearest2d(
  1762. input: Tensor,
  1763. output_size: List[int],
  1764. scales_h: Optional[float] = None,
  1765. scales_w: Optional[float] = None,
  1766. ) -> Tensor:
  1767. h_indices, w_indices = _compute_upsample_nearest_indices(
  1768. input, output_size, (scales_h, scales_w)
  1769. )
  1770. result = input[:, :, h_indices, w_indices]
  1771. # convert output to correct memory format, if necessary
  1772. memory_format = utils.suggest_memory_format(input)
  1773. # following "heuristic: only use channels_last path when it's faster than the contiguous path"
  1774. _, n_channels, _, _ = input.shape
  1775. if input.device.type == "cuda" and n_channels < 4:
  1776. memory_format = torch.contiguous_format
  1777. result = result.contiguous(memory_format=memory_format)
  1778. return result
  1779. @register_decomposition(aten.upsample_nearest3d.default)
  1780. @aten.upsample_nearest3d.default.py_impl(DispatchKey.Autograd)
  1781. @pw_cast_for_opmath
  1782. def upsample_nearest3d(
  1783. input: Tensor,
  1784. output_size: List[int],
  1785. scales_d: Optional[float] = None,
  1786. scales_h: Optional[float] = None,
  1787. scales_w: Optional[float] = None,
  1788. ) -> Tensor:
  1789. d_indices, h_indices, w_indices = _compute_upsample_nearest_indices(
  1790. input, output_size, (scales_d, scales_h, scales_w)
  1791. )
  1792. result = input[:, :, d_indices, h_indices, w_indices]
  1793. return result
  1794. def gather_params(params, has_biases, has_projections):
  1795. if has_biases and has_projections:
  1796. group_size = 5
  1797. elif has_biases:
  1798. group_size = 4
  1799. elif has_projections:
  1800. group_size = 3
  1801. else:
  1802. group_size = 2
  1803. assert len(params) % group_size == 0, len(params)
  1804. return [
  1805. tuple(params[i : i + group_size]) for i in range(0, len(params), group_size)
  1806. ]
  1807. def params_hiddens(params, hiddens, i, bidirectional):
  1808. if bidirectional:
  1809. cur_params, cur_hidden = params[2 * i], hiddens[2 * i]
  1810. bidir_params, bidir_hidden = params[2 * i + 1], hiddens[2 * i + 1]
  1811. else:
  1812. cur_params, cur_hidden = params[i], hiddens[i]
  1813. bidir_params, bidir_hidden = None, None
  1814. return cur_params, cur_hidden, bidir_params, bidir_hidden
  1815. def update_hidden_for_packed(cur_hidden, last_batch_size, batch_size, hiddens):
  1816. assert last_batch_size > batch_size
  1817. hiddens.append(cur_hidden.narrow(0, batch_size, last_batch_size - batch_size))
  1818. return cur_hidden.narrow(0, 0, batch_size)
  1819. def update_hidden_for_packed_reverse(
  1820. cur_hidden, last_batch_size, batch_size, inp_hidden
  1821. ):
  1822. if last_batch_size == batch_size:
  1823. return cur_hidden
  1824. assert last_batch_size < batch_size
  1825. return torch.concat(
  1826. (
  1827. cur_hidden,
  1828. inp_hidden.narrow(0, last_batch_size, batch_size - last_batch_size),
  1829. )
  1830. )
  1831. def one_layer_rnn_data(
  1832. inp, hidden, params, has_biases, hidden_fn, batch_sizes, reverse=False
  1833. ):
  1834. ih_weight = params[0]
  1835. hh_weight = params[1]
  1836. ih_bias = params[2] if has_biases else None
  1837. hh_bias = params[3] if has_biases else None
  1838. step_output = []
  1839. hiddens: List["torch.Tensor"] = []
  1840. last_batch_size = batch_sizes[-1] if reverse else batch_sizes[0]
  1841. cur_hidden = hidden.narrow(0, 0, last_batch_size)
  1842. split_inp = torch.split(inp, list(batch_sizes))
  1843. if reverse:
  1844. split_inp = split_inp[::-1]
  1845. for inp in split_inp:
  1846. i = inp.shape[0]
  1847. if last_batch_size == i:
  1848. pass # don't update cur_hidden
  1849. # this will only happen when reverse=False, since batch sizes are sorted largest -> smallest
  1850. elif reverse:
  1851. cur_hidden = update_hidden_for_packed_reverse(
  1852. cur_hidden, last_batch_size, i, hidden
  1853. )
  1854. else:
  1855. cur_hidden = update_hidden_for_packed(
  1856. cur_hidden, last_batch_size, i, hiddens
  1857. )
  1858. cur_hidden = hidden_fn(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias)
  1859. last_batch_size = i
  1860. step_output.append(cur_hidden)
  1861. if reverse:
  1862. step_output.reverse()
  1863. else:
  1864. hiddens.append(cur_hidden)
  1865. hiddens.reverse()
  1866. out = torch.cat(step_output, 0)
  1867. hidden_out = torch.cat(hiddens, 0) if not reverse else cur_hidden
  1868. return out, hidden_out
  1869. def rnn_cell(nonlinearity):
  1870. def inner(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias):
  1871. return nonlinearity(F.linear(cur_hidden, hh_weight, hh_bias) + i)
  1872. return inner
  1873. def rnn_cell_data(nonlinearity):
  1874. def inner(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias):
  1875. i = F.linear(i, ih_weight, ih_bias)
  1876. return nonlinearity(F.linear(cur_hidden, hh_weight, hh_bias) + i)
  1877. return inner
  1878. def one_layer_rnn(inp, hidden, params, has_biases, hidden_fn, reverse=False):
  1879. ih_weight = params[0]
  1880. hh_weight = params[1]
  1881. ih_bias = params[2] if has_biases else None
  1882. hh_bias = params[3] if has_biases else None
  1883. precomputed_input = F.linear(inp, ih_weight, ih_bias)
  1884. precomputed_input = precomputed_input.flip(0) if reverse else precomputed_input
  1885. cur_hidden = hidden.unsqueeze(0)
  1886. step_output = []
  1887. for i in precomputed_input:
  1888. cur_hidden = hidden_fn(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias)
  1889. step_output.append(cur_hidden)
  1890. if reverse:
  1891. step_output.reverse()
  1892. out = torch.cat(step_output, 0)
  1893. return out, cur_hidden.squeeze(0)
  1894. def _rnn_helper(
  1895. input,
  1896. hidden,
  1897. params,
  1898. has_biases,
  1899. num_layers,
  1900. dropout,
  1901. train,
  1902. bidirectional,
  1903. batch_first,
  1904. layer_fn,
  1905. ):
  1906. input = input.transpose(0, 1) if batch_first else input
  1907. final_hiddens = []
  1908. for i in range(num_layers):
  1909. cur_params, cur_hidden, bidir_params, bidir_hidden = params_hiddens(
  1910. params, hidden, i, bidirectional
  1911. )
  1912. dropout = dropout if (train and num_layers < i - 1) else 0.0
  1913. fwd_inp, fwd_hidden = layer_fn(input, cur_hidden, cur_params, has_biases)
  1914. final_hiddens.append(fwd_hidden)
  1915. if bidirectional:
  1916. bwd_inp, bwd_hidden = layer_fn(
  1917. input, bidir_hidden, bidir_params, has_biases, reverse=True
  1918. )
  1919. final_hiddens.append(bwd_hidden)
  1920. if bidirectional:
  1921. input = torch.cat([fwd_inp, bwd_inp], fwd_inp.dim() - 1)
  1922. else:
  1923. input = fwd_inp
  1924. if dropout != 0 and train and i < num_layers - 1:
  1925. input = torch.dropout(input, dropout, train=True)
  1926. input = input.transpose(0, 1) if batch_first else input
  1927. return input, final_hiddens
  1928. @register_decomposition(aten.rnn_tanh.input)
  1929. @aten.rnn_tanh.input.py_impl(DispatchKey.CompositeImplicitAutograd)
  1930. @aten.rnn_tanh.input.py_impl(DispatchKey.Autograd)
  1931. def rnn_tanh_input(
  1932. input,
  1933. hx,
  1934. params,
  1935. has_biases,
  1936. num_layers,
  1937. dropout,
  1938. train,
  1939. bidirectional,
  1940. batch_first,
  1941. ):
  1942. hidden = hx.unbind(0)
  1943. params = gather_params(params, has_biases, False)
  1944. out, final_hiddens = _rnn_helper(
  1945. input,
  1946. hidden,
  1947. params,
  1948. has_biases,
  1949. num_layers,
  1950. dropout,
  1951. train,
  1952. bidirectional,
  1953. batch_first,
  1954. partial(one_layer_rnn, hidden_fn=rnn_cell(torch.tanh)),
  1955. )
  1956. return out, torch.stack(final_hiddens, 0)
  1957. @register_decomposition(aten.rnn_relu.input)
  1958. @aten.rnn_relu.input.py_impl(DispatchKey.CompositeImplicitAutograd)
  1959. @aten.rnn_relu.input.py_impl(DispatchKey.Autograd)
  1960. def rnn_relu_input(
  1961. input,
  1962. hx,
  1963. params,
  1964. has_biases,
  1965. num_layers,
  1966. dropout,
  1967. train,
  1968. bidirectional,
  1969. batch_first,
  1970. ):
  1971. hidden = hx.unbind(0)
  1972. params = gather_params(params, has_biases, False)
  1973. out, final_hiddens = _rnn_helper(
  1974. input,
  1975. hidden,
  1976. params,
  1977. has_biases,
  1978. num_layers,
  1979. dropout,
  1980. train,
  1981. bidirectional,
  1982. batch_first,
  1983. partial(one_layer_rnn, hidden_fn=rnn_cell(torch.relu)),
  1984. )
  1985. return out, torch.stack(final_hiddens, 0)
  1986. @register_decomposition(aten.rnn_relu.data)
  1987. @aten.rnn_relu.data.py_impl(DispatchKey.CompositeImplicitAutograd)
  1988. @aten.rnn_relu.data.py_impl(DispatchKey.Autograd)
  1989. def rnn_relu_data(
  1990. data,
  1991. batch_sizes,
  1992. hx,
  1993. params,
  1994. has_biases,
  1995. num_layers,
  1996. dropout,
  1997. train,
  1998. bidirectional,
  1999. ):
  2000. hidden = hx.unbind(0)
  2001. params = gather_params(params, has_biases, False)
  2002. out, final_hiddens = _rnn_helper(
  2003. data,
  2004. hidden,
  2005. params,
  2006. has_biases,
  2007. num_layers,
  2008. dropout,
  2009. train,
  2010. bidirectional,
  2011. False,
  2012. partial(
  2013. one_layer_rnn_data,
  2014. batch_sizes=batch_sizes,
  2015. hidden_fn=rnn_cell_data(torch.relu),
  2016. ),
  2017. )
  2018. return out, torch.stack(final_hiddens, 0)
  2019. @register_decomposition(aten.rnn_tanh.data)
  2020. @aten.rnn_tanh.data.py_impl(DispatchKey.CompositeImplicitAutograd)
  2021. @aten.rnn_tanh.data.py_impl(DispatchKey.Autograd)
  2022. def rnn_tanh_data(
  2023. data,
  2024. batch_sizes,
  2025. hx,
  2026. params,
  2027. has_biases,
  2028. num_layers,
  2029. dropout,
  2030. train,
  2031. bidirectional,
  2032. ):
  2033. hidden = hx.unbind(0)
  2034. params = gather_params(params, has_biases, False)
  2035. out, final_hiddens = _rnn_helper(
  2036. data,
  2037. hidden,
  2038. params,
  2039. has_biases,
  2040. num_layers,
  2041. dropout,
  2042. train,
  2043. bidirectional,
  2044. False,
  2045. partial(
  2046. one_layer_rnn_data,
  2047. batch_sizes=batch_sizes,
  2048. hidden_fn=rnn_cell_data(torch.tanh),
  2049. ),
  2050. )
  2051. return out, torch.stack(final_hiddens, 0)
  2052. def lstm_cell(inp, hx, cx, hh_weight, hh_bias, hr_weight, chunk_dim):
  2053. gates = F.linear(hx, hh_weight, hh_bias) + inp
  2054. chunked_gates = gates.chunk(4, chunk_dim)
  2055. in_gate = chunked_gates[0].sigmoid()
  2056. forget_gate = chunked_gates[1].sigmoid()
  2057. cell_gate = chunked_gates[2].tanh()
  2058. out_gate = chunked_gates[3].sigmoid()
  2059. cy = forget_gate * cx + (in_gate * cell_gate)
  2060. hy = out_gate * cy.tanh()
  2061. hy = hy if hr_weight is None else F.linear(hy, hr_weight, None)
  2062. return hy, cy
  2063. def one_layer_lstm(inp, hidden, params, has_biases, reverse=False):
  2064. ih_weight = params[0]
  2065. hh_weight = params[1]
  2066. ih_bias = params[2] if has_biases else None
  2067. hh_bias = params[3] if has_biases else None
  2068. hr_weight = (
  2069. params[4] if len(params) == 5 else params[2] if len(params) == 3 else None
  2070. )
  2071. hx = hidden[0].unsqueeze(0)
  2072. cx = hidden[1].unsqueeze(0)
  2073. precomputed_input = F.linear(inp, ih_weight, ih_bias)
  2074. precomputed_input = precomputed_input.flip(0) if reverse else precomputed_input
  2075. step_output = []
  2076. for inp in precomputed_input:
  2077. hx, cx = lstm_cell(inp, hx, cx, hh_weight, hh_bias, hr_weight, chunk_dim=2)
  2078. step_output.append(hx)
  2079. if reverse:
  2080. step_output.reverse()
  2081. out = torch.cat(step_output, 0)
  2082. return out, (hx.squeeze(1), cx.squeeze(1))
  2083. def one_layer_lstm_data(inp, hidden, params, has_biases, batch_sizes, reverse=False):
  2084. ih_weight = params[0]
  2085. hh_weight = params[1]
  2086. ih_bias = params[2] if has_biases else None
  2087. hh_bias = params[3] if has_biases else None
  2088. hr_weight = (
  2089. params[4] if len(params) == 5 else params[2] if len(params) == 3 else None
  2090. )
  2091. step_output = []
  2092. hiddens = []
  2093. last_batch_size = batch_sizes[-1] if reverse else batch_sizes[0]
  2094. split_inp = torch.split(inp, list(batch_sizes))
  2095. if reverse:
  2096. split_inp = split_inp[::-1]
  2097. orig_hx = hidden[0]
  2098. orig_cx = hidden[1]
  2099. hx, cx = orig_hx.narrow(0, 0, last_batch_size), orig_cx.narrow(
  2100. 0, 0, last_batch_size
  2101. )
  2102. for inp in split_inp:
  2103. i = inp.shape[0]
  2104. inp = F.linear(inp, ih_weight, ih_bias)
  2105. # this will only happen when reverse=False, since batch sizes are sorted largest -> smallest
  2106. if i < last_batch_size:
  2107. hiddens.append(
  2108. (
  2109. hx.narrow(0, i, last_batch_size - i),
  2110. cx.narrow(0, i, last_batch_size - i),
  2111. )
  2112. )
  2113. hx, cx = hx.narrow(0, 0, i), cx.narrow(0, 0, i)
  2114. # this will only happen when reverse=True
  2115. if i > last_batch_size:
  2116. hx = torch.concat(
  2117. (hx, orig_hx.narrow(0, last_batch_size, i - last_batch_size)), 0
  2118. )
  2119. cx = torch.concat(
  2120. (cx, orig_cx.narrow(0, last_batch_size, i - last_batch_size)), 0
  2121. )
  2122. hx, cx = lstm_cell(inp, hx, cx, hh_weight, hh_bias, hr_weight, chunk_dim=1)
  2123. last_batch_size = i
  2124. step_output.append(hx)
  2125. if reverse:
  2126. step_output.reverse()
  2127. hidden_out = (hx, cx)
  2128. else:
  2129. hiddens.append((hx, cx))
  2130. hiddens.reverse()
  2131. hidden0, hidden1 = zip(*hiddens)
  2132. hidden_out = torch.cat(hidden0, 0), torch.cat(hidden1, 0)
  2133. out = torch.cat(step_output, 0)
  2134. return out, hidden_out
  2135. @register_decomposition(aten.lstm.input)
  2136. @aten.lstm.input.py_impl(DispatchKey.CompositeImplicitAutograd)
  2137. @aten.lstm.input.py_impl(DispatchKey.Autograd)
  2138. def lstm_impl(
  2139. input,
  2140. hx,
  2141. params,
  2142. has_biases,
  2143. num_layers,
  2144. dropout,
  2145. train,
  2146. bidirectional,
  2147. batch_first,
  2148. ):
  2149. assert len(hx) == 2, "lstm expects two hidden states"
  2150. params = gather_params(params, has_biases, hx[0].size(2) != hx[1].size(2))
  2151. hidden = list(zip(hx[0], hx[1]))
  2152. out, final_hiddens = _rnn_helper(
  2153. input,
  2154. hidden,
  2155. params,
  2156. has_biases,
  2157. num_layers,
  2158. dropout,
  2159. train,
  2160. bidirectional,
  2161. batch_first,
  2162. one_layer_lstm,
  2163. )
  2164. final_hiddens = list(zip(*final_hiddens))
  2165. return out, torch.stack(final_hiddens[0], 0), torch.stack(final_hiddens[1], 0)
  2166. @register_decomposition(aten.lstm.data)
  2167. @aten.lstm.data.py_impl(DispatchKey.CompositeImplicitAutograd)
  2168. @aten.lstm.data.py_impl(DispatchKey.Autograd)
  2169. def lstm_data_impl(
  2170. data,
  2171. batch_sizes,
  2172. hx,
  2173. params,
  2174. has_biases,
  2175. num_layers,
  2176. dropout,
  2177. train,
  2178. bidirectional,
  2179. ):
  2180. assert len(hx) == 2, "lstm expects two hidden states"
  2181. params = gather_params(params, has_biases, hx[0].size(2) != hx[1].size(2))
  2182. hidden = list(zip(hx[0], hx[1]))
  2183. out, final_hiddens = _rnn_helper(
  2184. data,
  2185. hidden,
  2186. params,
  2187. has_biases,
  2188. num_layers,
  2189. dropout,
  2190. train,
  2191. bidirectional,
  2192. False,
  2193. partial(one_layer_lstm_data, batch_sizes=batch_sizes),
  2194. )
  2195. final_hiddens = list(zip(*final_hiddens))
  2196. return out, torch.stack(final_hiddens[0], 0), torch.stack(final_hiddens[1], 0)
  2197. def gru_cell(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias):
  2198. chunked_igates = inp.chunk(3, 1)
  2199. chunked_hgates = F.linear(cur_hidden, hh_weight, hh_bias).chunk(3, 2)
  2200. reset_gate = (chunked_hgates[0] + chunked_igates[0]).sigmoid()
  2201. input_gate = (chunked_hgates[1] + chunked_igates[1]).sigmoid()
  2202. new_gate = (chunked_igates[2] + (chunked_hgates[2] * reset_gate)).tanh()
  2203. return (cur_hidden - new_gate) * input_gate + new_gate
  2204. def gru_cell_data(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias):
  2205. chunked_igates = F.linear(inp, ih_weight, ih_bias).chunk(3, 1)
  2206. chunked_hgates = F.linear(cur_hidden, hh_weight, hh_bias).chunk(3, 1)
  2207. reset_gate = (chunked_hgates[0] + chunked_igates[0]).sigmoid()
  2208. input_gate = (chunked_hgates[1] + chunked_igates[1]).sigmoid()
  2209. new_gate = (chunked_igates[2] + (chunked_hgates[2] * reset_gate)).tanh()
  2210. return (cur_hidden - new_gate) * input_gate + new_gate
  2211. @register_decomposition(aten.gru.data)
  2212. @aten.gru.data.py_impl(DispatchKey.CompositeImplicitAutograd)
  2213. @aten.gru.data.py_impl(DispatchKey.Autograd)
  2214. def gru_impl_data(
  2215. data,
  2216. batch_sizes,
  2217. hx,
  2218. params,
  2219. has_biases,
  2220. num_layers,
  2221. dropout,
  2222. train,
  2223. bidirectional,
  2224. ):
  2225. params = gather_params(params, has_biases, False)
  2226. out, final_hiddens = _rnn_helper(
  2227. data,
  2228. hx.unbind(0),
  2229. params,
  2230. has_biases,
  2231. num_layers,
  2232. dropout,
  2233. train,
  2234. bidirectional,
  2235. False,
  2236. partial(one_layer_rnn_data, batch_sizes=batch_sizes, hidden_fn=gru_cell_data),
  2237. )
  2238. return out, torch.stack(final_hiddens, 0)
  2239. @register_decomposition(aten.gru.input)
  2240. @aten.gru.input.py_impl(DispatchKey.CompositeImplicitAutograd)
  2241. @aten.gru.input.py_impl(DispatchKey.Autograd)
  2242. def gru_impl(
  2243. input,
  2244. hx,
  2245. params,
  2246. has_biases,
  2247. num_layers,
  2248. dropout,
  2249. train,
  2250. bidirectional,
  2251. batch_first,
  2252. ):
  2253. params = gather_params(params, has_biases, False)
  2254. out, final_hiddens = _rnn_helper(
  2255. input,
  2256. hx.unbind(0),
  2257. params,
  2258. has_biases,
  2259. num_layers,
  2260. dropout,
  2261. train,
  2262. bidirectional,
  2263. batch_first,
  2264. partial(one_layer_rnn, hidden_fn=gru_cell),
  2265. )
  2266. return out, torch.stack(final_hiddens, 0)
  2267. @register_decomposition(aten.upsample_bilinear2d.vec)
  2268. @aten.upsample_bilinear2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
  2269. @aten.upsample_bilinear2d.vec.py_impl(DispatchKey.Autograd)
  2270. def upsample_bilinear2d_vec(input, output_size, align_corners, scale_factors):
  2271. osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
  2272. scale_h = get_scale_value(scale_factors, 0)
  2273. scale_w = get_scale_value(scale_factors, 1)
  2274. return upsample_bilinear2d(input, osize, align_corners, scale_h, scale_w)
  2275. @register_decomposition(aten.upsample_bilinear2d.default)
  2276. @aten.upsample_bilinear2d.default.py_impl(DispatchKey.Autograd)
  2277. @pw_cast_for_opmath
  2278. def upsample_bilinear2d(
  2279. input: Tensor,
  2280. output_size: List[int],
  2281. align_corners: bool,
  2282. scales_h: Optional[float] = None,
  2283. scales_w: Optional[float] = None,
  2284. ) -> Tensor:
  2285. # get dimensions of original image
  2286. n_batch, n_channels, in_h, in_w = input.shape
  2287. out_h = output_size[0]
  2288. out_w = output_size[1]
  2289. # Calculate horizontal and vertical scaling factor
  2290. # TODO: Figure out if scales_h/scales_w matters here
  2291. if out_h > 1:
  2292. if align_corners:
  2293. h_scale_factor = (in_h - 1) / (out_h - 1)
  2294. else:
  2295. h_scale_factor = (
  2296. in_h / (in_h * scales_h) if scales_h is not None else in_h / out_h
  2297. )
  2298. else:
  2299. h_scale_factor = 0.0
  2300. if out_w > 1:
  2301. if align_corners:
  2302. w_scale_factor = (in_w - 1) / (out_w - 1)
  2303. else:
  2304. w_scale_factor = (
  2305. in_w / (in_w * scales_w) if scales_w is not None else in_w / out_w
  2306. )
  2307. else:
  2308. w_scale_factor = 0.0
  2309. i = torch.arange(out_h, dtype=input.dtype, device=input.device)
  2310. j = torch.arange(out_w, dtype=input.dtype, device=input.device)
  2311. if align_corners:
  2312. x = h_scale_factor * i
  2313. y = w_scale_factor * j
  2314. else:
  2315. x = (h_scale_factor * (i + 0.5) - 0.5).clamp(min=0.0)
  2316. y = (w_scale_factor * (j + 0.5) - 0.5).clamp(min=0.0)
  2317. x_floor = x.to(torch.int64)
  2318. x_ceil = torch.ceil(x).clamp(max=in_h - 1).to(torch.int64)
  2319. y_floor = y.to(torch.int64)
  2320. y_ceil = torch.ceil(y).clamp(max=in_w - 1).to(torch.int64)
  2321. x_view = x.unsqueeze(1)
  2322. x_floor_view = x_floor.unsqueeze(1)
  2323. x_ceil_view = x_ceil.unsqueeze(1)
  2324. v1 = input[:, :, x_floor_view, y_floor]
  2325. v2 = input[:, :, x_ceil_view, y_floor]
  2326. v3 = input[:, :, x_floor_view, y_ceil]
  2327. v4 = input[:, :, x_ceil_view, y_ceil]
  2328. xscale2 = x_view - x_floor_view
  2329. xscale1 = 1.0 - xscale2
  2330. yscale2 = y - y_floor
  2331. yscale1 = 1.0 - yscale2
  2332. q1 = torch.mul(v1, xscale1) + torch.mul(v2, xscale2)
  2333. q2 = torch.mul(v3, xscale1) + torch.mul(v4, xscale2)
  2334. result = torch.mul(q1, yscale1) + torch.mul(q2, yscale2)
  2335. # convert output to correct memory format, if necessary
  2336. memory_format = utils.suggest_memory_format(input)
  2337. # following "heuristic: only use channels_last path when it's faster than the contiguous path"
  2338. if input.device.type == "cuda" and n_channels < 16:
  2339. memory_format = torch.contiguous_format
  2340. result = result.contiguous(memory_format=memory_format)
  2341. return result
  2342. # We should be applying decompositions after all transformations
  2343. @register_decomposition(aten.is_same_size.default)
  2344. def is_same_size(a: Tensor, b: Tensor) -> bool:
  2345. return a.shape == b.shape
  2346. @register_decomposition([aten._reshape_alias, aten._unsafe_view])
  2347. def _reshape_alias(x, shape, *args):
  2348. return aten.view(x, shape)
  2349. @register_decomposition(aten.nll_loss_forward)
  2350. def nll_loss_forward(
  2351. self: Tensor,
  2352. target: Tensor,
  2353. weight: Optional[Tensor],
  2354. reduction: int,
  2355. ignore_index: int,
  2356. ) -> Tuple[Tensor, Tensor]:
  2357. assert self.dim() > 0 and self.dim() <= 2, "input tensor should be 1D or 2D"
  2358. assert (
  2359. target.dim() <= 1
  2360. ), "0D or 1D target tensor expected, multi-target not supported"
  2361. no_batch_dim = self.dim() == 1 and target.dim() == 0
  2362. assert no_batch_dim or (
  2363. self.shape[0] == target.shape[0]
  2364. ), f"size mismatch (got input: {self.shape}, target: {target.shape})"
  2365. n_classes = self.shape[-1]
  2366. assert weight is None or (
  2367. weight.dim() == 1 and weight.numel() == n_classes
  2368. ), f"weight tensor should be defined either for all {n_classes} classes or no classes but got weight tensor of shape: {weight.shape}" # noqa: B950
  2369. # self can be [N, C] or [C]
  2370. # target can be [N] or []
  2371. n_dims = self.dim()
  2372. channel_dim = 1
  2373. if n_dims < 2:
  2374. channel_dim = 0
  2375. if weight is not None:
  2376. w = weight.unsqueeze(0) if n_dims > 1 else weight
  2377. self = self * w
  2378. safe_target = torch.where(target != ignore_index, target, 0)
  2379. safe_target_ = safe_target.unsqueeze(channel_dim)
  2380. # target can be [N, 1] or [1]
  2381. result = -torch.gather(self, channel_dim, safe_target_).squeeze(channel_dim)
  2382. result = torch.where(target != ignore_index, result, 0)
  2383. if reduction == Reduction.NONE.value and n_dims > 1:
  2384. total_weight = self.new_full((), 0.0)
  2385. return result, total_weight
  2386. if weight is not None:
  2387. w = weight.unsqueeze(0).expand(self.shape) if n_dims > 1 else weight
  2388. wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim)
  2389. wsum = torch.where(target != ignore_index, wsum, 0)
  2390. total_weight = wsum.sum()
  2391. else:
  2392. total_weight = (target != ignore_index).sum().to(self)
  2393. if reduction == Reduction.SUM.value:
  2394. result = result.sum()
  2395. elif reduction == Reduction.MEAN.value:
  2396. result = result.sum() / total_weight
  2397. return result, total_weight
  2398. # These are adapted from aten/src/ATen/native/UpSample.h, wich is based on
  2399. # https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
  2400. def _upsample_cubic_convolution1(x: Tensor, A: float) -> Tensor:
  2401. return ((A + 2) * x - (A + 3)) * x * x + 1
  2402. def _upsample_cubic_convolution2(x: Tensor, A: float) -> Tensor:
  2403. return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A
  2404. def _upsample_get_cubic_coefficients(t: Tensor) -> TensorSequenceType:
  2405. A = -0.75
  2406. return (
  2407. _upsample_cubic_convolution2(t + 1.0, A),
  2408. _upsample_cubic_convolution1(t, A),
  2409. _upsample_cubic_convolution1(1.0 - t, A),
  2410. _upsample_cubic_convolution2(2.0 - t, A),
  2411. )
  2412. def _upsample_cubic_interp1d(coeffs: TensorSequenceType, ts: Tensor) -> Tensor:
  2413. coeffs2 = _upsample_get_cubic_coefficients(ts)
  2414. return _sum_tensors(c1 * c2 for (c1, c2) in zip(coeffs, coeffs2))
  2415. # Need this instead of just sum() to keep mypy happy
  2416. def _sum_tensors(ts: Iterable[Tensor]) -> Tensor:
  2417. return reduce(torch.add, ts)
  2418. @register_decomposition(aten.grid_sampler_2d)
  2419. @pw_cast_for_opmath
  2420. def grid_sampler_2d(
  2421. a: Tensor,
  2422. grid: Tensor,
  2423. interpolation_mode: int = 0,
  2424. padding_mode: int = 0,
  2425. align_corners: bool = False,
  2426. ) -> Tensor:
  2427. utils.check(
  2428. interpolation_mode in (0, 1, 2),
  2429. lambda: f"Invalid interpolation mode {interpolation_mode}",
  2430. )
  2431. utils.check(
  2432. padding_mode in (0, 1, 2), lambda: f"Invalid padding mode {padding_mode}"
  2433. )
  2434. def unnormalize(coords: Tensor, size: int) -> Tensor:
  2435. # Rescale coordinates from [-1, 1] to:
  2436. # [0, size - 1] if align_corners is True
  2437. # [-.5, size -.5] if align_corners is False
  2438. mul = (size * 0.5 - 0.5) if align_corners else (size * 0.5)
  2439. ofs = size * 0.5 - 0.5
  2440. return coords * mul + ofs
  2441. # Reflects coordinates until they fall between low and high (inclusive).
  2442. # The bounds are passed as twice their value so that half-integer values
  2443. # can be represented as ints.
  2444. def reflect_coordinates(coords: Tensor, twice_low: int, twice_high: int) -> Tensor:
  2445. if twice_low == twice_high:
  2446. return torch.zeros_like(coords)
  2447. coords_min = twice_low / 2
  2448. coords_span = (twice_high - twice_low) / 2
  2449. coords2 = (coords - coords_min).abs()
  2450. extra = torch.fmod(coords2, coords_span)
  2451. flips = (coords2 / coords_span).floor().to(dtype=torch.int8)
  2452. return torch.where(
  2453. flips & 1 == 0, extra + coords_min, coords_span + coords_min - extra
  2454. )
  2455. def compute_coordinates(coords: Tensor, size: int) -> Tensor:
  2456. if padding_mode == 0: # Zero
  2457. return coords
  2458. elif padding_mode == 1: # Borders
  2459. return torch.clamp(coords, 0, size - 1)
  2460. else: # padding_mode == 2, Reflection
  2461. if align_corners:
  2462. coords_reflected = reflect_coordinates(coords, 0, 2 * (size - 1))
  2463. else:
  2464. coords_reflected = reflect_coordinates(coords, -1, 2 * size - 1)
  2465. return torch.clamp(coords_reflected, 0, size - 1)
  2466. def compute_source_index(coords: Tensor, size: int) -> Tensor:
  2467. coords_un = unnormalize(coords, size)
  2468. return compute_coordinates(coords_un, size)
  2469. N, C, iH, iW = a.shape
  2470. _, oH, oW, _ = grid.shape
  2471. def in_bounds_cond(xs: Tensor, ys: Tensor) -> Tensor:
  2472. return torch.logical_and(
  2473. 0 <= xs, torch.logical_and(xs < iW, torch.logical_and(0 <= ys, ys < iH))
  2474. )
  2475. N_idx = torch.arange(N, device=a.device).view(N, 1, 1, 1)
  2476. C_idx = torch.arange(C, device=a.device).view(1, C, 1, 1)
  2477. def clip(xs: Tensor, ys: Tensor, ws: Tensor) -> TensorSequenceType:
  2478. cond = in_bounds_cond(xs, ys)
  2479. # To clip to inside valid coordinates, we map the coordinates
  2480. # to (x, y) = (0, 0) and also set the weight to 0
  2481. # We also change the shape of the tensor to the appropriate one for
  2482. # broadcasting with N_idx, C_idx for the purposes of advanced indexing
  2483. return tuple(
  2484. torch.where(cond, t, 0).view(N, 1, oH, oW)
  2485. for t in (xs.to(dtype=torch.int64), ys.to(dtype=torch.int64), ws)
  2486. )
  2487. def get_summand(ix: Tensor, iy: Tensor, w) -> Tensor:
  2488. # Perform clipping, index into input tensor and multiply by weight
  2489. idx_x, idx_y, w_ = clip(ix, iy, w)
  2490. return a[N_idx, C_idx, idx_y, idx_x] * w_
  2491. x = grid[..., 0]
  2492. y = grid[..., 1]
  2493. if interpolation_mode == 0: # Bilinear
  2494. ix = compute_source_index(x, iW)
  2495. iy = compute_source_index(y, iH)
  2496. ix_nw, iy_nw = ix.floor(), iy.floor()
  2497. ix_ne, iy_ne = ix_nw + 1, iy_nw
  2498. ix_sw, iy_sw = ix_nw, iy_nw + 1
  2499. ix_se, iy_se = ix_ne, iy_sw
  2500. w_nw = (ix_se - ix) * (iy_se - iy)
  2501. w_ne = (ix - ix_sw) * (iy_sw - iy)
  2502. w_sw = (ix_ne - ix) * (iy - iy_ne)
  2503. w_se = (ix - ix_nw) * (iy - iy_nw)
  2504. return _sum_tensors(
  2505. get_summand(ix, iy, w)
  2506. for (ix, iy, w) in (
  2507. (ix_nw, iy_nw, w_nw),
  2508. (ix_ne, iy_ne, w_ne),
  2509. (ix_sw, iy_sw, w_sw),
  2510. (ix_se, iy_se, w_se),
  2511. )
  2512. )
  2513. elif interpolation_mode == 1: # Nearest
  2514. ix = compute_source_index(x, iW)
  2515. iy = compute_source_index(y, iH)
  2516. ix_nearest = ix.round()
  2517. iy_nearest = iy.round()
  2518. return get_summand(ix_nearest, iy_nearest, 1)
  2519. else: # interpolation_mode == 2, Bicubic
  2520. ix = unnormalize(x, iW)
  2521. iy = unnormalize(y, iH)
  2522. ix_nw = ix.floor()
  2523. iy_nw = iy.floor()
  2524. tx = ix - ix_nw
  2525. ty = iy - iy_nw
  2526. def get_value_bounded(ix: Tensor, iy: Tensor) -> Tensor:
  2527. x = compute_coordinates(ix, iW)
  2528. y = compute_coordinates(iy, iH)
  2529. return get_summand(x, y, 1)
  2530. def get_coeff(ofs: int) -> Tensor:
  2531. iy_ofs = iy_nw + (ofs - 1)
  2532. cs = (
  2533. get_value_bounded(ix_nw - 1, iy_ofs),
  2534. get_value_bounded(ix_nw, iy_ofs),
  2535. get_value_bounded(ix_nw + 1, iy_ofs),
  2536. get_value_bounded(ix_nw + 2, iy_ofs),
  2537. )
  2538. return _upsample_cubic_interp1d(cs, tx.unsqueeze(1))
  2539. coeffs = tuple((get_coeff(ofs) for ofs in range(4)))
  2540. return _upsample_cubic_interp1d(coeffs, ty.unsqueeze(1))
  2541. @register_decomposition(aten.mv)
  2542. @out_wrapper()
  2543. @pw_cast_for_opmath
  2544. def mv(self, vec):
  2545. utils.check(
  2546. self.dim() == 2 and vec.dim() == 1,
  2547. lambda: f"matrix @ vector expected, got {self.dim()}, {vec.dim()}",
  2548. )
  2549. utils.check(
  2550. self.size(1) == vec.size(0),
  2551. lambda: f"size mismatch, got {self.size(0)}x{self.size(1)},{vec.size(0)}",
  2552. )
  2553. return (self * vec).sum(dim=1)
  2554. @register_decomposition(aten.dot)
  2555. @out_wrapper()
  2556. @pw_cast_for_opmath
  2557. def dot(self, other):
  2558. if self.is_complex():
  2559. if self.is_conj():
  2560. if other.is_conj():
  2561. return torch.dot(self.conj(), other.conj()).conj()
  2562. else:
  2563. return torch.vdot(self.conj(), other)
  2564. elif other.is_conj():
  2565. return torch.vdot(other.conj(), self)
  2566. utils.check(
  2567. self.dim() == 1 and other.dim() == 1,
  2568. lambda: f"1D tensors expected, but got {self.dim()}D and {other.dim()}D tensors",
  2569. )
  2570. utils.check(
  2571. self.dtype == other.dtype,
  2572. lambda: f"dot : expected both vectors to have same dtype, but found {self.dtype} and {other.dtype}",
  2573. )
  2574. def numel_error():
  2575. return (
  2576. f"inconsistent tensor size, expected tensor [{self.numel()}] and src [{other.numel()}] to have the"
  2577. f"same number of elements, but got {self.numel()} and {other.numel()} elements respectively"
  2578. )
  2579. utils.check(self.numel() == other.numel(), numel_error)
  2580. return (self * other).sum()
  2581. @register_decomposition(aten.binary_cross_entropy_with_logits)
  2582. def binary_cross_entropy_with_logits(
  2583. self, target, weight=None, pos_weight=None, reduction=Reduction.MEAN.value
  2584. ):
  2585. max_val = (-self).clamp_min(0)
  2586. if pos_weight is not None:
  2587. log_weight = (pos_weight - 1) * target + 1
  2588. loss = (1 - target) * self + log_weight * (
  2589. ((-max_val).exp() + (-self - max_val).exp()).log() + max_val
  2590. )
  2591. else:
  2592. loss = (
  2593. (1 - target) * self
  2594. + max_val
  2595. + ((-max_val).exp() + (-self - max_val).exp()).log()
  2596. )
  2597. if weight is not None:
  2598. loss = loss * weight
  2599. return apply_loss_reduction(loss, reduction)
  2600. def should_fold(tensor1: torch.Tensor, dim_tensor2: int) -> bool:
  2601. dim_tensor1 = tensor1.ndim
  2602. if dim_tensor1 >= 3 and (dim_tensor2 == 1 or dim_tensor2 == 2):
  2603. t1_sizes_ptr = tensor1.shape
  2604. t1_strides = tensor1.stride()
  2605. if (
  2606. dim_tensor1 == 3
  2607. and dim_tensor2 == 2
  2608. and t1_strides[-1] != 1
  2609. and t1_strides[0] == t1_sizes_ptr[1] * t1_sizes_ptr[2]
  2610. ):
  2611. # First dim is slowest moving, and then the following two dims are
  2612. # transposed. This can happen for example by permute(0, 2, 1).
  2613. # First 2 dims could be folded to use mm but would require permutation
  2614. # with actual data movement, which can be instead handled by BMM with each
  2615. # GEMM transposed.
  2616. # This can be generalized to a tensor with dim X + Y + Z where X, Y, and Z
  2617. # dims are contiguous, Y dims and Z dims are transposed, and X, Y, Z > 0.
  2618. # For example, this can happen by permute(0, 1, 5, 2, 3, 4), where X = 2,
  2619. # Y = 3, and Z = 1.
  2620. return False
  2621. else:
  2622. return True
  2623. else:
  2624. return False
  2625. @aten.matmul.default.py_impl(DispatchKey.CompositeImplicitAutograd)
  2626. def matmul(tensor1, tensor2):
  2627. dim_tensor1 = tensor1.dim()
  2628. dim_tensor2 = tensor2.dim()
  2629. assert dim_tensor1 != 0 and dim_tensor2 != 0
  2630. if dim_tensor1 == 1 and dim_tensor2 == 1:
  2631. return torch.dot(tensor1, tensor2)
  2632. elif dim_tensor1 == 2 and dim_tensor2 == 1:
  2633. return torch.mv(tensor1, tensor2)
  2634. elif dim_tensor1 == 1 and dim_tensor2 == 2:
  2635. return torch.squeeze(torch.mm(torch.unsqueeze(tensor1, 0), tensor2), 0)
  2636. elif dim_tensor1 == 2 and dim_tensor2 == 2:
  2637. # if tensor1.shape[1] != tensor2.shape[0]:
  2638. # breakpoint()
  2639. return torch.mm(tensor1, tensor2)
  2640. elif should_fold(tensor1, dim_tensor2) or should_fold(tensor2, dim_tensor1):
  2641. # NB: Much of this was written with Copilot! (although still had to fix a bunch of issues)
  2642. # dim_tensor1 >=3 && (dim_tensor2 == 1 || dim_tensor2 == 2) ||
  2643. # dim_tensor2 >=3 && (dim_tensor1 == 1 || dim_tensor1 == 2)
  2644. # and some condition on the strides is fulfilled
  2645. # optimization: use mm instead of bmm by folding the batch of the larger tensor
  2646. # into its leading matrix dimension
  2647. transpose = dim_tensor2 > dim_tensor1
  2648. t1 = tensor2.mT if transpose else tensor1
  2649. t2 = (
  2650. tensor2 if not transpose else (tensor1.t() if dim_tensor1 == 2 else tensor1)
  2651. )
  2652. # Invariant: t1.dim() >= 3 && (t2.dim() == 1 || t2.dim() == 2)
  2653. # and t1 and t2 are matmul-compatible
  2654. # Why not t1.view(-1, sizes_1[-1])?
  2655. # If the last dim is 0, then view(-1, 0) won't work because the -1 becomes ambiguous.
  2656. # This can happen in e.g. [3, 5, 0] @ [0, 0].
  2657. sizes_1 = t1.shape
  2658. output_shape = list(sizes_1[:-1])
  2659. folded_dim1 = reduce(operator.mul, output_shape)
  2660. # Readjust output_shape if we are multiplying by a matrix
  2661. t2_is_matrix = t2.dim() == 2
  2662. if t2_is_matrix:
  2663. output_shape.append(t2.shape[1])
  2664. t1_folded = t1.reshape(folded_dim1, sizes_1[-1])
  2665. if t2_is_matrix:
  2666. # FIXME This path always does an unnecessary copy when transpose == True as the returned
  2667. # result from BLAS is already C-transposed
  2668. output = t1_folded.mm(t2).view(output_shape)
  2669. return output.mT.contiguous() if transpose else output
  2670. else:
  2671. return t1_folded.mv(t2).view(output_shape)
  2672. elif dim_tensor1 >= 1 and dim_tensor2 >= 1:
  2673. # We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list);
  2674. # we track m1 vs m2 separately even though they must match for nicer error messages
  2675. n = tensor1.size(-2) if dim_tensor1 > 1 else 1
  2676. m1 = tensor1.size(-1)
  2677. batch_tensor1 = tensor1.shape[:-2]
  2678. m2 = tensor2.size(-2) if dim_tensor2 > 1 else tensor2.size(-1)
  2679. p = tensor2.size(-1) if dim_tensor2 > 1 else 1
  2680. batch_tensor2: List[int] = []
  2681. # TODO: handling of slice
  2682. for i in range(dim_tensor2 - 2):
  2683. batch_tensor2.append(tensor2.size(i))
  2684. # expand the batch portion (i.e. cut off matrix dimensions and expand rest)
  2685. expand_batch_portion = list(
  2686. torch.broadcast_shapes(batch_tensor1, batch_tensor2)
  2687. )
  2688. tensor1_expand_size = expand_batch_portion + [n, m1]
  2689. tensor2_expand_size = expand_batch_portion + [m2, p]
  2690. expand_batch_product = prod(expand_batch_portion)
  2691. # HACK: We need reshape with symint support
  2692. tensor1_expanded = tensor1.expand(tensor1_expand_size).reshape(
  2693. expand_batch_product, n, m1
  2694. )
  2695. tensor2_expanded = tensor2.expand(tensor2_expand_size).reshape(
  2696. expand_batch_product, m2, p
  2697. )
  2698. output_shape = expand_batch_portion
  2699. if dim_tensor1 > 1:
  2700. output_shape.append(n)
  2701. if dim_tensor2 > 1:
  2702. output_shape.append(p)
  2703. return tensor1_expanded.bmm(tensor2_expanded).view(output_shape)
  2704. else:
  2705. utils.check(False, lambda: "both arguments to matmul need to be at least 1D")
  2706. @register_decomposition(aten.upsample_bicubic2d.default)
  2707. @pw_cast_for_opmath
  2708. def upsample_bicubic2d_default(
  2709. a: Tensor,
  2710. output_size: Tuple[int, int],
  2711. align_corners: bool,
  2712. scale_h: Optional[float] = None,
  2713. scale_w: Optional[float] = None,
  2714. ) -> Tensor:
  2715. N, C, iH, iW = a.shape
  2716. oH, oW = output_size
  2717. def compute_scale(in_size, out_size, align_corners, scale=None):
  2718. if align_corners:
  2719. return (in_size - 1) / (out_size - 1) if out_size > 1 else 0
  2720. else:
  2721. return 1 / scale if scale is not None and scale > 0 else in_size / out_size
  2722. def compute_source_index(scale, dst_index, align_corners):
  2723. if align_corners:
  2724. return scale * dst_index
  2725. else:
  2726. return scale * (dst_index + 0.5) - 0.5
  2727. height_scale = compute_scale(iH, oH, align_corners, scale_h)
  2728. width_scale = compute_scale(iW, oW, align_corners, scale_w)
  2729. N_idx = torch.arange(N, device=a.device).view(N, 1, 1, 1)
  2730. C_idx = torch.arange(C, device=a.device).view(1, C, 1, 1)
  2731. out_y = torch.arange(oH, device=a.device).view((1, 1, oH, 1))
  2732. out_x = torch.arange(oW, device=a.device).view((1, 1, 1, oW))
  2733. real_x = compute_source_index(width_scale, out_x, align_corners)
  2734. in_x = real_x.floor()
  2735. t_x = real_x - in_x
  2736. ix = in_x.to(dtype=torch.int64)
  2737. real_y = compute_source_index(height_scale, out_y, align_corners)
  2738. in_y = real_y.floor()
  2739. t_y = real_y - in_y
  2740. iy = in_y.to(dtype=torch.int64)
  2741. iys_ofs = (iy - 1, iy, iy + 1, iy + 2)
  2742. ixs_ofs = (ix - 1, ix, ix + 1, ix + 2)
  2743. def load_bounded(ys, xs):
  2744. y_idx = torch.clamp(ys, 0, iH - 1)
  2745. x_idx = torch.clamp(xs, 0, iW - 1)
  2746. return a[N_idx, C_idx, y_idx, x_idx]
  2747. def get_x_interp(y):
  2748. coeffs_x = tuple((load_bounded(y, x_ofs) for x_ofs in ixs_ofs))
  2749. return _upsample_cubic_interp1d(coeffs_x, t_x)
  2750. coeffs_y = tuple((get_x_interp(y_ofs) for y_ofs in iys_ofs))
  2751. result = _upsample_cubic_interp1d(coeffs_y, t_y)
  2752. # convert output to correct memory format, if necessary
  2753. memory_format = utils.suggest_memory_format(a)
  2754. result = result.contiguous(memory_format=memory_format)
  2755. return result
  2756. @register_decomposition(aten.upsample_bicubic2d.vec)
  2757. @aten.upsample_bicubic2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
  2758. @aten.upsample_bicubic2d.vec.py_impl(DispatchKey.Autograd)
  2759. @out_wrapper()
  2760. @pw_cast_for_opmath
  2761. def upsample_bicubic2d_vec(
  2762. a: Tensor,
  2763. output_size: Optional[Tuple[int, int]],
  2764. align_corners: bool,
  2765. scale_factors: Optional[Tuple[float, float]] = None,
  2766. ) -> Tensor:
  2767. utils.check(
  2768. bool(output_size) + bool(scale_factors) == 1,
  2769. lambda: "Must specify exactly one of output_size and scale_factors.",
  2770. )
  2771. if output_size is None:
  2772. assert scale_factors is not None
  2773. output_size = cast(
  2774. Tuple[int, int],
  2775. tuple(
  2776. sym_int(sym_float(w) * scale)
  2777. for w, scale in zip(a.shape[2:], scale_factors)
  2778. ),
  2779. )
  2780. scale_h, scale_w = scale_factors if scale_factors else (None, None)
  2781. return upsample_bicubic2d_default(a, output_size, align_corners, scale_h, scale_w)
  2782. def register_inplace(aten_op, outplace_op):
  2783. @register_decomposition(aten_op)
  2784. def inplace_op(*args, **kwargs):
  2785. out = outplace_op(*args, **kwargs)
  2786. return args[0].copy_(out)
  2787. return inplace_op
  2788. register_inplace(aten.addbmm_, aten.addbmm)
  2789. register_inplace(aten.addmm_, aten.addmm)
  2790. register_inplace(aten.addmv_, aten.addmv)
  2791. register_inplace(aten.baddbmm_, aten.baddbmm)
  2792. register_inplace(aten.cumprod_, aten.cumprod)
  2793. register_inplace(aten.fill_, aten.fill)
  2794. register_inplace(aten.gelu_, aten.gelu)
  2795. register_inplace(aten.hardswish_, aten.hardswish)
  2796. register_inplace(aten.hardtanh_, aten.hardtanh)
  2797. register_inplace(aten.hardsigmoid_, aten.hardsigmoid)
  2798. register_inplace(aten.index_put_, aten.index_put)
  2799. register_inplace(aten.index_reduce_, aten.index_reduce)
  2800. register_inplace(aten.leaky_relu_, aten.leaky_relu)
  2801. register_inplace(aten.logit_, aten.logit)
  2802. register_inplace(aten.relu_, aten.relu)
  2803. register_inplace(aten.renorm_, aten.renorm)
  2804. register_inplace(aten.round_, aten.round)
  2805. register_inplace(aten.scatter_, aten.scatter)
  2806. register_inplace(aten.scatter_add_, aten.scatter_add)
  2807. register_inplace(aten.scatter_reduce_, aten.scatter_reduce)
  2808. register_inplace(aten.silu_, aten.silu)