12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317 |
- import functools
- import operator
- import sys
- from enum import Enum
- from functools import partial, reduce
- from itertools import product
- from typing import Callable, cast, Iterable, List, Optional, Tuple, Union
- import torch
- import torch._prims as prims
- import torch._prims_common as utils
- import torch.nn.functional as F
- from torch import sym_float, sym_int, Tensor
- from torch._decomp import register_decomposition
- from torch._prims_common import IntLike, NumberType, TensorLike, TensorSequenceType
- from torch._prims_common.wrappers import (
- _maybe_convert_to_dtype,
- _maybe_resize_out,
- _safe_copy_out,
- out_wrapper,
- )
- from torch.fx.experimental.symbolic_shapes import guard_int
- from torch.utils._pytree import tree_flatten, tree_map
- DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
- # None of these functions are publicly accessible; get at them
- # from torch._decomps
- __all__: List[str] = []
- aten = torch._ops.ops.aten
- class Reduction(Enum):
- NONE = 0
- MEAN = 1
- SUM = 2
- # This wraps a decomposition and performs various type promotion logic within it, depending on the strategy provided
- # We're currently re-using ELEMENTWISE_TYPE_PROMOTION_KIND, although some of the usages are on non-elementwise ops
- # Will need to validate the non-elementwise uses
- def type_casts(
- f: Callable,
- type_promotion: utils.ELEMENTWISE_TYPE_PROMOTION_KIND,
- compute_dtype_only: bool = False,
- ):
- @functools.wraps(f)
- def inner(*args, **kwargs):
- flat_args = [
- x for x in tree_flatten((args, kwargs))[0] if isinstance(x, Tensor)
- ]
- computation_dtype, result_dtype = utils.elementwise_dtypes(
- *flat_args, type_promotion_kind=type_promotion
- )
- # TODO: pretty sure this is not quite right
- def increase_prec(x):
- if isinstance(x, Tensor):
- return x.to(computation_dtype)
- else:
- return x
- def decrease_prec(x):
- if isinstance(x, Tensor):
- return x.to(result_dtype)
- else:
- return x
- r = f(*tree_map(increase_prec, args), **tree_map(increase_prec, kwargs))
- if compute_dtype_only:
- return r
- else:
- return tree_map(decrease_prec, r)
- return inner
- compute_only_pw_cast_for_opmath = partial(
- type_casts,
- type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- compute_dtype_only=True,
- )
- pw_cast_for_opmath = partial(
- type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
- )
- pw_cast_for_int_to_real = partial(
- type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
- )
- # This expands x until x.dim() == dim. Might be useful as an operator
- def _unsqueeze_to_dim(x: Tensor, dim: int):
- for _ in range(dim - x.dim()):
- x = x.unsqueeze(-1)
- return x
- @register_decomposition(aten.tanh_backward)
- @pw_cast_for_opmath
- def tanh_backward(out_grad: Tensor, y: Tensor):
- return out_grad * (1 - y * y).conj_physical()
- @register_decomposition(aten.sigmoid_backward)
- @pw_cast_for_opmath
- def sigmoid_backward(out_grad: Tensor, y: Tensor):
- return out_grad * (y * (1 - y)).conj_physical()
- @register_decomposition(aten.softplus_backward)
- @pw_cast_for_opmath
- def softplus_backward(out_grad: Tensor, x: Tensor, beta: float, threshold: float):
- z = (x * beta).exp()
- return torch.where((x * beta) > threshold, out_grad, out_grad * z / (z + 1.0))
- @register_decomposition(aten.elu_backward)
- @pw_cast_for_opmath
- def elu_backward(
- grad_output: Tensor,
- alpha: float,
- scale: float,
- input_scale: float,
- is_result: bool,
- self_or_result: Tensor,
- ):
- negcoef = alpha * scale
- poscoef = scale
- negiptcoef = input_scale
- if is_result:
- return torch.where(
- self_or_result <= 0,
- grad_output * negiptcoef * (self_or_result + negcoef),
- self_or_result * poscoef,
- )
- else:
- return torch.where(
- self_or_result <= 0,
- grad_output * negiptcoef * negcoef * torch.exp(self_or_result * negiptcoef),
- grad_output * poscoef,
- )
- @register_decomposition([aten.fill.Scalar])
- def fill_scalar(self, value):
- return torch.full_like(self, value)
- @register_decomposition([aten.fill.Tensor])
- def fill_tensor(self, value: Tensor):
- utils.check(
- value.dim() == 0,
- lambda: f"fill only supports 0-dimension value tensor but got tensor with {value.dim()} dimensions",
- )
- return torch.full_like(self, value.item())
- @register_decomposition(aten.hardsigmoid)
- @pw_cast_for_opmath
- def hardsigmoid(self: Tensor) -> Tensor:
- return torch.clamp(torch.clamp(self + 3, min=0), max=6) / 6
- @register_decomposition(aten.hardsigmoid_backward)
- @pw_cast_for_opmath
- def hardsigmoid_backward(grad_output: Tensor, self: Tensor):
- return torch.where(
- (self > -3.0) & (self < 3.0),
- grad_output * (1.0 / 6.0),
- 0.0,
- )
- @register_decomposition(aten.hardtanh_backward)
- def hardtanh_backward(
- grad_output: Tensor, self: Tensor, min_val: float, max_val: float
- ):
- return torch.where((self <= min_val) | (self >= max_val), 0.0, grad_output)
- @register_decomposition(aten.hardshrink_backward)
- def hardshrink_backward(grad_out: Tensor, self: Tensor, lambd: float):
- return torch.where((self >= -lambd) & (self <= lambd), 0.0, grad_out)
- @register_decomposition(aten.hardswish)
- @pw_cast_for_opmath
- def hardswish(self: Tensor) -> Tensor:
- return self * torch.clamp(torch.clamp(self + 3, min=0), max=6) / 6
- @register_decomposition(aten.hardswish_backward)
- @pw_cast_for_opmath
- def hardswish_backward(grad_output: Tensor, self: Tensor) -> Tensor:
- return torch.where(
- self < -3,
- 0.0,
- torch.where(self <= 3, grad_output * ((self / 3) + 0.5), grad_output),
- )
- @register_decomposition(aten.threshold_backward)
- def threshold_backward(grad_output: Tensor, self: Tensor, threshold: float):
- return torch.where(self <= threshold, 0.0, grad_output)
- @register_decomposition(aten.leaky_relu_backward)
- @pw_cast_for_opmath
- def leaky_relu_backward(
- grad_output: Tensor, self: Tensor, negative_slope: float, self_is_result: bool
- ):
- return torch.where(self > 0, grad_output, grad_output * negative_slope)
- @register_decomposition(aten.gelu_backward)
- @pw_cast_for_opmath
- def gelu_backward(grad: Tensor, self: Tensor, approximate: str = "none"):
- M_SQRT2 = 1.41421356237309504880
- M_SQRT1_2 = 0.70710678118654752440
- M_2_SQRTPI = 1.12837916709551257390
- if approximate == "tanh":
- kBeta = M_SQRT2 * M_2_SQRTPI * 0.5
- kKappa = 0.044715
- x_sq = self * self
- x_cube = x_sq * self
- inner = kBeta * (self + kKappa * x_cube)
- tanh_inner = torch.tanh(inner)
- left = 0.5 * self
- right = 1 + tanh_inner
- left_derivative = 0.5 * right
- tanh_derivative = 1 - tanh_inner * tanh_inner
- inner_derivative = kBeta * (1 + 3 * kKappa * x_sq)
- right_derivative = left * tanh_derivative * inner_derivative
- return grad * (left_derivative + right_derivative)
- else:
- kAlpha = M_SQRT1_2
- kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5
- cdf = 0.5 * (1 + torch.erf(self * kAlpha))
- pdf = kBeta * torch.exp(self * self * -0.5)
- return grad * (cdf + self * pdf)
- @register_decomposition(aten.mish_backward)
- @pw_cast_for_opmath
- def mish_backward(grad_output: Tensor, input: Tensor):
- input_tanh_softplus = torch.tanh(F.softplus(input))
- input_sigmoid = torch.sigmoid(input)
- out = input * input_sigmoid * (1 - input_tanh_softplus * input_tanh_softplus)
- return grad_output * (input_tanh_softplus + out)
- @register_decomposition(aten.silu)
- @pw_cast_for_opmath
- def silu(self: Tensor) -> Tensor:
- return self * torch.sigmoid(self)
- @register_decomposition(aten.silu_backward)
- @pw_cast_for_opmath
- def silu_backward(grad_output: Tensor, self: Tensor) -> Tensor:
- sigmoid = 1 / (1 + torch.exp(-self))
- return grad_output * sigmoid * (1 + self * (1 - sigmoid))
- @register_decomposition(aten.softshrink_backward)
- def softshrink_backward(grad_output: Tensor, self: Tensor, lambd: float) -> Tensor:
- return torch.where((self >= -lambd) & (self <= lambd), 0.0, grad_output)
- @register_decomposition(aten._prelu_kernel)
- def _prelu_kernel(self: Tensor, weight: Tensor) -> Tensor:
- return torch.where(self > 0, self, weight * self)
- @register_decomposition(aten._prelu_kernel_backward)
- def _prelu_kernel_backward(
- grad_output: Tensor,
- self: Tensor,
- weight: Tensor,
- ) -> Tuple[Tensor, Tensor]:
- input_grad = torch.where(self > 0, grad_output, weight * grad_output)
- weight_grad = torch.where(self > 0, 0.0, self * grad_output)
- return (input_grad, weight_grad)
- @register_decomposition(aten.rrelu_with_noise_backward)
- @pw_cast_for_opmath
- def rrelu_with_noise_backward(
- grad_output: Tensor,
- self: Tensor,
- noise: Tensor,
- lower: float,
- upper: float,
- training: bool,
- self_is_result: bool,
- ) -> Tensor:
- if training and upper - lower > 1e-6:
- return grad_output.mul(noise)
- else:
- negative_slope = (lower + upper) / 2
- return aten.leaky_relu_backward(
- grad_output, self, negative_slope, self_is_result
- )
- @register_decomposition(aten.log_sigmoid_backward)
- @pw_cast_for_opmath
- def log_sigmoid_backward(grad_output: Tensor, self: Tensor, buffer: Tensor) -> Tensor:
- in_negative = self < 0
- max_deriv = torch.where(in_negative, 1, 0)
- sign = torch.where(in_negative, 1, -1)
- z = torch.exp(-torch.abs(self))
- return grad_output * (max_deriv - sign * (z / (1 + z)))
- # CPU has a special formula that uses buffer, but disabled for convenience sake
- # return (max_deriv - sign * (buffer / (1 + buffer))) * grad_output
- def apply_loss_reduction(loss: Tensor, reduction: int):
- if reduction == Reduction.MEAN.value:
- return torch.mean(loss)
- elif reduction == Reduction.SUM.value:
- return torch.sum(loss)
- else:
- return loss
- def to_real_dtype(dtype: torch.dtype):
- if dtype == torch.complex32:
- return torch.float16
- elif dtype == torch.complex64:
- return torch.float32
- elif dtype == torch.complex128:
- return torch.float64
- # TODO: None of these loss castings are quite correct, see
- # https://github.com/pytorch/pytorch/issues/76870. Also, the ATen kernels
- # perform the pointwise portion in opmath, but don't maintain it between the
- # pointwise portion and the reduction
- @register_decomposition(aten.mse_loss)
- @pw_cast_for_opmath
- def mse_loss(
- self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value
- ) -> Tensor:
- loss = (self - target) ** 2
- return apply_loss_reduction(loss, reduction)
- @register_decomposition(aten.mse_loss_backward)
- @pw_cast_for_opmath
- def mse_loss_backward(
- grad_output: Tensor, input: Tensor, target: Tensor, reduction: int
- ):
- norm = 2.0 / input.numel() if reduction == Reduction.MEAN.value else 2.0
- return norm * (input - target) * grad_output
- @register_decomposition(aten.huber_loss_backward.default)
- @pw_cast_for_opmath
- def huber_loss_backward(
- grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, delta: float
- ):
- norm = 1.0 / self.numel() if reduction == Reduction.MEAN.value else 1.0
- x = self - target
- return torch.where(
- x < -delta,
- -norm * grad_output * delta,
- torch.where(x > delta, norm * grad_output * delta, norm * x * grad_output),
- )
- # We cannot use @out_wrapper() here, because the output tensor is not named 'out', it's 'grad_input'
- @register_decomposition(aten.huber_loss_backward.out)
- @pw_cast_for_opmath
- def huber_loss_backward_out(
- grad_output: Tensor,
- self: Tensor,
- target: Tensor,
- reduction: int,
- delta: float,
- grad_input: Tensor,
- ):
- result = huber_loss_backward(grad_output, self, target, reduction, delta)
- _maybe_resize_out(grad_input, result.shape)
- return _safe_copy_out(copy_from=result, copy_to=grad_input, exact_dtype=True)
- def _nll_loss_backward(
- grad_output: Tensor,
- self: Tensor,
- target: Tensor,
- weight: Optional[Tensor],
- reduction: int,
- ignore_index: int,
- total_weight: Tensor,
- ) -> Tensor:
- channel_dim = 0 if self.dim() < 2 else 1
- if reduction == Reduction.MEAN.value:
- grad_output = grad_output / total_weight
- target = target.unsqueeze(channel_dim)
- safe_target = torch.where(target != ignore_index, target, 0)
- grad_input = torch.zeros_like(self)
- grad_input = torch.scatter(grad_input, channel_dim, safe_target, -1.0)
- if grad_input.dim() > grad_output.dim() > 0:
- grad_output = grad_output.unsqueeze(channel_dim)
- if weight is not None:
- new_shape = [1 for _ in range(self.dim())]
- new_shape[channel_dim] = weight.shape[0]
- weight = weight.reshape(new_shape)
- grad_output = grad_output * weight
- grad_output = torch.where(target != ignore_index, grad_output, 0)
- return grad_input * grad_output
- @register_decomposition(aten.glu_backward)
- @pw_cast_for_opmath
- def glu_backward(grad_output: Tensor, self: Tensor, dim: int) -> Tensor:
- assert self.dim() > 0, "glu does not support 0-dimensional tensors"
- wrap_dim = utils.canonicalize_dim(self.dim(), dim)
- nIn = self.size(wrap_dim)
- assert (
- nIn % 2 == 0
- ), f"Halving dimension must be even, but dimension {wrap_dim} is size {nIn}"
- inputSize = nIn // 2
- firstHalf = self.narrow(wrap_dim, 0, inputSize)
- secondHalf = self.narrow(wrap_dim, inputSize, inputSize)
- gradInputFirstHalf = torch.sigmoid(secondHalf)
- gradInputSecondHalf = (
- (1.0 - gradInputFirstHalf) * gradInputFirstHalf * firstHalf * grad_output
- )
- gradInputFirstHalf = gradInputFirstHalf * grad_output
- return torch.cat([gradInputFirstHalf, gradInputSecondHalf], dim=wrap_dim)
- @register_decomposition(aten.nll_loss_backward)
- def nll_loss_backward(
- grad_output: Tensor,
- self: Tensor,
- target: Tensor,
- weight: Optional[Tensor],
- reduction: int,
- ignore_index: int,
- total_weight: Tensor,
- ) -> Tensor:
- assert 0 <= self.dim() <= 2, "input tensor should be 1D or 2D"
- assert (
- target.dim() <= 1
- ), "0D or 1D target tensor expected, multi-target not supported"
- no_batch_dim = self.dim() == 1 and target.dim() == 0
- assert no_batch_dim or (
- self.shape[0] == target.shape[0]
- ), f"size mismatch (got input: {self.shape}, target: {target.shape})"
- assert total_weight.numel() == 1, (
- "expected total_weight to be a single element tensor, got: ",
- f"{total_weight.shape} ({total_weight.numel()} elements)",
- )
- assert (
- weight is None or weight.numel() == self.shape[-1]
- ), "weight tensor should be defined either for all or no classes"
- if reduction == Reduction.NONE.value and self.dim() == 2:
- assert grad_output.dim() == 1 and grad_output.shape[0] == self.shape[0], (
- f"Expected a tensor of dimension 1 and tensor.size[0] == {self.shape[0]} but "
- f"got: dimension {grad_output.dim()} and tensor.size[0] == {grad_output.shape[0]}"
- )
- else:
- assert (
- grad_output.dim() <= 1 and grad_output.numel() == 1
- ), f"Expected a single element grad_output tensor, but got: {grad_output.shape}"
- return _nll_loss_backward(
- grad_output, self, target, weight, reduction, ignore_index, total_weight
- )
- @register_decomposition(aten.nll_loss2d_backward)
- def nll_loss2d_backward(
- grad_output: Tensor,
- self: Tensor,
- target: Tensor,
- weight: Optional[Tensor],
- reduction: int,
- ignore_index: int,
- total_weight: Tensor,
- ) -> Tensor:
- assert (
- self.dim() == 4
- ), f"only batches of spatial inputs supported (4D tensors), but got input of dimension: {self.dim()}"
- assert (
- target.dim() == 3
- ), f"only batches of spatial targets supported (3D tensors) but got targets of dimension: {target.dim()}"
- assert (
- self.shape[0] == target.shape[0]
- and self.shape[2] == target.shape[1]
- and self.shape[3] == target.shape[2]
- ), f"size mismatch (got input: {self.shape}, target: {target.shape}"
- assert total_weight.numel() == 1, (
- "expected total_weight to be a single element tensor, "
- f"got: {total_weight.shape} ( {total_weight.numel()}, elements)"
- )
- return _nll_loss_backward(
- grad_output, self, target, weight, reduction, ignore_index, total_weight
- )
- @register_decomposition(aten.binary_cross_entropy)
- @pw_cast_for_opmath
- def binary_cross_entropy(
- self: Tensor,
- target: Tensor,
- weight: Optional[Tensor] = None,
- reduction: int = Reduction.MEAN.value,
- ) -> Tensor:
- # We cannot currently model this without introducing data-dependent control flow
- # TORCH_CHECK(
- # (input_val >= 0) && (input_val <= 1),
- # "all elements of input should be between 0 and 1"
- # )
- loss = (target - 1) * torch.maximum(
- torch.log1p(-self), self.new_full((), -100)
- ) - target * torch.maximum(torch.log(self), self.new_full((), -100))
- if weight is not None:
- loss = loss * weight
- return apply_loss_reduction(loss, reduction)
- @register_decomposition(aten.binary_cross_entropy_backward)
- @pw_cast_for_opmath
- def binary_cross_entropy_backward(
- grad_output: Tensor,
- self: Tensor,
- target: Tensor,
- weight: Optional[Tensor] = None,
- reduction: int = Reduction.MEAN.value,
- ) -> Tensor:
- EPSILON = 1e-12
- result = grad_output * (self - target) / torch.clamp(self * (1 - self), min=EPSILON)
- if weight is not None:
- result = result * weight
- if reduction == Reduction.MEAN.value:
- result = result / self.numel()
- return result
- @register_decomposition(aten.soft_margin_loss)
- @out_wrapper()
- @pw_cast_for_opmath
- def soft_margin_loss(
- input: Tensor,
- target: Tensor,
- reduction: int = Reduction.MEAN.value,
- ) -> Tensor:
- loss = torch.log1p(torch.exp(-input * target))
- return apply_loss_reduction(loss, reduction)
- @register_decomposition(aten.soft_margin_loss_backward)
- @pw_cast_for_opmath
- def soft_margin_loss_backward(
- grad_output: Tensor,
- self: Tensor,
- target: Tensor,
- reduction: int = Reduction.MEAN.value,
- ) -> Tensor:
- grad_input = target * grad_output * (torch.sigmoid(target * self) - 1)
- if reduction == Reduction.MEAN.value:
- grad_input = grad_input / self.numel()
- return grad_input
- @register_decomposition(aten._euclidean_dist)
- def _euclidean_dist(x1: Tensor, x2: Tensor) -> Tensor:
- x1_norm = x1.pow(2).sum(-1, True)
- x1_pad = torch.ones_like(x1_norm, memory_format=torch.contiguous_format)
- x2_norm = x2.pow(2).sum(-1, True)
- x2_pad = torch.ones_like(x2_norm, memory_format=torch.contiguous_format)
- x1_ = torch.cat([x1.mul(-2), x1_norm, x1_pad], -1)
- x2_ = torch.cat([x2, x2_pad, x2_norm], -1)
- result = x1_.matmul(x2_.mT)
- return result.clamp_min(0).sqrt()
- @register_decomposition(aten.slice_backward)
- def slice_backward(
- grad_output: Tensor,
- input_sizes: List[int],
- dim: int,
- start: int,
- end: int,
- step: int,
- ):
- grad_input = grad_output.new_zeros(input_sizes)
- return torch.slice_scatter(grad_input, grad_output, dim, start, end, step)
- @register_decomposition(aten.slice.Tensor)
- def slice_forward(
- # Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1
- self: Tensor,
- dim: int = 0,
- start: Optional[int] = None,
- end: Optional[int] = None,
- step: int = 1,
- ):
- ndim = self.dim()
- if ndim == 0:
- raise RuntimeError("slice() cannot be applied to a 0-dim tensor.")
- dim = utils.canonicalize_dim(self.dim(), dim)
- sizes = list(self.size())
- strides = list(self.stride())
- if step <= 0:
- raise RuntimeError("slice step must be positive")
- start_val = start if start is not None else 0
- end_val = end if end is not None else sys.maxsize # 2^63 – 1
- if start_val < 0:
- start_val += sizes[dim]
- if end_val < 0:
- end_val += sizes[dim]
- if start_val < 0:
- start_val = 0
- elif start_val >= sizes[dim]:
- start_val = sizes[dim]
- if end_val < start_val:
- end_val = start_val
- elif end_val >= sizes[dim]:
- end_val = sizes[dim]
- storage_offset = self.storage_offset() + start_val * strides[dim]
- len = end_val - start_val
- sizes[dim] = (len + step - 1) // step
- strides[dim] *= step
- if self.is_quantized:
- raise NotImplementedError(
- "Slice decomposition for quantized tensors aren't implemented"
- )
- else:
- return self.as_strided(sizes, strides, storage_offset)
- @register_decomposition(aten.select_backward)
- def select_backward(grad_output: Tensor, input_sizes: List[int], dim: int, index: int):
- grad_input = grad_output.new_zeros(input_sizes)
- return torch.select_scatter(grad_input, grad_output, dim, index)
- @register_decomposition(aten.diagonal_backward)
- def diagonal_backward(
- grad_output: Tensor, input_sizes: List[int], offset: int, dim1: int, dim2: int
- ):
- grad_input = grad_output.new_zeros(input_sizes)
- return torch.diagonal_scatter(grad_input, grad_output, offset, dim1, dim2)
- def _cast_grad_to_input_dtype(
- grad_output: Tensor, grad_input: Tensor, input_dtype: torch.dtype
- ):
- if grad_output.dtype != input_dtype:
- grad_input = grad_input.to(input_dtype)
- return grad_input
- @register_decomposition(aten._softmax_backward_data)
- @compute_only_pw_cast_for_opmath
- def _softmax_backward_data(
- grad_output: Tensor, output: Tensor, dim: int, input_dtype: torch.dtype
- ):
- new_grad_output = grad_output * output
- grad_input = new_grad_output - output * torch.sum(
- new_grad_output, dim=dim, keepdim=True
- )
- # CPU kernel doesn't respect input_dtype, but following check doesn't work for meta tensor
- # if grad_output.device == torch.device("cpu"):
- # return grad_input.contiguous()
- return _cast_grad_to_input_dtype(grad_output, grad_input, input_dtype).contiguous()
- @register_decomposition(aten._log_softmax_backward_data)
- @compute_only_pw_cast_for_opmath
- def _log_softmax_backward_data(
- grad_output: Tensor, output: Tensor, dim: int, input_dtype: torch.dtype
- ):
- grad_input = grad_output - torch.exp(output) * torch.sum(
- grad_output, dim=dim, keepdim=True
- )
- return _cast_grad_to_input_dtype(grad_output, grad_input, input_dtype)
- def _im2col_col2im_indices_along_dim(
- input_d, kernel_d, dilation_d, padding_d, stride_d, device
- ):
- """Utility function to implement im2col and col2im"""
- blocks_d = input_d + padding_d * 2 - dilation_d * (kernel_d - 1)
- arange_kw = partial(torch.arange, dtype=torch.int64, device=device)
- # Stride kernel over input and find starting indices along dim d
- blocks_d_indices = arange_kw(0, blocks_d, stride_d).unsqueeze(0)
- # Apply dilation on kernel and find its indices along dim d
- kernel_grid = arange_kw(0, kernel_d * dilation_d, dilation_d).unsqueeze(-1)
- # Broadcast and add kernel staring positions (indices) with
- # kernel_grid along dim d, to get block indices along dim d
- return blocks_d_indices + kernel_grid
- @register_decomposition(aten.im2col)
- @out_wrapper()
- @pw_cast_for_opmath
- def im2col(
- input: Tensor,
- kernel_size: List[int],
- dilation: List[int],
- padding: List[int],
- stride: List[int],
- ) -> Tensor:
- utils.check(len(kernel_size) == 2, lambda: "im2col(): only 2D kernel supported")
- utils.check(len(dilation) == 2, lambda: "im2col(): only 2D dilation supported")
- utils.check(len(padding) == 2, lambda: "im2col(): only 2D padding supported")
- utils.check(len(stride) == 2, lambda: "im2col(): only 2D stride supported")
- def check_positive(param, param_name, strict=True):
- cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param)
- utils.check(
- cond, lambda: "{param_name} should be greater {'than' zero, but got {param}"
- )
- check_positive(kernel_size, "kernel_size")
- check_positive(dilation, "dilation")
- check_positive(dilation, "padding", strict=False)
- check_positive(stride, "stride")
- shape = input.shape
- ndim = len(shape)
- utils.check(
- ndim in (3, 4) and all(d != 0 for d in shape[-3:]),
- lambda: "Expected 3D or 4D (batch mode) tensor for input with possible 0 batch size "
- f"and non-zero dimensions, but got: {tuple(shape)}",
- )
- output_size = tuple(
- 1 + (out + 2 * pad - dil * (ker - 1) - 1) // st
- for out, pad, dil, ker, st in zip(
- shape[-2:], padding, dilation, kernel_size, stride
- )
- )
- utils.check(
- all(c > 0 for c in output_size),
- lambda: f"Given an input with spacial size {tuple(shape[-2:])}, "
- f"kernel_size={kernel_size}, dilation={dilation}, "
- f"padding={padding}, stride={stride}, "
- "the calculated shape of the array of sliding blocks "
- f"is {output_size}, but its components must be at least one.",
- )
- batched_input = ndim == 4
- if not batched_input:
- input = input.unsqueeze(0)
- batch_dim, channel_dim, input_h, input_w = input.shape
- stride_h, stride_w = stride
- padding_h, padding_w = padding
- dilation_h, dilation_w = dilation
- kernel_h, kernel_w = kernel_size
- blocks_row_indices = _im2col_col2im_indices_along_dim(
- input_h, kernel_h, dilation_h, padding_h, stride_h, input.device
- )
- blocks_col_indices = _im2col_col2im_indices_along_dim(
- input_w, kernel_w, dilation_w, padding_w, stride_w, input.device
- )
- # Note that F.pad takes (padding_left, padding_right, padding_top, padding_bottom)
- # ugh
- padded_input = F.pad(input, (padding_w, padding_w, padding_h, padding_h))
- blocks_row_indices = blocks_row_indices.unsqueeze(-1).unsqueeze(-1)
- output = padded_input[:, :, blocks_row_indices, blocks_col_indices]
- output = output.permute(0, 1, 2, 4, 3, 5)
- num_blocks_row = blocks_row_indices.size(1)
- num_blocks_col = blocks_col_indices.size(1)
- output = output.reshape(
- batch_dim, channel_dim * kernel_h * kernel_w, num_blocks_row * num_blocks_col
- )
- if not batched_input:
- output = output.squeeze(0)
- return output
- @register_decomposition(aten.col2im)
- @out_wrapper()
- @pw_cast_for_opmath
- def col2im(
- input: Tensor,
- output_size: List[int],
- kernel_size: List[int],
- dilation: List[int],
- padding: List[int],
- stride: List[int],
- ) -> Tensor:
- utils.check(len(output_size) == 2, lambda: "only 2D output_size supported")
- utils.check(len(kernel_size) == 2, lambda: "only 2D kernel supported")
- utils.check(len(dilation) == 2, lambda: "only 2D dilation supported")
- utils.check(len(padding) == 2, lambda: "only 2D padding supported")
- utils.check(len(stride) == 2, lambda: "only 2D stride supported")
- def check_positive(param, param_name, strict=True):
- cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param)
- utils.check(
- cond, lambda: "{param_name} should be greater than zero, but got {param}"
- )
- check_positive(kernel_size, "kernel_size")
- check_positive(dilation, "dilation")
- check_positive(padding, "padding", strict=False)
- check_positive(stride, "stride")
- check_positive(output_size, "output_size")
- shape = input.shape
- ndim = len(shape)
- utils.check(
- ndim in (2, 3) and all(d != 0 for d in shape[-2:]),
- lambda: "Expected 2D or 3D (batch mode) tensor for input with possible 0 batch size "
- f"and non-zero dimensions, but got: {tuple(shape)}",
- )
- prod_kernel_size = kernel_size[0] * kernel_size[1]
- utils.check(
- shape[-2] % prod_kernel_size == 0,
- lambda: "Expected size of input's first non-batch dimension to be divisible by the "
- f"product of kernel_size, but got input.shape[-2] = {shape[-2]} and "
- f"kernel_size={kernel_size}",
- )
- col = [
- 1 + (out + 2 * pad - dil * (ker - 1) - 1) // st
- for out, pad, dil, ker, st in zip(
- output_size, padding, dilation, kernel_size, stride
- )
- ]
- L = col[0] * col[1]
- utils.check(
- shape[-1] == L,
- lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, "
- f"dilation={dilation}, padding={padding}, stride={stride}, "
- f"expected input.size(-1) to be {L} but got {shape[-1]}.",
- )
- utils.check(
- L > 0,
- lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, "
- f"dilation={dilation}, padding={padding}, stride={stride}, "
- f"expected input.size(-1) to be {L} but got {shape[-1]}.",
- )
- batched_input = ndim == 3
- if not batched_input:
- input = input.unsqueeze(0)
- shape = input.shape
- out_h, out_w = output_size
- stride_h, stride_w = stride
- padding_h, padding_w = padding
- dilation_h, dilation_w = dilation
- kernel_h, kernel_w = kernel_size
- # col2im is defined as the backwards of im2col, so we differentiate its decomposition by hand
- input = input.reshape([shape[0], shape[1] // prod_kernel_size] + kernel_size + col)
- input = input.permute(0, 1, 2, 4, 3, 5)
- indices_row = _im2col_col2im_indices_along_dim(
- out_h, kernel_h, dilation_h, padding_h, stride_h, input.device
- )
- indices_row = _unsqueeze_to_dim(indices_row, 4)
- indices_col = _im2col_col2im_indices_along_dim(
- out_w, kernel_w, dilation_w, padding_w, stride_w, input.device
- )
- output_padded_size = [o + 2 * p for o, p in zip(output_size, padding)]
- output = input.new_zeros(
- [shape[0], shape[1] // prod(kernel_size)] + output_padded_size
- )
- idx = (None, None, indices_row, indices_col)
- output = aten.index_put(output, idx, input, accumulate=True)
- output = F.pad(output, (-padding_w, -padding_w, -padding_h, -padding_h))
- if not batched_input:
- output = output.squeeze(0)
- return output
- @register_decomposition(aten.native_dropout_backward)
- def native_dropout_backward(grad_output: Tensor, mask: Tensor, scale: float):
- # According to the CUDA kernel implementation we should have this test;
- # but it seems to fail tests!
- # utils.check(mask.dtype == torch.bool, lambda: f"Mask should be Bool Scalar Type {mask.dtype}")
- # Mimicking CUDA kernel's behavior for output stride: output follow input's memory format
- # This different from TensorIterator's behavior
- r = (grad_output * (mask.type_as(grad_output) * scale)).clone(
- memory_format=utils.suggest_memory_format(grad_output)
- )
- return r
- @register_decomposition(aten.unfold_backward)
- def unfold_backward(
- grad: Tensor, input_size: List[int], dimension: int, size: int, step: int
- ) -> Tensor:
- if len(input_size) == 0:
- return torch.squeeze_copy(grad, 0)
- dim = utils.canonicalize_dim(len(input_size), dimension)
- idx = torch.arange(input_size[dim], device=grad.device, dtype=torch.int32)
- idx = idx.unfold(0, size, step).flatten()
- grad = grad.movedim(-1, dim + 1).flatten(dim, dim + 1)
- # nb. At the moment this generates two kernels in triton
- # It could potentially be fused into one call to scatter_reduce,
- # in the case step <= size provided scatter_reduce generates 1 kernel
- grad_input = grad.new_zeros(input_size)
- return torch.index_add(grad_input, dim, idx, grad)
- @register_decomposition(aten.logit_backward.default)
- @pw_cast_for_opmath
- def logit_backward(
- grad_output: Tensor, self: Tensor, eps: Optional[float] = None
- ) -> Tensor:
- if eps is not None:
- lo = eps
- hi = 1.0 - lo
- return torch.where(
- torch.logical_and(self >= lo, self <= hi),
- grad_output / (self * (1.0 - self)),
- 0.0,
- )
- else:
- return torch.where(
- torch.logical_and(self >= 0.0, self <= 1.0),
- grad_output / (self * (1.0 - self)),
- self.new_full((), float("nan")),
- )
- @register_decomposition(aten.native_dropout)
- def native_dropout(input: Tensor, p: float, train: Optional[bool]):
- if train:
- bool_mask = torch.rand_like(input) > p
- res = bool_mask * input * float(1.0 / (1.0 - p))
- return (res, bool_mask)
- else:
- return (input, torch.ones_like(input, dtype=torch.bool))
- @register_decomposition(aten._softmax)
- @out_wrapper()
- def _softmax(x: Tensor, dim: int, half_to_float: bool):
- # eager softmax returns a contiguous tensor. Ensure that decomp also returns
- # a contiguous tensor.
- x = x.contiguous()
- if half_to_float:
- assert x.dtype == torch.half
- computation_dtype, result_dtype = utils.elementwise_dtypes(
- x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
- )
- x = x.to(computation_dtype)
- if x.numel() == 0:
- unnormalized = torch.exp(x)
- else:
- x_max = torch.amax(x, dim, keepdim=True)
- unnormalized = torch.exp(x - x_max)
- result = unnormalized / torch.sum(unnormalized, dim, keepdim=True)
- if not half_to_float:
- result = result.to(result_dtype)
- return result
- @register_decomposition(aten._log_softmax)
- @out_wrapper()
- def _log_softmax(x: Tensor, dim: int, half_to_float: bool):
- # eager log_softmax returns a contiguous tensor. Ensure that decomp also
- # returns a contiguous tensor.
- x = x.contiguous()
- if half_to_float:
- assert x.dtype == torch.half
- computation_dtype, result_dtype = utils.elementwise_dtypes(
- x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
- )
- x = x.to(computation_dtype)
- if x.numel() == 0:
- shifted = x
- else:
- x_max = torch.amax(x, dim, keepdim=True)
- shifted = x - x_max
- shifted_logsumexp = torch.log(torch.sum(torch.exp(shifted), dim, keepdim=True))
- result = shifted - shifted_logsumexp
- if not half_to_float:
- result = result.to(result_dtype)
- return result
- @register_decomposition(aten.rsub.Tensor)
- def rsub_Tensor(self: Tensor, other: Tensor, alpha: float = 1) -> Tensor:
- return torch.sub(other, self, alpha=alpha)
- @register_decomposition(aten.rsub.Scalar)
- def rsub_Scalar(self: Tensor, other: float, alpha: float = 1) -> Tensor:
- return torch.sub(other, self, alpha=alpha)
- @register_decomposition(aten.embedding)
- def embedding(
- weight: Tensor,
- indices: Tensor,
- padding_idx: int = -1,
- scale_grad_by_freq: bool = False,
- sparse: bool = False,
- ) -> Tensor:
- assert weight.dim() == 2, "'weight' must be 2-D"
- # Nb. scale_grad_by_freq is not used in the forward
- if indices.ndim <= 1:
- # We need this one as weight[indices] calls item() in these cases
- out = weight.index_select(0, indices)
- if indices.ndim == 0:
- out = out.squeeze(0)
- return out
- else:
- return weight[indices]
- @register_decomposition(aten.embedding_dense_backward)
- def embedding_dense_backward(
- grad_output: Tensor,
- indices: Tensor,
- num_weights: int,
- padding_idx: int,
- scale_grad_by_freq: bool,
- ):
- computation_dtype, result_dtype = utils.elementwise_dtypes(
- grad_output, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
- )
- grad_output = grad_output.to(computation_dtype)
- indices = _maybe_convert_to_dtype(indices, torch.long) # type: ignore[assignment]
- if scale_grad_by_freq:
- counts = indices.new_zeros((num_weights,))
- ones = torch.ones_like(indices)
- counts = counts.index_put([indices], ones, accumulate=True)
- grad_weights_scale = counts[indices]
- grad_output = grad_output / grad_weights_scale.unsqueeze(1)
- mask = _unsqueeze_to_dim(indices == padding_idx, grad_output.ndim)
- grad = grad_output.masked_fill(mask, 0)
- grad_weight = grad_output.new_zeros(
- (num_weights,) + grad_output.shape[indices.ndim :]
- )
- return grad_weight.index_put([indices], grad, accumulate=True).to(result_dtype)
- def prod(x: List[int]):
- r = 1
- for i in x:
- r *= i
- return r
- @register_decomposition([aten.split_with_sizes, aten.unsafe_split_with_sizes])
- def split_with_sizes(
- self: Tensor, split_sizes: List[int], dim: int = 0
- ) -> List[Tensor]:
- num_splits = len(split_sizes)
- splits = []
- start_idx = 0
- for i in range(num_splits):
- length = split_sizes[i]
- splits.append(self.narrow(dim, start_idx, length))
- start_idx += length
- return splits
- @register_decomposition([aten.split.Tensor, aten.unsafe_split.Tensor])
- def split(self: Tensor, split_size: int, dim: int = 0) -> List[Tensor]:
- input_sizes = self.shape
- dim_size = input_sizes[dim]
- if split_size == 0:
- assert dim_size == 0
- return [self]
- chunks = (dim_size + split_size - 1) // split_size
- chunks = guard_int(chunks)
- split_sizes = [split_size for i in range(chunks)]
- split_sizes[-1] = split_size - (split_size * chunks - dim_size)
- return torch.split(self, split_sizes, dim)
- # TODO: this doesn't appear to have enough precision in bfloat16
- @register_decomposition(aten.addmm)
- @out_wrapper()
- @pw_cast_for_opmath
- def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int = 1):
- if not self.is_floating_point() and not self.is_complex():
- beta = int(beta)
- alpha = int(alpha)
- out = alpha * torch.mm(mat1, mat2)
- if beta == 0:
- return out
- # The output of aten.addmm is contiguous, we need to match this behavior in the decomposition.
- # The original implementation 'beta * self + out' would return a strided tensor if `self` is strided.
- # We thus use `out`, the output of torch.mm, which is always contiguous, as the first argument for addition.
- # This is relying on TensorIterator's behavior that it takes higher precedence on the stride of first input.
- # Alternative, we can write `(beta * self + out).contiguous()`, but it introduces another copy in some cases.
- # This implementation is not ideal, and we should revisit this when we have a better solution.
- return out + beta * self
- @register_decomposition(aten.native_group_norm_backward)
- @pw_cast_for_opmath
- def native_group_norm_backward(
- grad_output: Tensor,
- input: Tensor,
- mean: Tensor,
- rstd: Tensor,
- gamma: Optional[Tensor],
- N: int,
- C: int,
- HxW: int,
- group: int,
- output_mask: List[bool],
- ) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
- utils.check_same_device(
- grad_output, input, mean, rstd, allow_cpu_scalar_tensors=False
- )
- utils.check_same_shape(input, grad_output, allow_cpu_scalar_tensors=False)
- utils.check_same_shape(mean, rstd, allow_cpu_scalar_tensors=False)
- utils.check(
- input.numel() == N * C * HxW,
- lambda: f"Expect input to have { N * C * HxW} elements",
- )
- utils.check(
- mean.shape == (N, group),
- lambda: f"Expect mean to have shape ({N}, {group}, but got {mean.shape}",
- )
- utils.check(
- gamma is None or gamma.numel() == C,
- lambda: f"Expect gamma to have {C} elements but got {gamma.numel() if gamma is not None else -1}",
- )
- cpg, _rem = divmod(C, group)
- utils.check(
- _rem == 0,
- lambda: f"Expect number of channels {C} to be evenly-divisible by number of groups {group}",
- )
- # Compute Internal gradients
- ds = torch.mul(grad_output, input).view(N, C, HxW).sum(dim=[2])
- db = grad_output.view(N, C, HxW).sum(dim=[2])
- d_input: Optional[Tensor] = None
- d_gamma: Optional[Tensor] = None
- d_bias: Optional[Tensor] = None
- if output_mask[0]:
- s = 1.0 / (HxW * cpg)
- if gamma is not None:
- ds_val = torch.mul(ds, gamma.unsqueeze(0)).reshape(N, group, cpg).sum(2)
- db_val = torch.mul(db, gamma.unsqueeze(0)).reshape(N, group, cpg).sum(2)
- c1 = torch.mul(
- rstd.unsqueeze(-1),
- gamma.reshape(1, group, cpg),
- )
- else:
- ds_val = ds.reshape(N, group, cpg).sum(2)
- db_val = db.reshape(N, group, cpg).sum(2)
- c1 = torch.mul(
- rstd.unsqueeze(-1),
- torch.ones((1, group, cpg), device=rstd.device),
- )
- c2 = (db_val * mean - ds_val) * rstd * rstd * rstd * s
- c3 = -c2 * mean - db_val * rstd * s
- c1 = c1.unsqueeze(-1)
- c2 = _unsqueeze_to_dim(c2, 4)
- c3 = _unsqueeze_to_dim(c3, 4)
- d_input = (
- torch.mul(grad_output.reshape(N, group, cpg, HxW), c1)
- + torch.mul(input.reshape(N, group, cpg, HxW), c2)
- + c3
- )
- d_input = d_input.reshape(input.shape).to(input.dtype)
- if output_mask[1]:
- d_gamma = (
- (
- (ds.view(N, group, cpg) - db.view(N, group, cpg) * mean.unsqueeze(-1))
- * rstd.unsqueeze(-1)
- )
- .sum(dim=[0])
- .reshape(C)
- )
- if output_mask[2]:
- d_bias = db.sum(dim=[0])
- return (d_input, d_gamma, d_bias)
- def _maybe_cast(x: Optional[Tensor], dtype) -> Optional[Tensor]:
- if x is not None:
- return x.to(dtype)
- return x
- # TODO: Take a closer look at the type promotion semantics
- @register_decomposition(aten.native_layer_norm_backward)
- def native_layer_norm_backward(
- grad_out: Tensor,
- input: Tensor,
- normalized_shape: List[int],
- mean: Tensor,
- rstd: Tensor,
- weight: Optional[Tensor],
- bias: Optional[Tensor],
- output_mask: List[bool],
- ) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
- input_shape = input.shape
- input_ndim = input.dim()
- computation_dtype = utils.get_computation_dtype(input.dtype)
- grad_out_cast, input_cast, weight_cast, bias_cast = [
- x.to(computation_dtype).contiguous() if x is not None else x
- for x in (grad_out, input, weight, bias)
- ]
- assert grad_out_cast is not None
- axis = input_ndim - len(normalized_shape)
- inner_dims = input_shape[axis:]
- outer_dims = input_shape[:axis]
- inner_dim_indices: List[int] = []
- outer_dim_indices: List[int] = []
- for i in range(input_ndim):
- if i >= axis:
- inner_dim_indices.append(i)
- else:
- outer_dim_indices.append(i)
- N = prod(inner_dims) # type: ignore[arg-type]
- M = prod(outer_dims) # type: ignore[arg-type]
- if M <= 0 or N <= 0:
- return (
- input.new_zeros(input_shape) if output_mask[0] else None,
- input.new_zeros(input_shape[axis:]) if output_mask[1] else None,
- input.new_zeros(input_shape[axis:]) if output_mask[2] else None,
- )
- x_hat = (input_cast - mean) * rstd
- if weight_cast is not None:
- grad_x_hat = grad_out_cast * weight_cast
- else:
- grad_x_hat = grad_out_cast
- a = grad_x_hat * N
- b = torch.sum(grad_x_hat, inner_dim_indices, True)
- c1 = torch.mul(grad_x_hat, x_hat)
- c2 = torch.sum(c1, inner_dim_indices, True)
- c3 = torch.mul(x_hat, c2)
- inner = a - b - c3
- d_input: Optional[Tensor] = None
- d_weight: Optional[Tensor] = None
- d_bias: Optional[Tensor] = None
- if output_mask[0]:
- d_input = (rstd / N) * inner
- if output_mask[1] and weight_cast is not None:
- if len(outer_dim_indices) > 0:
- d_weight = torch.sum(grad_out_cast * x_hat, outer_dim_indices, False)
- else:
- d_weight = grad_out_cast * x_hat
- if output_mask[2] and bias_cast is not None:
- if len(outer_dim_indices) > 0:
- d_bias = torch.sum(grad_out_cast, outer_dim_indices, False)
- else:
- d_bias = grad_out_cast.clone()
- return (
- _maybe_cast(d_input, input.dtype),
- _maybe_cast(d_weight, input.dtype),
- _maybe_cast(d_bias, input.dtype),
- )
- def native_batch_norm_helper(
- input: Tensor,
- weight: Optional[Tensor],
- bias: Optional[Tensor],
- running_mean: Optional[Tensor],
- running_var: Optional[Tensor],
- training: bool,
- momentum: float,
- eps: float,
- functional: bool,
- ) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
- reduction_dims = [0] + list(range(2, input.dim()))
- computation_dtype = utils.get_computation_dtype(input.dtype)
- new_running_mean = running_mean
- new_running_var = running_var
- if training:
- computation_dtype = utils.get_computation_dtype(input.dtype)
- input_acc = input.to(dtype=computation_dtype)
- biased_var, mean = torch.var_mean(
- input_acc, dim=reduction_dims, correction=0, keepdim=True
- )
- rstd = torch.rsqrt(biased_var + eps)
- output = (input - mean) * rstd
- save_mean = torch.squeeze(mean, reduction_dims)
- save_rstd = torch.squeeze(rstd, reduction_dims)
- if running_mean is not None:
- new_running_mean = momentum * save_mean + (1 - momentum) * running_mean
- if not functional:
- running_mean.copy_(new_running_mean)
- if running_var is not None:
- n = input.numel() / input.shape[1]
- # This doesn't strictly match eager's numerics, which accumulates var sum and then directly applies the correction
- # But... that would require re-implementing var here, for negligible numerics gain on a tensor whose
- # numerics probably don't matter.
- squeezed_var = torch.squeeze(biased_var, reduction_dims)
- unbiased_var = squeezed_var * (n / (n - 1))
- new_running_var = momentum * unbiased_var + (1 - momentum) * running_var
- if not functional:
- running_var.copy_(new_running_var)
- else:
- assert running_mean is not None and running_var is not None
- running_mean = running_mean.to(dtype=computation_dtype, copy=True)
- new_running_mean = running_mean
- running_var = running_var.to(dtype=computation_dtype, copy=True)
- new_running_var = running_var
- mean = running_mean
- invstd = 1 / (torch.sqrt(running_var + eps))
- # Very annoying inconsistency where CPU and CUDA give different shapes
- if input.device.type != "cpu":
- save_mean = running_mean
- save_rstd = invstd
- else:
- save_mean = input.new_zeros((0,))
- save_rstd = input.new_zeros((0,))
- mean = _unsqueeze_to_dim(mean, input.dim() - 1)
- invstd = _unsqueeze_to_dim(invstd, input.dim() - 1)
- output = (input - mean) * invstd
- if weight is None:
- weight = input.new_ones(())
- if bias is None:
- bias = input.new_zeros(())
- weight = _unsqueeze_to_dim(weight, input.dim() - 1)
- bias = _unsqueeze_to_dim(bias, input.dim() - 1)
- output = output * weight + bias
- if input.device.type == "cpu":
- save_mean = save_mean.to(dtype=input.dtype)
- save_rstd = save_rstd.to(dtype=input.dtype)
- return (
- output.to(dtype=input.dtype),
- save_mean,
- save_rstd,
- new_running_mean,
- new_running_var,
- )
- @register_decomposition(aten.native_batch_norm)
- def native_batch_norm(
- input: Tensor,
- weight: Optional[Tensor],
- bias: Optional[Tensor],
- running_mean: Optional[Tensor],
- running_var: Optional[Tensor],
- training: bool,
- momentum: float,
- eps: float,
- ) -> Tuple[Tensor, Tensor, Tensor]:
- output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
- input, weight, bias, running_mean, running_var, training, momentum, eps, False
- )
- return output, save_mean, save_rstd
- # TODO: this decomposition is NOT here to stay. We would much prefer replacing native_batch_norm
- # with our new correctly schema'd _native_batch_norm_legit and its variants, but
- # we cannot do that immediately in the C++ because it would be forwards incompatible
- # with some mobile use cases.
- #
- # Since this change is most impactful for aot autograd/functionalization, we simply
- # register this decomposition on the Autograd key for the python dispatcher (which is
- # currently only used by aot autograd/functionalization and no one else, really).
- # In two weeks or so, we should remove this decomposition and phase out the current native_batch_norm
- # to be _native_batch_norm_legit and have the right schema (stating that there are input mutations).
- @aten.native_batch_norm.default.py_impl(DispatchKey.Autograd)
- def native_batch_norm_decomposition(
- input: Tensor,
- weight: Optional[Tensor],
- bias: Optional[Tensor],
- running_mean: Optional[Tensor],
- running_var: Optional[Tensor],
- training: bool,
- momentum: float,
- eps: float,
- ) -> Tuple[Tensor, Tensor, Tensor]:
- if running_mean is None and running_var is None:
- return aten._native_batch_norm_legit(
- input, weight, bias, training, momentum, eps
- )
- if running_mean is None:
- raise RuntimeError(
- "running_mean is None, but running_var is provided. "
- "They should both be None or both be provided."
- )
- if running_var is None:
- raise RuntimeError(
- "running_var is None, but running_mean is provided. "
- "They should both be None or both be provided."
- )
- return aten._native_batch_norm_legit(
- input, weight, bias, running_mean, running_var, training, momentum, eps
- )
- @aten.unsafe_chunk.default.py_impl(DispatchKey.CompositeImplicitAutograd)
- def unsafe_chunk_py_impl(tensor, chunks, dim=0) -> List[Tensor]:
- dim_size = tensor.size(dim)
- split_size = (dim_size + chunks - 1) // chunks
- if split_size == 0 and dim_size == 0:
- split_sizes = [split_size for _ in chunks]
- split_sizes[chunks - 1] = split_size - (split_size * chunks - dim_size)
- return torch.ops.aten.unsafe_split_with_sizes.default(tensor, split_sizes, dim)
- return torch.ops.aten.unsafe_split.Tensor(tensor, split_size, dim)
- @register_decomposition(aten._native_batch_norm_legit.default)
- def _native_batch_norm_legit(
- input: Tensor,
- weight: Optional[Tensor],
- bias: Optional[Tensor],
- running_mean: Tensor,
- running_var: Tensor,
- training: bool,
- momentum: float,
- eps: float,
- ) -> Tuple[Tensor, Tensor, Tensor]:
- output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
- input, weight, bias, running_mean, running_var, training, momentum, eps, False
- )
- return output, save_mean, save_rstd
- @register_decomposition(aten._native_batch_norm_legit.no_stats)
- def _native_batch_norm_legit_no_stats(
- input: Tensor,
- weight: Optional[Tensor],
- bias: Optional[Tensor],
- training: bool,
- momentum: float,
- eps: float,
- ) -> Tuple[Tensor, Tensor, Tensor]:
- output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
- input, weight, bias, None, None, training, momentum, eps, False
- )
- return output, save_mean, save_rstd
- @register_decomposition(aten._native_batch_norm_legit_functional.default)
- def _native_batch_norm_legit_functional(
- input: Tensor,
- weight: Optional[Tensor],
- bias: Optional[Tensor],
- running_mean: Tensor,
- running_var: Tensor,
- training: bool,
- momentum: float,
- eps: float,
- ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
- (
- output,
- save_mean,
- save_rstd,
- new_running_mean,
- new_running_var,
- ) = native_batch_norm_helper(
- input, weight, bias, running_mean, running_var, training, momentum, eps, True
- )
- assert new_running_mean is not None, "new_running_mean should not be None"
- assert new_running_var is not None, "new_running_var should not be None"
- return output, save_mean, save_rstd, new_running_mean, new_running_var
- @register_decomposition(aten._fused_dropout)
- @pw_cast_for_opmath
- def _fused_dropout_decomposition(input, p, generator=None):
- assert generator is None
- mask = (torch.rand_like(input) < p).to(dtype=torch.uint8)
- res = mask.type_as(input) * input * (1.0 / p)
- return (res, mask)
- @register_decomposition(aten._to_copy)
- def _to_copy(
- x: Tensor,
- *,
- dtype: Optional[torch.dtype] = None,
- layout=None,
- device: Optional[torch.device] = None,
- pin_memory: bool = False,
- non_blocking: bool = False,
- memory_format: Optional[torch.memory_format] = None,
- ):
- assert not layout or layout == torch.strided, "TODO"
- assert not pin_memory, "TODO"
- if device is None and dtype is None and memory_format is None:
- return x.clone()
- dtype_converted = False
- if device is not None and device != x.device:
- # avoid conversions on cpu
- if dtype is not None and device.type == "cpu":
- x = torch._prims.convert_element_type(x, dtype)
- dtype_converted = True
- x = torch._prims.device_put(x, device)
- if dtype is not None and not dtype_converted:
- x = torch._prims.convert_element_type(x, dtype)
- if memory_format is not None: # no ref/prim for memory format
- return torch.clone(x, memory_format=memory_format)
- return x
- # Questionable decompositions
- # This is only valid if we're running the graph without autograd, such as if the backward pass has been traced.
- # Note that this decomposition causes issues with in-place ops
- @register_decomposition([aten.detach, aten.lift, aten.lift_fresh])
- def nop_decomposition(x):
- return aten.alias(x)
- # Also register to the Autograd dispatch key, so this decomp can run above autograd.
- # native_batch_norm needs to decompose into other ops before autograd.
- @aten.cudnn_batch_norm.default.py_impl(DispatchKey.Autograd)
- @register_decomposition(aten.cudnn_batch_norm)
- def cudnn_batch_norm(
- input: Tensor,
- weight: Tensor,
- bias: Optional[Tensor],
- running_mean: Optional[Tensor],
- running_var: Optional[Tensor],
- training: bool,
- exponential_average_factor: float,
- epsilon: float,
- ):
- a, b, c = aten.native_batch_norm(
- input,
- weight,
- bias,
- running_mean,
- running_var,
- training,
- exponential_average_factor,
- epsilon,
- )
- # Cudnn return running mean and variance when training is True
- if training:
- return (a, b, c, input.new_zeros((0,), dtype=torch.uint8))
- return (
- a,
- weight.new_zeros((0,)),
- weight.new_zeros((0,)),
- input.new_zeros((0,), dtype=torch.uint8),
- )
- def _broadcast_batch_norm_backward(x, broadcast_mask):
- for axis, mask in enumerate(broadcast_mask):
- if mask == 1 and not (axis < x.ndim and x.shape[axis] == broadcast_mask[axis]):
- x = x.unsqueeze(axis)
- return x
- @register_decomposition(aten.native_batch_norm_backward)
- def native_batch_norm_backward(
- grad_out: Tensor,
- input: Tensor,
- weight: Optional[Tensor],
- running_mean: Optional[Tensor],
- running_var: Optional[Tensor],
- save_mean: Optional[Tensor],
- save_invstd: Optional[Tensor],
- train: bool,
- eps: float,
- output_mask: List[bool],
- ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
- input_dtype = input.dtype
- if weight is not None:
- weight_dtype = weight.dtype
- else:
- weight_dtype = input_dtype
- computation_dtype = utils.get_computation_dtype(input.dtype)
- (
- grad_out_cast,
- input_cast,
- weight_cast,
- running_mean_cast,
- running_var_cast,
- save_mean_cast,
- save_invstd_cast,
- ) = [
- x.to(computation_dtype) if x is not None else x
- for x in (
- grad_out,
- input,
- weight,
- running_mean,
- running_var,
- save_mean,
- save_invstd,
- )
- ]
- input_shape = input.shape
- input_rank = input.dim()
- assert input_rank >= 2, "rank of the input must be at least 2"
- axis = 1
- num_features = prod(list(input_shape)) / input_shape[axis]
- mean = save_mean_cast
- invstd = save_invstd_cast
- if train:
- assert save_mean_cast is not None and save_invstd_cast is not None
- else:
- assert running_mean_cast is not None and running_var_cast is not None
- mean = running_mean_cast
- invstd = torch.rsqrt(running_var_cast + eps)
- broadcast_mask: List[int] = [1] * input_rank
- broadcast_mask[axis] = input_shape[axis]
- reduction_axes: List[int] = []
- for i in range(input_rank):
- if i != axis:
- reduction_axes.append(i)
- mean = _broadcast_batch_norm_backward(mean, broadcast_mask) # type: ignore[arg-type]
- norm = 1.0 / num_features
- grad_output_sum = torch.sum(grad_out_cast, reduction_axes) # type: ignore[arg-type]
- dot_p = torch.sum(grad_out_cast * (input_cast - mean), reduction_axes) # type: ignore[operator]
- grad_mean = _broadcast_batch_norm_backward(grad_output_sum * norm, broadcast_mask)
- proj_scale = _broadcast_batch_norm_backward(torch.mul(dot_p * norm, invstd * invstd), broadcast_mask) # type: ignore[operator]
- if weight_cast is None:
- grad_scale = _broadcast_batch_norm_backward(invstd, broadcast_mask) * 1.0 # type: ignore[arg-type]
- else:
- grad_scale = _broadcast_batch_norm_backward(
- invstd * weight_cast, broadcast_mask
- )
- if train:
- proj = (input_cast - mean) * proj_scale # type: ignore[operator]
- grad_input = ((grad_out_cast - proj) - grad_mean) * grad_scale
- else:
- grad_input = grad_out_cast * grad_scale
- if output_mask[1]:
- grad_weight = dot_p * invstd
- else:
- grad_weight = None # "None" doesn't work with vjp, should use zeros for vjp
- if output_mask[2]:
- grad_bias = grad_output_sum
- else:
- grad_bias = None # "None" doesn't work with vjp, should use zeros for vjp
- return (
- grad_input.to(input_dtype),
- _maybe_cast(grad_weight, weight_dtype),
- _maybe_cast(grad_bias, weight_dtype),
- )
- @register_decomposition(aten.cudnn_batch_norm_backward)
- def cudnn_batch_norm_backward(
- input: Tensor,
- grad_output: Tensor,
- weight: Tensor,
- running_mean: Optional[Tensor],
- running_var: Optional[Tensor],
- save_mean: Optional[Tensor],
- save_var: Optional[Tensor],
- epsilon: float,
- reserveSpace: Tensor,
- ):
- return aten.native_batch_norm_backward(
- grad_output,
- input,
- weight,
- running_mean,
- running_var,
- save_mean,
- save_var,
- True,
- epsilon,
- [True, True, True],
- )
- @register_decomposition(aten._adaptive_avg_pool2d)
- @pw_cast_for_opmath
- def adaptive_avg_pool2d(input: Tensor, output_size: Tuple[int, int]):
- # Preconditions
- device = input.device
- shape = input.shape
- ndim = len(shape)
- utils.check(
- ndim in (3, 4),
- lambda: f"adaptive_avg_pool2d(): Expected 3D or 4D tensor, but got {ndim}",
- )
- for d in input.shape[-2:]:
- utils.check(
- d != 0,
- lambda: "adaptive_avg_pool2d(): Expected input to have non-zero size for "
- f"non-batch dimensions, but input has shape {tuple(shape)}.",
- )
- # Optimisation (we should also do this in the kernel implementation)
- if shape[-2] % output_size[-2] == 0 and shape[-1] % output_size[-1] == 0:
- stride = tuple(i // o for i, o in zip(shape[-2:], output_size))
- kernel = tuple(
- i - (o - 1) * s for i, o, s in zip(shape[-2:], output_size, stride)
- )
- return torch.nn.functional.avg_pool2d(input, kernel, stride)
- def start_index(a, b, c):
- return torch.div(a * c, b, rounding_mode="trunc")
- def end_index(a, b, c):
- return torch.div((a + 1) * c + b - 1, b, rounding_mode="trunc")
- def compute_idx(in_size, out_size):
- orange = torch.arange(out_size, device=device, dtype=torch.int64)
- i0 = start_index(orange, out_size, in_size)
- # Let length = end_index - start_index, i.e. the length of the pooling kernels
- # length.max() can be computed analytically as follows:
- maxlength = in_size // out_size + 1
- in_size_mod = in_size % out_size
- # adaptive = True iff there are kernels with different lengths
- adaptive = not (in_size_mod == 0 or out_size % in_size_mod == 0)
- if adaptive:
- maxlength += 1
- elif in_size_mod == 0:
- maxlength -= 1
- range_max = torch.arange(maxlength, device=device, dtype=torch.int64)
- idx = i0.unsqueeze(-1) + range_max
- if adaptive:
- # Need to clamp to avoid accesing out-of-bounds memory
- # TODO make minimum accept scalars
- maxval = torch.scalar_tensor(
- in_size - 1, dtype=idx.dtype, device=idx.device
- )
- idx = torch.minimum(idx, maxval)
- # Compute the lenghts
- i1 = end_index(orange, out_size, in_size)
- length = i1 - i0
- else:
- length = maxlength
- return idx, length, range_max, adaptive
- # length is not None if it's constant, otherwise we'll need to compute it
- idxh, length_h, range_max_h, adaptive_h = compute_idx(shape[-2], output_size[-2])
- idxw, length_w, range_max_w, adaptive_w = compute_idx(shape[-1], output_size[-1])
- vals = input[..., _unsqueeze_to_dim(idxh, 4), idxw]
- # Shortcut for the simpler case
- if not adaptive_h and not adaptive_w:
- return torch.mean(vals, dim=(-3, -1))
- def maybe_mask(vals, length, range_max, adaptive, dim):
- if isinstance(length, IntLike):
- return vals, length
- else:
- # zero-out the things we didn't really want to select
- assert dim < 0
- # hack
- mask = range_max >= length.unsqueeze(-1)
- if dim == -2:
- mask = _unsqueeze_to_dim(mask, 4)
- vals = torch.masked_fill(vals, mask, 0.0)
- # Compute the length of each window
- length = _unsqueeze_to_dim(length, -dim)
- return vals, length
- vals, length_h = maybe_mask(
- vals, length_h, range_max_h, adaptive=adaptive_h, dim=-2
- )
- vals, length_w = maybe_mask(
- vals, length_w, range_max_w, adaptive=adaptive_w, dim=-1
- )
- # We unroll the sum as we assume that the kernels are going to be small
- ret = None
- for i, j in product(range(vals.shape[-3]), range(vals.shape[-1])):
- if ret is None:
- ret = vals[..., i, :, j]
- else:
- ret = ret + vals[..., i, :, j]
- return ret / (length_h * length_w)
- @register_decomposition(aten.index_add_)
- def index_add_(
- x: TensorLike,
- dim: int,
- index: TensorLike,
- tensor: TensorLike,
- *,
- alpha: NumberType = 1,
- ):
- return _index_add(x, dim, index, tensor, inplace=True, alpha=alpha)
- @register_decomposition(aten.index_add)
- @out_wrapper()
- def index_add(
- x: TensorLike,
- dim: int,
- index: TensorLike,
- tensor: TensorLike,
- *,
- alpha: NumberType = 1,
- ):
- return _index_add(x, dim, index, tensor, inplace=False, alpha=alpha)
- def _index_add(
- x: TensorLike,
- dim: int,
- index: TensorLike,
- tensor: TensorLike,
- *,
- inplace: bool,
- alpha: NumberType = 1,
- ):
- dim = utils.canonicalize_dims(x.ndim, dim)
- utils.check(
- index.ndim <= 1,
- lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
- )
- if alpha != 1:
- python_type = utils.dtype_to_type(x.dtype)
- utils.check(
- python_type == bool
- or utils.is_weakly_lesser_type(type(alpha), python_type),
- lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!",
- )
- tensor = tensor * alpha
- # Treat scalars as elements of \R^1
- zero_dim = x.ndim == 0
- x1 = x.unsqueeze(0) if zero_dim else x
- idx = (None,) * dim + (index,)
- index_put = aten.index_put_ if inplace else aten.index_put
- out = index_put(x1, idx, tensor, accumulate=True)
- if inplace:
- return x
- else:
- return out.squeeze(0) if zero_dim else out.contiguous()
- @register_decomposition(aten.index_copy_)
- def index_copy_(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike):
- return _index_copy(x, dim, index, tensor, inplace=True)
- @register_decomposition(aten.index_copy)
- @out_wrapper()
- def index_copy(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike):
- return _index_copy(x, dim, index, tensor, inplace=False)
- def _index_copy(
- x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike, *, inplace: bool
- ):
- dim = utils.canonicalize_dims(x.ndim, dim)
- utils.check(
- index.ndim <= 1,
- lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
- )
- # Treat scalars as elements of \R^1
- zero_dim = x.ndim == 0
- x1 = x.unsqueeze(0) if zero_dim else x
- idx = (None,) * dim + (index,)
- index_put = aten.index_put_ if inplace else aten.index_put
- out = index_put(x1, idx, tensor)
- if inplace:
- return x
- else:
- return out.squeeze(0) if zero_dim else out.contiguous()
- # nb: Should use acc_t, not op_math
- @register_decomposition(aten.log_sigmoid_forward)
- @out_wrapper("output", "buffer")
- @pw_cast_for_opmath
- def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]:
- min = torch.minimum(self.new_zeros(()), self)
- z = torch.exp(-torch.abs(self))
- if self.is_cuda:
- buffer = self.new_zeros((0,))
- else:
- buffer = z
- return min - torch.log1p(z), buffer
- @register_decomposition(aten.uniform)
- def uniform(
- x: Tensor,
- low: Union[bool, int, float] = 0.0,
- high: Union[bool, int, float] = 1.0,
- ):
- return prims._uniform_helper(
- x.shape,
- low=sym_float(low),
- high=sym_float(high),
- dtype=x.dtype,
- device=x.device,
- )
- @register_decomposition(aten.uniform_)
- def uniform_(self, low=0, high=1, generator=None):
- assert generator is None
- return self.copy_((high - low) * torch.rand_like(self) + low)
- # aten/src/ATen/native/UpSample.cpp compute_output_size
- def upsample_compute_output_size(input_size, output_size, scale_factors):
- spatial_dimensions = len(input_size) - 2
- if output_size is not None:
- utils.check(
- scale_factors is None,
- lambda: "Must specify exactly one of output_size and scale_factors",
- )
- utils.check(len(output_size) == spatial_dimensions, lambda: "")
- return output_size
- if scale_factors is not None:
- # NB: this isn't necessary lol
- utils.check(
- output_size is None,
- lambda: "Must specify exactly one of output_size and scale_factors",
- )
- utils.check(len(scale_factors) == spatial_dimensions, lambda: "")
- output_size = []
- for i, s in enumerate(scale_factors):
- if int(s) == s:
- output_size.append(input_size[i + 2] * int(s))
- else:
- output_size.append(sym_int(input_size[i + 2] * s))
- return output_size
- utils.check(
- False, lambda: "Must specify exactly one of output_size and scale_factors"
- )
- def get_scale_value(scales, idx):
- if scales is None:
- return None
- return scales[idx]
- @register_decomposition(aten.upsample_nearest1d.vec)
- @aten.upsample_nearest1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
- @aten.upsample_nearest1d.vec.py_impl(DispatchKey.Autograd)
- def upsample_nearest1d_vec(input, output_size, scale_factors):
- osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
- scale = get_scale_value(scale_factors, 0)
- return upsample_nearest1d(input, osize, scale)
- @register_decomposition(aten.upsample_nearest2d.vec)
- @aten.upsample_nearest2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
- @aten.upsample_nearest2d.vec.py_impl(DispatchKey.Autograd)
- def upsample_nearest2d_vec(input, output_size, scale_factors):
- osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
- scale_h = get_scale_value(scale_factors, 0)
- scale_w = get_scale_value(scale_factors, 1)
- return upsample_nearest2d(input, osize, scale_h, scale_w)
- @register_decomposition(aten.upsample_nearest3d.vec)
- @aten.upsample_nearest3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
- @aten.upsample_nearest3d.vec.py_impl(DispatchKey.Autograd)
- def upsample_nearest3d_vec(input, output_size, scale_factors):
- osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
- scale_d = get_scale_value(scale_factors, 0)
- scale_h = get_scale_value(scale_factors, 1)
- scale_w = get_scale_value(scale_factors, 2)
- return upsample_nearest3d(input, osize, scale_d, scale_h, scale_w)
- def _compute_upsample_nearest_indices(input, output_size, scales):
- # For each dim in output_size, compute the set of input indices used
- # to produce the upsampled output.
- indices = []
- num_spatial_dims = len(output_size)
- for d in range(num_spatial_dims):
- # Math matches aten/src/ATen/native/cpu/UpSampleKernel.cpp
- # Indices are computed as following:
- # scale = isize / osize
- # input_index = floor(output_index * scale)
- # Same as OpenCV INTER_NEAREST
- osize = output_size[d]
- output_indices = torch.arange(osize, dtype=input.dtype, device=input.device)
- isize = input.shape[-num_spatial_dims + d]
- scale = isize / (isize * scales[d]) if scales[d] is not None else isize / osize
- input_indices = (output_indices * scale).to(torch.int64)
- for _ in range(num_spatial_dims - 1 - d):
- input_indices = input_indices.unsqueeze(-1)
- indices.append(input_indices)
- return tuple(indices)
- @register_decomposition(aten.upsample_nearest1d.default)
- @aten.upsample_nearest1d.default.py_impl(DispatchKey.Autograd)
- @pw_cast_for_opmath
- def upsample_nearest1d(
- input: Tensor,
- output_size: List[int],
- scales: Optional[float] = None,
- ) -> Tensor:
- (l_indices,) = _compute_upsample_nearest_indices(input, output_size, (scales,))
- result = input[:, :, l_indices]
- return result
- @register_decomposition(aten.upsample_nearest2d.default)
- @aten.upsample_nearest2d.default.py_impl(DispatchKey.Autograd)
- @pw_cast_for_opmath
- def upsample_nearest2d(
- input: Tensor,
- output_size: List[int],
- scales_h: Optional[float] = None,
- scales_w: Optional[float] = None,
- ) -> Tensor:
- h_indices, w_indices = _compute_upsample_nearest_indices(
- input, output_size, (scales_h, scales_w)
- )
- result = input[:, :, h_indices, w_indices]
- # convert output to correct memory format, if necessary
- memory_format = utils.suggest_memory_format(input)
- # following "heuristic: only use channels_last path when it's faster than the contiguous path"
- _, n_channels, _, _ = input.shape
- if input.device.type == "cuda" and n_channels < 4:
- memory_format = torch.contiguous_format
- result = result.contiguous(memory_format=memory_format)
- return result
- @register_decomposition(aten.upsample_nearest3d.default)
- @aten.upsample_nearest3d.default.py_impl(DispatchKey.Autograd)
- @pw_cast_for_opmath
- def upsample_nearest3d(
- input: Tensor,
- output_size: List[int],
- scales_d: Optional[float] = None,
- scales_h: Optional[float] = None,
- scales_w: Optional[float] = None,
- ) -> Tensor:
- d_indices, h_indices, w_indices = _compute_upsample_nearest_indices(
- input, output_size, (scales_d, scales_h, scales_w)
- )
- result = input[:, :, d_indices, h_indices, w_indices]
- return result
- def gather_params(params, has_biases, has_projections):
- if has_biases and has_projections:
- group_size = 5
- elif has_biases:
- group_size = 4
- elif has_projections:
- group_size = 3
- else:
- group_size = 2
- assert len(params) % group_size == 0, len(params)
- return [
- tuple(params[i : i + group_size]) for i in range(0, len(params), group_size)
- ]
- def params_hiddens(params, hiddens, i, bidirectional):
- if bidirectional:
- cur_params, cur_hidden = params[2 * i], hiddens[2 * i]
- bidir_params, bidir_hidden = params[2 * i + 1], hiddens[2 * i + 1]
- else:
- cur_params, cur_hidden = params[i], hiddens[i]
- bidir_params, bidir_hidden = None, None
- return cur_params, cur_hidden, bidir_params, bidir_hidden
- def update_hidden_for_packed(cur_hidden, last_batch_size, batch_size, hiddens):
- assert last_batch_size > batch_size
- hiddens.append(cur_hidden.narrow(0, batch_size, last_batch_size - batch_size))
- return cur_hidden.narrow(0, 0, batch_size)
- def update_hidden_for_packed_reverse(
- cur_hidden, last_batch_size, batch_size, inp_hidden
- ):
- if last_batch_size == batch_size:
- return cur_hidden
- assert last_batch_size < batch_size
- return torch.concat(
- (
- cur_hidden,
- inp_hidden.narrow(0, last_batch_size, batch_size - last_batch_size),
- )
- )
- def one_layer_rnn_data(
- inp, hidden, params, has_biases, hidden_fn, batch_sizes, reverse=False
- ):
- ih_weight = params[0]
- hh_weight = params[1]
- ih_bias = params[2] if has_biases else None
- hh_bias = params[3] if has_biases else None
- step_output = []
- hiddens: List["torch.Tensor"] = []
- last_batch_size = batch_sizes[-1] if reverse else batch_sizes[0]
- cur_hidden = hidden.narrow(0, 0, last_batch_size)
- split_inp = torch.split(inp, list(batch_sizes))
- if reverse:
- split_inp = split_inp[::-1]
- for inp in split_inp:
- i = inp.shape[0]
- if last_batch_size == i:
- pass # don't update cur_hidden
- # this will only happen when reverse=False, since batch sizes are sorted largest -> smallest
- elif reverse:
- cur_hidden = update_hidden_for_packed_reverse(
- cur_hidden, last_batch_size, i, hidden
- )
- else:
- cur_hidden = update_hidden_for_packed(
- cur_hidden, last_batch_size, i, hiddens
- )
- cur_hidden = hidden_fn(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias)
- last_batch_size = i
- step_output.append(cur_hidden)
- if reverse:
- step_output.reverse()
- else:
- hiddens.append(cur_hidden)
- hiddens.reverse()
- out = torch.cat(step_output, 0)
- hidden_out = torch.cat(hiddens, 0) if not reverse else cur_hidden
- return out, hidden_out
- def rnn_cell(nonlinearity):
- def inner(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias):
- return nonlinearity(F.linear(cur_hidden, hh_weight, hh_bias) + i)
- return inner
- def rnn_cell_data(nonlinearity):
- def inner(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias):
- i = F.linear(i, ih_weight, ih_bias)
- return nonlinearity(F.linear(cur_hidden, hh_weight, hh_bias) + i)
- return inner
- def one_layer_rnn(inp, hidden, params, has_biases, hidden_fn, reverse=False):
- ih_weight = params[0]
- hh_weight = params[1]
- ih_bias = params[2] if has_biases else None
- hh_bias = params[3] if has_biases else None
- precomputed_input = F.linear(inp, ih_weight, ih_bias)
- precomputed_input = precomputed_input.flip(0) if reverse else precomputed_input
- cur_hidden = hidden.unsqueeze(0)
- step_output = []
- for i in precomputed_input:
- cur_hidden = hidden_fn(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias)
- step_output.append(cur_hidden)
- if reverse:
- step_output.reverse()
- out = torch.cat(step_output, 0)
- return out, cur_hidden.squeeze(0)
- def _rnn_helper(
- input,
- hidden,
- params,
- has_biases,
- num_layers,
- dropout,
- train,
- bidirectional,
- batch_first,
- layer_fn,
- ):
- input = input.transpose(0, 1) if batch_first else input
- final_hiddens = []
- for i in range(num_layers):
- cur_params, cur_hidden, bidir_params, bidir_hidden = params_hiddens(
- params, hidden, i, bidirectional
- )
- dropout = dropout if (train and num_layers < i - 1) else 0.0
- fwd_inp, fwd_hidden = layer_fn(input, cur_hidden, cur_params, has_biases)
- final_hiddens.append(fwd_hidden)
- if bidirectional:
- bwd_inp, bwd_hidden = layer_fn(
- input, bidir_hidden, bidir_params, has_biases, reverse=True
- )
- final_hiddens.append(bwd_hidden)
- if bidirectional:
- input = torch.cat([fwd_inp, bwd_inp], fwd_inp.dim() - 1)
- else:
- input = fwd_inp
- if dropout != 0 and train and i < num_layers - 1:
- input = torch.dropout(input, dropout, train=True)
- input = input.transpose(0, 1) if batch_first else input
- return input, final_hiddens
- @register_decomposition(aten.rnn_tanh.input)
- @aten.rnn_tanh.input.py_impl(DispatchKey.CompositeImplicitAutograd)
- @aten.rnn_tanh.input.py_impl(DispatchKey.Autograd)
- def rnn_tanh_input(
- input,
- hx,
- params,
- has_biases,
- num_layers,
- dropout,
- train,
- bidirectional,
- batch_first,
- ):
- hidden = hx.unbind(0)
- params = gather_params(params, has_biases, False)
- out, final_hiddens = _rnn_helper(
- input,
- hidden,
- params,
- has_biases,
- num_layers,
- dropout,
- train,
- bidirectional,
- batch_first,
- partial(one_layer_rnn, hidden_fn=rnn_cell(torch.tanh)),
- )
- return out, torch.stack(final_hiddens, 0)
- @register_decomposition(aten.rnn_relu.input)
- @aten.rnn_relu.input.py_impl(DispatchKey.CompositeImplicitAutograd)
- @aten.rnn_relu.input.py_impl(DispatchKey.Autograd)
- def rnn_relu_input(
- input,
- hx,
- params,
- has_biases,
- num_layers,
- dropout,
- train,
- bidirectional,
- batch_first,
- ):
- hidden = hx.unbind(0)
- params = gather_params(params, has_biases, False)
- out, final_hiddens = _rnn_helper(
- input,
- hidden,
- params,
- has_biases,
- num_layers,
- dropout,
- train,
- bidirectional,
- batch_first,
- partial(one_layer_rnn, hidden_fn=rnn_cell(torch.relu)),
- )
- return out, torch.stack(final_hiddens, 0)
- @register_decomposition(aten.rnn_relu.data)
- @aten.rnn_relu.data.py_impl(DispatchKey.CompositeImplicitAutograd)
- @aten.rnn_relu.data.py_impl(DispatchKey.Autograd)
- def rnn_relu_data(
- data,
- batch_sizes,
- hx,
- params,
- has_biases,
- num_layers,
- dropout,
- train,
- bidirectional,
- ):
- hidden = hx.unbind(0)
- params = gather_params(params, has_biases, False)
- out, final_hiddens = _rnn_helper(
- data,
- hidden,
- params,
- has_biases,
- num_layers,
- dropout,
- train,
- bidirectional,
- False,
- partial(
- one_layer_rnn_data,
- batch_sizes=batch_sizes,
- hidden_fn=rnn_cell_data(torch.relu),
- ),
- )
- return out, torch.stack(final_hiddens, 0)
- @register_decomposition(aten.rnn_tanh.data)
- @aten.rnn_tanh.data.py_impl(DispatchKey.CompositeImplicitAutograd)
- @aten.rnn_tanh.data.py_impl(DispatchKey.Autograd)
- def rnn_tanh_data(
- data,
- batch_sizes,
- hx,
- params,
- has_biases,
- num_layers,
- dropout,
- train,
- bidirectional,
- ):
- hidden = hx.unbind(0)
- params = gather_params(params, has_biases, False)
- out, final_hiddens = _rnn_helper(
- data,
- hidden,
- params,
- has_biases,
- num_layers,
- dropout,
- train,
- bidirectional,
- False,
- partial(
- one_layer_rnn_data,
- batch_sizes=batch_sizes,
- hidden_fn=rnn_cell_data(torch.tanh),
- ),
- )
- return out, torch.stack(final_hiddens, 0)
- def lstm_cell(inp, hx, cx, hh_weight, hh_bias, hr_weight, chunk_dim):
- gates = F.linear(hx, hh_weight, hh_bias) + inp
- chunked_gates = gates.chunk(4, chunk_dim)
- in_gate = chunked_gates[0].sigmoid()
- forget_gate = chunked_gates[1].sigmoid()
- cell_gate = chunked_gates[2].tanh()
- out_gate = chunked_gates[3].sigmoid()
- cy = forget_gate * cx + (in_gate * cell_gate)
- hy = out_gate * cy.tanh()
- hy = hy if hr_weight is None else F.linear(hy, hr_weight, None)
- return hy, cy
- def one_layer_lstm(inp, hidden, params, has_biases, reverse=False):
- ih_weight = params[0]
- hh_weight = params[1]
- ih_bias = params[2] if has_biases else None
- hh_bias = params[3] if has_biases else None
- hr_weight = (
- params[4] if len(params) == 5 else params[2] if len(params) == 3 else None
- )
- hx = hidden[0].unsqueeze(0)
- cx = hidden[1].unsqueeze(0)
- precomputed_input = F.linear(inp, ih_weight, ih_bias)
- precomputed_input = precomputed_input.flip(0) if reverse else precomputed_input
- step_output = []
- for inp in precomputed_input:
- hx, cx = lstm_cell(inp, hx, cx, hh_weight, hh_bias, hr_weight, chunk_dim=2)
- step_output.append(hx)
- if reverse:
- step_output.reverse()
- out = torch.cat(step_output, 0)
- return out, (hx.squeeze(1), cx.squeeze(1))
- def one_layer_lstm_data(inp, hidden, params, has_biases, batch_sizes, reverse=False):
- ih_weight = params[0]
- hh_weight = params[1]
- ih_bias = params[2] if has_biases else None
- hh_bias = params[3] if has_biases else None
- hr_weight = (
- params[4] if len(params) == 5 else params[2] if len(params) == 3 else None
- )
- step_output = []
- hiddens = []
- last_batch_size = batch_sizes[-1] if reverse else batch_sizes[0]
- split_inp = torch.split(inp, list(batch_sizes))
- if reverse:
- split_inp = split_inp[::-1]
- orig_hx = hidden[0]
- orig_cx = hidden[1]
- hx, cx = orig_hx.narrow(0, 0, last_batch_size), orig_cx.narrow(
- 0, 0, last_batch_size
- )
- for inp in split_inp:
- i = inp.shape[0]
- inp = F.linear(inp, ih_weight, ih_bias)
- # this will only happen when reverse=False, since batch sizes are sorted largest -> smallest
- if i < last_batch_size:
- hiddens.append(
- (
- hx.narrow(0, i, last_batch_size - i),
- cx.narrow(0, i, last_batch_size - i),
- )
- )
- hx, cx = hx.narrow(0, 0, i), cx.narrow(0, 0, i)
- # this will only happen when reverse=True
- if i > last_batch_size:
- hx = torch.concat(
- (hx, orig_hx.narrow(0, last_batch_size, i - last_batch_size)), 0
- )
- cx = torch.concat(
- (cx, orig_cx.narrow(0, last_batch_size, i - last_batch_size)), 0
- )
- hx, cx = lstm_cell(inp, hx, cx, hh_weight, hh_bias, hr_weight, chunk_dim=1)
- last_batch_size = i
- step_output.append(hx)
- if reverse:
- step_output.reverse()
- hidden_out = (hx, cx)
- else:
- hiddens.append((hx, cx))
- hiddens.reverse()
- hidden0, hidden1 = zip(*hiddens)
- hidden_out = torch.cat(hidden0, 0), torch.cat(hidden1, 0)
- out = torch.cat(step_output, 0)
- return out, hidden_out
- @register_decomposition(aten.lstm.input)
- @aten.lstm.input.py_impl(DispatchKey.CompositeImplicitAutograd)
- @aten.lstm.input.py_impl(DispatchKey.Autograd)
- def lstm_impl(
- input,
- hx,
- params,
- has_biases,
- num_layers,
- dropout,
- train,
- bidirectional,
- batch_first,
- ):
- assert len(hx) == 2, "lstm expects two hidden states"
- params = gather_params(params, has_biases, hx[0].size(2) != hx[1].size(2))
- hidden = list(zip(hx[0], hx[1]))
- out, final_hiddens = _rnn_helper(
- input,
- hidden,
- params,
- has_biases,
- num_layers,
- dropout,
- train,
- bidirectional,
- batch_first,
- one_layer_lstm,
- )
- final_hiddens = list(zip(*final_hiddens))
- return out, torch.stack(final_hiddens[0], 0), torch.stack(final_hiddens[1], 0)
- @register_decomposition(aten.lstm.data)
- @aten.lstm.data.py_impl(DispatchKey.CompositeImplicitAutograd)
- @aten.lstm.data.py_impl(DispatchKey.Autograd)
- def lstm_data_impl(
- data,
- batch_sizes,
- hx,
- params,
- has_biases,
- num_layers,
- dropout,
- train,
- bidirectional,
- ):
- assert len(hx) == 2, "lstm expects two hidden states"
- params = gather_params(params, has_biases, hx[0].size(2) != hx[1].size(2))
- hidden = list(zip(hx[0], hx[1]))
- out, final_hiddens = _rnn_helper(
- data,
- hidden,
- params,
- has_biases,
- num_layers,
- dropout,
- train,
- bidirectional,
- False,
- partial(one_layer_lstm_data, batch_sizes=batch_sizes),
- )
- final_hiddens = list(zip(*final_hiddens))
- return out, torch.stack(final_hiddens[0], 0), torch.stack(final_hiddens[1], 0)
- def gru_cell(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias):
- chunked_igates = inp.chunk(3, 1)
- chunked_hgates = F.linear(cur_hidden, hh_weight, hh_bias).chunk(3, 2)
- reset_gate = (chunked_hgates[0] + chunked_igates[0]).sigmoid()
- input_gate = (chunked_hgates[1] + chunked_igates[1]).sigmoid()
- new_gate = (chunked_igates[2] + (chunked_hgates[2] * reset_gate)).tanh()
- return (cur_hidden - new_gate) * input_gate + new_gate
- def gru_cell_data(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias):
- chunked_igates = F.linear(inp, ih_weight, ih_bias).chunk(3, 1)
- chunked_hgates = F.linear(cur_hidden, hh_weight, hh_bias).chunk(3, 1)
- reset_gate = (chunked_hgates[0] + chunked_igates[0]).sigmoid()
- input_gate = (chunked_hgates[1] + chunked_igates[1]).sigmoid()
- new_gate = (chunked_igates[2] + (chunked_hgates[2] * reset_gate)).tanh()
- return (cur_hidden - new_gate) * input_gate + new_gate
- @register_decomposition(aten.gru.data)
- @aten.gru.data.py_impl(DispatchKey.CompositeImplicitAutograd)
- @aten.gru.data.py_impl(DispatchKey.Autograd)
- def gru_impl_data(
- data,
- batch_sizes,
- hx,
- params,
- has_biases,
- num_layers,
- dropout,
- train,
- bidirectional,
- ):
- params = gather_params(params, has_biases, False)
- out, final_hiddens = _rnn_helper(
- data,
- hx.unbind(0),
- params,
- has_biases,
- num_layers,
- dropout,
- train,
- bidirectional,
- False,
- partial(one_layer_rnn_data, batch_sizes=batch_sizes, hidden_fn=gru_cell_data),
- )
- return out, torch.stack(final_hiddens, 0)
- @register_decomposition(aten.gru.input)
- @aten.gru.input.py_impl(DispatchKey.CompositeImplicitAutograd)
- @aten.gru.input.py_impl(DispatchKey.Autograd)
- def gru_impl(
- input,
- hx,
- params,
- has_biases,
- num_layers,
- dropout,
- train,
- bidirectional,
- batch_first,
- ):
- params = gather_params(params, has_biases, False)
- out, final_hiddens = _rnn_helper(
- input,
- hx.unbind(0),
- params,
- has_biases,
- num_layers,
- dropout,
- train,
- bidirectional,
- batch_first,
- partial(one_layer_rnn, hidden_fn=gru_cell),
- )
- return out, torch.stack(final_hiddens, 0)
- @register_decomposition(aten.upsample_bilinear2d.vec)
- @aten.upsample_bilinear2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
- @aten.upsample_bilinear2d.vec.py_impl(DispatchKey.Autograd)
- def upsample_bilinear2d_vec(input, output_size, align_corners, scale_factors):
- osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
- scale_h = get_scale_value(scale_factors, 0)
- scale_w = get_scale_value(scale_factors, 1)
- return upsample_bilinear2d(input, osize, align_corners, scale_h, scale_w)
- @register_decomposition(aten.upsample_bilinear2d.default)
- @aten.upsample_bilinear2d.default.py_impl(DispatchKey.Autograd)
- @pw_cast_for_opmath
- def upsample_bilinear2d(
- input: Tensor,
- output_size: List[int],
- align_corners: bool,
- scales_h: Optional[float] = None,
- scales_w: Optional[float] = None,
- ) -> Tensor:
- # get dimensions of original image
- n_batch, n_channels, in_h, in_w = input.shape
- out_h = output_size[0]
- out_w = output_size[1]
- # Calculate horizontal and vertical scaling factor
- # TODO: Figure out if scales_h/scales_w matters here
- if out_h > 1:
- if align_corners:
- h_scale_factor = (in_h - 1) / (out_h - 1)
- else:
- h_scale_factor = (
- in_h / (in_h * scales_h) if scales_h is not None else in_h / out_h
- )
- else:
- h_scale_factor = 0.0
- if out_w > 1:
- if align_corners:
- w_scale_factor = (in_w - 1) / (out_w - 1)
- else:
- w_scale_factor = (
- in_w / (in_w * scales_w) if scales_w is not None else in_w / out_w
- )
- else:
- w_scale_factor = 0.0
- i = torch.arange(out_h, dtype=input.dtype, device=input.device)
- j = torch.arange(out_w, dtype=input.dtype, device=input.device)
- if align_corners:
- x = h_scale_factor * i
- y = w_scale_factor * j
- else:
- x = (h_scale_factor * (i + 0.5) - 0.5).clamp(min=0.0)
- y = (w_scale_factor * (j + 0.5) - 0.5).clamp(min=0.0)
- x_floor = x.to(torch.int64)
- x_ceil = torch.ceil(x).clamp(max=in_h - 1).to(torch.int64)
- y_floor = y.to(torch.int64)
- y_ceil = torch.ceil(y).clamp(max=in_w - 1).to(torch.int64)
- x_view = x.unsqueeze(1)
- x_floor_view = x_floor.unsqueeze(1)
- x_ceil_view = x_ceil.unsqueeze(1)
- v1 = input[:, :, x_floor_view, y_floor]
- v2 = input[:, :, x_ceil_view, y_floor]
- v3 = input[:, :, x_floor_view, y_ceil]
- v4 = input[:, :, x_ceil_view, y_ceil]
- xscale2 = x_view - x_floor_view
- xscale1 = 1.0 - xscale2
- yscale2 = y - y_floor
- yscale1 = 1.0 - yscale2
- q1 = torch.mul(v1, xscale1) + torch.mul(v2, xscale2)
- q2 = torch.mul(v3, xscale1) + torch.mul(v4, xscale2)
- result = torch.mul(q1, yscale1) + torch.mul(q2, yscale2)
- # convert output to correct memory format, if necessary
- memory_format = utils.suggest_memory_format(input)
- # following "heuristic: only use channels_last path when it's faster than the contiguous path"
- if input.device.type == "cuda" and n_channels < 16:
- memory_format = torch.contiguous_format
- result = result.contiguous(memory_format=memory_format)
- return result
- # We should be applying decompositions after all transformations
- @register_decomposition(aten.is_same_size.default)
- def is_same_size(a: Tensor, b: Tensor) -> bool:
- return a.shape == b.shape
- @register_decomposition([aten._reshape_alias, aten._unsafe_view])
- def _reshape_alias(x, shape, *args):
- return aten.view(x, shape)
- @register_decomposition(aten.nll_loss_forward)
- def nll_loss_forward(
- self: Tensor,
- target: Tensor,
- weight: Optional[Tensor],
- reduction: int,
- ignore_index: int,
- ) -> Tuple[Tensor, Tensor]:
- assert self.dim() > 0 and self.dim() <= 2, "input tensor should be 1D or 2D"
- assert (
- target.dim() <= 1
- ), "0D or 1D target tensor expected, multi-target not supported"
- no_batch_dim = self.dim() == 1 and target.dim() == 0
- assert no_batch_dim or (
- self.shape[0] == target.shape[0]
- ), f"size mismatch (got input: {self.shape}, target: {target.shape})"
- n_classes = self.shape[-1]
- assert weight is None or (
- weight.dim() == 1 and weight.numel() == n_classes
- ), f"weight tensor should be defined either for all {n_classes} classes or no classes but got weight tensor of shape: {weight.shape}" # noqa: B950
- # self can be [N, C] or [C]
- # target can be [N] or []
- n_dims = self.dim()
- channel_dim = 1
- if n_dims < 2:
- channel_dim = 0
- if weight is not None:
- w = weight.unsqueeze(0) if n_dims > 1 else weight
- self = self * w
- safe_target = torch.where(target != ignore_index, target, 0)
- safe_target_ = safe_target.unsqueeze(channel_dim)
- # target can be [N, 1] or [1]
- result = -torch.gather(self, channel_dim, safe_target_).squeeze(channel_dim)
- result = torch.where(target != ignore_index, result, 0)
- if reduction == Reduction.NONE.value and n_dims > 1:
- total_weight = self.new_full((), 0.0)
- return result, total_weight
- if weight is not None:
- w = weight.unsqueeze(0).expand(self.shape) if n_dims > 1 else weight
- wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim)
- wsum = torch.where(target != ignore_index, wsum, 0)
- total_weight = wsum.sum()
- else:
- total_weight = (target != ignore_index).sum().to(self)
- if reduction == Reduction.SUM.value:
- result = result.sum()
- elif reduction == Reduction.MEAN.value:
- result = result.sum() / total_weight
- return result, total_weight
- # These are adapted from aten/src/ATen/native/UpSample.h, wich is based on
- # https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
- def _upsample_cubic_convolution1(x: Tensor, A: float) -> Tensor:
- return ((A + 2) * x - (A + 3)) * x * x + 1
- def _upsample_cubic_convolution2(x: Tensor, A: float) -> Tensor:
- return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A
- def _upsample_get_cubic_coefficients(t: Tensor) -> TensorSequenceType:
- A = -0.75
- return (
- _upsample_cubic_convolution2(t + 1.0, A),
- _upsample_cubic_convolution1(t, A),
- _upsample_cubic_convolution1(1.0 - t, A),
- _upsample_cubic_convolution2(2.0 - t, A),
- )
- def _upsample_cubic_interp1d(coeffs: TensorSequenceType, ts: Tensor) -> Tensor:
- coeffs2 = _upsample_get_cubic_coefficients(ts)
- return _sum_tensors(c1 * c2 for (c1, c2) in zip(coeffs, coeffs2))
- # Need this instead of just sum() to keep mypy happy
- def _sum_tensors(ts: Iterable[Tensor]) -> Tensor:
- return reduce(torch.add, ts)
- @register_decomposition(aten.grid_sampler_2d)
- @pw_cast_for_opmath
- def grid_sampler_2d(
- a: Tensor,
- grid: Tensor,
- interpolation_mode: int = 0,
- padding_mode: int = 0,
- align_corners: bool = False,
- ) -> Tensor:
- utils.check(
- interpolation_mode in (0, 1, 2),
- lambda: f"Invalid interpolation mode {interpolation_mode}",
- )
- utils.check(
- padding_mode in (0, 1, 2), lambda: f"Invalid padding mode {padding_mode}"
- )
- def unnormalize(coords: Tensor, size: int) -> Tensor:
- # Rescale coordinates from [-1, 1] to:
- # [0, size - 1] if align_corners is True
- # [-.5, size -.5] if align_corners is False
- mul = (size * 0.5 - 0.5) if align_corners else (size * 0.5)
- ofs = size * 0.5 - 0.5
- return coords * mul + ofs
- # Reflects coordinates until they fall between low and high (inclusive).
- # The bounds are passed as twice their value so that half-integer values
- # can be represented as ints.
- def reflect_coordinates(coords: Tensor, twice_low: int, twice_high: int) -> Tensor:
- if twice_low == twice_high:
- return torch.zeros_like(coords)
- coords_min = twice_low / 2
- coords_span = (twice_high - twice_low) / 2
- coords2 = (coords - coords_min).abs()
- extra = torch.fmod(coords2, coords_span)
- flips = (coords2 / coords_span).floor().to(dtype=torch.int8)
- return torch.where(
- flips & 1 == 0, extra + coords_min, coords_span + coords_min - extra
- )
- def compute_coordinates(coords: Tensor, size: int) -> Tensor:
- if padding_mode == 0: # Zero
- return coords
- elif padding_mode == 1: # Borders
- return torch.clamp(coords, 0, size - 1)
- else: # padding_mode == 2, Reflection
- if align_corners:
- coords_reflected = reflect_coordinates(coords, 0, 2 * (size - 1))
- else:
- coords_reflected = reflect_coordinates(coords, -1, 2 * size - 1)
- return torch.clamp(coords_reflected, 0, size - 1)
- def compute_source_index(coords: Tensor, size: int) -> Tensor:
- coords_un = unnormalize(coords, size)
- return compute_coordinates(coords_un, size)
- N, C, iH, iW = a.shape
- _, oH, oW, _ = grid.shape
- def in_bounds_cond(xs: Tensor, ys: Tensor) -> Tensor:
- return torch.logical_and(
- 0 <= xs, torch.logical_and(xs < iW, torch.logical_and(0 <= ys, ys < iH))
- )
- N_idx = torch.arange(N, device=a.device).view(N, 1, 1, 1)
- C_idx = torch.arange(C, device=a.device).view(1, C, 1, 1)
- def clip(xs: Tensor, ys: Tensor, ws: Tensor) -> TensorSequenceType:
- cond = in_bounds_cond(xs, ys)
- # To clip to inside valid coordinates, we map the coordinates
- # to (x, y) = (0, 0) and also set the weight to 0
- # We also change the shape of the tensor to the appropriate one for
- # broadcasting with N_idx, C_idx for the purposes of advanced indexing
- return tuple(
- torch.where(cond, t, 0).view(N, 1, oH, oW)
- for t in (xs.to(dtype=torch.int64), ys.to(dtype=torch.int64), ws)
- )
- def get_summand(ix: Tensor, iy: Tensor, w) -> Tensor:
- # Perform clipping, index into input tensor and multiply by weight
- idx_x, idx_y, w_ = clip(ix, iy, w)
- return a[N_idx, C_idx, idx_y, idx_x] * w_
- x = grid[..., 0]
- y = grid[..., 1]
- if interpolation_mode == 0: # Bilinear
- ix = compute_source_index(x, iW)
- iy = compute_source_index(y, iH)
- ix_nw, iy_nw = ix.floor(), iy.floor()
- ix_ne, iy_ne = ix_nw + 1, iy_nw
- ix_sw, iy_sw = ix_nw, iy_nw + 1
- ix_se, iy_se = ix_ne, iy_sw
- w_nw = (ix_se - ix) * (iy_se - iy)
- w_ne = (ix - ix_sw) * (iy_sw - iy)
- w_sw = (ix_ne - ix) * (iy - iy_ne)
- w_se = (ix - ix_nw) * (iy - iy_nw)
- return _sum_tensors(
- get_summand(ix, iy, w)
- for (ix, iy, w) in (
- (ix_nw, iy_nw, w_nw),
- (ix_ne, iy_ne, w_ne),
- (ix_sw, iy_sw, w_sw),
- (ix_se, iy_se, w_se),
- )
- )
- elif interpolation_mode == 1: # Nearest
- ix = compute_source_index(x, iW)
- iy = compute_source_index(y, iH)
- ix_nearest = ix.round()
- iy_nearest = iy.round()
- return get_summand(ix_nearest, iy_nearest, 1)
- else: # interpolation_mode == 2, Bicubic
- ix = unnormalize(x, iW)
- iy = unnormalize(y, iH)
- ix_nw = ix.floor()
- iy_nw = iy.floor()
- tx = ix - ix_nw
- ty = iy - iy_nw
- def get_value_bounded(ix: Tensor, iy: Tensor) -> Tensor:
- x = compute_coordinates(ix, iW)
- y = compute_coordinates(iy, iH)
- return get_summand(x, y, 1)
- def get_coeff(ofs: int) -> Tensor:
- iy_ofs = iy_nw + (ofs - 1)
- cs = (
- get_value_bounded(ix_nw - 1, iy_ofs),
- get_value_bounded(ix_nw, iy_ofs),
- get_value_bounded(ix_nw + 1, iy_ofs),
- get_value_bounded(ix_nw + 2, iy_ofs),
- )
- return _upsample_cubic_interp1d(cs, tx.unsqueeze(1))
- coeffs = tuple((get_coeff(ofs) for ofs in range(4)))
- return _upsample_cubic_interp1d(coeffs, ty.unsqueeze(1))
- @register_decomposition(aten.mv)
- @out_wrapper()
- @pw_cast_for_opmath
- def mv(self, vec):
- utils.check(
- self.dim() == 2 and vec.dim() == 1,
- lambda: f"matrix @ vector expected, got {self.dim()}, {vec.dim()}",
- )
- utils.check(
- self.size(1) == vec.size(0),
- lambda: f"size mismatch, got {self.size(0)}x{self.size(1)},{vec.size(0)}",
- )
- return (self * vec).sum(dim=1)
- @register_decomposition(aten.dot)
- @out_wrapper()
- @pw_cast_for_opmath
- def dot(self, other):
- if self.is_complex():
- if self.is_conj():
- if other.is_conj():
- return torch.dot(self.conj(), other.conj()).conj()
- else:
- return torch.vdot(self.conj(), other)
- elif other.is_conj():
- return torch.vdot(other.conj(), self)
- utils.check(
- self.dim() == 1 and other.dim() == 1,
- lambda: f"1D tensors expected, but got {self.dim()}D and {other.dim()}D tensors",
- )
- utils.check(
- self.dtype == other.dtype,
- lambda: f"dot : expected both vectors to have same dtype, but found {self.dtype} and {other.dtype}",
- )
- def numel_error():
- return (
- f"inconsistent tensor size, expected tensor [{self.numel()}] and src [{other.numel()}] to have the"
- f"same number of elements, but got {self.numel()} and {other.numel()} elements respectively"
- )
- utils.check(self.numel() == other.numel(), numel_error)
- return (self * other).sum()
- @register_decomposition(aten.binary_cross_entropy_with_logits)
- def binary_cross_entropy_with_logits(
- self, target, weight=None, pos_weight=None, reduction=Reduction.MEAN.value
- ):
- max_val = (-self).clamp_min(0)
- if pos_weight is not None:
- log_weight = (pos_weight - 1) * target + 1
- loss = (1 - target) * self + log_weight * (
- ((-max_val).exp() + (-self - max_val).exp()).log() + max_val
- )
- else:
- loss = (
- (1 - target) * self
- + max_val
- + ((-max_val).exp() + (-self - max_val).exp()).log()
- )
- if weight is not None:
- loss = loss * weight
- return apply_loss_reduction(loss, reduction)
- def should_fold(tensor1: torch.Tensor, dim_tensor2: int) -> bool:
- dim_tensor1 = tensor1.ndim
- if dim_tensor1 >= 3 and (dim_tensor2 == 1 or dim_tensor2 == 2):
- t1_sizes_ptr = tensor1.shape
- t1_strides = tensor1.stride()
- if (
- dim_tensor1 == 3
- and dim_tensor2 == 2
- and t1_strides[-1] != 1
- and t1_strides[0] == t1_sizes_ptr[1] * t1_sizes_ptr[2]
- ):
- # First dim is slowest moving, and then the following two dims are
- # transposed. This can happen for example by permute(0, 2, 1).
- # First 2 dims could be folded to use mm but would require permutation
- # with actual data movement, which can be instead handled by BMM with each
- # GEMM transposed.
- # This can be generalized to a tensor with dim X + Y + Z where X, Y, and Z
- # dims are contiguous, Y dims and Z dims are transposed, and X, Y, Z > 0.
- # For example, this can happen by permute(0, 1, 5, 2, 3, 4), where X = 2,
- # Y = 3, and Z = 1.
- return False
- else:
- return True
- else:
- return False
- @aten.matmul.default.py_impl(DispatchKey.CompositeImplicitAutograd)
- def matmul(tensor1, tensor2):
- dim_tensor1 = tensor1.dim()
- dim_tensor2 = tensor2.dim()
- assert dim_tensor1 != 0 and dim_tensor2 != 0
- if dim_tensor1 == 1 and dim_tensor2 == 1:
- return torch.dot(tensor1, tensor2)
- elif dim_tensor1 == 2 and dim_tensor2 == 1:
- return torch.mv(tensor1, tensor2)
- elif dim_tensor1 == 1 and dim_tensor2 == 2:
- return torch.squeeze(torch.mm(torch.unsqueeze(tensor1, 0), tensor2), 0)
- elif dim_tensor1 == 2 and dim_tensor2 == 2:
- # if tensor1.shape[1] != tensor2.shape[0]:
- # breakpoint()
- return torch.mm(tensor1, tensor2)
- elif should_fold(tensor1, dim_tensor2) or should_fold(tensor2, dim_tensor1):
- # NB: Much of this was written with Copilot! (although still had to fix a bunch of issues)
- # dim_tensor1 >=3 && (dim_tensor2 == 1 || dim_tensor2 == 2) ||
- # dim_tensor2 >=3 && (dim_tensor1 == 1 || dim_tensor1 == 2)
- # and some condition on the strides is fulfilled
- # optimization: use mm instead of bmm by folding the batch of the larger tensor
- # into its leading matrix dimension
- transpose = dim_tensor2 > dim_tensor1
- t1 = tensor2.mT if transpose else tensor1
- t2 = (
- tensor2 if not transpose else (tensor1.t() if dim_tensor1 == 2 else tensor1)
- )
- # Invariant: t1.dim() >= 3 && (t2.dim() == 1 || t2.dim() == 2)
- # and t1 and t2 are matmul-compatible
- # Why not t1.view(-1, sizes_1[-1])?
- # If the last dim is 0, then view(-1, 0) won't work because the -1 becomes ambiguous.
- # This can happen in e.g. [3, 5, 0] @ [0, 0].
- sizes_1 = t1.shape
- output_shape = list(sizes_1[:-1])
- folded_dim1 = reduce(operator.mul, output_shape)
- # Readjust output_shape if we are multiplying by a matrix
- t2_is_matrix = t2.dim() == 2
- if t2_is_matrix:
- output_shape.append(t2.shape[1])
- t1_folded = t1.reshape(folded_dim1, sizes_1[-1])
- if t2_is_matrix:
- # FIXME This path always does an unnecessary copy when transpose == True as the returned
- # result from BLAS is already C-transposed
- output = t1_folded.mm(t2).view(output_shape)
- return output.mT.contiguous() if transpose else output
- else:
- return t1_folded.mv(t2).view(output_shape)
- elif dim_tensor1 >= 1 and dim_tensor2 >= 1:
- # We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list);
- # we track m1 vs m2 separately even though they must match for nicer error messages
- n = tensor1.size(-2) if dim_tensor1 > 1 else 1
- m1 = tensor1.size(-1)
- batch_tensor1 = tensor1.shape[:-2]
- m2 = tensor2.size(-2) if dim_tensor2 > 1 else tensor2.size(-1)
- p = tensor2.size(-1) if dim_tensor2 > 1 else 1
- batch_tensor2: List[int] = []
- # TODO: handling of slice
- for i in range(dim_tensor2 - 2):
- batch_tensor2.append(tensor2.size(i))
- # expand the batch portion (i.e. cut off matrix dimensions and expand rest)
- expand_batch_portion = list(
- torch.broadcast_shapes(batch_tensor1, batch_tensor2)
- )
- tensor1_expand_size = expand_batch_portion + [n, m1]
- tensor2_expand_size = expand_batch_portion + [m2, p]
- expand_batch_product = prod(expand_batch_portion)
- # HACK: We need reshape with symint support
- tensor1_expanded = tensor1.expand(tensor1_expand_size).reshape(
- expand_batch_product, n, m1
- )
- tensor2_expanded = tensor2.expand(tensor2_expand_size).reshape(
- expand_batch_product, m2, p
- )
- output_shape = expand_batch_portion
- if dim_tensor1 > 1:
- output_shape.append(n)
- if dim_tensor2 > 1:
- output_shape.append(p)
- return tensor1_expanded.bmm(tensor2_expanded).view(output_shape)
- else:
- utils.check(False, lambda: "both arguments to matmul need to be at least 1D")
- @register_decomposition(aten.upsample_bicubic2d.default)
- @pw_cast_for_opmath
- def upsample_bicubic2d_default(
- a: Tensor,
- output_size: Tuple[int, int],
- align_corners: bool,
- scale_h: Optional[float] = None,
- scale_w: Optional[float] = None,
- ) -> Tensor:
- N, C, iH, iW = a.shape
- oH, oW = output_size
- def compute_scale(in_size, out_size, align_corners, scale=None):
- if align_corners:
- return (in_size - 1) / (out_size - 1) if out_size > 1 else 0
- else:
- return 1 / scale if scale is not None and scale > 0 else in_size / out_size
- def compute_source_index(scale, dst_index, align_corners):
- if align_corners:
- return scale * dst_index
- else:
- return scale * (dst_index + 0.5) - 0.5
- height_scale = compute_scale(iH, oH, align_corners, scale_h)
- width_scale = compute_scale(iW, oW, align_corners, scale_w)
- N_idx = torch.arange(N, device=a.device).view(N, 1, 1, 1)
- C_idx = torch.arange(C, device=a.device).view(1, C, 1, 1)
- out_y = torch.arange(oH, device=a.device).view((1, 1, oH, 1))
- out_x = torch.arange(oW, device=a.device).view((1, 1, 1, oW))
- real_x = compute_source_index(width_scale, out_x, align_corners)
- in_x = real_x.floor()
- t_x = real_x - in_x
- ix = in_x.to(dtype=torch.int64)
- real_y = compute_source_index(height_scale, out_y, align_corners)
- in_y = real_y.floor()
- t_y = real_y - in_y
- iy = in_y.to(dtype=torch.int64)
- iys_ofs = (iy - 1, iy, iy + 1, iy + 2)
- ixs_ofs = (ix - 1, ix, ix + 1, ix + 2)
- def load_bounded(ys, xs):
- y_idx = torch.clamp(ys, 0, iH - 1)
- x_idx = torch.clamp(xs, 0, iW - 1)
- return a[N_idx, C_idx, y_idx, x_idx]
- def get_x_interp(y):
- coeffs_x = tuple((load_bounded(y, x_ofs) for x_ofs in ixs_ofs))
- return _upsample_cubic_interp1d(coeffs_x, t_x)
- coeffs_y = tuple((get_x_interp(y_ofs) for y_ofs in iys_ofs))
- result = _upsample_cubic_interp1d(coeffs_y, t_y)
- # convert output to correct memory format, if necessary
- memory_format = utils.suggest_memory_format(a)
- result = result.contiguous(memory_format=memory_format)
- return result
- @register_decomposition(aten.upsample_bicubic2d.vec)
- @aten.upsample_bicubic2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
- @aten.upsample_bicubic2d.vec.py_impl(DispatchKey.Autograd)
- @out_wrapper()
- @pw_cast_for_opmath
- def upsample_bicubic2d_vec(
- a: Tensor,
- output_size: Optional[Tuple[int, int]],
- align_corners: bool,
- scale_factors: Optional[Tuple[float, float]] = None,
- ) -> Tensor:
- utils.check(
- bool(output_size) + bool(scale_factors) == 1,
- lambda: "Must specify exactly one of output_size and scale_factors.",
- )
- if output_size is None:
- assert scale_factors is not None
- output_size = cast(
- Tuple[int, int],
- tuple(
- sym_int(sym_float(w) * scale)
- for w, scale in zip(a.shape[2:], scale_factors)
- ),
- )
- scale_h, scale_w = scale_factors if scale_factors else (None, None)
- return upsample_bicubic2d_default(a, output_size, align_corners, scale_h, scale_w)
- def register_inplace(aten_op, outplace_op):
- @register_decomposition(aten_op)
- def inplace_op(*args, **kwargs):
- out = outplace_op(*args, **kwargs)
- return args[0].copy_(out)
- return inplace_op
- register_inplace(aten.addbmm_, aten.addbmm)
- register_inplace(aten.addmm_, aten.addmm)
- register_inplace(aten.addmv_, aten.addmv)
- register_inplace(aten.baddbmm_, aten.baddbmm)
- register_inplace(aten.cumprod_, aten.cumprod)
- register_inplace(aten.fill_, aten.fill)
- register_inplace(aten.gelu_, aten.gelu)
- register_inplace(aten.hardswish_, aten.hardswish)
- register_inplace(aten.hardtanh_, aten.hardtanh)
- register_inplace(aten.hardsigmoid_, aten.hardsigmoid)
- register_inplace(aten.index_put_, aten.index_put)
- register_inplace(aten.index_reduce_, aten.index_reduce)
- register_inplace(aten.leaky_relu_, aten.leaky_relu)
- register_inplace(aten.logit_, aten.logit)
- register_inplace(aten.relu_, aten.relu)
- register_inplace(aten.renorm_, aten.renorm)
- register_inplace(aten.round_, aten.round)
- register_inplace(aten.scatter_, aten.scatter)
- register_inplace(aten.scatter_add_, aten.scatter_add)
- register_inplace(aten.scatter_reduce_, aten.scatter_reduce)
- register_inplace(aten.silu_, aten.silu)
|