12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705 |
- import math
- from typing import List, Optional, Union
- import torch
- import torch._prims_common as utils
- from torch import Tensor
- from torch._decomp import _add_op_to_registry, global_decomposition_table, meta_table
- from torch._ops import OpOverload
- from torch._prims import _elementwise_meta, ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND
- from torch._prims_common import (
- check,
- corresponding_complex_dtype,
- corresponding_real_dtype,
- elementwise_dtypes,
- ELEMENTWISE_TYPE_PROMOTION_KIND,
- IntLike,
- make_contiguous_strides_for,
- )
- from torch._prims_common.wrappers import out_wrapper
- from torch._refs import _broadcast_shapes
- from torch._subclasses.fake_tensor import check_no_bool_index_tensors
- from torch.utils._pytree import tree_map
- aten = torch.ops.aten
- _meta_lib_dont_use_me_use_register_meta = torch.library.Library("aten", "IMPL", "Meta")
- def register_meta(op):
- def wrapper(fn):
- def register(op):
- _add_op_to_registry(meta_table, op, fn)
- tree_map(register, op)
- return fn
- return wrapper
- def toRealValueType(dtype):
- from_complex = {
- torch.complex32: torch.half,
- torch.cfloat: torch.float,
- torch.cdouble: torch.double,
- }
- return from_complex.get(dtype, dtype)
- @register_meta([aten._fft_c2c.default, aten._fft_c2c.out])
- @out_wrapper()
- def meta_fft_c2c(self, dim, normalization, forward):
- assert self.dtype.is_complex
- return self.new_empty(self.size())
- @register_meta([aten._fft_r2c.default, aten._fft_r2c.out])
- @out_wrapper()
- def meta_fft_r2c(self, dim, normalization, onesided):
- assert self.dtype.is_floating_point
- output_sizes = list(self.size())
- if onesided:
- last_dim = dim[-1]
- last_dim_halfsize = (output_sizes[last_dim] // 2) + 1
- output_sizes[last_dim] = last_dim_halfsize
- return self.new_empty(
- output_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
- )
- @register_meta(aten.randperm.generator_out)
- def meta_randperm(n, *, generator=None, out):
- assert out.ndim == 1 and out.size(0) == n
- return out
- @register_meta(aten.randint.default)
- def meta_randint(
- high, size, *, dtype=torch.long, layout=None, device=None, pin_memory=None
- ):
- return torch.empty(
- size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
- )
- @register_meta(aten.randint.low)
- def meta_randint_low(
- low, high, size, *, dtype=torch.long, layout=None, device=None, pin_memory=None
- ):
- return torch.empty(
- size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
- )
- @register_meta(aten.rand.default)
- def meta_rand_default(size, *, dtype=None, layout=None, device=None, pin_memory=None):
- return torch.empty(
- size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
- )
- @register_meta([aten._fft_c2r.default, aten._fft_c2r.out])
- @out_wrapper()
- def meta_fft_c2r(self, dim, normalization, lastdim):
- assert self.dtype.is_complex
- output_sizes = list(self.size())
- output_sizes[dim[-1]] = lastdim
- return self.new_empty(output_sizes, dtype=toRealValueType(self.dtype))
- @register_meta(aten.copy_.default)
- def meta_copy_(self, src, non_blocking=False):
- return self
- def inferUnsqueezeGeometry(tensor, dim):
- result_sizes = list(tensor.size())
- result_strides = list(tensor.stride())
- new_stride = 1 if dim >= tensor.dim() else result_sizes[dim] * result_strides[dim]
- result_sizes.insert(dim, 1)
- result_strides.insert(dim, new_stride)
- return result_sizes, result_strides
- @register_meta(aten.unsqueeze_.default)
- def meta_unsqueeze_(self, dim):
- dim = maybe_wrap_dim(dim, self.dim() + 1)
- g_sizes, g_strides = inferUnsqueezeGeometry(self, dim)
- self.as_strided_(g_sizes, g_strides)
- return self
- # Implementations below are taken from https://github.com/albanD/subclass_zoo/blob/main/python_meta_tensor.py
- @register_meta(aten.index_select.default)
- def meta_index_select(self, dim, index):
- result_size = list(self.size())
- if self.dim() > 0:
- result_size[dim] = index.numel()
- return self.new_empty(result_size)
- @register_meta(aten.index_select.out)
- def meta_index_select_out(self, dim, index, out):
- torch._resize_output_(out, self.size(), self.device)
- return out.copy_(torch.index_select(self, dim, index))
- @register_meta([aten.max.default, aten.max.unary_out])
- @out_wrapper()
- def meta_max(self):
- return self.new_empty(())
- @register_meta(aten.max.dim)
- def meta_max_dim(self, dim, keepdim=False):
- dim = utils.reduction_dims(self.shape, (dim,))
- output_shape = _compute_reduction_shape(self, dim, keepdim)
- return (
- self.new_empty(output_shape),
- self.new_empty(output_shape, dtype=torch.long),
- )
- @register_meta([aten.min.default])
- def meta_min(self):
- return self.new_empty(())
- @register_meta(aten.angle.default)
- def meta_angle(self):
- if self.is_complex():
- result_dtype = corresponding_real_dtype(self.dtype)
- else:
- _, result_dtype = elementwise_dtypes(
- self, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
- )
- return torch.empty_like(self, dtype=result_dtype)
- @register_meta(aten.angle.out)
- def meta_angle_out(self, out):
- torch._resize_output_(out, self.size(), self.device)
- return out.copy_(torch.angle(self))
- # From aten/src/ATen/native/LinearAlgebraUtils.h
- def squareCheckInputs(self: Tensor, f_name: str):
- assert (
- self.dim() >= 2
- ), f"{f_name}: The input tensor must have at least 2 dimensions."
- assert self.size(-1) == self.size(
- -2
- ), f"{f_name}: A must be batches of square matrices, but they are {self.size(-2)} by {self.size(-1)} matrices"
- # From aten/src/ATen/native/LinearAlgebraUtils.h
- def checkFloatingOrComplex(
- t: Tensor, f_name: str, allow_low_precision_dtypes: bool = True
- ):
- dtype = t.dtype
- check(
- t.is_floating_point() or t.is_complex(),
- lambda: f"{f_name}, : Expected a floating point or complex tensor as input. Got , {dtype}",
- )
- if allow_low_precision_dtypes:
- check(
- dtype in (torch.float, torch.double, torch.cfloat, torch.cdouble),
- lambda: f"{f_name} : Low precision dtypes not supported. Got {dtype}",
- )
- # From aten/src/ATen/native/LinearAlgebraUtils.h
- def checkIsMatrix(A: Tensor, f_name: str, arg_name: str = "A"):
- check(
- A.dim() >= 2,
- lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.",
- )
- def checkUplo(uplo: str):
- uplo_uppercase = uplo.upper()
- assert (
- len(uplo) == 1 and uplo_uppercase == "U" or uplo_uppercase == "L"
- ), f"Expected UPLO argument to be 'L' or 'U', but got {uplo}"
- # @register_meta(aten.linalg_eigh.default)
- def meta_linalg_eigh(self, uplo="L"):
- squareCheckInputs(self, "linalg_eigh")
- checkUplo(uplo)
- real_dtype = toRealValueType(self.dtype)
- assert self.dim() >= 2
- values = self.new_empty(self.shape, dtype=real_dtype)
- values.transpose_(-2, -1)
- vectors = self.new_empty(self.shape[:-1])
- return (values, vectors)
- # From aten/src/ATen/native/BatchLinearAlgebra.cpp
- @register_meta(aten.linalg_cholesky_ex.default)
- def linalg_cholesky_ex(A: Tensor, upper: bool = False, check_errors: bool = False):
- squareCheckInputs(A, "linalg.cholesky")
- checkFloatingOrComplex(A, "linalg.cholesky")
- A_shape = A.shape
- ndim = len(A_shape)
- # L
- L_strides = make_contiguous_strides_for(A_shape, False)
- L = A.new_empty(A_shape)
- L.as_strided_(A_shape, L_strides)
- # infos
- infos = A.new_empty(A_shape[0 : ndim - 2], dtype=torch.int32)
- return L, infos
- # From aten/src/ATen/native/BatchLinearAlgebra.cpp
- @register_meta(aten.linalg_inv_ex.default)
- def linalg_inv_ex_meta(A: Tensor, check_errors: bool = False):
- squareCheckInputs(A, "linalg.inv_ex")
- checkFloatingOrComplex(A, "linalg.inv_ex", allow_low_precision_dtypes=False)
- L = A.new_empty(A.shape)
- L.as_strided_(A.shape, make_contiguous_strides_for(A.shape, row_major=False))
- infos = A.new_empty(A.shape[:-2], dtype=torch.int32)
- return L, infos
- # From aten/src/ATen/native/BatchLinearAlgebra.cpp
- # NOTE: matching defaults in aten/src/ATen/native/native_functions.yaml
- @register_meta(aten._linalg_svd.default)
- def _linalg_svd_meta(
- A: Tensor, full_matrices: bool = False, compute_uv: bool = True, driver: str = None
- ):
- checkIsMatrix(A, "linalg.svd")
- checkFloatingOrComplex(A, "linalg.svd")
- batch_dims = list(A.shape[:-2])
- m = A.shape[-2]
- n = A.shape[-1]
- k = min(m, n)
- if compute_uv:
- U_shape = batch_dims + [m, m if full_matrices else k]
- U = A.new_empty(U_shape)
- U.as_strided_(U_shape, make_contiguous_strides_for(U_shape, row_major=False))
- V_shape = batch_dims + [n if full_matrices else k, n]
- V = A.new_empty(V_shape)
- # TODO: need to distinguish cuSOLVER case? (see original code)
- V.as_strided_(V_shape, make_contiguous_strides_for(V_shape, row_major=False))
- else:
- # doesn't matter
- U = A.new_empty([0])
- V = A.new_empty([0])
- # S is always real, even when A is complex.
- S = A.new_empty(batch_dims + [k], dtype=toRealValueType(A.dtype))
- return U, S, V
- # From aten/src/ATen/native/LinearAlgebra.cpp
- @register_meta(aten._linalg_det.default)
- def _linalg_det_meta(A):
- squareCheckInputs(A, "linalg.det")
- checkFloatingOrComplex(A, "linalg.det")
- det = A.new_empty(A.shape[:-2])
- LU = A.new_empty(A.shape)
- LU.as_strided_(A.shape, make_contiguous_strides_for(A.shape, row_major=False))
- pivots = A.new_empty(A.shape[:-1], dtype=torch.int32)
- return det, LU, pivots
- # From aten/src/ATen/native/ReflectionPad.cpp
- @register_meta(
- [aten.reflection_pad2d_backward.default, aten.replication_pad2d_backward.default]
- )
- def meta_pad2d_backward(grad_output, self, padding):
- dim_w = 2
- dim_h = 1
- dim_plane = 0
- nbatch = 1
- self_shape = self.shape
- if self.dim() == 4:
- nbatch = self_shape[0]
- dim_w += 1
- dim_h += 1
- dim_plane += 1
- pad_l = padding[0]
- pad_r = padding[1]
- pad_t = padding[2]
- pad_b = padding[3]
- nplane = self_shape[dim_plane]
- input_h = self_shape[dim_h]
- input_w = self_shape[dim_w]
- output_h = input_h + pad_t + pad_b
- output_w = input_w + pad_l + pad_r
- check(
- output_w == grad_output.shape[dim_w],
- lambda: f"gradOutput width unexpected. Expected: {output_w}, Got: {grad_output.shape[dim_w]}",
- )
- check(
- output_h == grad_output.shape[dim_h],
- lambda: f"gradOutput height unexpected. Expected: {output_h}, Got: {grad_output.shape[dim_h]}",
- )
- return self.new_empty(self.shape)
- @register_meta(aten.reflection_pad2d.default)
- def meta_pad2d(self, padding):
- valid_dims = self.size(1) != 0 and self.size(2) != 0
- check(
- (self.ndim == 3 and valid_dims)
- or (self.ndim == 4 and valid_dims and self.size(3) != 0),
- lambda: f"3D or 4D (batch mode) tensor expected for input, but got: {self}",
- )
- if self.ndim == 4:
- nbatch, nplane, input_h, input_w = self.shape
- else:
- nbatch = 1
- nplane, input_h, input_w = self.shape
- pad_l, pad_r, pad_t, pad_b = padding
- output_h = input_h + pad_t + pad_b
- output_w = input_w + pad_l + pad_r
- if self.ndim == 3:
- return self.new_empty((nplane, output_h, output_w))
- else:
- return self.new_empty((nbatch, nplane, output_h, output_w))
- @register_meta([aten.bernoulli.default, aten.bernoulli.out])
- @out_wrapper()
- def meta_bernoulli(self, *, generator=None):
- # https://github.com/pytorch/pytorch/issues/88612
- return torch.empty_like(self).contiguous()
- @register_meta(aten.bernoulli_.float)
- def meta_bernoulli_(self, p=0.5, generator=None):
- return self
- @register_meta(aten.bernoulli.p)
- def meta_bernoulli_p(self, p=0.5, generator=None):
- # https://github.com/pytorch/pytorch/issues/88612
- return torch.empty_like(self).contiguous()
- @register_meta(aten._fused_moving_avg_obs_fq_helper.default)
- def meta__fused_moving_avg_obs_fq_helper(
- self,
- observer_on,
- fake_quant_on,
- running_min,
- running_max,
- scale,
- zero_point,
- averaging_const,
- quant_min,
- quant_max,
- ch_axis,
- per_row_fake_quant=False,
- symmetric_quant=False,
- ):
- check(
- ch_axis < self.dim(),
- lambda: "Error in fused_moving_avg_obs_fake_quant_cpu: ch_axis must be < self.dim()",
- )
- mask = torch.empty_like(self, dtype=torch.bool)
- return (torch.empty_like(self), mask)
- def dot_check(self, other):
- check(
- self.dim() == 1 and other.dim() == 1,
- lambda: f"1D tensors expected, but got {self.dim()}D and {other.dim()}D tensors",
- )
- @register_meta(aten.dot.default)
- def meta_dot(self, tensor):
- dot_check(self, tensor)
- return self.new_empty(())
- @register_meta([aten.mm.default])
- def meta_mm(a, b):
- check(a.dim() == 2, lambda: "a must be 2D")
- check(b.dim() == 2, lambda: "b must be 2D")
- N, M1 = a.shape
- M2, P = b.shape
- check(M1 == M2, lambda: "a and b must have same reduction dim")
- return a.new_empty(N, P)
- def _compute_reduction_shape(self, dims, keepdim):
- if keepdim:
- return tuple(self.shape[i] if i not in dims else 1 for i in range(self.ndim))
- return utils.compute_reduction_output_shape(self.shape, dims)
- # FakeTensors (meta tensors with a device) will report device as meta
- # when running meta kernels. Here, access the "fake device" of FakeTensor if it
- # exists so meta kernels which have diverge per device will be more
- # accurate when run with FakeTensors
- def device_hint(tensor) -> "str":
- if isinstance(tensor, torch._subclasses.FakeTensor):
- return tensor.fake_device.type
- else:
- return "cuda" # default to cuda
- def calc_conv_nd_return_shape(
- input_tensor: torch.Tensor,
- weight: torch.Tensor,
- stride: Union[List[int], int],
- padding: Union[List[int], int],
- dilation: Union[List[int], int],
- is_transposed: bool,
- groups: int,
- output_padding: Optional[Union[List[int], int]] = None,
- ):
- def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
- """
- Formula to apply to calculate the length of some dimension of the output
- See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
- Args:
- ln: length of the dimension
- p: padding in that dim
- d: dilation in that dim
- k: kernel size in that dim
- s: stride in that dim
- Returns:
- The output length
- """
- return (ln + 2 * p - d * (k - 1) - 1) // s + 1
- def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int:
- """
- Formula to apply to calculate the length of some dimension of the output
- if transposed convolution is used.
- See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
- Args:
- ln: length of the dimension
- p: padding in that dim
- d: dilation in that dim
- k: kernel size in that dim
- s: stride in that dim
- op: output padding in that dim
- Returns:
- The output length
- """
- return (ln - 1) * s - 2 * p + d * (k - 1) + op + 1
- kernel_size = weight.shape[2:]
- dims = input_tensor.shape[2:]
- if is_transposed:
- out_channels = groups * weight.shape[1]
- else:
- out_channels = weight.shape[0]
- if weight.shape[1] * groups != input_tensor.shape[1]:
- raise RuntimeError("Invalid channel dimensions")
- ret_shape = [input_tensor.shape[0], out_channels]
- if isinstance(stride, IntLike):
- stride = [stride] * len(dims)
- elif len(stride) == 1:
- stride = [stride[0]] * len(dims)
- if isinstance(padding, IntLike):
- padding = [padding] * len(dims)
- elif len(padding) == 1:
- padding = [padding[0]] * len(dims)
- if isinstance(dilation, IntLike):
- dilation = [dilation] * len(dims)
- elif len(dilation) == 1:
- dilation = [dilation[0]] * len(dims)
- output_padding_list: Optional[List[int]] = None
- if output_padding:
- if isinstance(output_padding, IntLike):
- output_padding_list = [output_padding] * len(dims)
- elif len(output_padding) == 1:
- output_padding_list = [output_padding[0]] * len(dims)
- else:
- output_padding_list = output_padding
- for i in range(len(dims)):
- # If output_padding is present, we are dealing with a transposed convolution
- if output_padding_list:
- ret_shape.append(
- _formula_transposed(
- dims[i],
- padding[i],
- dilation[i],
- kernel_size[i],
- stride[i],
- output_padding_list[i],
- )
- )
- else:
- ret_shape.append(
- _formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i])
- )
- return ret_shape
- def is_channels_last(ten):
- return torch._prims_common.suggest_memory_format(ten) == torch.channels_last
- @register_meta(aten.convolution.default)
- def meta_conv(
- input_tensor: torch.Tensor,
- weight: torch.Tensor,
- bias: torch.Tensor,
- stride: List[int],
- padding: List[int],
- dilation: List[int],
- is_transposed: bool,
- output_padding: List[int],
- groups: int,
- ):
- def pick_memory_format():
- if device_hint(input_tensor) == "cuda":
- if is_channels_last(input_tensor) or is_channels_last(weight):
- return torch.channels_last
- else:
- if is_channels_last(input_tensor):
- return torch.channels_last
- if input_tensor.is_contiguous(memory_format=torch.contiguous_format):
- return torch.contiguous_format
- elif input_tensor.is_contiguous(memory_format=torch.preserve_format):
- return torch.preserve_format
- shape_out = calc_conv_nd_return_shape(
- input_tensor,
- weight,
- stride,
- padding,
- dilation,
- is_transposed,
- groups,
- output_padding if is_transposed else None,
- )
- out = input_tensor.new_empty(shape_out)
- out = out.to(memory_format=pick_memory_format()) # type: ignore[call-overload]
- return out
- if torch._C.has_mkldnn:
- _meta_lib_dont_use_me_use_register_meta_for_mkldnn = torch.library.Library(
- "mkldnn", "IMPL", "Meta"
- )
- def pick_mkldnn_conv_memory_format(input_tensor, weight):
- if weight.is_mkldnn:
- return torch.channels_last
- if is_channels_last(input_tensor) or is_channels_last(weight):
- return torch.channels_last
- if input_tensor.is_contiguous(memory_format=torch.contiguous_format):
- return torch.contiguous_format
- elif input_tensor.is_contiguous(memory_format=torch.preserve_format):
- return torch.preserve_format
- @register_meta(torch.ops.mkldnn._convolution_pointwise.default)
- def meta_mkldnn_convolution_default(
- input_tensor,
- weight,
- bias,
- padding,
- stride,
- dilation,
- groups,
- attr,
- scalars,
- algorithm,
- ):
- shape_out = calc_conv_nd_return_shape(
- input_tensor, weight, stride, padding, dilation, False, groups, []
- )
- out = input_tensor.new_empty(shape_out)
- out_memory_format = torch.channels_last
- out = out.to(memory_format=out_memory_format) # type: ignore[call-overload]
- return out
- @register_meta(torch.ops.mkldnn._convolution_pointwise.binary)
- def meta_mkldnn_convolution_binary(
- input_tensor,
- other,
- weight,
- bias,
- padding,
- stride,
- dilation,
- groups,
- binary_attr,
- alpha,
- unary_attr,
- unary_scalars,
- unary_algorithm,
- ):
- out = input_tensor.new_empty(other.size())
- out = out.to(memory_format=torch.channels_last) # type: ignore[call-overload]
- return out
- @register_meta(torch.ops.mkldnn._convolution_pointwise_.binary)
- def meta_mkldnn_convolution_binary_inplace(
- input_tensor,
- other,
- weight,
- bias,
- padding,
- stride,
- dilation,
- groups,
- binary_attr,
- alpha,
- unary_attr,
- unary_scalars,
- unary_algorithm,
- ):
- return other
- @register_meta(torch.ops.mkldnn._linear_pointwise.default)
- def meta_linear_pointwise_default(
- input_tensor, weight, bias, attr, scalars, algorithm
- ):
- return input_tensor.new_empty((*input_tensor.shape[:-1], weight.shape[0]))
- @register_meta(torch.ops.mkldnn._linear_pointwise.binary)
- def meta_linear_pointwise_binary(input_tensor, other, weight, bias, attr):
- out = input_tensor.new_empty(other.size())
- return out
- if torch._C.has_mkl:
- _meta_lib_dont_use_me_use_register_meta_for_mkl = torch.library.Library(
- "mkl", "IMPL", "Meta"
- )
- @register_meta(torch.ops.mkl._mkl_linear)
- def meta_mkl_linear(
- input_tensor,
- packed_weight,
- orig_weight,
- bias,
- batch_size,
- ):
- return input_tensor.new_empty(
- (*input_tensor.shape[:-1], orig_weight.shape[0])
- )
- # from check_dim_size() in aten/src/ATen/TensorUtils.cpp.
- def check_dim_size(tensor, dim, dim_size, size):
- check(
- tensor.dim() == dim and tensor.shape[dim_size] == size,
- lambda: f"Expected a tensor of dimension {dim} and tensor.size[{dim_size}] == {size}, "
- + f"but got : dimension {tensor.dim()} and tensor.size[{dim_size}] = {tensor.shape[dim_size]}",
- )
- @register_meta(aten.avg_pool2d.default)
- def meta_avg_pool2d(
- input,
- kernel_size,
- stride=(),
- padding=(0,),
- ceil_mode=False,
- count_include_pad=True,
- divisor_override=None,
- ):
- def unpack(name, val):
- check(
- len(val) in [1, 2],
- lambda: f"avg_pool2d: {name} must either be a single int, or a tuple of two ints",
- )
- H = val[0]
- W = H if len(val) == 1 else val[1]
- return H, W
- kH, kW = unpack("kernel_size", kernel_size)
- check(
- len(stride) in [0, 1, 2],
- lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
- )
- if len(stride) == 0:
- dH, dW = kH, kW
- elif len(stride) == 1:
- dH, dW = stride[0], stride[0]
- else:
- dH, dW = unpack("stride", stride)
- padH, padW = unpack("padding", padding)
- check(
- divisor_override is None or divisor_override != 0,
- lambda: "divisor must be not zero",
- )
- nbatch = input.size(-4) if input.dim() == 4 else 1
- nInputPlane = input.size(-3)
- inputHeight = input.size(-2)
- inputWidth = input.size(-1)
- outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode)
- outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode)
- memory_format = utils.suggest_memory_format(input)
- pool2d_shape_check(
- input,
- kH,
- kW,
- dH,
- dW,
- padH,
- padW,
- 1,
- 1,
- nInputPlane,
- inputHeight,
- inputWidth,
- outputHeight,
- outputWidth,
- memory_format,
- )
- if input.dim() == 3:
- size = [nInputPlane, outputHeight, outputWidth]
- else:
- size = [nbatch, nInputPlane, outputHeight, outputWidth]
- return torch.empty(
- size, dtype=input.dtype, device=input.device, memory_format=memory_format
- )
- # from avg_pool2d_backward_shape_check() in aten/src/ATen/native/Pool.h.
- def avg_pool2d_backward_shape_check(
- input,
- gradOutput,
- nbatch,
- kH,
- kW,
- dH,
- dW,
- padH,
- padW,
- nInputPlane,
- inputHeight,
- inputWidth,
- outputHeight,
- outputWidth,
- mem_format,
- ):
- pool2d_shape_check(
- input,
- kH,
- kW,
- dH,
- dW,
- padH,
- padW,
- 1,
- 1,
- nInputPlane,
- inputHeight,
- inputWidth,
- outputHeight,
- outputWidth,
- mem_format,
- )
- ndim = input.dim()
- nOutputPlane = nInputPlane
- check_dim_size(gradOutput, ndim, ndim - 3, nOutputPlane)
- check_dim_size(gradOutput, ndim, ndim - 2, outputHeight)
- check_dim_size(gradOutput, ndim, ndim - 1, outputWidth)
- # Don't override the C++ registration.
- @register_meta(aten.avg_pool2d_backward.default)
- def meta_avg_pool2d_backward(
- gradOutput_,
- input,
- kernel_size,
- stride,
- padding,
- ceil_mode,
- count_include_pad,
- divisor_override,
- ):
- # From aten/src/ATen/native/AveragePool2d.cpp structured kernel meta func.
- check(
- len(kernel_size) == 1 or len(kernel_size) == 2,
- lambda: "avg_pool2d: kernel_size must either be a single int, or a tuple of two ints",
- )
- kH = kernel_size[0]
- kW = kH if len(kernel_size) == 1 else kernel_size[1]
- check(
- len(stride) == 0 or len(stride) == 1 or len(stride) == 2,
- lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
- )
- dH = kH if len(stride) == 0 else stride[0]
- dW = kW if len(stride) == 0 else dH if len(stride) == 1 else stride[1]
- check(
- len(padding) == 1 or len(padding) == 2,
- lambda: "avg_pool2d: padding must either be a single int, or a tuple of two ints",
- )
- padH = padding[0]
- padW = padH if len(padding) == 1 else padding[1]
- check(
- divisor_override is None or divisor_override != 0,
- lambda: "divisor must be not zero",
- )
- input_size = input.shape
- nbatch = input_size[-4] if input.dim() == 4 else 1
- nInputPlane = input_size[-3]
- inputHeight = input_size[-2]
- inputWidth = input_size[-1]
- outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode)
- outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode)
- mem_format = utils.suggest_memory_format(input)
- avg_pool2d_backward_shape_check(
- input,
- gradOutput_,
- nbatch,
- kH,
- kW,
- dH,
- dW,
- padH,
- padW,
- nInputPlane,
- inputHeight,
- inputWidth,
- outputHeight,
- outputWidth,
- mem_format,
- )
- return torch.empty(
- input_size, dtype=input.dtype, device=input.device, memory_format=mem_format
- )
- @register_meta(aten._adaptive_avg_pool2d.default)
- def meta_adaptive_avg_pool2d(self, output_size):
- check(
- self.ndim == 3 or self.ndim == 4,
- lambda: f"Expected 3D or 4D tensor, but got {self.shape}",
- )
- output_shape = self.shape[:-2] + tuple(output_size)
- memory_format = utils.suggest_memory_format(self)
- # need to set memory_format to preserve the memory format of the input
- # channel last input should have channel last output
- return torch.empty(
- output_shape, dtype=self.dtype, device=self.device, memory_format=memory_format
- )
- @register_meta(aten._adaptive_avg_pool3d.default)
- def meta_adaptive_avg_pool3d(self, output_size):
- check(
- self.ndim == 4 or self.ndim == 5,
- lambda: f"Expected 4D or 5D tensor, but got {self.shape}",
- )
- return self.new_empty(self.shape[:-3] + tuple(output_size))
- @register_meta(aten._adaptive_avg_pool2d_backward.default)
- def meta__adaptive_avg_pool2d_backward(grad_out, self):
- ndim = grad_out.ndim
- for i in range(1, ndim):
- check(
- grad_out.size(i) > 0,
- lambda: f"adaptive_avg_pool2d_backward(): Expected grad_output to have non-zero \
- size for non-batch dimensions, {grad_out.shape} with dimension {i} being empty",
- )
- check(
- ndim == 3 or ndim == 4,
- lambda: f"adaptive_avg_pool2d_backward(): Expected 3D or 4D tensor, but got {self.shape}",
- )
- check(
- self.dtype == grad_out.dtype,
- lambda: f"expected dtype {self.dtype} for `grad_output` but got dtype {grad_out.dtype}",
- )
- return self.new_empty(self.shape)
- @register_meta(aten.repeat_interleave.Tensor)
- def meta_repeat_interleave_Tensor(repeats, output_size=None):
- if output_size is None:
- raise RuntimeError("cannot repeat_interleave a meta tensor without output_size")
- return repeats.new_empty(output_size)
- @register_meta([aten.complex.default, aten.complex.out])
- @out_wrapper()
- def meta_complex(real, imag):
- assert real.dtype.is_floating_point
- assert imag.dtype.is_floating_point
- out_shape = _broadcast_shapes(real.shape, imag.shape)
- return real.new_empty(out_shape, dtype=corresponding_complex_dtype(real.dtype))
- @register_meta(aten.vdot.default)
- def vdot(self, other):
- if not self.is_complex:
- return torch.dot(self, other)
- if self.is_conj():
- if other.is_conj():
- return torch.vdot(other.conj(), self.conj())
- else:
- return torch.dot(self.conj(), other)
- elif other.is_conj():
- return torch.dot(self, other.conj()).conj()
- dot_check(self, other)
- return self.new_empty(())
- # Leaving this function around because a python implementation
- # of indexing shape inference is useful,
- # but not registering it to the dispatcher because we already
- # get shape inference through structured kernels
- @register_meta(aten.index.Tensor)
- def meta_index_Tensor(self, indices):
- check_no_bool_index_tensors(aten.index.Tensor, self, indices)
- check(indices, lambda: "at least one index must be provided")
- # aten::index is the internal advanced indexing implementation
- # checkIndexTensorTypes and expandTensors
- result: List[Optional[Tensor]] = []
- for i, index in enumerate(indices):
- if index is not None:
- check(
- index.dtype in [torch.long, torch.int, torch.int8, torch.bool],
- lambda: "tensors used as indices must be long, int, byte or bool tensors",
- )
- if index.dtype in [torch.int8, torch.bool]:
- nonzero = index.nonzero()
- k = len(result)
- check(
- k + index.ndim <= self.ndim,
- lambda: f"too many indices for tensor of dimension {self.ndim}",
- IndexError,
- )
- for j in range(index.ndim):
- check(
- index.shape[j] == self.shape[k + j],
- lambda: f"The shape of the mask {index.shape} at index {i} "
- f"does not match the shape of the indexed tensor {self.shape} at index {k + j}",
- IndexError,
- )
- result.append(nonzero.select(1, j))
- else:
- result.append(index)
- else:
- result.append(index)
- indices = result
- check(
- len(indices) <= self.ndim,
- lambda: f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})",
- )
- # expand_outplace
- import torch._refs as refs # avoid import cycle in mypy
- indices = list(refs._maybe_broadcast(*indices))
- # add missing null tensors
- while len(indices) < self.ndim:
- indices.append(None)
- # hasContiguousSubspace
- # true if all non-null tensors are adjacent
- # See:
- # https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
- # https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency
- state = 0
- has_contiguous_subspace = False
- for index in indices:
- if state == 0:
- if index is not None:
- state = 1
- elif state == 1:
- if index is None:
- state = 2
- else:
- if index is not None:
- break
- else:
- has_contiguous_subspace = True
- # transposeToFront
- # This is the logic that causes the newly inserted dimensions to show up
- # at the beginning of the tensor, if they're not contiguous
- if not has_contiguous_subspace:
- dims = []
- transposed_indices = []
- for i, index in enumerate(indices):
- if index is not None:
- dims.append(i)
- transposed_indices.append(index)
- for i, index in enumerate(indices):
- if index is None:
- dims.append(i)
- transposed_indices.append(index)
- self = self.permute(dims)
- indices = transposed_indices
- # AdvancedIndex::AdvancedIndex
- # Now we can assume the indices have contiguous subspace
- # This is simplified from AdvancedIndex which goes to more effort
- # to put the input and indices in a form so that TensorIterator can
- # take them. If we write a ref for this, probably that logic should
- # get implemented
- before_shape: List[int] = []
- after_shape: List[int] = []
- replacement_shape: List[int] = []
- for dim, index in enumerate(indices):
- if index is None:
- if replacement_shape:
- after_shape.append(self.shape[dim])
- else:
- before_shape.append(self.shape[dim])
- else:
- replacement_shape = list(index.shape)
- return self.new_empty(before_shape + replacement_shape + after_shape)
- @register_meta([aten.convolution_backward.default])
- def meta_convolution_backward(
- grad_output_,
- input_,
- weight_,
- bias_sizes_opt,
- stride,
- padding,
- dilation,
- transposed,
- output_padding,
- groups,
- output_mask,
- ):
- # High level logic taken from slow_conv3d_backward_cpu which should
- # be representative of all convolution_backward impls
- backend_grad_input = None
- backend_grad_weight = None
- backend_grad_bias = None
- if output_mask[0]:
- backend_grad_input = grad_output_.new_empty(input_.size())
- if output_mask[1]:
- backend_grad_weight = grad_output_.new_empty(weight_.size())
- if output_mask[2]:
- backend_grad_bias = grad_output_.new_empty(bias_sizes_opt)
- return (backend_grad_input, backend_grad_weight, backend_grad_bias)
- @register_meta([aten.addbmm.default, aten.addbmm.out])
- @out_wrapper()
- def meta_addbmm(self, batch1, batch2, *, beta=1, alpha=1):
- dim1 = batch1.size(1)
- dim2 = batch2.size(2)
- self = self.expand((dim1, dim2))
- check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
- check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
- check(
- batch1.size(0) == batch2.size(0),
- lambda: f"batch1 and batch2 must have same number of batches, got {batch1.size(0)} and {batch2.size(0)}",
- )
- check(
- batch1.size(2) == batch2.size(1),
- lambda: (
- f"Incompatible matrix sizes for bmm ({batch1.size(1)}x{batch1.size(2)} "
- f"and {batch2.size(1)}x{batch2.size(2)})"
- ),
- )
- check(
- self.size(0) == dim1 and self.size(1) == dim2,
- lambda: "self tensor does not match matmul output shape",
- )
- return self.new_empty(self.size())
- @register_meta(aten._cdist_forward.default)
- def meta_cdist_forward(x1, x2, p, compute_mode):
- check(
- x1.dim() >= 2,
- lambda: f"cdist only supports at least 2D tensors, X1 got: {x1.dim()}D",
- )
- check(
- x2.dim() >= 2,
- lambda: f"cdist only supports at least 2D tensors, X2 got: {x2.dim()}D",
- )
- check(
- x1.size(-1) == x2.size(-1),
- lambda: f"X1 and X2 must have the same number of columns. X1: {x1.size(-1)} X2: {x2.size(-1)}",
- )
- check(
- utils.is_float_dtype(x1.dtype),
- lambda: "cdist only supports floating-point dtypes, X1 got: {x1.dtype}",
- )
- check(
- utils.is_float_dtype(x2.dtype),
- lambda: "cdist only supports floating-point dtypes, X2 got: {x2.dtype}",
- )
- check(p >= 0, lambda: "cdist only supports non-negative p values")
- check(
- compute_mode in (None, 1, 2),
- lambda: f"possible modes: None, 1, 2, but was: {compute_mode}",
- )
- r1 = x1.size(-2)
- r2 = x2.size(-2)
- batch_tensor1 = x1.shape[:-2]
- batch_tensor2 = x2.shape[:-2]
- output_shape = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2))
- output_shape.extend([r1, r2])
- return x1.new_empty(output_shape)
- @register_meta(aten._embedding_bag.default)
- def meta_embedding_bag(
- weight,
- indices,
- offsets,
- scale_grad_by_freq=False,
- mode=0,
- sparse=False,
- per_sample_weights=None,
- include_last_offset=False,
- padding_idx=-1,
- ):
- check(
- indices.dtype in (torch.long, torch.int),
- lambda: f"expected indices to be long or int, got {indices.dtype}",
- )
- check(
- offsets.dtype in (torch.long, torch.int),
- lambda: f"expected offsets to be long or int, got {offsets.dtype}",
- )
- check(
- utils.is_float_dtype(weight.dtype),
- lambda: f"expected weight to be floating point type, got {weight.dtype}",
- )
- num_bags = offsets.size(0)
- if include_last_offset:
- check(
- num_bags >= 1, lambda: "include_last_offset: numBags should be at least 1"
- )
- num_bags -= 1
- output = weight.new_empty(num_bags, weight.size(1))
- MODE_SUM, MODE_MEAN, MODE_MAX = range(3)
- if per_sample_weights is not None:
- check(
- mode == MODE_SUM,
- lambda: "embedding_bag: per_sample_weights only supported with mode='sum'",
- )
- check(
- per_sample_weights.dtype == weight.dtype,
- lambda: f"expected weight ({weight.dtype}) and per_sample_weights ({per_sample_weights.dtype}) to have same dtype",
- )
- check(
- per_sample_weights.ndim == 1,
- lambda: f"expected per_sample_weights to be 1D tensor, got {per_sample_weights.ndim}D",
- )
- check(
- per_sample_weights.numel() == indices.numel(),
- lambda: (
- f"expected per_sample_weights.numel() ({per_sample_weights.numel()} "
- f"to be the same as indices.numel() ({indices.numel()})"
- ),
- )
- def is_fast_path_index_select_scale(src, scale, output, padding_idx):
- return (
- is_fast_path_index_select(src, output, padding_idx) and scale.stride(0) == 1
- )
- def is_fast_path_index_select(src, output, padding_idx):
- return (
- (src.dtype == torch.float or src.dtype == torch.half)
- and src.stride(1) == 1
- and output.stride(1) == 1
- and padding_idx < 0
- )
- def is_fast_path(src, scale, output, padding_idx):
- if scale is not None:
- return is_fast_path_index_select_scale(src, scale, output, padding_idx)
- else:
- return is_fast_path_index_select(src, output, padding_idx)
- if device_hint(offsets) != "cpu":
- offset2bag = indices.new_empty(indices.size(0))
- bag_size = indices.new_empty(offsets.size())
- if mode == MODE_MAX:
- max_indices = indices.new_empty(num_bags, weight.size(1))
- else:
- max_indices = indices.new_empty(0)
- else:
- fast_path_sum = is_fast_path(weight, per_sample_weights, output, padding_idx)
- if mode == MODE_MEAN or mode == MODE_MAX or not fast_path_sum:
- offset2bag = offsets.new_empty(indices.size(0))
- else:
- offset2bag = offsets.new_empty(0)
- bag_size = offsets.new_empty(num_bags)
- # This part of the logic comes from make_max_indices_out in EmbeddingBag.cpp
- numBags = offsets.shape[0]
- if mode == MODE_MAX:
- if include_last_offset:
- check(
- numBags >= 1,
- lambda: "include_last_offset: numBags should be at least 1",
- )
- numBags -= 1
- max_indices = offsets.new_empty(numBags, weight.shape[1])
- else:
- max_indices = offsets.new_empty(bag_size.size())
- return output, offset2bag, bag_size, max_indices
- @register_meta(aten._embedding_bag_forward_only.default)
- def meta_embedding_bag_forward_only(weight, indices, offsets, *args):
- output, offset2bag, bag_size, max_indices = meta_embedding_bag(
- weight, indices, offsets, *args
- )
- if device_hint(offsets) == "cpu":
- bag_size = offsets.new_empty(offsets.size())
- return output, offset2bag, bag_size, max_indices
- def _get_reduction_dtype(input, dtype, promote_int_to_long=True):
- # if specified, dtype takes precedence
- if dtype:
- return dtype
- if input.dtype.is_floating_point or input.dtype.is_complex:
- return input.dtype
- elif promote_int_to_long:
- return torch.long
- return input.dtype
- @register_meta([aten.nansum.default, aten.nansum.out])
- @out_wrapper()
- def meta_nansum(input, dims=None, keepdim=False, *, dtype=None):
- output_dtype = _get_reduction_dtype(input, dtype, promote_int_to_long=True)
- dims = utils.reduction_dims(input.shape, dims)
- output_shape = _compute_reduction_shape(input, dims, keepdim)
- return input.new_empty(output_shape, dtype=output_dtype)
- @register_meta(aten.nanmedian.default)
- def meta_nanmedian(input):
- output_shape = utils.compute_reduction_output_shape(
- input.shape, tuple(range(input.dim()))
- )
- return input.new_empty(output_shape)
- @register_meta([aten.nanmedian.dim, aten.nanmedian.dim_values])
- @out_wrapper("values", "indices")
- def meta_nanmedian_dim(input, dim=-1, keepdim=False):
- dim = utils.reduction_dims(input.shape, (dim,))
- output_shape = _compute_reduction_shape(input, dim, keepdim)
- return (
- input.new_empty(output_shape),
- input.new_empty(output_shape, dtype=torch.long),
- )
- @register_meta(aten.logical_not_.default)
- def meta_logical_not_(self):
- return self
- @register_meta(aten.repeat.default)
- def meta_repeat(self, repeats):
- check(
- len(repeats) >= self.dim(),
- lambda: "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor",
- )
- # Add new leading dimensions to the tensor if the
- # number of target dimensions is larger than the
- # number of source dimensions.
- num_new_dimensions = len(repeats) - self.dim()
- padded_size = (1,) * num_new_dimensions + tuple(self.shape)
- target_size = [padded_size[i] * repeats[i] for i in range(len(repeats))]
- return self.new_empty(target_size)
- @register_meta(aten.zero_.default)
- def meta_zero_(self):
- return self
- @register_meta(
- [
- aten.mul_.Scalar,
- aten.div_.Scalar,
- aten.mul_.Tensor,
- aten.div_.Tensor,
- aten.logical_and_.default,
- aten.logical_or_.default,
- aten.logical_xor_.default,
- ],
- )
- def meta_binop_inplace(self, other):
- return self
- @register_meta(
- [
- aten.add_.Scalar,
- aten.sub_.Scalar,
- aten.add_.Tensor,
- aten.sub_.Tensor,
- ],
- )
- def meta_binop_inplace_alpha(self, other, alpha=1):
- return self
- @register_meta([aten.round.default, aten.round.decimals])
- def meta_round(self, **kwargs):
- return _elementwise_meta(
- self, type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT
- )
- @register_meta(aten.zero.default)
- def meta_zero(self):
- return self.new_empty(self.shape)
- @register_meta([aten.fill_.Tensor, aten.fill_.Scalar])
- def meta_fill_(self, val):
- return self
- @register_meta([aten.fill.Tensor, aten.fill.Scalar])
- def meta_fill(self, val):
- return torch.empty_like(self)
- @register_meta(aten.relu_.default)
- def meta_relu_(self):
- return self
- @register_meta(aten.index_put.default)
- def meta_index_put(self, indices, values, accumulate=False):
- return torch.empty_like(self)
- @register_meta(aten.masked_fill_.Scalar)
- def meta_masked_fill_(self, mask, value):
- return self
- @register_meta(aten.index_put_.default)
- def meta_index_put_(self, indices, values, accumulate=False):
- return self
- @register_meta(aten.alias.default)
- def meta_alias(self):
- return self.view(self.shape)
- def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None):
- check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
- check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
- batch1_sizes = batch1.size()
- batch2_sizes = batch2.size()
- bs = batch1_sizes[0]
- contraction_size = batch1_sizes[2]
- res_rows = batch1_sizes[1]
- res_cols = batch2_sizes[2]
- output_size = (bs, res_rows, res_cols)
- check(
- batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size,
- lambda: f"Expected size for first two dimensions of batch2 tensor to be: [{bs}"
- f", {contraction_size}] but got: [{batch2_sizes[0]}, {batch2_sizes[1]}].",
- )
- # TODO: handle out
- output = batch2.new_empty(output_size)
- if not is_bmm and self_baddbmm is not None:
- check(self_baddbmm.dim() == 3, lambda: "self must be a 3D tensor")
- check(
- self_baddbmm.size() == output_size,
- lambda: "Expected an input tensor shape with shape {output_size} but got shape: {self.size()}",
- )
- return output
- @register_meta(aten.bmm.default)
- def meta_bmm(self, mat2):
- return common_meta_baddbmm_bmm(self, mat2, True)
- def div_rtn(x, y):
- q = x // y
- r = x % y
- # WARNING: explicit bool conversion here is necessary;
- # would be fixed by SymBool
- if r != 0 and (bool(r < 0) != bool(y < 0)):
- q -= 1
- return q
- def pooling_output_shape_pad_lr(
- inputSize, kernelSize, pad_l, pad_r, stride, dilation, ceil_mode
- ):
- outputSize = (
- div_rtn(
- inputSize
- + pad_l
- + pad_r
- - dilation * (kernelSize - 1)
- - 1
- + (stride - 1 if ceil_mode else 0),
- stride,
- )
- + 1
- )
- if ceil_mode:
- if (outputSize - 1) * stride >= inputSize + pad_l:
- outputSize -= 1
- return outputSize
- def pooling_output_shape(inputSize, kernelSize, pad, stride, dilation, ceil_mode):
- check(stride != 0, lambda: "stride should not be zero")
- check(pad >= 0, lambda: f"pad must be non-negative, but got pad: {pad}")
- check(
- pad <= kernelSize // 2,
- lambda: f"pad should be at most half of kernel size, but got pad={pad} and kernel_size={kernelSize}",
- )
- return pooling_output_shape_pad_lr(
- inputSize, kernelSize, pad, pad, stride, dilation, ceil_mode
- )
- def pool2d_shape_check(
- input,
- kH,
- kW,
- dH,
- dW,
- padH,
- padW,
- dilationH,
- dilationW,
- nInputPlane,
- inputHeight,
- inputWidth,
- outputHeight,
- outputWidth,
- memory_format,
- ):
- ndim = input.dim()
- nOutputPlane = nInputPlane
- check(
- kW > 0 and kH > 0,
- lambda: "kernel size should be greater than zero, but got kH: {kH}, kW: {kW}",
- )
- check(
- dW > 0 and dH > 0,
- lambda: "stride should be greater than zero, but got dH: {dH}, dW: {dW}",
- )
- check(
- dilationH > 0 and dilationW > 0,
- lambda: "dilation should be greater than zero, but got dilationH: {dilationH}, dilationW: {dilationW}",
- )
- valid_dims = input.size(1) != 0 and input.size(2) != 0
- if memory_format == torch.channels_last:
- check(
- ndim == 4 and valid_dims and input.size(3) != 0,
- lambda: "Expected 4D (batch mode) tensor expected for input with channels_last layout"
- " with optional 0 dim batch size for input, but got: {input.size()}",
- )
- else:
- check(
- (ndim == 3 and input.size(0) != 0 and valid_dims)
- or (ndim == 4 and valid_dims and input.size(3) != 0),
- lambda: f"Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got: {input.size()}",
- )
- check(
- kW // 2 >= padW and kH // 2 >= padH,
- lambda: "pad should be smaller than or equal to half of kernel size, but got "
- f"padW = {padW}, padH = {padH}, kW = {kW}, kH = {kH}",
- )
- check(
- outputWidth >= 1 and outputHeight >= 1,
- lambda: f"Given input size: ({nInputPlane}x{inputHeight}x{inputWidth}). "
- f"Calculated output size: ({nOutputPlane}x{outputHeight}x{outputWidth}). "
- "Output size is too small",
- )
- def max_pool2d_checks_and_compute_shape(
- input, kernel_size, stride, padding, dilation, ceil_mode
- ):
- # Reference: aten/src/ATen/native/DilatedMaxPool2d.cpp
- def unpack(name, val):
- check(
- len(val) in [1, 2],
- lambda: f"max_pool2d: {name} must either be a single int, or a tuple of two ints",
- )
- H = val[0]
- W = H if len(val) == 1 else val[1]
- return H, W
- kH, kW = unpack("kernel_size", kernel_size)
- check(
- len(stride) in [0, 1, 2],
- lambda: "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
- )
- if len(stride) == 0:
- dH, dW = kH, kW
- else:
- dH, dW = unpack("stride", stride)
- padH, padW = unpack("padding", padding)
- dilationH, dilationW = unpack("dilation", dilation)
- nInputPlane = input.size(-3)
- inputHeight = input.size(-2)
- inputWidth = input.size(-1)
- memory_format = utils.suggest_memory_format(input)
- if memory_format == torch.channels_last:
- check(
- input.dim() == 4,
- lambda: "non-empty 4D (batch mode) tensor expected for input with channels_last layout",
- )
- elif memory_format == torch.contiguous_format:
- check(
- input.dim() in [3, 4],
- lambda: "non-empty 3D or 4D (batch mode) tensor expected for input",
- )
- else:
- check(
- False,
- lambda: "Unsupport memory format. Supports only ChannelsLast, Contiguous",
- )
- outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode)
- outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode)
- pool2d_shape_check(
- input,
- kH,
- kW,
- dH,
- dW,
- padH,
- padW,
- dilationH,
- dilationW,
- nInputPlane,
- inputHeight,
- inputWidth,
- outputHeight,
- outputWidth,
- memory_format,
- )
- return nInputPlane, outputHeight, outputWidth
- @register_meta(aten.max_pool2d_with_indices_backward.default)
- def meta_max_pool2d_with_indices_backward(
- grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices
- ):
- nInputPlane, outputHeight, outputWidth = max_pool2d_checks_and_compute_shape(
- self, kernel_size, stride, padding, dilation, ceil_mode
- )
- check(
- self.dtype == grad_output.dtype,
- lambda: "expected dtype {self.dtype} for `gradOutput` but got dtype {grad_output.dtype}",
- )
- nOutputPlane = nInputPlane
- ndim = self.ndim
- def _check_dim_size(t):
- check_dim_size(t, ndim, ndim - 3, nOutputPlane)
- check_dim_size(t, ndim, ndim - 2, outputHeight)
- check_dim_size(t, ndim, ndim - 1, outputWidth)
- _check_dim_size(grad_output)
- _check_dim_size(indices)
- memory_format = utils.suggest_memory_format(self)
- return torch.empty(
- self.shape, dtype=self.dtype, device=self.device, memory_format=memory_format
- )
- @register_meta(aten.max_pool2d_with_indices.default)
- def meta_max_pool2d_with_indices(
- input, kernel_size, stride=(), padding=(0,), dilation=(1,), ceil_mode=False
- ):
- nInputPlane, outputHeight, outputWidth = max_pool2d_checks_and_compute_shape(
- input, kernel_size, stride, padding, dilation, ceil_mode
- )
- nbatch = input.size(-4) if input.dim() == 4 else 1
- memory_format = utils.suggest_memory_format(input)
- if input.dim() == 3:
- size = [nInputPlane, outputHeight, outputWidth]
- else:
- size = [nbatch, nInputPlane, outputHeight, outputWidth]
- return (
- torch.empty(
- size, dtype=input.dtype, device=input.device, memory_format=memory_format
- ),
- torch.empty(
- size, dtype=torch.int64, device=input.device, memory_format=memory_format
- ),
- )
- @register_meta(aten.grid_sampler_2d_backward.default)
- def grid_sampler_2d_backward_meta(
- grad_output,
- input,
- grid,
- interpolation_mode,
- padding_mode,
- align_corners,
- output_mask,
- ):
- input_requires_grad = output_mask[0]
- if input_requires_grad:
- grad_input = torch.zeros_like(input, memory_format=torch.contiguous_format)
- else:
- grad_input = None
- grad_grid = torch.empty_like(grid, memory_format=torch.contiguous_format)
- return (grad_input, grad_grid)
- @register_meta([aten.full.default])
- def full(size, fill_value, *args, **kwargs):
- return torch.empty(size, *args, **kwargs)
- @register_meta(
- [
- aten.randint_like.default,
- aten.randint_like.low_dtype,
- aten.randn_like.default,
- aten.rand_like.default,
- aten.full_like.default,
- aten.ones_like.default,
- ]
- )
- def meta_like(self, *args, **kwargs):
- return aten.empty_like.default(self, **kwargs)
- # zeros_like is special cased to work for sparse
- @register_meta(aten.zeros_like.default)
- def zeros_like(
- self, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None
- ):
- if layout == torch.sparse_coo:
- check(
- memory_format is None,
- lambda: "memory format option is only supported by strided tensors",
- )
- res = torch.empty(
- 0,
- dtype=self.dtype if dtype is None else dtype,
- layout=layout,
- device=self.device if device is None else device,
- pin_memory=pin_memory,
- )
- if self.is_sparse:
- res.sparse_resize_and_clear_(
- self.size(), self.sparse_dim(), self.dense_dim()
- )
- else:
- res.sparse_resize_and_clear_(self.size(), self.dim(), 0)
- res._coalesced_(True)
- return res
- return aten.empty_like.default(
- self,
- dtype=dtype,
- layout=layout,
- device=device,
- pin_memory=pin_memory,
- memory_format=memory_format,
- )
- @register_meta(aten.select.int)
- def meta_select(self, dim, index):
- ndim = self.dim()
- check(
- ndim != 0, lambda: "select() cannot be applied to a 0-dim tensor.", IndexError
- )
- dim = dim if dim >= 0 else dim + ndim
- size = self.size(dim)
- check(
- not (-index > size or index >= size),
- lambda: f"select(): index {index} out of range for tensor of size "
- f"{self.size()} at dimension {dim}",
- IndexError,
- )
- index = index if index >= 0 else index + size
- new_size = list(self.size())
- new_stride = list(self.stride())
- new_storage_offset = self.storage_offset() + index * new_stride[dim]
- del new_size[dim]
- del new_stride[dim]
- return self.as_strided(new_size, new_stride, new_storage_offset)
- @register_meta(aten.select_scatter.default)
- def meta_select_scatter(self, src, dim, index):
- return utils.clone_preserve_strides(self)
- @register_meta(aten.slice_scatter.default)
- def meta_slice_scatter(self, src, dim=0, start=None, end=None, step=1):
- return utils.clone_preserve_strides(self)
- # TODO: Deduplicate this with canonicalize_dim
- def maybe_wrap_dim(dim: int, dim_post_expr: int, wrap_scalar: bool = True):
- if dim_post_expr <= 0:
- assert wrap_scalar
- dim_post_expr = 1
- min = -dim_post_expr
- max = dim_post_expr - 1
- assert not (dim < min or dim > max), f"dim {dim} out of bounds ({min}, {max})"
- if dim < 0:
- dim += dim_post_expr
- return dim
- def ensure_nonempty_size(t, dim):
- return 1 if t.dim() == 0 else t.shape[dim]
- # From aten/src/ATen/native/ScatterGatherChecks.h
- def gather_shape_check(self, dim, index):
- self_dims = max(self.dim(), 1)
- index_dims = max(index.dim(), 1)
- check(
- self_dims == index_dims,
- lambda: "Index tensor must have the same number of dimensions as input tensor",
- )
- for i in range(self_dims):
- if i != dim:
- check(
- ensure_nonempty_size(index, i) <= ensure_nonempty_size(self, i),
- lambda: f"Size does not match at dimension {i} expected index {index.shape}"
- + f" to be smaller than self {self.shape} apart from dimension {dim}",
- )
- @register_meta(aten.gather.default)
- def meta_gather(self, dim, index, sparse_grad=False):
- wrapped_dim = maybe_wrap_dim(dim, self.dim())
- is_index_empty = index.numel() == 0
- if not is_index_empty:
- check(
- index.dtype == torch.long,
- lambda: f"gather(): Expected dtype int64 for index, but got {index.dtype}",
- )
- gather_shape_check(self, wrapped_dim, index)
- return self.new_empty(index.shape)
- # From aten/src/ATen/native/TensorAdvancedIndexing.cpp
- def get_operator_enum(reduce_, use_new_options=False):
- if use_new_options:
- if reduce_ == "sum":
- return "REDUCE_ADD"
- elif reduce_ == "prod":
- return "REDUCE_MULTIPLY"
- elif reduce_ == "mean":
- return "REDUCE_MEAN"
- elif reduce_ == "amax":
- return "REDUCE_MAXIMUM"
- elif reduce_ == "amin":
- return "REDUCE_MINIMUM"
- check(
- False,
- lambda: "reduce argument must be either sum, prod, mean, amax or amin.",
- )
- return
- else:
- if reduce_ == "add":
- return "REDUCE_ADD"
- elif reduce_ == "multiply":
- return "REDUCE_MULTIPLY"
- check(False, lambda: "reduce argument must be either add or multiply.")
- return
- # From aten/src/ATen/native/ScatterGatherChecks.h
- def scatter_gather_dtype_check(method_name, self, index, src_opt=None):
- if index.numel() != 0:
- check(
- index.dtype == torch.long,
- lambda: f"{method_name}(): Expected dtype int64 for index",
- )
- if src_opt is not None:
- check(
- self.dtype == src_opt.dtype,
- lambda: f"{method_name}(): Expected self.dtype to be equal to src.dtype",
- )
- def ensure_nonempty_dim(dim):
- return max(dim, 1)
- # From aten/src/ATen/native/ScatterGatherChecks.h
- def scatter_shape_check(self, dim, index, src_opt=None):
- if index.numel() == 0:
- return
- check(
- ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
- lambda: "Index tensor must have the same number of dimensions as self tensor",
- )
- is_wrong_shape = False
- self_dims = ensure_nonempty_dim(self.dim())
- # Check: index.size(d) <= self.size(d) for all d != dim
- for d in range(self_dims):
- index_d_size = ensure_nonempty_size(index, d)
- if d == dim:
- continue
- if index_d_size > ensure_nonempty_size(self, d):
- is_wrong_shape = True
- break
- # Check: index.size(d) <= src.size(d) for all d if src is Tensor
- if not is_wrong_shape and src_opt is not None:
- for d in range(self_dims):
- index_d_size = ensure_nonempty_size(index, d)
- if index_d_size > ensure_nonempty_size(src_opt, d):
- is_wrong_shape = True
- break
- if src_opt is not None:
- check(
- ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
- lambda: "Index tensor must have the same number of dimensions as self tensor",
- )
- check(
- not is_wrong_shape,
- lambda: f"Expected index {index.shape} to be smaller than self {self.shape}"
- + f" apart from dimension {dim} and to be smaller than src {src_opt.shape}",
- )
- else:
- check(
- not is_wrong_shape,
- lambda: f"Expected index {index.shape} to be smaller than self {self.shape}"
- + f" apart from dimension {dim}",
- )
- # From aten/src/ATen/native/TensorAdvancedIndexing.cpp
- def scatter_meta_impl(self, dim, index, src=None, reduce_=None, use_new_options=False):
- wrapped_dim = maybe_wrap_dim(dim, self.dim())
- scatter_gather_dtype_check("scatter", self, index, src)
- scatter_shape_check(self, wrapped_dim, index, src)
- if reduce_ is not None:
- # Check if we have a valid reduce operator.
- get_operator_enum(reduce_, use_new_options)
- @register_meta(aten.scatter_add.default)
- def meta_scatter_add(self, dim, index, src):
- scatter_meta_impl(self, dim, index, src, "add")
- return self.new_empty(self.shape)
- @register_meta(aten.scatter_add_)
- def meta_scatter_add_(self, dim, index, src):
- scatter_meta_impl(self, dim, index, src, "add")
- return self
- @register_meta(
- [
- aten.scatter.src,
- aten.scatter.value,
- aten.scatter.reduce,
- aten.scatter.value_reduce,
- ]
- )
- @out_wrapper()
- def meta_scatter(self, dim, index, src_or_value, reduce=None):
- src = src_or_value if isinstance(src_or_value, torch.Tensor) else None
- scatter_meta_impl(self, dim, index, src, reduce)
- return self.new_empty(self.shape)
- @register_meta(
- [
- aten.scatter_.src,
- aten.scatter_.value,
- aten.scatter_.reduce,
- aten.scatter_.value_reduce,
- ]
- )
- def meta_scatter_(self, dim, index, src_or_value, reduce=None):
- src = src_or_value if isinstance(src_or_value, torch.Tensor) else None
- scatter_meta_impl(self, dim, index, src, reduce)
- return self
- @register_meta(
- [
- aten._scaled_dot_product_flash_attention,
- ]
- )
- def meta__scaled_dot_product_flash(
- query: Tensor,
- key: Tensor,
- value: Tensor,
- dropout_p: float = 0.0,
- is_causal: bool = False,
- return_debug_mask: bool = False,
- ):
- # [Note] SDPA_flash's meta function returns incorrect Philox seed and offset:
- # We have added logic to torch/_dynamo/variables/torch.py
- # We need to check if scaled_dot_product_attention will run the flash attention
- # kernel and if dropout is != 0.0. If that is the case then we want dynamo
- # to graph break. The derivative calculation for _scaled_dot_product_flash_attention
- # does not function correctly with cuda graphs because the full philox state is not captured
- # the forward's return values. Another reason to graph break is that the the meta function
- # returns the wrong outputs for philox seed and offset and these values get baked into the
- # inductor fallback calls to the eager kernels.
- check(
- dropout_p == 0.0,
- lambda: f"Can only trace _scaled_dot_product_flash_attention when dropout is set to 0 but got a dropout_p of {dropout_p}.",
- )
- batch_size = query.size(0)
- num_heads = query.size(1)
- max_seqlen_batch_q = query.size(2)
- head_dim = query.size(3)
- max_seqlen_batch_k = key.size(2)
- query = query.transpose(1, 2)
- key = key.transpose(1, 2)
- value = value.transpose(1, 2)
- Nnz_q = batch_size * max_seqlen_batch_q
- output = torch.empty(
- (Nnz_q, num_heads, head_dim), dtype=query.dtype, device=query.device
- )
- output = output.view(batch_size, max_seqlen_batch_q, num_heads, head_dim).transpose(
- 1, 2
- )
- max_seqlen_q = math.ceil(max_seqlen_batch_q / 16) * 16
- logsumexp = torch.empty(
- (batch_size, num_heads, max_seqlen_q),
- dtype=torch.float,
- device=query.device,
- )
- cumulative_sequence_length_q = torch.empty(
- batch_size + 1, dtype=torch.int32, device="meta"
- )
- cumulative_sequence_length_k = torch.empty(
- batch_size + 1, dtype=torch.int32, device="meta"
- )
- if return_debug_mask:
- blocksize_c = 128 if head_dim > 64 else 256
- max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)
- if max_seqlen_batch_k <= 128:
- max_seqlen_k = 128
- elif max_seqlen_batch_k <= 256:
- max_seqlen_k = 256
- debug_mask = torch.empty(
- (batch_size, num_heads, max_seqlen_q, max_seqlen_k),
- dtype=query.dtype,
- device=query.device,
- )
- else:
- debug_mask = torch.empty(0, dtype=query.dtype, device=query.device)
- return (
- output,
- logsumexp,
- cumulative_sequence_length_q,
- cumulative_sequence_length_k,
- max_seqlen_batch_q,
- max_seqlen_batch_k,
- 1, # Philox Seed will not be used, see note at top.
- 1, # Philox Offset will not be used, see note at top.
- debug_mask,
- )
- @register_meta(
- [
- aten._scaled_dot_product_flash_attention_backward,
- ]
- )
- def meta__scaled_dot_product_flash_backward(
- grad_out: Tensor,
- query: Tensor,
- key: Tensor,
- value: Tensor,
- out: Tensor,
- logsumexp: Tensor,
- cum_seq_q: Tensor,
- cum_seq_k: Tensor,
- max_q: int,
- max_k: int,
- dropout_p: float,
- is_causal: bool,
- philox_seed: int,
- philox_offset: int,
- ):
- batch_size = query.size(0)
- num_heads = query.size(1)
- head_dim = query.size(3)
- Nnz_q = batch_size * max_q
- Nnz_kv = batch_size * max_k
- query = query.transpose(1, 2)
- key = key.transpose(1, 2)
- value = value.transpose(1, 2)
- query_reshaped = query.reshape(Nnz_q, num_heads, head_dim)
- key_reshaped = key.reshape(Nnz_kv, num_heads, head_dim)
- value_reshaped = value.reshape(Nnz_kv, num_heads, head_dim)
- grad_q = torch.empty_like(query_reshaped)
- grad_k = torch.empty_like(key_reshaped)
- grad_v = torch.empty_like(value_reshaped)
- grad_q = grad_q.view(batch_size, max_q, num_heads, head_dim).transpose(1, 2)
- grad_k = grad_k.view(batch_size, max_k, num_heads, head_dim).transpose(1, 2)
- grad_v = grad_v.view(batch_size, max_k, num_heads, head_dim).transpose(1, 2)
- return grad_q, grad_k, grad_v
- @register_meta(
- [
- aten._scaled_dot_product_efficient_attention,
- ]
- )
- def meta__scaled_dot_product_efficient(
- query: Tensor,
- key: Tensor,
- value: Tensor,
- compute_log_sumexp: bool,
- is_causal: bool = False,
- ):
- query = query.transpose(1, 2)
- key = key.transpose(1, 2)
- value = value.transpose(1, 2)
- B = query.size(0)
- M = query.size(1)
- N = key.size(1)
- num_heads = query.size(-2)
- K = query.size(-1)
- Kv = value.size(-1)
- res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device)
- logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0
- logsum_exp = torch.empty(
- (B, num_heads, logsumexp_dim),
- dtype=torch.float,
- device=query.device,
- )
- res = res.transpose(1, 2)
- return res, logsum_exp
- @register_meta(
- [
- aten._scaled_dot_product_efficient_attention_backward,
- ]
- )
- def meta__scaled_dot_product_efficient_backward(
- grad_out: Tensor,
- query: Tensor,
- key: Tensor,
- value: Tensor,
- out: Tensor,
- logsumexp: Tensor,
- is_causal: bool = False,
- chunk_grad_outputs=False,
- ):
- grad_out = grad_out.transpose(1, 2)
- query = query.transpose(1, 2)
- key = key.transpose(1, 2)
- value = value.transpose(1, 2)
- B = query.size(0)
- M = query.size(1)
- N = key.size(1)
- nH = query.size(2)
- K = query.size(3)
- grad_kv_needs_init = is_causal and N > M
- if chunk_grad_outputs:
- chunk = torch.empty((B, M, 3, nH, K), dtype=query.dtype, device=query.device)
- grad_q = chunk.select(2, 0)
- grad_k = chunk.select(2, 1)
- grad_v = chunk.select(2, 2)
- else:
- grad_q = torch.empty(query.shape, dtype=query.dtype, device=query.device)
- grad_k = (
- torch.zeros(key.shape, dtype=key.dtype, device=key.device)
- if grad_kv_needs_init
- else torch.empty(key.shape, dtype=key.dtype, device=key.device)
- )
- grad_v = (
- torch.zeros(value.shape, dtype=value.dtype, device=value.device)
- if grad_kv_needs_init
- else torch.empty(value.shape, dtype=value.dtype, device=value.device)
- )
- return grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2)
- @register_meta([aten.scatter_reduce.two, aten.scatter_reduce.two_out])
- @out_wrapper()
- def meta_scatter_reduce_two(self, dim, index, src, reduce, include_self=True):
- scatter_meta_impl(self, dim, index, src, reduce, use_new_options=True)
- return self.new_empty(self.shape)
- @register_meta(aten.scatter_reduce_.two)
- def meta_scatter_reduce__two(self, dim, index, src, reduce, include_self=True):
- scatter_meta_impl(self, dim, index, src, reduce, use_new_options=True)
- return self
- def multiply_integers(vs):
- r = 1
- for v in vs:
- r *= v
- return r
- def upsample_common_check(input_size, output_size, num_spatial_dims):
- check(
- len(output_size) == num_spatial_dims,
- lambda: f"It is expected output_size equals to {num_spatial_dims}, but got size {len(output_size)}",
- )
- expected_input_dims = num_spatial_dims + 2 # N, C, ...
- check(
- len(input_size) == expected_input_dims,
- lambda: f"It is expected input_size equals to {expected_input_dims}, but got size {len(input_size)}",
- )
- check(
- all([s > 0 for s in input_size[2:]]) and all([s > 0 for s in output_size]),
- lambda: f"Input and output sizes should be greater than 0, but got "
- f"input size {input_size} and output size {output_size}",
- )
- nbatch, channels = input_size[:2]
- return (nbatch, channels, *output_size)
- @register_meta(aten.upsample_nearest1d.default)
- def upsample_nearest1d(input, output_size, scales=None):
- check(
- input.numel() != 0 or multiply_integers(input.size()[1:]),
- lambda: "Non-empty 3D data tensor expected but got a tensor with sizes {input.size()}",
- )
- full_output_size = upsample_common_check(
- input.size(), output_size, num_spatial_dims=1
- )
- return input.new_empty(full_output_size).to(
- memory_format=utils.suggest_memory_format(input)
- )
- @register_meta(aten.upsample_nearest2d.default)
- def upsample_nearest2d(input, output_size, scales_h=None, scales_w=None):
- check(
- input.numel() != 0 or multiply_integers(input.size()[1:]),
- lambda: "Non-empty 4D data tensor expected but got a tensor with sizes {input.size()}",
- )
- full_output_size = upsample_common_check(
- input.size(), output_size, num_spatial_dims=2
- )
- output = input.new_empty(full_output_size)
- # convert output to correct memory format, if necessary
- memory_format = utils.suggest_memory_format(input)
- # following "heuristic: only use channels_last path when it's faster than the contiguous path"
- _, n_channels, _, _ = input.shape
- if input.device.type == "cuda" and n_channels < 4:
- memory_format = torch.contiguous_format
- output = output.contiguous(memory_format=memory_format)
- return output
- @register_meta(aten.upsample_nearest3d.default)
- def upsample_nearest3d(input, output_size, scales_d=None, scales_h=None, scales_w=None):
- check(
- input.numel() != 0 or multiply_integers(input.size()[1:]),
- lambda: "Non-empty 5D data tensor expected but got a tensor with sizes {input.size()}",
- )
- full_output_size = upsample_common_check(
- input.size(), output_size, num_spatial_dims=3
- )
- return input.new_empty(full_output_size).to(
- memory_format=utils.suggest_memory_format(input)
- )
- @register_meta([aten.sort.default, aten.sort.stable])
- def meta_sort(self, stable=None, dim=-1, descending=False):
- return torch.empty_like(self), torch.empty_like(self, dtype=torch.int64)
- def rnn_cell_checkSizes(
- input_gates, hidden_gates, input_bias, hidden_bias, factor, prev_hidden
- ):
- check(input_gates.ndim == 2, lambda: f"{input_gates.ndim} != 2")
- check(
- input_gates.shape == hidden_gates.shape,
- lambda: f"{input_gates.shape} != {hidden_gates.shape}",
- )
- gates_size = input_gates.size(1)
- if input_bias is not None:
- check(input_bias.ndim == 1, lambda: f"{input_bias.ndim} != 1")
- check(
- input_bias.numel() == gates_size,
- lambda: f"{input_bias.numel()} != {gates_size}",
- )
- check(
- input_bias.shape == hidden_bias.shape,
- lambda: f"{input_bias.shape} != {hidden_bias.shape}",
- )
- check(prev_hidden.ndim == 2, lambda: f"{prev_hidden.ndim} != 2")
- expected_prev_hidden_numel = input_gates.size(0) * gates_size // factor
- check(
- prev_hidden.numel() == expected_prev_hidden_numel,
- lambda: f"{prev_hidden.numel()} != {input_gates.size(0)} * {gates_size} // {factor} (aka {expected_prev_hidden_numel})",
- )
- check(
- all(
- x.device == input_gates.device
- for x in [hidden_gates, input_bias, hidden_bias, prev_hidden]
- ),
- lambda: "expected all inputs to be same device",
- )
- @register_meta(aten._thnn_fused_lstm_cell.default)
- def _thnn_fused_lstm_cell_meta(
- input_gates, hidden_gates, cx, input_bias=None, hidden_bias=None
- ):
- rnn_cell_checkSizes(input_gates, hidden_gates, input_bias, hidden_bias, 4, cx)
- workspace = torch.empty_like(input_gates, memory_format=torch.contiguous_format)
- hy = torch.empty_like(cx, memory_format=torch.contiguous_format)
- cy = torch.empty_like(cx, memory_format=torch.contiguous_format)
- return (hy, cy, workspace)
- @register_meta(aten._cudnn_rnn.default)
- def _cudnn_rnn(
- input,
- weight,
- weight_stride0,
- weight_buf,
- hx,
- cx,
- mode,
- hidden_size,
- proj_size,
- num_layers,
- batch_first,
- dropout,
- train,
- bidirectional,
- batch_sizes,
- dropout_state,
- ):
- is_input_packed = len(batch_sizes) != 0
- if is_input_packed:
- seq_length = len(batch_sizes)
- mini_batch = batch_sizes[0]
- batch_sizes_sum = input.shape[0]
- else:
- seq_length = input.shape[1] if batch_first else input.shape[0]
- mini_batch = input.shape[0] if batch_first else input.shape[1]
- batch_sizes_sum = -1
- num_directions = 2 if bidirectional else 1
- out_size = proj_size if proj_size != 0 else hidden_size
- if is_input_packed:
- out_shape = [batch_sizes_sum, out_size * num_directions]
- else:
- out_shape = (
- [mini_batch, seq_length, out_size * num_directions]
- if batch_first
- else [seq_length, mini_batch, out_size * num_directions]
- )
- output = input.new_empty(out_shape)
- cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
- if cx is None:
- cy = torch.empty(0, device=input.device)
- else:
- cy = cx.new_empty(cell_shape)
- hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size])
- # TODO: Query cudnnGetRNNTrainingReserveSize (expose to python)
- reserve_shape = 0 if train else 0
- reserve = input.new_empty(reserve_shape, dtype=torch.uint8)
- return output, hy, cy, reserve, weight_buf
- @register_meta(aten.mkldnn_rnn_layer.default)
- def mkldnn_rnn_layer(
- input,
- w0,
- w1,
- w2,
- w3,
- hx_,
- cx_,
- reverse,
- batch_sizes,
- mode,
- hidden_size,
- num_layers,
- has_biases,
- bidirectional,
- batch_first,
- train,
- ):
- seq_length = input.shape[1] if batch_first else input.shape[0]
- mini_batch = input.shape[0] if batch_first else input.shape[1]
- output_chanels = hidden_size
- out_shape = (
- [mini_batch, seq_length, output_chanels]
- if batch_first
- else [seq_length, mini_batch, output_chanels]
- )
- output = input.new_empty(out_shape)
- if hx_ is None:
- hy = torch.empty(0, device=input.device)
- else:
- hy = hx_.new_empty(hx_.shape)
- if cx_ is None:
- cy = torch.empty(0, device=input.device)
- else:
- cy = cx_.new_empty(cx_.shape)
- workspace = torch.empty(0, device=input.device, dtype=torch.uint8)
- return output, hy, cy, workspace
- def zero_numel_check_dims(self, dim, fn_name):
- if self.ndim == 0:
- check(
- dim == 0 or dim == -1,
- lambda: f"{fn_name}: Expected reduction dim -1 or 0 for scalar but got {dim}",
- IndexError,
- )
- else:
- check(
- self.size(dim) != 0,
- lambda: f"{fn_name}: Expected reduction dim {dim} to have non-zero size.",
- IndexError,
- )
- # From aten/src/ATen/native/ReduceOps.cpp
- def check_argmax_argmin(name, self, dim):
- if dim is not None:
- dim = maybe_wrap_dim(dim, self.dim())
- zero_numel_check_dims(self, dim, name)
- else:
- check(
- self.numel() != 0,
- lambda: f"{name}: Expected reduction dim to be specified for input.numel() == 0.",
- )
- @register_meta([aten.argmax.default, aten.argmin.default])
- def argmax_argmin_meta(self, dim=None, keepdim=False):
- check_argmax_argmin("argmax", self, dim)
- dims = utils.reduction_dims(self.shape, (dim,) if dim is not None else None)
- shape = _compute_reduction_shape(self, dims, keepdim)
- return self.new_empty(shape, dtype=torch.int64)
- @register_meta(aten.scalar_tensor.default)
- def scalar_tensor(s, dtype=None, layout=None, device=None, pin_memory=None):
- return torch.empty(
- (), dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
- )
- @register_meta(aten.topk.default)
- def topk_meta(self, k, dim=-1, largest=True, sorted=True):
- # From aten/src/ATen/native/Sorting.cpp
- dim = maybe_wrap_dim(dim, self.dim(), wrap_scalar=True)
- check(
- k >= 0 and k <= (self.size(dim) if self.dim() > 0 else 1),
- lambda: "selected index k out of range",
- )
- sliceSize = 1 if self.dim() == 0 else self.size(dim)
- check(k >= 0 and k <= sliceSize, lambda: "k not in range for dimension")
- topKSize = list(self.shape)
- if len(topKSize) > 0:
- topKSize[dim] = k
- return self.new_empty(topKSize), self.new_empty(topKSize, dtype=torch.int64)
- legacy_contiguous_memory_format = torch.contiguous_format
- # From aten/src/ATen/native/cuda/RNN.cu
- def checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace):
- defined_grad = grad_hy if grad_hy is not None else grad_cy
- check(defined_grad.dim() == 2, lambda: "")
- exp_size = defined_grad.size()
- if grad_hy is not None:
- check(grad_hy.size() == exp_size, lambda: "")
- if grad_cy is not None:
- check(grad_cy.size() == exp_size, lambda: "")
- check(cx.size() == exp_size, lambda: "")
- check(cy.size() == exp_size, lambda: "")
- check(workspace.dim() == 2, lambda: "")
- check(workspace.numel() == exp_size[0] * exp_size[1] * 4, lambda: "")
- # From aten/src/ATen/native/cuda/RNN.cu
- @register_meta(aten._thnn_fused_lstm_cell_backward_impl.default)
- def _thnn_fused_lstm_cell_backward_impl(grad_hy, grad_cy, cx, cy, workspace, has_bias):
- if grad_hy is None and grad_cy is None:
- return None, None, None
- checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace)
- grad_gates = torch.empty_like(
- workspace, memory_format=legacy_contiguous_memory_format
- )
- grad_cx = torch.empty_like(cx, memory_format=legacy_contiguous_memory_format)
- grad_bias = grad_gates.sum(0, keepdim=False) if has_bias else None
- return grad_gates, grad_cx, grad_bias
- @register_meta(aten.pixel_shuffle.default)
- def meta_pixel_shuffle(self, upscale_factor):
- assert (
- len(self.shape) > 2 and self.shape[-3] % (upscale_factor * upscale_factor) == 0
- ), f"Invalid input shape for pixel_shuffle: {self.shape} with upscale_factor = {upscale_factor}"
- def is_channels_last(ten):
- return torch._prims_common.suggest_memory_format(ten) == torch.channels_last
- def pick_memory_format():
- if is_channels_last(self):
- if device_hint(self) == "cuda":
- return torch.contiguous_format
- else:
- return torch.channels_last
- elif self.is_contiguous(memory_format=torch.contiguous_format):
- return torch.contiguous_format
- elif self.is_contiguous(memory_format=torch.preserve_format):
- return torch.preserve_format
- C = self.shape[-3] // (upscale_factor * upscale_factor)
- Hr = self.shape[-2] * upscale_factor
- Wr = self.shape[-1] * upscale_factor
- out_shape = (*self.shape[:-3], C, Hr, Wr)
- out = self.new_empty(out_shape)
- out = out.to(memory_format=pick_memory_format()) # type: ignore[call-overload]
- return out
- @register_meta(aten.mkldnn_rnn_layer_backward.default)
- def mkldnn_rnn_layer_backward(
- input,
- weight0,
- weight1,
- weight2,
- weight3,
- hx_,
- cx_tmp,
- output,
- hy_,
- cy_,
- grad_output_r_opt,
- grad_hy_r_opt,
- grad_cy_r_opt,
- reverse,
- mode,
- hidden_size,
- num_layers,
- has_biases,
- train,
- bidirectional,
- batch_sizes,
- batch_first,
- workspace,
- ):
- diff_x = input.new_empty(input.shape)
- diff_hx = hx_.new_empty(hx_.shape)
- diff_cx = cx_tmp.new_empty(cx_tmp.shape)
- diff_w1 = weight0.new_empty(weight0.shape)
- diff_w2 = weight1.new_empty(weight1.shape)
- diff_b = weight2.new_empty(weight2.shape)
- return diff_x, diff_w1, diff_w2, diff_b, diff_b, diff_hx, diff_cx
- @register_meta([aten.bucketize.Tensor, aten.bucketize.Tensor_out])
- @out_wrapper()
- def meta_bucketize(self, boundaries, *, out_int32=False, right=False):
- return torch.empty_like(
- self, dtype=torch.int32 if out_int32 else torch.int64
- ).contiguous()
- # We must also trigger meta registrations from PrimTorch ref
- # decompositions
- import torch._refs
- import torch._refs.nn.functional
- import torch._refs.special
- def activate_meta():
- activate_meta_table = {}
- # For a given op, we pick the most specific decomp function from
- # global_decomp_table in the precedence order of meta > post_autograd > pre_autograd
- for type in ["meta", "post_autograd", "pre_autograd"]:
- registry = global_decomposition_table[type]
- for opo in registry:
- if opo not in activate_meta_table:
- activate_meta_table[opo] = registry[opo]
- for op_overload, fn in activate_meta_table.items():
- assert isinstance(op_overload, OpOverload)
- op_overload.py_impl(torch._C.DispatchKey.Meta)(fn)
- if torch._C._dispatch_has_kernel_for_dispatch_key(
- op_overload.name(), "CompositeImplicitAutograd"
- ):
- # Internally, we shouldn't be registering meta kernels for any operators that
- # have CompositeImplicitAutograd kernels.
- # Instead, we should be letting those decompositions run, and writing meta kernels
- # only for the base operators.
- if op_overload in global_decomposition_table["meta"]:
- raise RuntimeError(
- f"{op_overload} is a CompositeImplicitAutograd op, we shouldn't "
- "register meta function for it. Instead, we should let the decomposition run and write "
- "meta kernels for the base operators."
- )
- pass
- elif op_overload.is_view:
- # Attempting to register a python meta kernel for a view operator.
- # We shouldn't do this, because the output will report as not having aliased storages.
- # All view ops have meta kernels in C++ today, so we should use those instead.
- pass
- elif op_overload.name() in {
- "aten::empty_strided", # causing infinite recursion, test_meta.py
- "aten::clone", # causing infinite recursion
- "aten::_to_copy", # causing infinite recursion, test_serialization.py -k test_tensor_subclass_getstate_overwrite # noqa: B950
- "aten::copy_", # Exception not raised, test_torch.py -k test_storage_meta_errors_cpu_int64 # noqa: B950
- "aten::constant_pad_nd", # requires_grad mismatch, test_ops.py -k test_fake_crossref_backward_amp_istft_cuda_float32 # noqa: B950
- "aten::rot90", # requires_grad mismatch! test_ops.py -k test_fake_crossref_backward_amp_rot90_cuda_float32 # noqa: B950
- "aten::as_strided_scatter", # requires_grad mismatch, test_ops.py -k test_fake_crossref_backward_no_amp_as_strided_scatter_cuda_float32 # noqa: B950
- }:
- pass
- else:
- if "mkldnn::" in op_overload.name():
- _meta_lib_dont_use_me_use_register_meta_for_mkldnn.impl(op_overload, fn)
- elif "mkl::" in op_overload.name():
- _meta_lib_dont_use_me_use_register_meta_for_mkl.impl(op_overload, fn)
- else:
- _meta_lib_dont_use_me_use_register_meta.impl(op_overload, fn)
- activate_meta()
|