_meta_registrations.py 83 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705
  1. import math
  2. from typing import List, Optional, Union
  3. import torch
  4. import torch._prims_common as utils
  5. from torch import Tensor
  6. from torch._decomp import _add_op_to_registry, global_decomposition_table, meta_table
  7. from torch._ops import OpOverload
  8. from torch._prims import _elementwise_meta, ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND
  9. from torch._prims_common import (
  10. check,
  11. corresponding_complex_dtype,
  12. corresponding_real_dtype,
  13. elementwise_dtypes,
  14. ELEMENTWISE_TYPE_PROMOTION_KIND,
  15. IntLike,
  16. make_contiguous_strides_for,
  17. )
  18. from torch._prims_common.wrappers import out_wrapper
  19. from torch._refs import _broadcast_shapes
  20. from torch._subclasses.fake_tensor import check_no_bool_index_tensors
  21. from torch.utils._pytree import tree_map
  22. aten = torch.ops.aten
  23. _meta_lib_dont_use_me_use_register_meta = torch.library.Library("aten", "IMPL", "Meta")
  24. def register_meta(op):
  25. def wrapper(fn):
  26. def register(op):
  27. _add_op_to_registry(meta_table, op, fn)
  28. tree_map(register, op)
  29. return fn
  30. return wrapper
  31. def toRealValueType(dtype):
  32. from_complex = {
  33. torch.complex32: torch.half,
  34. torch.cfloat: torch.float,
  35. torch.cdouble: torch.double,
  36. }
  37. return from_complex.get(dtype, dtype)
  38. @register_meta([aten._fft_c2c.default, aten._fft_c2c.out])
  39. @out_wrapper()
  40. def meta_fft_c2c(self, dim, normalization, forward):
  41. assert self.dtype.is_complex
  42. return self.new_empty(self.size())
  43. @register_meta([aten._fft_r2c.default, aten._fft_r2c.out])
  44. @out_wrapper()
  45. def meta_fft_r2c(self, dim, normalization, onesided):
  46. assert self.dtype.is_floating_point
  47. output_sizes = list(self.size())
  48. if onesided:
  49. last_dim = dim[-1]
  50. last_dim_halfsize = (output_sizes[last_dim] // 2) + 1
  51. output_sizes[last_dim] = last_dim_halfsize
  52. return self.new_empty(
  53. output_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
  54. )
  55. @register_meta(aten.randperm.generator_out)
  56. def meta_randperm(n, *, generator=None, out):
  57. assert out.ndim == 1 and out.size(0) == n
  58. return out
  59. @register_meta(aten.randint.default)
  60. def meta_randint(
  61. high, size, *, dtype=torch.long, layout=None, device=None, pin_memory=None
  62. ):
  63. return torch.empty(
  64. size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
  65. )
  66. @register_meta(aten.randint.low)
  67. def meta_randint_low(
  68. low, high, size, *, dtype=torch.long, layout=None, device=None, pin_memory=None
  69. ):
  70. return torch.empty(
  71. size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
  72. )
  73. @register_meta(aten.rand.default)
  74. def meta_rand_default(size, *, dtype=None, layout=None, device=None, pin_memory=None):
  75. return torch.empty(
  76. size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
  77. )
  78. @register_meta([aten._fft_c2r.default, aten._fft_c2r.out])
  79. @out_wrapper()
  80. def meta_fft_c2r(self, dim, normalization, lastdim):
  81. assert self.dtype.is_complex
  82. output_sizes = list(self.size())
  83. output_sizes[dim[-1]] = lastdim
  84. return self.new_empty(output_sizes, dtype=toRealValueType(self.dtype))
  85. @register_meta(aten.copy_.default)
  86. def meta_copy_(self, src, non_blocking=False):
  87. return self
  88. def inferUnsqueezeGeometry(tensor, dim):
  89. result_sizes = list(tensor.size())
  90. result_strides = list(tensor.stride())
  91. new_stride = 1 if dim >= tensor.dim() else result_sizes[dim] * result_strides[dim]
  92. result_sizes.insert(dim, 1)
  93. result_strides.insert(dim, new_stride)
  94. return result_sizes, result_strides
  95. @register_meta(aten.unsqueeze_.default)
  96. def meta_unsqueeze_(self, dim):
  97. dim = maybe_wrap_dim(dim, self.dim() + 1)
  98. g_sizes, g_strides = inferUnsqueezeGeometry(self, dim)
  99. self.as_strided_(g_sizes, g_strides)
  100. return self
  101. # Implementations below are taken from https://github.com/albanD/subclass_zoo/blob/main/python_meta_tensor.py
  102. @register_meta(aten.index_select.default)
  103. def meta_index_select(self, dim, index):
  104. result_size = list(self.size())
  105. if self.dim() > 0:
  106. result_size[dim] = index.numel()
  107. return self.new_empty(result_size)
  108. @register_meta(aten.index_select.out)
  109. def meta_index_select_out(self, dim, index, out):
  110. torch._resize_output_(out, self.size(), self.device)
  111. return out.copy_(torch.index_select(self, dim, index))
  112. @register_meta([aten.max.default, aten.max.unary_out])
  113. @out_wrapper()
  114. def meta_max(self):
  115. return self.new_empty(())
  116. @register_meta(aten.max.dim)
  117. def meta_max_dim(self, dim, keepdim=False):
  118. dim = utils.reduction_dims(self.shape, (dim,))
  119. output_shape = _compute_reduction_shape(self, dim, keepdim)
  120. return (
  121. self.new_empty(output_shape),
  122. self.new_empty(output_shape, dtype=torch.long),
  123. )
  124. @register_meta([aten.min.default])
  125. def meta_min(self):
  126. return self.new_empty(())
  127. @register_meta(aten.angle.default)
  128. def meta_angle(self):
  129. if self.is_complex():
  130. result_dtype = corresponding_real_dtype(self.dtype)
  131. else:
  132. _, result_dtype = elementwise_dtypes(
  133. self, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
  134. )
  135. return torch.empty_like(self, dtype=result_dtype)
  136. @register_meta(aten.angle.out)
  137. def meta_angle_out(self, out):
  138. torch._resize_output_(out, self.size(), self.device)
  139. return out.copy_(torch.angle(self))
  140. # From aten/src/ATen/native/LinearAlgebraUtils.h
  141. def squareCheckInputs(self: Tensor, f_name: str):
  142. assert (
  143. self.dim() >= 2
  144. ), f"{f_name}: The input tensor must have at least 2 dimensions."
  145. assert self.size(-1) == self.size(
  146. -2
  147. ), f"{f_name}: A must be batches of square matrices, but they are {self.size(-2)} by {self.size(-1)} matrices"
  148. # From aten/src/ATen/native/LinearAlgebraUtils.h
  149. def checkFloatingOrComplex(
  150. t: Tensor, f_name: str, allow_low_precision_dtypes: bool = True
  151. ):
  152. dtype = t.dtype
  153. check(
  154. t.is_floating_point() or t.is_complex(),
  155. lambda: f"{f_name}, : Expected a floating point or complex tensor as input. Got , {dtype}",
  156. )
  157. if allow_low_precision_dtypes:
  158. check(
  159. dtype in (torch.float, torch.double, torch.cfloat, torch.cdouble),
  160. lambda: f"{f_name} : Low precision dtypes not supported. Got {dtype}",
  161. )
  162. # From aten/src/ATen/native/LinearAlgebraUtils.h
  163. def checkIsMatrix(A: Tensor, f_name: str, arg_name: str = "A"):
  164. check(
  165. A.dim() >= 2,
  166. lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.",
  167. )
  168. def checkUplo(uplo: str):
  169. uplo_uppercase = uplo.upper()
  170. assert (
  171. len(uplo) == 1 and uplo_uppercase == "U" or uplo_uppercase == "L"
  172. ), f"Expected UPLO argument to be 'L' or 'U', but got {uplo}"
  173. # @register_meta(aten.linalg_eigh.default)
  174. def meta_linalg_eigh(self, uplo="L"):
  175. squareCheckInputs(self, "linalg_eigh")
  176. checkUplo(uplo)
  177. real_dtype = toRealValueType(self.dtype)
  178. assert self.dim() >= 2
  179. values = self.new_empty(self.shape, dtype=real_dtype)
  180. values.transpose_(-2, -1)
  181. vectors = self.new_empty(self.shape[:-1])
  182. return (values, vectors)
  183. # From aten/src/ATen/native/BatchLinearAlgebra.cpp
  184. @register_meta(aten.linalg_cholesky_ex.default)
  185. def linalg_cholesky_ex(A: Tensor, upper: bool = False, check_errors: bool = False):
  186. squareCheckInputs(A, "linalg.cholesky")
  187. checkFloatingOrComplex(A, "linalg.cholesky")
  188. A_shape = A.shape
  189. ndim = len(A_shape)
  190. # L
  191. L_strides = make_contiguous_strides_for(A_shape, False)
  192. L = A.new_empty(A_shape)
  193. L.as_strided_(A_shape, L_strides)
  194. # infos
  195. infos = A.new_empty(A_shape[0 : ndim - 2], dtype=torch.int32)
  196. return L, infos
  197. # From aten/src/ATen/native/BatchLinearAlgebra.cpp
  198. @register_meta(aten.linalg_inv_ex.default)
  199. def linalg_inv_ex_meta(A: Tensor, check_errors: bool = False):
  200. squareCheckInputs(A, "linalg.inv_ex")
  201. checkFloatingOrComplex(A, "linalg.inv_ex", allow_low_precision_dtypes=False)
  202. L = A.new_empty(A.shape)
  203. L.as_strided_(A.shape, make_contiguous_strides_for(A.shape, row_major=False))
  204. infos = A.new_empty(A.shape[:-2], dtype=torch.int32)
  205. return L, infos
  206. # From aten/src/ATen/native/BatchLinearAlgebra.cpp
  207. # NOTE: matching defaults in aten/src/ATen/native/native_functions.yaml
  208. @register_meta(aten._linalg_svd.default)
  209. def _linalg_svd_meta(
  210. A: Tensor, full_matrices: bool = False, compute_uv: bool = True, driver: str = None
  211. ):
  212. checkIsMatrix(A, "linalg.svd")
  213. checkFloatingOrComplex(A, "linalg.svd")
  214. batch_dims = list(A.shape[:-2])
  215. m = A.shape[-2]
  216. n = A.shape[-1]
  217. k = min(m, n)
  218. if compute_uv:
  219. U_shape = batch_dims + [m, m if full_matrices else k]
  220. U = A.new_empty(U_shape)
  221. U.as_strided_(U_shape, make_contiguous_strides_for(U_shape, row_major=False))
  222. V_shape = batch_dims + [n if full_matrices else k, n]
  223. V = A.new_empty(V_shape)
  224. # TODO: need to distinguish cuSOLVER case? (see original code)
  225. V.as_strided_(V_shape, make_contiguous_strides_for(V_shape, row_major=False))
  226. else:
  227. # doesn't matter
  228. U = A.new_empty([0])
  229. V = A.new_empty([0])
  230. # S is always real, even when A is complex.
  231. S = A.new_empty(batch_dims + [k], dtype=toRealValueType(A.dtype))
  232. return U, S, V
  233. # From aten/src/ATen/native/LinearAlgebra.cpp
  234. @register_meta(aten._linalg_det.default)
  235. def _linalg_det_meta(A):
  236. squareCheckInputs(A, "linalg.det")
  237. checkFloatingOrComplex(A, "linalg.det")
  238. det = A.new_empty(A.shape[:-2])
  239. LU = A.new_empty(A.shape)
  240. LU.as_strided_(A.shape, make_contiguous_strides_for(A.shape, row_major=False))
  241. pivots = A.new_empty(A.shape[:-1], dtype=torch.int32)
  242. return det, LU, pivots
  243. # From aten/src/ATen/native/ReflectionPad.cpp
  244. @register_meta(
  245. [aten.reflection_pad2d_backward.default, aten.replication_pad2d_backward.default]
  246. )
  247. def meta_pad2d_backward(grad_output, self, padding):
  248. dim_w = 2
  249. dim_h = 1
  250. dim_plane = 0
  251. nbatch = 1
  252. self_shape = self.shape
  253. if self.dim() == 4:
  254. nbatch = self_shape[0]
  255. dim_w += 1
  256. dim_h += 1
  257. dim_plane += 1
  258. pad_l = padding[0]
  259. pad_r = padding[1]
  260. pad_t = padding[2]
  261. pad_b = padding[3]
  262. nplane = self_shape[dim_plane]
  263. input_h = self_shape[dim_h]
  264. input_w = self_shape[dim_w]
  265. output_h = input_h + pad_t + pad_b
  266. output_w = input_w + pad_l + pad_r
  267. check(
  268. output_w == grad_output.shape[dim_w],
  269. lambda: f"gradOutput width unexpected. Expected: {output_w}, Got: {grad_output.shape[dim_w]}",
  270. )
  271. check(
  272. output_h == grad_output.shape[dim_h],
  273. lambda: f"gradOutput height unexpected. Expected: {output_h}, Got: {grad_output.shape[dim_h]}",
  274. )
  275. return self.new_empty(self.shape)
  276. @register_meta(aten.reflection_pad2d.default)
  277. def meta_pad2d(self, padding):
  278. valid_dims = self.size(1) != 0 and self.size(2) != 0
  279. check(
  280. (self.ndim == 3 and valid_dims)
  281. or (self.ndim == 4 and valid_dims and self.size(3) != 0),
  282. lambda: f"3D or 4D (batch mode) tensor expected for input, but got: {self}",
  283. )
  284. if self.ndim == 4:
  285. nbatch, nplane, input_h, input_w = self.shape
  286. else:
  287. nbatch = 1
  288. nplane, input_h, input_w = self.shape
  289. pad_l, pad_r, pad_t, pad_b = padding
  290. output_h = input_h + pad_t + pad_b
  291. output_w = input_w + pad_l + pad_r
  292. if self.ndim == 3:
  293. return self.new_empty((nplane, output_h, output_w))
  294. else:
  295. return self.new_empty((nbatch, nplane, output_h, output_w))
  296. @register_meta([aten.bernoulli.default, aten.bernoulli.out])
  297. @out_wrapper()
  298. def meta_bernoulli(self, *, generator=None):
  299. # https://github.com/pytorch/pytorch/issues/88612
  300. return torch.empty_like(self).contiguous()
  301. @register_meta(aten.bernoulli_.float)
  302. def meta_bernoulli_(self, p=0.5, generator=None):
  303. return self
  304. @register_meta(aten.bernoulli.p)
  305. def meta_bernoulli_p(self, p=0.5, generator=None):
  306. # https://github.com/pytorch/pytorch/issues/88612
  307. return torch.empty_like(self).contiguous()
  308. @register_meta(aten._fused_moving_avg_obs_fq_helper.default)
  309. def meta__fused_moving_avg_obs_fq_helper(
  310. self,
  311. observer_on,
  312. fake_quant_on,
  313. running_min,
  314. running_max,
  315. scale,
  316. zero_point,
  317. averaging_const,
  318. quant_min,
  319. quant_max,
  320. ch_axis,
  321. per_row_fake_quant=False,
  322. symmetric_quant=False,
  323. ):
  324. check(
  325. ch_axis < self.dim(),
  326. lambda: "Error in fused_moving_avg_obs_fake_quant_cpu: ch_axis must be < self.dim()",
  327. )
  328. mask = torch.empty_like(self, dtype=torch.bool)
  329. return (torch.empty_like(self), mask)
  330. def dot_check(self, other):
  331. check(
  332. self.dim() == 1 and other.dim() == 1,
  333. lambda: f"1D tensors expected, but got {self.dim()}D and {other.dim()}D tensors",
  334. )
  335. @register_meta(aten.dot.default)
  336. def meta_dot(self, tensor):
  337. dot_check(self, tensor)
  338. return self.new_empty(())
  339. @register_meta([aten.mm.default])
  340. def meta_mm(a, b):
  341. check(a.dim() == 2, lambda: "a must be 2D")
  342. check(b.dim() == 2, lambda: "b must be 2D")
  343. N, M1 = a.shape
  344. M2, P = b.shape
  345. check(M1 == M2, lambda: "a and b must have same reduction dim")
  346. return a.new_empty(N, P)
  347. def _compute_reduction_shape(self, dims, keepdim):
  348. if keepdim:
  349. return tuple(self.shape[i] if i not in dims else 1 for i in range(self.ndim))
  350. return utils.compute_reduction_output_shape(self.shape, dims)
  351. # FakeTensors (meta tensors with a device) will report device as meta
  352. # when running meta kernels. Here, access the "fake device" of FakeTensor if it
  353. # exists so meta kernels which have diverge per device will be more
  354. # accurate when run with FakeTensors
  355. def device_hint(tensor) -> "str":
  356. if isinstance(tensor, torch._subclasses.FakeTensor):
  357. return tensor.fake_device.type
  358. else:
  359. return "cuda" # default to cuda
  360. def calc_conv_nd_return_shape(
  361. input_tensor: torch.Tensor,
  362. weight: torch.Tensor,
  363. stride: Union[List[int], int],
  364. padding: Union[List[int], int],
  365. dilation: Union[List[int], int],
  366. is_transposed: bool,
  367. groups: int,
  368. output_padding: Optional[Union[List[int], int]] = None,
  369. ):
  370. def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
  371. """
  372. Formula to apply to calculate the length of some dimension of the output
  373. See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
  374. Args:
  375. ln: length of the dimension
  376. p: padding in that dim
  377. d: dilation in that dim
  378. k: kernel size in that dim
  379. s: stride in that dim
  380. Returns:
  381. The output length
  382. """
  383. return (ln + 2 * p - d * (k - 1) - 1) // s + 1
  384. def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int:
  385. """
  386. Formula to apply to calculate the length of some dimension of the output
  387. if transposed convolution is used.
  388. See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
  389. Args:
  390. ln: length of the dimension
  391. p: padding in that dim
  392. d: dilation in that dim
  393. k: kernel size in that dim
  394. s: stride in that dim
  395. op: output padding in that dim
  396. Returns:
  397. The output length
  398. """
  399. return (ln - 1) * s - 2 * p + d * (k - 1) + op + 1
  400. kernel_size = weight.shape[2:]
  401. dims = input_tensor.shape[2:]
  402. if is_transposed:
  403. out_channels = groups * weight.shape[1]
  404. else:
  405. out_channels = weight.shape[0]
  406. if weight.shape[1] * groups != input_tensor.shape[1]:
  407. raise RuntimeError("Invalid channel dimensions")
  408. ret_shape = [input_tensor.shape[0], out_channels]
  409. if isinstance(stride, IntLike):
  410. stride = [stride] * len(dims)
  411. elif len(stride) == 1:
  412. stride = [stride[0]] * len(dims)
  413. if isinstance(padding, IntLike):
  414. padding = [padding] * len(dims)
  415. elif len(padding) == 1:
  416. padding = [padding[0]] * len(dims)
  417. if isinstance(dilation, IntLike):
  418. dilation = [dilation] * len(dims)
  419. elif len(dilation) == 1:
  420. dilation = [dilation[0]] * len(dims)
  421. output_padding_list: Optional[List[int]] = None
  422. if output_padding:
  423. if isinstance(output_padding, IntLike):
  424. output_padding_list = [output_padding] * len(dims)
  425. elif len(output_padding) == 1:
  426. output_padding_list = [output_padding[0]] * len(dims)
  427. else:
  428. output_padding_list = output_padding
  429. for i in range(len(dims)):
  430. # If output_padding is present, we are dealing with a transposed convolution
  431. if output_padding_list:
  432. ret_shape.append(
  433. _formula_transposed(
  434. dims[i],
  435. padding[i],
  436. dilation[i],
  437. kernel_size[i],
  438. stride[i],
  439. output_padding_list[i],
  440. )
  441. )
  442. else:
  443. ret_shape.append(
  444. _formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i])
  445. )
  446. return ret_shape
  447. def is_channels_last(ten):
  448. return torch._prims_common.suggest_memory_format(ten) == torch.channels_last
  449. @register_meta(aten.convolution.default)
  450. def meta_conv(
  451. input_tensor: torch.Tensor,
  452. weight: torch.Tensor,
  453. bias: torch.Tensor,
  454. stride: List[int],
  455. padding: List[int],
  456. dilation: List[int],
  457. is_transposed: bool,
  458. output_padding: List[int],
  459. groups: int,
  460. ):
  461. def pick_memory_format():
  462. if device_hint(input_tensor) == "cuda":
  463. if is_channels_last(input_tensor) or is_channels_last(weight):
  464. return torch.channels_last
  465. else:
  466. if is_channels_last(input_tensor):
  467. return torch.channels_last
  468. if input_tensor.is_contiguous(memory_format=torch.contiguous_format):
  469. return torch.contiguous_format
  470. elif input_tensor.is_contiguous(memory_format=torch.preserve_format):
  471. return torch.preserve_format
  472. shape_out = calc_conv_nd_return_shape(
  473. input_tensor,
  474. weight,
  475. stride,
  476. padding,
  477. dilation,
  478. is_transposed,
  479. groups,
  480. output_padding if is_transposed else None,
  481. )
  482. out = input_tensor.new_empty(shape_out)
  483. out = out.to(memory_format=pick_memory_format()) # type: ignore[call-overload]
  484. return out
  485. if torch._C.has_mkldnn:
  486. _meta_lib_dont_use_me_use_register_meta_for_mkldnn = torch.library.Library(
  487. "mkldnn", "IMPL", "Meta"
  488. )
  489. def pick_mkldnn_conv_memory_format(input_tensor, weight):
  490. if weight.is_mkldnn:
  491. return torch.channels_last
  492. if is_channels_last(input_tensor) or is_channels_last(weight):
  493. return torch.channels_last
  494. if input_tensor.is_contiguous(memory_format=torch.contiguous_format):
  495. return torch.contiguous_format
  496. elif input_tensor.is_contiguous(memory_format=torch.preserve_format):
  497. return torch.preserve_format
  498. @register_meta(torch.ops.mkldnn._convolution_pointwise.default)
  499. def meta_mkldnn_convolution_default(
  500. input_tensor,
  501. weight,
  502. bias,
  503. padding,
  504. stride,
  505. dilation,
  506. groups,
  507. attr,
  508. scalars,
  509. algorithm,
  510. ):
  511. shape_out = calc_conv_nd_return_shape(
  512. input_tensor, weight, stride, padding, dilation, False, groups, []
  513. )
  514. out = input_tensor.new_empty(shape_out)
  515. out_memory_format = torch.channels_last
  516. out = out.to(memory_format=out_memory_format) # type: ignore[call-overload]
  517. return out
  518. @register_meta(torch.ops.mkldnn._convolution_pointwise.binary)
  519. def meta_mkldnn_convolution_binary(
  520. input_tensor,
  521. other,
  522. weight,
  523. bias,
  524. padding,
  525. stride,
  526. dilation,
  527. groups,
  528. binary_attr,
  529. alpha,
  530. unary_attr,
  531. unary_scalars,
  532. unary_algorithm,
  533. ):
  534. out = input_tensor.new_empty(other.size())
  535. out = out.to(memory_format=torch.channels_last) # type: ignore[call-overload]
  536. return out
  537. @register_meta(torch.ops.mkldnn._convolution_pointwise_.binary)
  538. def meta_mkldnn_convolution_binary_inplace(
  539. input_tensor,
  540. other,
  541. weight,
  542. bias,
  543. padding,
  544. stride,
  545. dilation,
  546. groups,
  547. binary_attr,
  548. alpha,
  549. unary_attr,
  550. unary_scalars,
  551. unary_algorithm,
  552. ):
  553. return other
  554. @register_meta(torch.ops.mkldnn._linear_pointwise.default)
  555. def meta_linear_pointwise_default(
  556. input_tensor, weight, bias, attr, scalars, algorithm
  557. ):
  558. return input_tensor.new_empty((*input_tensor.shape[:-1], weight.shape[0]))
  559. @register_meta(torch.ops.mkldnn._linear_pointwise.binary)
  560. def meta_linear_pointwise_binary(input_tensor, other, weight, bias, attr):
  561. out = input_tensor.new_empty(other.size())
  562. return out
  563. if torch._C.has_mkl:
  564. _meta_lib_dont_use_me_use_register_meta_for_mkl = torch.library.Library(
  565. "mkl", "IMPL", "Meta"
  566. )
  567. @register_meta(torch.ops.mkl._mkl_linear)
  568. def meta_mkl_linear(
  569. input_tensor,
  570. packed_weight,
  571. orig_weight,
  572. bias,
  573. batch_size,
  574. ):
  575. return input_tensor.new_empty(
  576. (*input_tensor.shape[:-1], orig_weight.shape[0])
  577. )
  578. # from check_dim_size() in aten/src/ATen/TensorUtils.cpp.
  579. def check_dim_size(tensor, dim, dim_size, size):
  580. check(
  581. tensor.dim() == dim and tensor.shape[dim_size] == size,
  582. lambda: f"Expected a tensor of dimension {dim} and tensor.size[{dim_size}] == {size}, "
  583. + f"but got : dimension {tensor.dim()} and tensor.size[{dim_size}] = {tensor.shape[dim_size]}",
  584. )
  585. @register_meta(aten.avg_pool2d.default)
  586. def meta_avg_pool2d(
  587. input,
  588. kernel_size,
  589. stride=(),
  590. padding=(0,),
  591. ceil_mode=False,
  592. count_include_pad=True,
  593. divisor_override=None,
  594. ):
  595. def unpack(name, val):
  596. check(
  597. len(val) in [1, 2],
  598. lambda: f"avg_pool2d: {name} must either be a single int, or a tuple of two ints",
  599. )
  600. H = val[0]
  601. W = H if len(val) == 1 else val[1]
  602. return H, W
  603. kH, kW = unpack("kernel_size", kernel_size)
  604. check(
  605. len(stride) in [0, 1, 2],
  606. lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
  607. )
  608. if len(stride) == 0:
  609. dH, dW = kH, kW
  610. elif len(stride) == 1:
  611. dH, dW = stride[0], stride[0]
  612. else:
  613. dH, dW = unpack("stride", stride)
  614. padH, padW = unpack("padding", padding)
  615. check(
  616. divisor_override is None or divisor_override != 0,
  617. lambda: "divisor must be not zero",
  618. )
  619. nbatch = input.size(-4) if input.dim() == 4 else 1
  620. nInputPlane = input.size(-3)
  621. inputHeight = input.size(-2)
  622. inputWidth = input.size(-1)
  623. outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode)
  624. outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode)
  625. memory_format = utils.suggest_memory_format(input)
  626. pool2d_shape_check(
  627. input,
  628. kH,
  629. kW,
  630. dH,
  631. dW,
  632. padH,
  633. padW,
  634. 1,
  635. 1,
  636. nInputPlane,
  637. inputHeight,
  638. inputWidth,
  639. outputHeight,
  640. outputWidth,
  641. memory_format,
  642. )
  643. if input.dim() == 3:
  644. size = [nInputPlane, outputHeight, outputWidth]
  645. else:
  646. size = [nbatch, nInputPlane, outputHeight, outputWidth]
  647. return torch.empty(
  648. size, dtype=input.dtype, device=input.device, memory_format=memory_format
  649. )
  650. # from avg_pool2d_backward_shape_check() in aten/src/ATen/native/Pool.h.
  651. def avg_pool2d_backward_shape_check(
  652. input,
  653. gradOutput,
  654. nbatch,
  655. kH,
  656. kW,
  657. dH,
  658. dW,
  659. padH,
  660. padW,
  661. nInputPlane,
  662. inputHeight,
  663. inputWidth,
  664. outputHeight,
  665. outputWidth,
  666. mem_format,
  667. ):
  668. pool2d_shape_check(
  669. input,
  670. kH,
  671. kW,
  672. dH,
  673. dW,
  674. padH,
  675. padW,
  676. 1,
  677. 1,
  678. nInputPlane,
  679. inputHeight,
  680. inputWidth,
  681. outputHeight,
  682. outputWidth,
  683. mem_format,
  684. )
  685. ndim = input.dim()
  686. nOutputPlane = nInputPlane
  687. check_dim_size(gradOutput, ndim, ndim - 3, nOutputPlane)
  688. check_dim_size(gradOutput, ndim, ndim - 2, outputHeight)
  689. check_dim_size(gradOutput, ndim, ndim - 1, outputWidth)
  690. # Don't override the C++ registration.
  691. @register_meta(aten.avg_pool2d_backward.default)
  692. def meta_avg_pool2d_backward(
  693. gradOutput_,
  694. input,
  695. kernel_size,
  696. stride,
  697. padding,
  698. ceil_mode,
  699. count_include_pad,
  700. divisor_override,
  701. ):
  702. # From aten/src/ATen/native/AveragePool2d.cpp structured kernel meta func.
  703. check(
  704. len(kernel_size) == 1 or len(kernel_size) == 2,
  705. lambda: "avg_pool2d: kernel_size must either be a single int, or a tuple of two ints",
  706. )
  707. kH = kernel_size[0]
  708. kW = kH if len(kernel_size) == 1 else kernel_size[1]
  709. check(
  710. len(stride) == 0 or len(stride) == 1 or len(stride) == 2,
  711. lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
  712. )
  713. dH = kH if len(stride) == 0 else stride[0]
  714. dW = kW if len(stride) == 0 else dH if len(stride) == 1 else stride[1]
  715. check(
  716. len(padding) == 1 or len(padding) == 2,
  717. lambda: "avg_pool2d: padding must either be a single int, or a tuple of two ints",
  718. )
  719. padH = padding[0]
  720. padW = padH if len(padding) == 1 else padding[1]
  721. check(
  722. divisor_override is None or divisor_override != 0,
  723. lambda: "divisor must be not zero",
  724. )
  725. input_size = input.shape
  726. nbatch = input_size[-4] if input.dim() == 4 else 1
  727. nInputPlane = input_size[-3]
  728. inputHeight = input_size[-2]
  729. inputWidth = input_size[-1]
  730. outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode)
  731. outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode)
  732. mem_format = utils.suggest_memory_format(input)
  733. avg_pool2d_backward_shape_check(
  734. input,
  735. gradOutput_,
  736. nbatch,
  737. kH,
  738. kW,
  739. dH,
  740. dW,
  741. padH,
  742. padW,
  743. nInputPlane,
  744. inputHeight,
  745. inputWidth,
  746. outputHeight,
  747. outputWidth,
  748. mem_format,
  749. )
  750. return torch.empty(
  751. input_size, dtype=input.dtype, device=input.device, memory_format=mem_format
  752. )
  753. @register_meta(aten._adaptive_avg_pool2d.default)
  754. def meta_adaptive_avg_pool2d(self, output_size):
  755. check(
  756. self.ndim == 3 or self.ndim == 4,
  757. lambda: f"Expected 3D or 4D tensor, but got {self.shape}",
  758. )
  759. output_shape = self.shape[:-2] + tuple(output_size)
  760. memory_format = utils.suggest_memory_format(self)
  761. # need to set memory_format to preserve the memory format of the input
  762. # channel last input should have channel last output
  763. return torch.empty(
  764. output_shape, dtype=self.dtype, device=self.device, memory_format=memory_format
  765. )
  766. @register_meta(aten._adaptive_avg_pool3d.default)
  767. def meta_adaptive_avg_pool3d(self, output_size):
  768. check(
  769. self.ndim == 4 or self.ndim == 5,
  770. lambda: f"Expected 4D or 5D tensor, but got {self.shape}",
  771. )
  772. return self.new_empty(self.shape[:-3] + tuple(output_size))
  773. @register_meta(aten._adaptive_avg_pool2d_backward.default)
  774. def meta__adaptive_avg_pool2d_backward(grad_out, self):
  775. ndim = grad_out.ndim
  776. for i in range(1, ndim):
  777. check(
  778. grad_out.size(i) > 0,
  779. lambda: f"adaptive_avg_pool2d_backward(): Expected grad_output to have non-zero \
  780. size for non-batch dimensions, {grad_out.shape} with dimension {i} being empty",
  781. )
  782. check(
  783. ndim == 3 or ndim == 4,
  784. lambda: f"adaptive_avg_pool2d_backward(): Expected 3D or 4D tensor, but got {self.shape}",
  785. )
  786. check(
  787. self.dtype == grad_out.dtype,
  788. lambda: f"expected dtype {self.dtype} for `grad_output` but got dtype {grad_out.dtype}",
  789. )
  790. return self.new_empty(self.shape)
  791. @register_meta(aten.repeat_interleave.Tensor)
  792. def meta_repeat_interleave_Tensor(repeats, output_size=None):
  793. if output_size is None:
  794. raise RuntimeError("cannot repeat_interleave a meta tensor without output_size")
  795. return repeats.new_empty(output_size)
  796. @register_meta([aten.complex.default, aten.complex.out])
  797. @out_wrapper()
  798. def meta_complex(real, imag):
  799. assert real.dtype.is_floating_point
  800. assert imag.dtype.is_floating_point
  801. out_shape = _broadcast_shapes(real.shape, imag.shape)
  802. return real.new_empty(out_shape, dtype=corresponding_complex_dtype(real.dtype))
  803. @register_meta(aten.vdot.default)
  804. def vdot(self, other):
  805. if not self.is_complex:
  806. return torch.dot(self, other)
  807. if self.is_conj():
  808. if other.is_conj():
  809. return torch.vdot(other.conj(), self.conj())
  810. else:
  811. return torch.dot(self.conj(), other)
  812. elif other.is_conj():
  813. return torch.dot(self, other.conj()).conj()
  814. dot_check(self, other)
  815. return self.new_empty(())
  816. # Leaving this function around because a python implementation
  817. # of indexing shape inference is useful,
  818. # but not registering it to the dispatcher because we already
  819. # get shape inference through structured kernels
  820. @register_meta(aten.index.Tensor)
  821. def meta_index_Tensor(self, indices):
  822. check_no_bool_index_tensors(aten.index.Tensor, self, indices)
  823. check(indices, lambda: "at least one index must be provided")
  824. # aten::index is the internal advanced indexing implementation
  825. # checkIndexTensorTypes and expandTensors
  826. result: List[Optional[Tensor]] = []
  827. for i, index in enumerate(indices):
  828. if index is not None:
  829. check(
  830. index.dtype in [torch.long, torch.int, torch.int8, torch.bool],
  831. lambda: "tensors used as indices must be long, int, byte or bool tensors",
  832. )
  833. if index.dtype in [torch.int8, torch.bool]:
  834. nonzero = index.nonzero()
  835. k = len(result)
  836. check(
  837. k + index.ndim <= self.ndim,
  838. lambda: f"too many indices for tensor of dimension {self.ndim}",
  839. IndexError,
  840. )
  841. for j in range(index.ndim):
  842. check(
  843. index.shape[j] == self.shape[k + j],
  844. lambda: f"The shape of the mask {index.shape} at index {i} "
  845. f"does not match the shape of the indexed tensor {self.shape} at index {k + j}",
  846. IndexError,
  847. )
  848. result.append(nonzero.select(1, j))
  849. else:
  850. result.append(index)
  851. else:
  852. result.append(index)
  853. indices = result
  854. check(
  855. len(indices) <= self.ndim,
  856. lambda: f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})",
  857. )
  858. # expand_outplace
  859. import torch._refs as refs # avoid import cycle in mypy
  860. indices = list(refs._maybe_broadcast(*indices))
  861. # add missing null tensors
  862. while len(indices) < self.ndim:
  863. indices.append(None)
  864. # hasContiguousSubspace
  865. # true if all non-null tensors are adjacent
  866. # See:
  867. # https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
  868. # https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency
  869. state = 0
  870. has_contiguous_subspace = False
  871. for index in indices:
  872. if state == 0:
  873. if index is not None:
  874. state = 1
  875. elif state == 1:
  876. if index is None:
  877. state = 2
  878. else:
  879. if index is not None:
  880. break
  881. else:
  882. has_contiguous_subspace = True
  883. # transposeToFront
  884. # This is the logic that causes the newly inserted dimensions to show up
  885. # at the beginning of the tensor, if they're not contiguous
  886. if not has_contiguous_subspace:
  887. dims = []
  888. transposed_indices = []
  889. for i, index in enumerate(indices):
  890. if index is not None:
  891. dims.append(i)
  892. transposed_indices.append(index)
  893. for i, index in enumerate(indices):
  894. if index is None:
  895. dims.append(i)
  896. transposed_indices.append(index)
  897. self = self.permute(dims)
  898. indices = transposed_indices
  899. # AdvancedIndex::AdvancedIndex
  900. # Now we can assume the indices have contiguous subspace
  901. # This is simplified from AdvancedIndex which goes to more effort
  902. # to put the input and indices in a form so that TensorIterator can
  903. # take them. If we write a ref for this, probably that logic should
  904. # get implemented
  905. before_shape: List[int] = []
  906. after_shape: List[int] = []
  907. replacement_shape: List[int] = []
  908. for dim, index in enumerate(indices):
  909. if index is None:
  910. if replacement_shape:
  911. after_shape.append(self.shape[dim])
  912. else:
  913. before_shape.append(self.shape[dim])
  914. else:
  915. replacement_shape = list(index.shape)
  916. return self.new_empty(before_shape + replacement_shape + after_shape)
  917. @register_meta([aten.convolution_backward.default])
  918. def meta_convolution_backward(
  919. grad_output_,
  920. input_,
  921. weight_,
  922. bias_sizes_opt,
  923. stride,
  924. padding,
  925. dilation,
  926. transposed,
  927. output_padding,
  928. groups,
  929. output_mask,
  930. ):
  931. # High level logic taken from slow_conv3d_backward_cpu which should
  932. # be representative of all convolution_backward impls
  933. backend_grad_input = None
  934. backend_grad_weight = None
  935. backend_grad_bias = None
  936. if output_mask[0]:
  937. backend_grad_input = grad_output_.new_empty(input_.size())
  938. if output_mask[1]:
  939. backend_grad_weight = grad_output_.new_empty(weight_.size())
  940. if output_mask[2]:
  941. backend_grad_bias = grad_output_.new_empty(bias_sizes_opt)
  942. return (backend_grad_input, backend_grad_weight, backend_grad_bias)
  943. @register_meta([aten.addbmm.default, aten.addbmm.out])
  944. @out_wrapper()
  945. def meta_addbmm(self, batch1, batch2, *, beta=1, alpha=1):
  946. dim1 = batch1.size(1)
  947. dim2 = batch2.size(2)
  948. self = self.expand((dim1, dim2))
  949. check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
  950. check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
  951. check(
  952. batch1.size(0) == batch2.size(0),
  953. lambda: f"batch1 and batch2 must have same number of batches, got {batch1.size(0)} and {batch2.size(0)}",
  954. )
  955. check(
  956. batch1.size(2) == batch2.size(1),
  957. lambda: (
  958. f"Incompatible matrix sizes for bmm ({batch1.size(1)}x{batch1.size(2)} "
  959. f"and {batch2.size(1)}x{batch2.size(2)})"
  960. ),
  961. )
  962. check(
  963. self.size(0) == dim1 and self.size(1) == dim2,
  964. lambda: "self tensor does not match matmul output shape",
  965. )
  966. return self.new_empty(self.size())
  967. @register_meta(aten._cdist_forward.default)
  968. def meta_cdist_forward(x1, x2, p, compute_mode):
  969. check(
  970. x1.dim() >= 2,
  971. lambda: f"cdist only supports at least 2D tensors, X1 got: {x1.dim()}D",
  972. )
  973. check(
  974. x2.dim() >= 2,
  975. lambda: f"cdist only supports at least 2D tensors, X2 got: {x2.dim()}D",
  976. )
  977. check(
  978. x1.size(-1) == x2.size(-1),
  979. lambda: f"X1 and X2 must have the same number of columns. X1: {x1.size(-1)} X2: {x2.size(-1)}",
  980. )
  981. check(
  982. utils.is_float_dtype(x1.dtype),
  983. lambda: "cdist only supports floating-point dtypes, X1 got: {x1.dtype}",
  984. )
  985. check(
  986. utils.is_float_dtype(x2.dtype),
  987. lambda: "cdist only supports floating-point dtypes, X2 got: {x2.dtype}",
  988. )
  989. check(p >= 0, lambda: "cdist only supports non-negative p values")
  990. check(
  991. compute_mode in (None, 1, 2),
  992. lambda: f"possible modes: None, 1, 2, but was: {compute_mode}",
  993. )
  994. r1 = x1.size(-2)
  995. r2 = x2.size(-2)
  996. batch_tensor1 = x1.shape[:-2]
  997. batch_tensor2 = x2.shape[:-2]
  998. output_shape = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2))
  999. output_shape.extend([r1, r2])
  1000. return x1.new_empty(output_shape)
  1001. @register_meta(aten._embedding_bag.default)
  1002. def meta_embedding_bag(
  1003. weight,
  1004. indices,
  1005. offsets,
  1006. scale_grad_by_freq=False,
  1007. mode=0,
  1008. sparse=False,
  1009. per_sample_weights=None,
  1010. include_last_offset=False,
  1011. padding_idx=-1,
  1012. ):
  1013. check(
  1014. indices.dtype in (torch.long, torch.int),
  1015. lambda: f"expected indices to be long or int, got {indices.dtype}",
  1016. )
  1017. check(
  1018. offsets.dtype in (torch.long, torch.int),
  1019. lambda: f"expected offsets to be long or int, got {offsets.dtype}",
  1020. )
  1021. check(
  1022. utils.is_float_dtype(weight.dtype),
  1023. lambda: f"expected weight to be floating point type, got {weight.dtype}",
  1024. )
  1025. num_bags = offsets.size(0)
  1026. if include_last_offset:
  1027. check(
  1028. num_bags >= 1, lambda: "include_last_offset: numBags should be at least 1"
  1029. )
  1030. num_bags -= 1
  1031. output = weight.new_empty(num_bags, weight.size(1))
  1032. MODE_SUM, MODE_MEAN, MODE_MAX = range(3)
  1033. if per_sample_weights is not None:
  1034. check(
  1035. mode == MODE_SUM,
  1036. lambda: "embedding_bag: per_sample_weights only supported with mode='sum'",
  1037. )
  1038. check(
  1039. per_sample_weights.dtype == weight.dtype,
  1040. lambda: f"expected weight ({weight.dtype}) and per_sample_weights ({per_sample_weights.dtype}) to have same dtype",
  1041. )
  1042. check(
  1043. per_sample_weights.ndim == 1,
  1044. lambda: f"expected per_sample_weights to be 1D tensor, got {per_sample_weights.ndim}D",
  1045. )
  1046. check(
  1047. per_sample_weights.numel() == indices.numel(),
  1048. lambda: (
  1049. f"expected per_sample_weights.numel() ({per_sample_weights.numel()} "
  1050. f"to be the same as indices.numel() ({indices.numel()})"
  1051. ),
  1052. )
  1053. def is_fast_path_index_select_scale(src, scale, output, padding_idx):
  1054. return (
  1055. is_fast_path_index_select(src, output, padding_idx) and scale.stride(0) == 1
  1056. )
  1057. def is_fast_path_index_select(src, output, padding_idx):
  1058. return (
  1059. (src.dtype == torch.float or src.dtype == torch.half)
  1060. and src.stride(1) == 1
  1061. and output.stride(1) == 1
  1062. and padding_idx < 0
  1063. )
  1064. def is_fast_path(src, scale, output, padding_idx):
  1065. if scale is not None:
  1066. return is_fast_path_index_select_scale(src, scale, output, padding_idx)
  1067. else:
  1068. return is_fast_path_index_select(src, output, padding_idx)
  1069. if device_hint(offsets) != "cpu":
  1070. offset2bag = indices.new_empty(indices.size(0))
  1071. bag_size = indices.new_empty(offsets.size())
  1072. if mode == MODE_MAX:
  1073. max_indices = indices.new_empty(num_bags, weight.size(1))
  1074. else:
  1075. max_indices = indices.new_empty(0)
  1076. else:
  1077. fast_path_sum = is_fast_path(weight, per_sample_weights, output, padding_idx)
  1078. if mode == MODE_MEAN or mode == MODE_MAX or not fast_path_sum:
  1079. offset2bag = offsets.new_empty(indices.size(0))
  1080. else:
  1081. offset2bag = offsets.new_empty(0)
  1082. bag_size = offsets.new_empty(num_bags)
  1083. # This part of the logic comes from make_max_indices_out in EmbeddingBag.cpp
  1084. numBags = offsets.shape[0]
  1085. if mode == MODE_MAX:
  1086. if include_last_offset:
  1087. check(
  1088. numBags >= 1,
  1089. lambda: "include_last_offset: numBags should be at least 1",
  1090. )
  1091. numBags -= 1
  1092. max_indices = offsets.new_empty(numBags, weight.shape[1])
  1093. else:
  1094. max_indices = offsets.new_empty(bag_size.size())
  1095. return output, offset2bag, bag_size, max_indices
  1096. @register_meta(aten._embedding_bag_forward_only.default)
  1097. def meta_embedding_bag_forward_only(weight, indices, offsets, *args):
  1098. output, offset2bag, bag_size, max_indices = meta_embedding_bag(
  1099. weight, indices, offsets, *args
  1100. )
  1101. if device_hint(offsets) == "cpu":
  1102. bag_size = offsets.new_empty(offsets.size())
  1103. return output, offset2bag, bag_size, max_indices
  1104. def _get_reduction_dtype(input, dtype, promote_int_to_long=True):
  1105. # if specified, dtype takes precedence
  1106. if dtype:
  1107. return dtype
  1108. if input.dtype.is_floating_point or input.dtype.is_complex:
  1109. return input.dtype
  1110. elif promote_int_to_long:
  1111. return torch.long
  1112. return input.dtype
  1113. @register_meta([aten.nansum.default, aten.nansum.out])
  1114. @out_wrapper()
  1115. def meta_nansum(input, dims=None, keepdim=False, *, dtype=None):
  1116. output_dtype = _get_reduction_dtype(input, dtype, promote_int_to_long=True)
  1117. dims = utils.reduction_dims(input.shape, dims)
  1118. output_shape = _compute_reduction_shape(input, dims, keepdim)
  1119. return input.new_empty(output_shape, dtype=output_dtype)
  1120. @register_meta(aten.nanmedian.default)
  1121. def meta_nanmedian(input):
  1122. output_shape = utils.compute_reduction_output_shape(
  1123. input.shape, tuple(range(input.dim()))
  1124. )
  1125. return input.new_empty(output_shape)
  1126. @register_meta([aten.nanmedian.dim, aten.nanmedian.dim_values])
  1127. @out_wrapper("values", "indices")
  1128. def meta_nanmedian_dim(input, dim=-1, keepdim=False):
  1129. dim = utils.reduction_dims(input.shape, (dim,))
  1130. output_shape = _compute_reduction_shape(input, dim, keepdim)
  1131. return (
  1132. input.new_empty(output_shape),
  1133. input.new_empty(output_shape, dtype=torch.long),
  1134. )
  1135. @register_meta(aten.logical_not_.default)
  1136. def meta_logical_not_(self):
  1137. return self
  1138. @register_meta(aten.repeat.default)
  1139. def meta_repeat(self, repeats):
  1140. check(
  1141. len(repeats) >= self.dim(),
  1142. lambda: "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor",
  1143. )
  1144. # Add new leading dimensions to the tensor if the
  1145. # number of target dimensions is larger than the
  1146. # number of source dimensions.
  1147. num_new_dimensions = len(repeats) - self.dim()
  1148. padded_size = (1,) * num_new_dimensions + tuple(self.shape)
  1149. target_size = [padded_size[i] * repeats[i] for i in range(len(repeats))]
  1150. return self.new_empty(target_size)
  1151. @register_meta(aten.zero_.default)
  1152. def meta_zero_(self):
  1153. return self
  1154. @register_meta(
  1155. [
  1156. aten.mul_.Scalar,
  1157. aten.div_.Scalar,
  1158. aten.mul_.Tensor,
  1159. aten.div_.Tensor,
  1160. aten.logical_and_.default,
  1161. aten.logical_or_.default,
  1162. aten.logical_xor_.default,
  1163. ],
  1164. )
  1165. def meta_binop_inplace(self, other):
  1166. return self
  1167. @register_meta(
  1168. [
  1169. aten.add_.Scalar,
  1170. aten.sub_.Scalar,
  1171. aten.add_.Tensor,
  1172. aten.sub_.Tensor,
  1173. ],
  1174. )
  1175. def meta_binop_inplace_alpha(self, other, alpha=1):
  1176. return self
  1177. @register_meta([aten.round.default, aten.round.decimals])
  1178. def meta_round(self, **kwargs):
  1179. return _elementwise_meta(
  1180. self, type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT
  1181. )
  1182. @register_meta(aten.zero.default)
  1183. def meta_zero(self):
  1184. return self.new_empty(self.shape)
  1185. @register_meta([aten.fill_.Tensor, aten.fill_.Scalar])
  1186. def meta_fill_(self, val):
  1187. return self
  1188. @register_meta([aten.fill.Tensor, aten.fill.Scalar])
  1189. def meta_fill(self, val):
  1190. return torch.empty_like(self)
  1191. @register_meta(aten.relu_.default)
  1192. def meta_relu_(self):
  1193. return self
  1194. @register_meta(aten.index_put.default)
  1195. def meta_index_put(self, indices, values, accumulate=False):
  1196. return torch.empty_like(self)
  1197. @register_meta(aten.masked_fill_.Scalar)
  1198. def meta_masked_fill_(self, mask, value):
  1199. return self
  1200. @register_meta(aten.index_put_.default)
  1201. def meta_index_put_(self, indices, values, accumulate=False):
  1202. return self
  1203. @register_meta(aten.alias.default)
  1204. def meta_alias(self):
  1205. return self.view(self.shape)
  1206. def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None):
  1207. check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
  1208. check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
  1209. batch1_sizes = batch1.size()
  1210. batch2_sizes = batch2.size()
  1211. bs = batch1_sizes[0]
  1212. contraction_size = batch1_sizes[2]
  1213. res_rows = batch1_sizes[1]
  1214. res_cols = batch2_sizes[2]
  1215. output_size = (bs, res_rows, res_cols)
  1216. check(
  1217. batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size,
  1218. lambda: f"Expected size for first two dimensions of batch2 tensor to be: [{bs}"
  1219. f", {contraction_size}] but got: [{batch2_sizes[0]}, {batch2_sizes[1]}].",
  1220. )
  1221. # TODO: handle out
  1222. output = batch2.new_empty(output_size)
  1223. if not is_bmm and self_baddbmm is not None:
  1224. check(self_baddbmm.dim() == 3, lambda: "self must be a 3D tensor")
  1225. check(
  1226. self_baddbmm.size() == output_size,
  1227. lambda: "Expected an input tensor shape with shape {output_size} but got shape: {self.size()}",
  1228. )
  1229. return output
  1230. @register_meta(aten.bmm.default)
  1231. def meta_bmm(self, mat2):
  1232. return common_meta_baddbmm_bmm(self, mat2, True)
  1233. def div_rtn(x, y):
  1234. q = x // y
  1235. r = x % y
  1236. # WARNING: explicit bool conversion here is necessary;
  1237. # would be fixed by SymBool
  1238. if r != 0 and (bool(r < 0) != bool(y < 0)):
  1239. q -= 1
  1240. return q
  1241. def pooling_output_shape_pad_lr(
  1242. inputSize, kernelSize, pad_l, pad_r, stride, dilation, ceil_mode
  1243. ):
  1244. outputSize = (
  1245. div_rtn(
  1246. inputSize
  1247. + pad_l
  1248. + pad_r
  1249. - dilation * (kernelSize - 1)
  1250. - 1
  1251. + (stride - 1 if ceil_mode else 0),
  1252. stride,
  1253. )
  1254. + 1
  1255. )
  1256. if ceil_mode:
  1257. if (outputSize - 1) * stride >= inputSize + pad_l:
  1258. outputSize -= 1
  1259. return outputSize
  1260. def pooling_output_shape(inputSize, kernelSize, pad, stride, dilation, ceil_mode):
  1261. check(stride != 0, lambda: "stride should not be zero")
  1262. check(pad >= 0, lambda: f"pad must be non-negative, but got pad: {pad}")
  1263. check(
  1264. pad <= kernelSize // 2,
  1265. lambda: f"pad should be at most half of kernel size, but got pad={pad} and kernel_size={kernelSize}",
  1266. )
  1267. return pooling_output_shape_pad_lr(
  1268. inputSize, kernelSize, pad, pad, stride, dilation, ceil_mode
  1269. )
  1270. def pool2d_shape_check(
  1271. input,
  1272. kH,
  1273. kW,
  1274. dH,
  1275. dW,
  1276. padH,
  1277. padW,
  1278. dilationH,
  1279. dilationW,
  1280. nInputPlane,
  1281. inputHeight,
  1282. inputWidth,
  1283. outputHeight,
  1284. outputWidth,
  1285. memory_format,
  1286. ):
  1287. ndim = input.dim()
  1288. nOutputPlane = nInputPlane
  1289. check(
  1290. kW > 0 and kH > 0,
  1291. lambda: "kernel size should be greater than zero, but got kH: {kH}, kW: {kW}",
  1292. )
  1293. check(
  1294. dW > 0 and dH > 0,
  1295. lambda: "stride should be greater than zero, but got dH: {dH}, dW: {dW}",
  1296. )
  1297. check(
  1298. dilationH > 0 and dilationW > 0,
  1299. lambda: "dilation should be greater than zero, but got dilationH: {dilationH}, dilationW: {dilationW}",
  1300. )
  1301. valid_dims = input.size(1) != 0 and input.size(2) != 0
  1302. if memory_format == torch.channels_last:
  1303. check(
  1304. ndim == 4 and valid_dims and input.size(3) != 0,
  1305. lambda: "Expected 4D (batch mode) tensor expected for input with channels_last layout"
  1306. " with optional 0 dim batch size for input, but got: {input.size()}",
  1307. )
  1308. else:
  1309. check(
  1310. (ndim == 3 and input.size(0) != 0 and valid_dims)
  1311. or (ndim == 4 and valid_dims and input.size(3) != 0),
  1312. lambda: f"Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got: {input.size()}",
  1313. )
  1314. check(
  1315. kW // 2 >= padW and kH // 2 >= padH,
  1316. lambda: "pad should be smaller than or equal to half of kernel size, but got "
  1317. f"padW = {padW}, padH = {padH}, kW = {kW}, kH = {kH}",
  1318. )
  1319. check(
  1320. outputWidth >= 1 and outputHeight >= 1,
  1321. lambda: f"Given input size: ({nInputPlane}x{inputHeight}x{inputWidth}). "
  1322. f"Calculated output size: ({nOutputPlane}x{outputHeight}x{outputWidth}). "
  1323. "Output size is too small",
  1324. )
  1325. def max_pool2d_checks_and_compute_shape(
  1326. input, kernel_size, stride, padding, dilation, ceil_mode
  1327. ):
  1328. # Reference: aten/src/ATen/native/DilatedMaxPool2d.cpp
  1329. def unpack(name, val):
  1330. check(
  1331. len(val) in [1, 2],
  1332. lambda: f"max_pool2d: {name} must either be a single int, or a tuple of two ints",
  1333. )
  1334. H = val[0]
  1335. W = H if len(val) == 1 else val[1]
  1336. return H, W
  1337. kH, kW = unpack("kernel_size", kernel_size)
  1338. check(
  1339. len(stride) in [0, 1, 2],
  1340. lambda: "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
  1341. )
  1342. if len(stride) == 0:
  1343. dH, dW = kH, kW
  1344. else:
  1345. dH, dW = unpack("stride", stride)
  1346. padH, padW = unpack("padding", padding)
  1347. dilationH, dilationW = unpack("dilation", dilation)
  1348. nInputPlane = input.size(-3)
  1349. inputHeight = input.size(-2)
  1350. inputWidth = input.size(-1)
  1351. memory_format = utils.suggest_memory_format(input)
  1352. if memory_format == torch.channels_last:
  1353. check(
  1354. input.dim() == 4,
  1355. lambda: "non-empty 4D (batch mode) tensor expected for input with channels_last layout",
  1356. )
  1357. elif memory_format == torch.contiguous_format:
  1358. check(
  1359. input.dim() in [3, 4],
  1360. lambda: "non-empty 3D or 4D (batch mode) tensor expected for input",
  1361. )
  1362. else:
  1363. check(
  1364. False,
  1365. lambda: "Unsupport memory format. Supports only ChannelsLast, Contiguous",
  1366. )
  1367. outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode)
  1368. outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode)
  1369. pool2d_shape_check(
  1370. input,
  1371. kH,
  1372. kW,
  1373. dH,
  1374. dW,
  1375. padH,
  1376. padW,
  1377. dilationH,
  1378. dilationW,
  1379. nInputPlane,
  1380. inputHeight,
  1381. inputWidth,
  1382. outputHeight,
  1383. outputWidth,
  1384. memory_format,
  1385. )
  1386. return nInputPlane, outputHeight, outputWidth
  1387. @register_meta(aten.max_pool2d_with_indices_backward.default)
  1388. def meta_max_pool2d_with_indices_backward(
  1389. grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices
  1390. ):
  1391. nInputPlane, outputHeight, outputWidth = max_pool2d_checks_and_compute_shape(
  1392. self, kernel_size, stride, padding, dilation, ceil_mode
  1393. )
  1394. check(
  1395. self.dtype == grad_output.dtype,
  1396. lambda: "expected dtype {self.dtype} for `gradOutput` but got dtype {grad_output.dtype}",
  1397. )
  1398. nOutputPlane = nInputPlane
  1399. ndim = self.ndim
  1400. def _check_dim_size(t):
  1401. check_dim_size(t, ndim, ndim - 3, nOutputPlane)
  1402. check_dim_size(t, ndim, ndim - 2, outputHeight)
  1403. check_dim_size(t, ndim, ndim - 1, outputWidth)
  1404. _check_dim_size(grad_output)
  1405. _check_dim_size(indices)
  1406. memory_format = utils.suggest_memory_format(self)
  1407. return torch.empty(
  1408. self.shape, dtype=self.dtype, device=self.device, memory_format=memory_format
  1409. )
  1410. @register_meta(aten.max_pool2d_with_indices.default)
  1411. def meta_max_pool2d_with_indices(
  1412. input, kernel_size, stride=(), padding=(0,), dilation=(1,), ceil_mode=False
  1413. ):
  1414. nInputPlane, outputHeight, outputWidth = max_pool2d_checks_and_compute_shape(
  1415. input, kernel_size, stride, padding, dilation, ceil_mode
  1416. )
  1417. nbatch = input.size(-4) if input.dim() == 4 else 1
  1418. memory_format = utils.suggest_memory_format(input)
  1419. if input.dim() == 3:
  1420. size = [nInputPlane, outputHeight, outputWidth]
  1421. else:
  1422. size = [nbatch, nInputPlane, outputHeight, outputWidth]
  1423. return (
  1424. torch.empty(
  1425. size, dtype=input.dtype, device=input.device, memory_format=memory_format
  1426. ),
  1427. torch.empty(
  1428. size, dtype=torch.int64, device=input.device, memory_format=memory_format
  1429. ),
  1430. )
  1431. @register_meta(aten.grid_sampler_2d_backward.default)
  1432. def grid_sampler_2d_backward_meta(
  1433. grad_output,
  1434. input,
  1435. grid,
  1436. interpolation_mode,
  1437. padding_mode,
  1438. align_corners,
  1439. output_mask,
  1440. ):
  1441. input_requires_grad = output_mask[0]
  1442. if input_requires_grad:
  1443. grad_input = torch.zeros_like(input, memory_format=torch.contiguous_format)
  1444. else:
  1445. grad_input = None
  1446. grad_grid = torch.empty_like(grid, memory_format=torch.contiguous_format)
  1447. return (grad_input, grad_grid)
  1448. @register_meta([aten.full.default])
  1449. def full(size, fill_value, *args, **kwargs):
  1450. return torch.empty(size, *args, **kwargs)
  1451. @register_meta(
  1452. [
  1453. aten.randint_like.default,
  1454. aten.randint_like.low_dtype,
  1455. aten.randn_like.default,
  1456. aten.rand_like.default,
  1457. aten.full_like.default,
  1458. aten.ones_like.default,
  1459. ]
  1460. )
  1461. def meta_like(self, *args, **kwargs):
  1462. return aten.empty_like.default(self, **kwargs)
  1463. # zeros_like is special cased to work for sparse
  1464. @register_meta(aten.zeros_like.default)
  1465. def zeros_like(
  1466. self, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None
  1467. ):
  1468. if layout == torch.sparse_coo:
  1469. check(
  1470. memory_format is None,
  1471. lambda: "memory format option is only supported by strided tensors",
  1472. )
  1473. res = torch.empty(
  1474. 0,
  1475. dtype=self.dtype if dtype is None else dtype,
  1476. layout=layout,
  1477. device=self.device if device is None else device,
  1478. pin_memory=pin_memory,
  1479. )
  1480. if self.is_sparse:
  1481. res.sparse_resize_and_clear_(
  1482. self.size(), self.sparse_dim(), self.dense_dim()
  1483. )
  1484. else:
  1485. res.sparse_resize_and_clear_(self.size(), self.dim(), 0)
  1486. res._coalesced_(True)
  1487. return res
  1488. return aten.empty_like.default(
  1489. self,
  1490. dtype=dtype,
  1491. layout=layout,
  1492. device=device,
  1493. pin_memory=pin_memory,
  1494. memory_format=memory_format,
  1495. )
  1496. @register_meta(aten.select.int)
  1497. def meta_select(self, dim, index):
  1498. ndim = self.dim()
  1499. check(
  1500. ndim != 0, lambda: "select() cannot be applied to a 0-dim tensor.", IndexError
  1501. )
  1502. dim = dim if dim >= 0 else dim + ndim
  1503. size = self.size(dim)
  1504. check(
  1505. not (-index > size or index >= size),
  1506. lambda: f"select(): index {index} out of range for tensor of size "
  1507. f"{self.size()} at dimension {dim}",
  1508. IndexError,
  1509. )
  1510. index = index if index >= 0 else index + size
  1511. new_size = list(self.size())
  1512. new_stride = list(self.stride())
  1513. new_storage_offset = self.storage_offset() + index * new_stride[dim]
  1514. del new_size[dim]
  1515. del new_stride[dim]
  1516. return self.as_strided(new_size, new_stride, new_storage_offset)
  1517. @register_meta(aten.select_scatter.default)
  1518. def meta_select_scatter(self, src, dim, index):
  1519. return utils.clone_preserve_strides(self)
  1520. @register_meta(aten.slice_scatter.default)
  1521. def meta_slice_scatter(self, src, dim=0, start=None, end=None, step=1):
  1522. return utils.clone_preserve_strides(self)
  1523. # TODO: Deduplicate this with canonicalize_dim
  1524. def maybe_wrap_dim(dim: int, dim_post_expr: int, wrap_scalar: bool = True):
  1525. if dim_post_expr <= 0:
  1526. assert wrap_scalar
  1527. dim_post_expr = 1
  1528. min = -dim_post_expr
  1529. max = dim_post_expr - 1
  1530. assert not (dim < min or dim > max), f"dim {dim} out of bounds ({min}, {max})"
  1531. if dim < 0:
  1532. dim += dim_post_expr
  1533. return dim
  1534. def ensure_nonempty_size(t, dim):
  1535. return 1 if t.dim() == 0 else t.shape[dim]
  1536. # From aten/src/ATen/native/ScatterGatherChecks.h
  1537. def gather_shape_check(self, dim, index):
  1538. self_dims = max(self.dim(), 1)
  1539. index_dims = max(index.dim(), 1)
  1540. check(
  1541. self_dims == index_dims,
  1542. lambda: "Index tensor must have the same number of dimensions as input tensor",
  1543. )
  1544. for i in range(self_dims):
  1545. if i != dim:
  1546. check(
  1547. ensure_nonempty_size(index, i) <= ensure_nonempty_size(self, i),
  1548. lambda: f"Size does not match at dimension {i} expected index {index.shape}"
  1549. + f" to be smaller than self {self.shape} apart from dimension {dim}",
  1550. )
  1551. @register_meta(aten.gather.default)
  1552. def meta_gather(self, dim, index, sparse_grad=False):
  1553. wrapped_dim = maybe_wrap_dim(dim, self.dim())
  1554. is_index_empty = index.numel() == 0
  1555. if not is_index_empty:
  1556. check(
  1557. index.dtype == torch.long,
  1558. lambda: f"gather(): Expected dtype int64 for index, but got {index.dtype}",
  1559. )
  1560. gather_shape_check(self, wrapped_dim, index)
  1561. return self.new_empty(index.shape)
  1562. # From aten/src/ATen/native/TensorAdvancedIndexing.cpp
  1563. def get_operator_enum(reduce_, use_new_options=False):
  1564. if use_new_options:
  1565. if reduce_ == "sum":
  1566. return "REDUCE_ADD"
  1567. elif reduce_ == "prod":
  1568. return "REDUCE_MULTIPLY"
  1569. elif reduce_ == "mean":
  1570. return "REDUCE_MEAN"
  1571. elif reduce_ == "amax":
  1572. return "REDUCE_MAXIMUM"
  1573. elif reduce_ == "amin":
  1574. return "REDUCE_MINIMUM"
  1575. check(
  1576. False,
  1577. lambda: "reduce argument must be either sum, prod, mean, amax or amin.",
  1578. )
  1579. return
  1580. else:
  1581. if reduce_ == "add":
  1582. return "REDUCE_ADD"
  1583. elif reduce_ == "multiply":
  1584. return "REDUCE_MULTIPLY"
  1585. check(False, lambda: "reduce argument must be either add or multiply.")
  1586. return
  1587. # From aten/src/ATen/native/ScatterGatherChecks.h
  1588. def scatter_gather_dtype_check(method_name, self, index, src_opt=None):
  1589. if index.numel() != 0:
  1590. check(
  1591. index.dtype == torch.long,
  1592. lambda: f"{method_name}(): Expected dtype int64 for index",
  1593. )
  1594. if src_opt is not None:
  1595. check(
  1596. self.dtype == src_opt.dtype,
  1597. lambda: f"{method_name}(): Expected self.dtype to be equal to src.dtype",
  1598. )
  1599. def ensure_nonempty_dim(dim):
  1600. return max(dim, 1)
  1601. # From aten/src/ATen/native/ScatterGatherChecks.h
  1602. def scatter_shape_check(self, dim, index, src_opt=None):
  1603. if index.numel() == 0:
  1604. return
  1605. check(
  1606. ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
  1607. lambda: "Index tensor must have the same number of dimensions as self tensor",
  1608. )
  1609. is_wrong_shape = False
  1610. self_dims = ensure_nonempty_dim(self.dim())
  1611. # Check: index.size(d) <= self.size(d) for all d != dim
  1612. for d in range(self_dims):
  1613. index_d_size = ensure_nonempty_size(index, d)
  1614. if d == dim:
  1615. continue
  1616. if index_d_size > ensure_nonempty_size(self, d):
  1617. is_wrong_shape = True
  1618. break
  1619. # Check: index.size(d) <= src.size(d) for all d if src is Tensor
  1620. if not is_wrong_shape and src_opt is not None:
  1621. for d in range(self_dims):
  1622. index_d_size = ensure_nonempty_size(index, d)
  1623. if index_d_size > ensure_nonempty_size(src_opt, d):
  1624. is_wrong_shape = True
  1625. break
  1626. if src_opt is not None:
  1627. check(
  1628. ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
  1629. lambda: "Index tensor must have the same number of dimensions as self tensor",
  1630. )
  1631. check(
  1632. not is_wrong_shape,
  1633. lambda: f"Expected index {index.shape} to be smaller than self {self.shape}"
  1634. + f" apart from dimension {dim} and to be smaller than src {src_opt.shape}",
  1635. )
  1636. else:
  1637. check(
  1638. not is_wrong_shape,
  1639. lambda: f"Expected index {index.shape} to be smaller than self {self.shape}"
  1640. + f" apart from dimension {dim}",
  1641. )
  1642. # From aten/src/ATen/native/TensorAdvancedIndexing.cpp
  1643. def scatter_meta_impl(self, dim, index, src=None, reduce_=None, use_new_options=False):
  1644. wrapped_dim = maybe_wrap_dim(dim, self.dim())
  1645. scatter_gather_dtype_check("scatter", self, index, src)
  1646. scatter_shape_check(self, wrapped_dim, index, src)
  1647. if reduce_ is not None:
  1648. # Check if we have a valid reduce operator.
  1649. get_operator_enum(reduce_, use_new_options)
  1650. @register_meta(aten.scatter_add.default)
  1651. def meta_scatter_add(self, dim, index, src):
  1652. scatter_meta_impl(self, dim, index, src, "add")
  1653. return self.new_empty(self.shape)
  1654. @register_meta(aten.scatter_add_)
  1655. def meta_scatter_add_(self, dim, index, src):
  1656. scatter_meta_impl(self, dim, index, src, "add")
  1657. return self
  1658. @register_meta(
  1659. [
  1660. aten.scatter.src,
  1661. aten.scatter.value,
  1662. aten.scatter.reduce,
  1663. aten.scatter.value_reduce,
  1664. ]
  1665. )
  1666. @out_wrapper()
  1667. def meta_scatter(self, dim, index, src_or_value, reduce=None):
  1668. src = src_or_value if isinstance(src_or_value, torch.Tensor) else None
  1669. scatter_meta_impl(self, dim, index, src, reduce)
  1670. return self.new_empty(self.shape)
  1671. @register_meta(
  1672. [
  1673. aten.scatter_.src,
  1674. aten.scatter_.value,
  1675. aten.scatter_.reduce,
  1676. aten.scatter_.value_reduce,
  1677. ]
  1678. )
  1679. def meta_scatter_(self, dim, index, src_or_value, reduce=None):
  1680. src = src_or_value if isinstance(src_or_value, torch.Tensor) else None
  1681. scatter_meta_impl(self, dim, index, src, reduce)
  1682. return self
  1683. @register_meta(
  1684. [
  1685. aten._scaled_dot_product_flash_attention,
  1686. ]
  1687. )
  1688. def meta__scaled_dot_product_flash(
  1689. query: Tensor,
  1690. key: Tensor,
  1691. value: Tensor,
  1692. dropout_p: float = 0.0,
  1693. is_causal: bool = False,
  1694. return_debug_mask: bool = False,
  1695. ):
  1696. # [Note] SDPA_flash's meta function returns incorrect Philox seed and offset:
  1697. # We have added logic to torch/_dynamo/variables/torch.py
  1698. # We need to check if scaled_dot_product_attention will run the flash attention
  1699. # kernel and if dropout is != 0.0. If that is the case then we want dynamo
  1700. # to graph break. The derivative calculation for _scaled_dot_product_flash_attention
  1701. # does not function correctly with cuda graphs because the full philox state is not captured
  1702. # the forward's return values. Another reason to graph break is that the the meta function
  1703. # returns the wrong outputs for philox seed and offset and these values get baked into the
  1704. # inductor fallback calls to the eager kernels.
  1705. check(
  1706. dropout_p == 0.0,
  1707. lambda: f"Can only trace _scaled_dot_product_flash_attention when dropout is set to 0 but got a dropout_p of {dropout_p}.",
  1708. )
  1709. batch_size = query.size(0)
  1710. num_heads = query.size(1)
  1711. max_seqlen_batch_q = query.size(2)
  1712. head_dim = query.size(3)
  1713. max_seqlen_batch_k = key.size(2)
  1714. query = query.transpose(1, 2)
  1715. key = key.transpose(1, 2)
  1716. value = value.transpose(1, 2)
  1717. Nnz_q = batch_size * max_seqlen_batch_q
  1718. output = torch.empty(
  1719. (Nnz_q, num_heads, head_dim), dtype=query.dtype, device=query.device
  1720. )
  1721. output = output.view(batch_size, max_seqlen_batch_q, num_heads, head_dim).transpose(
  1722. 1, 2
  1723. )
  1724. max_seqlen_q = math.ceil(max_seqlen_batch_q / 16) * 16
  1725. logsumexp = torch.empty(
  1726. (batch_size, num_heads, max_seqlen_q),
  1727. dtype=torch.float,
  1728. device=query.device,
  1729. )
  1730. cumulative_sequence_length_q = torch.empty(
  1731. batch_size + 1, dtype=torch.int32, device="meta"
  1732. )
  1733. cumulative_sequence_length_k = torch.empty(
  1734. batch_size + 1, dtype=torch.int32, device="meta"
  1735. )
  1736. if return_debug_mask:
  1737. blocksize_c = 128 if head_dim > 64 else 256
  1738. max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)
  1739. if max_seqlen_batch_k <= 128:
  1740. max_seqlen_k = 128
  1741. elif max_seqlen_batch_k <= 256:
  1742. max_seqlen_k = 256
  1743. debug_mask = torch.empty(
  1744. (batch_size, num_heads, max_seqlen_q, max_seqlen_k),
  1745. dtype=query.dtype,
  1746. device=query.device,
  1747. )
  1748. else:
  1749. debug_mask = torch.empty(0, dtype=query.dtype, device=query.device)
  1750. return (
  1751. output,
  1752. logsumexp,
  1753. cumulative_sequence_length_q,
  1754. cumulative_sequence_length_k,
  1755. max_seqlen_batch_q,
  1756. max_seqlen_batch_k,
  1757. 1, # Philox Seed will not be used, see note at top.
  1758. 1, # Philox Offset will not be used, see note at top.
  1759. debug_mask,
  1760. )
  1761. @register_meta(
  1762. [
  1763. aten._scaled_dot_product_flash_attention_backward,
  1764. ]
  1765. )
  1766. def meta__scaled_dot_product_flash_backward(
  1767. grad_out: Tensor,
  1768. query: Tensor,
  1769. key: Tensor,
  1770. value: Tensor,
  1771. out: Tensor,
  1772. logsumexp: Tensor,
  1773. cum_seq_q: Tensor,
  1774. cum_seq_k: Tensor,
  1775. max_q: int,
  1776. max_k: int,
  1777. dropout_p: float,
  1778. is_causal: bool,
  1779. philox_seed: int,
  1780. philox_offset: int,
  1781. ):
  1782. batch_size = query.size(0)
  1783. num_heads = query.size(1)
  1784. head_dim = query.size(3)
  1785. Nnz_q = batch_size * max_q
  1786. Nnz_kv = batch_size * max_k
  1787. query = query.transpose(1, 2)
  1788. key = key.transpose(1, 2)
  1789. value = value.transpose(1, 2)
  1790. query_reshaped = query.reshape(Nnz_q, num_heads, head_dim)
  1791. key_reshaped = key.reshape(Nnz_kv, num_heads, head_dim)
  1792. value_reshaped = value.reshape(Nnz_kv, num_heads, head_dim)
  1793. grad_q = torch.empty_like(query_reshaped)
  1794. grad_k = torch.empty_like(key_reshaped)
  1795. grad_v = torch.empty_like(value_reshaped)
  1796. grad_q = grad_q.view(batch_size, max_q, num_heads, head_dim).transpose(1, 2)
  1797. grad_k = grad_k.view(batch_size, max_k, num_heads, head_dim).transpose(1, 2)
  1798. grad_v = grad_v.view(batch_size, max_k, num_heads, head_dim).transpose(1, 2)
  1799. return grad_q, grad_k, grad_v
  1800. @register_meta(
  1801. [
  1802. aten._scaled_dot_product_efficient_attention,
  1803. ]
  1804. )
  1805. def meta__scaled_dot_product_efficient(
  1806. query: Tensor,
  1807. key: Tensor,
  1808. value: Tensor,
  1809. compute_log_sumexp: bool,
  1810. is_causal: bool = False,
  1811. ):
  1812. query = query.transpose(1, 2)
  1813. key = key.transpose(1, 2)
  1814. value = value.transpose(1, 2)
  1815. B = query.size(0)
  1816. M = query.size(1)
  1817. N = key.size(1)
  1818. num_heads = query.size(-2)
  1819. K = query.size(-1)
  1820. Kv = value.size(-1)
  1821. res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device)
  1822. logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0
  1823. logsum_exp = torch.empty(
  1824. (B, num_heads, logsumexp_dim),
  1825. dtype=torch.float,
  1826. device=query.device,
  1827. )
  1828. res = res.transpose(1, 2)
  1829. return res, logsum_exp
  1830. @register_meta(
  1831. [
  1832. aten._scaled_dot_product_efficient_attention_backward,
  1833. ]
  1834. )
  1835. def meta__scaled_dot_product_efficient_backward(
  1836. grad_out: Tensor,
  1837. query: Tensor,
  1838. key: Tensor,
  1839. value: Tensor,
  1840. out: Tensor,
  1841. logsumexp: Tensor,
  1842. is_causal: bool = False,
  1843. chunk_grad_outputs=False,
  1844. ):
  1845. grad_out = grad_out.transpose(1, 2)
  1846. query = query.transpose(1, 2)
  1847. key = key.transpose(1, 2)
  1848. value = value.transpose(1, 2)
  1849. B = query.size(0)
  1850. M = query.size(1)
  1851. N = key.size(1)
  1852. nH = query.size(2)
  1853. K = query.size(3)
  1854. grad_kv_needs_init = is_causal and N > M
  1855. if chunk_grad_outputs:
  1856. chunk = torch.empty((B, M, 3, nH, K), dtype=query.dtype, device=query.device)
  1857. grad_q = chunk.select(2, 0)
  1858. grad_k = chunk.select(2, 1)
  1859. grad_v = chunk.select(2, 2)
  1860. else:
  1861. grad_q = torch.empty(query.shape, dtype=query.dtype, device=query.device)
  1862. grad_k = (
  1863. torch.zeros(key.shape, dtype=key.dtype, device=key.device)
  1864. if grad_kv_needs_init
  1865. else torch.empty(key.shape, dtype=key.dtype, device=key.device)
  1866. )
  1867. grad_v = (
  1868. torch.zeros(value.shape, dtype=value.dtype, device=value.device)
  1869. if grad_kv_needs_init
  1870. else torch.empty(value.shape, dtype=value.dtype, device=value.device)
  1871. )
  1872. return grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2)
  1873. @register_meta([aten.scatter_reduce.two, aten.scatter_reduce.two_out])
  1874. @out_wrapper()
  1875. def meta_scatter_reduce_two(self, dim, index, src, reduce, include_self=True):
  1876. scatter_meta_impl(self, dim, index, src, reduce, use_new_options=True)
  1877. return self.new_empty(self.shape)
  1878. @register_meta(aten.scatter_reduce_.two)
  1879. def meta_scatter_reduce__two(self, dim, index, src, reduce, include_self=True):
  1880. scatter_meta_impl(self, dim, index, src, reduce, use_new_options=True)
  1881. return self
  1882. def multiply_integers(vs):
  1883. r = 1
  1884. for v in vs:
  1885. r *= v
  1886. return r
  1887. def upsample_common_check(input_size, output_size, num_spatial_dims):
  1888. check(
  1889. len(output_size) == num_spatial_dims,
  1890. lambda: f"It is expected output_size equals to {num_spatial_dims}, but got size {len(output_size)}",
  1891. )
  1892. expected_input_dims = num_spatial_dims + 2 # N, C, ...
  1893. check(
  1894. len(input_size) == expected_input_dims,
  1895. lambda: f"It is expected input_size equals to {expected_input_dims}, but got size {len(input_size)}",
  1896. )
  1897. check(
  1898. all([s > 0 for s in input_size[2:]]) and all([s > 0 for s in output_size]),
  1899. lambda: f"Input and output sizes should be greater than 0, but got "
  1900. f"input size {input_size} and output size {output_size}",
  1901. )
  1902. nbatch, channels = input_size[:2]
  1903. return (nbatch, channels, *output_size)
  1904. @register_meta(aten.upsample_nearest1d.default)
  1905. def upsample_nearest1d(input, output_size, scales=None):
  1906. check(
  1907. input.numel() != 0 or multiply_integers(input.size()[1:]),
  1908. lambda: "Non-empty 3D data tensor expected but got a tensor with sizes {input.size()}",
  1909. )
  1910. full_output_size = upsample_common_check(
  1911. input.size(), output_size, num_spatial_dims=1
  1912. )
  1913. return input.new_empty(full_output_size).to(
  1914. memory_format=utils.suggest_memory_format(input)
  1915. )
  1916. @register_meta(aten.upsample_nearest2d.default)
  1917. def upsample_nearest2d(input, output_size, scales_h=None, scales_w=None):
  1918. check(
  1919. input.numel() != 0 or multiply_integers(input.size()[1:]),
  1920. lambda: "Non-empty 4D data tensor expected but got a tensor with sizes {input.size()}",
  1921. )
  1922. full_output_size = upsample_common_check(
  1923. input.size(), output_size, num_spatial_dims=2
  1924. )
  1925. output = input.new_empty(full_output_size)
  1926. # convert output to correct memory format, if necessary
  1927. memory_format = utils.suggest_memory_format(input)
  1928. # following "heuristic: only use channels_last path when it's faster than the contiguous path"
  1929. _, n_channels, _, _ = input.shape
  1930. if input.device.type == "cuda" and n_channels < 4:
  1931. memory_format = torch.contiguous_format
  1932. output = output.contiguous(memory_format=memory_format)
  1933. return output
  1934. @register_meta(aten.upsample_nearest3d.default)
  1935. def upsample_nearest3d(input, output_size, scales_d=None, scales_h=None, scales_w=None):
  1936. check(
  1937. input.numel() != 0 or multiply_integers(input.size()[1:]),
  1938. lambda: "Non-empty 5D data tensor expected but got a tensor with sizes {input.size()}",
  1939. )
  1940. full_output_size = upsample_common_check(
  1941. input.size(), output_size, num_spatial_dims=3
  1942. )
  1943. return input.new_empty(full_output_size).to(
  1944. memory_format=utils.suggest_memory_format(input)
  1945. )
  1946. @register_meta([aten.sort.default, aten.sort.stable])
  1947. def meta_sort(self, stable=None, dim=-1, descending=False):
  1948. return torch.empty_like(self), torch.empty_like(self, dtype=torch.int64)
  1949. def rnn_cell_checkSizes(
  1950. input_gates, hidden_gates, input_bias, hidden_bias, factor, prev_hidden
  1951. ):
  1952. check(input_gates.ndim == 2, lambda: f"{input_gates.ndim} != 2")
  1953. check(
  1954. input_gates.shape == hidden_gates.shape,
  1955. lambda: f"{input_gates.shape} != {hidden_gates.shape}",
  1956. )
  1957. gates_size = input_gates.size(1)
  1958. if input_bias is not None:
  1959. check(input_bias.ndim == 1, lambda: f"{input_bias.ndim} != 1")
  1960. check(
  1961. input_bias.numel() == gates_size,
  1962. lambda: f"{input_bias.numel()} != {gates_size}",
  1963. )
  1964. check(
  1965. input_bias.shape == hidden_bias.shape,
  1966. lambda: f"{input_bias.shape} != {hidden_bias.shape}",
  1967. )
  1968. check(prev_hidden.ndim == 2, lambda: f"{prev_hidden.ndim} != 2")
  1969. expected_prev_hidden_numel = input_gates.size(0) * gates_size // factor
  1970. check(
  1971. prev_hidden.numel() == expected_prev_hidden_numel,
  1972. lambda: f"{prev_hidden.numel()} != {input_gates.size(0)} * {gates_size} // {factor} (aka {expected_prev_hidden_numel})",
  1973. )
  1974. check(
  1975. all(
  1976. x.device == input_gates.device
  1977. for x in [hidden_gates, input_bias, hidden_bias, prev_hidden]
  1978. ),
  1979. lambda: "expected all inputs to be same device",
  1980. )
  1981. @register_meta(aten._thnn_fused_lstm_cell.default)
  1982. def _thnn_fused_lstm_cell_meta(
  1983. input_gates, hidden_gates, cx, input_bias=None, hidden_bias=None
  1984. ):
  1985. rnn_cell_checkSizes(input_gates, hidden_gates, input_bias, hidden_bias, 4, cx)
  1986. workspace = torch.empty_like(input_gates, memory_format=torch.contiguous_format)
  1987. hy = torch.empty_like(cx, memory_format=torch.contiguous_format)
  1988. cy = torch.empty_like(cx, memory_format=torch.contiguous_format)
  1989. return (hy, cy, workspace)
  1990. @register_meta(aten._cudnn_rnn.default)
  1991. def _cudnn_rnn(
  1992. input,
  1993. weight,
  1994. weight_stride0,
  1995. weight_buf,
  1996. hx,
  1997. cx,
  1998. mode,
  1999. hidden_size,
  2000. proj_size,
  2001. num_layers,
  2002. batch_first,
  2003. dropout,
  2004. train,
  2005. bidirectional,
  2006. batch_sizes,
  2007. dropout_state,
  2008. ):
  2009. is_input_packed = len(batch_sizes) != 0
  2010. if is_input_packed:
  2011. seq_length = len(batch_sizes)
  2012. mini_batch = batch_sizes[0]
  2013. batch_sizes_sum = input.shape[0]
  2014. else:
  2015. seq_length = input.shape[1] if batch_first else input.shape[0]
  2016. mini_batch = input.shape[0] if batch_first else input.shape[1]
  2017. batch_sizes_sum = -1
  2018. num_directions = 2 if bidirectional else 1
  2019. out_size = proj_size if proj_size != 0 else hidden_size
  2020. if is_input_packed:
  2021. out_shape = [batch_sizes_sum, out_size * num_directions]
  2022. else:
  2023. out_shape = (
  2024. [mini_batch, seq_length, out_size * num_directions]
  2025. if batch_first
  2026. else [seq_length, mini_batch, out_size * num_directions]
  2027. )
  2028. output = input.new_empty(out_shape)
  2029. cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
  2030. if cx is None:
  2031. cy = torch.empty(0, device=input.device)
  2032. else:
  2033. cy = cx.new_empty(cell_shape)
  2034. hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size])
  2035. # TODO: Query cudnnGetRNNTrainingReserveSize (expose to python)
  2036. reserve_shape = 0 if train else 0
  2037. reserve = input.new_empty(reserve_shape, dtype=torch.uint8)
  2038. return output, hy, cy, reserve, weight_buf
  2039. @register_meta(aten.mkldnn_rnn_layer.default)
  2040. def mkldnn_rnn_layer(
  2041. input,
  2042. w0,
  2043. w1,
  2044. w2,
  2045. w3,
  2046. hx_,
  2047. cx_,
  2048. reverse,
  2049. batch_sizes,
  2050. mode,
  2051. hidden_size,
  2052. num_layers,
  2053. has_biases,
  2054. bidirectional,
  2055. batch_first,
  2056. train,
  2057. ):
  2058. seq_length = input.shape[1] if batch_first else input.shape[0]
  2059. mini_batch = input.shape[0] if batch_first else input.shape[1]
  2060. output_chanels = hidden_size
  2061. out_shape = (
  2062. [mini_batch, seq_length, output_chanels]
  2063. if batch_first
  2064. else [seq_length, mini_batch, output_chanels]
  2065. )
  2066. output = input.new_empty(out_shape)
  2067. if hx_ is None:
  2068. hy = torch.empty(0, device=input.device)
  2069. else:
  2070. hy = hx_.new_empty(hx_.shape)
  2071. if cx_ is None:
  2072. cy = torch.empty(0, device=input.device)
  2073. else:
  2074. cy = cx_.new_empty(cx_.shape)
  2075. workspace = torch.empty(0, device=input.device, dtype=torch.uint8)
  2076. return output, hy, cy, workspace
  2077. def zero_numel_check_dims(self, dim, fn_name):
  2078. if self.ndim == 0:
  2079. check(
  2080. dim == 0 or dim == -1,
  2081. lambda: f"{fn_name}: Expected reduction dim -1 or 0 for scalar but got {dim}",
  2082. IndexError,
  2083. )
  2084. else:
  2085. check(
  2086. self.size(dim) != 0,
  2087. lambda: f"{fn_name}: Expected reduction dim {dim} to have non-zero size.",
  2088. IndexError,
  2089. )
  2090. # From aten/src/ATen/native/ReduceOps.cpp
  2091. def check_argmax_argmin(name, self, dim):
  2092. if dim is not None:
  2093. dim = maybe_wrap_dim(dim, self.dim())
  2094. zero_numel_check_dims(self, dim, name)
  2095. else:
  2096. check(
  2097. self.numel() != 0,
  2098. lambda: f"{name}: Expected reduction dim to be specified for input.numel() == 0.",
  2099. )
  2100. @register_meta([aten.argmax.default, aten.argmin.default])
  2101. def argmax_argmin_meta(self, dim=None, keepdim=False):
  2102. check_argmax_argmin("argmax", self, dim)
  2103. dims = utils.reduction_dims(self.shape, (dim,) if dim is not None else None)
  2104. shape = _compute_reduction_shape(self, dims, keepdim)
  2105. return self.new_empty(shape, dtype=torch.int64)
  2106. @register_meta(aten.scalar_tensor.default)
  2107. def scalar_tensor(s, dtype=None, layout=None, device=None, pin_memory=None):
  2108. return torch.empty(
  2109. (), dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
  2110. )
  2111. @register_meta(aten.topk.default)
  2112. def topk_meta(self, k, dim=-1, largest=True, sorted=True):
  2113. # From aten/src/ATen/native/Sorting.cpp
  2114. dim = maybe_wrap_dim(dim, self.dim(), wrap_scalar=True)
  2115. check(
  2116. k >= 0 and k <= (self.size(dim) if self.dim() > 0 else 1),
  2117. lambda: "selected index k out of range",
  2118. )
  2119. sliceSize = 1 if self.dim() == 0 else self.size(dim)
  2120. check(k >= 0 and k <= sliceSize, lambda: "k not in range for dimension")
  2121. topKSize = list(self.shape)
  2122. if len(topKSize) > 0:
  2123. topKSize[dim] = k
  2124. return self.new_empty(topKSize), self.new_empty(topKSize, dtype=torch.int64)
  2125. legacy_contiguous_memory_format = torch.contiguous_format
  2126. # From aten/src/ATen/native/cuda/RNN.cu
  2127. def checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace):
  2128. defined_grad = grad_hy if grad_hy is not None else grad_cy
  2129. check(defined_grad.dim() == 2, lambda: "")
  2130. exp_size = defined_grad.size()
  2131. if grad_hy is not None:
  2132. check(grad_hy.size() == exp_size, lambda: "")
  2133. if grad_cy is not None:
  2134. check(grad_cy.size() == exp_size, lambda: "")
  2135. check(cx.size() == exp_size, lambda: "")
  2136. check(cy.size() == exp_size, lambda: "")
  2137. check(workspace.dim() == 2, lambda: "")
  2138. check(workspace.numel() == exp_size[0] * exp_size[1] * 4, lambda: "")
  2139. # From aten/src/ATen/native/cuda/RNN.cu
  2140. @register_meta(aten._thnn_fused_lstm_cell_backward_impl.default)
  2141. def _thnn_fused_lstm_cell_backward_impl(grad_hy, grad_cy, cx, cy, workspace, has_bias):
  2142. if grad_hy is None and grad_cy is None:
  2143. return None, None, None
  2144. checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace)
  2145. grad_gates = torch.empty_like(
  2146. workspace, memory_format=legacy_contiguous_memory_format
  2147. )
  2148. grad_cx = torch.empty_like(cx, memory_format=legacy_contiguous_memory_format)
  2149. grad_bias = grad_gates.sum(0, keepdim=False) if has_bias else None
  2150. return grad_gates, grad_cx, grad_bias
  2151. @register_meta(aten.pixel_shuffle.default)
  2152. def meta_pixel_shuffle(self, upscale_factor):
  2153. assert (
  2154. len(self.shape) > 2 and self.shape[-3] % (upscale_factor * upscale_factor) == 0
  2155. ), f"Invalid input shape for pixel_shuffle: {self.shape} with upscale_factor = {upscale_factor}"
  2156. def is_channels_last(ten):
  2157. return torch._prims_common.suggest_memory_format(ten) == torch.channels_last
  2158. def pick_memory_format():
  2159. if is_channels_last(self):
  2160. if device_hint(self) == "cuda":
  2161. return torch.contiguous_format
  2162. else:
  2163. return torch.channels_last
  2164. elif self.is_contiguous(memory_format=torch.contiguous_format):
  2165. return torch.contiguous_format
  2166. elif self.is_contiguous(memory_format=torch.preserve_format):
  2167. return torch.preserve_format
  2168. C = self.shape[-3] // (upscale_factor * upscale_factor)
  2169. Hr = self.shape[-2] * upscale_factor
  2170. Wr = self.shape[-1] * upscale_factor
  2171. out_shape = (*self.shape[:-3], C, Hr, Wr)
  2172. out = self.new_empty(out_shape)
  2173. out = out.to(memory_format=pick_memory_format()) # type: ignore[call-overload]
  2174. return out
  2175. @register_meta(aten.mkldnn_rnn_layer_backward.default)
  2176. def mkldnn_rnn_layer_backward(
  2177. input,
  2178. weight0,
  2179. weight1,
  2180. weight2,
  2181. weight3,
  2182. hx_,
  2183. cx_tmp,
  2184. output,
  2185. hy_,
  2186. cy_,
  2187. grad_output_r_opt,
  2188. grad_hy_r_opt,
  2189. grad_cy_r_opt,
  2190. reverse,
  2191. mode,
  2192. hidden_size,
  2193. num_layers,
  2194. has_biases,
  2195. train,
  2196. bidirectional,
  2197. batch_sizes,
  2198. batch_first,
  2199. workspace,
  2200. ):
  2201. diff_x = input.new_empty(input.shape)
  2202. diff_hx = hx_.new_empty(hx_.shape)
  2203. diff_cx = cx_tmp.new_empty(cx_tmp.shape)
  2204. diff_w1 = weight0.new_empty(weight0.shape)
  2205. diff_w2 = weight1.new_empty(weight1.shape)
  2206. diff_b = weight2.new_empty(weight2.shape)
  2207. return diff_x, diff_w1, diff_w2, diff_b, diff_b, diff_hx, diff_cx
  2208. @register_meta([aten.bucketize.Tensor, aten.bucketize.Tensor_out])
  2209. @out_wrapper()
  2210. def meta_bucketize(self, boundaries, *, out_int32=False, right=False):
  2211. return torch.empty_like(
  2212. self, dtype=torch.int32 if out_int32 else torch.int64
  2213. ).contiguous()
  2214. # We must also trigger meta registrations from PrimTorch ref
  2215. # decompositions
  2216. import torch._refs
  2217. import torch._refs.nn.functional
  2218. import torch._refs.special
  2219. def activate_meta():
  2220. activate_meta_table = {}
  2221. # For a given op, we pick the most specific decomp function from
  2222. # global_decomp_table in the precedence order of meta > post_autograd > pre_autograd
  2223. for type in ["meta", "post_autograd", "pre_autograd"]:
  2224. registry = global_decomposition_table[type]
  2225. for opo in registry:
  2226. if opo not in activate_meta_table:
  2227. activate_meta_table[opo] = registry[opo]
  2228. for op_overload, fn in activate_meta_table.items():
  2229. assert isinstance(op_overload, OpOverload)
  2230. op_overload.py_impl(torch._C.DispatchKey.Meta)(fn)
  2231. if torch._C._dispatch_has_kernel_for_dispatch_key(
  2232. op_overload.name(), "CompositeImplicitAutograd"
  2233. ):
  2234. # Internally, we shouldn't be registering meta kernels for any operators that
  2235. # have CompositeImplicitAutograd kernels.
  2236. # Instead, we should be letting those decompositions run, and writing meta kernels
  2237. # only for the base operators.
  2238. if op_overload in global_decomposition_table["meta"]:
  2239. raise RuntimeError(
  2240. f"{op_overload} is a CompositeImplicitAutograd op, we shouldn't "
  2241. "register meta function for it. Instead, we should let the decomposition run and write "
  2242. "meta kernels for the base operators."
  2243. )
  2244. pass
  2245. elif op_overload.is_view:
  2246. # Attempting to register a python meta kernel for a view operator.
  2247. # We shouldn't do this, because the output will report as not having aliased storages.
  2248. # All view ops have meta kernels in C++ today, so we should use those instead.
  2249. pass
  2250. elif op_overload.name() in {
  2251. "aten::empty_strided", # causing infinite recursion, test_meta.py
  2252. "aten::clone", # causing infinite recursion
  2253. "aten::_to_copy", # causing infinite recursion, test_serialization.py -k test_tensor_subclass_getstate_overwrite # noqa: B950
  2254. "aten::copy_", # Exception not raised, test_torch.py -k test_storage_meta_errors_cpu_int64 # noqa: B950
  2255. "aten::constant_pad_nd", # requires_grad mismatch, test_ops.py -k test_fake_crossref_backward_amp_istft_cuda_float32 # noqa: B950
  2256. "aten::rot90", # requires_grad mismatch! test_ops.py -k test_fake_crossref_backward_amp_rot90_cuda_float32 # noqa: B950
  2257. "aten::as_strided_scatter", # requires_grad mismatch, test_ops.py -k test_fake_crossref_backward_no_amp_as_strided_scatter_cuda_float32 # noqa: B950
  2258. }:
  2259. pass
  2260. else:
  2261. if "mkldnn::" in op_overload.name():
  2262. _meta_lib_dont_use_me_use_register_meta_for_mkldnn.impl(op_overload, fn)
  2263. elif "mkl::" in op_overload.name():
  2264. _meta_lib_dont_use_me_use_register_meta_for_mkl.impl(op_overload, fn)
  2265. else:
  2266. _meta_lib_dont_use_me_use_register_meta.impl(op_overload, fn)
  2267. activate_meta()