__init__.py 76 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882
  1. import contextlib
  2. import itertools
  3. import math
  4. import operator
  5. import weakref
  6. from enum import Enum
  7. from functools import partial, reduce
  8. from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union
  9. import torch
  10. import torch._prims_common as utils
  11. import torch.library
  12. from torch import sym_float, Tensor, TypedStorage
  13. from torch._C import _get_default_device
  14. from torch._prims.nvfuser_prims import register_nvprims
  15. from torch._prims_common import (
  16. check,
  17. Dim,
  18. DimsSequenceType,
  19. DimsType,
  20. IntLike,
  21. Number,
  22. NumberType,
  23. RETURN_TYPE,
  24. ShapeType,
  25. StrideType,
  26. TensorLike,
  27. TensorLikeType,
  28. type_to_dtype,
  29. )
  30. from torch._prims_common.wrappers import backwards_not_supported
  31. from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
  32. from torch.overrides import handle_torch_function, has_torch_function
  33. from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
  34. prim = torch.library.Library("prims", "DEF")
  35. prim_impl = torch.library.Library("prims", "IMPL", "CompositeExplicitAutograd")
  36. prim_backend_select_impl = torch.library.Library("prims", "IMPL", "BackendSelect")
  37. prim_autograd_impl = torch.library.Library("prims", "IMPL", "Autograd")
  38. prim_meta_impl = torch.library.Library("prims", "IMPL", "Meta")
  39. # Experimental module containing prototype "primitive" operations.
  40. __all__ = [
  41. #
  42. # Common datastructures and helpers
  43. #
  44. "RETURN_TYPE",
  45. #
  46. # Elementwise unary prims
  47. #
  48. "abs",
  49. "acos",
  50. "acosh",
  51. "asin",
  52. "asinh",
  53. "atan",
  54. "atanh",
  55. "cos",
  56. "cosh",
  57. "bessel_i0",
  58. "bessel_i0e",
  59. "bessel_i1",
  60. "bessel_i1e",
  61. "bessel_j0",
  62. "bessel_j1",
  63. "bitwise_not",
  64. "cbrt",
  65. "ceil",
  66. "conj_physical",
  67. "digamma",
  68. "erf",
  69. "erf_inv",
  70. "erfc",
  71. "erfcx",
  72. "exp",
  73. "expm1",
  74. "exp2",
  75. "fill",
  76. "floor",
  77. "imag",
  78. "isfinite",
  79. "lgamma",
  80. "log",
  81. "log1p",
  82. "log2",
  83. "log10",
  84. "ndtri",
  85. "neg",
  86. "real",
  87. "reciprocal",
  88. "round",
  89. "sign",
  90. "signbit",
  91. "sin",
  92. "sinh",
  93. "spherical_bessel_j0",
  94. "sqrt",
  95. "tan",
  96. "tanh",
  97. "trunc",
  98. #
  99. # Elementwise binary prims
  100. #
  101. "add",
  102. "atan2",
  103. "bitwise_and",
  104. "bitwise_or",
  105. "bitwise_xor",
  106. # 'complex', # needs custom meta
  107. "div",
  108. "eq",
  109. "fmax",
  110. "fmin",
  111. "fmod",
  112. "gcd",
  113. "ge",
  114. "gt",
  115. "hypot",
  116. "igamma",
  117. "igammac",
  118. "le",
  119. "lt",
  120. "maximum",
  121. "minimum",
  122. "mul",
  123. "ne",
  124. "nextafter",
  125. "pow",
  126. "remainder",
  127. "rsqrt",
  128. "shift_left",
  129. "shift_right_arithmetic",
  130. "shift_right_logical", # not implemented
  131. "sub",
  132. "zeta",
  133. #
  134. # View prims
  135. #
  136. "as_strided",
  137. "broadcast_in_dim",
  138. "collapse_view",
  139. "conj",
  140. "expand_dims",
  141. "slice",
  142. "slice_in_dim", # implemented using slice -- make this a ref?
  143. "split_dim",
  144. "squeeze",
  145. "transpose",
  146. "view_of",
  147. #
  148. # Functionalized view mutations
  149. #
  150. "as_strided_scatter",
  151. #
  152. # Shape prims
  153. #
  154. "collapse",
  155. "cat",
  156. "reshape",
  157. "rev",
  158. #
  159. # Conditional prims
  160. #
  161. "where",
  162. #
  163. # Data conversion and movement prims
  164. #
  165. "clone",
  166. "convert_element_type",
  167. "device_put",
  168. "item",
  169. "maximum_value",
  170. "minimum_value",
  171. "to_dtype",
  172. "copy_strided",
  173. #
  174. # Inplace prims
  175. #
  176. "copy_to",
  177. "resize",
  178. # "_set", # Commented out, see note below
  179. #
  180. # Reduction prims
  181. #
  182. "amax",
  183. "amin",
  184. "prod",
  185. "sum",
  186. "var",
  187. #
  188. # Tensor Creation Prims
  189. #
  190. "empty_strided",
  191. "scalar_tensor",
  192. "iota",
  193. #
  194. # Linear algebra (linalg) Prims
  195. #
  196. "svd",
  197. #
  198. # Randomness Prims
  199. #
  200. "normal",
  201. "_uniform_helper",
  202. #
  203. # FFT prims
  204. #
  205. "fft_r2c",
  206. "fft_c2c",
  207. "fft_c2r",
  208. ]
  209. def TensorMeta(
  210. tensorlike: Optional[Union[NumberType, torch.Tensor]] = None,
  211. *,
  212. shape: Optional[ShapeType] = None,
  213. strides: Optional[StrideType] = None,
  214. dtype: Optional[torch.dtype] = None,
  215. device: Optional[Union[torch.device, str]] = None,
  216. ):
  217. if isinstance(tensorlike, Number):
  218. assert not shape and (shape is None or isinstance(shape, Sequence))
  219. assert not strides and (strides is None or isinstance(strides, Sequence))
  220. inferred_shape: Tuple[int, ...] = ()
  221. inferred_strides: Tuple[int, ...] = ()
  222. inferred_dtype = type_to_dtype(type(tensorlike))
  223. inferred_device = torch.device("cpu")
  224. # TODO: This looks wrong, a number that is wrapped into a tensor
  225. # needs to behave differently than a scalar tensor for type
  226. # promotion purposes
  227. elif tensorlike is not None:
  228. assert isinstance(tensorlike, torch.Tensor)
  229. inferred_shape = tuple(tensorlike.shape)
  230. inferred_strides = tuple(tensorlike.stride())
  231. inferred_dtype = tensorlike.dtype
  232. inferred_device = tensorlike.device
  233. else:
  234. # If no tensorlike "example" is given then all metadata
  235. # must be provided explicitly
  236. assert shape is not None
  237. assert strides is not None
  238. assert dtype is not None
  239. assert device is not None
  240. shape = inferred_shape if shape is None else tuple(shape)
  241. strides = inferred_strides if strides is None else tuple(strides)
  242. dtype = inferred_dtype if dtype is None else dtype
  243. device = inferred_device if device is None else device
  244. if isinstance(device, str):
  245. device = torch.device(device)
  246. return torch.empty_strided(shape, strides, dtype=dtype, device=device)
  247. def _make_prim(
  248. *,
  249. schema: str,
  250. return_type: Union[RETURN_TYPE, Tuple[RETURN_TYPE, ...]],
  251. meta: Callable,
  252. impl_aten: Callable,
  253. doc: str,
  254. ):
  255. """
  256. Creates a primitive operation.
  257. """
  258. prim.define(schema)
  259. def _prim_impl(*args, **kwargs):
  260. # always run the meta function because aten implementation will
  261. # typically accept more inputs (e.g., it will do promotion and
  262. # broadcasting) which we want to reject
  263. meta(*args, **kwargs)
  264. return impl_aten(*args, **kwargs)
  265. # Right now prims don't support autograd (we can and should add an
  266. # argument that provides an implementation for backward here.) Because we
  267. # don't have derivative formulas, we must setup a custom autograd function
  268. # that raises an error if backwards is invoked
  269. def _autograd_impl(*args, **kwargs):
  270. return backwards_not_supported(_prim)(*args, **kwargs)
  271. def _backend_select_impl(*args, **kwargs):
  272. if kwargs.get("device") and kwargs["device"].type == "meta":
  273. return meta(*args, **kwargs)
  274. else:
  275. return _prim_impl(*args, **kwargs)
  276. name = schema.split("(")[0]
  277. prim_impl.impl(name, _prim_impl)
  278. prim_autograd_impl.impl(name, _autograd_impl)
  279. prim_meta_impl.impl(name, meta)
  280. _prim_packet = getattr(torch._ops.ops.prims, name)
  281. _prim = _prim_packet.default
  282. from torch._subclasses.fake_tensor import contains_tensor_types
  283. if not any(contains_tensor_types(a.type) for a in _prim._schema.arguments):
  284. prim_backend_select_impl.impl(name, _backend_select_impl)
  285. for p in (_prim_packet, _prim):
  286. p.__doc__ = doc
  287. p.return_type = return_type # type: ignore[attr-defined]
  288. p.schema = schema
  289. p.prim_impl = _prim_impl
  290. p.prim_meta_impl = meta
  291. p.impl_aten = impl_aten
  292. return _prim
  293. class ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND(Enum):
  294. DEFAULT = (0,)
  295. ALWAYS_BOOL = (2,)
  296. COMPLEX_TO_FLOAT = (3,)
  297. # TODO: implement dtype validation here, too, or on the corresponding refs
  298. def _elementwise_meta(
  299. *args,
  300. type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND,
  301. args_with_fixed_dtypes: Tuple[TensorLikeType, ...] = None,
  302. ) -> FakeTensor:
  303. """
  304. Meta function for elementwise operations that produce outputs in the same dtype
  305. as their inputs.
  306. Stride logic is currently incorrect.
  307. """
  308. assert len(args) > 0
  309. utils.check_same_dtype(*args)
  310. args_ = list(args)
  311. if args_with_fixed_dtypes is not None:
  312. args_ = list(args_with_fixed_dtypes) + args_
  313. utils.check_same_device(*args_, allow_cpu_scalar_tensors=True)
  314. utils.check_same_shape(*args_, allow_cpu_scalar_tensors=True)
  315. strides = utils.compute_elementwise_output_strides(*args_)
  316. shape = utils.extract_shape(*args_, allow_cpu_scalar_tensors=True)
  317. # Acquires the dtype
  318. dtype = None
  319. scalar_type = None
  320. for arg in args:
  321. if isinstance(arg, TensorLike):
  322. if not utils.is_cpu_scalar_tensor(arg):
  323. dtype = arg.dtype
  324. break
  325. else:
  326. dtype = arg.dtype
  327. elif isinstance(arg, Number):
  328. scalar_type = type(arg)
  329. if dtype is None and scalar_type is not None:
  330. dtype = utils.type_to_dtype(scalar_type)
  331. # Acquires the device (if it exists) or number
  332. device = None
  333. number = None
  334. for arg in args_:
  335. if isinstance(arg, TensorLike):
  336. if utils.is_cpu_scalar_tensor(arg):
  337. if device is None:
  338. device = arg.device
  339. # keep going, in case there is a cuda tensor later
  340. else:
  341. device = arg.device
  342. break
  343. elif isinstance(arg, Number):
  344. if number is None:
  345. number = arg
  346. # NOTE: type promotion behavior here is mostly hidden from tests because
  347. # references will typically handle the type promotion properly even if this doesn't
  348. # (but getting it wrong will cause too many casts to be inserted in traces!)
  349. if device is not None:
  350. assert dtype is not None
  351. if type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT:
  352. dtype = dtype
  353. elif type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL:
  354. dtype = torch.bool
  355. elif type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT:
  356. if utils.is_complex_dtype(dtype):
  357. dtype = utils.corresponding_real_dtype(dtype)
  358. else:
  359. dtype = dtype
  360. return TensorMeta(device=device, shape=shape, strides=strides, dtype=dtype)
  361. # Number case
  362. # TODO: fix number type promotion (bool, complex->float)
  363. # For now for symint/float, just implementing the common / simple cases of (int,float,symint,symfloat)
  364. seen_float = False
  365. if isinstance(number, (torch.SymInt, torch.SymFloat)):
  366. for a in args:
  367. assert isinstance(a, (int, float, torch.SymInt, torch.SymFloat)), "NYI"
  368. seen_float = seen_float or isinstance(a, (float, torch.SymFloat))
  369. if seen_float:
  370. number = sym_float(number)
  371. return TensorMeta(number) # type: ignore[arg-type]
  372. def _complex_only_elementwise_meta(*args, **kwargs):
  373. utils.check(
  374. utils.is_complex_dtype(args[0].dtype), lambda: "Only complex dtype is supported"
  375. )
  376. return _elementwise_meta(*args, **kwargs)
  377. def _make_elementwise_unary_prim(
  378. name: str, *, type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND, **kwargs
  379. ):
  380. """
  381. Creates an elementwise unary prim.
  382. """
  383. return _make_prim(
  384. schema=f"{name}(Tensor self) -> Tensor",
  385. meta=partial(_elementwise_meta, type_promotion=type_promotion),
  386. return_type=RETURN_TYPE.NEW,
  387. **kwargs,
  388. )
  389. def _make_elementwise_binary_prim(
  390. name: str, *, type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND, **kwargs
  391. ):
  392. """
  393. Creates an elementwise binary prim.
  394. """
  395. return _make_prim(
  396. schema=f"{name}(Tensor self, Tensor other) -> Tensor",
  397. meta=partial(_elementwise_meta, type_promotion=type_promotion),
  398. return_type=RETURN_TYPE.NEW,
  399. **kwargs,
  400. )
  401. def _not_impl(*args, **kwargs):
  402. raise NotImplementedError
  403. #
  404. # Elementwise unary operations
  405. #
  406. abs = _make_elementwise_unary_prim(
  407. "abs",
  408. impl_aten=torch.abs,
  409. doc="",
  410. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
  411. )
  412. acos = _make_elementwise_unary_prim(
  413. "acos",
  414. impl_aten=torch.acos,
  415. doc="",
  416. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  417. )
  418. acosh = _make_elementwise_unary_prim(
  419. "acosh",
  420. impl_aten=torch.acosh,
  421. doc="",
  422. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  423. )
  424. asin = _make_elementwise_unary_prim(
  425. "asin",
  426. impl_aten=torch.asin,
  427. doc="",
  428. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  429. )
  430. asinh = _make_elementwise_unary_prim(
  431. "asinh",
  432. impl_aten=torch.asinh,
  433. doc="",
  434. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  435. )
  436. atan = _make_elementwise_unary_prim(
  437. "atan",
  438. impl_aten=torch.atan,
  439. doc="",
  440. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  441. )
  442. atanh = _make_elementwise_unary_prim(
  443. "atanh",
  444. impl_aten=torch.atanh,
  445. doc="",
  446. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  447. )
  448. cos = _make_elementwise_unary_prim(
  449. "cos",
  450. impl_aten=torch.cos,
  451. doc="",
  452. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  453. )
  454. cosh = _make_elementwise_unary_prim(
  455. "cosh",
  456. impl_aten=torch.cosh,
  457. doc="",
  458. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  459. )
  460. bessel_j0 = _make_elementwise_unary_prim(
  461. "bessel_j0",
  462. impl_aten=torch.special.bessel_j0,
  463. doc="",
  464. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  465. )
  466. bessel_j1 = _make_elementwise_unary_prim(
  467. "bessel_j1",
  468. impl_aten=torch.special.bessel_j1,
  469. doc="",
  470. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  471. )
  472. bessel_i0 = _make_elementwise_unary_prim(
  473. "bessel_i0",
  474. impl_aten=torch.i0,
  475. doc="",
  476. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  477. )
  478. bessel_i0e = _make_elementwise_unary_prim(
  479. "bessel_i0e",
  480. impl_aten=torch.special.i0e,
  481. doc="",
  482. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  483. )
  484. bessel_i1 = _make_elementwise_unary_prim(
  485. "bessel_i1",
  486. impl_aten=torch.special.i1,
  487. doc="",
  488. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  489. )
  490. bessel_i1e = _make_elementwise_unary_prim(
  491. "bessel_i1e",
  492. impl_aten=torch.special.i1e,
  493. doc="",
  494. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  495. )
  496. bitwise_not = _make_elementwise_unary_prim(
  497. "bitwise_not",
  498. impl_aten=torch.bitwise_not,
  499. doc="",
  500. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  501. )
  502. def _cbrt_aten(a: torch.Tensor) -> Tensor:
  503. utils.check(
  504. not a.is_complex(),
  505. lambda: "cbrt: Complex inputs not supported. Consider calling torch.pow(a, 1.0/3.0)",
  506. )
  507. # Returns the real cubic root of the number.
  508. # Note that if a < 0, pow(a, (1. / 3.)) returns th complex number
  509. # exp(1/3 * log(a)) = exp(1/3 * (log(abs(a)) + pi*i)) = cbrt(abs(a)) * e^{pi/3*i}
  510. # which is a complex number.
  511. # For more info see the section Note in
  512. # https://en.cppreference.com/w/cpp/numeric/math/cbrt
  513. return torch.copysign(torch.pow(a.abs(), 1 / 3), a)
  514. cbrt = _make_elementwise_unary_prim(
  515. "cbrt",
  516. impl_aten=_cbrt_aten,
  517. doc="",
  518. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  519. )
  520. ceil = _make_elementwise_unary_prim(
  521. "ceil",
  522. impl_aten=torch.ceil,
  523. doc="",
  524. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  525. )
  526. def _conj_physical_meta(input: TensorLikeType) -> TensorLikeType:
  527. if not input.dtype.is_complex:
  528. raise RuntimeError("prims.conj_physical is only defined for complex dtypes")
  529. strides = utils.compute_elementwise_output_strides(input)
  530. return TensorMeta(input, strides=strides)
  531. conj_physical = _make_prim(
  532. schema="conj_physical(Tensor self) -> Tensor",
  533. meta=_conj_physical_meta,
  534. impl_aten=torch._conj_physical,
  535. doc="Returns the physical conjugation of a complex tensor",
  536. return_type=RETURN_TYPE.NEW,
  537. )
  538. def _clone_meta(
  539. input: TensorLikeType, *, memory_format: torch.memory_format = torch.preserve_format
  540. ) -> TensorLikeType:
  541. if memory_format != torch.preserve_format:
  542. return torch.empty(
  543. input.shape,
  544. dtype=input.dtype,
  545. layout=input.layout,
  546. device=input.device,
  547. requires_grad=input.requires_grad,
  548. memory_format=memory_format,
  549. )
  550. # memory_format == torch.preserve_format
  551. strides = utils.compute_elementwise_output_strides(input)
  552. return torch.empty_strided(
  553. input.shape,
  554. strides,
  555. dtype=input.dtype,
  556. layout=input.layout,
  557. device=input.device,
  558. requires_grad=input.requires_grad,
  559. )
  560. clone = _make_prim(
  561. schema="clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor",
  562. meta=_clone_meta,
  563. impl_aten=torch.clone,
  564. doc="Returns the copy of a tensor",
  565. return_type=RETURN_TYPE.NEW,
  566. )
  567. digamma = _make_elementwise_unary_prim(
  568. "digamma",
  569. impl_aten=torch.digamma,
  570. doc="",
  571. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  572. )
  573. erf = _make_elementwise_unary_prim(
  574. "erf",
  575. impl_aten=torch.erf,
  576. doc="",
  577. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  578. )
  579. erf_inv = _make_elementwise_unary_prim(
  580. "erf_inv",
  581. impl_aten=torch.special.erfinv,
  582. doc="",
  583. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  584. )
  585. erfc = _make_elementwise_unary_prim(
  586. "erfc",
  587. impl_aten=torch.special.erfc,
  588. doc="",
  589. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  590. )
  591. erfcx = _make_elementwise_unary_prim(
  592. "erfcx",
  593. impl_aten=torch.special.erfcx,
  594. doc="",
  595. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  596. )
  597. exp = _make_elementwise_unary_prim(
  598. "exp",
  599. impl_aten=torch.exp,
  600. doc="",
  601. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  602. )
  603. expm1 = _make_elementwise_unary_prim(
  604. "expm1",
  605. impl_aten=torch.special.expm1,
  606. doc="",
  607. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  608. )
  609. exp2 = _make_elementwise_unary_prim(
  610. "exp2",
  611. impl_aten=torch.special.exp2,
  612. doc="",
  613. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  614. )
  615. def _fill_meta(a: TensorLikeType, value: NumberType) -> TensorLikeType:
  616. return _elementwise_meta(
  617. a, type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT
  618. )
  619. # NOTE: fill uses _make_prim directly because it has a value parameter
  620. fill = _make_prim(
  621. schema="fill(Tensor self, Scalar value) -> Tensor",
  622. return_type=RETURN_TYPE.NEW,
  623. meta=_fill_meta,
  624. impl_aten=torch.fill,
  625. doc="",
  626. )
  627. floor = _make_elementwise_unary_prim(
  628. "floor",
  629. impl_aten=torch.floor,
  630. doc="",
  631. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  632. )
  633. imag = _make_prim(
  634. schema="imag(Tensor self) -> Tensor",
  635. meta=partial(
  636. _complex_only_elementwise_meta,
  637. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
  638. ),
  639. return_type=RETURN_TYPE.VIEW,
  640. impl_aten=torch.imag,
  641. doc="",
  642. )
  643. isfinite = _make_elementwise_unary_prim(
  644. "isfinite",
  645. impl_aten=torch.isfinite,
  646. doc="",
  647. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
  648. )
  649. lgamma = _make_elementwise_unary_prim(
  650. "lgamma",
  651. impl_aten=torch.lgamma,
  652. doc="",
  653. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  654. )
  655. log = _make_elementwise_unary_prim(
  656. "log",
  657. impl_aten=torch.log,
  658. doc="",
  659. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  660. )
  661. log1p = _make_elementwise_unary_prim(
  662. "log1p",
  663. impl_aten=torch.log1p,
  664. doc="",
  665. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  666. )
  667. log2 = _make_elementwise_unary_prim(
  668. "log2",
  669. impl_aten=torch.log2,
  670. doc="",
  671. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  672. )
  673. log10 = _make_elementwise_unary_prim(
  674. "log10",
  675. impl_aten=torch.log10,
  676. doc="",
  677. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  678. )
  679. real = _make_prim(
  680. schema="real(Tensor self) -> Tensor",
  681. meta=partial(
  682. _complex_only_elementwise_meta,
  683. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
  684. ),
  685. return_type=RETURN_TYPE.VIEW,
  686. impl_aten=torch.real,
  687. doc="",
  688. )
  689. reciprocal = _make_elementwise_unary_prim(
  690. "reciprocal",
  691. impl_aten=torch.reciprocal,
  692. doc="",
  693. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  694. )
  695. ndtri = _make_elementwise_unary_prim(
  696. "ndtri",
  697. impl_aten=torch.special.ndtri,
  698. doc="",
  699. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  700. )
  701. neg = _make_elementwise_unary_prim(
  702. "neg",
  703. impl_aten=torch.neg,
  704. doc="",
  705. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  706. )
  707. round = _make_elementwise_unary_prim(
  708. "round",
  709. impl_aten=torch.round,
  710. doc="",
  711. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  712. )
  713. rsqrt = _make_elementwise_unary_prim(
  714. "rsqrt",
  715. impl_aten=torch.rsqrt,
  716. doc="",
  717. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  718. )
  719. sign = _make_elementwise_unary_prim(
  720. "sign",
  721. impl_aten=torch.sign,
  722. doc="",
  723. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  724. )
  725. signbit = _make_elementwise_unary_prim(
  726. "signbit",
  727. impl_aten=torch.signbit,
  728. doc="",
  729. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  730. )
  731. sin = _make_elementwise_unary_prim(
  732. "sin",
  733. impl_aten=torch.sin,
  734. doc="",
  735. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  736. )
  737. sinh = _make_elementwise_unary_prim(
  738. "sinh",
  739. impl_aten=torch.sinh,
  740. doc="",
  741. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  742. )
  743. spherical_bessel_j0 = _make_elementwise_unary_prim(
  744. "spherical_bessel_j0",
  745. impl_aten=torch.special.spherical_bessel_j0,
  746. doc="",
  747. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  748. )
  749. sqrt = _make_elementwise_unary_prim(
  750. "sqrt",
  751. impl_aten=torch.sqrt,
  752. doc="",
  753. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  754. )
  755. tan = _make_elementwise_unary_prim(
  756. "tan",
  757. impl_aten=torch.tan,
  758. doc="",
  759. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  760. )
  761. tanh = _make_elementwise_unary_prim(
  762. "tanh",
  763. impl_aten=torch.tanh,
  764. doc="",
  765. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  766. )
  767. trunc = _make_elementwise_unary_prim(
  768. "trunc",
  769. impl_aten=torch.trunc,
  770. doc="",
  771. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  772. )
  773. #
  774. # Elementwise binary operations
  775. #
  776. add = _make_elementwise_binary_prim(
  777. name="add",
  778. impl_aten=torch.add,
  779. doc="",
  780. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  781. )
  782. atan2 = _make_elementwise_binary_prim(
  783. name="atan2",
  784. impl_aten=torch.atan2,
  785. doc="",
  786. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  787. )
  788. bitwise_and = _make_elementwise_binary_prim(
  789. "bitwise_and",
  790. impl_aten=torch.bitwise_and,
  791. doc="",
  792. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  793. )
  794. bitwise_or = _make_elementwise_binary_prim(
  795. "bitwise_or",
  796. impl_aten=torch.bitwise_or,
  797. doc="",
  798. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  799. )
  800. bitwise_xor = _make_elementwise_binary_prim(
  801. "bitwise_xor",
  802. impl_aten=torch.bitwise_xor,
  803. doc="",
  804. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  805. )
  806. # TODO: complex needs a special meta to account for its float -> complex behavior
  807. # complex = _make_elementwise_binary_prim(
  808. # impl_aten=torch.complex,
  809. # doc="",
  810. # )
  811. # div prim performs truncation division on integer inputs
  812. # and true division for floating and complex inputs
  813. def _div_aten(a, b):
  814. is_integral = isinstance(a, (bool, int, torch.SymInt)) or (
  815. isinstance(a, torch.Tensor) and utils.is_integer_dtype(a.dtype)
  816. )
  817. if is_integral:
  818. return torch.div(a, b, rounding_mode="trunc")
  819. else:
  820. return torch.true_divide(a, b)
  821. div = _make_elementwise_binary_prim(
  822. "div",
  823. impl_aten=_div_aten,
  824. doc="",
  825. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  826. )
  827. eq = _make_elementwise_binary_prim(
  828. "eq",
  829. impl_aten=torch.eq,
  830. doc="",
  831. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
  832. )
  833. fmax = _make_elementwise_binary_prim(
  834. "fmax",
  835. impl_aten=torch.fmax,
  836. doc="",
  837. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  838. )
  839. fmin = _make_elementwise_binary_prim(
  840. "fmin",
  841. impl_aten=torch.fmin,
  842. doc="",
  843. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  844. )
  845. fmod = _make_elementwise_binary_prim(
  846. "fmod",
  847. impl_aten=torch.fmod,
  848. doc="",
  849. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  850. )
  851. gcd = _make_elementwise_binary_prim(
  852. "gcd",
  853. impl_aten=torch.gcd,
  854. doc="",
  855. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  856. )
  857. ge = _make_elementwise_binary_prim(
  858. "ge",
  859. impl_aten=torch.ge,
  860. doc="",
  861. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
  862. )
  863. gt = _make_elementwise_binary_prim(
  864. "gt",
  865. impl_aten=torch.gt,
  866. doc="",
  867. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
  868. )
  869. hypot = _make_elementwise_binary_prim(
  870. "hypot",
  871. impl_aten=torch.hypot,
  872. doc="",
  873. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  874. )
  875. igamma = _make_elementwise_binary_prim(
  876. "igamma",
  877. impl_aten=torch.special.gammainc,
  878. doc="",
  879. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  880. )
  881. igammac = _make_elementwise_binary_prim(
  882. "igammac",
  883. impl_aten=torch.special.gammaincc,
  884. doc="",
  885. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  886. )
  887. le = _make_elementwise_binary_prim(
  888. "le",
  889. impl_aten=torch.le,
  890. doc="",
  891. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
  892. )
  893. lt = _make_elementwise_binary_prim(
  894. "lt",
  895. impl_aten=torch.lt,
  896. doc="",
  897. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
  898. )
  899. # Note: the following impls are because torch.maximum and torch.mininum do not support scalar inputs
  900. def _maximum_aten(
  901. a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
  902. ) -> TensorLikeType:
  903. if isinstance(a, TensorLike) and isinstance(b, Number):
  904. b = scalar_tensor(b, dtype=a.dtype, device=a.device)
  905. elif isinstance(b, TensorLike) and isinstance(a, Number):
  906. a = scalar_tensor(a, dtype=b.dtype, device=b.device)
  907. return torch.maximum(a, b) # type: ignore[arg-type]
  908. maximum = _make_elementwise_binary_prim(
  909. "maximum",
  910. impl_aten=_maximum_aten,
  911. doc="",
  912. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  913. )
  914. def _minimum_aten(
  915. a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
  916. ) -> TensorLikeType:
  917. if isinstance(a, TensorLike) and isinstance(b, Number):
  918. b = scalar_tensor(b, dtype=a.dtype, device=a.device)
  919. elif isinstance(b, TensorLike) and isinstance(a, Number):
  920. a = scalar_tensor(a, dtype=b.dtype, device=b.device)
  921. return torch.minimum(a, b) # type: ignore[arg-type]
  922. minimum = _make_elementwise_binary_prim(
  923. "minimum",
  924. impl_aten=_minimum_aten,
  925. doc="",
  926. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  927. )
  928. mul = _make_elementwise_binary_prim(
  929. "mul",
  930. impl_aten=torch.mul,
  931. doc="",
  932. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  933. )
  934. ne = _make_elementwise_binary_prim(
  935. "ne",
  936. impl_aten=torch.ne,
  937. doc="",
  938. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
  939. )
  940. nextafter = _make_elementwise_binary_prim(
  941. "nextafter",
  942. impl_aten=torch.nextafter,
  943. doc="",
  944. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  945. )
  946. pow = _make_elementwise_binary_prim(
  947. "pow",
  948. impl_aten=torch.pow,
  949. doc="",
  950. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  951. )
  952. remainder = _make_elementwise_binary_prim(
  953. "remainder",
  954. impl_aten=torch.remainder,
  955. doc="",
  956. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  957. )
  958. shift_left = _make_elementwise_binary_prim(
  959. "shift_left",
  960. impl_aten=torch.bitwise_left_shift,
  961. doc="",
  962. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  963. )
  964. shift_right_arithmetic = _make_elementwise_binary_prim(
  965. "shift_right_arithmetic",
  966. impl_aten=torch.bitwise_right_shift,
  967. doc="",
  968. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  969. )
  970. shift_right_logical = _not_impl
  971. sub = _make_elementwise_binary_prim(
  972. "sub",
  973. impl_aten=torch.sub,
  974. doc="",
  975. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  976. )
  977. zeta = _make_elementwise_binary_prim(
  978. "zeta",
  979. impl_aten=torch.special.zeta,
  980. doc="",
  981. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  982. )
  983. #
  984. # View operations
  985. def _as_strided_meta(
  986. a: TensorLikeType, size: ShapeType, stride: StrideType, storage_offset: int
  987. ) -> TensorLikeType:
  988. assert len(size) == len(stride)
  989. assert storage_offset >= 0
  990. utils.validate_strides(stride)
  991. utils.validate_shape(size)
  992. if reduce(operator.mul, size) == 0:
  993. # NOTE: This special case is to avoid having to acquire the storage below
  994. # as_strided to shapes with no elements are trivially valid, so it's OK
  995. pass
  996. elif isinstance(a, torch.Tensor):
  997. utils.check_in_bounds_for_storage(
  998. a._typed_storage(), size, stride, storage_offset
  999. )
  1000. return torch.as_strided(a, size, stride, storage_offset)
  1001. def _as_strided_aten(
  1002. a: Tensor, size: ShapeType, stride: StrideType, storage_offset: int
  1003. ) -> Tensor:
  1004. return torch.as_strided(a, size, stride, storage_offset)
  1005. _as_strided_doc = """
  1006. Creates a view of the tensor with the given shape (size), strides (stride) and
  1007. storage offset (storage_offset).
  1008. """
  1009. as_strided = _make_prim(
  1010. schema="as_strided(Tensor(a!) a, SymInt[] size, SymInt[] stride, SymInt storage_offset) -> Tensor(a!)",
  1011. meta=_as_strided_meta,
  1012. impl_aten=_as_strided_aten,
  1013. return_type=RETURN_TYPE.VIEW,
  1014. doc=_as_strided_doc,
  1015. )
  1016. def _broadcast_in_dim_meta(
  1017. a: TensorLikeType, shape: ShapeType, broadcast_dimensions: Sequence[int]
  1018. ):
  1019. # Type checks
  1020. assert isinstance(a, TensorLike)
  1021. assert isinstance(shape, Sequence)
  1022. assert isinstance(broadcast_dimensions, Sequence)
  1023. # every dimension must be accounted for
  1024. assert a.ndim == len(broadcast_dimensions)
  1025. # broadcast shape must have weakly more dimensions
  1026. assert len(shape) >= a.ndim
  1027. # broadcast_dimensions must be an ascending sequence
  1028. # (no relative reordering of dims) of integers and
  1029. # each dimension must be within the new shape
  1030. def _greater_than_reduce(acc, x):
  1031. assert isinstance(x, Dim)
  1032. assert x > acc
  1033. assert x < len(shape)
  1034. return x
  1035. reduce(lambda acc, x: _greater_than_reduce(acc, x), broadcast_dimensions, -1)
  1036. # shape must be broadcastable to
  1037. for idx, new_idx in enumerate(broadcast_dimensions):
  1038. assert a.shape[idx] == 1 or a.shape[idx] == shape[new_idx]
  1039. new_strides = []
  1040. original_idx = 0
  1041. for idx in range(len(shape)):
  1042. if idx in broadcast_dimensions:
  1043. # Assigns a stride of zero to dimensions
  1044. # which were actually broadcast
  1045. if a.shape[original_idx] != shape[idx]:
  1046. new_strides.append(0)
  1047. else:
  1048. new_strides.append(a.stride()[original_idx])
  1049. original_idx = original_idx + 1
  1050. else:
  1051. if shape[idx] != 1:
  1052. new_strides.append(0)
  1053. elif original_idx == a.ndim:
  1054. new_strides.append(1)
  1055. else:
  1056. new_strides.append(a.stride()[original_idx] * a.size()[original_idx])
  1057. return a.as_strided(shape, new_strides, a.storage_offset())
  1058. def _broadcast_in_dim_aten(a, shape, broadcast_dimensions):
  1059. s = list(shape)
  1060. for broadcast_dimension in broadcast_dimensions:
  1061. s[broadcast_dimension] = -1
  1062. v = a
  1063. for idx, x in enumerate(s):
  1064. if x != -1:
  1065. v = v.unsqueeze(idx)
  1066. return v.expand(shape)
  1067. _broadcast_in_dim_doc = """
  1068. Creates a view of a with the specified shape.
  1069. Allows adding dimensions of any length and broadcasting
  1070. dimensions of length one in a to any length.
  1071. The location of the broadcast dimensions must be specified
  1072. using the broadcast_dimensions argument. Changing the
  1073. relative order of dimensions is not supported.
  1074. """
  1075. broadcast_in_dim = _make_prim(
  1076. schema="broadcast_in_dim(Tensor(a) a, SymInt[] shape, int[] broadcast_dimensions) -> Tensor(a)",
  1077. meta=_broadcast_in_dim_meta,
  1078. impl_aten=_broadcast_in_dim_aten,
  1079. return_type=RETURN_TYPE.VIEW,
  1080. doc=_broadcast_in_dim_doc,
  1081. )
  1082. def _collapse_view_helper(
  1083. a: TensorLikeType, start: int, end: int
  1084. ) -> Tuple[Optional[ShapeType], Optional[StrideType]]:
  1085. assert isinstance(a, TensorLike)
  1086. # Special-case for zero dimensional tensors
  1087. if a.ndim == 0:
  1088. shape = (1,)
  1089. strides = (1,)
  1090. else:
  1091. shape = a.shape # type: ignore[assignment]
  1092. strides = a.stride() # type: ignore[assignment]
  1093. utils.validate_idx(len(shape), start)
  1094. utils.validate_exclusive_idx(len(shape), end)
  1095. # Verifies end is strictly greater than start
  1096. # (Collapse requires a non-empty interval)
  1097. if end <= start:
  1098. msg = "Attempting to collapse but end, {0}, is less than or equal to start, {1}!".format(
  1099. end, start
  1100. )
  1101. raise ValueError(msg)
  1102. if a.ndim == 0 or (end - 1 == start):
  1103. return shape, strides
  1104. length = shape[end - 1]
  1105. stride = strides[end - 1]
  1106. for idx in reversed(range(start, end - 1)):
  1107. if shape[idx] == 0 or shape[idx + 1] == 0:
  1108. length = 0
  1109. stride = 0
  1110. break
  1111. if shape[idx] == 1:
  1112. continue
  1113. length = length * shape[idx]
  1114. stride = min(stride, strides[idx])
  1115. if (
  1116. a.numel() > 0
  1117. and shape[idx + 1] != 1
  1118. and not (strides[idx] == strides[idx + 1] * shape[idx + 1])
  1119. ):
  1120. return None, None
  1121. new_shape = shape[:start] + (length,) + shape[end:]
  1122. new_strides = strides[:start] + (stride,) + strides[end:]
  1123. # NOTE: when the input has no elements it's restrided as if it were contiguous
  1124. if a.numel() == 0:
  1125. new_strides = utils.make_contiguous_strides_for(new_shape)
  1126. return new_shape, new_strides
  1127. def _collapse_view_meta(a: TensorLikeType, start: int, end: int) -> TensorLikeType:
  1128. new_shape, new_strides = _collapse_view_helper(a, start, end)
  1129. if new_shape is None:
  1130. msg = "Attempting to view a collapsed tensor, but no such view exists!"
  1131. raise ValueError(msg)
  1132. if new_strides is None:
  1133. return a.view(new_shape)
  1134. else:
  1135. return a.as_strided(new_shape, new_strides, a.storage_offset())
  1136. def _collapse_view_aten(a: Tensor, start: int, end: int) -> Tensor:
  1137. # Special-cases zero-dim tensors
  1138. if a.ndim == 0:
  1139. shape = (1,)
  1140. else:
  1141. shape = a.shape # type: ignore[assignment]
  1142. dim_length = 1
  1143. for idx in range(start, end):
  1144. dim_length = dim_length * shape[idx]
  1145. new_shape = shape[0:start] + (dim_length,) + shape[end:]
  1146. return a.view(new_shape)
  1147. _collapse_view_doc = """
  1148. Creates a view of a with the dimensions between
  1149. start (inclusive) and end (exclusive) merged into a
  1150. single dimension.
  1151. If it's not possible to take such a view then an error
  1152. is thrown. See collapse instead.
  1153. The dimensions can be merged if and only if
  1154. they are all "nested" with each other. That is, they all
  1155. have the property that
  1156. stride[i] = stride[i+1] * shape[i+1]
  1157. for all i in [start, end - 1).
  1158. """
  1159. collapse_view = _make_prim(
  1160. schema="collapse_view(Tensor(a) a, int start, int end) -> Tensor(a)",
  1161. meta=_collapse_view_meta,
  1162. impl_aten=_collapse_view_aten,
  1163. return_type=RETURN_TYPE.VIEW,
  1164. doc=_collapse_view_doc,
  1165. )
  1166. def _conj_meta(a: TensorLikeType) -> TensorLikeType:
  1167. if not a.dtype.is_complex:
  1168. raise RuntimeError("Expected complex dtype in prims.conj")
  1169. return a.as_strided(a.shape, a.stride(), a.storage_offset())
  1170. _conj_doc = """
  1171. Returns a conjugated view of the original tensor
  1172. """
  1173. conj = _make_prim(
  1174. schema="conj(Tensor(a) a) -> Tensor(a)",
  1175. meta=_conj_meta,
  1176. impl_aten=torch.conj,
  1177. return_type=RETURN_TYPE.VIEW,
  1178. doc=_conj_doc,
  1179. )
  1180. def expand_dims(
  1181. a: TensorLikeType, dimensions: DimsSequenceType, ndim=None
  1182. ) -> TensorLikeType:
  1183. """
  1184. Creates a view of a with a.ndim + len(dimensions) dimensions, with new
  1185. dimensions of length one at the dimensions specified by dimensions.
  1186. """
  1187. if ndim is not None:
  1188. # TODO: this is only here to support the unsqueeze ref
  1189. dims = sorted(utils.canonicalize_dims(ndim, dimensions)) # type: ignore[arg-type]
  1190. else:
  1191. dims = sorted(utils.canonicalize_dims(a.ndim, dimensions)) # type: ignore[arg-type]
  1192. if len(set(dims)) != len(dims):
  1193. msg = "Received duplicate dimensions to expand in {0}".format(str(dimensions))
  1194. raise ValueError(msg)
  1195. new_shape = list(a.shape)
  1196. for idx in dims:
  1197. new_shape.insert(idx, 1)
  1198. broadcast_dimensions = [
  1199. idx for idx in range(len(new_shape)) if idx not in dimensions
  1200. ]
  1201. return broadcast_in_dim(a, new_shape, broadcast_dimensions)
  1202. # Note: saves the Python slice object because we're about to clobber its name with the slice prim
  1203. pyslice: Type[slice] = slice # type: ignore[has-type]
  1204. def _slice_meta(
  1205. a: TensorLikeType,
  1206. start_indices: DimsSequenceType,
  1207. limit_indices: DimsSequenceType,
  1208. strides: Optional[StrideType] = None,
  1209. ) -> TensorLikeType:
  1210. _strides = strides if strides is not None else [1] * len(start_indices)
  1211. if a.ndim != len(start_indices):
  1212. msg = "Attempting to slice tensor of rank {0} with start_indices of length {1}!".format(
  1213. a.ndim, len(start_indices)
  1214. )
  1215. raise ValueError(msg)
  1216. if a.ndim != len(limit_indices):
  1217. msg = "Attempting to slice tensor of rank {0} with limit_indices of length {1}!".format(
  1218. a.ndim, len(limit_indices)
  1219. )
  1220. raise ValueError(msg)
  1221. if a.ndim != len(_strides):
  1222. msg = (
  1223. "Attempting to slice tensor of rank {0} with strides of length {1}!".format(
  1224. a.ndim, len(limit_indices)
  1225. )
  1226. )
  1227. raise ValueError(msg)
  1228. for x, y in zip(start_indices, a.shape):
  1229. if x < 0:
  1230. msg = "Attempting to slice a tensor with a negative start index of {0}!".format(
  1231. x
  1232. )
  1233. raise ValueError(msg)
  1234. if x > y:
  1235. msg = (
  1236. "Attempting to slice a tensor but a start index in {0} is greater than"
  1237. " the length of its corresponding dimension in shape {1}".format(
  1238. start_indices, a.shape
  1239. )
  1240. )
  1241. raise ValueError(msg)
  1242. for x, y, z in zip(limit_indices, a.shape, start_indices):
  1243. if x < 0:
  1244. msg = "Attempting to slice a tensor with a negative stop index of {0}!".format(
  1245. x
  1246. )
  1247. raise ValueError(msg)
  1248. if x > y:
  1249. msg = (
  1250. "Attempting to slice a tensor but a stop index in {0} is greater than the length of "
  1251. " its corresponding dimension in shape {1}".format(
  1252. limit_indices, a.shape
  1253. )
  1254. )
  1255. raise ValueError(msg)
  1256. if x < z:
  1257. msg = (
  1258. "Attempting to slice a tensor but a start index in {0} is greater than "
  1259. " its corresponding stop index {1}".format(x, z)
  1260. )
  1261. for x in _strides:
  1262. if x <= 0:
  1263. msg = (
  1264. "Attempting to slice a tensor with a non-positive step of {0}!".format(
  1265. x
  1266. )
  1267. )
  1268. raise ValueError(msg)
  1269. new_shape = []
  1270. for x, y, z in zip(start_indices, limit_indices, _strides):
  1271. new_shape.append(math.floor((y - x) / z))
  1272. new_strides = []
  1273. for x, y in zip(a.stride(), _strides):
  1274. new_strides.append(x * y)
  1275. return a.as_strided(new_shape, new_strides, a.storage_offset())
  1276. def _slice_aten(
  1277. a: Tensor,
  1278. start_indices: DimsSequenceType,
  1279. limit_indices: DimsSequenceType,
  1280. strides: Optional[StrideType] = None,
  1281. ) -> Tensor:
  1282. _strides = strides if strides is not None else [1] * len(start_indices)
  1283. slices = []
  1284. for start, stop, step in zip(start_indices, limit_indices, _strides):
  1285. slices.append(pyslice(start, stop, step))
  1286. return operator.getitem(a, slices) # type: ignore[call-overload]
  1287. _slice_doc = """
  1288. Creates a view of a "bounding box" within the tensor.
  1289. The bounding box is specified independently in each of the tensor's dimensions.
  1290. start_indices and limit_indices describe the box's boundaries for their corresponding
  1291. dimensions. If strides is specified then they specify the step size between elements
  1292. in their corresponding dimension.
  1293. This operation is analogous to slicing in NumPy, but does not permit slices where
  1294. the stop indices are less than the start indices.
  1295. """
  1296. slice = _make_prim(
  1297. schema="slice(Tensor(a) a, SymInt[] start_indices, SymInt[] limit_indices, SymInt[]? strides=None) -> Tensor(a)",
  1298. meta=_slice_meta,
  1299. impl_aten=_slice_aten,
  1300. return_type=RETURN_TYPE.VIEW,
  1301. doc=_slice_doc,
  1302. )
  1303. def _slice_in_dim_meta(
  1304. a: TensorLikeType,
  1305. start_index: int,
  1306. limit_index: int,
  1307. stride: int = 1,
  1308. axis: int = 0,
  1309. ) -> TensorLikeType:
  1310. if axis < 0:
  1311. msg = "slice_in_dim: received a negative axis {0}".format(axis)
  1312. raise ValueError(msg)
  1313. if axis >= a.ndim:
  1314. msg = "slice_in_dim: axis {0} is greater or equal to the rank {1} of the tensor".format(
  1315. axis, a.ndim
  1316. )
  1317. raise ValueError(msg)
  1318. if start_index < 0:
  1319. msg = "slice_in_dim: received a negative start_index {0}".format(start_index)
  1320. raise ValueError(msg)
  1321. if start_index > a.shape[axis]:
  1322. msg = "slice_in_dim: start_index is greater than the length {0} of dimension {1}".format(
  1323. start_index, axis
  1324. )
  1325. raise ValueError(msg)
  1326. if limit_index > a.shape[axis]:
  1327. msg = "slice_in_dim: limit_index is greater than the length {0} of dimension {1}".format(
  1328. limit_index, axis
  1329. )
  1330. raise ValueError(msg)
  1331. if limit_index < start_index:
  1332. msg = "slice_in_dim: received a limit_index {0} less than the start_index {1}".format(
  1333. limit_index, start_index
  1334. )
  1335. raise ValueError(msg)
  1336. if stride < 0:
  1337. msg = "slice_in_dim: received a non-positive stride of {0}!".format(stride)
  1338. raise ValueError(msg)
  1339. start_indices = [0] * a.ndim
  1340. limit_indices = list(a.shape)
  1341. strides = [1] * a.ndim
  1342. start_indices[axis] = start_index
  1343. limit_indices[axis] = limit_index
  1344. strides[axis] = stride
  1345. return _slice_meta(a, start_indices, limit_indices, strides)
  1346. def _slice_in_dim_aten(
  1347. a: Tensor,
  1348. start_index: int,
  1349. limit_index: int,
  1350. stride: int = 1,
  1351. axis: int = 0,
  1352. ) -> Tensor:
  1353. start_indices = [0] * a.ndim
  1354. limit_indices = list(a.shape)
  1355. strides = [1] * a.ndim
  1356. start_indices[axis] = start_index
  1357. limit_indices[axis] = limit_index
  1358. strides[axis] = stride
  1359. return slice(a, start_indices, limit_indices, strides)
  1360. _slice_in_dim_doc = """
  1361. Convenience wrapper for slicing just one dimension using slice.
  1362. """
  1363. # TODO: make stride SymInt
  1364. slice_in_dim = _make_prim(
  1365. schema="slice_in_dim(Tensor(a) a, SymInt start_index, SymInt limit_index, int stride=1, int axis=0) -> Tensor(a)",
  1366. meta=_slice_in_dim_meta,
  1367. impl_aten=_slice_in_dim_aten,
  1368. return_type=RETURN_TYPE.VIEW,
  1369. doc=_slice_in_dim_doc,
  1370. )
  1371. def _split_dim_meta(a: TensorLikeType, dim: int, outer_length: int) -> TensorLikeType:
  1372. assert isinstance(a, TensorLike)
  1373. utils.validate_idx(a.ndim, dim)
  1374. utils.validate_dim_length(outer_length)
  1375. # Verifies the dim can be split with the specified lhs_length
  1376. inner_length = a.shape[dim] // outer_length
  1377. if (a.shape[dim] % outer_length) != 0:
  1378. msg = "Attempting to split dimension of length {0}, but outer length of {1} divides it with a remainder!".format(
  1379. a.shape[dim], outer_length
  1380. )
  1381. raise ValueError(msg)
  1382. new_shape: List[int] = []
  1383. new_strides: List[int] = []
  1384. for idx in range(a.ndim):
  1385. if idx == dim:
  1386. new_shape.extend((outer_length, inner_length))
  1387. new_strides.extend((a.stride()[idx] * inner_length, a.stride()[idx]))
  1388. else:
  1389. new_shape.append(a.shape[idx])
  1390. new_strides.append(a.stride()[idx])
  1391. return a.as_strided(new_shape, new_strides, a.storage_offset())
  1392. def _split_dim_aten(a: Tensor, dim: int, outer_length: int) -> Tensor:
  1393. inner_length = a.shape[dim] // outer_length
  1394. new_shape = a.shape[0:dim] + (outer_length, inner_length) + a.shape[dim + 1 :]
  1395. return a.view(new_shape)
  1396. _split_dim_doc = """
  1397. Creates a view of a with the given dimension (of length l) split
  1398. into two dimensions, with the outer of the two having
  1399. length outer_length and the inner of the two having computed
  1400. length inner_length such outer_length * inner_length = l.
  1401. """
  1402. # TODO: consider renaming split_dim_view
  1403. split_dim = _make_prim(
  1404. schema="split_dim(Tensor(a) a, int dim, SymInt outer_length) -> Tensor(a)",
  1405. meta=_split_dim_meta,
  1406. impl_aten=_split_dim_aten,
  1407. return_type=RETURN_TYPE.VIEW,
  1408. doc=_split_dim_doc,
  1409. )
  1410. # Note: allows dimensions to be specified redundantly
  1411. def _squeeze_meta(a: TensorLikeType, dimensions: Sequence) -> TensorLikeType:
  1412. assert isinstance(a, TensorLike)
  1413. for idx in dimensions:
  1414. utils.validate_idx(a.ndim, idx)
  1415. assert a.shape[idx] == 1
  1416. new_shape = []
  1417. new_strides = []
  1418. for idx in range(len(a.shape)):
  1419. if idx in dimensions:
  1420. continue
  1421. new_shape.append(a.shape[idx])
  1422. new_strides.append(a.stride()[idx])
  1423. return a.as_strided(new_shape, new_strides, a.storage_offset())
  1424. _squeeze_doc = """
  1425. Creates a view of the tensor with the specified dimensions removed.
  1426. The removed dimensions must each have length one.
  1427. """
  1428. squeeze = _make_prim(
  1429. schema="squeeze(Tensor(a) a, int[] dimensions) -> Tensor(a)",
  1430. meta=_squeeze_meta,
  1431. impl_aten=torch.squeeze,
  1432. return_type=RETURN_TYPE.VIEW,
  1433. doc=_squeeze_doc,
  1434. )
  1435. def _transpose_meta(a: TensorLikeType, permutation: DimsSequenceType) -> TensorLikeType:
  1436. if a.ndim != len(permutation):
  1437. msg = "Attempting to permute a tensor of rank {0}, but received a permutation of length {1}!".format(
  1438. a.ndim, len(permutation)
  1439. )
  1440. raise ValueError(msg)
  1441. if not utils.is_valid_permutation(a.ndim, permutation):
  1442. msg = "Received an invalid permutation, {0}!".format(permutation)
  1443. raise ValueError(msg)
  1444. new_shape = [0] * a.ndim
  1445. new_strides = [0] * a.ndim
  1446. for idx, dim in enumerate(permutation):
  1447. new_shape[idx] = a.shape[dim]
  1448. new_strides[idx] = a.stride()[dim]
  1449. return a.as_strided(tuple(new_shape), tuple(new_strides), a.storage_offset())
  1450. def _transpose_aten(a: Tensor, permutation: DimsSequenceType) -> Tensor:
  1451. return torch.permute(a, permutation)
  1452. _transpose_doc = """
  1453. Creates a view of the tensor with its dimensions permuted.
  1454. The length of the permutation must be the rank of the tensor,
  1455. and each element of the permutation specifies the new order
  1456. for the corresponding dimension.
  1457. """
  1458. transpose = _make_prim(
  1459. schema="transpose(Tensor(a) a, int[] permutation) -> Tensor(a)",
  1460. meta=_transpose_meta,
  1461. impl_aten=_transpose_aten,
  1462. return_type=RETURN_TYPE.VIEW,
  1463. doc=_transpose_doc,
  1464. )
  1465. def _view_of_meta(a: TensorLikeType) -> TensorLikeType:
  1466. return a.as_strided(a.shape, a.stride(), a.storage_offset())
  1467. def _view_of_aten(a: Tensor) -> Tensor:
  1468. return a.view(a.shape)
  1469. _view_of_doc = """
  1470. Creates a view of the tensor.
  1471. """
  1472. view_of = _make_prim(
  1473. schema="view_of(Tensor(a) a) -> Tensor",
  1474. meta=_view_of_meta,
  1475. impl_aten=_view_of_aten,
  1476. return_type=RETURN_TYPE.VIEW,
  1477. doc=_view_of_doc,
  1478. )
  1479. #
  1480. # Functionalized view mutations
  1481. #
  1482. def _as_strided_scatter_meta(
  1483. input: TensorLikeType,
  1484. src: TensorLikeType,
  1485. size: ShapeType,
  1486. stride: StrideType,
  1487. storage_offset: int,
  1488. ) -> TensorLikeType:
  1489. utils.validate_shape(size)
  1490. utils.validate_strides(stride)
  1491. required_size = utils.compute_required_storage_length(size, stride, storage_offset)
  1492. utils.check(
  1493. input.numel() >= required_size,
  1494. lambda: (
  1495. f"as_strided_scatter: sizes {size}, strides {stride}, storage offset {storage_offset} "
  1496. f" and itemsize {input.element_size()} requiring a storage size of "
  1497. f"{required_size * input.element_size()} are out of bounds "
  1498. f"for storage of size {input.numel() * input.element_size()}"
  1499. ),
  1500. )
  1501. utils.check(
  1502. utils.is_same_shape(src.shape, size),
  1503. lambda: f"expected src to have a size equal to the slice of self. src size = {src.shape}, slice size = {size}",
  1504. )
  1505. return utils.clone_preserve_strides(input)
  1506. _as_strided_scatter_doc = """
  1507. Creates a new tensor equivalent to ``out = input.clone()`` after mutation by
  1508. ``out.as_strided(size, stride, storage_offset).copy_(src)``.
  1509. """
  1510. as_strided_scatter = _make_prim(
  1511. schema="as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt storage_offset) -> Tensor",
  1512. meta=_as_strided_scatter_meta,
  1513. impl_aten=torch.as_strided_scatter,
  1514. return_type=RETURN_TYPE.NEW,
  1515. doc=_as_strided_scatter_doc,
  1516. )
  1517. #
  1518. # Shape operations
  1519. #
  1520. def collapse(a: Tensor, start: int, end: int) -> Tensor:
  1521. """
  1522. Wrapper around reshape that collapses a span of dimensions.
  1523. See collapse_view for the corresponding view operation.
  1524. """
  1525. dim_length = 1
  1526. for idx in range(start, end):
  1527. dim_length = dim_length * a.shape[idx]
  1528. new_shape = a.shape[0:start] + (dim_length,) + a.shape[end:]
  1529. return reshape(a, new_shape)
  1530. # TODO: review stride logic
  1531. def _cat_meta(tensors: Sequence[TensorLikeType], dim: int) -> TensorLikeType:
  1532. # Verifies same shape (except in the concat dimension)
  1533. shape = tensors[0].shape
  1534. concat_length = 0
  1535. for tensor_idx, tensor in enumerate(tensors):
  1536. for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)):
  1537. if idx == dim:
  1538. concat_length = concat_length + length
  1539. elif length != common_length:
  1540. raise RuntimeError(
  1541. f"Sizes of tensors must match except in dimension {dim}. "
  1542. f"Expected {common_length} but got {length} for tensor number "
  1543. f"{tensor_idx} in the list"
  1544. )
  1545. new_shape = list(tensors[0].shape).copy()
  1546. new_shape[dim] = concat_length
  1547. return TensorMeta(
  1548. tensors[0],
  1549. shape=new_shape,
  1550. strides=utils.make_contiguous_strides_for(new_shape),
  1551. )
  1552. def _cat_aten(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: int) -> Tensor:
  1553. return torch.cat(tensors, dim)
  1554. _cat_doc = """
  1555. Concatenates tensors along the specified dimension.
  1556. The tensors' shapes must have the same rank and same length for other dimensions.
  1557. """
  1558. cat = _make_prim(
  1559. schema="cat(Tensor[] tensors, int dim) -> Tensor",
  1560. meta=_cat_meta,
  1561. impl_aten=_cat_aten,
  1562. return_type=RETURN_TYPE.NEW,
  1563. doc=_cat_doc,
  1564. )
  1565. def _reshape_meta(a: TensorLikeType, shape: ShapeType):
  1566. assert isinstance(a, TensorLike)
  1567. utils.validate_shape(shape)
  1568. # Validates the tensor and the requested shape have the
  1569. # same number of elements
  1570. numel = reduce(operator.mul, shape)
  1571. if numel != a.numel():
  1572. msg = "Attempting to reshape a tensor with {0} elements to a shape with {1} elements!".format(
  1573. a.numel(), numel
  1574. )
  1575. raise ValueError(msg)
  1576. return TensorMeta(a, shape=shape, strides=utils.make_contiguous_strides_for(shape))
  1577. def _reshape_aten(a: Tensor, shape: ShapeType) -> Tensor:
  1578. return a.reshape(shape).contiguous().clone()
  1579. _reshape_doc = """
  1580. Creates a contiguous tensor with the specified shape
  1581. containing a copy of the data in a.
  1582. """
  1583. reshape = _make_prim(
  1584. schema="reshape(Tensor a, SymInt[] shape) -> Tensor",
  1585. meta=_reshape_meta,
  1586. impl_aten=_reshape_aten,
  1587. return_type=RETURN_TYPE.NEW,
  1588. doc=_reshape_doc,
  1589. )
  1590. def _rev_meta(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType:
  1591. utils.validate_dimension_indices(a.ndim, dims)
  1592. out = torch.empty_like(a, memory_format=torch.preserve_format)
  1593. return TensorMeta(out)
  1594. _rev_doc = """
  1595. Reverses the order of elements along the given dimensions.
  1596. """
  1597. rev = _make_prim(
  1598. schema="rev(Tensor a, int[] dims) -> Tensor",
  1599. meta=_rev_meta,
  1600. impl_aten=torch.flip,
  1601. return_type=RETURN_TYPE.NEW,
  1602. doc=_rev_doc,
  1603. )
  1604. #
  1605. # Conditional prims
  1606. #
  1607. def _where_meta(
  1608. pred: TensorLikeType, a: TensorLikeType, b: TensorLikeType
  1609. ) -> TensorLikeType:
  1610. return _elementwise_meta(
  1611. a,
  1612. b,
  1613. type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
  1614. args_with_fixed_dtypes=(pred,),
  1615. )
  1616. _where_doc = """
  1617. Selects elements from a and b according to pred.
  1618. Where pred is true the result contains the element from a, and
  1619. where pred is false the result contains the element from b.
  1620. """
  1621. where = _make_prim(
  1622. schema="where(Tensor pred, Tensor a, Tensor b) -> Tensor",
  1623. meta=_where_meta,
  1624. impl_aten=torch.where,
  1625. return_type=RETURN_TYPE.NEW,
  1626. doc=_where_doc,
  1627. )
  1628. #
  1629. # Type conversions
  1630. #
  1631. def _convert_element_type_meta(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType:
  1632. # Type checks
  1633. assert isinstance(a, TensorLike)
  1634. assert isinstance(dtype, torch.dtype)
  1635. # dtype conversion preserves dense strides
  1636. if torch._prims_common.is_non_overlapping_and_dense(a):
  1637. strides = a.stride()
  1638. else:
  1639. strides = utils.compute_elementwise_output_strides(a)
  1640. return TensorMeta(a, strides=strides, dtype=dtype)
  1641. def _convert_element_type_aten(a: Tensor, dtype: torch.dtype) -> Tensor:
  1642. # Propagates requires grad when possible
  1643. if not utils.is_grad_dtype(dtype):
  1644. requires_grad = False
  1645. else:
  1646. # TODO: update meta objects so this can be acquired directly
  1647. try:
  1648. requires_grad = a.requires_grad
  1649. except Exception as e:
  1650. requires_grad = False
  1651. result = torch.empty_like(
  1652. a, device=a.device, dtype=dtype, requires_grad=requires_grad
  1653. )
  1654. with torch.no_grad():
  1655. return copy_to(result, a)
  1656. _convert_element_type_doc = """
  1657. Creates a copy of a tensor with the given dtype.
  1658. """
  1659. convert_element_type = _make_prim(
  1660. schema="convert_element_type(Tensor a, ScalarType dtype) -> Tensor",
  1661. meta=_convert_element_type_meta,
  1662. impl_aten=_convert_element_type_aten,
  1663. return_type=RETURN_TYPE.NEW,
  1664. doc=_convert_element_type_doc,
  1665. )
  1666. def _device_put_meta(
  1667. a: TensorLikeType, device: Union[str, torch.device]
  1668. ) -> TensorLikeType:
  1669. assert isinstance(a, TensorLike)
  1670. assert isinstance(device, (str, torch.device))
  1671. return TensorMeta(a, device=utils.canonicalize_device(device))
  1672. def _device_put_aten(a: Tensor, device: Union[str, torch.device]) -> Tensor:
  1673. return a.to(device)
  1674. _device_put_doc = """
  1675. Creates a copy of a tensor on the given device.
  1676. """
  1677. device_put = _make_prim(
  1678. schema="device_put(Tensor a, Device device) -> Tensor",
  1679. meta=_device_put_meta,
  1680. impl_aten=_device_put_aten,
  1681. return_type=RETURN_TYPE.NEW,
  1682. doc=_device_put_doc,
  1683. )
  1684. # NOTE: need to model meta scalars
  1685. # See https://github.com/pytorch/pytorch/issues/78070
  1686. def _item_meta(a: TensorLikeType) -> FakeTensor:
  1687. number_type = utils.dtype_to_type(a.dtype)
  1688. return TensorMeta(number_type(-1))
  1689. _item_doc = """
  1690. Converts a tensor with one element to a Python number.
  1691. """
  1692. # TODO: create a new return type for scalars?
  1693. # FIXME: currently returns integers for boolean tensors
  1694. # https://github.com/pytorch/pytorch/issues/78071
  1695. item = _make_prim(
  1696. schema="item(Tensor a) -> Scalar",
  1697. meta=_item_meta,
  1698. impl_aten=torch.Tensor.item,
  1699. return_type=RETURN_TYPE.NEW,
  1700. doc=_item_doc,
  1701. )
  1702. # NOTE: need to model meta scalars
  1703. # See https://github.com/pytorch/pytorch/issues/78070
  1704. def _maximum_value_meta(dtype: torch.dtype) -> FakeTensor:
  1705. number_type = utils.dtype_to_type(dtype)
  1706. return TensorMeta(number_type(-1))
  1707. def _maximum_value_aten(dtype: torch.dtype):
  1708. if dtype == torch.bool:
  1709. return True
  1710. elif dtype.is_complex or dtype.is_floating_point:
  1711. return torch.finfo(dtype).max
  1712. else:
  1713. return torch.iinfo(dtype).max
  1714. _maximum_value_doc = """
  1715. Return the maximum finite value for a dtype.
  1716. """
  1717. # TODO: create a new return type for scalars?
  1718. # FIXME: currently returns integers for boolean tensors
  1719. # https://github.com/pytorch/pytorch/issues/78071
  1720. maximum_value = _make_prim(
  1721. schema="maximum_value(ScalarType dtype) -> Scalar",
  1722. meta=_maximum_value_meta,
  1723. impl_aten=_maximum_value_aten,
  1724. return_type=RETURN_TYPE.NEW,
  1725. doc=_maximum_value_doc,
  1726. )
  1727. # NOTE: need to model meta scalars
  1728. # See https://github.com/pytorch/pytorch/issues/78070
  1729. def _minimum_value_meta(dtype: torch.dtype) -> FakeTensor:
  1730. number_type = utils.dtype_to_type(dtype)
  1731. return TensorMeta(number_type(-1))
  1732. def _minimum_value_aten(dtype: torch.dtype):
  1733. if dtype == torch.bool:
  1734. return False
  1735. elif dtype.is_complex or dtype.is_floating_point:
  1736. return torch.finfo(dtype).min
  1737. else:
  1738. return torch.iinfo(dtype).min
  1739. _minimum_value_doc = """
  1740. Return the mimimum finite value for a dtype.
  1741. """
  1742. # TODO: create a new return type for scalars?
  1743. # FIXME: currently returns integers for boolean tensors
  1744. # https://github.com/pytorch/pytorch/issues/78071
  1745. minimum_value = _make_prim(
  1746. schema="minium_value(ScalarType dtype) -> Scalar",
  1747. meta=_minimum_value_meta,
  1748. impl_aten=_minimum_value_aten,
  1749. return_type=RETURN_TYPE.NEW,
  1750. doc=_minimum_value_doc,
  1751. )
  1752. #
  1753. # Inplace operators
  1754. #
  1755. def _copy_to_meta(a: TensorLikeType, b: TensorLikeType):
  1756. assert isinstance(a, TensorLike)
  1757. assert isinstance(b, TensorLike)
  1758. # Validates the cast is safe
  1759. # TODO: move this as an option on the reference
  1760. # a_typ = utils.dtype_to_type(a.dtype)
  1761. # b_typ = utils.dtype_to_type(b.dtype)
  1762. # if a_typ is not utils.get_higher_type(a_typ, b_typ):
  1763. # raise RuntimeError(str(b.dtype), " can't be cast safely to ", str(a.dtype), "!")
  1764. # Validates the tensors have the same number of elements
  1765. if a.numel() != b.numel():
  1766. msg = "Attempting to copy {0} elements to a tensor with {1} elements!".format(
  1767. b.numel(), a.numel()
  1768. )
  1769. raise RuntimeError(msg)
  1770. return a
  1771. def _copy_to_aten(a: Tensor, b: Tensor) -> Tensor:
  1772. return a.copy_(b)
  1773. _copy_to_doc = """
  1774. Copies the data in b to a and returns the modified a.
  1775. """
  1776. # TODO: Remove safe casting and implement on reference instead
  1777. copy_to = _make_prim(
  1778. schema="copy_to(Tensor(a!) a, Tensor b) -> Tensor(a!)",
  1779. meta=_copy_to_meta,
  1780. impl_aten=_copy_to_aten,
  1781. return_type=RETURN_TYPE.INPLACE,
  1782. doc=_copy_to_doc,
  1783. )
  1784. def _copy_strided_meta(a: TensorLikeType, stride: ShapeType):
  1785. assert isinstance(a, TensorLike)
  1786. return torch.empty_strided(
  1787. a.shape,
  1788. stride,
  1789. dtype=a.dtype,
  1790. layout=a.layout,
  1791. device=a.device,
  1792. requires_grad=a.requires_grad,
  1793. )
  1794. def _copy_strided_aten(a: Tensor, stride: ShapeType) -> Tensor:
  1795. out = torch.empty_strided(
  1796. a.size(),
  1797. stride=stride,
  1798. dtype=a.dtype,
  1799. layout=a.layout,
  1800. device=a.device,
  1801. requires_grad=a.requires_grad,
  1802. )
  1803. out.copy_(a)
  1804. return out
  1805. _copy_strided_doc = """
  1806. Copies the data in a to a new tensor, the new tensor has same shape with a size, but has different stride.
  1807. """
  1808. copy_strided = _make_prim(
  1809. schema="copy_strided(Tensor a, SymInt[] stride) -> Tensor",
  1810. meta=_copy_strided_meta,
  1811. impl_aten=_copy_strided_aten,
  1812. return_type=RETURN_TYPE.NEW,
  1813. doc=_copy_strided_doc,
  1814. )
  1815. def _resize_meta(a: TensorLikeType, shape: ShapeType):
  1816. return a.resize_(shape)
  1817. def _resize_aten(a: Tensor, shape: ShapeType) -> Tensor:
  1818. return a.resize_(shape)
  1819. _resize_doc = """
  1820. Gives a tensor with no elements a new shape, returning the modified tensor.
  1821. The tensor's strides are contiguous and its values are unitialized.
  1822. """
  1823. # TODO: review support arbitrary resizes
  1824. resize = _make_prim(
  1825. schema="resize(Tensor(a!) a, SymInt[] shape) -> Tensor(a!)",
  1826. meta=_resize_meta,
  1827. impl_aten=_resize_aten,
  1828. return_type=RETURN_TYPE.INPLACE,
  1829. doc=_resize_doc,
  1830. )
  1831. def _reduction_meta(inp, dims, *, output_dtype=None):
  1832. """
  1833. Meta function for single output reduction operations
  1834. Stride logic is incorrect
  1835. """
  1836. assert isinstance(inp, TensorLike)
  1837. if output_dtype is None:
  1838. output_dtype = inp.dtype
  1839. output_shape = utils.compute_reduction_output_shape(inp.shape, dims)
  1840. return TensorMeta(
  1841. shape=output_shape,
  1842. strides=utils.make_contiguous_strides_for(output_shape),
  1843. dtype=output_dtype,
  1844. device=inp.device,
  1845. )
  1846. def _var_reduction_meta(inp, dims, *, correction):
  1847. if utils.is_complex_dtype(inp.dtype):
  1848. output_dtype = utils.corresponding_real_dtype(inp.dtype)
  1849. else:
  1850. output_dtype = inp.dtype
  1851. return _reduction_meta(inp, dims, output_dtype=output_dtype)
  1852. _sum_doc = """
  1853. Computes the sum of elements in the input tensor over the list of dimensions
  1854. specified in the dim argument
  1855. """
  1856. _prod_doc = """
  1857. Computes the product of elements in the input tensor over the list of dimensions
  1858. specified in the dim argument
  1859. """
  1860. _amax_doc = """
  1861. Computes the maximum value of elements in the input tensor over the list of dimensions
  1862. specified in the dim argument
  1863. """
  1864. _amin_doc = """
  1865. Computes the minimum value of elements in the input tensor over the list of dimensions
  1866. specified in the dim argument
  1867. """
  1868. _var_doc = """
  1869. Computes the biased variance of x over the list of dimensions specified in the dim argument
  1870. """
  1871. def _make_reduction_prim(name: str, impl_aten, doc):
  1872. """Creates a reduction prim."""
  1873. return _make_prim(
  1874. schema=f"{name}(Tensor inp, int[]? dims, *, ScalarType? output_dtype=None) -> Tensor",
  1875. meta=_reduction_meta,
  1876. impl_aten=impl_aten,
  1877. return_type=RETURN_TYPE.NEW,
  1878. doc=doc,
  1879. )
  1880. def _make_var_reduction_prim(name: str, impl_aten, doc):
  1881. """Creates a reduction prim."""
  1882. return _make_prim(
  1883. schema=f"{name}(Tensor inp, int[]? dims, *, int correction, ScalarType? output_dtype=None) -> Tensor",
  1884. meta=_var_reduction_meta,
  1885. impl_aten=impl_aten,
  1886. return_type=RETURN_TYPE.NEW,
  1887. doc=doc,
  1888. )
  1889. sum = _make_reduction_prim(
  1890. name="sum",
  1891. impl_aten=torch.sum,
  1892. doc=_sum_doc,
  1893. )
  1894. def _prod_aten(
  1895. inp: TensorLikeType,
  1896. dims: Optional[DimsSequenceType],
  1897. *,
  1898. dtype: Optional[torch.dtype] = None,
  1899. ) -> Tensor:
  1900. if dims is not None:
  1901. for d in sorted(dims, reverse=True):
  1902. assert d >= 0
  1903. inp = torch.prod(inp, d, dtype=dtype)
  1904. return inp
  1905. else:
  1906. return torch.prod(inp, dims, dtype=dtype)
  1907. prod = _make_reduction_prim(
  1908. name="prod",
  1909. impl_aten=_prod_aten,
  1910. doc=_prod_doc,
  1911. )
  1912. var = _make_var_reduction_prim(
  1913. name="var",
  1914. impl_aten=torch.var,
  1915. doc=_var_doc,
  1916. )
  1917. amax = _make_reduction_prim(
  1918. name="amax",
  1919. impl_aten=torch.amax,
  1920. doc=_amax_doc,
  1921. )
  1922. amin = _make_reduction_prim(
  1923. name="amin",
  1924. impl_aten=torch.amin,
  1925. doc=_amin_doc,
  1926. )
  1927. _iota_doc = """
  1928. Constructs a 1-D tensor t where ``t[i] == start + i * step``.
  1929. """
  1930. # TODO: layout, pin_memory, memory_format
  1931. # TODO: model requires_grad on TensorMeta
  1932. def _iota_meta(
  1933. length: int,
  1934. *,
  1935. start: int,
  1936. step: int,
  1937. dtype: torch.dtype,
  1938. device: torch.device,
  1939. requires_grad: bool,
  1940. ) -> TensorLikeType:
  1941. utils.check(
  1942. utils.is_integer_dtype(dtype),
  1943. lambda: "prims.iota only supports integer dtypes",
  1944. )
  1945. utils.check(step != 0, lambda: "step must be nonzero")
  1946. return torch.empty(
  1947. length,
  1948. dtype=dtype,
  1949. device=device,
  1950. requires_grad=requires_grad,
  1951. )
  1952. def _iota_aten(
  1953. length: int,
  1954. *,
  1955. start: int,
  1956. step: int,
  1957. dtype: torch.dtype,
  1958. device: torch.device,
  1959. requires_grad: bool,
  1960. ) -> TensorLikeType:
  1961. end = start + length * step
  1962. return torch.arange(
  1963. start, end, step, dtype=dtype, device=device, requires_grad=requires_grad
  1964. )
  1965. iota = _make_prim(
  1966. schema="iota(SymInt length, *, SymInt start, SymInt step, ScalarType dtype, Device device, bool requires_grad) -> Tensor", # noqa: B950
  1967. return_type=RETURN_TYPE.NEW,
  1968. meta=_iota_meta,
  1969. impl_aten=_iota_aten,
  1970. doc=_iota_doc,
  1971. )
  1972. # TODO: layout, pin_memory, memory_format
  1973. # TODO: model requires_grad on TensorMeta
  1974. def _empty_meta(
  1975. shape: ShapeType, *, dtype: torch.dtype, device: torch.device, requires_grad: bool
  1976. ) -> TensorLikeType:
  1977. strides = utils.make_contiguous_strides_for(shape)
  1978. return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
  1979. def _empty_aten(
  1980. shape: ShapeType, *, dtype: torch.dtype, device: torch.device, requires_grad: bool
  1981. ) -> Tensor:
  1982. return torch.empty(shape, dtype=dtype, device=device, requires_grad=requires_grad)
  1983. _empty_doc = """
  1984. Creates a tensor with uninitialized values and the specified shape, dtype, and device.
  1985. """
  1986. empty = _make_prim(
  1987. schema="empty(SymInt[] shape, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor",
  1988. meta=_empty_meta,
  1989. impl_aten=_empty_aten,
  1990. return_type=RETURN_TYPE.NEW,
  1991. doc=_empty_doc,
  1992. )
  1993. def _empty_strided_meta(
  1994. shape: ShapeType,
  1995. strides: StrideType,
  1996. *,
  1997. dtype: torch.dtype,
  1998. device: torch.device,
  1999. requires_grad: bool,
  2000. ) -> TensorLikeType:
  2001. return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
  2002. _empty_strided_doc = """
  2003. Creates a tensor with uninitialized values.
  2004. """
  2005. # TODO: add layout, pin_memory
  2006. empty_strided = _make_prim(
  2007. schema="empty_strided(SymInt[] shape, SymInt[] strides, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor",
  2008. return_type=RETURN_TYPE.NEW,
  2009. meta=_empty_strided_meta,
  2010. impl_aten=torch.empty_strided,
  2011. doc=_empty_strided_doc,
  2012. )
  2013. def _full_meta(
  2014. shape: ShapeType,
  2015. fill_value: NumberType,
  2016. *,
  2017. dtype: torch.dtype,
  2018. device: torch.device,
  2019. requires_grad: bool,
  2020. ) -> TensorLikeType:
  2021. strides = utils.make_contiguous_strides_for(shape)
  2022. return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
  2023. def _full_aten(
  2024. shape: ShapeType,
  2025. fill_value: NumberType,
  2026. *,
  2027. dtype: torch.dtype,
  2028. device: torch.device,
  2029. requires_grad: bool,
  2030. ) -> Tensor:
  2031. # Note that Mypy thinks torch.full can't accept a complex fill_value
  2032. return torch.full(
  2033. shape, fill_value, dtype=dtype, device=device, requires_grad=requires_grad # type: ignore[arg-type]
  2034. )
  2035. _full_doc = """
  2036. Creates a tensor filled with the given fill value, and with the specified shape, dtype, and device.
  2037. """
  2038. # TODO: add layout
  2039. full = _make_prim(
  2040. schema="full(SymInt[] shape, Scalar fill_value, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor",
  2041. meta=_full_meta,
  2042. impl_aten=_full_aten,
  2043. return_type=RETURN_TYPE.NEW,
  2044. doc=_full_doc,
  2045. )
  2046. def _full_like_meta(
  2047. a: TensorLikeType,
  2048. fill_value: NumberType,
  2049. *,
  2050. dtype: torch.dtype,
  2051. device: torch.device,
  2052. requires_grad: bool,
  2053. ) -> TensorLikeType:
  2054. strides = utils.compute_elementwise_output_strides(a)
  2055. if a.numel() == 0:
  2056. strides = a.stride()
  2057. return TensorMeta(a, strides=strides, dtype=dtype, device=device)
  2058. def _full_like_aten(
  2059. a: Tensor,
  2060. fill_value: NumberType,
  2061. *,
  2062. dtype: torch.dtype,
  2063. device: torch.device,
  2064. requires_grad: bool,
  2065. ) -> Tensor:
  2066. # Note that Mypy thinks torch.full can't accept a complex fill_value
  2067. return torch.full_like(
  2068. a, fill_value, dtype=dtype, device=device, requires_grad=requires_grad # type: ignore[arg-type]
  2069. )
  2070. _full_like_doc = """
  2071. Creates a tensor filled with the given fill value, and the same shape, dtype, and device as the
  2072. given tensor by default. The dtype and device settings can be overridden
  2073. by specifying them explicitly.
  2074. """
  2075. full_like = _make_prim(
  2076. schema="full_like(Tensor a, Scalar fill_value, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor",
  2077. meta=_full_like_meta,
  2078. impl_aten=_full_like_aten,
  2079. return_type=RETURN_TYPE.NEW,
  2080. doc=_full_like_doc,
  2081. )
  2082. def _scalar_tensor_meta(
  2083. scalar: NumberType,
  2084. *,
  2085. dtype: torch.dtype,
  2086. device: torch.device,
  2087. ) -> TensorLikeType:
  2088. shape: ShapeType = []
  2089. strides = utils.make_contiguous_strides_for(shape)
  2090. return TensorMeta(scalar, shape=shape, strides=strides, dtype=dtype, device=device)
  2091. def _scalar_tensor_aten(
  2092. scalar: NumberType,
  2093. *,
  2094. dtype: torch.dtype,
  2095. device: torch.device,
  2096. ) -> Tensor:
  2097. if isinstance(scalar, complex) and (
  2098. dtype is None or not utils.is_complex_dtype(dtype)
  2099. ):
  2100. raise TypeError("Complex scalar requires complex tensor dtype.")
  2101. # Note that Mypy thinks torch.scalar can't accept a complex scalar
  2102. return torch.scalar_tensor(scalar, dtype=dtype, device=device) # type: ignore[arg-type]
  2103. _scalar_tensor_doc = """
  2104. Wraps a Number into a Tensor with the specified dtype and device.
  2105. """
  2106. # TODO: add layout and pin_memory support
  2107. scalar_tensor = _make_prim(
  2108. schema="scalar_tensor(Scalar s, *, ScalarType? dtype=None, Device? device=None) -> Tensor",
  2109. meta=_scalar_tensor_meta,
  2110. impl_aten=_scalar_tensor_aten,
  2111. return_type=RETURN_TYPE.NEW,
  2112. doc=_scalar_tensor_doc,
  2113. )
  2114. #
  2115. # Linear algebra (linalg) prims
  2116. #
  2117. def _svd_meta(
  2118. A: TensorLikeType, *, full_matrices: bool
  2119. ) -> Tuple[TensorLikeType, TensorLikeType, TensorLikeType]:
  2120. utils.check_is_matrix(A, "linalg.svd")
  2121. utils.check_fp_or_complex(A.dtype, "linalg.svd", allow_low_precision_dtypes=False)
  2122. A_shape = A.shape
  2123. batch = A_shape[:-2]
  2124. m, n = A_shape[-2:]
  2125. k = min(m, n)
  2126. shape_U = batch + (m, m if full_matrices else k)
  2127. strides_U = utils.make_contiguous_strides_for(shape_U, row_major=False)
  2128. U = TensorMeta(shape=shape_U, strides=strides_U, dtype=A.dtype, device=A.device)
  2129. shape_S = batch + (k,)
  2130. strides_S = utils.make_contiguous_strides_for(shape_S)
  2131. S = TensorMeta(
  2132. shape=shape_S,
  2133. strides=strides_S,
  2134. dtype=utils.corresponding_real_dtype(A.dtype) if A.is_complex() else A.dtype,
  2135. device=A.device,
  2136. )
  2137. shape_Vh = batch + (n if full_matrices else k, n)
  2138. # The CPU backend returns V, but the cuSolver backend returns V^H
  2139. # TODO The MAGMA backend returns V, so this is wrong if used with the MAGMA backend
  2140. is_cuda = A.device.type == "cuda"
  2141. strides_Vh = utils.make_contiguous_strides_for(shape_Vh, row_major=is_cuda)
  2142. Vh = TensorMeta(shape=shape_Vh, strides=strides_Vh, dtype=A.dtype, device=A.device)
  2143. return U, S, Vh
  2144. def _svd_aten(
  2145. A: TensorLikeType, *, full_matrices: bool
  2146. ) -> Tuple[Tensor, Tensor, Tensor]:
  2147. return torch.linalg.svd(A, full_matrices=full_matrices)
  2148. _svd_doc = """
  2149. Returns the SVD of a matrix or batch of matrices.
  2150. The `full_matrices` flag controls whether the full or reduced SVD decomposition is returned.
  2151. """
  2152. svd = _make_prim(
  2153. schema="svd(Tensor A, *, bool full_matrices) -> (Tensor U, Tensor S, Tensor Vh)",
  2154. meta=_svd_meta,
  2155. impl_aten=_svd_aten,
  2156. return_type=(RETURN_TYPE.NEW, RETURN_TYPE.NEW, RETURN_TYPE.NEW),
  2157. doc=_svd_doc,
  2158. )
  2159. #
  2160. # Randomness Prims
  2161. #
  2162. # TODO: add generator support
  2163. # NOTE: there is currently no way of acquiring the "default" torch generator
  2164. def _normal_meta(
  2165. shape: ShapeType,
  2166. *,
  2167. mean: Union[float, complex],
  2168. std: float,
  2169. dtype: torch.dtype,
  2170. device: torch.device,
  2171. requires_grad: bool,
  2172. ) -> TensorLikeType:
  2173. utils.check(
  2174. std >= 0.0,
  2175. lambda: f"expected non-negative standard deviation, but got std={std}",
  2176. )
  2177. utils.check(
  2178. utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
  2179. lambda: f"expected a floating-point or complex dtype, but got dtype={dtype}",
  2180. )
  2181. strides = utils.make_contiguous_strides_for(shape)
  2182. return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
  2183. def _normal_aten(
  2184. shape: ShapeType,
  2185. *,
  2186. mean: Union[float, complex],
  2187. std: float,
  2188. dtype: torch.dtype,
  2189. device: torch.device,
  2190. requires_grad: bool,
  2191. ) -> Tensor:
  2192. a = torch.empty(shape, dtype=dtype, device=device, requires_grad=requires_grad)
  2193. with torch.no_grad():
  2194. # NOTE: normal_ is incorrectly annotated to expect mean to be a float
  2195. a.normal_(mean, std) # type: ignore[arg-type]
  2196. return a
  2197. _normal_doc = """
  2198. Constructs a tensor filled with values drawn from a normal distribution with the specified mean
  2199. and standard deviation.
  2200. Only supports floating-point types.
  2201. """
  2202. normal = _make_prim(
  2203. schema=(
  2204. "normal(SymInt[] shape, *, Scalar mean, Scalar std, ScalarType dtype, Device device, bool requires_grad) -> Tensor"
  2205. ),
  2206. return_type=RETURN_TYPE.NEW,
  2207. meta=_normal_meta,
  2208. impl_aten=_normal_aten,
  2209. doc=_normal_doc,
  2210. )
  2211. def _uniform_meta(
  2212. shape: ShapeType,
  2213. *,
  2214. low: float,
  2215. high: float,
  2216. dtype: torch.dtype,
  2217. device: torch.device,
  2218. ) -> TensorLikeType:
  2219. strides = utils.make_contiguous_strides_for(shape)
  2220. return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
  2221. def _uniform_aten(
  2222. shape: ShapeType,
  2223. *,
  2224. low: float,
  2225. high: float,
  2226. dtype: torch.dtype,
  2227. device: torch.device,
  2228. ) -> Tensor:
  2229. a = torch.empty(shape, dtype=dtype, device=device)
  2230. a.uniform_(low, high)
  2231. return a
  2232. _uniform_doc = """
  2233. Constructs a tensor filled with values drawn uniformly from low to high.
  2234. """
  2235. # TODO: we should more seriously review randomness modeling and prims
  2236. _uniform_helper = _make_prim(
  2237. schema=(
  2238. "uniform(SymInt[] shape, *, Scalar low, Scalar high, ScalarType dtype, Device device) -> Tensor"
  2239. ),
  2240. return_type=RETURN_TYPE.NEW,
  2241. meta=_uniform_meta,
  2242. impl_aten=_uniform_aten,
  2243. doc=_uniform_doc,
  2244. )
  2245. #
  2246. # FFT prims
  2247. #
  2248. def _fft_r2c_meta(
  2249. input: TensorLike,
  2250. *,
  2251. dim: DimsSequenceType,
  2252. onesided: bool,
  2253. ) -> TensorLikeType:
  2254. dim = utils.canonicalize_dims(input.ndim, dim)
  2255. utils.validate_no_repeating_dims(dim)
  2256. shape = list(input.shape)
  2257. if onesided:
  2258. last_dim = dim[-1]
  2259. shape[last_dim] = shape[last_dim] // 2 + 1
  2260. dtype = utils.corresponding_complex_dtype(input.dtype)
  2261. strides = utils.make_contiguous_strides_for(shape)
  2262. return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=input.device)
  2263. def _fft_r2c_aten(
  2264. input: TensorLike,
  2265. *,
  2266. dim: DimsSequenceType,
  2267. onesided: bool,
  2268. ) -> TensorLikeType:
  2269. normalization = 0 # No normalization
  2270. return torch._fft_r2c(input, dim, normalization, onesided)
  2271. _fft_r2c_doc = """
  2272. Performs a real to complex Fast Fourier Transform
  2273. """
  2274. fft_r2c = _make_prim(
  2275. schema="fft_r2c(Tensor self, *, int[] dim, bool onesided) -> Tensor",
  2276. meta=_fft_r2c_meta,
  2277. impl_aten=_fft_r2c_aten,
  2278. return_type=RETURN_TYPE.NEW,
  2279. doc=_fft_r2c_doc,
  2280. )
  2281. def _fft_c2c_meta(
  2282. input: TensorLike,
  2283. *,
  2284. dim: DimsSequenceType,
  2285. forward: bool,
  2286. ) -> TensorLikeType:
  2287. dim = utils.canonicalize_dims(input.ndim, dim)
  2288. utils.validate_no_repeating_dims(dim)
  2289. shape = input.shape
  2290. strides = utils.make_contiguous_strides_for(shape)
  2291. return TensorMeta(
  2292. shape=shape, strides=strides, dtype=input.dtype, device=input.device
  2293. )
  2294. def _fft_c2c_aten(
  2295. input: TensorLike,
  2296. *,
  2297. dim: DimsSequenceType,
  2298. forward: bool,
  2299. ) -> TensorLikeType:
  2300. normalization = 0 # No normalization
  2301. return torch._fft_c2c(input, dim, normalization, forward)
  2302. _fft_c2c_doc = """
  2303. Performs either a Fast Fourier Transform, or its inverse
  2304. """
  2305. fft_c2c = _make_prim(
  2306. schema="fft_c2c(Tensor self, *, int[] dim, bool forward) -> Tensor",
  2307. meta=_fft_c2c_meta,
  2308. impl_aten=_fft_c2c_aten,
  2309. return_type=RETURN_TYPE.NEW,
  2310. doc=_fft_c2c_doc,
  2311. )
  2312. def _fft_c2r_meta(
  2313. input: TensorLike,
  2314. *,
  2315. dim: DimsSequenceType,
  2316. last_dim_size: int,
  2317. ) -> TensorLikeType:
  2318. dim = utils.canonicalize_dims(input.ndim, dim)
  2319. utils.validate_no_repeating_dims(dim)
  2320. shape = list(input.shape)
  2321. shape[dim[-1]] = last_dim_size
  2322. dtype = utils.corresponding_real_dtype(input.dtype)
  2323. strides = utils.make_contiguous_strides_for(shape)
  2324. return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=input.device)
  2325. def _fft_c2r_aten(
  2326. input: TensorLike,
  2327. *,
  2328. dim: DimsSequenceType,
  2329. last_dim_size: int,
  2330. ) -> TensorLikeType:
  2331. normalization = 0 # No normalization
  2332. return torch._fft_c2r(input, dim, normalization, last_dim_size)
  2333. _fft_c2r_doc = """
  2334. Performs a complex to real Inverse Fast Fourier Transform
  2335. """
  2336. fft_c2r = _make_prim(
  2337. schema="fft_c2r(Tensor self, *, int[] dim, SymInt last_dim_size) -> Tensor",
  2338. meta=_fft_c2r_meta,
  2339. impl_aten=_fft_c2r_aten,
  2340. return_type=RETURN_TYPE.NEW,
  2341. doc=_fft_c2r_doc,
  2342. )
  2343. register_nvprims()